📄 matmult_test.br
字号:
// matmult_test.br// tests kernel splitting of many-output kernel#include <stdio.h>
kernel void matmult8(
float4 ax0<>, float4 ax1<>, float4 ax2<>, float4 ax3<>,
float4 ay0<>, float4 ay1<>, float4 ay2<>, float4 ay3<>,
float4 az0<>, float4 az1<>, float4 az2<>, float4 az3<>,
float4 aw0<>, float4 aw1<>, float4 aw2<>, float4 aw3<>,
float4 bx0<>, float4 bx1<>, float4 bx2<>, float4 bx3<>,
float4 by0<>, float4 by1<>, float4 by2<>, float4 by3<>,
float4 bz0<>, float4 bz1<>, float4 bz2<>, float4 bz3<>,
float4 bw0<>, float4 bw1<>, float4 bw2<>, float4 bw3<>,
out float4 cx0<>, out float4 cx1<>, out float4 cx2<>, out float4 cx3<>,
out float4 cy0<>, out float4 cy1<>, out float4 cy2<>, out float4 cy3<>,
out float4 cz0<>, out float4 cz1<>, out float4 cz2<>, out float4 cz3<>,
out float4 cw0<>, out float4 cw1<>, out float4 cw2<>, out float4 cw3<>
);
extern int getTime();
void fillData( int inSize, float4* outData, float4 inValue )
{
int i;
for( i = 0; i < inSize*inSize; i++ )
outData[i] = inValue;
}
void testMatmult( int inSize, int inCount, int* outTime )
{
float4 ax0<inSize,inSize>; float4 ax1<inSize,inSize>; float4 ax2<inSize,inSize>; float4 ax3<inSize,inSize>;
float4 ay0<inSize,inSize>; float4 ay1<inSize,inSize>; float4 ay2<inSize,inSize>; float4 ay3<inSize,inSize>;
float4 az0<inSize,inSize>; float4 az1<inSize,inSize>; float4 az2<inSize,inSize>; float4 az3<inSize,inSize>;
float4 aw0<inSize,inSize>; float4 aw1<inSize,inSize>; float4 aw2<inSize,inSize>; float4 aw3<inSize,inSize>;
float4 bx0<inSize,inSize>; float4 bx1<inSize,inSize>; float4 bx2<inSize,inSize>; float4 bx3<inSize,inSize>;
float4 by0<inSize,inSize>; float4 by1<inSize,inSize>; float4 by2<inSize,inSize>; float4 by3<inSize,inSize>;
float4 bz0<inSize,inSize>; float4 bz1<inSize,inSize>; float4 bz2<inSize,inSize>; float4 bz3<inSize,inSize>;
float4 bw0<inSize,inSize>; float4 bw1<inSize,inSize>; float4 bw2<inSize,inSize>; float4 bw3<inSize,inSize>;
float4 cx0<inSize,inSize>; float4 cx1<inSize,inSize>; float4 cx2<inSize,inSize>; float4 cx3<inSize,inSize>;
float4 cy0<inSize,inSize>; float4 cy1<inSize,inSize>; float4 cy2<inSize,inSize>; float4 cy3<inSize,inSize>;
float4 cz0<inSize,inSize>; float4 cz1<inSize,inSize>; float4 cz2<inSize,inSize>; float4 cz3<inSize,inSize>;
float4 cw0<inSize,inSize>; float4 cw1<inSize,inSize>; float4 cw2<inSize,inSize>; float4 cw3<inSize,inSize>;
float4* data;
float4* data0;
float4* dataX;
float4* dataY;
float4* dataZ;
float4* dataW;
int i;
int start, stop;
data = (float4*)malloc( inSize*inSize*sizeof(float4) );
data0 = (float4*)malloc( inSize*inSize*sizeof(float4) );
dataX = (float4*)malloc( inSize*inSize*sizeof(float4) );
dataY = (float4*)malloc( inSize*inSize*sizeof(float4) );
dataZ = (float4*)malloc( inSize*inSize*sizeof(float4) );
dataW = (float4*)malloc( inSize*inSize*sizeof(float4) );
fillData( inSize, data0, float4(0,0,0,0) );
fillData( inSize, dataX, float4(1,0,0,0) );
fillData( inSize, dataY, float4(0,1,0,0) );
fillData( inSize, dataZ, float4(0,0,1,0) );
fillData( inSize, dataW, float4(0,0,0,1) );
streamRead( ax0, dataX );
streamRead( ax1, dataY );
streamRead( ax2, dataZ );
streamRead( ax3, dataW );
streamRead( ay0, data0 );
streamRead( ay1, data0 );
streamRead( ay2, data0 );
streamRead( ay3, data0 );
streamRead( az0, data0 );
streamRead( az1, data0 );
streamRead( az2, data0 );
streamRead( az3, data0 );
streamRead( aw0, dataX );
streamRead( aw1, dataY );
streamRead( aw2, dataZ );
streamRead( aw3, dataW );
fillData( inSize, data, float4( 0, 1, 2, 3 ) );
streamRead( bx0, data );
fillData( inSize, data, float4( 4, 5, 6, 7 ) );
streamRead( bx1, data );
fillData( inSize, data, float4( 8, 9, 10, 11 ) );
streamRead( bx2, data );
fillData( inSize, data, float4( 12, 13, 14, 15 ) );
streamRead( bx3, data );
fillData( inSize, data, float4( 16, 17, 18, 19 ) );
streamRead( by0, data );
fillData( inSize, data, float4( 20, 21, 22, 23 ) );
streamRead( by1, data );
fillData( inSize, data, float4( 24, 25, 26, 27 ) );
streamRead( by2, data );
fillData( inSize, data, float4( 28, 29, 30, 31 ) );
streamRead( by3, data );
fillData( inSize, data, float4( 32, 33, 34, 35 ) );
streamRead( bz0, data );
fillData( inSize, data, float4( 36, 37, 38, 39 ) );
streamRead( bz1, data );
fillData( inSize, data, float4( 40, 41, 42, 43 ) );
streamRead( bz2, data );
fillData( inSize, data, float4( 44, 45, 46, 47 ) );
streamRead( bz3, data );
fillData( inSize, data, float4( 48, 49, 50, 51 ) );
streamRead( bw0, data );
fillData( inSize, data, float4( 52, 53, 54, 55 ) );
streamRead( bw1, data );
fillData( inSize, data, float4( 56, 57, 58, 59 ) );
streamRead( bw2, data );
fillData( inSize, data, float4( 60, 61, 62, 63 ) );
streamRead( bw3, data );
// try to prep it all...
matmult8(
ax0, ax1, ax2, ax3,
ay0, ay1, ay2, ay3,
az0, az1, az2, az3,
aw0, aw1, aw2, aw3,
bx0, bx1, bx2, bx3,
by0, by1, by2, by3,
bz0, bz1, bz2, bz3,
bw0, bw1, bw2, bw3,
cx0, cx1, cx2, cx3,
cy0, cy1, cy2, cy3,
cz0, cz1, cz2, cz3,
cw0, cw1, cw2, cw3
);
streamWrite( cw3, data );
start = getTime();
for( i = 0; i < inCount; i++ )
{
matmult8(
ax0, ax1, ax2, ax3,
ay0, ay1, ay2, ay3,
az0, az1, az2, az3,
aw0, aw1, aw2, aw3,
bx0, bx1, bx2, bx3,
by0, by1, by2, by3,
bz0, bz1, bz2, bz3,
bw0, bw1, bw2, bw3,
cx0, cx1, cx2, cx3,
cy0, cy1, cy2, cy3,
cz0, cz1, cz2, cz3,
cw0, cw1, cw2, cw3
);
}
streamWrite( cx0, data );
// printf( "%f %f %f %f ", data->x, data->y, data->z, data->w );
streamWrite( cx1, data );
// printf( "%f %f %f %f\n", data->x, data->y, data->z, data->w );
streamWrite( cx2, data );
// printf( "%f %f %f %f ", data->x, data->y, data->z, data->w );
streamWrite( cx3, data );
// printf( "%f %f %f %f\n", data->x, data->y, data->z, data->w );
streamWrite( cy0, data );
// printf( "%f %f %f %f ", data->x, data->y, data->z, data->w );
streamWrite( cy1, data );
// printf( "%f %f %f %f\n", data->x, data->y, data->z, data->w );
streamWrite( cy2, data );
// printf( "%f %f %f %f ", data->x, data->y, data->z, data->w );
streamWrite( cy3, data );
// printf( "%f %f %f %f\n", data->x, data->y, data->z, data->w );
streamWrite( cz0, data );
// printf( "%f %f %f %f ", data->x, data->y, data->z, data->w );
streamWrite( cz1, data );
// printf( "%f %f %f %f\n", data->x, data->y, data->z, data->w );
streamWrite( cz2, data );
// printf( "%f %f %f %f ", data->x, data->y, data->z, data->w );
streamWrite( cz3, data );
// printf( "%f %f %f %f\n", data->x, data->y, data->z, data->w );
streamWrite( cw0, data );
// printf( "%f %f %f %f ", data->x, data->y, data->z, data->w );
streamWrite( cw1, data );
// printf( "%f %f %f %f\n", data->x, data->y, data->z, data->w );
streamWrite( cw2, data );
// printf( "%f %f %f %f ", data->x, data->y, data->z, data->w );
streamWrite( cw3, data );
// printf( "%f %f %f %f\n", data->x, data->y, data->z, data->w );
stop = getTime();
*outTime = stop - start;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -