reshape.glsl (2088B)
1 /* See LICENSE for license details. */ 2 3 #if InputDataKind == DataKind_Float32Complex 4 #define Input Float32Complex 5 #elif InputDataKind == DataKind_Float32 6 #define Input Float32 7 #elif InputDataKind == DataKind_Float16Complex || InputDataKind == DataKind_Int16Complex 8 #define Input Int16Complex 9 #elif InputDataKind == DataKind_Float16 || InputDataKind == DataKind_Int16 10 #define Input Int16 11 #else 12 #error unsupported data kind for Reshape 13 #endif 14 15 #if OutputDataKind == DataKind_Float32Complex 16 #define Output Float32Complex 17 #define OutputKind f32vec2 18 #elif OutputDataKind == DataKind_Float32 19 #define Output Float32 20 #define OutputKind f32 21 #elif OutputDataKind == DataKind_Float16Complex || OutputDataKind == DataKind_Int16Complex 22 #define Output Int16Complex 23 #define OutputKind s16vec2 24 #elif OutputDataKind == DataKind_Float16 || OutputDataKind == DataKind_Int16 25 #define Output Int16 26 #define OutputKind s16 27 #else 28 #error unsupported data kind for Reshape 29 #endif 30 31 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Int16 { 32 s16 x[]; 33 }; 34 35 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Int16Complex { 36 s16vec2 x[]; 37 }; 38 39 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Float32 { 40 f32 x[]; 41 }; 42 43 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Float32Complex { 44 f32vec2 x[]; 45 }; 46 47 void main(void) 48 { 49 if (all(lessThan(gl_GlobalInvocationID, uvec3(SizeX, SizeY, SizeZ)))) { 50 u32 x = gl_GlobalInvocationID.x; 51 u32 y = gl_GlobalInvocationID.y; 52 u32 z = gl_GlobalInvocationID.z; 53 54 u32 input_index = InputStrideX * x + InputStrideY * y + InputStrideZ * z; 55 u32 output_index = OutputStrideX * x + OutputStrideY * y + OutputStrideZ * z; 56 57 OutputKind out_value = OutputKind(0); 58 59 #if Interleave 60 out_value[0] = Input(left_input_buffer).x[input_index]; 61 out_value[1] = Input(right_input_buffer).x[input_index]; 62 #else 63 out_value = Input(left_input_buffer).x[input_index]; 64 #endif 65 66 Output(output_buffer).x[output_index] = out_value; 67 } 68 }