ogl_beamforming

Ultrasound Beamforming Implemented with OpenGL
git clone anongit@rnpnr.xyz:ogl_beamforming.git
Log | Files | Refs | Feed | Submodules | README | LICENSE

filter.glsl (4517B)


      1 /* See LICENSE for license details. */
      2 #if  (InputDataKind == DataKind_Int16Complex          || \
      3       (InputDataKind == DataKind_Int16 && Demodulate) || \
      4       (InputDataKind == DataKind_Float16 && Demodulate))
      5   #define SAMPLE_TYPE f16vec2
      6 #elif InputDataKind == DataKind_Int16
      7   #define SAMPLE_TYPE f16
      8 #elif InputDataKind == DataKind_Float32 && Demodulate
      9   #define SAMPLE_TYPE f32vec2
     10 #endif
     11 
     12 #ifndef SAMPLE_TYPE
     13   #define SAMPLE_TYPE InputDataType
     14 #endif
     15 
     16 #define ComplexSampleType (InputDataKind == DataKind_Float32Complex || \
     17                            InputDataKind == DataKind_Float16Complex || \
     18                            InputDataKind == DataKind_Int16Complex   || \
     19                            Demodulate)
     20 #if ComplexSampleType
     21   #define RESULT_TYPE f32vec2
     22 #else
     23   #define RESULT_TYPE f32
     24 #endif
     25 
     26 #if ComplexFilter
     27   #define FILTER_TYPE f32vec2
     28 #else
     29   #define FILTER_TYPE f32
     30 #endif
     31 
     32 #if ComplexFilter && ComplexSampleType
     33   #define apply_filter(iq, h) complex_mul(f32vec2(iq), f32vec2(h))
     34 #else
     35   #define apply_filter(iq, h) ((iq) * (h))
     36 #endif
     37 
     38 layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer Input {
     39 	InputDataType x[];
     40 };
     41 
     42 layout(set = ShaderResourceKind_Buffer, binding = ShaderBufferSlot_PingPong) buffer Output {
     43 	OutputDataType output_data[];
     44 };
     45 
     46 layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer Filter {
     47 	FILTER_TYPE values[FilterLength];
     48 };
     49 
     50 f32vec2 complex_mul(f32vec2 a, f32vec2 b)
     51 {
     52 	mat2 m = mat2(b.x, b.y, -b.y, b.x);
     53 	f32vec2 result = m * a;
     54 	return result;
     55 }
     56 
     57 #if Demodulate
     58 SAMPLE_TYPE rotate_iq(SAMPLE_TYPE iq, uint index)
     59 {
     60 	float arg          = radians(360) * DemodulationFrequency * index / SamplingFrequency;
     61 	SAMPLE_TYPE result = SAMPLE_TYPE(complex_mul(iq, f32vec2(cos(arg), -sin(arg))));
     62 	return result;
     63 }
     64 #endif
     65 
     66 shared SAMPLE_TYPE rf[DecimationRate * gl_WorkGroupSize.x + FilterLength - 1];
     67 
     68 void main()
     69 {
     70 	uint out_sample = gl_GlobalInvocationID.x;
     71 	uint channel    = gl_GlobalInvocationID.y;
     72 	uint transmit   = gl_GlobalInvocationID.z;
     73 
     74 	uint thread_index = gl_LocalInvocationIndex;
     75 	uint thread_count = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
     76 	/////////////////////////
     77 	// NOTE: sample caching
     78 	{
     79 		bool offset_wraps = (DecimationRate * gl_WorkGroupID.x * gl_WorkGroupSize.x) < (FilterLength - 1);
     80 
     81 		u32 in_offset = InputDataKindByteSize * (InputChannelStride * channel + InputTransmitStride * transmit);
     82 		// NOTE(rnp): when demodulating we want to load 2 elements at a time but the
     83 		// input strides were specified in terms of a single element. therefore we
     84 		// must divide this by two. by doing this here we can gracefully handle
     85 		// the case where there are an odd number of samples (this drops the last one).
     86 		if (Demodulate != 0)
     87 			in_offset /= 2;
     88 
     89 		// NOTE(rnp): broken out to avoid overflow from the subtraction
     90 		u64 input_address = input_data + in_offset;
     91 		input_address += InputDataKindByteSize * (DecimationRate * gl_WorkGroupID.x * gl_WorkGroupSize.x);
     92 		input_address -= InputDataKindByteSize * (FilterLength - 1);
     93 
     94 		uint total_samples       = rf.length();
     95 		uint samples_per_thread  = total_samples / thread_count;
     96 		uint leftover_count      = total_samples % thread_count;
     97 		uint samples_this_thread = samples_per_thread + uint(thread_index < leftover_count);
     98 
     99 		const SAMPLE_TYPE scale = SAMPLE_TYPE(bool(ComplexFilter) ? 1 : sqrt(2.0f));
    100 		for (uint i = 0; i < samples_this_thread; i++) {
    101 			uint index = thread_count * i + thread_index;
    102 			SAMPLE_TYPE s = SAMPLE_TYPE(0);
    103 			if (!offset_wraps || index >= FilterLength - 1) {
    104 				s = SAMPLE_TYPE(Input(input_address).x[index]);
    105 				#if Demodulate
    106 				s = scale * rotate_iq(s * SAMPLE_TYPE(1, -1), index);
    107 				#endif
    108 			}
    109 			rf[index] = s;
    110 		}
    111 	}
    112 	barrier();
    113 
    114 	if (out_sample < SampleCount / DecimationRate) {
    115 		RESULT_TYPE result = RESULT_TYPE(0);
    116 		uint offset = DecimationRate * thread_index;
    117 		for (uint j = 0; j < FilterLength; j++)
    118 			result += apply_filter(rf[offset + j], Filter(filter_coefficients).values[j]);
    119 
    120 		u32 out_offset = OutputChannelStride  * channel +
    121 		                 OutputTransmitStride * transmit +
    122 		                 OutputSampleStride   * out_sample +
    123 		                 output_element_offset;
    124 
    125 		#if BatchSampleCount
    126 		// NOTE(rnp): deinterleave
    127 		output_data[out_offset] = OutputDataType(result.x);
    128 		out_offset += BatchSampleCount;
    129 		output_data[out_offset] = OutputDataType(result.y);
    130 		#else
    131 		output_data[out_offset] = OutputDataType(result);
    132 		#endif
    133 	}
    134 }