Skip to content

Commit

Permalink
make split_table_batched_embeddings_benchmark work on cpu
Browse files Browse the repository at this point in the history
Summary: Make device benchmark runnable on CPU. Currently switch based on torch.cuda.is_available() . Later we can make it controllable via cmd line args.

Reviewed By: jianyuh

Differential Revision: D27558925

fbshipit-source-id: 0af0d5fd4a199dab2bd26c897efe387f08f5d1a8
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Apr 5, 2021
1 parent 62c7209 commit f5dbb5a
Showing 1 changed file with 85 additions and 67 deletions.
152 changes: 85 additions & 67 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Callable, List, Optional, Tuple
import time
from typing import Callable, Dict, List, Optional, Tuple

import click
import fbgemm_gpu.split_table_batched_embeddings_ops as split_table_batched_embeddings_ops
import numpy as np
import torch
from fbgemm_gpu.split_table_batched_embeddings_ops import OptimType, SparseType
from typing import Dict
from fbgemm_gpu.split_table_batched_embeddings_ops import (
CacheAlgorithm,
ComputeDevice,
EmbeddingLocation,
OptimType,
SparseType,
SplitTableBatchedEmbeddingBagsCodegen,
)

logging.basicConfig(level=logging.DEBUG)

Expand All @@ -28,6 +34,12 @@ def div_round_up(a: int, b: int) -> int:
return int((a + b - 1) // b) * b


def get_device() -> torch.device:
return (
torch.cuda.current_device if torch.cuda.is_available() else torch.device("cpu")
)


# Merged indices with shape (T, B, L) -> (flattened indices with shape
# (T * B * L), offsets with shape (T * B + 1))
def get_table_batched_offsets_from_dense(
Expand All @@ -37,8 +49,8 @@ def get_table_batched_offsets_from_dense(
lengths = np.ones((T, B)) * L
flat_lengths = lengths.flatten()
return (
merged_indices.long().contiguous().view(-1).cuda(),
torch.tensor(([0] + np.cumsum(flat_lengths).tolist())).long().cuda(),
merged_indices.long().contiguous().view(-1).to(get_device()),
torch.tensor(([0] + np.cumsum(flat_lengths).tolist())).long().to(get_device()),
)


Expand All @@ -61,31 +73,35 @@ def generate_requests(
low=0,
high=E,
size=(iters, T, B * L),
device=torch.cuda.current_device(),
device=get_device(),
dtype=torch.int32,
)
else:
all_indices = (
torch.as_tensor(np.random.zipf(a=alpha, size=(iters, T, B * L)))
.to(torch.cuda.current_device())
.to(get_device())
.int()
% E
)
for it in range(iters - 1):
for t in range(T):
reused_indices = torch.randperm(B * L, device=torch.cuda.current_device())[
reused_indices = torch.randperm(B * L, device=get_device())[
: int(B * L * reuse)
]
all_indices[it + 1, t, reused_indices] = all_indices[it, t, reused_indices]

rs = []
for it in range(iters):
weights_tensor = None if not weighted else torch.randn(
T * B * L,
device=torch.cuda.current_device(),
dtype=torch.float16
if weights_precision == SparseType.FP16
else torch.float32,
weights_tensor = (
None
if not weighted
else torch.randn(
T * B * L,
device=get_device(),
dtype=torch.float16
if weights_precision == SparseType.FP16
else torch.float32,
)
)
rs.append(
get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L))
Expand All @@ -99,15 +115,21 @@ def benchmark_requests(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
f: Callable,
) -> float:
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
if torch.cuda.is_available():
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
else:
start_time = time.time()
for (indices, offsets, weights) in requests:
f(indices, offsets, weights)
end_event.record()
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / len(requests)
if torch.cuda.is_available():
end_event.record()
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / len(requests)
else:
return (time.time() - start_time) / len(requests)


def benchmark_pipelined_requests(
Expand Down Expand Up @@ -203,7 +225,7 @@ def device( # noqa C901
torch.tensor(
[1 if t in weighted_requires_grad_tables else 0 for t in range(T)]
)
.cuda()
.to(get_device())
.int()
)
else:
Expand All @@ -219,17 +241,21 @@ def device( # noqa C901
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD

if managed == "device":
managed_option = split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE
managed_option = (
EmbeddingLocation.DEVICE
if torch.cuda.is_available()
else EmbeddingLocation.HOST
)
else:
managed_option = split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED
managed_option = EmbeddingLocation.MANAGED

emb = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
emb = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
managed_option,
split_table_batched_embeddings_ops.ComputeDevice.CUDA,
ComputeDevice.CUDA if torch.cuda.is_available() else ComputeDevice.CPU,
)
for d in Ds
],
Expand All @@ -238,7 +264,7 @@ def device( # noqa C901
eps=0.1,
weights_precision=weights_precision,
stochastic_rounding=stoc,
).cuda()
).to(get_device())
if weights_precision == SparseType.INT8:
emb.init_embedding_weights_uniform(-0.0003, 0.0003)

Expand Down Expand Up @@ -283,7 +309,7 @@ def device( # noqa C901
f"T: {time_per_iter * 1.0e6:.0f}us"
)

grad_output = torch.randn(B, sum(Ds)).cuda()
grad_output = torch.randn(B, sum(Ds)).to(get_device())
# backward
time_per_iter = benchmark_requests(
requests,
Expand Down Expand Up @@ -352,13 +378,13 @@ def uvm(
D = np.average(Ds)
else:
Ds = [D] * T
emb_uvm = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
emb_uvm = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
split_table_batched_embeddings_ops.ComputeDevice.CUDA,
EmbeddingLocation.MANAGED,
ComputeDevice.CUDA,
)
for d in Ds[:T_uvm]
],
Expand All @@ -369,13 +395,13 @@ def uvm(
if weights_precision == SparseType.INT8:
emb_uvm.init_embedding_weights_uniform(-0.0003, 0.0003)

emb_gpu = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
emb_gpu = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE,
split_table_batched_embeddings_ops.ComputeDevice.CUDA,
EmbeddingLocation.DEVICE,
ComputeDevice.CUDA,
)
for d in Ds[T_uvm:]
],
Expand All @@ -386,27 +412,23 @@ def uvm(
if weights_precision == SparseType.INT8:
emb_gpu.init_embedding_weights_uniform(-0.0003, 0.0003)

emb_mixed = (
split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
managed_option,
split_table_batched_embeddings_ops.ComputeDevice.CUDA,
)
for (d, managed_option) in zip(
Ds,
[split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED]
* T_uvm
+ [split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE]
* T_gpu,
)
],
weights_precision=weights_precision,
stochastic_rounding=stoc,
).cuda()
)
emb_mixed = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
managed_option,
ComputeDevice.CUDA,
)
for (d, managed_option) in zip(
Ds,
[EmbeddingLocation.MANAGED] * T_uvm
+ [EmbeddingLocation.DEVICE] * T_gpu,
)
],
weights_precision=weights_precision,
stochastic_rounding=stoc,
).cuda()

