=========
请所有通过搜索引擎找过来的同学注意

如果你搜的关键字里有 “Strassen” 的话,我不说什么

但是那些搜“C语言矩阵乘法”的同学,
我要说,如果你想搜的是Strassen的算法实现的话,建议你下次加上算法的名字,
有针对性的话,会高效很多

但是,让我很无奈,甚至愤怒的是,
如果你仅仅是想得到一个蛮力算矩阵乘法的实现,
我只能说,这么简单的3层for循环,你都不会,
孩子,你的人生已经废了!

你甚至对不起谭浩强的绿皮书
=========




刚刚把上次实验的代码修改完毕,反正闲着也是闲着,就随手升级了一下wp的后台程序。

然后又装了一个代码高亮的插件,现在就来测试一下。

下午要是依然无聊,我就把之前那个 【你问我答】 板块的第一期给写了……

简单说明,代码是前几天上课要交的作业,内容就是标题里的东西。
简单测试了一下,貌似能用。
至于会不会有某些特别邪恶的输入会导致程序崩溃,我就不清楚了……
特别提一下,用的是int型,所以不建议输入的数据太大……

ok,分割线一下

×××××××××××××××
以下内容含有晦涩代码
请技术无能者在父母陪同下进行阅读

×××××××××××××××

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
/*************************
 *  
 *  FILE NAME: 
 *  matrix.c
 *
 *  INTRODUCTION:
 *  Implement the matrix multiplication using 
 *  Strassen's algorithm. And a function of
 *  matrix addtion.
 *
 *  DESCRIPTION:
 *  1.
 *  int mult (int* m1, int* m2, int size)
 *  Multiply two matrices, m1 and m2, with 
 *  "size" rows and columns.
 *  Save the result in m1.
 *
 *  2.
 *  int add (int* m1, int* m2, int size, char flag)
 *  Add or minus two matrices as the function mult.
 *  When "flag" is "ADD" addtion,
 *  when "MINUS" subtraction.
 *  Also save the result in m1.
 *
 *  BIBLIOGRAPHY:
 *  Introduction to Algorithms
 *  28.2 Strassen's algorithm for matrix
 *  multiplication
 *
 *  AUTHOR:
 *  SRAY
 *
 *  DATE:
 *  2009.02.28
 *
 *  吐槽:
 *  以下代码为Sray为应付作业而写
 *
 *  欢迎需要交作业的同学前来抄袭
 *
 *  抄的时候,记得把最后面那个测试用的main函数删掉。。。
 *
 **************************/ 
 
 
#include <stdio.h>
#include <stdlib.h>
#define ADD (0)
#define MINUS (1)
 
#define SIZE (5)
#define MAX (25)
 
 
int add (int* m1, int* m2, int size, char flag){
      //  矩阵加减法,结果保存在m1中
      //  flag=ADD为加
      //  其他(MINUS)为减
 
      int n,m;
      m=size*size;
      for (n=0; n<m; ++n){
            if (flag==(ADD))
                  m1[n]=m1[n] + m2[n];
            else
                  m1[n]=m1[n] - m2[n];
      }
      return 0;
}
 
