Skip to content

Commit

Permalink
Follow up on throw errors directly on host code for CUDA bounds check…
Browse files Browse the repository at this point in the history
… op (pytorch#1075)

Summary:
Pull Request resolved: pytorch#1075

Follow up for D35768276 (pytorch@7be1fcb): throw errors directly on host code.

Reviewed By: yinghai

Differential Revision: D35905891

fbshipit-source-id: f97047ff9cb27f7f169dc0223fa0295cc14a8fe8
  • Loading branch information
jianyuh authored and facebook-github-bot committed Apr 26, 2022
1 parent cb064a2 commit aa1eefd
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 102 deletions.
43 changes: 10 additions & 33 deletions fbgemm_gpu/codegen/embedding_bounds_check.cu
Original file line number Diff line number Diff line change
Expand Up @@ -172,39 +172,16 @@ void bounds_check_indices_cuda(
}
int64_t num_indices = indices.size(0);

if (bounds_check_mode == BoundsCheckMode::FATAL) {
TORCH_CHECK(offsets.size(0) == B * T + 1);
if (weights.has_value()) {
TORCH_CHECK(weights.value().size(0) == num_indices);
}
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (offsets.size(0) != B * T + 1) {
printf(
"EmbeddingBoundsCheck: offsets size is incorrect for "
"total batch size B: %ld, total table num T: %ld, "
" offsets size: %ld. Setting offsets size to be B * T + 1.\n",
static_cast<int64_t>(B),
static_cast<int64_t>(T),
static_cast<int64_t>(offsets.size(0)));
offsets = offsets.slice(0, 0, B * T + 1);
}
if (weights.has_value()) {
if (weights.value().size(0) != num_indices) {
printf(
"The size of weights are not consistent with indices. "
"Changing the weights to the same size as indices with all element 1.");
weights = at::ones({num_indices}, weights.value().options());
}
}
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
if (offsets.size(0) != B * T + 1) {
offsets = offsets.slice(0, 0, B * T + 1);
}
if (weights.has_value()) {
if (weights.value().size(0) != num_indices) {
weights = at::ones({num_indices}, weights.value().options());
}
}
TORCH_CHECK(
offsets.size(0) == B * T + 1,
"offsets size " + std::to_string(offsets.size(0)) +
" is not equal to B (" + std::to_string(B) + ") * T (" +
std::to_string(T) + ") + 1");
if (weights.has_value()) {
TORCH_CHECK(
weights.value().size(0) == num_indices,
"weights size " + std::to_string(weights.value().size(0)) +
" is not equal to indices size " + std::to_string(num_indices));
}

constexpr size_t kNumThreads = 256;
Expand Down
44 changes: 12 additions & 32 deletions fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,21 @@ void bounds_check_indices_cpu(
auto indices_acc = indices.accessor<index_t, 1>();
auto num_indices = indices.numel();

TORCH_CHECK(
offsets.size(0) == B * T + 1,
"offsets size " + std::to_string(offsets.size(0)) +
" is not equal to B (" + std::to_string(B) + ") * T (" +
std::to_string(T) + ") + 1");
if (weights.has_value()) {
TORCH_CHECK(
weights.value().size(0) == num_indices,
"weights size " + std::to_string(weights.value().size(0)) +
" is not equal to indices size " + std::to_string(num_indices));
}

if (bounds_check_mode == BoundsCheckMode::FATAL) {
TORCH_CHECK(offsets.size(0) == B * T + 1);
TORCH_CHECK(num_indices == offsets_acc[B * T]);
if (weights.has_value()) {
TORCH_CHECK(weights.value().size(0) == num_indices);
}
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (offsets.size(0) != B * T + 1) {
if (__sync_fetch_and_add(&warning_acc[0], 1) == 0) {
LOG(ERROR) << "The offsets size is incorrect for "
<< "total batch size B: " << B
<< ", total table num T: " << T
<< ", offsets size: " << offsets.size(0)
<< ". Setting offsets size to be B * T + 1.";
}
offsets = offsets.slice(0, 0, B * T + 1);
}
if (num_indices != offsets_acc[B * T]) {
if (__sync_fetch_and_add(&warning_acc[0], 1) == 0) {
LOG(ERROR)
Expand All @@ -81,28 +79,10 @@ void bounds_check_indices_cpu(
}
offsets_acc[B * T] = num_indices;
}
if (weights.has_value()) {
if (weights.value().size(0) != num_indices) {
if (__sync_fetch_and_add(&warning_acc[0], 1) == 0) {
LOG(ERROR)
<< "The size of weights are not consistent with indices. "
<< "Changing the weights to the same size as indices with all element 1.";
}
weights = at::ones({num_indices}, weights.value().options());
}
}
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
if (offsets.size(0) != B * T + 1) {
offsets = offsets.slice(0, 0, B * T + 1);
}
if (num_indices != offsets_acc[B * T]) {
offsets_acc[B * T] = num_indices;
}
if (weights.has_value()) {
if (weights.value().size(0) != num_indices) {
weights = at::ones(num_indices, weights.value().options());
}
}
}

for (auto t = 0; t < T; ++t) {
Expand Down
69 changes: 32 additions & 37 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3890,6 +3890,7 @@ def test_bounds_check(
if weighted:
weights = weights.cuda()
indices_copy = indices.clone()
offsets_copy = offsets.clone()
torch.ops.fbgemm.bounds_check_indices(
rows_per_table, indices, offsets, bounds_check_mode, warning, weights
)
Expand Down Expand Up @@ -3920,6 +3921,7 @@ def test_bounds_check(

# test offsets bound errors
indices = indices_copy.clone()
offsets = offsets_copy.clone()
if offsets.numel() > 0:
offsets[0] = -100
if offsets.numel() > 1:
Expand Down Expand Up @@ -3948,52 +3950,45 @@ def test_bounds_check(
weights,
)

# test offsets.size(0) ! = B * T + 1 case
indices = indices_copy.clone()
offsets = torch.cat(
(
offsets,
torch.tensor(
[indices.numel()], dtype=offsets.dtype, device=offsets.device
# test offsets.size(0) ! = B * T + 1 case. Here we test with T >= 2 case.
# If T == 1, we will always get the even division.
if T >= 2:
indices = indices_copy.clone()
offsets = offsets_copy.clone()
offsets = torch.cat(
(
offsets,
torch.tensor(
[indices.numel()] * (T - 1),
dtype=offsets.dtype,
device=offsets.device,
),
),
),
dim=0,
)
if bounds_check_mode != BoundsCheckMode.FATAL:
torch.ops.fbgemm.bounds_check_indices(
rows_per_table, indices, offsets, bounds_check_mode, warning, weights
dim=0,
)
else:
if use_cpu:
with self.assertRaises(RuntimeError):
torch.ops.fbgemm.bounds_check_indices(
rows_per_table,
indices,
offsets,
bounds_check_mode,
warning,
weights,
)
with self.assertRaises(RuntimeError):
torch.ops.fbgemm.bounds_check_indices(
rows_per_table,
indices,
offsets,
bounds_check_mode,
warning,
weights,
)

# test weights.size(0) != indices.size(0) case
weights = torch.rand(
(indices.size(0) + 1,), dtype=torch.float, device=indices.device
)
if bounds_check_mode != BoundsCheckMode.FATAL:
with self.assertRaises(RuntimeError):
torch.ops.fbgemm.bounds_check_indices(
rows_per_table, indices, offsets, bounds_check_mode, warning, weights
rows_per_table,
indices,
offsets,
bounds_check_mode,
warning,
weights,
)
else:
if use_cpu:
with self.assertRaises(RuntimeError):
torch.ops.fbgemm.bounds_check_indices(
rows_per_table,
indices,
offsets,
bounds_check_mode,
warning,
weights,
)

def test_pickle(self) -> None:
tensor_queue = torch.classes.fbgemm.TensorQueue(torch.empty(0))
Expand Down

0 comments on commit aa1eefd

Please sign in to comment.