算法上机实验----矩阵乘法

矩阵乘法

数据结构定义:

  • MatFrame: 标定矩阵某一块
  • Mat : 矩阵的表示形式
1
2
3
4
5
6
7
8
9
10
11
12
13
#include <stdio.h>
#include <stdlib.h>
typedef struct _MatFrame {
int x, y;
int row, col;
} MatFrame;
typedef struct _Mat {
int * _data; // 真实数据
int **data; // 用于二维数据索引
int row, col;
} Mat;

数据结构基本操作

NewMatFrame 用于快速创建MatFrame

1
2
3
4
5
6
7
8
MatFrame NewMatFrame( int x, int y, int row, int col ) {
MatFrame mf;
mf.x = x;
mf.y = y;
mf.row = row;
mf.col = col;
return mf;
}

mat_create 创建一个新的矩阵,注意要用mat_free释放

1
2
3
4
5
6
7
8
9
10
11
Mat *mat_create( int row, int col ) {
Mat *mat = (Mat *) malloc( sizeof( Mat ) );
mat->_data = (int *) malloc( sizeof( int ) * row * col );
mat->data = (int **) malloc( sizeof( int * ) * row );
for ( int i = 0; i < row; i++ ) {
mat->data[ i ] = mat->_data + row * i;
}
mat->row = row;
mat->col = col;
return mat;
}

mat_copy Mat的浅拷贝

1
2
3
4
5
6
7
8
Mat *mat_copy( Mat *mat ) {
Mat *new_mat = (Mat *) malloc( sizeof( Mat ) );
new_mat->_data = mat->_data;
new_mat->data = mat->data;
new_mat->row = mat->row;
new_mat->col = mat->col;
return new_mat;
}

mat_free 用create方式创建的都要用此函数释放

1
2
3
4
5
void mat_free( Mat *mat ) {
free( mat->data );
free( mat->_data );
free( mat );
}

mat_init 用于初始化矩阵, random标定是否用随机数初始化

1
2
3
4
5
6
7
void mat_init( Mat *mat, int random ) {
for ( int i = 0; i < mat->row; i++ ) {
for ( int j = 0; j < mat->col; j++ ) {
mat->data[ i ][ j ] = random ? rand() % 5 : 0;
}
}
}

mat_print 用于打印数组

1
2
3
4
5
6
7
8
9
10
11
12
void mat_print( Mat *mat ) {
for ( int i = 0; i < mat->row; i++ ) {
for ( int j = 0; j < mat->col; j++ ) {
printf( "%d", mat->data[ i ][ j ] );
if ( j != mat->col ) {
printf( " " );
}
}
printf( "\n" );
}
printf( "\n" );
}

mat_add 矩阵相加,返回一个新的,相加过后的矩阵

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Mat *mat_add( Mat *mat1, Mat *mat2 ) {
if ( mat1->row != mat2->row || mat1->col != mat2->col ) {
return NULL;
}
Mat *result = mat_create( mat1->row, mat2->col );
mat_init( result, 0 );
for ( int i = 0; i < result->row; i++ ) {
for ( int j = 0; j < result->col; j++ ) {
result->data[ i ][ j ] =
mat1->data[ i ][ j ] + mat2->data[ i ][ j ];
}
}
return result;
}

mat_copy_with_frame 将sec矩阵的某一块拷贝dest矩阵的某一块

1
2
3
4
5
6
7
8
9
10
11
12
13
void mat_copy_with_frame( Mat *dest, MatFrame dmf, Mat *src, MatFrame smf ) {
if ( dmf.row != smf.row || dmf.col != smf.col ) {
return;
}
int row = dmf.row, col = dmf.col;
for ( int i = 0; i < row; i++ ) {
for ( int j = 0; j < col; j++ ) {
dest->data[ dmf.x + i ][ dmf.y + j ] =
src->data[ smf.x + i ][ smf.y + j ];
}
}
}

算法的具体实施

经典算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 用经典定义去计算矩阵乘积
Mat *mat_multiply( Mat *mat1, Mat *mat2 ) {
if ( mat1->col != mat2->row ) {
return NULL;
}
Mat *result = mat_create( mat1->row, mat2->col );
mat_init( result, 0 );
for ( int i = 0; i < result->row; i++ ) {
for ( int j = 0; j < result->col; j++ ) {
for ( int k = 0; k < mat1->col; k++ ) {
result->data[ i ][ j ] +=
mat1->data[ i ][ k ] * mat2->data[ k ][ j ];
}
}
}
return result;
}

分治算法

假设矩阵为方阵,尺寸是 2n2^n
将矩阵乘法形式:

C=ABC = A*B

转换成:

[C11C12C21C22]=[A11A12A21A22][B11B12B21B22] \left[ \begin{matrix} C_{11} & C_{12}\\\\ C_{21} & C_{22} \end{matrix} \right] = \left[ \begin{matrix} A_{11} & A_{12}\\\\ A_{21} & A_{22} \end{matrix} \right] \left[ \begin{matrix} B_{11} & B_{12}\\\\ B_{21} & B_{22} \end{matrix} \right]

