📄 cusvd_kernel.cu
字号:
#ifndef _BITONIC_KERNEL_H_
#define _BITONIC_KERNEL_H_
#define ADD if (tid < 128) { temp[tid] += temp[tid + 128];} __syncthreads();\
if (tid < 64) { temp[tid] += temp[tid + 64];} __syncthreads();\
if (tid < 32) { temp[tid] += temp[tid + 32];} __syncthreads();\
if (tid < 16) { temp[tid] += temp[tid + 16];} __syncthreads();\
if (tid < 8) { temp[tid] += temp[tid + 8];} __syncthreads();\
if (tid < 4) { temp[tid] += temp[tid + 4];} __syncthreads();\
if (tid < 2) { temp[tid] += temp[tid + 2];} __syncthreads();\
if (tid < 1) { temp[tid] += temp[tid + 1];}\
#define JUDGE if( c < 0.0000001 && c > -0.0000001 )\
{cs = 1.0f;ss = 0.0f;}\
else\
{l = ( ss - cs ) * 0.5f / c;t =signbit(l)*( fabsf(l) + sqrtf( 1.0f + l*l));cs = rsqrtf(1.0f + t*t);ss = t * cs;}
__global__ static void bjrot(float * d_w_o, float * d_w_i,float * d_u_o, float * d_u_i, float * d_index)
{
const int tid = threadIdx.x;
const int bstart = (blockIdx.x<<1);
const int i_index0 = __mul24(bstart,NUM);
const int i_index1 = __mul24((bstart + 1),NUM);
const int o_index0 = __mul24(d_index[bstart], NUM);
const int o_index1 = __mul24(d_index[bstart + 1], NUM);
__shared__ float value[2][NUM];
__shared__ float u[2][NUM];
__shared__ float cs;
__shared__ float ss;
__shared__ float c;
__shared__ float l;
__shared__ float t;
__shared__ float temp[256];
value[0][tid] = d_w_i[i_index0 + tid];
value[0][256 + tid] = d_w_i[i_index0 + 256 + tid];
value[1][tid] = d_w_i[i_index1 + tid];
value[1][256 + tid] = d_w_i[i_index1 + 256 + tid];
u[0][tid] = d_u_i[i_index0 + tid];
u[0][256 + tid] = d_u_i[i_index0 + 256 + tid];
u[1][tid] = d_u_i[i_index1 + tid];
u[1][256 + tid] = d_u_i[i_index1 + 256 + tid];
__syncthreads();
temp[tid] = value[0][tid] * value[0][tid];
temp[tid] += value[0][256 + tid] * value[0][256 + tid];
__syncthreads();
ADD;
__syncthreads();
cs = temp[0];
temp[tid] = value[1][tid] * value[1][tid];
temp[tid] += value[1][256 + tid] * value[1][256 + tid];
__syncthreads();
ADD;
__syncthreads();
ss = temp[0];
temp[tid] = value[0][tid] * value[1][tid];
temp[tid] += value[0][256 + tid] * value[1][256 + tid];
__syncthreads();
ADD;
__syncthreads();
c = temp[0];
JUDGE;
temp[tid] = value[0][tid] * cs - value[1][tid] * ss;
value[1][tid] = value[0][tid] * ss + value[1][tid] * cs;
value[0][tid] = temp[tid];
temp[tid] = value[0][256 + tid] * cs - value[1][256 + tid] * ss;
value[1][256 + tid] = value[0][256 + tid] * ss + value[1][256 + tid] * cs;
value[0][256 + tid] = temp[tid];
temp[tid] = u[0][tid] * cs - u[1][tid] * ss;
u[1][tid] *= cs;
u[1][tid] += u[0][tid] * ss;
u[0][tid] = temp[tid];
temp[tid] = u[0][256 + tid] * cs - u[1][256 + tid] * ss;
u[1][256 + tid] *= cs;
u[1][256 + tid] += u[0][256 + tid] * ss;
u[0][256 + tid] = temp[tid];
d_w_o[o_index0 + tid] = value[0][tid];
d_w_o[o_index0 + 256 + tid] = value[0][256 + tid];
d_w_o[o_index1 + tid] = value[1][tid];
d_w_o[o_index1 + 256 + tid] = value[1][256 + tid];
d_u_o[o_index0 + tid] = u[0][tid];
d_u_o[o_index0 + 256 + tid] = u[0][256 + tid];
d_u_o[o_index1 + tid] = u[1][tid];
d_u_o[o_index1 + 256 + tid] = u[1][256 + tid];
}
#endif // _BITONIC_KERNEL_H_
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -