Skip to content

Commit

Permalink
Break up fbgemm_gpu:split_table_batched_embeddings_ops into 2 targe…
Browse files Browse the repository at this point in the history
…ts (pytorch#1783)

Summary:
Pull Request resolved: pytorch#1783

- Create two separate targets to replace `fbgemm_gpu:split_table_batched_embeddings_ops`
corresponding to training and inference

Reviewed By: sryap

Differential Revision: D46041414

fbshipit-source-id: a601727ca35f3a36c14b4f3a24b863a9b0a17dbe
  • Loading branch information
q10 authored and facebook-github-bot committed May 24, 2023
1 parent a0a0413 commit 042cf19
Show file tree
Hide file tree
Showing 15 changed files with 3,901 additions and 3,884 deletions.
9 changes: 6 additions & 3 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import tabulate
import torch

from fbgemm_gpu.split_table_batched_embeddings_ops import (
from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
EmbeddingLocation,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
SparseType,
)

from torch import Tensor
from torch.profiler import profile, ProfilerActivity

Expand Down Expand Up @@ -228,7 +231,7 @@ def print_p2p_bandwidth(
print(table)


def benchmark(
def benchmark( # noqa C901
all_to_one_only: bool,
sum_reduce_to_one_only: bool,
num_ads: int,
Expand Down
7 changes: 6 additions & 1 deletion fbgemm_gpu/bench/split_embeddings_cache_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
import click
import numpy as np
import torch

from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops import (
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
CacheAlgorithm,
EmbeddingLocation,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)

from torch import nn, Tensor

logging.basicConfig(level=logging.DEBUG)
Expand All @@ -25,6 +29,7 @@
# pyre-ignore[21]
from fbgemm_gpu import open_source # noqa: F401
except Exception:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils")
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings"
)
Expand Down
18 changes: 12 additions & 6 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,26 @@
import click
import fbgemm_gpu
import numpy as np

import torch

from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_embedding_utils import generate_requests, get_device, round_up
from fbgemm_gpu.split_table_batched_embeddings_ops import (

from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
CacheAlgorithm,
ComputeDevice,
DenseTableBatchedEmbeddingBagsCodegen,
EmbeddingLocation,
IntNBitTableBatchedEmbeddingBagsCodegen,
OptimType,
PoolingMode,
RecordCacheMetrics,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
rounded_row_size_in_bytes,
SparseType,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
ComputeDevice,
DenseTableBatchedEmbeddingBagsCodegen,
SplitTableBatchedEmbeddingBagsCodegen,
)
from torch import Tensor
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from fbgemm_gpu.bench.bench_utils import benchmark_requests
from fbgemm_gpu.split_embedding_utils import generate_requests, round_up
from fbgemm_gpu.split_table_batched_embeddings_ops import (
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.ssd_split_table_batched_embeddings_ops import (
Expand Down
11 changes: 7 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/_fbgemm_gpu_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import fbgemm_gpu
import fbgemm_gpu.split_table_batched_embeddings_ops
import fbgemm_gpu.split_table_batched_embeddings_ops_training
import torch # usort:skip

Tensor = torch.Tensor
Expand Down Expand Up @@ -231,7 +231,7 @@ def add_docs(method, docstr):


add_docs(
fbgemm_gpu.split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen,
fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen,
"""
SplitTableBatchedEmbeddingBagsCodegen(embedding_specs, feature_table_map=None, cache_algorithm=CacheAlgorithm.LRU, cache_load_factor=0.2, cache_sets=0, cache_reserved_memory=0.0, cache_precision=SparseType.FP32, weights_precision=SparseType.FP32, output_dtype=SparseType.FP32, enforce_hbm=False, optimizer=OptimType.EXACT_SGD, record_cache_metrics=None, stochastic_rounding=True, gradient_clipping=False, max_gradient=1.0, learning_rate=0.01, eps=1.0e-8, momentum=0.9, weight_decay=0.0, weight_decay_mode=WeightDecayMode.NONE, eta=0.001, beta1=0.9, beta2=0.999, pooling_mode=PoolingMode.SUM, device=None, bounds_check_mode=BoundsCheckMode.WARNING) -> None
Expand Down Expand Up @@ -304,9 +304,12 @@ def add_docs(method, docstr):
Example:
>>> import torch
>>> from fbgemm_gpu.split_table_batched_embeddings_ops import (
>>> SplitTableBatchedEmbeddingBagsCodegen,
>>>
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
>>> EmbeddingLocation,
>>> )
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
>>> SplitTableBatchedEmbeddingBagsCodegen,
>>> ComputeDevice,
>>> )
>>>
Expand Down
25 changes: 15 additions & 10 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
import math
from typing import Optional, Tuple

import fbgemm_gpu.split_table_batched_embeddings_ops as split_table_batched_embeddings_ops # usort:skip
import numpy as np
import torch

from fbgemm_gpu.split_embedding_configs import QuantizationConfig, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
ComputeDevice,
SplitTableBatchedEmbeddingBagsCodegen,
)

from torch import nn, Tensor # usort:skip


Expand Down Expand Up @@ -60,7 +68,7 @@ def _prune_embs(
self,
idx: int,
num_rows: int,
module: split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen,
module: SplitTableBatchedEmbeddingBagsCodegen,
) -> Tuple[Tensor, Optional[Tensor]]:
# TODO(yingz): Avoid DtoH / HtoD overhead.
weights = module.split_embedding_weights()[idx].cpu()
Expand Down Expand Up @@ -143,13 +151,10 @@ def _process_split_embs(self, model: nn.Module) -> None:
for name, child in model.named_children():
if isinstance(
child,
split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen,
SplitTableBatchedEmbeddingBagsCodegen,
):
embedding_specs = []
use_cpu = (
child.embedding_specs[0][3]
== split_table_batched_embeddings_ops.ComputeDevice.CPU
)
use_cpu = child.embedding_specs[0][3] == ComputeDevice.CPU
for E, D, _, _ in child.embedding_specs:
weights_ty = self.quantize_type
if D % weights_ty.align_size() != 0:
Expand All @@ -174,9 +179,9 @@ def _process_split_embs(self, model: nn.Module) -> None:
pruned_weight.size()[0],
D,
weight_ty,
split_table_batched_embeddings_ops.EmbeddingLocation.HOST
EmbeddingLocation.HOST
if use_cpu
else split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
else EmbeddingLocation.DEVICE,
)
)
index_remapping_list.append(index_remapping)
Expand All @@ -186,7 +191,7 @@ def _process_split_embs(self, model: nn.Module) -> None:

is_fp8_weight = self.quantize_type == SparseType.FP8

q_child = split_table_batched_embeddings_ops.IntNBitTableBatchedEmbeddingBagsCodegen(
q_child = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=new_embedding_specs,
index_remapping=index_remapping_list
if self.pruning_ratio is not None
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def b_indices(
return b(to_device(indices, use_cpu))


def generate_requests(
def generate_requests( # noqa C901
iters: int,
B: int,
T: int,
Expand Down
Loading

0 comments on commit 042cf19

Please sign in to comment.