decode.glsl (5505B)
1 /* See LICENSE for license details. */ 2 3 #if CooperativeMatrix 4 #extension GL_KHR_cooperative_matrix : require 5 #extension GL_KHR_memory_scope_semantics : require 6 #endif 7 8 layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer RF { 9 InputDataType x[]; 10 }; 11 12 layout(std430, buffer_reference, buffer_reference_align = 64) restrict writeonly buffer Output { 13 OutputDataType x[]; 14 }; 15 16 layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer Hadamard { 17 f16 x[]; 18 }; 19 20 OutputDataType sample_rf_data(u32 index) 21 { 22 OutputDataType result = OutputDataType(RF(rf_buffer).x[index]); 23 return result; 24 } 25 26 #if UseSharedMemory 27 28 shared InputDataType rf[gl_WorkGroupSize.y][TransmitCount]; 29 void run_decode_large(void) 30 { 31 u32 transmit = gl_GlobalInvocationID.x * ToProcess; 32 u32 channel = gl_GlobalInvocationID.y; 33 u32 time_sample = gl_GlobalInvocationID.z; 34 35 const u32 samples_per_thread = TransmitCount / gl_WorkGroupSize.x; 36 const u32 leftover_samples = TransmitCount % gl_WorkGroupSize.x; 37 38 u32 thread_index_x = gl_LocalInvocationID.x; 39 u32 samples_this_thread = samples_per_thread + u32(thread_index_x < leftover_samples); 40 41 u32 rf_offset = TransmitCount * ChunkChannelCount * gl_WorkGroupID.z + TransmitCount * channel; 42 43 for (u32 i = 0; i < samples_this_thread; i++) { 44 u32 index = i * gl_WorkGroupSize.x + thread_index_x; 45 rf[gl_LocalInvocationID.y][index] = RF(rf_buffer).x[rf_offset + index]; 46 } 47 48 barrier(); 49 50 OutputDataType result[ToProcess]; 51 if (time_sample < OutputTransmitStride) { 52 for (s32 i = 0; i < ToProcess; i++) 53 result[i] = OutputDataType(0); 54 55 for (s32 j = 0; j < TransmitCount; j++) { 56 OutputDataType s = OutputDataType(rf[gl_LocalInvocationID.y][j]); 57 for (s32 i = 0; i < ToProcess; i++) 58 result[i] += s * Hadamard(hadamard_buffer).x[TransmitCount * j + (i + transmit)]; 59 } 60 61 for (uint i = 0; i < ToProcess; i++) 62 result[i] /= float(TransmitCount); 63 } 64 65 /* NOTE(rnp): DO NOT combine with above; compiler shits the bed on TransmitCount == 80 66 * and it kills performance. reinvestigate when we further optimize */ 67 if (time_sample < OutputTransmitStride) { 68 uint out_off = OutputChannelStride * channel + 69 OutputTransmitStride * transmit + 70 OutputSampleStride * time_sample; 71 72 for (uint i = 0; i < ToProcess; i++, out_off += OutputTransmitStride) 73 if (TransmitCount % (gl_WorkGroupSize.x * ToProcess) == 0 || transmit + i < TransmitCount) 74 Output(output_buffer).x[out_off] = result[i]; 75 } 76 } 77 #endif 78 79 #if CooperativeMatrix 80 81 void run_decode_coop(void) 82 { 83 #if UseSharedMemory 84 #else 85 86 u32vec2 tile_index = gl_WorkGroupID.xy; 87 u32 time_sample = gl_WorkGroupID.z; 88 89 coopmat<f16, gl_ScopeSubgroup, CooperativeMatrixM, CooperativeMatrixK, gl_MatrixUseA> rf_matrix; 90 coopmat<f16, gl_ScopeSubgroup, CooperativeMatrixK, CooperativeMatrixN, gl_MatrixUseB> hadamard_matrix; 91 coopmat<f32, gl_ScopeSubgroup, CooperativeMatrixM, CooperativeMatrixN, gl_MatrixUseAccumulator> result; 92 result = coopmat<f32, gl_ScopeSubgroup, CooperativeMatrixM, CooperativeMatrixN, gl_MatrixUseAccumulator>(0.0f); 93 94 u32 result_row = CooperativeMatrixM * tile_index.y; 95 u32 result_col = CooperativeMatrixN * tile_index.x; 96 97 u32 offset = ChunkChannelCount * TransmitCount * time_sample; 98 99 for (u32 k = 0; k < TransmitCount; k += CooperativeMatrixK) { 100 u32 rf_tile_row = CooperativeMatrixM * tile_index.y; 101 u32 rf_tile_col = k; 102 coopMatLoad(rf_matrix, RF(rf_buffer).x, offset + TransmitCount * rf_tile_row + rf_tile_col, 103 TransmitCount, gl_CooperativeMatrixLayoutRowMajor); 104 105 u32 hadamard_tile_row = k; 106 u32 hadamard_tile_col = CooperativeMatrixN * tile_index.x; 107 coopMatLoad(hadamard_matrix, Hadamard(hadamard_buffer).x, 108 TransmitCount * hadamard_tile_row + hadamard_tile_col, TransmitCount, 109 gl_CooperativeMatrixLayoutRowMajor); 110 111 result = coopMatMulAdd(rf_matrix, hadamard_matrix, result); 112 } 113 114 for (s32 i = 0; i < result.length(); i++) 115 result[i] = result[i] / f32(TransmitCount); 116 117 Output out_buffer = Output(output_buffer); 118 coopMatStore(result, out_buffer.x, offset + TransmitCount * result_row + result_col, 119 TransmitCount, gl_CooperativeMatrixLayoutRowMajor); 120 #endif 121 } 122 #endif 123 124 void run_decode_small(void) 125 { 126 u32 time_sample = gl_GlobalInvocationID.x; 127 u32 channel = gl_GlobalInvocationID.y; 128 u32 rf_offset = TransmitCount * ChunkChannelCount * time_sample + TransmitCount * channel; 129 130 if (time_sample < OutputTransmitStride) { 131 InputDataType rf[TransmitCount]; 132 for (s32 j = 0; j < TransmitCount; j++) 133 rf[j] = RF(rf_buffer).x[rf_offset + j]; 134 135 OutputDataType result[TransmitCount]; 136 for (s32 j = 0; j < TransmitCount; j++) 137 result[j] = OutputDataType(0); 138 139 for (s32 i = 0; i < TransmitCount; i++) { 140 OutputDataType s = OutputDataType(rf[i]); 141 for (s32 j = 0; j < TransmitCount; j++) { 142 result[j] += s * Hadamard(hadamard_buffer).x[TransmitCount * i + j]; 143 } 144 } 145 146 for (int i = 0; i < TransmitCount; i++) 147 result[i] /= float(TransmitCount); 148 149 uint out_off = OutputChannelStride * channel + 150 OutputSampleStride * time_sample; 151 for (int i = 0; i < TransmitCount; i++, out_off += OutputTransmitStride) 152 Output(output_buffer).x[out_off] = result[i]; 153 } 154 } 155 156 void main() 157 { 158 switch (DecodeMode) { 159 case DecodeMode_Hadamard:{ 160 #if CooperativeMatrix 161 run_decode_coop(); 162 #elif UseSharedMemory 163 run_decode_large(); 164 #else 165 run_decode_small(); 166 #endif 167 }break; 168 } 169 }