11049번: 행렬 곱셈 순서 (acmicpc.net)
11049번 : 행렬 곱셈 순서
크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.
예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.
- AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
- BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.
같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.
행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.
입력
첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어진다.
둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다. (1 ≤ r, c ≤ 500)
항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.
출력
첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같다.
생각해 볼 점
11066번 문제와 매우 흡사한 형태입니다.
행렬의 순서는 바뀌지 않고, 곱하는 순서만 정해주면 되는 문제이므로, 인접한 행렬만 곱할 수 있습니다.
예를 들면, A, B, C 행렬을 입력 받았을 때, A * C * B와 같은 곱셈은 이루어지지 않는다는 뜻입니다.
저는 정답을 구할 함수를 solution이라 칭하고, 인자 값은 Left와 Right로 하겠습니다.
solution(Left, Right)는 Left번 째 행렬에서 Right번 째까지의 행렬을 모두 곱했을 때 곱한 횟수의 최소값이라고 하겠습니다.
행렬이 A, B, C를 입력 받았다고 가정하고, 각각의 index를 0,1,2라고 하면
solution(0, 2)를 구하면 정답입니다.
solution(0, 2)는
1. [solution(0, 0) + solution(1, 2) + 이번 곱셈 횟수]
2. [solution(0, 1) + solution(2, 2) + 이번 곱셈 횟수]
두 가지 중 작은 값을 취할 것입니다.
solution(0, 0)과 같이 Left와 Right 값이 같으면, 곱한 횟수가 없으므로 0을 반환하고,
solution(1, 2)는 곧 B * C의 곱셈 횟수를 반환합니다.
1번은 결국 A * (B * C)의 곱셈 횟수를 나타내며, 식으로 나타내면
0 + (B*C의 곱셈 횟수) + 이번에 시행될 곱셈 횟수 입니다.
이번 곱셈 횟수를 구하는 방법은, [Left의 앞의 값 * 분할 값 * Right의 뒤의 값]입니다.
예를 들어,
A = 4 x 3
B = 3 x 6
C = 6 x 8
일 경우,
A * (B * C)의 곱셈 값은
Left의 앞 = 4
분할 값 = A 와 B 사이에서 분할 되었으므로 3
Right의 뒤 = 8
4 * 3 * 8입니다.
예시의 풀이)
solution(0, 2) = min( solution(0, 0) + solution(1, 2) + 4 * 3 * 8 ,
solution(0, 1) + solution(2, 2) + 4 * 6 * 8)
= min( 0 + solution(1, 1) + solution(2, 2) + 3 * 6 * 8 + 4 * 3 * 8 ,
solution(0, 0) + solution(1, 1) + 0 + 4 * 3 * 6 + 4 * 6 * 8 )
= min( 0 + 0 + 0 + 144 + 96 , 0 + 0 + 0 + 72 + 192 )
= min( 240, 264 )
= 240
코드
#include <iostream>
using namespace std;
int N;
pair<int, int> *matrix; //행렬의 x, y를 저장
int **dp; //dp[left][right] = left ~ right까지 행렬을 모두 곱한 횟수의 최소값
int solution(int left, int right)
{
if(left == right) return 0;
if(dp[left][right] != 0) return dp[left][right];
int min_result = 2147483647; //2^31 -1
for(int i = left; i < right; i++)
{
int result = solution(left, i) + solution(i + 1, right);
result += matrix[left].first * matrix[i].second * matrix[right].second; // 행렬의 곱의 횟수
if(result < min_result) min_result = result;
}
dp[left][right] = min_result;
return min_result;
}
int main()
{
scanf("%d", &N);
matrix = new pair<int,int>[N];
dp = new int*[N];
for(int i = 0; i < N; i++)
{
pair<int, int> input;
scanf("%d %d", &input.first, &input.second);
matrix[i] = input;
dp[i] = new int[N];
fill_n(dp[i], N, 0);
}
printf("%d", solution(0, N-1));
for(int i = 0; i < N; i++) delete[] dp[i];
delete[] dp;
delete[] matrix;
return 0;
}
그 외
'공부 및 정리 > 백준 코드' 카테고리의 다른 글
[C++]백준 - 12865번 문제 (0) | 2021.08.05 |
---|---|
[C++]백준 - 1912번 문제 (0) | 2021.08.05 |
[C++]백준 - 9251번 문제 (0) | 2021.08.01 |
[C++]백준 - 7569번 문제 (0) | 2021.08.01 |
[C++]백준 - 1929번 문제 (0) | 2021.08.01 |