📄 at_gather_fetch.br
字号:
// at_gather_fetch.br// test address-translation code for// fetching elements of a gather stream#include <stdio.h>// arbitrary weird stream sizes:#define RESULT_X 33#define RESULT_Y 13#define RESULT_Z 20#define RESULT_W 9#define GATHER_X 20#define GATHER_Y 9#define GATHER_Z 33#define GATHER_W 13kernel void clear4( out float4 result<> ) { result = -1; }kernel void getIndexStream4( out float4 result<> ) { result = (indexof result).xyzw;}kernel void gather4( float4 g[][][][], out float4 result<> ){ float4 i = (indexof result).xyzw; // swizzle the index i = i.zwxy; result = g[i];}void checkItem( const char* streamName, const char* dimName, int expected, float actual, int x, int y, int z, int w ){ if( (int)actual != expected ) { printf("FAIL: %s[%d,%d,%d,%d].%s : expected %d got %f\n", streamName, w, z, y, x, dimName, expected, actual ); exit(0); }// else// printf("%s[%d,%d,%d,%d].%s : %f\n", streamName, w, z, y, x, dimName, actual );}void checkGather( float* inData ){ float* data = inData; int x, y, z, w; for( w = 0; w < GATHER_W; w++ ) for( z = 0; z < GATHER_Z; z++ ) for( y = 0; y < GATHER_Y; y++ ) for( x = 0; x < GATHER_X; x++ ) { checkItem( "gather", "x", x, *data++, x, y, z, w ); checkItem( "gather", "y", y, *data++, x, y, z, w ); checkItem( "gather", "z", z, *data++, x, y, z, w ); checkItem( "gather", "w", w, *data++, x, y, z, w ); }}void checkResult( float* inData ){ float* data = inData; int x, y, z, w; int xg, yg, zg, wg; for( w = 0; w < RESULT_W; w++ ) for( z = 0; z < RESULT_Z; z++ ) for( y = 0; y < RESULT_Y; y++ ) for( x = 0; x < RESULT_X; x++ ) { // swizzle the index xg = z; yg = w; zg = x; wg = y; checkItem( "result", "x", xg, *data++, x, y, z, w ); checkItem( "result", "y", yg, *data++, x, y, z, w ); checkItem( "result", "z", zg, *data++, x, y, z, w ); checkItem( "result", "w", wg, *data++, x, y, z, w ); // if( x > 5 ) exit(0); }}int main( int argc, char** argv ){ float4 result< RESULT_W, RESULT_Z, RESULT_Y, RESULT_X >; float4 g< GATHER_W, GATHER_Z, GATHER_Y, GATHER_X >; float* data_result = (float*)malloc( RESULT_W*RESULT_Z*RESULT_Y*RESULT_X*sizeof(float4) ); float* data_g = (float*)malloc( GATHER_W*GATHER_Z*GATHER_Y*GATHER_X*sizeof(float4) ); getIndexStream4( g ); streamWrite( g, data_g ); checkGather( data_g ); clear4( result ); gather4( g, result ); streamWrite( result, data_result ); checkResult( data_result ); printf( "pass\n" ); return 0;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -