Skip to content

Commit

Permalink
Use exported functions instead of calling initialize_weights in weigh…
Browse files Browse the repository at this point in the history
…ts loading (pytorch#1676)

Summary:
Pull Request resolved: pytorch#1676

Export a function to reset the embedding specs by target location

Reviewed By: RoshanPAN, houseroad

Differential Revision: D44338258

fbshipit-source-id: 502733e9f3a164450a02656d2822492fbf69f994
  • Loading branch information
qxy11 authored and facebook-github-bot committed Mar 31, 2023
1 parent d559a10 commit 1ac526f
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
@@ -2684,12 +2684,33 @@ def initialize_physical_weights_placements_and_offsets(

@torch.jit.export
def reset_weights_placements_and_offsets(
self,
self, device: torch.device, location: int
) -> None:
# Reset device/location denoted in embedding specs
self.reset_embedding_spec_location(device, location)
# Initialize all physical/logical weights placements and offsets without initializing large dev weights tensor
self.initialize_physical_weights_placements_and_offsets()
self.initialize_logical_weights_placements_and_offsets()

def reset_embedding_spec_location(
self, device: torch.device, location: int
) -> None:
# Overwrite location in embedding_specs with new location
# Use map since can't script enum call (ie. EmbeddingLocation(value))
INT_TO_EMBEDDING_LOCATION = {
0: EmbeddingLocation.DEVICE,
1: EmbeddingLocation.MANAGED,
2: EmbeddingLocation.MANAGED_CACHING,
3: EmbeddingLocation.HOST,
}
target_location = INT_TO_EMBEDDING_LOCATION[location]
self.current_device = device
self.row_alignment = 1 if target_location == EmbeddingLocation.HOST else 16
self.embedding_specs = [
(spec[0], spec[1], spec[2], spec[3], target_location)
for spec in self.embedding_specs
]

def _apply_split(
self,
dev_size: int,

0 comments on commit 1ac526f

Please sign in to comment.