-
Notifications
You must be signed in to change notification settings - Fork 0
/
buffer.rs
144 lines (123 loc) · 4.52 KB
/
buffer.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use std::{marker::PhantomData, mem, sync::Arc};
use bytemuck::Pod;
use wgpu::util::DeviceExt as _;
use crate::Context;
/// Specifies the storage access of the buffer in the kernel.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BufferAccess {
/// The buffer can only be read.
///
/// Corresponds to a `var<storage, read>` in wgsl.
ReadOnly,
/// The buffer may be read and/or written.
///
/// Corresponds to a `var<storage, read_write>` in wgsl.
ReadWrite,
}
/// Buffer on the GPU that stores homogeneous data.
///
/// With multiple elements it acts as an `array<T>` in kernel code.
/// With a single element it can also act as something of type `T`.
#[derive(Debug)]
pub struct Buffer<T> {
pub(crate) device: Arc<crate::Device>,
pub(crate) handle: wgpu::Buffer,
_marker: PhantomData<Vec<T>>,
}
impl<T: Pod> Buffer<T> {
// cheesy workaround to be able to make a const bitflag
// see: https://github.com/bitflags/bitflags/issues/180
const USAGES: wgpu::BufferUsages = wgpu::BufferUsages::from_bits_truncate(
wgpu::BufferUsages::STORAGE.bits()
| wgpu::BufferUsages::COPY_DST.bits()
| wgpu::BufferUsages::COPY_SRC.bits(),
);
/// Allocate a buffer on the GPU with `capacity` **elements of T**.
///
/// # Panics
///
/// - if capacity exceeds the limit of `max_buffer_size` (with a default
/// value of **2^30 bytes** that can be configured in `ContextInfo`).
pub fn new(context: &Context, capacity: wgpu::BufferAddress) -> Self {
let buffer = context
.device
.handle
.create_buffer(&wgpu::BufferDescriptor {
label: Some("buffer"),
size: capacity * mem::size_of::<T>() as wgpu::BufferAddress,
usage: Self::USAGES,
mapped_at_creation: false,
});
Self {
device: Arc::clone(&context.device),
handle: buffer,
_marker: PhantomData,
}
}
/// Creates an empty buffer able to store the same ammount of data that `original` does.
pub fn empty_like(original: &Self) -> Self {
let buffer = original
.device
.handle
.create_buffer(&wgpu::BufferDescriptor {
label: Some("buffer"),
size: original.handle.size(),
usage: Self::USAGES,
mapped_at_creation: false,
});
Self {
device: Arc::clone(&original.device),
handle: buffer,
_marker: PhantomData,
}
}
/// Write to a buffer starting at `index`.
///
/// # Panics
///
/// - if `data` overruns the buffer from any index.
pub fn write(&self, data: &[T], index: wgpu::BufferAddress) {
let offset = index * mem::size_of::<T>() as u64;
self.device
.queue
.write_buffer(&self.handle, offset, bytemuck::cast_slice(data));
}
/// Allocates a buffer on the GPU and initializes it with data.
pub fn from_slice(context: &Context, data: &[T]) -> Self {
let buffer = context
.device
.handle
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("buffer"),
contents: bytemuck::cast_slice(data),
usage: Self::USAGES,
});
Self {
device: Arc::clone(&context.device),
handle: buffer,
_marker: PhantomData,
}
}
/// Reads the contents of the buffer into a Vec.
pub fn read_to_vec(&self) -> Vec<T> {
let dst_buffer = self.device.handle.create_buffer(&wgpu::BufferDescriptor {
label: Some("Destination copy buffer"),
size: self.handle.size(),
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut encoder =
self.device
.handle
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Copy buffer command encoder"),
});
encoder.copy_buffer_to_buffer(&self.handle, 0, &dst_buffer, 0, dst_buffer.size());
self.device.queue.submit(std::iter::once(encoder.finish()));
let dst_slice = dst_buffer.slice(..);
dst_slice.map_async(wgpu::MapMode::Read, move |_| {});
self.device.handle.poll(wgpu::Maintain::Wait);
let data = dst_slice.get_mapped_range();
bytemuck::cast_slice(&data).to_vec()
}
}