Skip to content

Commit

Permalink
implement lambertian material
Browse files Browse the repository at this point in the history
  • Loading branch information
MrMondrian committed Aug 25, 2024
1 parent 10da4e2 commit 5464fa2
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 24 deletions.
25 changes: 23 additions & 2 deletions src/hittable.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
use glam::Vec3;

#[repr(C)]
#[derive(Copy, Clone, Debug,)]
#[derive(Copy, Clone, Debug)]
pub struct Hittable {
kind: u32,
_padding: [u32; 3], // Padding to align with the next field
sphere: Sphere,
material: Material,
}


impl Hittable {
pub fn new(kind: u32, sphere: Sphere) -> Self {
pub fn new(kind: u32, sphere: Sphere, material: Material) -> Self {
Self {
kind,
_padding: [0; 3],
sphere,
material,
}
}
}
Expand All @@ -40,3 +42,22 @@ impl Sphere {

unsafe impl bytemuck::Pod for Sphere {}
unsafe impl bytemuck::Zeroable for Sphere {}

#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct Material {
albedo: Vec3,
kind: u32,
}

impl Material {
pub fn new(albedo: Vec3, kind: u32) -> Self {
Self {
albedo,
kind,
}
}
}

unsafe impl bytemuck::Pod for Material {}
unsafe impl bytemuck::Zeroable for Material {}
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,12 @@ async fn exec(event_loop: EventLoop<()>, window: Window) {
});

let sphere1 = Sphere::new(Vec3::new(0.0, 0.0, -1.0), 0.5);
let material1 = Material::new(Vec3::new(0.8, 0.3, 0.3), 0);
let sphere2: Sphere = Sphere::new(Vec3::new(0.0, -100.5, -1.0), 100.0);
let material2 = Material::new(Vec3::new(0.8, 0.8, 0.0), 0);

let hittable1 = Hittable::new(0, sphere1);
let hittable2 = Hittable::new(0, sphere2);
let hittable1 = Hittable::new(0, sphere1, material1);
let hittable2 = Hittable::new(0, sphere2, material2);

let hittable_list = vec![hittable1, hittable2];
let hittable_list_buffer = device.create_buffer_init(
Expand Down
79 changes: 59 additions & 20 deletions src/shader.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let ray_origin = camera.center;
let ray_direction = pixel_loc - ray_origin;
let ray = Ray(ray_origin, ray_direction);
for(var k = u32(0); k < 4; k = k + 1) {
for(var k = u32(0); k < 2; k = k + 1) {
seed = sample_vec3(seed.yzx);
color += ray_color(ray, seed);
}
seed += color.xyz;
}
}
return color / f32(camera.samples_per_pixel * camera.samples_per_pixel * 4);
return color / f32(camera.samples_per_pixel * camera.samples_per_pixel * 2);

}

