Skip to content

Commit

Permalink
Types for fbgemm (pytorch#580)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#580

Reviewed By: jianyuh

Differential Revision: D27345748

fbshipit-source-id: b7ea69efa73ec90be5c7c24e036e0aa85a3a4191
  • Loading branch information
r-barnes authored and facebook-github-bot committed Mar 30, 2021
1 parent d8ac3eb commit e839120
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 149 deletions.
37 changes: 19 additions & 18 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,19 @@ def generate_requests(
]
all_indices[it + 1, t, reused_indices] = all_indices[it, t, reused_indices]

rs = [
get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L))
+ (
torch.randn(
T * B * L,
device=torch.cuda.current_device(),
dtype=torch.float16
if weights_precision == SparseType.FP16
else torch.float32,
)
if weighted
else None,
rs = []
for it in range(iters):
weights_tensor = None if not weighted else torch.randn(
T * B * L,
device=torch.cuda.current_device(),
dtype=torch.float16
if weights_precision == SparseType.FP16
else torch.float32,
)
rs.append(
get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L))
+ (weights_tensor,)
)
for it in range(iters)
]
# pyre-fixme[7]
return rs


Expand Down Expand Up @@ -441,8 +438,11 @@ def uvm(
indices = torch.cat([rs_uvm[0], rs_gpu[0]])
lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B)
offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda()
# pyre-fixme[6]
per_sample_weights = torch.cat([rs_uvm[2], rs_gpu[2]]) if weighted else None
per_sample_weights = None
if weighted:
assert (this_rs_uvm_weights := rs_uvm[2]) is not None
assert (this_rs_gpu_weights := rs_gpu[2]) is not None
per_sample_weights = torch.cat([this_rs_uvm_weights, this_rs_gpu_weights])
requests.append((indices, offsets, per_sample_weights))

# forward
Expand Down Expand Up @@ -624,7 +624,8 @@ def cache( # noqa C901
exchanged_cache_lines = []
NOT_FOUND = -1
for indices, offsets, _ in requests:
# pyre-fixme[16]
# pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no attribute
# `lxu_cache_state`.
old_lxu_cache_state = emb.lxu_cache_state.clone()
emb.prefetch(indices.long(), offsets.long())
exchanged_cache_lines.append(
Expand Down
6 changes: 2 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ class EmbOptimType(enum.Enum):
PARTIAL_ROWWISE_LAMB = "partial_row_wise_lamb"
ROWWISE_ADAGRAD = "row_wise_adagrad"

# pyre-fixme[3]: Return type must be annotated.
def __str__(self):
def __str__(self) -> str:
return self.value


Expand All @@ -36,8 +35,7 @@ class SparseType(enum.Enum):
FP16 = "fp16"
INT8 = "int8"

# pyre-fixme[3]: Return type must be annotated.
def __str__(self):
def __str__(self) -> str:
return self.value


Expand Down
Loading

0 comments on commit e839120

Please sign in to comment.