📄 matmult.br
字号:
// matmult.br// tests kernel splitting of many-output kernel#include <stdio.h>kernel void matmult4_pretransposed( float4 a0, float4 a1, float4 a2, float4 a3, float4 b0, float4 b1, float4 b2, float4 b3, out float4 c0<>, out float4 c1<>, out float4 c2<>, out float4 c3<> ){ c0 = float4( dot( a0, b0 ), dot( a0, b1 ), dot( a0, b2 ), dot( a0, b3 ) ); c1 = float4( dot( a1, b0 ), dot( a1, b1 ), dot( a1, b2 ), dot( a1, b3 ) ); c2 = float4( dot( a2, b0 ), dot( a2, b1 ), dot( a2, b2 ), dot( a2, b3 ) ); c3 = float4( dot( a3, b0 ), dot( a3, b1 ), dot( a3, b2 ), dot( a3, b3 ) );}
kernel void transpose4( float4 x0, float4 x1, float4 x2, float4 x3,
out float4 y0<>, out float4 y1<>, out float4 y2<>, out float4 y3<> )
{
y0 = float4( x0.x, x1.x, x2.x, x3.x );
y1 = float4( x0.y, x1.y, x2.y, x3.y );
y2 = float4( x0.z, x1.z, x2.z, x3.z );
y3 = float4( x0.w, x1.w, x2.w, x3.w );
}
kernel void matmult8(
float4 ax0<>, float4 ax1<>, float4 ax2<>, float4 ax3<>,
float4 ay0<>, float4 ay1<>, float4 ay2<>, float4 ay3<>,
float4 az0<>, float4 az1<>, float4 az2<>, float4 az3<>,
float4 aw0<>, float4 aw1<>, float4 aw2<>, float4 aw3<>,
float4 bx0<>, float4 bx1<>, float4 bx2<>, float4 bx3<>,
float4 by0<>, float4 by1<>, float4 by2<>, float4 by3<>,
float4 bz0<>, float4 bz1<>, float4 bz2<>, float4 bz3<>,
float4 bw0<>, float4 bw1<>, float4 bw2<>, float4 bw3<>,
out float4 cx0<>, out float4 cx1<>, out float4 cx2<>, out float4 cx3<>,
out float4 cy0<>, out float4 cy1<>, out float4 cy2<>, out float4 cy3<>,
out float4 cz0<>, out float4 cz1<>, out float4 cz2<>, out float4 cz3<>,
out float4 cw0<>, out float4 cw1<>, out float4 cw2<>, out float4 cw3<>
)
{
float4 tx0, tx1, tx2, tx3,
ty0, ty1, ty2, ty3,
tz0, tz1, tz2, tz3,
tw0, tw1, tw2, tw3;
/*
float4 dx0, dx1, dx2, dx3,
dy0, dy1, dy2, dy3,
dz0, dz1, dz2, dz3,
dw0, dw1, dw2, dw3;*/
float4 e0, e1, e2, e3;
transpose4( bx0, bx1, bx2, bx3, tx0, tx1, tx2, tx3 );
transpose4( by0, by1, by2, by3, ty0, ty1, ty2, ty3 );
transpose4( bz0, bz1, bz2, bz3, tz0, tz1, tz2, tz3 );
transpose4( bw0, bw1, bw2, bw3, tw0, tw1, tw2, tw3 );
// dx
matmult4_pretransposed( ax0, ax1, ax2, ax3, tx0, tx1, tx2, tx3, e0, e1, e2, e3 );
cx0 = e0; cx1 = e1; cx2 = e2; cx3 = e3;
matmult4_pretransposed( ay0, ay1, ay2, ay3, tz0, tz1, tz2, tz3, e0, e1, e2, e3 );
cx0 = cx0 + e0; cx1 = cx1 + e1; cx2 = cx2 + e2; cx3 = cx3 + e3;
// dy
matmult4_pretransposed( ax0, ax1, ax2, ax3, ty0, ty1, ty2, ty3, e0, e1, e2, e3 );
cy0 = e0; cy1 = e1; cy2 = e2; cy3 = e3;
matmult4_pretransposed( ay0, ay1, ay2, ay3, tw0, tw1, tw2, tw3, e0, e1, e2, e3 );
cy0 = cy0 + e0; cy1 = cy1 + e1; cy2 = cy2 + e2; cy3 = cy3 + e3;
// dz ...
matmult4_pretransposed( az0, az1, az2, az3, tx0, tx1, tx2, tx3, e0, e1, e2, e3 );
cz0 = e0; cz1 = e1; cz2 = e2; cz3 = e3;
matmult4_pretransposed( aw0, aw1, aw2, aw3, tz0, tz1, tz2, tz3, e0, e1, e2, e3 );
cz0 = cz0 + e0; cz1 = cz1 + e1; cz2 = cz2 + e2; cz3 = cz3 + e3;
// dw ...
matmult4_pretransposed( az0, az1, az2, az3, ty0, ty1, ty2, ty3, e0, e1, e2, e3 );
cw0 = e0; cw1 = e1; cw2 = e2; cw3 = e3;
matmult4_pretransposed( aw0, aw1, aw2, aw3, tw0, tw1, tw2, tw3, e0, e1, e2, e3 );
cw0 = cw0 + e0; cw1 = cw1 + e1; cw2 = cw2 + e2; cw3 = cw3 + e3;
/*
cx0 = dx0; cx1 = dx1; cx2 = dx2; cx3 = dx3;
cy0 = dy0; cy1 = dy1; cy2 = dy2; cy3 = dy3;
cz0 = dz0; cz1 = dz1; cz2 = dz2; cz3 = dz3;
cw0 = dw0; cw1 = dw1; cw2 = dw2; cw3 = dw3;*/
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -