Commit: da92b5ec7915c6ad2fd1401bc40349e2706c9c2c
Parent: 9b71d439faf4ff9a5e4f5a92a036e0ffd0be29b9
Author: Randy Palamar
Date: Mon, 16 Mar 2026 13:44:55 -0600
core: use cooperative matrix extension for decoding
When the code for loading matrix elements is written correctly
cooperative matrix operations are able to get very close to
maximum device FLOP utilization. This is true for both NVIDIA and
AMD. On AMD it is the only way to access the WMMA (wave matrix
multiply accumulate) instructions from a non hip shader.
Currently this only implements the simplest version without
prestaging into shared memory. That will come soon. Even without
that this version provides a nice performance boost for the
standard case (128 x 128).
The handling around reshaping and data types in this commit is
kind of a mess and a lot of less common cases are likely broken.
This will be addressed next.
Diffstat:
10 files changed, 766 insertions(+), 283 deletions(-)
diff --git a/beamformer.meta b/beamformer.meta
@@ -1,4 +1,4 @@
-@Constant(16) ChannelChunkCount
+@Constant(16) ChunkChannelCount
@Constant(4) FilterSlots
@Constant(4096) MaxBacklogFrames
@Constant(256) MaxChannelCount
@@ -289,17 +289,18 @@
@Bake
{
[DataKind data_kind U32]
- [DilateOutput dilate_output U32]
- [UseSharedMemory use_shared_memory U32]
+ [UseSharedMemory use_shared_memory B32]
[DecodeMode decode_mode U32]
- [InputChannelStride input_channel_stride U32]
- [InputSampleStride input_sample_stride U32]
- [InputTransmitStride input_transmit_stride U32]
[OutputChannelStride output_channel_stride U32]
[OutputSampleStride output_sample_stride U32]
[OutputTransmitStride output_transmit_stride U32]
[ToProcess to_process U32]
[TransmitCount transmit_count U32]
+ [ChunkChannelCount chunk_channel_count U32]
+ [CooperativeMatrix cooperative_matrix B32]
+ [CooperativeMatrixM cooperative_matrix_m U32]
+ [CooperativeMatrixN cooperative_matrix_n U32]
+ [CooperativeMatrixK cooperative_matrix_k U32]
}
@PushConstants
@@ -335,6 +336,7 @@
[OutputSampleStride output_sample_stride U32]
[OutputTransmitStride output_transmit_stride U32]
[SampleCount sample_count U32]
+ [BatchSampleCount batch_sample_count U32]
[DemodulationFrequency demodulation_frequency F32]
[SamplingFrequency sampling_frequency F32]
@@ -371,7 +373,7 @@
[AcquisitionCount acquisition_count U32]
[AcquisitionKind acquisition_kind U32]
[ChannelCount channel_count U32]
- [ChannelChunkCount channel_chunk_count U32]
+ [ChunkChannelCount chunk_channel_count U32]
[InterpolationMode interpolation_mode U32]
[SampleCount sample_count U32]
[TransmitReceiveOrientation transmit_receive_orientation U32]
@@ -440,6 +442,35 @@
[output_size_z U32]
}
}
+
+ @Shader(reshape.glsl) Reshape
+ {
+ @Enumeration DataKind
+
+ @Bake
+ {
+ [InputDataKind input_data_kind U32]
+ [OutputDataKind output_data_kind U32]
+ [SizeX size_x U32]
+ [SizeY size_y U32]
+ [SizeZ size_z U32]
+ [InputStrideX input_stride_x U32]
+ [InputStrideY input_stride_y U32]
+ [InputStrideZ input_stride_z U32]
+ [OutputStrideX output_stride_x U32]
+ [OutputStrideY output_stride_y U32]
+ [OutputStrideZ output_stride_z U32]
+ [Interleave interleave B32]
+ [Deinterleave deinterleave B32]
+ }
+
+ @PushConstants
+ {
+ [output_buffer U64]
+ [left_input_buffer U64]
+ [right_input_buffer U64]
+ }
+ }
}
// NOTE: general compute shaders which do not need baking
diff --git a/beamformer_core.c b/beamformer_core.c
@@ -1,12 +1,10 @@
/* See LICENSE for license details. */
/* TODO(rnp):
+ * [ ]: bug? HERCULES might be broken, we may need to to chunk on transmits instead of channels
* [ ]: refactor: plan_compute should build its own "command graph" which tracks
* dependencies better. It is very important that unnecessary barriers are
* not placed between compute stages which requires knowledge of the entire
* graph.
- * [ ]: refactor: DecodeMode_None should use a different mapping and optional conversion shader
- * for rf only mode with no filter and demod/filter should gain the OutputFloats flag for iq
- * case and rf mode with filter; this can also be used instead of first pass uniform
* [ ]: refactor: replace UploadRF with just the scratch_rf_size variable,
* use below to spin wait in library
* [ ]: utilize umonitor/umwait (intel), monitorx/mwaitx (amd), and wfe/sev (aarch64)
@@ -248,6 +246,16 @@ dispatch_for_output(uv3 layout, iv3 points)
return result;
}
+function uv3
+decode_data_stride(b32 input, u32 samples, u32 channels, u32 acquisitions)
+{
+ uv3 result;
+ result.x = input ? channels * acquisitions : 1;
+ result.y = input ? acquisitions : samples * acquisitions;
+ result.z = input ? 1 : samples;
+ return result;
+}
+
function b32
compute_plan_push_shader(BeamformerComputePlan *p, BeamformerShaderKind shader, BeamformerShaderParameters *sp)
{
@@ -278,8 +286,32 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
if (demodulate) run_cuda_hilbert = 0;
- BeamformerDataKind data_kind = pb->pipeline.data_kind;
- cp->iq_pipeline = beamformer_data_kind_complex[data_kind] || demodulate || run_cuda_hilbert;
+ BeamformerDataKind input_data_kind = pb->pipeline.data_kind;
+ cp->iq_pipeline = beamformer_data_kind_complex[input_data_kind] || demodulate || run_cuda_hilbert;
+
+ BeamformerDataKind das_data_kind = cp->iq_pipeline ? BeamformerDataKind_Float32Complex
+ : BeamformerDataKind_Float32;
+
+ read_only local_persist BeamformerDataKind input_to_intermediate_data_kind[] = {
+ [BeamformerDataKind_Int16] = BeamformerDataKind_Float16,
+ [BeamformerDataKind_Float16] = BeamformerDataKind_Float16,
+ [BeamformerDataKind_Float32] = BeamformerDataKind_Float32,
+ [BeamformerDataKind_Int16Complex] = BeamformerDataKind_Float16,
+ [BeamformerDataKind_Float16Complex] = BeamformerDataKind_Float16,
+ [BeamformerDataKind_Float32Complex] = BeamformerDataKind_Float32,
+ };
+ read_only local_persist b8 input_needs_deinterleave[] = {
+ [BeamformerDataKind_Int16] = 0,
+ [BeamformerDataKind_Float16] = 0,
+ [BeamformerDataKind_Float32] = 0,
+ [BeamformerDataKind_Int16Complex] = 1,
+ [BeamformerDataKind_Float16Complex] = 1,
+ [BeamformerDataKind_Float32Complex] = 1,
+ };
+ BeamformerDataKind intermediate_data_kind = input_to_intermediate_data_kind[input_data_kind];
+
+ cp->raw_channel_byte_stride = pb->parameters.sample_count * pb->parameters.acquisition_count
+ * beamformer_data_kind_byte_size[input_data_kind];
f32 sampling_frequency = pb->parameters.sampling_frequency;
u32 decimation_rate = Max(pb->parameters.decimation_rate, 1);
@@ -289,13 +321,11 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
sampling_frequency /= 2 * (f32)decimation_rate;
}
- cp->raw_channel_byte_stride = pb->parameters.sample_count * pb->parameters.acquisition_count * beamformer_data_kind_byte_size[data_kind];
-
cp->channel_count = pb->parameters.channel_count;
- u32 channel_chunk_count = Min(cp->channel_count, BeamformerChannelChunkCount);
+ u32 chunk_channel_count = Min(cp->channel_count, BeamformerChunkChannelCount);
- cp->rf_size = sample_count * pb->parameters.acquisition_count * channel_chunk_count;
+ cp->rf_size = sample_count * pb->parameters.acquisition_count * chunk_channel_count;
if (cp->iq_pipeline) cp->rf_size *= 8;
else cp->rf_size *= 4;
@@ -331,46 +361,79 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
BeamformerShaderKind *last_shader = cp->pipeline.shaders + slot - 1;
assert(first || ((*last_shader == BeamformerShaderKind_Demodulate ||
*last_shader == BeamformerShaderKind_Filter)));
+ b32 decode = pb->parameters.decode_mode != BeamformerDecodeMode_None;
+ if (first && compute_plan_push_shader(cp, BeamformerShaderKind_Reshape, sp)) {
+ sd = cp->shader_descriptors + cp->pipeline.shader_count - 1;
+
+ sd->layout = (uv3){{subgroup_size, 1, 1}};
+
+ sd->dispatch.x = (u32)(ceil_f32((f32)pb->parameters.sample_count / sd->layout.x));
+ sd->dispatch.y = chunk_channel_count;
+ sd->dispatch.z = pb->parameters.acquisition_count;
+
+ uv3 output_stride = decode_data_stride(decode, pb->parameters.sample_count,
+ chunk_channel_count, pb->parameters.acquisition_count);
+
+ BeamformerReshapeBakeParameters *rb = &sd->bake.Reshape;
+ rb->input_data_kind = input_data_kind;
+ rb->output_data_kind = decode ? intermediate_data_kind : BeamformerDataKind_Float32;
+ rb->size_x = pb->parameters.sample_count;
+ rb->size_y = chunk_channel_count;
+ rb->size_z = pb->parameters.acquisition_count;
+ rb->input_stride_x = 1;
+ rb->input_stride_y = pb->parameters.acquisition_count * pb->parameters.sample_count;
+ rb->input_stride_z = pb->parameters.sample_count;
+ rb->output_stride_x = output_stride.x;
+ rb->output_stride_y = output_stride.y;
+ rb->output_stride_z = output_stride.z;
+ rb->interleave = 0;
+ rb->deinterleave = decode ? input_needs_deinterleave[input_data_kind] : 0;
+ }
- if ((first || pb->parameters.decode_mode != BeamformerDecodeMode_None) &&
- compute_plan_push_shader(cp, shader, sp))
- {
- BeamformerDecodeBakeParameters *db = &sd->bake.Decode;
+ if (decode && compute_plan_push_shader(cp, shader, sp)) {
+ sd = cp->shader_descriptors + cp->pipeline.shader_count - 1;
- db->data_kind = data_kind;
- if (!first) {
- if (data_kind == BeamformerDataKind_Int16) {
- db->data_kind = BeamformerDataKind_Float16Complex;
- } else {
- db->data_kind = BeamformerDataKind_Float32Complex;
- }
- }
-
- db->decode_mode = pb->parameters.decode_mode;
- db->transmit_count = pb->parameters.acquisition_count;
+ BeamformerDecodeBakeParameters *db = &sd->bake.Decode;
+ db->data_kind = intermediate_data_kind;
- u32 channel_stride = pb->parameters.acquisition_count * pb->parameters.sample_count;
- db->input_sample_stride = first? 1 : ld->bake.Filter.output_sample_stride;
- db->input_channel_stride = first? channel_stride : ld->bake.Filter.output_channel_stride;
- db->input_transmit_stride = first? pb->parameters.sample_count : 1;
+ u32 decode_sample_count = demodulate ? 2 * sample_count : sample_count;
+ db->decode_mode = pb->parameters.decode_mode;
+ db->transmit_count = pb->parameters.acquisition_count;
+ db->chunk_channel_count = chunk_channel_count;
+ // NOTE(rnp): ignored when using coop matrices
db->output_sample_stride = das_sample_stride;
db->output_channel_stride = das_channel_stride;
db->output_transmit_stride = das_transmit_stride;
- if (first) {
- db->output_channel_stride *= decimation_rate;
- db->output_transmit_stride *= decimation_rate;
- }
- db->dilate_output = run_cuda_hilbert;
- db->to_process = 1;
+ db->to_process = 1;
+
+ b32 use_coop_matrix = vk_gpu_info()->cooperative_matrix &&
+ db->data_kind == BeamformerDataKind_Float16 &&
+ (db->transmit_count % 16 == 0) &&
+ (chunk_channel_count % 16 == 0);
+ b32 extra_reshape = 0;
if (db->decode_mode == BeamformerDecodeMode_None) {
sd->layout = (uv3){{subgroup_size, 1, 1}};
- sd->dispatch.x = (u32)ceil_f32((f32)sample_count / (f32)sd->layout.x);
- sd->dispatch.y = (u32)ceil_f32((f32)channel_chunk_count / (f32)sd->layout.y);
+ sd->dispatch.x = (u32)ceil_f32((f32)decode_sample_count / (f32)sd->layout.x);
+ sd->dispatch.y = (u32)ceil_f32((f32)chunk_channel_count / (f32)sd->layout.y);
sd->dispatch.z = (u32)ceil_f32((f32)pb->parameters.acquisition_count / (f32)sd->layout.z);
+ } else if (use_coop_matrix) {
+ extra_reshape = 1;
+ // TODO(rnp): shared memory for larger sizes
+
+ sd->layout = (uv3){{subgroup_size, 1, 1}};
+
+ db->cooperative_matrix = 1;
+ db->cooperative_matrix_m = 16;
+ db->cooperative_matrix_n = 16;
+ db->cooperative_matrix_k = 16;
+
+ sd->dispatch.x = db->transmit_count / db->cooperative_matrix_n;
+ sd->dispatch.y = chunk_channel_count / db->cooperative_matrix_m;
+ sd->dispatch.z = decode_sample_count;
} else if (db->transmit_count > 40) {
db->use_shared_memory = 1;
@@ -381,24 +444,46 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
db->transmit_count == 96 || db->transmit_count == 160;
sd->layout = (uv3){{4, 1, use_16z? 16 : 32}};
- sd->dispatch.x = (u32)ceil_f32((f32)sample_count / (f32)sd->layout.x);
- sd->dispatch.y = (u32)ceil_f32((f32)channel_chunk_count / (f32)sd->layout.y);
+ sd->dispatch.x = (u32)ceil_f32((f32)decode_sample_count / (f32)sd->layout.x);
+ sd->dispatch.y = (u32)ceil_f32((f32)chunk_channel_count / (f32)sd->layout.y);
sd->dispatch.z = (u32)ceil_f32((f32)pb->parameters.acquisition_count / (f32)sd->layout.z / (f32)db->to_process);
} else {
/* NOTE(rnp): register caching. using more threads will cause the compiler to do
* contortions to avoid spilling registers. using less gives higher performance */
sd->layout = (uv3){{subgroup_size / 2, 1, 1}};
- sd->dispatch.x = (u32)ceil_f32((f32)sample_count / (f32)sd->layout.x);
- sd->dispatch.y = (u32)ceil_f32((f32)channel_chunk_count / (f32)sd->layout.y);
+ sd->dispatch.x = (u32)ceil_f32((f32)decode_sample_count / (f32)sd->layout.x);
+ sd->dispatch.y = (u32)ceil_f32((f32)chunk_channel_count / (f32)sd->layout.y);
sd->dispatch.z = 1;
}
- if (first) sd->dispatch.x *= decimation_rate;
-
- /* NOTE(rnp): decode 2 samples per dispatch when data is i16 */
- if (first && data_kind == BeamformerDataKind_Int16)
- sd->dispatch.x = (u32)ceil_f32((f32)sd->dispatch.x / 2);
+ if (extra_reshape && compute_plan_push_shader(cp, BeamformerShaderKind_Reshape, sp)) {
+ cp->q_rf_data_offset = chunk_channel_count * sample_count * pb->parameters.acquisition_count *
+ beamformer_data_kind_byte_size[BeamformerDataKind_Float32];
+
+ sd = cp->shader_descriptors + cp->pipeline.shader_count - 1;
+ sd->layout.x = Min(subgroup_size, db->transmit_count);
+ sd->layout.y = subgroup_size / sd->layout.x;
+ sd->layout.z = 1;
+
+ sd->dispatch.x = (u32)(ceil_f32((f32)db->transmit_count / sd->layout.x));
+ sd->dispatch.y = (u32)(ceil_f32((f32)chunk_channel_count / sd->layout.y));
+ sd->dispatch.z = sample_count;
+
+ BeamformerReshapeBakeParameters *rb = &sd->bake.Reshape;
+ rb->input_data_kind = BeamformerDataKind_Float32;
+ rb->output_data_kind = BeamformerDataKind_Float32;
+ rb->size_x = db->transmit_count;
+ rb->size_y = chunk_channel_count;
+ rb->size_z = sample_count;
+ rb->input_stride_x = 1;
+ rb->input_stride_y = db->transmit_count;
+ rb->input_stride_z = chunk_channel_count * db->transmit_count;
+ rb->output_stride_x = das_transmit_stride;
+ rb->output_stride_y = das_channel_stride;
+ rb->output_stride_z = das_sample_stride;
+ rb->interleave = cp->iq_pipeline;
+ }
}
}break;
@@ -417,8 +502,12 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
fb->demodulate = demod;
fb->complex_filter = f->parameters.complex;
- fb->data_kind = data_kind;
- if (!first) fb->data_kind = BeamformerDataKind_Float32;
+ // NOTE(rnp): if we are decoding we need to deinterleave I and Q channels
+ if (pb->parameters.decode_mode != BeamformerDecodeMode_None)
+ fb->batch_sample_count = chunk_channel_count * sample_count * pb->parameters.acquisition_count;
+
+ fb->data_kind = input_data_kind;
+ if (!first) fb->data_kind = intermediate_data_kind;
/* NOTE(rnp): when we are demodulating we pretend that the sampler was alternating
* between sampling the I portion and the Q portion of an IQ signal. Therefore there
@@ -448,8 +537,8 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
fb->output_floats = 1;
} else {
/* NOTE(rnp): output optimized layout for decoding */
- fb->output_channel_stride = das_channel_stride;
- fb->output_sample_stride = pb->parameters.acquisition_count;
+ fb->output_channel_stride = pb->parameters.acquisition_count;
+ fb->output_sample_stride = pb->parameters.acquisition_count * chunk_channel_count;
fb->output_transmit_stride = 1;
}
} else {
@@ -471,7 +560,7 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
sd->layout = (uv3){{subgroup_size, 1, 1}};
sd->dispatch.x = (u32)ceil_f32((f32)sample_count / (f32)sd->layout.x);
- sd->dispatch.y = (u32)ceil_f32((f32)channel_chunk_count / (f32)sd->layout.y);
+ sd->dispatch.y = (u32)ceil_f32((f32)chunk_channel_count / (f32)sd->layout.y);
sd->dispatch.z = (u32)ceil_f32((f32)pb->parameters.acquisition_count / (f32)sd->layout.z);
}
}break;
@@ -481,12 +570,7 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
cp->first_image_shader_index = cp->pipeline.shader_count;
BeamformerDASBakeParameters *db = &sd->bake.DAS;
- db->data_kind = BeamformerDataKind_Float32;
- if (cp->iq_pipeline) db->data_kind = BeamformerDataKind_Float32Complex;
-
- cp->voxel_transform = m4_mul(cp->ui_voxel_transform, pb->parameters.das_voxel_transform);
- cp->xdc_element_pitch = pb->parameters.xdc_element_pitch;
-
+ db->data_kind = das_data_kind;
db->sampling_frequency = sampling_frequency;
db->demodulation_frequency = pb->parameters.demodulation_frequency;
db->speed_of_sound = pb->parameters.speed_of_sound;
@@ -496,15 +580,18 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
db->sample_count = sample_count;
db->channel_count = pb->parameters.channel_count;
db->acquisition_count = pb->parameters.acquisition_count;
+ db->chunk_channel_count = chunk_channel_count;
db->interpolation_mode = pb->parameters.interpolation_mode;
db->transmit_angle = pb->parameters.focal_vector.E[0];
db->focus_depth = pb->parameters.focal_vector.E[1];
db->transmit_receive_orientation = pb->parameters.transmit_receive_orientation;
- db->channel_chunk_count = channel_chunk_count;
// NOTE(rnp): old gcc will miscompile an assignment
mem_copy(cp->xdc_transform.E, pb->parameters.xdc_transform.E, sizeof(cp->xdc_transform));
+ cp->voxel_transform = m4_mul(cp->ui_voxel_transform, pb->parameters.das_voxel_transform);
+ cp->xdc_element_pitch = pb->parameters.xdc_element_pitch;
+
u32 id = pb->parameters.acquisition_kind;
if (id == BeamformerAcquisitionKind_UFORCES || id == BeamformerAcquisitionKind_FORCES)
cp->voxel_transform = m4_mul(cp->xdc_transform, cp->voxel_transform);
@@ -544,7 +631,7 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
default:{}break;
}
}
- cp->pipeline.data_kind = data_kind;
+ cp->pipeline.data_kind = input_data_kind;
if (cp->first_image_shader_index == 0)
cp->first_image_shader_index = cp->pipeline.shader_count;
@@ -735,7 +822,7 @@ beamformer_commit_parameter_block(BeamformerCtx *ctx, BeamformerComputePlan *cp,
}
if (cp->hadamard_order != (i32)cp->acquisition_count)
- update_hadamard(cp, (i32)cp->acquisition_count, 0, arena);
+ update_hadamard(cp, (i32)cp->acquisition_count, vk_gpu_info()->cooperative_matrix, arena);
}break;
case BeamformerParameterBlockRegion_ChannelMapping:{
@@ -831,7 +918,7 @@ do_compute_shader(BeamformerCtx *ctx, VulkanHandle cmd, BeamformerComputePlan *c
// NOTE(rnp): first pass or last stage output
{
.gpu_buffer = &cc->ping_pong_buffer,
- .offset = pp_output_pointer - cc->ping_pong_buffer.gpu_pointer,
+ .offset = pp_input_pointer - cc->ping_pong_buffer.gpu_pointer,
.size = pp_size,
},
// NOTE(rnp): output for DAS
@@ -879,14 +966,14 @@ do_compute_shader(BeamformerCtx *ctx, VulkanHandle cmd, BeamformerComputePlan *c
BeamformerFilterPushConstants pc = {
.filter_coefficients = cp->filters[filter_slot].buffer.gpu_pointer,
.input_data = shader_slot == 0 ? rf_pointer : pp_input_pointer,
- .output_element_offset = output_index * pp_size / element_size,
+ .output_element_offset = 2 * output_index * pp_size / element_size,
};
GPUMemoryBarrierInfo memory_barriers[] = {
// NOTE(rnp): last stage output
{
.gpu_buffer = &cc->ping_pong_buffer,
- .offset = pp_output_pointer - cc->ping_pong_buffer.gpu_pointer,
+ .offset = pp_input_pointer - cc->ping_pong_buffer.gpu_pointer,
.size = pp_size,
},
// NOTE(rnp): output for DAS
@@ -1003,6 +1090,41 @@ do_compute_shader(BeamformerCtx *ctx, VulkanHandle cmd, BeamformerComputePlan *c
vk_command_dispatch_compute(cmd, dispatch);
}break;
+ case BeamformerShaderKind_Reshape:{
+ BeamformerReshapePushConstants pc = {
+ .left_input_buffer = pp_input_pointer,
+ .right_input_buffer = pp_input_pointer + cp->q_rf_data_offset,
+ };
+
+ if ((shader_slot + 1) == das_index) pc.output_buffer = pp_das_pointer;
+ else pc.output_buffer = pp_output_pointer;
+
+ GPUMemoryBarrierInfo memory_barriers[]= {
+ // NOTE(rnp): first pass or last stage output
+ {
+ .gpu_buffer = &cc->ping_pong_buffer,
+ .offset = pp_input_pointer - cc->ping_pong_buffer.gpu_pointer,
+ .size = pp_size,
+ },
+ // NOTE(rnp): output for DAS
+ {
+ .gpu_buffer = &cc->ping_pong_buffer,
+ .offset = pp_das_pointer - cc->ping_pong_buffer.gpu_pointer,
+ .size = pp_size,
+ },
+ };
+
+ u32 barrier_count = 1;
+ if (shader_slot + 1 == das_index)
+ barrier_count++;
+
+ vk_command_buffer_memory_barriers(cmd, memory_barriers, barrier_count);
+ vk_command_push_constants(cmd, 0, sizeof(pc), &pc);
+ vk_command_dispatch_compute(cmd, dispatch);
+
+ cc->ping_pong_input_index = !cc->ping_pong_input_index;
+ }break;
+
// NOTE(rnp): invalid stages should be filtered in planning phase
InvalidDefaultCase;
}
@@ -1216,7 +1338,7 @@ complete_queue(BeamformerCtx *ctx, BeamformWorkQueue *q, Arena *arena)
for (u32 channel_offset = 0;
channel_offset < cp->channel_count;
- channel_offset += BeamformerChannelChunkCount)
+ channel_offset += BeamformerChunkChannelCount)
{
u64 rf_pointer = rf->buffer.gpu_pointer + slot * rf->active_rf_size;
rf_pointer += cp->raw_channel_byte_stride * channel_offset;
@@ -1244,7 +1366,7 @@ complete_queue(BeamformerCtx *ctx, BeamformWorkQueue *q, Arena *arena)
/* NOTE(rnp): this blocks until work completes */
u64 *timestamps = vk_command_read_timestamps(VulkanTimeline_Compute, &scratch);
- i32 steps = ((i32)cp->channel_count / BeamformerChannelChunkCount) - 1;
+ i32 steps = ((i32)cp->channel_count / BeamformerChunkChannelCount) - 1;
i32 step = 0;
u32 shader_index = 0;
u64 last_time = timestamps[0] > 0 ? timestamps[1] : 0;
diff --git a/beamformer_internal.h b/beamformer_internal.h
@@ -94,6 +94,8 @@ typedef struct {
u16 max_msaa_samples;
u16 subgroup_size;
+ b32 cooperative_matrix;
+
u32 max_image_dimension_2D;
// NOTE(rnp): vulkan compute will output to a buffer so this won't be relevant
u32 max_image_dimension_3D;
@@ -302,6 +304,7 @@ struct BeamformerComputePlan {
u32 rf_size;
i32 hadamard_order;
b32 iq_pipeline;
+ u32 q_rf_data_offset;
m4 voxel_transform;
m4 ui_voxel_transform;
diff --git a/generated/beamformer.meta.c b/generated/beamformer.meta.c
@@ -3,7 +3,7 @@
// GENERATED CODE
// NOTE: Constants (Integer)
-#define BeamformerChannelChunkCount (16)
+#define BeamformerChunkChannelCount (16)
#define BeamformerFilterSlots (4)
#define BeamformerMaxBacklogFrames (4096)
#define BeamformerMaxChannelCount (256)
@@ -105,16 +105,17 @@ typedef enum {
BeamformerShaderKind_Sum = 6,
BeamformerShaderKind_MinMax = 7,
BeamformerShaderKind_CoherencyWeighting = 8,
- BeamformerShaderKind_BufferClear = 9,
- BeamformerShaderKind_RenderBeamformed = 10,
+ BeamformerShaderKind_Reshape = 9,
+ BeamformerShaderKind_BufferClear = 10,
+ BeamformerShaderKind_RenderBeamformed = 11,
BeamformerShaderKind_Count,
BeamformerShaderKind_ComputeFirst = BeamformerShaderKind_CudaDecode,
BeamformerShaderKind_ComputeLast = BeamformerShaderKind_MinMax,
BeamformerShaderKind_ComputeCount = 8,
BeamformerShaderKind_ComputeHelpersFirst = BeamformerShaderKind_CoherencyWeighting,
- BeamformerShaderKind_ComputeHelpersLast = BeamformerShaderKind_CoherencyWeighting,
- BeamformerShaderKind_ComputeHelpersCount = 1,
+ BeamformerShaderKind_ComputeHelpersLast = BeamformerShaderKind_Reshape,
+ BeamformerShaderKind_ComputeHelpersCount = 2,
BeamformerShaderKind_ComputeInternalFirst = BeamformerShaderKind_BufferClear,
BeamformerShaderKind_ComputeInternalLast = BeamformerShaderKind_BufferClear,
BeamformerShaderKind_ComputeInternalCount = 1,
@@ -125,17 +126,18 @@ typedef enum {
typedef struct {
u32 data_kind;
- u32 dilate_output;
- u32 use_shared_memory;
+ b32 use_shared_memory;
u32 decode_mode;
- u32 input_channel_stride;
- u32 input_sample_stride;
- u32 input_transmit_stride;
u32 output_channel_stride;
u32 output_sample_stride;
u32 output_transmit_stride;
u32 to_process;
u32 transmit_count;
+ u32 chunk_channel_count;
+ b32 cooperative_matrix;
+ u32 cooperative_matrix_m;
+ u32 cooperative_matrix_n;
+ u32 cooperative_matrix_k;
} BeamformerDecodeBakeParameters;
typedef struct {
@@ -152,6 +154,7 @@ typedef struct {
u32 output_sample_stride;
u32 output_transmit_stride;
u32 sample_count;
+ u32 batch_sample_count;
f32 demodulation_frequency;
f32 sampling_frequency;
} BeamformerFilterBakeParameters;
@@ -165,7 +168,7 @@ typedef struct {
u32 acquisition_count;
u32 acquisition_kind;
u32 channel_count;
- u32 channel_chunk_count;
+ u32 chunk_channel_count;
u32 interpolation_mode;
u32 sample_count;
u32 transmit_receive_orientation;
@@ -183,6 +186,22 @@ typedef struct {
} BeamformerCoherencyWeightingBakeParameters;
typedef struct {
+ u32 input_data_kind;
+ u32 output_data_kind;
+ u32 size_x;
+ u32 size_y;
+ u32 size_z;
+ u32 input_stride_x;
+ u32 input_stride_y;
+ u32 input_stride_z;
+ u32 output_stride_x;
+ u32 output_stride_y;
+ u32 output_stride_z;
+ b32 interleave;
+ b32 deinterleave;
+} BeamformerReshapeBakeParameters;
+
+typedef struct {
u64 hadamard_buffer;
u64 rf_buffer;
u64 output_buffer;
@@ -228,6 +247,12 @@ typedef struct {
} BeamformerCoherencyWeightingPushConstants;
typedef struct {
+ u64 output_buffer;
+ u64 left_input_buffer;
+ u64 right_input_buffer;
+} BeamformerReshapePushConstants;
+
+typedef struct {
uv4 clear_v4;
u64 data;
u32 bins;
@@ -400,6 +425,7 @@ typedef union {
BeamformerFilterBakeParameters Filter;
BeamformerDASBakeParameters DAS;
BeamformerCoherencyWeightingBakeParameters CoherencyWeighting;
+ BeamformerReshapeBakeParameters Reshape;
} BeamformerShaderBakeParameters;
read_only global u8 beamformer_data_kind_element_size[] = {
@@ -503,6 +529,7 @@ read_only global s8 beamformer_shader_names[] = {
s8_comp("Sum"),
s8_comp("MinMax"),
s8_comp("CoherencyWeighting"),
+ s8_comp("Reshape"),
s8_comp("BufferClear"),
s8_comp("RenderBeamformed"),
};
@@ -514,6 +541,7 @@ read_only global BeamformerShaderKind beamformer_reloadable_shader_kinds[] = {
BeamformerShaderKind_Sum,
BeamformerShaderKind_MinMax,
BeamformerShaderKind_CoherencyWeighting,
+ BeamformerShaderKind_Reshape,
BeamformerShaderKind_BufferClear,
BeamformerShaderKind_RenderBeamformed,
};
@@ -525,6 +553,7 @@ read_only global s8 *beamformer_reloadable_shader_files[] = {
(s8 []){s8_comp("sum.glsl")},
(s8 []){s8_comp("min_max.glsl")},
(s8 []){s8_comp("coherency_weighting.glsl")},
+ (s8 []){s8_comp("reshape.glsl")},
(s8 []){s8_comp("buffer_clear.glsl")},
(s8 []){s8_comp("render_3d.vert.glsl"), s8_comp("render_3d.frag.glsl")},
};
@@ -541,6 +570,7 @@ read_only global i32 beamformer_shader_reloadable_index_by_shader[] = {
5,
6,
7,
+ 8,
};
read_only global i32 beamformer_reloadable_compute_shader_info_indices[] = {
@@ -553,14 +583,15 @@ read_only global i32 beamformer_reloadable_compute_shader_info_indices[] = {
read_only global i32 beamformer_reloadable_compute_helpers_shader_info_indices[] = {
5,
+ 6,
};
read_only global i32 beamformer_reloadable_compute_internal_shader_info_indices[] = {
- 6,
+ 7,
};
read_only global i32 beamformer_reloadable_render_shader_info_indices[] = {
- 7,
+ 8,
};
read_only global s8 beamformer_shader_global_header_strings[] = {
@@ -667,6 +698,13 @@ read_only global s8 beamformer_shader_global_header_strings[] = {
"\n"),
s8_comp(""
"layout(push_constant, std430) uniform PushConstants {\n"
+ " uint64_t output_buffer;\n"
+ " uint64_t left_input_buffer;\n"
+ " uint64_t right_input_buffer;\n"
+ "};\n"
+ "\n"),
+ s8_comp(""
+ "layout(push_constant, std430) uniform PushConstants {\n"
" u32vec4 clear_v4;\n"
" uint64_t data;\n"
" uint32_t bins;\n"
@@ -699,6 +737,7 @@ read_only global b8 beamformer_shader_has_primitive[] = {
0,
0,
0,
+ 0,
1,
};
@@ -710,6 +749,7 @@ read_only global b8 beamformer_shader_primitive_is_vertex[] = {
0,
0,
0,
+ 0,
1,
};
@@ -720,8 +760,9 @@ read_only global i32 *beamformer_shader_header_vectors[] = {
(i32 []){0, 12},
0,
(i32 []){0, 13},
- (i32 []){14},
- (i32 []){0, 15},
+ (i32 []){0, 14},
+ (i32 []){15},
+ (i32 []){0, 16},
};
read_only global i32 beamformer_shader_header_vector_lengths[] = {
@@ -731,6 +772,7 @@ read_only global i32 beamformer_shader_header_vector_lengths[] = {
2,
0,
2,
+ 2,
1,
2,
};
@@ -738,17 +780,18 @@ read_only global i32 beamformer_shader_header_vector_lengths[] = {
read_only global s8 *beamformer_shader_bake_parameter_names[] = {
(s8 []){
s8_comp("DataKind"),
- s8_comp("DilateOutput"),
s8_comp("UseSharedMemory"),
s8_comp("DecodeMode"),
- s8_comp("InputChannelStride"),
- s8_comp("InputSampleStride"),
- s8_comp("InputTransmitStride"),
s8_comp("OutputChannelStride"),
s8_comp("OutputSampleStride"),
s8_comp("OutputTransmitStride"),
s8_comp("ToProcess"),
s8_comp("TransmitCount"),
+ s8_comp("ChunkChannelCount"),
+ s8_comp("CooperativeMatrix"),
+ s8_comp("CooperativeMatrixM"),
+ s8_comp("CooperativeMatrixN"),
+ s8_comp("CooperativeMatrixK"),
},
(s8 []){
s8_comp("DataKind"),
@@ -764,6 +807,7 @@ read_only global s8 *beamformer_shader_bake_parameter_names[] = {
s8_comp("OutputSampleStride"),
s8_comp("OutputTransmitStride"),
s8_comp("SampleCount"),
+ s8_comp("BatchSampleCount"),
s8_comp("DemodulationFrequency"),
s8_comp("SamplingFrequency"),
},
@@ -776,7 +820,7 @@ read_only global s8 *beamformer_shader_bake_parameter_names[] = {
s8_comp("AcquisitionCount"),
s8_comp("AcquisitionKind"),
s8_comp("ChannelCount"),
- s8_comp("ChannelChunkCount"),
+ s8_comp("ChunkChannelCount"),
s8_comp("InterpolationMode"),
s8_comp("SampleCount"),
s8_comp("TransmitReceiveOrientation"),
@@ -793,28 +837,45 @@ read_only global s8 *beamformer_shader_bake_parameter_names[] = {
(s8 []){
s8_comp("DataKind"),
},
+ (s8 []){
+ s8_comp("InputDataKind"),
+ s8_comp("OutputDataKind"),
+ s8_comp("SizeX"),
+ s8_comp("SizeY"),
+ s8_comp("SizeZ"),
+ s8_comp("InputStrideX"),
+ s8_comp("InputStrideY"),
+ s8_comp("InputStrideZ"),
+ s8_comp("OutputStrideX"),
+ s8_comp("OutputStrideY"),
+ s8_comp("OutputStrideZ"),
+ s8_comp("Interleave"),
+ s8_comp("Deinterleave"),
+ },
0,
0,
};
read_only global u32 beamformer_shader_bake_parameter_float_bits[] = {
0x00000000UL,
- 0x00006000UL,
+ 0x0000c000UL,
0x0007f000UL,
0x00000000UL,
0x00000000UL,
0x00000000UL,
0x00000000UL,
0x00000000UL,
+ 0x00000000UL,
};
read_only global u8 beamformer_shader_bake_parameter_counts[] = {
- 12,
- 15,
+ 13,
+ 16,
19,
0,
0,
1,
+ 13,
0,
0,
};
@@ -826,6 +887,7 @@ read_only global u8 beamformer_shader_push_constant_sizes[] = {
sizeof(BeamformerSumPushConstants),
0,
sizeof(BeamformerCoherencyWeightingPushConstants),
+ sizeof(BeamformerReshapePushConstants),
sizeof(BeamformerBufferClearPushConstants),
sizeof(BeamformerRenderBeamformedPushConstants),
};
diff --git a/shaders/das.glsl b/shaders/das.glsl
@@ -5,14 +5,14 @@
#define RESULT_COHERENT_CAST(a) (a).x
#define RESULT_INCOHERENT_CAST(a) (a).y
#endif
- #define SAMPLE_TYPE float
+ #define SAMPLE_TYPE f32
#elif DataKind == DataKind_Float32Complex
#if CoherencyWeighting
#define RESULT_TYPE vec3
#define RESULT_COHERENT_CAST(a) (a).xy
#define RESULT_INCOHERENT_CAST(a) (a).z
#endif
- #define SAMPLE_TYPE vec2
+ #define SAMPLE_TYPE f32vec2
#else
#error DataKind unsupported for DAS
#endif
@@ -47,7 +47,6 @@ layout(std430, buffer_reference) buffer IncoherentOutput {
f32 x[];
};
-
#define RX_ORIENTATION(tx_rx) bitfieldExtract((tx_rx), 0, 4)
#define TX_ORIENTATION(tx_rx) bitfieldExtract((tx_rx), 4, 4)
@@ -89,19 +88,20 @@ SAMPLE_TYPE cubic(const int offset, const float t)
SAMPLE_TYPE T1 = C_SPLINE * (P2 - samples[0]);
SAMPLE_TYPE T2 = C_SPLINE * (samples[3] - P1);
-#if DataKind == DataKind_Float32
+ #if DataKind == DataKind_Float32
vec4 C = vec4(P1.x, P2.x, T1.x, T2.x);
float result = dot(S, h * C);
-#elif DataKind == DataKind_Float32Complex
+ #elif DataKind == DataKind_Float32Complex
mat2x4 C = mat2x4(vec4(P1.x, P2.x, T1.x, T2.x), vec4(P1.y, P2.y, T1.y, T2.y));
vec2 result = S * h * C;
-#endif
+ #endif
return result;
}
SAMPLE_TYPE sample_rf(const int rf_offset, const float index)
{
SAMPLE_TYPE result = SAMPLE_TYPE(0);
+
switch (InterpolationMode) {
case InterpolationMode_Nearest:{
if (int(index) >= 0 && int(round(index)) < SampleCount)
@@ -215,7 +215,7 @@ RESULT_TYPE RCA(const vec3 world_point)
int rf_offset = int(rf_element_offset) + acquisition * SampleCount;
rf_offset -= int(InterpolationMode == InterpolationMode_Cubic);
- for (int chunk_channel = 0; chunk_channel < ChannelChunkCount; chunk_channel++) {
+ for (int chunk_channel = 0; chunk_channel < ChunkChannelCount; chunk_channel++) {
int rx_channel = channel_offset + chunk_channel;
vec3 rx_center = vec3(rx_channel * xdc_element_pitch, 0);
vec2 receive_vector = xdc_world_point - rca_plane_projection(rx_center, rx_rows);
@@ -246,8 +246,8 @@ RESULT_TYPE HERCULES(const vec3 world_point)
const float apodization_test = 0.25f / (f_number_over_z * f_number_over_z);
RESULT_TYPE result = RESULT_TYPE(0);
- for (float chunk_channel = 0; chunk_channel < float(ChannelChunkCount); chunk_channel += 1.0f) {
- float rx_channel = float(channel_offset) + chunk_channel;
+ for (f32 chunk_channel = 0; chunk_channel < f32(ChunkChannelCount); chunk_channel += 1.0f) {
+ f32 rx_channel = f32(channel_offset) + chunk_channel;
int rf_offset = int(rf_element_offset) + int(chunk_channel) * SampleCount * AcquisitionCount + Sparse * SampleCount;
rf_offset -= int(InterpolationMode == InterpolationMode_Cubic);
@@ -294,7 +294,7 @@ RESULT_TYPE FORCES(const vec3 xdc_world_point)
float transmit_y_delta = xdc_world_point.y - xdc_element_pitch.y * ChannelCount / 2;
float transmit_yz_squared = transmit_y_delta * transmit_y_delta + z_delta_squared;
- for (float chunk_channel = 0; chunk_channel < float(ChannelChunkCount); chunk_channel += 1.0f) {
+ for (f32 chunk_channel = 0; chunk_channel < f32(ChunkChannelCount); chunk_channel += 1.0f) {
float rx_channel = float(channel_offset) + chunk_channel;
float receive_x_delta = xdc_world_point.x - rx_channel * xdc_element_pitch.x;
float a_arg = abs(FNumber * receive_x_delta / xdc_world_point.z);
diff --git a/shaders/decode.glsl b/shaders/decode.glsl
@@ -1,36 +1,21 @@
/* See LICENSE for license details. */
-/* NOTE(rnp): invoked with samples x channels x transmits
- * Each instance extracts a single time sample from a single channel for all transmits
- * and does a dot product with the appropriate row of the bound hadamard matrix
- * (unless decode_mode == DECODE_MODE_NONE). The result of this dot product is stored in the
- * output. In bulk this has the effect of computing a matrix multiply of the
- * sample-transmit plane with the bound hadamard matrix.
- */
+#if CooperativeMatrix
+#extension GL_KHR_cooperative_matrix : require
+#extension GL_KHR_memory_scope_semantics : require
+#endif
#if DataKind == DataKind_Float32
- #define INPUT_DATA_TYPE float
- #define SAMPLE_DATA_TYPE float
-#elif DataKind == DataKind_Float32Complex
- #define INPUT_DATA_TYPE vec2
- #define SAMPLE_DATA_TYPE vec2
-#elif DataKind == DataKind_Float16Complex
- #define INPUT_DATA_TYPE f16vec2
- #define SAMPLE_DATA_TYPE vec2
+ #define INPUT_DATA_TYPE f32
#elif DataKind == DataKind_Float16
- #define INPUT_DATA_TYPE float16_t
- #define SAMPLE_DATA_TYPE float
-#elif DataKind == DataKind_Int16Complex
- #define INPUT_DATA_TYPE i16vec2
- #define SAMPLE_DATA_TYPE vec2
+ #define INPUT_DATA_TYPE f16
#elif DataKind == DataKind_Int16
- #define INPUT_DATA_TYPE int16_t
- #define SAMPLE_DATA_TYPE float
+ #define INPUT_DATA_TYPE s16
#else
#error unsupported data kind for Decode
#endif
-// TODO(rnp): fix DilateOutput
+#define SAMPLE_DATA_TYPE f32
layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer RF {
INPUT_DATA_TYPE values[];
@@ -45,7 +30,7 @@ layout(std430, buffer_reference, buffer_reference_align = 64) restrict writeonly
};
layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer Hadamard {
- float16_t values[];
+ f16 values[];
};
SAMPLE_DATA_TYPE sample_rf_data(uint index)
@@ -55,6 +40,7 @@ SAMPLE_DATA_TYPE sample_rf_data(uint index)
}
#if UseSharedMemory
+
shared INPUT_DATA_TYPE rf[gl_WorkGroupSize.x * TransmitCount];
void run_decode_large(void)
{
@@ -69,7 +55,7 @@ void run_decode_large(void)
uint leftover_samples = rf.length() % thread_count;
uint samples_this_thread = samples_per_thread + uint(thread_index < leftover_samples);
- uint rf_offset = InputChannelStride * channel + TransmitCount * gl_WorkGroupID.x * gl_WorkGroupSize.x;
+ u32 rf_offset = TransmitCount * ChunkChannelCount * gl_WorkGroupID.x * gl_WorkGroupSize.x + TransmitCount * channel;
for (uint i = 0; i < samples_this_thread; i++) {
uint index = i * thread_count + thread_index;
@@ -107,11 +93,56 @@ void run_decode_large(void)
}
#endif
+#if CooperativeMatrix
+
+void run_decode_coop(void)
+{
+ #if UseSharedMemory
+ #else
+
+ u32vec2 tile_index = gl_WorkGroupID.xy;
+ u32 time_sample = gl_WorkGroupID.z;
+
+ coopmat<f16, gl_ScopeSubgroup, CooperativeMatrixM, CooperativeMatrixK, gl_MatrixUseA> rf_matrix;
+ coopmat<f16, gl_ScopeSubgroup, CooperativeMatrixK, CooperativeMatrixN, gl_MatrixUseB> hadamard_matrix;
+ coopmat<f32, gl_ScopeSubgroup, CooperativeMatrixM, CooperativeMatrixN, gl_MatrixUseAccumulator> result;
+ result = coopmat<f32, gl_ScopeSubgroup, CooperativeMatrixM, CooperativeMatrixN, gl_MatrixUseAccumulator>(0.0f);
+
+ u32 result_row = CooperativeMatrixM * tile_index.y;
+ u32 result_col = CooperativeMatrixN * tile_index.x;
+
+ u32 offset = ChunkChannelCount * TransmitCount * time_sample;
+
+ for (u32 k = 0; k < TransmitCount; k += CooperativeMatrixK) {
+ u32 rf_tile_row = CooperativeMatrixM * tile_index.y;
+ u32 rf_tile_col = k;
+ coopMatLoad(rf_matrix, RF(rf_buffer).values, offset + TransmitCount * rf_tile_row + rf_tile_col,
+ TransmitCount, gl_CooperativeMatrixLayoutRowMajor);
+
+ u32 hadamard_tile_row = k;
+ u32 hadamard_tile_col = CooperativeMatrixN * tile_index.x;
+ coopMatLoad(hadamard_matrix, Hadamard(hadamard_buffer).values,
+ TransmitCount * hadamard_tile_row + hadamard_tile_col, TransmitCount,
+ gl_CooperativeMatrixLayoutRowMajor);
+
+ result = coopMatMulAdd(rf_matrix, hadamard_matrix, result);
+ }
+
+ for (s32 i = 0; i < result.length(); i++)
+ result[i] = result[i] / f32(TransmitCount);
+
+ Output out_buffer = Output(output_buffer);
+ coopMatStore(result, out_buffer.values, offset + TransmitCount * result_row + result_col,
+ TransmitCount, gl_CooperativeMatrixLayoutRowMajor);
+ #endif
+}
+#endif
+
void run_decode_small(void)
{
- uint time_sample = gl_GlobalInvocationID.x;
- uint channel = gl_GlobalInvocationID.y;
- uint rf_offset = InputChannelStride * channel + TransmitCount * time_sample;
+ u32 time_sample = gl_GlobalInvocationID.x;
+ u32 channel = gl_GlobalInvocationID.y;
+ u32 rf_offset = TransmitCount * ChunkChannelCount * time_sample + TransmitCount * channel;
if (time_sample < OutputTransmitStride) {
INPUT_DATA_TYPE rf[TransmitCount];
@@ -142,50 +173,14 @@ void run_decode_small(void)
void main()
{
switch (DecodeMode) {
- case DecodeMode_None:{
- uint time_sample = gl_GlobalInvocationID.x;
- uint channel = gl_GlobalInvocationID.y;
- uint transmit = gl_GlobalInvocationID.z;
-
- if (time_sample < OutputTransmitStride) {
- uint in_off = InputChannelStride * channel +
- InputTransmitStride * transmit +
- InputSampleStride * time_sample;
-
- uint out_off = OutputChannelStride * channel +
- OutputTransmitStride * transmit +
- OutputSampleStride * time_sample;
-
- Output(output_buffer).values[out_off] = sample_rf_data(in_off);
- }
- }break;
case DecodeMode_Hadamard:{
- if (first_pass) {
- uint time_sample = gl_GlobalInvocationID.x;
- uint channel = gl_GlobalInvocationID.y;
- uint transmit = gl_GlobalInvocationID.z * ToProcess;
- if (time_sample < InputTransmitStride) {
- uint out_off = InputChannelStride * channel + TransmitCount * time_sample;
- uint in_off = InputChannelStride * channel + InputSampleStride * time_sample;
- #if UseSharedMemory
- in_off += InputTransmitStride * transmit;
- out_off += transmit;
- for (uint i = 0; i < ToProcess; i++, in_off += InputTransmitStride) {
- if (transmit + i < TransmitCount)
- OutputRF(output_rf_buffer).values[out_off + i] = RF(rf_buffer).values[in_off];
- }
- #else
- for (uint i = 0; i < TransmitCount; i++, in_off += InputTransmitStride)
- OutputRF(output_rf_buffer).values[out_off + i] = RF(rf_buffer).values[in_off];
- #endif
- }
- } else {
- #if UseSharedMemory
- run_decode_large();
- #else
- run_decode_small();
- #endif
- }
+ #if CooperativeMatrix
+ run_decode_coop();
+ #elif UseSharedMemory
+ run_decode_large();
+ #else
+ run_decode_small();
+ #endif
}break;
}
}
diff --git a/shaders/filter.glsl b/shaders/filter.glsl
@@ -1,50 +1,98 @@
/* See LICENSE for license details. */
-/* TODO(rnp): bug: this won't filter RF data correctly */
-#define SAMPLE_TYPE f32vec2
-#if DataKind == DataKind_Float32
- #define DATA_TYPE f32vec2
- #define RESULT_TYPE_CAST(v) (v)
- #define SAMPLE_TYPE_CAST(v) (v)
-#else
- #define DATA_TYPE i16vec2
- #define SAMPLE_TYPE_CAST(v) (v)
+#if DataKind == DataKind_Float32Complex || (DataKind == DataKind_Float32 && Demodulate)
+ #define INPUT_TYPE f32vec2
+ #define SAMPLE_TYPE f32vec2
+ #if BatchSampleCount
+ #define OUTPUT_TYPE f32
+ #else
+ #define OUTPUT_TYPE f32vec2
+ #endif
+#elif DataKind == DataKind_Float32
+ #define INPUT_TYPE f32
+ #define SAMPLE_TYPE f32
+ #define OUTPUT_TYPE f32
+#elif DataKind == DataKind_Float16Complex || (DataKind == DataKind_Float16 && Demodulate)
+ #define INPUT_TYPE f16vec2
+ #define SAMPLE_TYPE f16vec2
+ #if OutputFloats
+ #if BatchSampleCount
+ #define OUTPUT_TYPE f32
+ #else
+ #define OUTPUT_TYPE f32vec2
+ #endif
+ #else
+ #if BatchSampleCount
+ #define OUTPUT_TYPE f16
+ #else
+ #define OUTPUT_TYPE f16vec2
+ #endif
+ #endif
+#elif DataKind == DataKind_Float16
+ #define INPUT_TYPE f16
+ #define SAMPLE_TYPE f16
+ #define OUTPUT_TYPE f16
+#elif DataKind == DataKind_Int16Complex || (DataKind == DataKind_Int16 && Demodulate)
+ #define INPUT_TYPE s16vec2
+ #define SAMPLE_TYPE f16vec2
#if OutputFloats
- #define OUT_DATA_TYPE f32vec2
- #define RESULT_TYPE_CAST(v) f32vec2(v)
+ #if BatchSampleCount
+ #define OUTPUT_TYPE f32
+ #else
+ #define OUTPUT_TYPE f32vec2
+ #endif
#else
- #define OUT_DATA_TYPE f16vec2
- #define RESULT_TYPE_CAST(v) f16vec2(v)
+ #if BatchSampleCount
+ #define OUTPUT_TYPE f16
+ #else
+ #define OUTPUT_TYPE f16vec2
+ #endif
#endif
+#elif DataKind == DataKind_Int16
+ #define INPUT_TYPE s16
+ #define SAMPLE_TYPE f16
+ #define OUTPUT_TYPE f16
+#else
+ #error unsupported data kind
#endif
-#ifndef OUT_DATA_TYPE
- #define OUT_DATA_TYPE DATA_TYPE
+#define ComplexSampleType (DataKind == DataKind_Float32Complex || \
+ DataKind == DataKind_Float16Complex || \
+ DataKind == DataKind_Int16Complex || \
+ Demodulate)
+#if ComplexSampleType
+ #define RESULT_TYPE f32vec2
+#else
+ #define RESULT_TYPE f32
#endif
#if ComplexFilter
- #define FILTER_TYPE vec2
- #define apply_filter(iq, h) complex_mul((iq), (h))
+ #define FILTER_TYPE f32vec2
+#else
+ #define FILTER_TYPE f32
+#endif
+
+#if ComplexFilter && ComplexSampleType
+ #define apply_filter(iq, h) complex_mul(f32vec2(iq), f32vec2(h))
#else
- #define FILTER_TYPE float
- #define apply_filter(iq, h) ((iq) * (h))
+ #define apply_filter(iq, h) ((iq) * (h))
#endif
layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer Input {
- DATA_TYPE values[];
+ INPUT_TYPE x[];
};
layout(set = ShaderResourceKind_Buffer, binding = ShaderBufferSlot_PingPong) buffer Output {
- OUT_DATA_TYPE output_data[];
+ OUTPUT_TYPE output_data[];
};
layout(std430, buffer_reference, buffer_reference_align = 64) restrict readonly buffer Filter {
FILTER_TYPE values[FilterLength];
};
-SAMPLE_TYPE complex_mul(SAMPLE_TYPE a, SAMPLE_TYPE b)
+f32vec2 complex_mul(f32vec2 a, f32vec2 b)
{
mat2 m = mat2(b.x, b.y, -b.y, b.x);
- SAMPLE_TYPE result = SAMPLE_TYPE(m * a);
+ f32vec2 result = m * a;
return result;
}
@@ -52,14 +100,14 @@ SAMPLE_TYPE complex_mul(SAMPLE_TYPE a, SAMPLE_TYPE b)
SAMPLE_TYPE rotate_iq(SAMPLE_TYPE iq, uint index)
{
float arg = radians(360) * DemodulationFrequency * index / SamplingFrequency;
- SAMPLE_TYPE result = complex_mul(iq, SAMPLE_TYPE(cos(arg), -sin(arg)));
+ SAMPLE_TYPE result = SAMPLE_TYPE(complex_mul(iq, f32vec2(cos(arg), -sin(arg))));
return result;
}
#endif
SAMPLE_TYPE sample_rf(uint index)
{
- SAMPLE_TYPE result = SAMPLE_TYPE_CAST(Input(input_data).values[index]);
+ SAMPLE_TYPE result = SAMPLE_TYPE(Input(input_data).x[index]);
return result;
}
@@ -86,14 +134,14 @@ void main()
uint leftover_count = total_samples % thread_count;
uint samples_this_thread = samples_per_thread + uint(thread_index < leftover_count);
- const float scale = bool(ComplexFilter) ? 1 : sqrt(2.0f);
+ const SAMPLE_TYPE scale = SAMPLE_TYPE(bool(ComplexFilter) ? 1 : sqrt(2.0f));
for (uint i = 0; i < samples_this_thread; i++) {
uint index = thread_count * i + thread_index;
if (offset_wraps && index < FilterLength - 1) {
rf[index] = SAMPLE_TYPE(0);
} else {
#if Demodulate
- rf[index] = scale * rotate_iq(sample_rf(in_offset + index) * vec2(1, -1), index);
+ rf[index] = scale * rotate_iq(sample_rf(in_offset + index) * SAMPLE_TYPE(1, -1), index);
#else
rf[index] = sample_rf(in_offset + index);
#endif
@@ -103,14 +151,22 @@ void main()
barrier();
if (out_sample < SampleCount / DecimationRate) {
- SAMPLE_TYPE result = SAMPLE_TYPE(0);
+ RESULT_TYPE result = RESULT_TYPE(0);
uint offset = DecimationRate * thread_index;
for (uint j = 0; j < FilterLength; j++)
result += apply_filter(rf[offset + j], Filter(filter_coefficients).values[j]);
- uint out_offset = OutputChannelStride * channel +
- OutputTransmitStride * transmit +
- OutputSampleStride * out_sample;
- output_data[output_element_offset + out_offset] = RESULT_TYPE_CAST(result);
+ u32 out_offset = OutputChannelStride * channel +
+ OutputTransmitStride * transmit +
+ OutputSampleStride * out_sample;
+
+ #if BatchSampleCount
+ // NOTE(rnp): deinterleave
+ output_data[output_element_offset + out_offset] = OUTPUT_TYPE(result.x);
+ out_offset += BatchSampleCount;
+ output_data[output_element_offset + out_offset] = OUTPUT_TYPE(result.y);
+ #else
+ output_data[output_element_offset + out_offset] = OUTPUT_TYPE(result);
+ #endif
}
}
diff --git a/shaders/reshape.glsl b/shaders/reshape.glsl
@@ -0,0 +1,107 @@
+/* See LICENSE for license details. */
+
+#if InputDataKind == DataKind_Float32Complex
+ #define Input Float32Complex
+#elif InputDataKind == DataKind_Float32
+ #define Input Float32
+#elif InputDataKind == DataKind_Float16Complex || InputDataKind == DataKind_Int16Complex
+ #define Input Int16Complex
+#elif InputDataKind == DataKind_Float16 || InputDataKind == DataKind_Int16
+ #define Input Int16
+#else
+ #error unsupported data kind for Reshape
+#endif
+
+#if OutputDataKind == DataKind_Float32Complex
+ #if Interleave
+ #define InterleaveWide 1
+ #define Output Float32V4
+ #define OutputKind f32vec4
+ #else
+ #define Output Float32Complex
+ #define OutputKind f32vec2
+ #endif
+#elif OutputDataKind == DataKind_Float32
+ #if Interleave
+ #define Output Float32Complex
+ #define OutputKind f32vec2
+ #else
+ #define Output Float32
+ #define OutputKind f32
+ #endif
+#elif OutputDataKind == DataKind_Float16Complex || OutputDataKind == DataKind_Int16Complex
+ #if Interleave
+ #define InterleaveWide 1
+ #define Output Int16V4
+ #define OutputKind i16vec4
+ #else
+ #define Output Int16Complex
+ #define OutputKind s16vec2
+ #endif
+#elif OutputDataKind == DataKind_Float16 || OutputDataKind == DataKind_Int16
+ #if Interleave
+ #define Output Int16Complex
+ #define OutputKind s16vec2
+ #else
+ #define Output Int16
+ #define OutputKind s16
+ #endif
+#else
+ #error unsupported data kind for Reshape
+#endif
+
+#ifndef InterleaveWide
+ #define InterleaveWide 0
+#endif
+
+layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Int16 {
+ s16 x[];
+};
+
+layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Int16Complex {
+ s16vec2 x[];
+};
+
+layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Int16V4 {
+ i16vec4 x[];
+};
+
+layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Float32 {
+ f32 x[];
+};
+
+layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Float32Complex {
+ f32vec2 x[];
+};
+
+layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Float32V4 {
+ f32vec4 x[];
+};
+
+void main(void)
+{
+ if (all(lessThan(gl_GlobalInvocationID, uvec3(SizeX, SizeY, SizeZ)))) {
+ u32 x = gl_GlobalInvocationID.x;
+ u32 y = gl_GlobalInvocationID.y;
+ u32 z = gl_GlobalInvocationID.z;
+
+ u32 input_index = InputStrideX * x + InputStrideY * y + InputStrideZ * z;
+ u32 output_index = OutputStrideX * x + OutputStrideY * y + OutputStrideZ * z;
+
+ OutputKind out_value = OutputKind(0);
+
+ #if Interleave && InterleaveWide
+ out_value.xy = Input(left_input_buffer).x[input_index];
+ out_value.zw = Input(right_input_buffer).x[input_index];
+ #else
+ #if Interleave
+ out_value[0] = Input(left_input_buffer).x[input_index];
+ out_value[1] = Input(right_input_buffer).x[input_index];
+ #else
+ out_value = Input(left_input_buffer).x[input_index];
+ #endif
+ #endif
+
+ Output(output_buffer).x[output_index] = out_value;
+ }
+}
diff --git a/vulkan.c b/vulkan.c
@@ -188,16 +188,15 @@ read_only global const char *vk_required_instance_extensions[] = {
X("VK_KHR_timeline_semaphore") \
VK_OS_REQUIRED_DEVICE_EXTENSIONS_LIST
-#define X(str) str,
-read_only global const char *vk_required_device_extensions[] = {
-VK_REQUIRED_DEVICE_EXTENSIONS_LIST
-};
+#define X(str) s8_comp(str),
+read_only global s8 vk_required_device_extensions[] = {VK_REQUIRED_DEVICE_EXTENSIONS_LIST};
#undef X
-#define X(str) sizeof(str) - 1,
-read_only global u32 vk_required_device_extension_name_lengths[] = {
-VK_REQUIRED_DEVICE_EXTENSIONS_LIST
-};
+#define VK_OPTIONAL_DEVICE_EXTENSIONS_LIST \
+ X(VK_KHR, cooperative_matrix) \
+
+#define X(p, s, ...) s8_comp(#p "_" #s),
+read_only global s8 vk_optional_device_extensions[] = {VK_OPTIONAL_DEVICE_EXTENSIONS_LIST};
#undef X
#define VK_REQUIRED_PHYSICAL_FEATURES \
@@ -211,6 +210,7 @@ VK_REQUIRED_DEVICE_EXTENSIONS_LIST
X(bufferDeviceAddress) \
X(shaderFloat16) \
X(timelineSemaphore) \
+ X(vulkanMemoryModel) \
#define VK_REQUIRED_PHYSICAL_13_FEATURES \
X(dynamicRendering) \
@@ -220,11 +220,8 @@ VK_REQUIRED_DEVICE_EXTENSIONS_LIST
X(VK_KHR, shader_non_semantic_info) \
X(VK_KHR, shader_relaxed_extended_instruction) \
-#define X(p, s, ...) #p "_" #s,
-read_only global const char *vk_debug_extensions[] = {VK_DEBUG_EXTENSIONS};
-#undef X
-#define X(p, s, ...) sizeof(#p "_" #s) - 1,
-read_only global u32 vk_debug_extension_name_lengths[] = {VK_DEBUG_EXTENSIONS};
+#define X(p, s, ...) s8_comp(#p "_" #s),
+read_only global s8 vk_debug_extensions[] = {VK_DEBUG_EXTENSIONS};
#undef X
#define VK_INSTANCE_DEBUG_EXTENSIONS_LIST \
@@ -234,13 +231,24 @@ read_only global u32 vk_debug_extension_name_lengths[] = {VK_DEBUG_EXTENSIONS};
read_only global s8 vk_instance_debug_extensions[] = {VK_INSTANCE_DEBUG_EXTENSIONS_LIST};
#undef X
-global union {
- struct {
- #define X(_, name, ...) b8 name;
- VK_DEBUG_EXTENSIONS
- #undef X
- };
- b8 E[countof(vk_debug_extensions)];
+global struct {
+ union {
+ struct {
+ #define X(_, name, ...) b8 name;
+ VK_OPTIONAL_DEVICE_EXTENSIONS_LIST
+ #undef X
+ };
+ b8 E[countof(vk_optional_device_extensions)];
+ } optional;
+
+ union {
+ struct {
+ #define X(_, name, ...) b8 name;
+ VK_DEBUG_EXTENSIONS
+ #undef X
+ };
+ b8 E[countof(vk_debug_extensions)];
+ } debug;
union {
struct {
@@ -250,7 +258,12 @@ global union {
};
b8 E[countof(vk_instance_debug_extensions)];
} instance;
-} vulkan_debug;
+} vulkan_config;
+
+#define MAX_ENABLED_EXTENSIONS ( countof(vk_required_device_extensions) \
+ + countof(vk_optional_device_extensions) \
+ + countof(vk_debug_extensions) \
+ )
global VulkanContext vulkan_context[1];
@@ -355,7 +368,7 @@ vk_label_object_(VkObjectType kind, u64 handle, s8 label, s8 extra)
{
local_persist u8 buffer[1024];
Stream sb = arena_stream(arena_from_memory(buffer, sizeof(buffer)));
- if (vulkan_debug.instance.debug_utils && label.len > 0) {
+ if (vulkan_config.instance.debug_utils && label.len > 0) {
stream_append_s8s(&sb, label, s8(" ("), extra, s8(")"));
stream_append_byte(&sb, 0);
if (!sb.errors) {
@@ -473,7 +486,7 @@ glsl_to_spirv(Arena *arena, u32 kind, s8 shader_text, s8 name)
if (glslang_program_link(program, messages)) {
glslang_spv_options_t options = {.validate = 1,};
- if (vulkan_debug.shader_non_semantic_info) {
+ if (vulkan_config.debug.shader_non_semantic_info) {
options.generate_debug_info = 1;
options.emit_nonsemantic_shader_debug_info = 1;
options.emit_nonsemantic_shader_debug_source = 1;
@@ -1045,7 +1058,7 @@ vk_load_instance(Arena arena, Stream *err)
if (s8_equal(vk_instance_debug_extensions[it], instance_ext_s8s[i])) {
u32 index = enabled_instance_extensions_count++;
enabled_instance_extensions[index] = (char *)vk_instance_debug_extensions[it].data;
- vulkan_debug.instance.E[it] = 1;
+ vulkan_config.instance.E[it] = 1;
break;
}
}
@@ -1147,15 +1160,9 @@ vk_load_physical_device(Arena arena, Stream *err)
ext_str8s[index] = c_str_to_s8(extensions[index].extensionName);
b8 *supported = push_array(&scratch, b8, countof(vk_required_device_extensions));
- for (u32 index = 0; index < extension_count; index++) {
- for EachElement(vk_required_device_extensions, it) {
- s8 test = {
- .data = (u8 *)vk_required_device_extensions[it],
- .len = vk_required_device_extension_name_lengths[it],
- };
- supported[it] |= s8_equal(test, ext_str8s[index]);
- }
- }
+ for EachIndex(extension_count, index)
+ for EachElement(vk_required_device_extensions, it)
+ supported[it] |= s8_equal(vk_required_device_extensions[it], ext_str8s[index]);
u32 supported_count = 0;
for EachElement(vk_required_device_extensions, it)
@@ -1167,26 +1174,21 @@ vk_load_physical_device(Arena arena, Stream *err)
missing_count > 1 ? s8("s") : s8(""), s8(":\n"));
for EachElement(vk_required_device_extensions, it) {
if (!supported[it]) {
- s8 name = {
- .data = (u8 *)vk_required_device_extensions[it],
- .len = vk_required_device_extension_name_lengths[it],
- };
+ s8 name = vk_required_device_extensions[it];
stream_append_s8s(err, vulkan_info(" "), name, s8("\n"));
}
}
fatal(stream_to_s8(err));
}
+ for EachIndex(extension_count, index)
+ for EachElement(vk_optional_device_extensions, it)
+ vulkan_config.optional.E[it] |= s8_equal(vk_optional_device_extensions[it], ext_str8s[index]);
+
#if BEAMFORMER_DEBUG
- for (u32 index = 0; index < extension_count; index++) {
- for EachElement(vk_debug_extensions, it) {
- s8 test = {
- .data = (u8 *)vk_debug_extensions[it],
- .len = vk_debug_extension_name_lengths[it],
- };
- vulkan_debug.E[it] |= s8_equal(test, ext_str8s[index]);
- }
- }
+ for EachIndex(extension_count, index)
+ for EachElement(vk_debug_extensions, it)
+ vulkan_config.debug.E[it] |= s8_equal(vk_debug_extensions[it], ext_str8s[index]);
#endif
}
@@ -1259,6 +1261,39 @@ vk_load_physical_device(Arena arena, Stream *err)
fatal(stream_to_s8(err));
}
}
+
+ if (vulkan_config.optional.cooperative_matrix) {
+ Arena scratch = arena;
+ u32 property_count = 0;
+ vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(vk->physical_device, &property_count, 0);
+
+ VkCooperativeMatrixPropertiesKHR *mat = push_array(&scratch, VkCooperativeMatrixPropertiesKHR, property_count);
+
+ // NOTE(rnp): validation layer stupidity
+ for EachIndex(property_count, it)
+ mat[it].sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
+
+ vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(vk->physical_device, &property_count, mat);
+ b32 supported = 0;
+ // TODO(rnp): for now the requirements are hardcoded, it is possible to support a couple
+ // variations if needed.
+ for EachIndex(property_count, it) {
+ b32 match = 1;
+ supported &= mat[it].scope == VK_SCOPE_SUBGROUP_KHR;
+
+ supported &= mat[it].MSize == 16;
+ supported &= mat[it].NSize == 16;
+ supported &= mat[it].KSize == 16;
+
+ supported &= mat[it].AType == VK_COMPONENT_TYPE_FLOAT16_KHR;
+ supported &= mat[it].BType == VK_COMPONENT_TYPE_FLOAT16_KHR;
+ supported &= mat[it].CType == VK_COMPONENT_TYPE_FLOAT32_KHR;
+ supported &= mat[it].ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR;
+
+ supported |= match;
+ }
+ vk->gpu_info.cooperative_matrix = supported;
+ }
}
VkPhysicalDeviceMemoryProperties2 mp = {.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2};
@@ -1498,64 +1533,85 @@ vk_load_queues(Arena *memory, Stream *err)
queue_info_filled[base_q] = 1;
}
- VkPhysicalDeviceVulkan13Features v13f = {
- .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_3_FEATURES,
- #define X(name, ...) .name = 1,
- VK_REQUIRED_PHYSICAL_13_FEATURES
- #undef X
+ u32 enabled_count = 0;
+ const char *enabled_extensions[MAX_ENABLED_EXTENSIONS];
+
+ for EachElement(vk_required_device_extensions, it)
+ enabled_extensions[enabled_count++] = (char *)vk_required_device_extensions[it].data;
+
+ for EachElement(vk_optional_device_extensions, it)
+ if (vulkan_config.optional.E[it])
+ enabled_extensions[enabled_count++] = (char *)vk_optional_device_extensions[it].data;
+
+ for EachElement(vk_debug_extensions, it)
+ if (vulkan_config.debug.E[it])
+ enabled_extensions[enabled_count++] = (char *)vk_debug_extensions[it].data;
+
+ VkDeviceCreateInfo device_create_info = {
+ .sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO,
+ .pQueueCreateInfos = queue_create_infos,
+ .queueCreateInfoCount = queue_create_index,
+ .ppEnabledExtensionNames = enabled_extensions,
+ .enabledExtensionCount = enabled_count,
};
VkPhysicalDeviceShaderRelaxedExtendedInstructionFeaturesKHR pdsre = {
.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_RELAXED_EXTENDED_INSTRUCTION_FEATURES_KHR,
.shaderRelaxedExtendedInstruction = 1,
};
- if (vulkan_debug.shader_relaxed_extended_instruction) v13f.pNext = &pdsre;
+ if (vulkan_config.debug.shader_relaxed_extended_instruction) {
+ pdsre.pNext = (void *)device_create_info.pNext;
+ device_create_info.pNext = &pdsre;
+ }
+
+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coop_mat_features = {
+ .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR,
+ .cooperativeMatrix = 1,
+ .cooperativeMatrixRobustBufferAccess = 0,
+ };
+ if (vk->gpu_info.cooperative_matrix) {
+ coop_mat_features.pNext = (void *)device_create_info.pNext;
+ device_create_info.pNext = &coop_mat_features;
+ }
+
+ VkPhysicalDeviceVulkan13Features v13f = {
+ .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_3_FEATURES,
+ .pNext = (void *)device_create_info.pNext,
+ #define X(name, ...) .name = 1,
+ VK_REQUIRED_PHYSICAL_13_FEATURES
+ #undef X
+ };
+ device_create_info.pNext = &v13f;
VkPhysicalDeviceVulkan12Features v12f = {
.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES,
- .pNext = &v13f,
+ .pNext = (void *)device_create_info.pNext,
#define X(name, ...) .name = 1,
VK_REQUIRED_PHYSICAL_12_FEATURES
#undef X
};
+ device_create_info.pNext = &v12f;
VkPhysicalDeviceVulkan11Features v11f = {
.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES,
- .pNext = &v12f,
+ .pNext = (void *)device_create_info.pNext,
#define X(name, ...) .name = 1,
VK_REQUIRED_PHYSICAL_11_FEATURES
#undef X
};
+ device_create_info.pNext = &v11f;
+
VkPhysicalDeviceFeatures2 device_features = {
.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2,
- .pNext = &v11f,
+ .pNext = (void *)device_create_info.pNext,
.features = {
#define X(name, ...) .name = 1,
VK_REQUIRED_PHYSICAL_FEATURES
#undef X
},
};
+ device_create_info.pNext = &device_features;
- Arena arena = *memory;
- u32 enabled_count = countof(vk_required_device_extensions) + countof(vk_debug_extensions);
- const char **enabled_extensions = push_array(&arena, const char *, enabled_count);
-
- enabled_count = 0;
- for EachElement(vk_required_device_extensions, it)
- enabled_extensions[enabled_count++] = vk_required_device_extensions[it];
-
- for EachElement(vk_debug_extensions, it)
- if (vulkan_debug.E[it])
- enabled_extensions[enabled_count++] = vk_debug_extensions[it];
-
- VkDeviceCreateInfo device_create_info = {
- .sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO,
- .pNext = &device_features,
- .pQueueCreateInfos = queue_create_infos,
- .queueCreateInfoCount = queue_create_index,
- .ppEnabledExtensionNames = enabled_extensions,
- .enabledExtensionCount = enabled_count,
- };
vkCreateDevice(vk->physical_device, &device_create_info, 0, &vk->device);
#define X(name, ...) name = (name##_fn *)vkGetDeviceProcAddr(vk->device, #name);
diff --git a/vulkan.h b/vulkan.h
@@ -138,6 +138,8 @@ typedef enum {
VK_STRUCTURE_TYPE_COPY_BUFFER_INFO_2 = 1000337000,
VK_STRUCTURE_TYPE_BUFFER_COPY_2 = 1000337006,
VK_STRUCTURE_TYPE_FORMAT_PROPERTIES_3 = 1000360000,
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR = 1000506000,
+ VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR = 1000506001,
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_RELAXED_EXTENDED_INSTRUCTION_FEATURES_KHR = 1000558000,
VK_STRUCTURE_TYPE_MAX_ENUM = 0x7FFFFFFF,
} VkStructureType;
@@ -2969,6 +2971,55 @@ typedef struct {
} VkWriteDescriptorSet;
typedef enum {
+ VK_COMPONENT_TYPE_FLOAT16_KHR = 0,
+ VK_COMPONENT_TYPE_FLOAT32_KHR = 1,
+ VK_COMPONENT_TYPE_FLOAT64_KHR = 2,
+ VK_COMPONENT_TYPE_SINT8_KHR = 3,
+ VK_COMPONENT_TYPE_SINT16_KHR = 4,
+ VK_COMPONENT_TYPE_SINT32_KHR = 5,
+ VK_COMPONENT_TYPE_SINT64_KHR = 6,
+ VK_COMPONENT_TYPE_UINT8_KHR = 7,
+ VK_COMPONENT_TYPE_UINT16_KHR = 8,
+ VK_COMPONENT_TYPE_UINT32_KHR = 9,
+ VK_COMPONENT_TYPE_UINT64_KHR = 10,
+ VK_COMPONENT_TYPE_BFLOAT16_KHR = 1000141000,
+ VK_COMPONENT_TYPE_SINT8_PACKED_NV = 1000491000,
+ VK_COMPONENT_TYPE_UINT8_PACKED_NV = 1000491001,
+ VK_COMPONENT_TYPE_FLOAT8_E4M3_EXT = 1000491002,
+ VK_COMPONENT_TYPE_FLOAT8_E5M2_EXT = 1000491003,
+ VK_COMPONENT_TYPE_MAX_ENUM_KHR = 0x7FFFFFFF
+} VkComponentTypeKHR;
+
+typedef enum {
+ VK_SCOPE_DEVICE_KHR = 1,
+ VK_SCOPE_WORKGROUP_KHR = 2,
+ VK_SCOPE_SUBGROUP_KHR = 3,
+ VK_SCOPE_QUEUE_FAMILY_KHR = 5,
+ VK_SCOPE_MAX_ENUM_KHR = 0x7FFFFFFF
+} VkScopeKHR;
+
+typedef struct {
+ VkStructureType sType;
+ void * pNext;
+ uint32_t MSize;
+ uint32_t NSize;
+ uint32_t KSize;
+ VkComponentTypeKHR AType;
+ VkComponentTypeKHR BType;
+ VkComponentTypeKHR CType;
+ VkComponentTypeKHR ResultType;
+ VkBool32 saturatingAccumulation;
+ VkScopeKHR scope;
+} VkCooperativeMatrixPropertiesKHR;
+
+typedef struct {
+ VkStructureType sType;
+ void * pNext;
+ VkBool32 cooperativeMatrix;
+ VkBool32 cooperativeMatrixRobustBufferAccess;
+} VkPhysicalDeviceCooperativeMatrixFeaturesKHR;
+
+typedef enum {
VK_VALIDATION_FEATURE_ENABLE_GPU_ASSISTED_EXT = 0,
VK_VALIDATION_FEATURE_ENABLE_GPU_ASSISTED_RESERVE_BINDING_SLOT_EXT = 1,
VK_VALIDATION_FEATURE_ENABLE_BEST_PRACTICES_EXT = 2,
@@ -3022,13 +3073,13 @@ typedef struct {
X(vkEnumerateDeviceExtensionProperties, VkResult, (VkPhysicalDevice physicalDevice, const char *pLayerName, uint32_t *pPropertyCount, VkExtensionProperties *pProperties)) \
X(vkEnumeratePhysicalDevices, VkResult, (VkInstance instance, uint32_t *pPhysicalDeviceCount, VkPhysicalDevice *pPhysicalDevices)) \
X(vkGetDeviceProcAddr, void *, (VkDevice device, const char *pName)) \
+ X(vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR, VkResult, (VkPhysicalDevice physicalDevice, uint32_t *pPropertyCount, VkCooperativeMatrixPropertiesKHR *pProperties)) \
X(vkGetPhysicalDeviceFeatures2, void, (VkPhysicalDevice physicalDevice, VkPhysicalDeviceFeatures2 *pFeatures)) \
X(vkGetPhysicalDeviceFormatProperties2, void, (VkPhysicalDevice physicalDevice, VkFormat format, VkFormatProperties2 *pFormatProperties)) \
X(vkGetPhysicalDeviceMemoryProperties2, void, (VkPhysicalDevice physicalDevice, VkPhysicalDeviceMemoryProperties2 *pMemoryProperties)) \
X(vkGetPhysicalDeviceProperties2, void, (VkPhysicalDevice physicalDevice, VkPhysicalDeviceProperties2 *pProperties)) \
X(vkGetPhysicalDeviceQueueFamilyProperties, void, (VkPhysicalDevice physicalDevice, uint32_t *pQueueFamilyPropertyCount, VkQueueFamilyProperties *pQueueFamilyProperties)) \
-
/* X(name, ret, params) */
#define VkDeviceProcedureList \
X(vkAllocateCommandBuffers, VkResult, (VkDevice device, const VkCommandBufferAllocateInfo *pAllocateInfo, VkCommandBuffer *pCommandBuffers)) \