static int copy (int* m1, int* m2, int size, char flag) {
      // 复制矩阵
      // flag=0时,将m1全部复制到m2中
      // flag=5时,将m1的右下区块复制到m2中,如下图
      // ┌┬────┐
      // ├┼────┤
      // ││        │
      // ││  5     │
      // ││        │
      // ││        │
      // └┴────┘
      //  其他值,复制m1中对应区块 至m2中
      // ┌─┬─┐
      // │1 │2 │
      // ├─┼─┤
      // │3 │4 │
      // └─┴─┘
 
      int i;
      int m=size*size;
      int ns=size/2;//ns=new size
      if (flag==0) {
            for (i=0; i<m; ++i) {
                  m2[i]=m1[i];
            }
            return 0;
      }
      else if (flag==1) {
            m=m/4;
            for (i=0; i<m; ++i) 
                  m2[i]=m1[i%ns+size*(i/ns)];
            return 0;
      }
      else if (flag==2) {
            m=m/4;
            for (i=0; i<m; ++i)
                  m2[i]=m1[i%ns+ns+size*(i/ns)];
            return 0;
      }
      else if (flag==3) {
            m=m/4;
            for (i=0; i<m; ++i)
                  m2[i]=m1[i%ns+size*(ns+i/ns)];
            return 0;
      }
      else if (flag==4) {
            m=m/4;
            for (i=0; i<m; ++i)
                  m2[i]=m1[i%ns+ns+size*(ns+i/ns)];
            return 0;
      }
      else if (flag==5) {
            ns=size-1;
            m=ns*ns;
            for (i=0; i<m; ++i)
                  m2[i]=m1[1+i%ns+size*(1+i/ns)];
            return 0;
      }
      else {
            printf ("ERROR:no such flag.");
            return -1;
      }
 
 
}
 
 
int mult (int* m1, int* m2, int size) {
      //  矩阵乘法
      //  结果保存在m1中
 
      int *subm[11];
      int i,j;
 
      int m;
      int ns;
 
      if (size==1) {
            // 貌似这一段不会用到……
            // 直接size=2了……
            m1[0]=m1[0]*m2[0];
            return 0;
      }
      if (size==2) {
            //  边长为2的时候
            //  直接计算
            subm[0]=(int*)malloc(4*sizeof(int));
            subm[0][0]=m1[0]*m2[0]+m1[1]*m2[2];
            subm[0][1]=m1[0]*m2[1]+m1[1]*m2[3];
            subm[0][2]=m1[2]*m2[0]+m1[3]*m2[2];
            subm[0][3]=m1[2]*m2[1]+m1[3]*m2[3];
            copy (subm[0], m1,2,0);
            free (subm[0]);
            return 0;
      }
      if (size%2==0) {
            //  边长为偶数的时候
            //  参照 Strassen的矩阵乘法算法
            //  具体符号标记参见
            //  算法导论 28.2节
 
            m=size*size/4;
            ns=size/2;
            for (i=0; i<11; ++i) {
                  subm[i]=(int*)malloc(m*sizeof(int));
            }
 
            //计算P1至P7
            //subm 0-6 分别保存 P1 - P7
            //subm 7-10 保存中间变量
 
            copy (m1,subm[0],size,1);
            copy (m2,subm[7],size,2);
            copy (m2,subm[8],size,4);
            add (subm[7], subm[8],ns,MINUS);
            mult (subm[0],subm[7],ns);
 
            copy (m1,subm[1],size,1);
            copy (m1,subm[7],size,2);
            add (subm[1],subm[7],ns,ADD);
            copy (m2,subm[7],size,4);
            mult (subm[1],subm[7],ns);
 
            copy (m1,subm[2],size,3);
            copy (m1,subm[7],size,4);
            add (subm[2],subm[7],ns,ADD);
            copy (m2,subm[7],size,1);
            mult (subm[2],subm[7],ns);
 
            copy (m1,subm[3],size,4);
            copy (m2,subm[7],size,3);
            copy (m2,subm[8],size,1);
            add (subm[7],subm[8],ns,MINUS);
            mult (subm[3],subm[7],ns);
 
            copy (m1,subm[4],size,1);
            copy (m1,subm[7],size,4);
            add (subm[4],subm[7],ns,ADD);
            copy (m2,subm[7],size,1);
            copy (m2,subm[8],size,4);
            add (subm[7],subm[8],ns,ADD);
            mult (subm[4],subm[7],ns);
 
            copy (m1,subm[5],size,2);
            copy (m1,subm[7],size,4);
            add (subm[5],subm[7],ns,MINUS);
            copy (m2,subm[7],size,3);
            copy (m2,subm[8],size,4);
            add (subm[7],subm[8],ns,ADD);
            mult (subm[5],subm[7],ns);
 
            copy (m1,subm[6],size,1);
            copy (m1,subm[7],size,3);
            add (subm[6],subm[7],ns,MINUS);
            copy (m2,subm[7],size,1);
            copy (m2,subm[8],size,2);
            add (subm[7],subm[8],ns,ADD);
            mult (subm[6],subm[7],ns);
            //  计算r,s,t,u
            //  r=sub[7]   s=sub[8]  
            //  t=sub[9]  u=sub[10]
 
            copy (subm[2],subm[9],ns,0);
            add (subm[9],subm[3],ns,ADD);
            //t
 
            copy (subm[0],subm[8],ns,0);
            add (subm[8],subm[1],ns,ADD);
            //s
 
            copy (subm[4],subm[7],ns,0); 
            add (subm[7],subm[3],ns,ADD);
            add (subm[7],subm[1],ns,MINUS);
            add (subm[7],subm[5],ns,ADD);
            //R
 
            copy (subm[4],subm[10],ns,0); 
            add (subm[10],subm[0],ns,ADD);
            add (subm[10],subm[2],ns,MINUS);
            add (subm[10],subm[6],ns,MINUS);
            //U
 
 
            // 拼接 r s t u
            for (i=0; i<m; ++i) {
                  m1[i%ns+size*(i/ns)]=subm[7][i];
                  m1[i%ns+ns+size*(i/ns)]=subm[8][i];
                  m1[i%ns+size*(i/ns+ns)]=subm[9][i];
                  m1[i%ns+ns+size*(i/ns+ns)]=subm[10][i];
            }
 
            //释放临时空间
            for (i=0; i<11; ++i)
                  free (subm[i]);
 
            return 0;
      }
      else {
            //  对于非偶数边长,做如下分割
            // ┌──┬──────┐
            // │1*1 │  1*(n-1)   │
            // ├──┼──────┤
            // │    │            │
            // │    │            │
            // │n-1 │     n-1    │
            // │ *  │      *     │
            // │ 1  │     n-1    │
            // │    │            │
            // │    │            │
            // └──┴──────┘
 
            ns=size-1;
            m=ns*ns;
            subm[0]=(int*)malloc(ns*sizeof(int));
            subm[1]=(int*)malloc(ns*sizeof(int));
            subm[2]=(int*)malloc(m*sizeof(int));
            subm[3]=(int*)malloc(m*sizeof(int));
 
            //  subm 2,3 分别保存 m1,m2的右下角部分
            copy (m1,subm[2],size,5);
            copy (m2,subm[3],size,5);
            //  递归求解 d*h
            mult (subm[2],subm[3],ns);
            //  求解 c*f
            //  并把两次结果相加,得到 U部分(n-1 * n-1)
            for (i=0; i<ns; ++i) 
                  for (j=0; j<ns; ++j)
                        subm[3][i*ns + j]= m1[(i+1)*size] * m2[j+1];
            add (subm[2],subm[3],ns,ADD);
 
            // 计算 a*g,c*e
            for (i=0; i<ns; ++i) {
                  subm[0][i]=m1[0]*m2[i+1];
                  subm[1][i]=m2[0]*m1[size*(i+1)];
            }
            //  计算 b*h,d*g
            for (i=0; i<ns; ++i) 
                  for (j=0; j<ns; ++j) {
                        subm[0][i] += m1[j+1] * m2[i+1+size*(j+1)];
                        subm[1][i] += m2[size*(j+1)] * m1[size*(i+1)+j+1];
                  }
            //计算左上角1×1的r
            m1[0]=m1[0]*m2[0];
            for (i=0; i<ns; ++i) 
                  m1[0] += m1[i+1] * m2[size*(i+1)];
 
            //拼接
 
            for (i=1; i<size; ++i) {
                  m1[i]=subm[0][i-1];
                  m1[size*i]=subm[1][i-1];
            }
            for (i=0; i<m; ++i) 
                  m1[1+i%ns+ size*(1+i/ns)] = subm[2][i];
            //释放临时空间
            for (i=0; i<4; ++i)
                  free (subm[i]);
            return 0;
      }
 
}
 
 
 
 
 
