Skip to content

Commit

Permalink
Fix 10x model
Browse files Browse the repository at this point in the history
Summary: Add missing weight init for dense embeddings

Reviewed By: xing-liu

Differential Revision: D26453218

fbshipit-source-id: 5cf46239d9f0cfb58604ee2c28fc1e4ed52481a7
  • Loading branch information
Dmytro Ivchenko authored and facebook-github-bot committed Feb 16, 2021
1 parent 24e92f0 commit b51f3e6
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,3 +1042,8 @@ def split_embedding_weights(self):
self.weights.detach()[offset : offset + rows * dim].view(rows, dim)
)
return splits

def init_embedding_weights_uniform(self, min_val, max_val):
splits = self.split_embedding_weights()
for param in splits:
param.uniform_(min_val, max_val)

0 comments on commit b51f3e6

Please sign in to comment.