Skip to content

Commit

Permalink
Fix typo (pytorch#506)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#506

As title

Reviewed By: jspark1105

Differential Revision: D26236508

fbshipit-source-id: 8aa19b7b8a54d40fd7ac52573ef9df3749857501
  • Loading branch information
jianyuh authored and facebook-github-bot committed Feb 4, 2021
1 parent a4de38e commit fe1fe6d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 42 deletions.
66 changes: 33 additions & 33 deletions fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ torch.ops.load_library("fbgemm_gpu.so")


def invoke(
commom_args: CommonArgs,
common_args: CommonArgs,
optimizer_args: OptimizerArgs,
{% if "momentum1_dev" in args.split_function_arg_names %}
momentum1: Momentum,
Expand All @@ -34,22 +34,22 @@ def invoke(
iter: int,
{% endif %}
) -> torch.Tensor:
if (commom_args.host_weights.numel() > 0):
if (common_args.host_weights.numel() > 0):
return torch.ops.fb.split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(
# commom_args
host_weights=commom_args.host_weights,
weights_placements=commom_args.weights_placements,
weights_offsets=commom_args.weights_offsets,
D_offsets=commom_args.D_offsets,
total_D=commom_args.total_D,
max_D=commom_args.max_D,
hash_size_cumsum=commom_args.hash_size_cumsum,
total_hash_size_bits=commom_args.total_hash_size_bits,
indices=commom_args.indices,
offsets=commom_args.offsets,
pooling_mode=commom_args.pooling_mode,
indice_weights=commom_args.indice_weights,
feature_requires_grad=commom_args.feature_requires_grad,
# common_args
host_weights=common_args.host_weights,
weights_placements=common_args.weights_placements,
weights_offsets=common_args.weights_offsets,
D_offsets=common_args.D_offsets,
total_D=common_args.total_D,
max_D=common_args.max_D,
hash_size_cumsum=common_args.hash_size_cumsum,
total_hash_size_bits=common_args.total_hash_size_bits,
indices=common_args.indices,
offsets=common_args.offsets,
pooling_mode=common_args.pooling_mode,
indice_weights=common_args.indice_weights,
feature_requires_grad=common_args.feature_requires_grad,
# optimizer_args
gradient_clipping = optimizer_args.gradient_clipping,
max_gradient=optimizer_args.max_gradient,
Expand Down Expand Up @@ -94,23 +94,23 @@ def invoke(
)
else:
return torch.ops.fb.split_embedding_codegen_lookup_{{ optimizer }}_function(
# commom_args
dev_weights=commom_args.dev_weights,
uvm_weights=commom_args.uvm_weights,
lxu_cache_weights=commom_args.lxu_cache_weights,
weights_placements=commom_args.weights_placements,
weights_offsets=commom_args.weights_offsets,
D_offsets=commom_args.D_offsets,
total_D=commom_args.total_D,
max_D=commom_args.max_D,
hash_size_cumsum=commom_args.hash_size_cumsum,
total_hash_size_bits=commom_args.total_hash_size_bits,
indices=commom_args.indices,
offsets=commom_args.offsets,
pooling_mode=commom_args.pooling_mode,
indice_weights=commom_args.indice_weights,
feature_requires_grad=commom_args.feature_requires_grad,
lxu_cache_locations=commom_args.lxu_cache_locations,
# common_args
dev_weights=common_args.dev_weights,
uvm_weights=common_args.uvm_weights,
lxu_cache_weights=common_args.lxu_cache_weights,
weights_placements=common_args.weights_placements,
weights_offsets=common_args.weights_offsets,
D_offsets=common_args.D_offsets,
total_D=common_args.total_D,
max_D=common_args.max_D,
hash_size_cumsum=common_args.hash_size_cumsum,
total_hash_size_bits=common_args.total_hash_size_bits,
indices=common_args.indices,
offsets=common_args.offsets,
pooling_mode=common_args.pooling_mode,
indice_weights=common_args.indice_weights,
feature_requires_grad=common_args.feature_requires_grad,
lxu_cache_locations=common_args.lxu_cache_locations,
# optimizer_args
gradient_clipping = optimizer_args.gradient_clipping,
max_gradient=optimizer_args.max_gradient,
Expand Down
18 changes: 9 additions & 9 deletions fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def forward(
if len(self.lxu_cache_locations_list) == 0
else self.lxu_cache_locations_list.pop(0)
)
commom_args = invokers.lookup_args.CommonArgs(
common_args = invokers.lookup_args.CommonArgs(
# pyre-fixme[16]
dev_weights=self.weights_dev,
# pyre-fixme[16]
Expand Down Expand Up @@ -407,7 +407,7 @@ def forward(
)

if self.optimizer == OptimType.EXACT_SGD:
return invokers.lookup_sgd.invoke(commom_args, self.optimizer_args)
return invokers.lookup_sgd.invoke(common_args, self.optimizer_args)

momentum1 = invokers.lookup_args.Momentum(
dev=self.momentum1_dev,
Expand All @@ -419,15 +419,15 @@ def forward(

if self.optimizer == OptimType.LARS_SGD:
return invokers.lookup_lars_sgd.invoke(
commom_args, self.optimizer_args, momentum1
common_args, self.optimizer_args, momentum1
)
if self.optimizer == OptimType.EXACT_ADAGRAD:
return invokers.lookup_adagrad.invoke(
commom_args, self.optimizer_args, momentum1
common_args, self.optimizer_args, momentum1
)
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
return invokers.lookup_rowwise_adagrad.invoke(
commom_args, self.optimizer_args, momentum1
common_args, self.optimizer_args, momentum1
)

momentum2 = invokers.lookup_args.Momentum(
Expand All @@ -444,19 +444,19 @@ def forward(

if self.optimizer == OptimType.ADAM:
return invokers.lookup_adam.invoke(
commom_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
common_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
)
if self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
return invokers.lookup_partial_rowwise_adam.invoke(
commom_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
common_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
)
if self.optimizer == OptimType.LAMB:
return invokers.lookup_lamb.invoke(
commom_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
common_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
)
if self.optimizer == OptimType.PARTIAL_ROWWISE_LAMB:
return invokers.lookup_partial_rowwise_lamb.invoke(
commom_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
common_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
)

raise ValueError(f"Invalid OptimType: {self.optimizer}")
Expand Down

0 comments on commit fe1fe6d

Please sign in to comment.