02月 28th, 2009Strassen的矩阵乘法【c语言实现】
=========
请所有通过搜索引擎找过来的同学注意
如果你搜的关键字里有 “Strassen” 的话,我不说什么
但是那些搜“C语言矩阵乘法”的同学,
我要说,如果你想搜的是Strassen的算法实现的话,建议你下次加上算法的名字,
有针对性的话,会高效很多
但是,让我很无奈,甚至愤怒的是,
如果你仅仅是想得到一个蛮力算矩阵乘法的实现,
我只能说,这么简单的3层for循环,你都不会,
孩子,你的人生已经废了!
你甚至对不起谭浩强的绿皮书!
=========
刚刚把上次实验的代码修改完毕,反正闲着也是闲着,就随手升级了一下wp的后台程序。
然后又装了一个代码高亮的插件,现在就来测试一下。
下午要是依然无聊,我就把之前那个 【你问我答】 板块的第一期给写了……
简单说明,代码是前几天上课要交的作业,内容就是标题里的东西。
简单测试了一下,貌似能用。
至于会不会有某些特别邪恶的输入会导致程序崩溃,我就不清楚了……
特别提一下,用的是int型,所以不建议输入的数据太大……
ok,分割线一下
×××××××××××××××
以下内容含有晦涩代码
请技术无能者在父母陪同下进行阅读
×××××××××××××××
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 | /************************* * * FILE NAME: * matrix.c * * INTRODUCTION: * Implement the matrix multiplication using * Strassen's algorithm. And a function of * matrix addtion. * * DESCRIPTION: * 1. * int mult (int* m1, int* m2, int size) * Multiply two matrices, m1 and m2, with * "size" rows and columns. * Save the result in m1. * * 2. * int add (int* m1, int* m2, int size, char flag) * Add or minus two matrices as the function mult. * When "flag" is "ADD" addtion, * when "MINUS" subtraction. * Also save the result in m1. * * BIBLIOGRAPHY: * Introduction to Algorithms * 28.2 Strassen's algorithm for matrix * multiplication * * AUTHOR: * SRAY * * DATE: * 2009.02.28 * * 吐槽: * 以下代码为Sray为应付作业而写 * * 欢迎需要交作业的同学前来抄袭 * * 抄的时候,记得把最后面那个测试用的main函数删掉。。。 * **************************/ #include <stdio.h> #include <stdlib.h> #define ADD (0) #define MINUS (1) #define SIZE (5) #define MAX (25) int add (int* m1, int* m2, int size, char flag){ // 矩阵加减法,结果保存在m1中 // flag=ADD为加 // 其他(MINUS)为减 int n,m; m=size*size; for (n=0; n<m; ++n){ if (flag==(ADD)) m1[n]=m1[n] + m2[n]; else m1[n]=m1[n] - m2[n]; } return 0; } static int copy (int* m1, int* m2, int size, char flag) { // 复制矩阵 // flag=0时,将m1全部复制到m2中 // flag=5时,将m1的右下区块复制到m2中,如下图 // ┌┬────┐ // ├┼────┤ // ││ │ // ││ 5 │ // ││ │ // ││ │ // └┴────┘ // 其他值,复制m1中对应区块 至m2中 // ┌─┬─┐ // │1 │2 │ // ├─┼─┤ // │3 │4 │ // └─┴─┘ int i; int m=size*size; int ns=size/2;//ns=new size if (flag==0) { for (i=0; i<m; ++i) { m2[i]=m1[i]; } return 0; } else if (flag==1) { m=m/4; for (i=0; i<m; ++i) m2[i]=m1[i%ns+size*(i/ns)]; return 0; } else if (flag==2) { m=m/4; for (i=0; i<m; ++i) m2[i]=m1[i%ns+ns+size*(i/ns)]; return 0; } else if (flag==3) { m=m/4; for (i=0; i<m; ++i) m2[i]=m1[i%ns+size*(ns+i/ns)]; return 0; } else if (flag==4) { m=m/4; for (i=0; i<m; ++i) m2[i]=m1[i%ns+ns+size*(ns+i/ns)]; return 0; } else if (flag==5) { ns=size-1; m=ns*ns; for (i=0; i<m; ++i) m2[i]=m1[1+i%ns+size*(1+i/ns)]; return 0; } else { printf ("ERROR:no such flag."); return -1; } } int mult (int* m1, int* m2, int size) { // 矩阵乘法 // 结果保存在m1中 int *subm[11]; int i,j; int m; int ns; if (size==1) { // 貌似这一段不会用到…… // 直接size=2了…… m1[0]=m1[0]*m2[0]; return 0; } if (size==2) { // 边长为2的时候 // 直接计算 subm[0]=(int*)malloc(4*sizeof(int)); subm[0][0]=m1[0]*m2[0]+m1[1]*m2[2]; subm[0][1]=m1[0]*m2[1]+m1[1]*m2[3]; subm[0][2]=m1[2]*m2[0]+m1[3]*m2[2]; subm[0][3]=m1[2]*m2[1]+m1[3]*m2[3]; copy (subm[0], m1,2,0); free (subm[0]); return 0; } if (size%2==0) { // 边长为偶数的时候 // 参照 Strassen的矩阵乘法算法 // 具体符号标记参见 // 算法导论 28.2节 m=size*size/4; ns=size/2; for (i=0; i<11; ++i) { subm[i]=(int*)malloc(m*sizeof(int)); } //计算P1至P7 //subm 0-6 分别保存 P1 - P7 //subm 7-10 保存中间变量 copy (m1,subm[0],size,1); copy (m2,subm[7],size,2); copy (m2,subm[8],size,4); add (subm[7], subm[8],ns,MINUS); mult (subm[0],subm[7],ns); copy (m1,subm[1],size,1); copy (m1,subm[7],size,2); add (subm[1],subm[7],ns,ADD); copy (m2,subm[7],size,4); mult (subm[1],subm[7],ns); copy (m1,subm[2],size,3); copy (m1,subm[7],size,4); add (subm[2],subm[7],ns,ADD); copy (m2,subm[7],size,1); mult (subm[2],subm[7],ns); copy (m1,subm[3],size,4); copy (m2,subm[7],size,3); copy (m2,subm[8],size,1); add (subm[7],subm[8],ns,MINUS); mult (subm[3],subm[7],ns); copy (m1,subm[4],size,1); copy (m1,subm[7],size,4); add (subm[4],subm[7],ns,ADD); copy (m2,subm[7],size,1); copy (m2,subm[8],size,4); add (subm[7],subm[8],ns,ADD); mult (subm[4],subm[7],ns); copy (m1,subm[5],size,2); copy (m1,subm[7],size,4); add (subm[5],subm[7],ns,MINUS); copy (m2,subm[7],size,3); copy (m2,subm[8],size,4); add (subm[7],subm[8],ns,ADD); mult (subm[5],subm[7],ns); copy (m1,subm[6],size,1); copy (m1,subm[7],size,3); add (subm[6],subm[7],ns,MINUS); copy (m2,subm[7],size,1); copy (m2,subm[8],size,2); add (subm[7],subm[8],ns,ADD); mult (subm[6],subm[7],ns); // 计算r,s,t,u // r=sub[7] s=sub[8] // t=sub[9] u=sub[10] copy (subm[2],subm[9],ns,0); add (subm[9],subm[3],ns,ADD); //t copy (subm[0],subm[8],ns,0); add (subm[8],subm[1],ns,ADD); //s copy (subm[4],subm[7],ns,0); add (subm[7],subm[3],ns,ADD); add (subm[7],subm[1],ns,MINUS); add (subm[7],subm[5],ns,ADD); //R copy (subm[4],subm[10],ns,0); add (subm[10],subm[0],ns,ADD); add (subm[10],subm[2],ns,MINUS); add (subm[10],subm[6],ns,MINUS); //U // 拼接 r s t u for (i=0; i<m; ++i) { m1[i%ns+size*(i/ns)]=subm[7][i]; m1[i%ns+ns+size*(i/ns)]=subm[8][i]; m1[i%ns+size*(i/ns+ns)]=subm[9][i]; m1[i%ns+ns+size*(i/ns+ns)]=subm[10][i]; } //释放临时空间 for (i=0; i<11; ++i) free (subm[i]); return 0; } else { // 对于非偶数边长,做如下分割 // ┌──┬──────┐ // │1*1 │ 1*(n-1) │ // ├──┼──────┤ // │ │ │ // │ │ │ // │n-1 │ n-1 │ // │ * │ * │ // │ 1 │ n-1 │ // │ │ │ // │ │ │ // └──┴──────┘ ns=size-1; m=ns*ns; subm[0]=(int*)malloc(ns*sizeof(int)); subm[1]=(int*)malloc(ns*sizeof(int)); subm[2]=(int*)malloc(m*sizeof(int)); subm[3]=(int*)malloc(m*sizeof(int)); // subm 2,3 分别保存 m1,m2的右下角部分 copy (m1,subm[2],size,5); copy (m2,subm[3],size,5); // 递归求解 d*h mult (subm[2],subm[3],ns); // 求解 c*f // 并把两次结果相加,得到 U部分(n-1 * n-1) for (i=0; i<ns; ++i) for (j=0; j<ns; ++j) subm[3][i*ns + j]= m1[(i+1)*size] * m2[j+1]; add (subm[2],subm[3],ns,ADD); // 计算 a*g,c*e for (i=0; i<ns; ++i) { subm[0][i]=m1[0]*m2[i+1]; subm[1][i]=m2[0]*m1[size*(i+1)]; } // 计算 b*h,d*g for (i=0; i<ns; ++i) for (j=0; j<ns; ++j) { subm[0][i] += m1[j+1] * m2[i+1+size*(j+1)]; subm[1][i] += m2[size*(j+1)] * m1[size*(i+1)+j+1]; } //计算左上角1×1的r m1[0]=m1[0]*m2[0]; for (i=0; i<ns; ++i) m1[0] += m1[i+1] * m2[size*(i+1)]; //拼接 for (i=1; i<size; ++i) { m1[i]=subm[0][i-1]; m1[size*i]=subm[1][i-1]; } for (i=0; i<m; ++i) m1[1+i%ns+ size*(1+i/ns)] = subm[2][i]; //释放临时空间 for (i=0; i<4; ++i) free (subm[i]); return 0; } } int main (){ int m1[MAX]; int m2[MAX]; int m3[4]; int i; for (i=0; i<MAX; ++i) { m1[i]=1; m2[i]=i; } // add(m1,m2,4,MINUS); // copy (m2,m3,4,4); for (i=0; i<MAX; ++i) { printf (" %d ",m1[i]); if (i%SIZE==(SIZE-1)) printf("\n"); } printf ("*************\n"); for (i=0; i<MAX; ++i) { printf (" %d ",m2[i]); if (i%SIZE==(SIZE-1)) printf("\n"); } printf ("*************\n"); mult (m1,m2,SIZE); for (i=0; i<MAX; ++i) { printf (" %d ",m1[i]); if (i%SIZE==(SIZE-1)) printf("\n"); } printf ("*************\n"); /* for (i=0; i<4; ++i) { printf (" %d ",m3[i]); if (i%2==1) printf("\n"); } */ getchar(); return 0; } |
末了,多嘟囔几句。
这个算法不仅可以算边长是2的n次方的矩阵
也可以计算任意边长的方阵
我在草稿纸上简单比划了一下,最后的时间复杂度依然是 n的2.81次方
还有其他问题的,可以留言
有空我就解释……
然后,函数用的是递归调用,效能上会比较FC,不过既然只是交个作业……
不用关心那么多的吧……
再补充几条科普知识:
关于StranssenStrassen算法的中文wiki
http://zh.wikipedia.org/wiki/施特拉森演算法
目前已知最nb的矩阵算法,时间复杂度是n的2.4次方不到(英文)
http://en.wikipedia.org/wiki/Coppersmith–Winograd_algorithm
另外,根据Hopcroft和Kerr在1971年的证明,两个矩阵相乘,按照2×2的四块划分的话,能达到的最好上界就是S算法里的那个2.807。如果想得到更好的结果,譬如说C-W算法里的2.376就要使用其他更神奇的技巧……
好吧,那个1971的证明我不知道发表在哪里。网上看来的,那里也没说出处。倒是在上面那个英文wiki页面里,看到一个群论的单词……

