Commit: 3e4bea29377e32bee5ef97cc5efef02310587b1b
Parent: 218fb7e140c058278ee9396b17569571ceb18e9e
Author: Randy Palamar
Date: Sun, 18 Jan 2026 12:28:57 -0700
vulkan: implement compute shader compilation api
Diffstat:
| M | beamformer_internal.h | | | 11 | +++++++++-- |
| M | vulkan.c | | | 135 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------- |
| M | vulkan.h | | | 60 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- |
3 files changed, 165 insertions(+), 41 deletions(-)
diff --git a/beamformer_internal.h b/beamformer_internal.h
@@ -21,13 +21,13 @@
#define os_path_separator() (s8){.data = &os_system_info()->path_separator_byte, .len = 1}
+typedef struct { u64 value[1]; } VulkanHandle;
+
typedef enum {
GPUBufferCreateFlags_HostWritable = 1 << 0,
GPUBufferCreateFlags_MemoryOnly = 1 << 1,
} GPUBufferCreateFlags;
-typedef struct { u64 value[1]; } VulkanHandle;
-
typedef struct {
u64 gpu_pointer;
i64 size;
@@ -70,6 +70,13 @@ DEBUG_IMPORT void vk_buffer_release(GPUBuffer *);
DEBUG_IMPORT void vk_buffer_range_upload(GPUBuffer *, void *data, u64 offset, u64 size, b32 non_temporal);
DEBUG_IMPORT u64 vk_round_up_to_sync_size(u64, u64 min);
+/* NOTE: Compute shaders do not have bindings. Data should be passed using push constants.
+ * In particular the push constants should contain pointers to gpu memory using the
+ * BufferDeviceAddress extension. */
+// TODO(rnp): change this to accept SPIR-V directly and accept BakeParameters as specialization data
+DEBUG_IMPORT VulkanHandle vk_compute_shader(s8 text, s8 name);
+DEBUG_IMPORT void vk_compute_shader_release(VulkanHandle);
+
// NOTE: temporary API
DEBUG_IMPORT b32 vk_buffer_needs_sync(GPUBuffer *);
diff --git a/vulkan.c b/vulkan.c
@@ -5,6 +5,8 @@
#define glslang_info(s) s8("[glslang] " s)
#define vulkan_info(s) s8("[vulkan] " s)
+#define ValidVulkanHandle(h) ((h).value[0] != 0)
+
typedef enum {
VulkanQueueKind_Graphics,
VulkanQueueKind_Compute,
@@ -28,9 +30,15 @@ typedef struct {
VulkanMemoryKind memory_kind;
} VulkanBuffer;
+typedef struct {
+ VkPipeline pipeline;
+ VkPipelineLayout layout;
+} VulkanShader;
+
typedef enum {
VulkanEntityKind_Buffer,
VulkanEntityKind_Semaphore,
+ VulkanEntityKind_Shader,
} VulkanEntityKind;
typedef struct VulkanEntity VulkanEntity;
@@ -38,9 +46,10 @@ struct VulkanEntity {
VulkanEntity * next;
VulkanEntityKind kind;
union {
- VulkanBuffer buffer;
- VkSemaphore semaphore;
- };
+ VulkanBuffer buffer;
+ VkSemaphore semaphore;
+ VulkanShader shader;
+ } as;
};
typedef alignas(64) struct {
@@ -63,8 +72,8 @@ typedef struct {
VkDevice device;
VkPhysicalDevice physical_device;
- // NOTE(rnp): fallback module for when a compute shader fails to compile
- VkShaderModule default_compute_module;
+ // NOTE(rnp): fallback for when a compute shader fails to compile
+ VulkanShader default_compute_shader;
GPUInfo gpu_info;
@@ -229,6 +238,14 @@ vk_entity_release(VulkanEntity *entity)
}
}
+function void *
+vk_entity_data(VulkanHandle h, VulkanEntityKind kind)
+{
+ VulkanEntity *e = (VulkanEntity *)h.value[0];
+ assert(ValidVulkanHandle(h) && e->kind == kind);
+ return &e->as;
+}
+
#define glslang_log(a, ...) glslang_log_(a, arg_list(s8, __VA_ARGS__))
function void
glslang_log_(Arena arena, s8 *items, uz count)
@@ -323,20 +340,41 @@ vk_shader_kind_to_glslang_shader_kind(u32 kind)
}
function VkShaderModule
-vk_compile_shader_module(u32 kind, s8 shader, s8 name)
+vk_compile_shader_module(Arena arena, u32 kind, s8 text, s8 name)
{
VkShaderModule result = 0;
+ s8 spirv = glsl_to_spirv(&arena, vk_shader_kind_to_glslang_shader_kind(kind), text, name);
+ VkShaderModuleCreateInfo create_info = {
+ .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
+ .codeSize = (uz)spirv.len,
+ .pCode = (u32 *)spirv.data,
+ };
+ if (spirv.len > 0) vkCreateShaderModule(vulkan_context->device, &create_info, 0, &result);
+ return result;
+}
- DeferLoop(take_lock(&vulkan_context->arena_lock, -1), release_lock(&vulkan_context->arena_lock))
- {
- Arena arena = vulkan_context->arena;
- s8 spirv = glsl_to_spirv(&arena, vk_shader_kind_to_glslang_shader_kind(kind), shader, name);
- VkShaderModuleCreateInfo create_info = {
- .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
- .codeSize = (uz)spirv.len,
- .pCode = (u32 *)spirv.data,
+function VulkanShader
+vk_compute_pipeline_from_shader_text(Arena arena, s8 text, s8 name)
+{
+ VulkanShader result = {0};
+ VkShaderModule module = vk_compile_shader_module(arena, VK_SHADER_STAGE_COMPUTE_BIT, text, name);
+ if (module) {
+ VkPipelineLayoutCreateInfo pli = {.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO};
+ vkCreatePipelineLayout(vulkan_context->device, &pli, 0, &result.layout);
+
+ VkComputePipelineCreateInfo pi = {
+ .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
+ .layout = result.layout,
+ .stage = {
+ .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
+ .stage = VK_SHADER_STAGE_COMPUTE_BIT,
+ .module = module,
+ .pName = "main",
+ },
};
- if (spirv.len > 0) vkCreateShaderModule(vulkan_context->device, &create_info, 0, &result);
+
+ vkCreateComputePipelines(vulkan_context->device, 0, 1, &pi, 0, &result.pipeline);
+ vkDestroyShaderModule(vulkan_context->device, module, 0);
}
return result;
@@ -727,8 +765,9 @@ vk_load(OSLibrary vulkan_library_handle, Arena *memory, Stream *err)
fatal(stream_to_s8(err));
}
- vulkan_context->entity_arena = sub_arena_end(memory, KB(64), KB(4));
- vulkan_context->arena = sub_arena_end(memory, KB(96), KB(4));
+ VulkanContext *vk = vulkan_context;
+ vk->entity_arena = sub_arena_end(memory, KB(64), KB(4));
+ vk->arena = sub_arena_end(memory, KB(96), KB(4));
vk_load_instance();
vk_load_physical_device(vulkan_context->arena, err);
@@ -740,8 +779,7 @@ vk_load(OSLibrary vulkan_library_handle, Arena *memory, Stream *err)
"void main() {}\n"
"\n");
- vulkan_context->default_compute_module = vk_compile_shader_module(VK_SHADER_STAGE_COMPUTE_BIT,
- default_compute_shader,
+ vk->default_compute_shader = vk_compute_pipeline_from_shader_text(vk->arena, default_compute_shader,
s8("error_compute_shader"));
// TODO: setup render pipeline
@@ -762,10 +800,8 @@ DEBUG_IMPORT void
vk_buffer_release(GPUBuffer *b)
{
VulkanContext *vk = vulkan_context;
- VulkanEntity *e = (VulkanEntity *)b->buffer.value[0];
- if (e) {
- assert(e->kind == VulkanEntityKind_Buffer);
- VulkanBuffer *vb = &e->buffer;
+ if ValidVulkanHandle(b->buffer) {
+ VulkanBuffer *vb = vk_entity_data(b->buffer, VulkanEntityKind_Buffer);
// TODO(rnp): this happens implicitly, probably just delete this if block
if (vb->host_pointer)
vkUnmapMemory(vk->device, vb->memory);
@@ -777,7 +813,7 @@ vk_buffer_release(GPUBuffer *b)
if (vb->memory_kind != VulkanMemoryKind_Host)
vk->gpu_info.gpu_heap_used -= b->size;
- vk_entity_release(e);
+ vk_entity_release((VulkanEntity *)b->buffer.value[0]);
}
zero_struct(b);
}
@@ -788,7 +824,7 @@ vk_buffer_allocate(GPUBuffer *b, iz size, GPUBufferCreateFlags flags, OSHandle *
vk_buffer_release(b);
VulkanContext *vk = vulkan_context;
VulkanEntity *e = vk_entity_allocate(VulkanEntityKind_Buffer);
- VulkanBuffer *vb = &e->buffer;
+ VulkanBuffer *vb = &e->as.buffer;
b->buffer.value[0] = (u64)e;
@@ -871,10 +907,8 @@ DEBUG_IMPORT b32
vk_buffer_needs_sync(GPUBuffer *b)
{
b32 result = 0;
- VulkanEntity *e = (VulkanEntity *)b->buffer.value[0];
- if (e) {
- assert(e->kind == VulkanEntityKind_Buffer);
- VulkanBuffer *vb = &e->buffer;
+ if ValidVulkanHandle(b->buffer) {
+ VulkanBuffer *vb = vk_entity_data(b->buffer, VulkanEntityKind_Buffer);
// TODO(rnp): not correct check. need to check if we used transfer queue
result = vb->memory_kind != VulkanMemoryKind_BAR;
@@ -894,13 +928,8 @@ vk_round_up_to_sync_size(u64 size, u64 min)
DEBUG_IMPORT void
vk_buffer_range_upload(GPUBuffer *b, void *data, u64 offset, u64 size, b32 non_temporal)
{
- assert(ValidHandle(b->buffer));
-
VulkanContext *vk = vulkan_context;
- VulkanEntity *e = (VulkanEntity *)b->buffer.value[0];
- VulkanBuffer *vb = &e->buffer;
-
- assert(e->kind == VulkanEntityKind_Buffer);
+ VulkanBuffer *vb = vk_entity_data(b->buffer, VulkanEntityKind_Buffer);
switch (vb->memory_kind) {
case VulkanMemoryKind_Host:
@@ -945,14 +974,14 @@ vk_semaphore_create(OSHandle *export)
VulkanEntity *e = vk_entity_allocate(VulkanEntityKind_Semaphore);
VulkanHandle result = {(u64)e};
- vkCreateSemaphore(vk->device, &sci, 0, &e->semaphore);
+ vkCreateSemaphore(vk->device, &sci, 0, &e->as.semaphore);
if (export) {
if (OS_WINDOWS) {
VkSemaphoreGetWin32HandleInfoKHR ghi = {
.sType = VK_STRUCTURE_TYPE_SEMAPHORE_GET_WIN32_HANDLE_INFO_KHR,
.handleType = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_BIT,
- .semaphore = e->semaphore,
+ .semaphore = e->as.semaphore,
};
void *handle;
vkGetSemaphoreWin32HandleKHR(vk->device, &ghi, &handle);
@@ -961,7 +990,7 @@ vk_semaphore_create(OSHandle *export)
VkSemaphoreGetFdInfoKHR ghi = {
.sType = VK_STRUCTURE_TYPE_SEMAPHORE_GET_FD_INFO_KHR,
.handleType = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD_BIT,
- .semaphore = e->semaphore,
+ .semaphore = e->as.semaphore,
};
i32 handle;
vkGetSemaphoreFdKHR(vk->device, &ghi, &handle);
@@ -971,3 +1000,33 @@ vk_semaphore_create(OSHandle *export)
return result;
}
+
+DEBUG_IMPORT VulkanHandle
+vk_compute_shader(s8 text, s8 name)
+{
+ VulkanHandle result = {0};
+ DeferLoop(take_lock(&vulkan_context->arena_lock, -1), release_lock(&vulkan_context->arena_lock))
+ {
+ Arena arena = vulkan_context->arena;
+
+ VulkanEntity *e = vk_entity_allocate(VulkanEntityKind_Shader);
+ result = (VulkanHandle){(u64)e};
+
+ e->as.shader = vk_compute_pipeline_from_shader_text(arena, text, name);
+ if (e->as.shader.pipeline == 0) e->as.shader = vulkan_context->default_compute_shader;
+ }
+ return result;
+}
+
+DEBUG_IMPORT void
+vk_compute_shader_release(VulkanHandle h)
+{
+ if ValidVulkanHandle(h) {
+ VulkanShader *vs = vk_entity_data(h, VulkanEntityKind_Shader);
+ if (vs->pipeline != vulkan_context->default_compute_shader.pipeline) {
+ vkDestroyPipeline(vulkan_context->device, vs->pipeline, 0);
+ vkDestroyPipelineLayout(vulkan_context->device, vs->layout, 0);
+ }
+ vk_entity_release((VulkanEntity *)h.value[0]);
+ }
+}
diff --git a/vulkan.h b/vulkan.h
@@ -43,6 +43,7 @@ VK_HANDLE(VkPipelineCache);
VK_HANDLE(VkPipelineLayout);
VK_HANDLE(VkQueue);
VK_HANDLE(VkRenderPass);
+VK_HANDLE(VkSampler);
VK_HANDLE(VkSemaphore);
VK_HANDLE(VkShaderModule);
VK_HANDLE(VkSurfaceKHR);
@@ -147,7 +148,7 @@ typedef enum {
} VkDeviceQueueCreateFlagBits;
typedef VkFlags VkDeviceQueueCreateFlags;
-typedef enum VkPipelineStageFlagBits {
+typedef enum {
VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT = 0x00000001,
VK_PIPELINE_STAGE_DRAW_INDIRECT_BIT = 0x00000002,
VK_PIPELINE_STAGE_VERTEX_INPUT_BIT = 0x00000004,
@@ -1066,12 +1067,47 @@ typedef VkFlags VkPipelineRasterizationStateCreateFlags;
typedef VkFlags VkPipelineMultisampleStateCreateFlags;
typedef enum {
+ VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT = 0x00000001,
+ VK_DESCRIPTOR_SET_LAYOUT_CREATE_UPDATE_AFTER_BIND_POOL_BIT = 0x00000002,
+ VK_DESCRIPTOR_SET_LAYOUT_CREATE_HOST_ONLY_POOL_BIT_EXT = 0x00000004,
+ VK_DESCRIPTOR_SET_LAYOUT_CREATE_DESCRIPTOR_BUFFER_BIT_EXT = 0x00000010,
+ VK_DESCRIPTOR_SET_LAYOUT_CREATE_EMBEDDED_IMMUTABLE_SAMPLERS_BIT_EXT = 0x00000020,
+ VK_DESCRIPTOR_SET_LAYOUT_CREATE_PER_STAGE_BIT_NV = 0x00000040,
+ VK_DESCRIPTOR_SET_LAYOUT_CREATE_INDIRECT_BINDABLE_BIT_NV = 0x00000080,
+ VK_DESCRIPTOR_SET_LAYOUT_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF
+} VkDescriptorSetLayoutCreateFlagBits;
+typedef VkFlags VkDescriptorSetLayoutCreateFlags;
+
+typedef enum {
VK_ATTACHMENT_DESCRIPTION_MAY_ALIAS_BIT = 0x00000001,
VK_ATTACHMENT_DESCRIPTION_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF
} VkAttachmentDescriptionFlagBits;
typedef VkFlags VkAttachmentDescriptionFlags;
typedef enum {
+ VK_DESCRIPTOR_TYPE_SAMPLER = 0,
+ VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER = 1,
+ VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE = 2,
+ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE = 3,
+ VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER = 4,
+ VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER = 5,
+ VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER = 6,
+ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER = 7,
+ VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC = 8,
+ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC = 9,
+ VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT = 10,
+ VK_DESCRIPTOR_TYPE_INLINE_UNIFORM_BLOCK = 1000138000,
+ VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR = 1000150000,
+ VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_NV = 1000165000,
+ VK_DESCRIPTOR_TYPE_SAMPLE_WEIGHT_IMAGE_QCOM = 1000440000,
+ VK_DESCRIPTOR_TYPE_BLOCK_MATCH_IMAGE_QCOM = 1000440001,
+ VK_DESCRIPTOR_TYPE_TENSOR_ARM = 1000460000,
+ VK_DESCRIPTOR_TYPE_MUTABLE_EXT = 1000351000,
+ VK_DESCRIPTOR_TYPE_PARTITIONED_ACCELERATION_STRUCTURE_NV = 1000570000,
+ VK_DESCRIPTOR_TYPE_MAX_ENUM = 0x7FFFFFFF
+} VkDescriptorType;
+
+typedef enum {
VK_ATTACHMENT_LOAD_OP_LOAD = 0,
VK_ATTACHMENT_LOAD_OP_CLEAR = 1,
VK_ATTACHMENT_LOAD_OP_DONT_CARE = 2,
@@ -2141,6 +2177,23 @@ typedef struct {
const VkClearValue * pClearValues;
} VkRenderPassBeginInfo;
+typedef struct {
+ uint32_t binding;
+ VkDescriptorType descriptorType;
+ uint32_t descriptorCount;
+ VkShaderStageFlags stageFlags;
+ const VkSampler * pImmutableSamplers;
+} VkDescriptorSetLayoutBinding;
+
+typedef struct {
+ VkStructureType sType;
+ const void * pNext;
+ VkDescriptorSetLayoutCreateFlags flags;
+ uint32_t bindingCount;
+ const VkDescriptorSetLayoutBinding * pBindings;
+} VkDescriptorSetLayoutCreateInfo;
+
+
/* X(name, ret, params) */
#define VkLoaderProcedureList \
X(vkGetInstanceProcAddr, void *, (VkInstance instance, const char *pName)) \
@@ -2163,9 +2216,14 @@ typedef struct {
/* X(name, ret, params) */
#define VkDeviceProcedureList \
X(vkAllocateMemory, VkResult, (VkDevice device, const VkMemoryAllocateInfo *pAllocateInfo, const VkAllocationCallbacks *pAllocator, VkDeviceMemory *pMemory)) \
+ X(vkCreateComputePipelines, VkResult, (VkDevice device, VkPipelineCache pipelineCache, uint32_t createInfoCount, const VkComputePipelineCreateInfo *pCreateInfos, const VkAllocationCallbacks *pAllocator, VkPipeline *pPipelines)) \
+ X(vkCreatePipelineLayout, VkResult, (VkDevice device, const VkPipelineLayoutCreateInfo *pCreateInfo, const VkAllocationCallbacks *pAllocator, VkPipelineLayout *pPipelineLayout)) \
X(vkCreateSemaphore, VkResult, (VkDevice device, const VkSemaphoreCreateInfo *pCreateInfo, const VkAllocationCallbacks *pAllocator, VkSemaphore *pSemaphore)) \
X(vkCreateShaderModule, VkResult, (VkDevice device, const VkShaderModuleCreateInfo *pCreateInfo, const VkAllocationCallbacks *pAllocator, VkShaderModule *pShaderModule)) \
X(vkDestroyBuffer, void, (VkDevice device, VkBuffer buffer, const VkAllocationCallbacks *pAllocator)) \
+ X(vkDestroyPipeline, void, (VkDevice device, VkPipeline pipeline, const VkAllocationCallbacks *pAllocator)) \
+ X(vkDestroyPipelineLayout, void, (VkDevice device, VkPipelineLayout pipelineLayout, const VkAllocationCallbacks *pAllocator)) \
+ X(vkDestroyShaderModule, void, (VkDevice device, VkShaderModule shaderModule, const VkAllocationCallbacks *pAllocator)) \
X(vkFlushMappedMemoryRanges, VkResult, (VkDevice device, uint32_t memoryRangeCount, const VkMappedMemoryRange *pMemoryRanges)) \
X(vkFreeMemory, void, (VkDevice device, VkDeviceMemory memory, const VkAllocationCallbacks *pAllocator)) \
X(vkGetDeviceQueue, void, (VkDevice device, uint32_t queueFamilyIndex, uint32_t queueIndex, VkQueue *pQueue)) \