Commit: 423409c940fc65a55e6306c9c2a482aec33f5d98
Parent: cac06ceb44ff0bf8626c0fcc95a8292998237d58
Author: Randy Palamar
Date: Fri, 10 Apr 2026 10:28:12 -0600
core/filter.glsl: fix demodulation for odd input sample count
During normal imaging with acquired datasets the sample count is
always even so we could safely divide by 2 and treat samples in
pairs for the purpose of demodulating. Simulated datasets however,
can have an odd number of samples. If we just divide by 2 the
input stride gets messed up and the filter shader will not load
the correct samples. If we move the division by 2 into the shader
then we can ensure that the addressing to the current sample batch
remains correct while also just dropping the final sample.
closes #39
Diffstat:
2 files changed, 41 insertions(+), 35 deletions(-)
diff --git a/beamformer_core.c b/beamformer_core.c
@@ -355,13 +355,13 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb, A
BeamformerComputeGraphNode *root_node = push_struct(&scratch, BeamformerComputeGraphNode);
root_node->kind = BeamformerShaderKind_Count;
root_node->input_data_kind = input_data_kind;
- root_node->input_stride.x = 1; // Sample Stride
- root_node->input_stride.y = input_sample_count * acquisition_count; // Channel Stride
- root_node->input_stride.z = input_sample_count; // Receive Event Stride
+ root_node->input_stride.x = 1; // Sample Stride
+ root_node->input_stride.y = pb->parameters.sample_count * acquisition_count; // Channel Stride
+ root_node->input_stride.z = pb->parameters.sample_count; // Receive Event Stride
root_node->output_data_kind = input_data_kind;
- root_node->output_stride.x = 1; // Sample Stride
- root_node->output_stride.y = input_sample_count * acquisition_count; // Channel Stride
- root_node->output_stride.z = input_sample_count; // Receive Event Stride
+ root_node->output_stride.x = 1; // Sample Stride
+ root_node->output_stride.y = pb->parameters.sample_count * acquisition_count; // Channel Stride
+ root_node->output_stride.z = pb->parameters.sample_count; // Receive Event Stride
root_node->next = root_node->prev = root_node;
for EachIndex(pb->pipeline.shader_count, it) {
@@ -721,6 +721,7 @@ stream_append_shader_header(Stream *s, i32 reloadable_index, BeamformerShaderDes
"#define f32 float32_t\n"
"#define f16 float16_t\n"
"#define s32 int32_t\n"
+ "#define u64 uint64_t\n"
"#define u32 uint32_t\n"
"#define s16 int16_t\n"
"#define u16 uint16_t\n"
@@ -758,17 +759,18 @@ stream_append_shader_header(Stream *s, i32 reloadable_index, BeamformerShaderDes
}
if (sd) {
- if (sd->input_data_kind != BeamformerDataKind_Count) {
- stream_append_s8s(s, s8("#define InputDataType "),
- beamformer_data_kind_glsl_type[sd->input_data_kind], s8("\n"));
- stream_append_s8s(s, s8("#define InputDataKind DataKind_"),
- beamformer_data_kind_s8[sd->input_data_kind], s8("\n"));
- }
- if (sd->output_data_kind != BeamformerDataKind_Count) {
- stream_append_s8s(s, s8("#define OutputDataType "),
- beamformer_data_kind_glsl_type[sd->output_data_kind], s8("\n"));
- stream_append_s8s(s, s8("#define OutputDataKind DataKind_"),
- beamformer_data_kind_s8[sd->output_data_kind], s8("\n"));
+ BeamformerDataKind data_kinds[] = {sd->input_data_kind, sd->output_data_kind};
+ s8 line_prefixes[] = {s8_comp("Input"), s8_comp("Output")};
+ for EachElement(data_kinds, it) {
+ if (data_kinds[it] != BeamformerDataKind_Count) {
+ stream_append_s8s(s, s8("#define "), line_prefixes[it], s8("DataType "),
+ beamformer_data_kind_glsl_type[data_kinds[it]],
+ s8("\n#define "), line_prefixes[it], s8("DataKind DataKind_"),
+ beamformer_data_kind_s8[data_kinds[it]],
+ s8("\n#define "), line_prefixes[it], s8("DataKindByteSize "));
+ stream_append_u64(s, beamformer_data_kind_byte_size[data_kinds[it]]);
+ stream_append_byte(s, '\n');
+ }
}
stream_append_byte(s, '\n');
diff --git a/shaders/filter.glsl b/shaders/filter.glsl
@@ -63,12 +63,6 @@ SAMPLE_TYPE rotate_iq(SAMPLE_TYPE iq, uint index)
}
#endif
-SAMPLE_TYPE sample_rf(uint index)
-{
- SAMPLE_TYPE result = SAMPLE_TYPE(Input(input_data).x[index]);
- return result;
-}
-
shared SAMPLE_TYPE rf[DecimationRate * gl_WorkGroupSize.x + FilterLength - 1];
void main()
@@ -77,7 +71,6 @@ void main()
uint channel = gl_GlobalInvocationID.y;
uint transmit = gl_GlobalInvocationID.z;
- uint in_offset = InputChannelStride * channel + InputTransmitStride * transmit;
uint thread_index = gl_LocalInvocationIndex;
uint thread_count = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
/////////////////////////
@@ -85,7 +78,18 @@ void main()
{
bool offset_wraps = (DecimationRate * gl_WorkGroupID.x * gl_WorkGroupSize.x) < (FilterLength - 1);
- in_offset += DecimationRate * gl_WorkGroupID.x * gl_WorkGroupSize.x - (FilterLength - 1);
+ u32 in_offset = InputDataKindByteSize * (InputChannelStride * channel + InputTransmitStride * transmit);
+ // NOTE(rnp): when demodulating we want to load 2 elements at a time but the
+ // input strides were specified in terms of a single element. therefore we
+ // must divide this by two. by doing this here we can gracefully handle
+ // the case where there are an odd number of samples (this drops the last one).
+ if (Demodulate != 0)
+ in_offset /= 2;
+
+ // NOTE(rnp): broken out to avoid overflow from the subtraction
+ u64 input_address = input_data + in_offset;
+ input_address += InputDataKindByteSize * (DecimationRate * gl_WorkGroupID.x * gl_WorkGroupSize.x);
+ input_address -= InputDataKindByteSize * (FilterLength - 1);
uint total_samples = rf.length();
uint samples_per_thread = total_samples / thread_count;
@@ -95,15 +99,14 @@ void main()
const SAMPLE_TYPE scale = SAMPLE_TYPE(bool(ComplexFilter) ? 1 : sqrt(2.0f));
for (uint i = 0; i < samples_this_thread; i++) {
uint index = thread_count * i + thread_index;
- if (offset_wraps && index < FilterLength - 1) {
- rf[index] = SAMPLE_TYPE(0);
- } else {
+ SAMPLE_TYPE s = SAMPLE_TYPE(0);
+ if (!offset_wraps || index >= FilterLength - 1) {
+ s = SAMPLE_TYPE(Input(input_address).x[index]);
#if Demodulate
- rf[index] = scale * rotate_iq(sample_rf(in_offset + index) * SAMPLE_TYPE(1, -1), index);
- #else
- rf[index] = sample_rf(in_offset + index);
+ s = scale * rotate_iq(s * SAMPLE_TYPE(1, -1), index);
#endif
}
+ rf[index] = s;
}
}
barrier();
@@ -116,15 +119,16 @@ void main()
u32 out_offset = OutputChannelStride * channel +
OutputTransmitStride * transmit +
- OutputSampleStride * out_sample;
+ OutputSampleStride * out_sample +
+ output_element_offset;
#if BatchSampleCount
// NOTE(rnp): deinterleave
- output_data[output_element_offset + out_offset] = OutputDataType(result.x);
+ output_data[out_offset] = OutputDataType(result.x);
out_offset += BatchSampleCount;
- output_data[output_element_offset + out_offset] = OutputDataType(result.y);
+ output_data[out_offset] = OutputDataType(result.y);
#else
- output_data[output_element_offset + out_offset] = OutputDataType(result);
+ output_data[out_offset] = OutputDataType(result);
#endif
}
}