Skip to content

Commit

Permalink
Refactor Vulkan Descriptor Type and Shader Stage set up.
Browse files Browse the repository at this point in the history
  • Loading branch information
luigifcruz committed May 26, 2024
1 parent 8f8c510 commit c9237a8
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 67 deletions.
8 changes: 8 additions & 0 deletions include/jetstream/render/base/buffer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ class Buffer {
Config config;
};

static constexpr Buffer::Target operator&(Buffer::Target a, Buffer::Target b) {
return static_cast<Buffer::Target>(static_cast<U64>(a) & static_cast<U64>(b));
}

static constexpr Buffer::Target operator|(Buffer::Target a, Buffer::Target b) {
return static_cast<Buffer::Target>(static_cast<U64>(a) | static_cast<U64>(b));
}

} // namespace Jetstream::Render

#endif
3 changes: 0 additions & 3 deletions include/jetstream/render/metal/kernel.hh
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class KernelImp<Device::Metal> : public Kernel {

std::vector<std::shared_ptr<BufferImp<Device::Metal>>> buffers;

static Result checkShaderCompilation(U64);
static Result checkProgramCompilation(U64);

friend class SurfaceImp<Device::Metal>;
};

Expand Down
3 changes: 0 additions & 3 deletions include/jetstream/render/metal/program.hh
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ class ProgramImp<Device::Metal> : public Program {
std::vector<std::shared_ptr<TextureImp<Device::Metal>>> textures;
std::vector<std::pair<std::shared_ptr<BufferImp<Device::Metal>>, Program::Target>> buffers;

static Result checkShaderCompilation(U64);
static Result checkProgramCompilation(U64);

friend class SurfaceImp<Device::Metal>;
};

Expand Down
5 changes: 0 additions & 5 deletions include/jetstream/render/vulkan/buffer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,9 @@ class BufferImp<Device::Vulkan> : public Buffer {
return buffer;
}

constexpr const VkDescriptorType& getDescriptorType() const {
return descriptorType;
}

private:
VkBuffer buffer;
VkDeviceMemory memory;
VkDescriptorType descriptorType;

