Skip to content

Commit

Permalink
Fix the function schema for split embedding backward (pytorch#2212)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2212

The momentum1_host is a write back tensor and in the functiona schema, it should be labeled as "Tensor(b!)" in order to make the CPU fallback correctly write data back to it.
Can't find a more elegant solution to fix this now.

Reviewed By: jspark1105

Differential Revision: D52082756

fbshipit-source-id: f19ad2332ec5a0f150ad37e7203cb8682fad26a6
  • Loading branch information
egienvalue authored and facebook-github-bot committed Dec 14, 2023
1 parent cfb8d11 commit 12a08d6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ for (const auto d : c10::irange(D)) {

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
{% if not dense %}
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor(a!) host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, bool stochastic_rounding, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int")}}, int output_dtype = 0) -> ()");
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor(a!) host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, bool stochastic_rounding, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int").replace("Tensor momentum1_host", "Tensor(b!) momentum1_host")}}, int output_dtype = 0) -> ()");
{% else %}
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor(a!) host_weights, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int")}}) -> Tensor");
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor(a!) host_weights, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int").replace("Tensor momentum1_host", "Tensor(b!) momentum1_host")}}) -> Tensor");
{% endif %}
DISPATCH_TO_CPU("split_embedding_backward_codegen_{{ optimizer }}_cpu", split_embedding_backward_codegen_{{ optimizer }}_cpu);
}
Expand Down

0 comments on commit 12a08d6

Please sign in to comment.