Skip to content

Commit

Permalink
Move batched unary embedding benchmark and refactoring (pytorch#981)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#981

- Move the batched unary embedding benchmark to OSS
- Change "fbgemm_gpu.bench.utils" to "bench_utils" to be compatible with OSS (with "open_source" flag check).
- Use fbgemm benchmark_torch_function for the common bench utility function in hpc.

Reviewed By: yinghai

Differential Revision: D34854189

fbshipit-source-id: c57513a6e9b4cf4060ce3028036802f0b59950f7
  • Loading branch information
jianyuh authored and facebook-github-bot committed Mar 14, 2022
1 parent abc7b74 commit 6c66609
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 14 deletions.
160 changes: 160 additions & 0 deletions fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# pyre-unsafe
import functools
from math import sqrt
from typing import List, Tuple

import click
import fbgemm_gpu
import fbgemm_gpu.batched_unary_embeddings_ops as batched_unary_embeddings_ops
import numpy as np
import torch

open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
# pyre-ignore[21]
from bench_utils import benchmark_torch_function
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function

torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")


def generate_unary_feature(
batch_size: int, num_embeddings: int
) -> Tuple[List, List, List]:
lengths = []
offsets = []
indices = []
offset = 0
for _ in range(batch_size):
n_indices = 1
indices += (
np.round(np.random.random(n_indices) * (num_embeddings - 1))
.astype(int)
.tolist()
)
offsets.append(offset)
offset += 1
lengths.append(n_indices)
offsets.append(offset)
return (lengths, offsets, indices)


class MyModule(torch.nn.Module):
def __init__(self, num_tasks: int, hash_sizes: List[int]) -> None:
super().__init__()
self.num_tasks = num_tasks
self.hash_sizes = hash_sizes
self.emb_modules = torch.nn.ModuleList()
for _ in range(num_tasks):
for h in self.hash_sizes:
emb = torch.nn.EmbeddingBag(
num_embeddings=h,
embedding_dim=1,
mode="sum",
sparse=False,
include_last_offset=True,
)
emb.weight = torch.nn.Parameter(
torch.empty([h, 1]).uniform_(-sqrt(1 / h), sqrt(1 / h))
)
self.emb_modules.append(emb)

def forward(
self, offsets: List[torch.Tensor], indices: List[torch.Tensor]
) -> torch.Tensor:
tt_list = []
for n in range(self.num_tasks):
t_list = []
for i in range(len(self.hash_sizes)):
t = self.emb_modules[n * len(self.hash_sizes) + i](
offsets=offsets[i].long(), input=indices[i].long()
)
t_list.append(t)
tt = torch.cat(t_list, dim=1)
tt_list.append(tt)
return torch.cat(tt_list).view(self.num_tasks, -1, len(self.hash_sizes))


@click.command()
@click.option("--batch-size", default=512)
@click.option("--num-tables", default=2)
@click.option("--num-tasks", default=3)
@click.option("--repeats", default=100)
def main(batch_size, num_tables, num_tasks, repeats):
device = torch.device("cuda", 0)
torch.cuda.set_device(device)
hash_sizes = list(np.random.choice(range(50, 250), size=(num_tables)))
lengths = []
offsets = []
indices = []
for h in hash_sizes:
l, o, i = generate_unary_feature(batch_size, h)
lengths.append(torch.IntTensor(l).to(device))
offsets.append(torch.IntTensor(o).to(device))
indices.append(torch.IntTensor(i).to(device))
lengths_tensor = torch.cat(lengths)
indices_tensor = torch.cat(indices)
offsets_tensor = torch.zeros(
lengths_tensor.numel() + 1,
dtype=lengths_tensor.dtype,
device=lengths_tensor.device,
)
offsets_tensor[1:] = torch.ops.fbgemm.asynchronous_inclusive_cumsum(
lengths_tensor.view(-1)
)

# forward
ref_emb = MyModule(num_tasks, hash_sizes).to(device)
unary_emb = batched_unary_embeddings_ops.BatchedUnaryEmbeddingBag(
num_tasks, hash_sizes
).to(device)
for i, param in enumerate(unary_emb.split_embedding_weights()):
param.detach().copy_(ref_emb.emb_modules[i].weight)
output_ref = ref_emb(offsets, indices)
output = unary_emb(offsets_tensor, indices_tensor)
torch.testing.assert_allclose(output_ref, output)
# backward
d_output = torch.randn([num_tasks, batch_size, len(hash_sizes)]).to(device) * 0.1
output_ref.backward(d_output)
output.backward(d_output)
d_weight_ref = []
for emb in ref_emb.emb_modules:
d_weight_ref.append(emb.weight.grad)
d_weight_ref = torch.cat(d_weight_ref).view(num_tasks, -1)
d_weight = unary_emb.weight.grad
torch.testing.assert_allclose(d_weight_ref, d_weight.squeeze())

# A100 40MB L2 cache
elapse, _ = benchmark_torch_function(ref_emb, (offsets, indices), iters=repeats)
print("PyTorch EmbeddingBag forward", elapse)

elapse, _ = benchmark_torch_function(
unary_emb,
(offsets_tensor, indices_tensor),
iters=repeats,
)
print("Batched Unary Emb forward", elapse)

output = ref_emb(offsets, indices)
output.backward(d_output, retain_graph=True)
elapse, _ = benchmark_torch_function(
functools.partial(output.backward, retain_graph=True),
(d_output,),
iters=repeats,
)
print("PyTorch EmbeddingBag backward", elapse)

output = unary_emb(offsets_tensor, indices_tensor)
elapse, _ = benchmark_torch_function(
functools.partial(output.backward, retain_graph=True),
(d_output,),
iters=repeats,
)
print("Batched Unary Emb backward", elapse)


if __name__ == "__main__":
main()
File renamed without changes.
18 changes: 8 additions & 10 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@
from typing import Tuple, List

import click
import fbgemm_gpu
import numpy as np
import tabulate
import torch
from fbgemm_gpu.bench.utils import benchmark_torch_function

try:
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
# pyre-ignore[21]
from fbgemm_gpu import open_source # noqa: F401
except Exception:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings")
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu"
)
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
from bench_utils import benchmark_torch_function
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function


from fbgemm_gpu.split_table_batched_embeddings_ops import (
SparseType,
Expand Down
12 changes: 8 additions & 4 deletions fbgemm_gpu/bench/quantize_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
import random

import click
import fbgemm_gpu
import torch
from fbgemm_gpu.bench.utils import benchmark_torch_function

logging.basicConfig(level=logging.DEBUG)

try:
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
# pyre-ignore[21]
from fbgemm_gpu import open_source # noqa: F401
except Exception:
from bench_utils import benchmark_torch_function
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function

torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")

Expand Down

0 comments on commit 6c66609

Please sign in to comment.