Skip to content

Commit

Permalink
Apply pyfmt formatting (pytorch#690)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#690

Just formatting.

Reviewed By: jspark1105

Differential Revision: D30770516

fbshipit-source-id: 857f58a536608d6999552dc1de5710766721b238
  • Loading branch information
jianyuh authored and facebook-github-bot committed Sep 7, 2021
1 parent 7b49986 commit 3ce04fc
Showing 1 changed file with 93 additions and 42 deletions.
135 changes: 93 additions & 42 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ class BoundsCheckMode(enum.IntEnum):
# No bounds checks.
NONE = 3


RecordCacheMetrics: NamedTuple = NamedTuple(
"RecordCacheMetrics",
[("record_cache_miss_counter", bool), ("record_tablewise_cache_miss", bool)]
[("record_cache_miss_counter", bool), ("record_tablewise_cache_miss", bool)],
)


@dataclass
class SplitState:
dev_size: int
Expand Down Expand Up @@ -290,14 +292,14 @@ def __init__( # noqa C901
self.register_buffer(
"rows_per_table",
torch.tensor(
[rows[t] for t in self.feature_table_map], device=self.current_device, dtype=torch.int64
)
[rows[t] for t in self.feature_table_map],
device=self.current_device,
dtype=torch.int64,
),
)
self.register_buffer(
"bounds_check_warning",
torch.tensor(
[0], device=self.current_device, dtype=torch.int64
)
torch.tensor([0], device=self.current_device, dtype=torch.int64),
)

weight_split = construct_split_state(
Expand Down Expand Up @@ -409,17 +411,25 @@ def __init__( # noqa C901
prefix="momentum2",
dtype=torch.float32,
)
self.register_buffer("iter", torch.zeros(1, dtype=torch.int64, device=self.current_device))
self.register_buffer(
"iter", torch.zeros(1, dtype=torch.int64, device=self.current_device)
)
else:
# NOTE: make TorchScript work!
self.register_buffer(
"momentum2_dev", torch.zeros(1, dtype=torch.int64, device=self.current_device), persistent=False
"momentum2_dev",
torch.zeros(1, dtype=torch.int64, device=self.current_device),
persistent=False,
)
self.register_buffer(
"momentum2_host", torch.zeros(1, dtype=torch.int64, device=self.current_device), persistent=False
"momentum2_host",
torch.zeros(1, dtype=torch.int64, device=self.current_device),
persistent=False,
)
self.register_buffer(
"momentum2_uvm", torch.zeros(1, dtype=torch.int64, device=self.current_device), persistent=False
"momentum2_uvm",
torch.zeros(1, dtype=torch.int64, device=self.current_device),
persistent=False,
)
self.register_buffer(
"momentum2_placements",
Expand All @@ -432,7 +442,9 @@ def __init__( # noqa C901
persistent=False,
)
self.register_buffer(
"iter", torch.zeros(1, dtype=torch.int64, device=self.current_device), persistent=False
"iter",
torch.zeros(1, dtype=torch.int64, device=self.current_device),
persistent=False,
)

cache_state = construct_cache_state(embedding_specs, self.feature_table_map)
Expand Down Expand Up @@ -681,15 +693,22 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
offsets,
)

if self.record_cache_metrics.record_cache_miss_counter or self.record_cache_metrics.record_tablewise_cache_miss:
if (
self.record_cache_metrics.record_cache_miss_counter
or self.record_cache_metrics.record_tablewise_cache_miss
):
lxu_cache_locations = torch.ops.fb.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
)
if self.record_cache_metrics.record_cache_miss_counter:
self._update_cache_miss_counter(lxu_cache_locations, linear_cache_indices)
self._update_cache_miss_counter(
lxu_cache_locations, linear_cache_indices
)
if self.record_cache_metrics.record_tablewise_cache_miss:
self._update_tablewise_cache_miss(lxu_cache_locations, linear_cache_indices, offsets)
self._update_tablewise_cache_miss(
lxu_cache_locations, linear_cache_indices, offsets
)

if self.cache_algorithm == CacheAlgorithm.LRU:
torch.ops.fb.lru_cache_populate(
Expand Down Expand Up @@ -724,19 +743,24 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
assert (
len(self.lxu_cache_locations_list) < self.max_prefetch_depth
), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
self.lxu_cache_locations_list.append(torch.ops.fb.lxu_cache_lookup(
self.lxu_cache_locations_list.append(
torch.ops.fb.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
)
)

def _update_cache_miss_counter(
self, lxu_cache_locations: Tensor, linear_cache_indices: Tensor,
self,
lxu_cache_locations: Tensor,
linear_cache_indices: Tensor,
) -> None:
CACHE_MISS = -1
CACHE_HIT = -2

cache_missed_locations = torch.where(lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT)
cache_missed_locations = torch.where(
lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
)
unique_ids_list = torch.unique(cache_missed_locations)
unique_ids_count_list = torch.where(unique_ids_list == CACHE_HIT, 0, 1)

Expand All @@ -755,7 +779,10 @@ def _update_cache_miss_counter(
self.cache_miss_counter[1] += miss_count

def _update_tablewise_cache_miss(
self, lxu_cache_locations: Tensor, linear_cache_indices: Tensor, offsets: Tensor,
self,
lxu_cache_locations: Tensor,
linear_cache_indices: Tensor,
offsets: Tensor,
) -> None:
CACHE_MISS = -1
CACHE_HIT = -2
Expand All @@ -765,7 +792,9 @@ def _update_tablewise_cache_miss(
# positional only parameter to call `len` but got `typing.Union[Tensor, nn.Module]`.
num_tables = len(self.cache_hash_size_cumsum) - 1
num_offsets_per_table = (len(offsets) - 1) // num_tables
cache_missed_locations = torch.where(lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT)
cache_missed_locations = torch.where(
lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT
)

for i in range(num_tables):
start = offsets[i * num_offsets_per_table]
Expand Down Expand Up @@ -1419,6 +1448,7 @@ def forward(
feature_requires_grad,
)


def round_up(a: int, b: int) -> int:
return int((a + b - 1) // b) * b

Expand All @@ -1428,6 +1458,7 @@ def rounded_row_size_in_bytes(dim: int, weight_ty: SparseType) -> int:
# align each row to 16-byte boundaries.
return round_up(r, 16)


def unpadded_row_size_in_bytes(dim: int, weight_ty: SparseType) -> int:
r = {
SparseType.FP16.value: dim * 2,
Expand Down Expand Up @@ -1491,19 +1522,15 @@ def __init__(

def max_ty_D(ty: SparseType) -> int:
return max(
[
dim
for dim, weight_ty in zip(dims, weights_tys)
if weight_ty == ty
],
[dim for dim, weight_ty in zip(dims, weights_tys) if weight_ty == ty],
default=0,
)

self.max_int2_D: int = max_ty_D(SparseType.INT2)
self.max_int4_D: int = max_ty_D(SparseType.INT4)
self.max_int8_D: int = max_ty_D(SparseType.INT8)
self.max_float16_D: int = max_ty_D(SparseType.FP16)


self.register_buffer(
"D_offsets",
torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
Expand All @@ -1513,22 +1540,25 @@ def max_ty_D(ty: SparseType) -> int:
self.register_buffer(
"rows_per_table",
torch.tensor(
[rows[t] for t in feature_table_map], device=self.current_device, dtype=torch.int64
)
[rows[t] for t in feature_table_map],
device=self.current_device,
dtype=torch.int64,
),
)
self.register_buffer(
"bounds_check_warning",
torch.tensor(
[0], device=self.current_device, dtype=torch.int64
)
torch.tensor([0], device=self.current_device, dtype=torch.int64),
)

def align_to_cacheline(a: int) -> int:
# align each table to 128b cache line boundary.
return round_up(a, 128)

weights_offsets = [0] + np.cumsum(
[align_to_cacheline(row * rounded_row_size_in_bytes(dim, weight_ty)) for _, row, dim, weight_ty in embedding_specs]
[
align_to_cacheline(row * rounded_row_size_in_bytes(dim, weight_ty))
for _, row, dim, weight_ty in embedding_specs
]
).tolist()
self.table_size: int = weights_offsets[-1]
weights = torch.randint(
Expand All @@ -1545,28 +1575,35 @@ def align_to_cacheline(a: int) -> int:
_, row, dim, weight_ty = embedding_specs[t]
assert self.weights[
weights_offsets[t] : weights_offsets[t + 1]
].numel() == align_to_cacheline(row * rounded_row_size_in_bytes(dim, weight_ty))
].numel() == align_to_cacheline(
row * rounded_row_size_in_bytes(dim, weight_ty)
)

weights_offsets = [weights_offsets[t] for t in feature_table_map]
weights_tys_int = [weights_tys[t].as_int() for t in feature_table_map]

self.register_buffer(
"weights_offsets",
torch.tensor(weights_offsets, device=self.current_device, dtype=torch.int64),
torch.tensor(
weights_offsets, device=self.current_device, dtype=torch.int64
),
)
self.register_buffer(
"weights_tys",
torch.tensor(weights_tys_int, device=self.current_device, dtype=torch.uint8),
torch.tensor(
weights_tys_int, device=self.current_device, dtype=torch.uint8
),
)

# Assign weights after weights and weights_offsets are initialized.
if weight_lists:
self.assign_embedding_weights(weight_lists) # type: ignore
self.assign_embedding_weights(weight_lists) # type: ignore

if index_remapping:
capacities = [
round_up(int(row * 1.0 / load_factor), 32)
if index_remap is not None else 0
if index_remap is not None
else 0
for (index_remap, row) in zip(index_remapping, rows)
]
hash_table = torch.empty(
Expand Down Expand Up @@ -1603,14 +1640,17 @@ def align_to_cacheline(a: int) -> int:
"index_remapping_hash_table", hash_table.to(self.current_device)
)
self.register_buffer(
"index_remapping_hash_table_offsets", hash_table_offsets.to(self.current_device)
"index_remapping_hash_table_offsets",
hash_table_offsets.to(self.current_device),
)

if use_cpu:
# pyre-fixme[4]: Attribute must be annotated.
self.index_remapping_hash_table_cpu = torch.classes.fb.PrunedMapCPU()
assert T == T_
self.index_remapping_hash_table_cpu.insert(indices, dense_indices, offsets, T_)
self.index_remapping_hash_table_cpu.insert(
indices, dense_indices, offsets, T_
)
else:
self.index_remapping_hash_table_cpu = None
else:
Expand All @@ -1621,7 +1661,6 @@ def align_to_cacheline(a: int) -> int:

self.index_remapping_hash_table_cpu = None


def forward(
self,
indices: Tensor,
Expand All @@ -1641,7 +1680,13 @@ def forward(
)
# We cast to int as a TorchScript workaround.
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
torch.ops.fb.bounds_check_indices(self.rows_per_table, indices, offsets, self.bounds_check_mode_int, self.bounds_check_warning)
torch.ops.fb.bounds_check_indices(
self.rows_per_table,
indices,
offsets,
self.bounds_check_mode_int,
self.bounds_check_warning,
)
return torch.ops.fb.int_nbit_split_embedding_codegen_lookup_function(
dev_weights=self.weights,
weights_offsets=self.weights_offsets,
Expand Down Expand Up @@ -1670,8 +1715,14 @@ def split_embedding_weights(self) -> List[Tuple[Tensor, Optional[Tensor]]]:
offset : offset + rows * rounded_row_size_in_bytes(dim, weight_ty)
].view(rows, rounded_row_size_in_bytes(dim, weight_ty))
# remove the padding at the end of each row.
weights_shifts = weights_shifts[:, :unpadded_row_size_in_bytes(dim, weight_ty)]
if weight_ty == SparseType.INT8 or weight_ty == SparseType.INT4 or weight_ty == SparseType.INT2:
weights_shifts = weights_shifts[
:, : unpadded_row_size_in_bytes(dim, weight_ty)
]
if (
weight_ty == SparseType.INT8
or weight_ty == SparseType.INT4
or weight_ty == SparseType.INT2
):
splits.append(
(
weights_shifts[:, 4:],
Expand Down

0 comments on commit 3ce04fc

Please sign in to comment.