📄 arrmult.sl
字号:
_debug_info = 1; () = evalfile ("inc.sl");print ("Testing Matrix Multiplications ...");#ifexists Double_Typestatic define dot_prod (a, b){ (a # b)[0]; % transpose not needed for 1-d arrays}static define sum (a){ variable ones = Double_Type [length (a)] + 1; dot_prod (a, ones);}if (1+2+3+4+5 != sum([1,2,3,4,5])) failed ("sum");#ifexists Complex_Typeif (1+2i != sum ([1,2i])) failed ("sum complex");#endifdefine mult (a, b){ variable dims_a, dims_b; variable nr_a, nr_b, nc_a, nc_b; variable i, j; variable c; (dims_a,,) = array_info (a); (dims_b,,) = array_info (b); nr_a = dims_a[0]; nc_a = dims_a[1]; nr_b = dims_b[0]; nc_b = dims_b[1]; c = _typeof ([a[0,0]]#[b[0,0]])[nr_a, nc_b]; for (i = 0; i < nr_a; i++) { for (j = 0; j < nc_b; j++) c[i,j] = dot_prod (a[i,*], b[*,j]); } return c;}static define arr_cmp (a, b){ variable i = length (where (b != a)); if (i == 0) return 0; i = where (b != a); a = a[i]; b = b[i]; reshape (a, [length(a)]); reshape (b, [length(b)]); vmessage ("%S != %S\n", a[0], b[0]); return 1;}static define test (a, b){ if (0 != arr_cmp (mult (a,b), a#b)) failed ("%S # %S", a, b);}variable A, B;#ifexists Complex_TypeA = [1+2i];B = [3+4i];reshape (A, [1, 1]);reshape (B, [1, 1]);test (A,B);#endif% Test intgersA = _reshape ([[1, 2, 3], [4, 5, 6]], [2,3]);B = _reshape ([[7,8,9],[1,2,4]], [2,3]);B = transpose (B);test (A, B);B *= 1f;test (A, B);B *= 1.0;test (A,B);A *= 1f;test (A,B);#ifexists Complex_TypeB += 2i;test (A,B);A += 3i;test (A,B);B = Real(B);test (A,B);% Now try an empty arrayif (Complex_Type != _typeof (Complex_Type[0,0,0] # Complex_Type[0])) failed ("[]#[]");#endif% And finally, do a 3-d array:A = _reshape ([1:2*3*4], [2,3,4]);B = _reshape ([1:4*5*6], [4,5,6]);static variable C = A#B;% C should be a [2,3,5,6] matrix. Let's check via brute forcestatic define multiply_3d (a, b, c){ variable i, j, k, l, m; variable dims_a, dims_b; (dims_a,,) = array_info(a); (dims_b,,) = array_info(b); _for (0, dims_a[0]-1, 1) { i = (); _for (0, dims_a[1]-1, 1) { j = (); _for (0, dims_b[1]-1, 1) { l = (); _for (0, dims_b[2]-1, 1) { m = (); variable sum = 0; _for (0, dims_b[0]-1, 1) { k = (); sum += a[i,j,k] * b[k, l, m]; } if (sum != c[i,j,l,m]) failed ("multiply_3d"); } } } }}multiply_3d (A, B, C); print ("Ok\n");#elseprint ("Not available\n");#endifexit (0);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -