Skip to content

Commit

Permalink
device guards
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackSamorez committed Feb 18, 2024
1 parent 4b08a7e commit 8d81cdb
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions inference_lib/src/aqlm/inference_kernels/cuda_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>

void code1x16_matvec_cuda(
const void* A,
Expand All @@ -25,6 +26,7 @@ void code1x16_matvec(
torch::Tensor& C,
const torch::Tensor& codebook
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code1x16_matvec_cuda(
Expand Down Expand Up @@ -81,6 +83,7 @@ void code2x8_matvec(
torch::Tensor& C,
const torch::Tensor& codebook
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code2x8_matvec_cuda(
Expand Down

0 comments on commit 8d81cdb

Please sign in to comment.