if weights_precision == SparseType.INT8:
emb_mixed.init_embedding_weights_uniform(-0.0003, 0.0003)
Expand Down Expand Up @@ -535,11 +557,7 @@ def cache( # noqa C901
L = bag_size
E = num_embeddings
T = num_tables
cache_alg = (
split_table_batched_embeddings_ops.CacheAlgorithm.LRU
if cache_algorithm == "lru"
else split_table_batched_embeddings_ops.CacheAlgorithm.LFU
)
cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU
if mixed:
Ds = [
div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
Expand All @@ -549,13 +567,13 @@ def cache( # noqa C901
else:
Ds = [D] * T

emb_nc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
emb_nc = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED,
split_table_batched_embeddings_ops.ComputeDevice.CUDA,
EmbeddingLocation.MANAGED,
ComputeDevice.CUDA,
)
for d in Ds
],
Expand All @@ -567,13 +585,13 @@ def cache( # noqa C901
if weights_precision == SparseType.INT8:
emb_nc.init_embedding_weights_uniform(-0.0003, 0.0003)

emb = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
emb = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING,
split_table_batched_embeddings_ops.ComputeDevice.CUDA,
EmbeddingLocation.MANAGED_CACHING,
ComputeDevice.CUDA,
)
for d in Ds
],
Expand Down

0 comments on commit f5dbb5a

Please sign in to comment.