然后分别计算每一块 CC

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// 实际的递归函数
void _mat_multiply_recursive( Mat *mat1, MatFrame mf1, Mat *mat2, MatFrame mf2,
Mat *res, MatFrame resmf ) {
int n = mf1.row;
if ( n == 1 ) {
int a = mat1->data[ mf1.x ][ mf1.y ];
int b = mat2->data[ mf2.x ][ mf2.y ];
res->data[ resmf.x ][ resmf.y ] = a * b;
} else {
int tn = n / 2;
Mat *tres;
Mat *t1 = mat_create( tn, tn );
Mat *t2 = mat_create( tn, tn );
mat_init( t1, 0 );mat_init( t2, 0 );
// A_11*B_11 + A_12*B_21
_mat_multiply_recursive( mat1, NewMatFrame( mf1.x, mf1.y, tn, tn ),
mat2, NewMatFrame( mf2.x, mf2.y, tn, tn ), t1,
NewMatFrame( 0, 0, tn, tn ) );
_mat_multiply_recursive( mat1, NewMatFrame( mf1.x, mf1.y + tn, tn, tn ),
mat2, NewMatFrame( mf2.x + tn, mf2.y, tn, tn ),
t2, NewMatFrame( 0, 0, tn, tn ) );
tres = mat_add( t1, t2 );
mat_copy_with_frame( res, NewMatFrame( resmf.x, resmf.y, tn, tn ), tres,
NewMatFrame( 0, 0, tn, tn ) );
mat_init( t1, 0 ); mat_init( t2, 0 ); mat_free( tres );
// A_11*B_12 + A_12*B_22
_mat_multiply_recursive( mat1, NewMatFrame( mf1.x, mf1.y, tn, tn ),
mat2, NewMatFrame( mf2.x, mf2.y + tn, tn, tn ),
t1, NewMatFrame( 0, 0, tn, tn ) );
_mat_multiply_recursive( mat1, NewMatFrame( mf1.x, mf1.y + tn, tn, tn ),
mat2,
NewMatFrame( mf2.x + tn, mf2.y + tn, tn, tn ),
t2, NewMatFrame( 0, 0, tn, tn ) );
tres = mat_add( t1, t2 );
mat_copy_with_frame( res, NewMatFrame( resmf.x, resmf.y + tn, tn, tn ),
tres, NewMatFrame( 0, 0, tn, tn ) );
mat_init( t1, 0 ); mat_init( t2, 0 ); mat_free( tres );
// A_21*B_11 + A_22*B_21
_mat_multiply_recursive( mat1, NewMatFrame( mf1.x + tn, mf1.y, tn, tn ),
mat2, NewMatFrame( mf2.x, mf2.y, tn, tn ), t1,
NewMatFrame( 0, 0, tn, tn ) );
_mat_multiply_recursive( mat1,
NewMatFrame( mf1.x + tn, mf1.y + tn, tn, tn ),
mat2, NewMatFrame( mf2.x + tn, mf2.y, tn, tn ),
t2, NewMatFrame( 0, 0, tn, tn ) );
tres = mat_add( t1, t2 );
mat_copy_with_frame( res, NewMatFrame( resmf.x + tn, resmf.y, tn, tn ),
tres, NewMatFrame( 0, 0, tn, tn ) );
mat_init( t1, 0 ); mat_init( t2, 0 ); mat_free( tres );
// A_21*B_12 + A_22*B_22
_mat_multiply_recursive( mat1, NewMatFrame( mf1.x + tn, mf1.y, tn, tn ),
mat2, NewMatFrame( mf2.x, mf2.y + tn, tn, tn ),
t1, NewMatFrame( 0, 0, tn, tn ) );
_mat_multiply_recursive(
mat1, NewMatFrame( mf1.x + tn, mf1.y + tn, tn, tn ), mat2,
NewMatFrame( mf2.x + tn, mf2.y + tn, tn, tn ), t2,
NewMatFrame( 0, 0, tn, tn ) );
tres = mat_add( t1, t2 );
mat_copy_with_frame( res,
NewMatFrame( resmf.x + tn, resmf.y + tn, tn, tn ),
tres, NewMatFrame( 0, 0, tn, tn ) );
mat_free( t1 ); mat_free( t2 ); mat_free( tres );
}
}
// 包装函数
Mat *mat_multiply_recursive( Mat *mat1, Mat *mat2 ) {
if ( mat1->row != mat1->col || mat2->row != mat2->col ||
mat1->col != mat2->row ) {
return NULL;
}
Mat *res = mat_create( mat1->row, mat1->col );
mat_init( res, 0 );
_mat_multiply_recursive( mat1, NewMatFrame( 0, 0, mat1->row, mat1->col ),
mat2, NewMatFrame( 0, 0, mat2->row, mat2->col ),
res, NewMatFrame( 0, 0, res->row, res->col ) );
return res;
}

主函数

调用上面的方法,记得free

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
int main() {
int n = 16;
Mat *mat1 = mat_create( n, n );
Mat *mat2 = mat_create( n, n );
mat_init( mat1, 1 );
mat_init( mat2, 1 );
// Mat* res = mat_multiply(mat1, mat2);
mat_print( mat1 );
mat_print( mat2 );
Mat *res = mat_multiply( mat1, mat2 );
mat_print( res );
mat_free( res );
res = mat_multiply_recursive( mat1, mat2 );
mat_print( res );
mat_free( res );
mat_free( mat1 );
mat_free( mat2 );
return 0;
}