friend class SurfaceImp<Device::Vulkan>;
friend class ProgramImp<Device::Vulkan>;
Expand Down
3 changes: 1 addition & 2 deletions include/jetstream/render/vulkan/kernel.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ class KernelImp<Device::Vulkan> : public Kernel {

std::vector<std::shared_ptr<BufferImp<Device::Vulkan>>> buffers;

static Result checkShaderCompilation(U64);
static Result checkProgramCompilation(U64);
static VkDescriptorType BufferDescriptorType(const std::shared_ptr<Buffer>& buffer);

friend class SurfaceImp<Device::Vulkan>;
};
Expand Down
3 changes: 2 additions & 1 deletion include/jetstream/render/vulkan/program.hh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class ProgramImp<Device::Vulkan> : public Program {
std::vector<std::shared_ptr<TextureImp<Device::Vulkan>>> textures;
std::vector<std::pair<std::shared_ptr<BufferImp<Device::Vulkan>>, Program::Target>> buffers;

static VkShaderStageFlags TargetToVulkan(const Program::Target& target);
static VkShaderStageFlags TargetToShaderStage(const Program::Target& target);
static VkDescriptorType BufferDescriptorType(const std::shared_ptr<Buffer>& buffer);

friend class SurfaceImp<Device::Vulkan>;
};
Expand Down
2 changes: 0 additions & 2 deletions include/jetstream/render/webgpu/program.hh
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class ProgramImp<Device::WebGPU> : public Program {
std::vector<std::shared_ptr<TextureImp<Device::WebGPU>>> textures;
std::vector<std::pair<std::shared_ptr<BufferImp<Device::WebGPU>>, Program::Target>> buffers;

static Result checkShaderCompilation(U64);
static Result checkProgramCompilation(U64);
static wgpu::ShaderStage TargetToWebGPU(const Program::Target& target);

friend class SurfaceImp<Device::WebGPU>;
Expand Down
3 changes: 2 additions & 1 deletion src/memory/vulkan/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ Implementation::TensorBuffer(std::shared_ptr<TensorStorageMetadata>&,
// TODO: Add a global way to specify usage.
bufferInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_VERTEX_BUFFER_BIT;
VK_BUFFER_USAGE_VERTEX_BUFFER_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
bufferInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
bufferInfo.pNext = &extImageCreateInfo;

Expand Down
8 changes: 4 additions & 4 deletions src/modules/lineplot/generic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Result Lineplot<D, T>::createPresent() {
cfg.buffer = gridPoints.data();
cfg.elementByteSize = sizeof(F32);
cfg.size = gridPoints.size();
cfg.target = Render::Buffer::Target::VERTEX;
cfg.target = Render::Buffer::Target::STORAGE;
cfg.enableZeroCopy = false;
JST_CHECK(window->build(gridPointsBuffer, cfg));
}
Expand All @@ -108,7 +108,7 @@ Result Lineplot<D, T>::createPresent() {
cfg.buffer = buffer;
cfg.elementByteSize = sizeof(F32);
cfg.size = gridVertices.size();
cfg.target = Render::Buffer::Target::VERTEX;
cfg.target = Render::Buffer::Target::VERTEX | Render::Buffer::Target::STORAGE;
cfg.enableZeroCopy = enableZeroCopy;
JST_CHECK(window->build(gridVerticesBuffer, cfg));
}
Expand Down Expand Up @@ -165,7 +165,7 @@ Result Lineplot<D, T>::createPresent() {
cfg.buffer = buffer;
cfg.elementByteSize = sizeof(F32);
cfg.size = signalPoints.size();
cfg.target = Render::Buffer::Target::VERTEX;
cfg.target = Render::Buffer::Target::STORAGE;
cfg.enableZeroCopy = enableZeroCopy;
JST_CHECK(window->build(signalPointsBuffer, cfg));
}
Expand All @@ -177,7 +177,7 @@ Result Lineplot<D, T>::createPresent() {
cfg.buffer = buffer;
cfg.elementByteSize = sizeof(F32);
cfg.size = signalVertices.size();
cfg.target = Render::Buffer::Target::VERTEX;
cfg.target = Render::Buffer::Target::VERTEX | Render::Buffer::Target::STORAGE;
cfg.enableZeroCopy = enableZeroCopy;
JST_CHECK(window->build(signalVerticesBuffer, cfg));
}
Expand Down
55 changes: 17 additions & 38 deletions src/render/vulkan/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,6 @@ Result Implementation::create() {
auto& device = Backend::State<Device::Vulkan>()->getDevice();
auto& physicalDevice = Backend::State<Device::Vulkan>()->getPhysicalDevice();

switch (config.target) {
case Target::VERTEX:
break;
case Target::VERTEX_INDICES:
break;
case Target::STORAGE:
descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
break;
case Target::UNIFORM:
descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
break;
case Target::STORAGE_DYNAMIC:
descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC;
break;
case Target::UNIFORM_DYNAMIC:
descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC;
break;
}

if (config.enableZeroCopy) {
buffer = reinterpret_cast<VkBuffer>(config.buffer);
} else {
Expand All @@ -41,25 +22,23 @@ Result Implementation::create() {
VkBufferUsageFlags bufferUsageFlag = 0;
bufferUsageFlag |= VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
bufferUsageFlag |= VK_BUFFER_USAGE_TRANSFER_DST_BIT;
switch (config.target) {
case Target::VERTEX:
bufferUsageFlag |= VK_BUFFER_USAGE_VERTEX_BUFFER_BIT;
break;
case Target::VERTEX_INDICES:
bufferUsageFlag |= VK_BUFFER_USAGE_INDEX_BUFFER_BIT;
break;
case Target::STORAGE:
bufferUsageFlag |= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
break;
case Target::UNIFORM:
bufferUsageFlag |= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
break;
case Target::STORAGE_DYNAMIC:
bufferUsageFlag |= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
break;
case Target::UNIFORM_DYNAMIC:
bufferUsageFlag |= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
break;

if ((config.target & Target::VERTEX) == Target::VERTEX) {
bufferUsageFlag |= VK_BUFFER_USAGE_VERTEX_BUFFER_BIT;
}

if ((config.target & Target::VERTEX_INDICES) == Target::VERTEX_INDICES) {
bufferUsageFlag |= VK_BUFFER_USAGE_INDEX_BUFFER_BIT;
}

if ((config.target & Target::STORAGE) == Target::STORAGE ||
(config.target & Target::STORAGE_DYNAMIC) == Target::STORAGE_DYNAMIC) {
bufferUsageFlag |= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
}

if ((config.target & Target::UNIFORM) == Target::UNIFORM ||
(config.target & Target::UNIFORM_DYNAMIC) == Target::UNIFORM_DYNAMIC) {
bufferUsageFlag |= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
}

// Create buffer.
Expand Down
21 changes: 17 additions & 4 deletions src/render/vulkan/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ Result Implementation::create() {

VkDescriptorSetLayoutBinding binding{};
binding.binding = i;
// TODO: Why buffer->getDescriptorType() doesn't work?
binding.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
binding.descriptorType = BufferDescriptorType(buffer);
binding.descriptorCount = 1;
binding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;

Expand Down Expand Up @@ -89,8 +88,7 @@ Result Implementation::create() {
descriptorWriteBuffer.dstSet = descriptorSet;
descriptorWriteBuffer.dstBinding = i;
descriptorWriteBuffer.dstArrayElement = 0;
// TODO: Why buffer->getDescriptorType() doesn't work?
descriptorWriteBuffer.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
descriptorWriteBuffer.descriptorType = BufferDescriptorType(buffer);
descriptorWriteBuffer.descriptorCount = 1;
descriptorWriteBuffer.pBufferInfo = &bufferInfo;

Expand Down Expand Up @@ -169,4 +167,19 @@ Result Implementation::encode(VkCommandBuffer& commandBuffer) {
return Result::SUCCESS;
}

VkDescriptorType Implementation::BufferDescriptorType(const std::shared_ptr<Buffer>& buffer) {
const auto& bufferType = buffer->getConfig().target;

if ((bufferType & Buffer::Target::UNIFORM) == Buffer::Target::UNIFORM) {
return VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
}

if ((bufferType & Buffer::Target::STORAGE) == Buffer::Target::STORAGE) {
return VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
}

JST_ERROR("[VULKAN] Invalid buffer usage.");
throw Result::ERROR;
}

} // namespace Jetstream::Render
23 changes: 19 additions & 4 deletions src/render/vulkan/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ Result Implementation::create(VkRenderPass& renderPass,
auto& [buffer, target] = buffers[i];

VkDescriptorSetLayoutBinding binding{};
binding.descriptorType = buffer->getDescriptorType();
binding.descriptorType = BufferDescriptorType(buffer);
binding.descriptorCount = 1;
binding.stageFlags = TargetToVulkan(target);
binding.stageFlags = TargetToShaderStage(target);
binding.binding = bindingOffset++;

bindings.push_back(binding);
Expand Down Expand Up @@ -130,7 +130,7 @@ Result Implementation::create(VkRenderPass& renderPass,
descriptorWriteBuffer.dstSet = descriptorSet;
descriptorWriteBuffer.dstBinding = bindingOffset++;
descriptorWriteBuffer.dstArrayElement = 0;
descriptorWriteBuffer.descriptorType = buffer->getDescriptorType();
descriptorWriteBuffer.descriptorType = BufferDescriptorType(buffer);
descriptorWriteBuffer.descriptorCount = 1;
descriptorWriteBuffer.pBufferInfo = &bufferInfo;

Expand Down Expand Up @@ -341,7 +341,7 @@ Result Implementation::encode(VkCommandBuffer& commandBuffer, VkRenderPass&) {
return Result::SUCCESS;
}

VkShaderStageFlags Implementation::TargetToVulkan(const Program::Target& target) {
VkShaderStageFlags Implementation::TargetToShaderStage(const Program::Target& target) {
VkShaderStageFlags flags = 0;

if (static_cast<U8>(target & Program::Target::VERTEX) > 0) {
Expand All @@ -355,4 +355,19 @@ VkShaderStageFlags Implementation::TargetToVulkan(const Program::Target& target)
return flags;
}

VkDescriptorType Implementation::BufferDescriptorType(const std::shared_ptr<Buffer>& buffer) {
const auto& bufferType = buffer->getConfig().target;

if ((bufferType & Buffer::Target::UNIFORM) == Buffer::Target::UNIFORM) {
return VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
}

if ((bufferType & Buffer::Target::STORAGE) == Buffer::Target::STORAGE) {
return VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
}

JST_ERROR("[VULKAN] Invalid buffer usage.");
throw Result::ERROR;
}

} // namespace Jetstream::Render

0 comments on commit c9237a8

Please sign in to comment.