Skip to content

Commit

Permalink
fix approx sgd on cpu (pytorch#566)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#566

We should add test cases for all new features we add.

Reviewed By: jianyuh

Differential Revision: D27237308

fbshipit-source-id: b6b1aebbe809fb4aa199d528e78acecd8694d123
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 22, 2021
1 parent 99f3034 commit ffff7a3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def approx_sgd():
weight_new.fma_(grad, -learning_rate);
"""
split_weight_update_cpu = """
host_weights_data[embedding_begin + d] += learning_rate * grad_val;
host_weights_data[embedding_begin + d] -= learning_rate * grad_val;
"""

generate(
Expand Down
6 changes: 5 additions & 1 deletion fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def test_backward_dense(
long_segments=st.booleans(),
pooling_mode=st.sampled_from(split_table_batched_embeddings_ops.PoolingMode),
use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True),
exact=st.booleans(),
)
@settings(
verbosity=Verbosity.verbose,
Expand All @@ -478,12 +479,15 @@ def test_backward_sgd( # noqa C901
long_segments,
pooling_mode,
use_cpu,
exact,
):
# NOTE: cache is not applicable to CPU version.
assume(not use_cpu or not use_cache)
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
assume(not use_cpu or T * B * L * D <= 2048)
assume(not (use_cpu and weights_precision == SparseType.FP16))
# GPU only does exact sgd
assume(use_cpu or exact)

assume(
pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM
Expand Down Expand Up @@ -591,7 +595,7 @@ def test_backward_sgd( # noqa C901

cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
[(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)],
optimizer=OptimType.EXACT_SGD,
optimizer=OptimType.EXACT_SGD if exact else OptimType.SGD,
feature_table_map=feature_table_map,
learning_rate=0.05,
weights_precision=weights_precision,
Expand Down

0 comments on commit ffff7a3

Please sign in to comment.