Skip to content

Commit

Permalink
Add the boundary checker twice to check the indices before and after …
Browse files Browse the repository at this point in the history
…pruning (pytorch#1526)

Summary: Pull Request resolved: pytorch#1526

Reviewed By: jspark1105

Differential Revision: D42243361

fbshipit-source-id: 0254feecc376da4968b5f5776be3b0c2a4d42f7d
  • Loading branch information
jianyuh authored and facebook-github-bot committed Dec 28, 2022
1 parent 0478c2e commit 96f1682
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,6 +1852,10 @@ def max_ty_D(ty: SparseType) -> int:
"index_remapping_hash_table",
torch.empty(0, device=self.current_device, dtype=torch.int32),
)
self.register_buffer(
"original_rows_per_table",
torch.empty(0, device=self.current_device, dtype=torch.int64),
)
# pyre-fixme[4]: Attribute must be annotated.
self.index_remapping_hash_table_cpu = None

Expand Down Expand Up @@ -2212,8 +2216,28 @@ def forward(
assert (
self.weight_initialized
), "weight needs to be initialized before forward function"

# First bound check: check if the indices/offsets are within the boundary
# of the original embedding rows before pruning.
# Note that this is only applied when we enable pruning (if the perf becomes
# an issue, we can fuse it inside the remapping kernel).
if (
self.index_remapping_hash_table_cpu is not None
or self.index_remapping_hash_table.numel() > 0
or self.index_remappings_array.numel() > 0
):
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
torch.ops.fbgemm.bounds_check_indices(
self.original_rows_per_table,
indices,
offsets,
self.bounds_check_mode_int,
self.bounds_check_warning,
per_sample_weights,
)

# Index remapping changes input indices, and some of them becomes -1 (prunned rows).
# Hence, remapping should be done before before prefetch, bound check, and emb lookup
# Hence, remapping should be done before prefetch and emb lookup
# so that these operations are with the remapped indices.
if self.index_remapping_hash_table_cpu is not None:
indices = self.index_remapping_hash_table_cpu.lookup(indices, offsets)
Expand All @@ -2238,7 +2262,9 @@ def forward(

lxu_cache_locations = self.lxu_cache_locations_list.pop()

# We cast to int as a TorchScript workaround.
# Second bound check: check if the indices/offsets are within the boundary
# of the pruned embedding rows after pruning.
# Note: we cast to int as a TorchScript workaround.
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
torch.ops.fbgemm.bounds_check_indices(
self.rows_per_table,
Expand Down Expand Up @@ -2668,6 +2694,7 @@ def set_index_remappings(
) -> None:
rows: List[int] = [e[1] for e in self.embedding_specs]
T = len(self.embedding_specs)
# Hash mapping pruning
if not use_array_for_index_remapping:
capacities = [
round_up(int(row * 1.0 / pruning_hash_load_factor), 32)
Expand All @@ -2683,12 +2710,17 @@ def set_index_remappings(
hash_table_offsets = torch.tensor([0] + list(accumulate(capacities))).long()

merged_index_remappings = [
mapping if mapping is not None else Tensor(list(range(spec[1])))
for (mapping, spec) in zip(index_remapping, self.embedding_specs)
mapping if mapping is not None else Tensor(list(range(row)))
for (mapping, row) in zip(index_remapping, rows)
]
original_feature_rows = [
mapping.numel() for mapping in merged_index_remappings
]
self.original_rows_per_table = torch.tensor(
[original_feature_rows[t] for t in self.feature_table_map],
device=self.current_device,
dtype=torch.int64,
)
dense_indices = torch.cat(merged_index_remappings, dim=0).int()
indices = torch.cat(
[torch.arange(row) for row in original_feature_rows], dim=0
Expand Down Expand Up @@ -2716,19 +2748,30 @@ def set_index_remappings(
self.current_device
)
self.index_remapping_hash_table_cpu = None
# Array mapping pruning
else:
index_remappings_array_offsets = [0]
original_feature_rows = []
last_offset = 0
for mapping in index_remapping:
for t, mapping in enumerate(index_remapping):
if mapping is not None:
last_offset += mapping.numel()
current_original_row = mapping.numel()
last_offset += current_original_row
original_feature_rows.append(current_original_row)
else:
original_feature_rows.append(rows[t])
index_remappings_array_offsets.append(last_offset)

self.index_remappings_array_offsets = torch.tensor(
index_remappings_array_offsets,
device=self.current_device,
dtype=torch.int64,
)
self.original_rows_per_table = torch.tensor(
[original_feature_rows[t] for t in self.feature_table_map],
device=self.current_device,
dtype=torch.int64,
)
self.index_remappings_array = (
torch.empty(0, dtype=torch.int32, device=self.current_device)
if self.index_remappings_array_offsets[-1] == 0
Expand Down

0 comments on commit 96f1682

Please sign in to comment.