Skip to content

Commit

Permalink
allow all uvm table placement in split_table_batched_embeddings_bench…
Browse files Browse the repository at this point in the history
…mark (pytorch#594)

Summary:
Pull Request resolved: pytorch#594

Although the UVM benchmark path allows T_uvm to be equal to T, the benchmark crashes when this occurs as the code always expects T_gpu > 0. This diff places the non-uvm allocated portions of benchmark behind a T_gpu > 0 check, enabling the benchmarking of the T_uvm == T case.

Reviewed By: jianyuh

Differential Revision: D27856247

fbshipit-source-id: e4012004c3540fc89fa4eb2fd08e2cc550b79f5e
  • Loading branch information
gsethi523 authored and facebook-github-bot committed Apr 19, 2021
1 parent 625201f commit 698c2dc
Showing 1 changed file with 97 additions and 89 deletions.
186 changes: 97 additions & 89 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def uvm(
T = num_tables
T_uvm = uvm_tables
assert T_uvm <= T
assert T_uvm > 0, f"T_uvm specified {T_uvm} <= 0. If not testing UVM, please use device benchmark."
T_gpu = T - T_uvm
L_uvm = uvm_bag_size

Expand Down Expand Up @@ -393,43 +394,44 @@ def uvm(
if weights_precision == SparseType.INT8:
emb_uvm.init_embedding_weights_uniform(-0.0003, 0.0003)

emb_gpu = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
EmbeddingLocation.DEVICE,
ComputeDevice.CUDA,
)
for d in Ds[T_uvm:]
],
weights_precision=weights_precision,
stochastic_rounding=stoc,
).cuda()

if weights_precision == SparseType.INT8:
emb_gpu.init_embedding_weights_uniform(-0.0003, 0.0003)

emb_mixed = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
managed_option,
ComputeDevice.CUDA,
)
for (d, managed_option) in zip(
Ds,
[EmbeddingLocation.MANAGED] * T_uvm
+ [EmbeddingLocation.DEVICE] * T_gpu,
)
],
weights_precision=weights_precision,
stochastic_rounding=stoc,
).cuda()

if weights_precision == SparseType.INT8:
emb_mixed.init_embedding_weights_uniform(-0.0003, 0.0003)
if T_gpu > 0:
emb_gpu = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
EmbeddingLocation.DEVICE,
ComputeDevice.CUDA,
)
for d in Ds[T_uvm:]
],
weights_precision=weights_precision,
stochastic_rounding=stoc,
).cuda()

if weights_precision == SparseType.INT8:
emb_gpu.init_embedding_weights_uniform(-0.0003, 0.0003)

emb_mixed = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
d,
managed_option,
ComputeDevice.CUDA,
)
for (d, managed_option) in zip(
Ds,
[EmbeddingLocation.MANAGED] * T_uvm
+ [EmbeddingLocation.DEVICE] * T_gpu,
)
],
weights_precision=weights_precision,
stochastic_rounding=stoc,
).cuda()

if weights_precision == SparseType.INT8:
emb_mixed.init_embedding_weights_uniform(-0.0003, 0.0003)

requests_uvm = generate_requests(
iters,
Expand All @@ -442,46 +444,21 @@ def uvm(
weights_precision=weights_precision,
weighted=weighted,
)
requests_gpu = generate_requests(
iters,
B,
T_gpu,
L,
E,
reuse=reuse,
alpha=alpha,
weights_precision=weights_precision,
weighted=False,
)
requests = []
for rs_uvm, rs_gpu in zip(requests_uvm, requests_gpu):
indices = torch.cat([rs_uvm[0], rs_gpu[0]])
lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B)
offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda()
per_sample_weights = None
if weighted:
assert (this_rs_uvm_weights := rs_uvm[2]) is not None
assert (this_rs_gpu_weights := rs_gpu[2]) is not None
per_sample_weights = torch.cat([this_rs_uvm_weights, this_rs_gpu_weights])
requests.append((indices, offsets, per_sample_weights))

# forward
time_per_iter = benchmark_requests(
requests_gpu,
lambda indices, offsets, per_sample_weights: emb_gpu.forward(
indices.long(),
offsets.long(),
per_sample_weights,
),
)
param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision]
if T_gpu > 0:
requests_gpu = generate_requests(
iters,
B,
T_gpu,
L,
E,
reuse=reuse,
alpha=alpha,
weights_precision=weights_precision,
weighted=False,
)

logging.info(
f"GPU Forward, B: {B}, "
f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
f"BW: {param_size_multiplier * B * sum(Ds[T_uvm:]) * L / time_per_iter / 1.0e9: .2f}GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)
param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision]

time_per_iter = benchmark_requests(
requests_uvm,
Expand All @@ -498,20 +475,51 @@ def uvm(
f"T: {time_per_iter * 1.0e6:.0f}us"
)

time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb_mixed.forward(
indices.long(),
offsets.long(),
per_sample_weights,
),
)
logging.info(
f"Mixed Forward, B: {B}, "
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
f"BW: {param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f}GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)
if T_gpu > 0:
requests = []
for rs_uvm, rs_gpu in zip(requests_uvm, requests_gpu):
indices = torch.cat([rs_uvm[0], rs_gpu[0]])
lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B)
offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda()
per_sample_weights = None
if weighted:
assert (this_rs_uvm_weights := rs_uvm[2]) is not None
assert (this_rs_gpu_weights := rs_gpu[2]) is not None
per_sample_weights = torch.cat([this_rs_uvm_weights, this_rs_gpu_weights])
requests.append((indices, offsets, per_sample_weights))

# forward
time_per_iter = benchmark_requests(
requests_gpu,
lambda indices, offsets, per_sample_weights: emb_gpu.forward(
indices.long(),
offsets.long(),
per_sample_weights,
),
)

logging.info(
f"GPU Forward, B: {B}, "
f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
f"BW: {param_size_multiplier * B * sum(Ds[T_uvm:]) * L / time_per_iter / 1.0e9: .2f}GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)


time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb_mixed.forward(
indices.long(),
offsets.long(),
per_sample_weights,
),
)
logging.info(
f"Mixed Forward, B: {B}, "
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
f"BW: {param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f}GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)


@cli.command()
Expand Down

0 comments on commit 698c2dc

Please sign in to comment.