filter.glsl (4517B)
1 /* See LICENSE for license details. */ 2 #if (InputDataKind == DataKind_Int16Complex || \ 3 (InputDataKind == DataKind_Int16 && Demodulate) || \ 4 (InputDataKind == DataKind_Float16 && Demodulate)) 5 #define SAMPLE_TYPE f16vec2 6 #elif InputDataKind == DataKind_Int16 7 #define SAMPLE_TYPE f16 8 #elif InputDataKind == DataKind_Float32 && Demodulate 9 #define SAMPLE_TYPE f32vec2 10 #endif 11 12 #ifndef SAMPLE_TYPE 13 #define SAMPLE_TYPE InputDataType 14 #endif 15 16 #define ComplexSampleType (InputDataKind == DataKind_Float32Complex || \ 17 InputDataKind == DataKind_Float16Complex || \ 18 InputDataKind == DataKind_Int16Complex || \ 19 Demodulate) 20 #if ComplexSampleType 21 #define RESULT_TYPE f32vec2 22 #else 23 #define RESULT_TYPE f32 24 #endif 25 26 #if ComplexFilter 27 #define FILTER_TYPE f32vec2 28 #else 29 #define FILTER_TYPE f32 30 #endif 31 32 #if ComplexFilter && ComplexSampleType 33 #define apply_filter(iq, h) complex_mul(f32vec2(iq), f32vec2(h)) 34 #else 35 #define apply_filter(iq, h) ((iq) * (h)) 36 #endif 37 38 layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer Input { 39 InputDataType x[]; 40 }; 41 42 layout(set = ShaderResourceKind_Buffer, binding = ShaderBufferSlot_PingPong) buffer Output { 43 OutputDataType output_data[]; 44 }; 45 46 layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer Filter { 47 FILTER_TYPE values[FilterLength]; 48 }; 49 50 f32vec2 complex_mul(f32vec2 a, f32vec2 b) 51 { 52 mat2 m = mat2(b.x, b.y, -b.y, b.x); 53 f32vec2 result = m * a; 54 return result; 55 } 56 57 #if Demodulate 58 SAMPLE_TYPE rotate_iq(SAMPLE_TYPE iq, uint index) 59 { 60 float arg = radians(360) * DemodulationFrequency * index / SamplingFrequency; 61 SAMPLE_TYPE result = SAMPLE_TYPE(complex_mul(iq, f32vec2(cos(arg), -sin(arg)))); 62 return result; 63 } 64 #endif 65 66 shared SAMPLE_TYPE rf[DecimationRate * gl_WorkGroupSize.x + FilterLength - 1]; 67 68 void main() 69 { 70 uint out_sample = gl_GlobalInvocationID.x; 71 uint channel = gl_GlobalInvocationID.y; 72 uint transmit = gl_GlobalInvocationID.z; 73 74 uint thread_index = gl_LocalInvocationIndex; 75 uint thread_count = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z; 76 ///////////////////////// 77 // NOTE: sample caching 78 { 79 bool offset_wraps = (DecimationRate * gl_WorkGroupID.x * gl_WorkGroupSize.x) < (FilterLength - 1); 80 81 u32 in_offset = InputDataKindByteSize * (InputChannelStride * channel + InputTransmitStride * transmit); 82 // NOTE(rnp): when demodulating we want to load 2 elements at a time but the 83 // input strides were specified in terms of a single element. therefore we 84 // must divide this by two. by doing this here we can gracefully handle 85 // the case where there are an odd number of samples (this drops the last one). 86 if (Demodulate != 0) 87 in_offset /= 2; 88 89 // NOTE(rnp): broken out to avoid overflow from the subtraction 90 u64 input_address = input_data + in_offset; 91 input_address += InputDataKindByteSize * (DecimationRate * gl_WorkGroupID.x * gl_WorkGroupSize.x); 92 input_address -= InputDataKindByteSize * (FilterLength - 1); 93 94 uint total_samples = rf.length(); 95 uint samples_per_thread = total_samples / thread_count; 96 uint leftover_count = total_samples % thread_count; 97 uint samples_this_thread = samples_per_thread + uint(thread_index < leftover_count); 98 99 const SAMPLE_TYPE scale = SAMPLE_TYPE(bool(ComplexFilter) ? 1 : sqrt(2.0f)); 100 for (uint i = 0; i < samples_this_thread; i++) { 101 uint index = thread_count * i + thread_index; 102 SAMPLE_TYPE s = SAMPLE_TYPE(0); 103 if (!offset_wraps || index >= FilterLength - 1) { 104 s = SAMPLE_TYPE(Input(input_address).x[index]); 105 #if Demodulate 106 s = scale * rotate_iq(s * SAMPLE_TYPE(1, -1), index); 107 #endif 108 } 109 rf[index] = s; 110 } 111 } 112 barrier(); 113 114 if (out_sample < SampleCount / DecimationRate) { 115 RESULT_TYPE result = RESULT_TYPE(0); 116 uint offset = DecimationRate * thread_index; 117 for (uint j = 0; j < FilterLength; j++) 118 result += apply_filter(rf[offset + j], Filter(filter_coefficients).values[j]); 119 120 u32 out_offset = OutputChannelStride * channel + 121 OutputTransmitStride * transmit + 122 OutputSampleStride * out_sample + 123 output_element_offset; 124 125 #if BatchSampleCount 126 // NOTE(rnp): deinterleave 127 output_data[out_offset] = OutputDataType(result.x); 128 out_offset += BatchSampleCount; 129 output_data[out_offset] = OutputDataType(result.y); 130 #else 131 output_data[out_offset] = OutputDataType(result); 132 #endif 133 } 134 }