diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index ff8ce4d094..c327d359cc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -115,14 +115,16 @@ class CounterBasedRegularizationDefinition: [("record_cache_miss_counter", bool), ("record_tablewise_cache_miss", bool)], ) - -@dataclass -class SplitState: - dev_size: int - host_size: int - uvm_size: int - placements: List[EmbeddingLocation] - offsets: List[int] +SplitState: NamedTuple = NamedTuple( + "SplitState", + [ + ("dev_size", int), + ("host_size", int), + ("uvm_size", int), + ("placements", List[EmbeddingLocation]), + ("offsets", List[int]), + ], +) def construct_split_state( @@ -132,11 +134,11 @@ def construct_split_state( precision: SparseType = SparseType.FP32, int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET, ) -> SplitState: - placements = [] - offsets = [] - dev_size = 0 - host_size = 0 - uvm_size = 0 + placements: List[EmbeddingLocation] = [] + offsets: List[int] = [] + dev_size: int = 0 + host_size: int = 0 + uvm_size: int = 0 for num_embeddings, embedding_dim, location, _ in embedding_specs: assert ( embedding_dim % 4 == 0 @@ -1935,8 +1937,8 @@ def nbit_construct_split_state( scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES, cacheline_alignment: bool = True, ) -> SplitState: - placements = [] - offsets = [] + placements = torch.jit.annotate(List[EmbeddingLocation], []) + offsets = torch.jit.annotate(List[int], []) dev_size = 0 host_size = 0 uvm_size = 0 @@ -1984,6 +1986,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module): cache_miss_counter: torch.Tensor uvm_cache_stats: torch.Tensor local_uvm_cache_stats: torch.Tensor + weights_offsets: torch.Tensor + weights_placements: torch.Tensor def __init__( self, @@ -2165,21 +2169,7 @@ def max_ty_D(ty: SparseType) -> int: ] self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0 - weight_split: SplitState = nbit_construct_split_state( - self.embedding_specs, - cacheable=True, - row_alignment=self.row_alignment, - scale_bias_size_in_bytes=self.scale_bias_size_in_bytes, - cacheline_alignment=cacheline_alignment, - ) - - self.weights_physical_placements: List[int] = [ - t.value for t in weight_split.placements - ] - self.weights_physical_offsets: List[int] = weight_split.offsets - self.host_size: int = weight_split.host_size - self.dev_size: int = weight_split.dev_size - self.uvm_size: int = weight_split.uvm_size + self.initialize_physical_weights_placements_and_offsets(cacheline_alignment) self.enforce_hbm: bool = enforce_hbm # Assign weights after weights and weights_offsets are initialized. @@ -2192,7 +2182,8 @@ def max_ty_D(ty: SparseType) -> int: self.weights_physical_offsets, self.enforce_hbm, ) - self.assign_embedding_weights(weight_lists) # type: ignore + # pyre-fixme [6]: In call `IntNBitTableBatchedEmbeddingBagsCodegen.assign_embedding_weights`, for 1st positional argument, expected `List[Tuple[Tensor, Optional[Tensor]]]` but got `List[Tuple[Tensor, Tensor]]`. + self.assign_embedding_weights(weight_lists) # Handle index remapping for embedding pruning. self.register_buffer( @@ -2654,6 +2645,51 @@ def forward( fp8_exponent_bias=self.fp8_exponent_bias, ) + def initialize_logical_weights_placements_and_offsets( + self, + ) -> None: + assert len(self.weights_physical_offsets) == len(self.embedding_specs) + assert len(self.weights_physical_offsets) == len( + self.weights_physical_placements + ) + offsets = [self.weights_physical_offsets[t] for t in self.feature_table_map] + placements = [ + self.weights_physical_placements[t] for t in self.feature_table_map + ] + self.weights_offsets = torch.tensor( + offsets, device=self.current_device, dtype=torch.int64 + ) + self.weights_placements = torch.tensor( + placements, device=self.current_device, dtype=torch.int32 + ) + + def initialize_physical_weights_placements_and_offsets( + self, + cacheline_alignment: bool = True, + ) -> None: + # Initialize physical weights placements and offsets + # and host/dev/uvm sizes + weight_split: SplitState = nbit_construct_split_state( + self.embedding_specs, + cacheable=True, + row_alignment=self.row_alignment, + scale_bias_size_in_bytes=self.scale_bias_size_in_bytes, + cacheline_alignment=cacheline_alignment, + ) + self.weights_physical_placements = [t.value for t in weight_split.placements] + self.weights_physical_offsets = weight_split.offsets + self.host_size = weight_split.host_size + self.dev_size = weight_split.dev_size + self.uvm_size = weight_split.uvm_size + + @torch.jit.export + def reset_weights_placements_and_offsets( + self, + ) -> None: + # Initialize all physical/logical weights placements and offsets without initializing large dev weights tensor + self.initialize_physical_weights_placements_and_offsets() + self.initialize_logical_weights_placements_and_offsets() + def _apply_split( self, dev_size: int, @@ -2672,14 +2708,7 @@ def _apply_split( self.dev_size = dev_size self.uvm_size = uvm_size - offsets = [offsets[t] for t in self.feature_table_map] - placements = [placements[t] for t in self.feature_table_map] - self.weights_offsets = torch.tensor( - offsets, device=self.current_device, dtype=torch.int64 - ) - self.weights_placements = torch.tensor( - placements, device=self.current_device, dtype=torch.int32 - ) + self.initialize_logical_weights_placements_and_offsets() if dev_size > 0: self.weights_dev = torch.zeros(