Skip to content

Commit

Permalink
Fixed: batch memory corruption due to overrunning temp_state buffer. (
Browse files Browse the repository at this point in the history
turboderp#128)

Without this fix, column_remap_cuda would try to write its remapped output into a temp buffer sized for a batch size of 1. This would work as long as `bsz*seq_len` was less than 2048, then there'd be a problem.

The fix is to use a smaller chunk size that fits in the buffer when the batch is larger. This should not impact performance much since we still work on roughly the same number of elements in parallel apart from some small rounding error when the max_input_len isn't evently divisible by the batch size. And worst case the user can always just increase `max_input_len`.
  • Loading branch information
aljungberg authored Jul 3, 2023
1 parent 693ef40 commit 63ebf2e
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 1 deletion.
4 changes: 4 additions & 0 deletions exllama_ext/cuda_buffers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ CudaBuffers::CudaBuffers
(
int _device,
half* _temp_state,
int _temp_state_size,
half* _temp_mlp,
float* _temp_zeros_float,
half* _temp_dq,
int _max_zeros_float
) :
device(_device),
temp_state(_temp_state),
temp_state_size(_temp_state_size),
temp_mlp(_temp_mlp),
temp_zeros_float(_temp_zeros_float),
temp_dq(_temp_dq),
Expand Down Expand Up @@ -65,6 +67,7 @@ void prepare_buffers_cuda
(
int _device,
half* _temp_state,
int _temp_state_size,
half* _temp_mlp,
float* _temp_zeros_float,
half* _temp_dq,
Expand All @@ -75,6 +78,7 @@ void prepare_buffers_cuda
(
_device,
_temp_state,
_temp_state_size,
_temp_mlp,
_temp_zeros_float,
_temp_dq,
Expand Down
3 changes: 3 additions & 0 deletions exllama_ext/cuda_buffers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public:
int device;

half* temp_state; // [max_hidden_rows * intermediate_size]
int temp_state_size;
half* temp_mlp; // [hidden_dim * intermediate_size]
float* temp_zeros_float; // [max_hidden_rows]
half* temp_dq; // size of largest quant tensor * 8
Expand All @@ -36,6 +37,7 @@ public:
(
int _device,
half* _temp_state,
int _temp_state_size,
half* _temp_mlp,
float* _temp_zeros_float,
half* _temp_dq,
Expand All @@ -52,6 +54,7 @@ void prepare_buffers_cuda
(
int _device,
half* _temp_state,
int _temp_state_size,
half* _temp_mlp,
float* _temp_zeros_float,
half* _temp_dq,
Expand Down
1 change: 1 addition & 0 deletions exllama_ext/cuda_func/q4_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ void q4_matmul_recons_cuda
const half* x_mapped = x;
if (w->cuda_x_map)
{
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
}
Expand Down
1 change: 1 addition & 0 deletions exllama_ext/cuda_func/q4_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void q4_mlp_cuda
// temp_x = rms_layernorm(x)

half* temp_x = buffers->temp_state + height * dim; // TOOD: ..
TORCH_CHECK(buffers->temp_state_size >= 2 * height * dim, "temp_state buffer too small");
rms_norm_cuda(tuningParams, x, rms_norm_weight, temp_x, epsilon, height, dim, device_index);

// temp_mlp[0] = temp_x @ gate
Expand Down
5 changes: 5 additions & 0 deletions exllama_ext/exllama_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void check_cuda(cudaError_t ret)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")

#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
Expand Down Expand Up @@ -141,6 +142,8 @@ void prepare_buffers
(
device_index,
(half*) temp_state.data_ptr(),
// buffer size used for sanity checks
temp_state.numel(),
(half*) temp_mlp.data_ptr(),
(float*) temp_zeros_float.data_ptr(),
(half*) temp_dq.data_ptr(),
Expand Down Expand Up @@ -337,6 +340,8 @@ void column_remap
int height = x.size(0);
int width = x.size(1);

TORCH_CHECK_BUFFER_SIZE(x_new, height * width);

const at::cuda::OptionalCUDAGuard device_guard(device_of(x));

column_remap_cuda
Expand Down
5 changes: 4 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,9 @@ def forward(self,

q_len = input_ids.shape[-1]
remaining_q_len = q_len
bsz = input_ids.shape[0]
# The buffers can only fit max_input_len tokens, so with larger batch sizes we reduce our work size correspondingly.
effective_max_input_len = self.config.max_input_len // bsz

# Split forward pass

Expand All @@ -825,7 +828,7 @@ def forward(self,

# Limit chunk_size to max_input_len

chunk_size = min(remaining_q_len, self.config.max_input_len)
chunk_size = min(remaining_q_len, effective_max_input_len)

# Limit chunk_size to keep size of attention operation <= max_attention_size

Expand Down

0 comments on commit 63ebf2e

Please sign in to comment.