Expand All @@ -77,17 +77,22 @@ fn sample_vec3(rng_seed: vec3<f32>) -> vec3<f32> {

fn ray_color(ray: Ray, seed: vec3<f32>) -> vec4<f32> {
var hits = 0u;
var attenuations = array<f32, 100>();
var attenuations = array<vec3<f32>, 100>();
var curr_ray = ray;
var mutable_seed = seed;
for(var depth = 0u; depth < camera.max_depth; depth = depth + 1u) {
let record = get_hit_record(curr_ray, 0.001, max_f32);
if record.hit {
let direction = random_vec3_on_hemisphere(record.normal, mutable_seed);
attenuations[depth] = 0.5;
curr_ray = Ray(record.p, direction);
hits++;
mutable_seed = sample_vec3(seed);
let hit_record = get_hit_record(curr_ray, 0.001, max_f32);
if hit_record.hit {
let scatter_record = scatter(hit_record.material, curr_ray, hit_record, mutable_seed);
if scatter_record.hit {
attenuations[hits] = scatter_record.attenuation;
curr_ray = scatter_record.scattered;
hits = hits + 1u;
mutable_seed = sample_vec3(mutable_seed);
}
else {
break;
}
}
else {
break;
Expand All @@ -104,7 +109,7 @@ fn ray_color(ray: Ray, seed: vec3<f32>) -> vec4<f32> {

fn get_hit_record(r: Ray, t_min: f32, t_max: f32) -> HitRecord {
var closest_so_far = max_f32;
var record = HitRecord(false, 0.0, vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0));
var record = null_hit_record();
for (var idx = 0u; idx < arrayLength(&hitabble_list); idx = idx + 1u) {
let sphere = hitabble_list[idx];
let temp_record = hit_object(sphere, r, t_min, closest_so_far);
Expand All @@ -118,24 +123,24 @@ fn get_hit_record(r: Ray, t_min: f32, t_max: f32) -> HitRecord {

fn hit_object(hitable: Hitable, r: Ray, t_min: f32, t_max: f32) -> HitRecord {
if hitable.kind == SPHERE {
return hit_sphere(hitable.sphere, r, t_min, t_max);
return hit_sphere(hitable, r, t_min, t_max);
}
return HitRecord(false, 0.0, vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0));
return null_hit_record();
}


fn at(ray: Ray, t: f32) -> vec3<f32> {
return ray.origin + t * ray.direction;
}

fn hit_sphere(sphere: Sphere, r: Ray, ray_tmin: f32, ray_tmax: f32) -> HitRecord {
let oc = sphere.center - r.origin;
fn hit_sphere(hitable: Hitable, r: Ray, ray_tmin: f32, ray_tmax: f32) -> HitRecord {
let oc = hitable.sphere.center - r.origin;
let a = dot(r.direction, r.direction);
let half_b = dot(oc,r.direction);
let c = dot(oc,oc) - sphere.radius * sphere.radius;
let c = dot(oc,oc) - hitable.sphere.radius * hitable.sphere.radius;
let discriminant = half_b*half_b - a*c;
if discriminant < 0.0 {
return HitRecord(false, 0.0, vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0));
return null_hit_record();
}

let sqrtd = sqrt(discriminant);
Expand All @@ -144,13 +149,13 @@ fn hit_sphere(sphere: Sphere, r: Ray, ray_tmin: f32, ray_tmax: f32) -> HitRecord
if root <= ray_tmin || ray_tmax <= root {
root = (half_b + sqrtd) / a;
if root <= ray_tmin || ray_tmax <= root {
return HitRecord(false, 0.0, vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0));
return null_hit_record();
}
}

let p = at(r,root);
let normal = normalize((p - sphere.center) / sphere.radius);
var record = HitRecord(true,root,p,normal);
let normal = normalize((p - hitable.sphere.center) / hitable.sphere.radius);
var record = HitRecord(true,root,p,normal, hitable.material);
record.normal = set_front_face(record, r);
return record;

Expand All @@ -164,6 +169,24 @@ fn set_front_face(rec: HitRecord, r: Ray) -> vec3<f32> {
return rec.normal;
}

fn scatter(material: Material, r: Ray, rec: HitRecord, seed: vec3<f32>) -> ScatterRecord {
if material.kind == LAMBERTIAN {
return scatter_lambertian(material, r, rec, seed);
}
return ScatterRecord(false, vec3(0.0, 0.0, 0.0), Ray(vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0)));
}

fn scatter_lambertian(material: Material, r: Ray, rec: HitRecord, seed: vec3<f32>) -> ScatterRecord {
let scatter_ray = random_vec3_on_hemisphere(rec.normal, seed);
let scattered = Ray(rec.p, scatter_ray);
let attenuation = material.albedo;
return ScatterRecord(true, attenuation, scattered);
}

fn null_hit_record() -> HitRecord {
return HitRecord(false, 0.0, vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0), Material(vec3(0.0, 0.0, 0.0), 0));
}

struct Ray {
origin: vec3<f32>,
direction: vec3<f32>,
Expand All @@ -175,6 +198,7 @@ const max_f32 = 3.40282347e+38;
struct Hitable {
kind: u32,
sphere: Sphere,
material: Material,
}

struct Sphere {
Expand All @@ -187,8 +211,23 @@ struct HitRecord {
t: f32,
p: vec3<f32>,
normal: vec3<f32>,
material: Material,
}

struct ScatterRecord {
hit: bool,
attenuation: vec3<f32>,
scattered: Ray,
}

struct Material {
albedo: vec3<f32>,
kind: u32,
}

const LAMBERTIAN = u32(0);
const METAL = u32(1);

fn random_vec3_on_hemisphere(normal: vec3<f32>, rng_seed: vec3<f32>) -> vec3<f32> {
let p = normal + sample_vec3(rng_seed);
var normed = normalize(p);
Expand Down

0 comments on commit 5464fa2

Please sign in to comment.