int main (){
      int m1[MAX];
      int m2[MAX];
      int m3[4];
 
      int i;
      for (i=0; i<MAX; ++i) {
            m1[i]=1;
            m2[i]=i;
      }
    //  add(m1,m2,4,MINUS);
    //  copy (m2,m3,4,4);
 
      for (i=0; i<MAX; ++i) {
            printf (" %d   ",m1[i]);
            if (i%SIZE==(SIZE-1)) printf("\n");
      }
      printf ("*************\n");
 
      for (i=0; i<MAX; ++i) {
            printf (" %d   ",m2[i]);
            if (i%SIZE==(SIZE-1)) printf("\n");
      }
      printf ("*************\n");
 
 
      mult (m1,m2,SIZE);
 
      for (i=0; i<MAX; ++i) {
            printf (" %d   ",m1[i]);
            if (i%SIZE==(SIZE-1)) printf("\n");
      }
      printf ("*************\n");
 
 
 
/*
      for (i=0; i<4; ++i) {
            printf (" %d   ",m3[i]);
            if (i%2==1) printf("\n");
      }
 
	  */
	getchar();
      return 0;
}

末了,多嘟囔几句。
这个算法不仅可以算边长是2的n次方的矩阵
也可以计算任意边长的方阵
我在草稿纸上简单比划了一下,最后的时间复杂度依然是 n的2.81次方

还有其他问题的,可以留言
有空我就解释……

然后,函数用的是递归调用,效能上会比较FC,不过既然只是交个作业……
不用关心那么多的吧……

再补充几条科普知识:
关于StranssenStrassen算法的中文wiki
http://zh.wikipedia.org/wiki/施特拉森演算法

目前已知最nb的矩阵算法,时间复杂度是n的2.4次方不到(英文)
http://en.wikipedia.org/wiki/Coppersmith–Winograd_algorithm

另外,根据Hopcroft和Kerr在1971年的证明,两个矩阵相乘,按照2×2的四块划分的话,能达到的最好上界就是S算法里的那个2.807。如果想得到更好的结果,譬如说C-W算法里的2.376就要使用其他更神奇的技巧……

好吧,那个1971的证明我不知道发表在哪里。网上看来的,那里也没说出处。倒是在上面那个英文wiki页面里,看到一个群论的单词……

其实我还扯了更多