Skip to content

Commit

Permalink
feat: implement wgsl preprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Oct 12, 2022
1 parent 1d9f390 commit d1718d9
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 14 deletions.
7 changes: 7 additions & 0 deletions shaders/blit.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include blit_common.wgsl

var<push_constant> view_index: u32;
@fragment
fn blit_fs_main(in: BlitVertexOutput) -> @location(0) vec4<f32> {
return textureSample(blit_texture, blit_sampler, in.uv_coords, i32(view_index));
}
9 changes: 1 addition & 8 deletions src/blit.wgsl → shaders/blit_common.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,13 @@ struct BlitVertexOutput {

@group(0) @binding(0)
var blit_texture: texture_2d_array<f32>;
@group(0)@binding(1)
@group(0) @binding(1)
var blit_sampler: sampler;

var<push_constant> view_index: u32;

@vertex
fn blit_vs_main(model: BlitVertexInput) -> BlitVertexOutput {
var out: BlitVertexOutput;
out.position = vec4<f32>(model.position, 1.0);
out.uv_coords = model.uv_coords;
return out;
}

@fragment
fn blit_fs_main(in: BlitVertexOutput) -> @location(0) vec4<f32> {
return textureSample(blit_texture, blit_sampler, in.uv_coords, i32(view_index));
}
File renamed without changes.
7 changes: 5 additions & 2 deletions src/blit_state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use glam::{vec3, Vec3};
use std::borrow::Cow;
use std::{borrow::Cow, path::Path};
use wgpu::util::DeviceExt;

pub struct BlitState {
Expand All @@ -14,6 +14,7 @@ pub struct BlitState {
impl BlitState {
pub fn new(
device: &wgpu::Device,
preprocessor: &crate::wgsl::Preprocessor,
render_target_view: &wgpu::TextureView,
swapchain_format: wgpu::TextureFormat,
) -> Self {
Expand Down Expand Up @@ -92,7 +93,9 @@ impl BlitState {
};
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("blit.wgsl"))),
source: wgpu::ShaderSource::Wgsl(Cow::Owned(
preprocessor.preprocess("blit.wgsl").unwrap(),
)),
});
let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: None,
Expand Down
20 changes: 18 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::Path;

use anyhow::Context;
use glam::{vec3, vec4, Quat};
use wgpu::util::DeviceExt;
Expand All @@ -15,6 +17,8 @@ mod main_state;
mod texture;
mod types;

pub mod wgsl;

use blit_state::BlitState;
use camera::CameraState;
use main_state::MainState;
Expand Down Expand Up @@ -75,8 +79,15 @@ fn main() -> anyhow::Result<()> {

let mut camera_state = CameraState::new(&wgpu_state.device, window.inner_size());

let preprocessor = wgsl::Preprocessor::from_directory(Path::new("shaders"))?;

let swapchain_format = surface.get_supported_formats(&wgpu_state.adapter)[0];
let mut main_state = MainState::new(&wgpu_state.device, &camera_state, swapchain_format);
let mut main_state = MainState::new(
&wgpu_state.device,
&preprocessor,
&camera_state,
swapchain_format,
);

let mut config = {
let size = window.inner_size();
Expand All @@ -91,7 +102,12 @@ fn main() -> anyhow::Result<()> {
surface.configure(&wgpu_state.device, &config);
let mut depth_texture = Texture::new_depth_texture(&wgpu_state.device, &config);
let mut rt_texture = Texture::new_rt_texture(&wgpu_state.device, &config, swapchain_format);
let mut blit_state = BlitState::new(&wgpu_state.device, rt_texture.view(), swapchain_format);
let mut blit_state = BlitState::new(
&wgpu_state.device,
&preprocessor,
rt_texture.view(),
swapchain_format,
);

let triangle_vertex_buffer =
wgpu_state
Expand Down
7 changes: 5 additions & 2 deletions src/main_state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use glam::{vec3, Mat4, Quat, Vec3};
use std::{borrow::Cow, num::NonZeroU32};
use std::{borrow::Cow, num::NonZeroU32, path::Path};
use wgpu::util::DeviceExt;

use crate::{
Expand All @@ -19,6 +19,7 @@ pub struct MainState {
impl MainState {
pub fn new(
device: &wgpu::Device,
preprocessor: &crate::wgsl::Preprocessor,
camera_state: &CameraState,
swapchain_format: wgpu::TextureFormat,
) -> Self {
Expand Down Expand Up @@ -47,7 +48,9 @@ impl MainState {

let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("main.wgsl"))),
source: wgpu::ShaderSource::Wgsl(Cow::Owned(
preprocessor.preprocess("main.wgsl").unwrap(),
)),
});
let vertex_buffer_layout = wgpu::VertexBufferLayout {
array_stride: std::mem::size_of::<Vertex>() as _,
Expand Down
88 changes: 88 additions & 0 deletions src/wgsl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use std::{
collections::HashMap,
path::{Path, PathBuf},
};

use anyhow::Context;

/// This is _not_ a robust preprocessor. It's the bare minimum to make this example work.
/// This *will* fall down at the first hurdle.
pub fn preprocess(files: &HashMap<PathBuf, String>, current_file: &str) -> anyhow::Result<String> {
let mut current_file = current_file.to_string();
while current_file.contains("#include") {
current_file = current_file
.lines()
.map(|l| match l.strip_prefix("#include ") {
Some(filename) => Ok(files
.get(&PathBuf::from(filename))
.context("failed to find file")?
.as_ref()),
None => Ok(l),
})
.collect::<anyhow::Result<Vec<&str>>>()?
.join("\n");
}
Ok(current_file)
}

pub struct Preprocessor {
files: HashMap<PathBuf, String>,
}
impl Preprocessor {
pub fn from_directory(path: &Path) -> std::io::Result<Self> {
Ok(Self {
files: std::fs::read_dir(path)?
.filter_map(Result::ok)
.map(|de| de.path())
.filter(|p| p.extension().unwrap_or_default() == "wgsl")
.map(|p| {
Ok((
PathBuf::from(
p.file_name()
.ok_or(std::io::Error::from(std::io::ErrorKind::NotFound))?,
),
std::fs::read_to_string(&p)?,
))
})
.collect::<std::io::Result<_>>()?,
})
}

pub fn preprocess(&self, filename: impl AsRef<Path>) -> anyhow::Result<String> {
preprocess(
&self.files,
self.files
.get(filename.as_ref())
.context("File not present!")?
.as_str(),
)
}
}

#[cfg(test)]
mod tests {
use std::path::PathBuf;

use super::preprocess;

#[test]
fn preprocess_can_include() {
let main_file = "#include blah.wgsl\n// and good night!";
let files = [
(PathBuf::from("foo.wgsl"), "// first file!".to_string()),
(
PathBuf::from("blah.wgsl"),
"#include foo.wgsl\n// hello world!".to_string(),
),
(PathBuf::from("main.wgsl"), main_file.to_string()),
]
.into_iter()
.collect();

let expected_output = r#"// first file!
// hello world!
// and good night!"#;

assert_eq!(preprocess(&files, main_file).unwrap(), expected_output);
}
}

0 comments on commit d1718d9

Please sign in to comment.