02月 28th, 2009 at 13:05
前几天才学的
囧三 更加恶毒地答复:
02月 28th, 2009 at 13:14
欢迎前来抄作业
08月 20th, 2009 at 17:28
倒数第四段:Strassen -> Stranssen
囧三 更加恶毒地答复:
08月 20th, 2009 at 21:55
@knighter,
谢谢……
实在很汗,大概那个错误是我全文唯一一个自己输入的吧……
嗯,其他的应该都是复制输入……
这就去改
10月 14th, 2009 at 16:16
我不是来抄作业的,只是来找代码参考的,你这程序跑512×512(为了让你的划分尽量平均,以提高效率)的矩阵要16秒多,我三重循环才要1.9秒,虽说不是太在意效率,但你也不能比三重循环还慢啊~
囧三 更加恶毒地答复:
10月 14th, 2009 at 21:46
@Nova,
首先感谢对我的blog的关心
是这样的,你的这个问题,我当时赶作业的时候也发现了。
而对于这个问题,我的理解是这样的
我这个S算法的实现是一个递归的自调用过程。
而三层for是一个顺序的过程
我们当时上数据结构,老师第一节课就说过:
永远不要使用递归!
简单的例子就是用for和递归fibonacci数列,或者更简单的1加到100
我想,如果通过改变程序的结构,把递归改写成循环,
速度肯定会提高很多,
而且,我想应该会算的比三层for要快
02月 22nd, 2010 at 22:39
矩阵维度为奇数的时候,你总是分离开1个维度并不见得非常高效,比如n=1023的时候:)
我觉得更好的处理方式是把它变为一个n+1维的矩阵,用0填充多余元素,然后在结果中分离出来。
囧三 更加恶毒地答复:
02月 22nd, 2010 at 23:04
@feng,
呵呵,自己应付作业的粗糙代码,着实配不上如此细致的分析。
刚刚拿笔随手画了一下。这样通过临时提高维度方式,看上去的确比我的方法要简单不少。
也是,我当时一直是在想,如何持续的降低阶数,没有想到过还有如此的欲扬先抑的方法。
代码我是懒得再改了,但还是非常感谢
04月 23rd, 2010 at 02:25
05月 1st, 2010 at 18:30
我收这个进来…Coppersmith—Winograd好像米有