Skip to content

Commit

Permalink
Pyre for fbgemm (pytorch#558)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#558

Reviewed By: jianyuh

Differential Revision: D27165949

fbshipit-source-id: d88b0fe85913ec9b0714d28a8802e28aa7bcbe98
  • Loading branch information
r-barnes authored and facebook-github-bot committed Mar 24, 2021
1 parent a8097ee commit 203f7ff
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 156 deletions.
12 changes: 8 additions & 4 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import numpy as np
import torch
from fbgemm_gpu.split_table_batched_embeddings_ops import OptimType, SparseType
from typing import Dict

logging.basicConfig(level=logging.DEBUG)

PRECISION_SIZE_MULTIPLIER = {
PRECISION_SIZE_MULTIPLIER: Dict[SparseType, int] = {
SparseType.FP32: 4,
SparseType.FP16: 2,
SparseType.INT8: 1,
Expand Down Expand Up @@ -98,8 +99,9 @@ def generate_requests(

def benchmark_requests(
requests: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
f: Callable,
):
) -> float:
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -113,9 +115,11 @@ def benchmark_requests(

def benchmark_pipelined_requests(
requests: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
f: Callable,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
g: Callable,
):
) -> Tuple[float, float]:
torch.cuda.synchronize()
start_events = [
(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
Expand Down Expand Up @@ -150,7 +154,7 @@ def benchmark_pipelined_requests(


@click.group()
def cli():
def cli() -> None:
pass


Expand Down
121 changes: 64 additions & 57 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,21 @@
# LICENSE file in the root directory of this source tree.

import argparse
import collections
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import jinja2

args: argparse.Namespace
_: List[str]
TENSOR: int
INT_TENSOR: int
LONG_TENSOR: int
INT: int
FLOAT: int


parser = argparse.ArgumentParser()
# By default the source template files are in the same folder as
# embedding_backward_code_generator.py;
Expand All @@ -26,72 +36,72 @@
env.globals["dense"] = False


def write(filename, s):
def write(filename: str, s: str) -> None:
with open(os.path.join(args.install_dir, filename), "w") as f:
f.write(s)


def _arg_constructor(type: str, name: str, gpu: bool = True, precision: int = 32):
def _arg_constructor(type: str, name: str, gpu: bool = True, precision: int = 32) -> str:
return (
f"{name}.packed_accessor{precision}<{type}, 1, RestrictPtrTraits>()"
if gpu
else f"auto {name}_accessor = {name}.accessor<{type}, 1>()"
)


def _arg(type: str, name: str, precision: int = 32):
def _arg(type: str, name: str, precision: int = 32) -> str:
return f"PackedTensorAccessor{precision}<{type}, 1, RestrictPtrTraits> {name}"


def acc_cache_tensor_arg_constructor(name):
def acc_cache_tensor_arg_constructor(name: str) -> str:
return _arg_constructor("acc_type<cache_t, true>", name, precision=64)


def acc_cache_tensor_arg(name):
def acc_cache_tensor_arg(name: str) -> str:
return _arg("acc_type<cache_t, true>", name, precision=64)


def long_tensor_arg_constructor(name, gpu=True):
def long_tensor_arg_constructor(name: str, gpu: bool = True) -> str:
return _arg_constructor("int64_t", name, gpu=gpu)


def long_tensor_arg(name):
def long_tensor_arg(name: str) -> str:
return _arg("int64_t", name)


def int_tensor_arg_constructor(name, gpu=True):
def int_tensor_arg_constructor(name: str, gpu: bool = True) -> str:
return _arg_constructor("int32_t", name, gpu=gpu)


def int_tensor_arg(name):
def int_tensor_arg(name: str) -> str:
return _arg("int32_t", name)


def host_accessor_constructor(name):
def host_accessor_constructor(name: str) -> str:
return _arg_constructor("acc_type<scalar_t, true>", name, gpu=False)


def tensor_arg(name):
def tensor_arg(name: str) -> str:
return f"Tensor {name}"


def double_arg(name):
def double_arg(name: str) -> str:
return f"double {name}"


def float_arg(name):
def float_arg(name: str) -> str:
return f"float {name}"


def int64_arg(name):
def int64_arg(name: str) -> str:
return f"int64_t {name}"


def int_arg(name):
def int_arg(name: str) -> str:
return f"int {name}"


def generate(**kwargs):
def generate(**kwargs: Any) -> None:
gen_args = kwargs["args"]

# Generates CUDA variants.
Expand Down Expand Up @@ -142,27 +152,24 @@ def generate(**kwargs):
)


Args = collections.namedtuple(
"Args",
[
"split_kernel_args",
"split_kernel_arg_constructors",
"split_host_accessor_constructors",
"split_function_args",
"split_saved_tensors",
"split_tensors",
"saved_data",
"split_function_arg_names",
"split_function_schemas",
"split_variables",
],
)
@dataclass
class Args:
split_kernel_args: List[str]
split_kernel_arg_constructors: List[str]
split_host_accessor_constructors: List[str]
split_function_args: List[str]
split_saved_tensors: List[str]
split_tensors: List[str]
saved_data: List[Tuple[str, str]]
split_function_arg_names: List[str]
split_function_schemas: List[str]
split_variables: List[str]

TENSOR, INT_TENSOR, LONG_TENSOR, INT, FLOAT = range(5)


def make_args(arg_spec):
def make_kernel_arg(ty, name):
def make_args(arg_spec: List[Tuple[int,str]]) -> Dict[str, Any]:
def make_kernel_arg(ty: int, name: str) -> str:
return {
TENSOR: acc_cache_tensor_arg,
INT_TENSOR: int_tensor_arg,
Expand All @@ -171,7 +178,7 @@ def make_kernel_arg(ty, name):
FLOAT: float_arg,
}[ty](name)

def make_kernel_arg_constructor(ty, name):
def make_kernel_arg_constructor(ty: int, name: str) -> str:
return {
TENSOR: acc_cache_tensor_arg_constructor,
INT_TENSOR: int_tensor_arg_constructor,
Expand All @@ -180,7 +187,7 @@ def make_kernel_arg_constructor(ty, name):
FLOAT: lambda x: x,
}[ty](name)

def make_host_accessor_constructor(ty, name):
def make_host_accessor_constructor(ty: int, name: str) -> str:
return {
TENSOR: host_accessor_constructor,
INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False),
Expand All @@ -189,7 +196,7 @@ def make_host_accessor_constructor(ty, name):
FLOAT: lambda x: "",
}[ty](name)

def make_function_arg(ty, name):
def make_function_arg(ty: int, name: str) -> str:
return {
TENSOR: tensor_arg,
INT_TENSOR: tensor_arg,
Expand All @@ -198,7 +205,7 @@ def make_function_arg(ty, name):
FLOAT: double_arg,
}[ty](name)

def make_function_schema_arg(ty, name):
def make_function_schema_arg(ty: int, name: str) -> str:
return {
TENSOR: tensor_arg,
INT_TENSOR: tensor_arg,
Expand All @@ -207,10 +214,10 @@ def make_function_schema_arg(ty, name):
FLOAT: float_arg,
}[ty](name)

def make_ivalue_cast(ty):
def make_ivalue_cast(ty: int) -> str:
return {INT: "toInt", FLOAT: "toDouble"}[ty]

def make_args_for_compute_device(split_arg_spec):
def make_args_for_compute_device(split_arg_spec: List[Tuple[int, str]]) -> Args:
return Args(
split_kernel_args=[
make_kernel_arg(ty, name) for (ty, name) in split_arg_spec
Expand Down Expand Up @@ -275,7 +282,7 @@ def make_args_for_compute_device(split_arg_spec):
return {"cpu": cpu, "cuda": cuda}


def adagrad():
def adagrad() -> None:
split_weight_update = """
Vec4T<cache_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x += grad.acc.x * grad.acc.x;
Expand Down Expand Up @@ -310,7 +317,7 @@ def adagrad():
)


def table_info_precomputation(momentum_prefix="momentum1"):
def table_info_precomputation(momentum_prefix: str="momentum1") -> str:
template = """
// table_begin -> (E, D, {momentum_prefix}_row_begin).
std::map<int64_t, std::tuple<int64_t, int64_t, int64_t>> table_info_map;
Expand All @@ -332,7 +339,7 @@ def table_info_precomputation(momentum_prefix="momentum1"):
return template.replace("{momentum_prefix}", momentum_prefix)


def rowwise_adagrad():
def rowwise_adagrad() -> None:
split_weight_update = """
weight_new.fma_(grad, -multiplier);
"""
Expand Down Expand Up @@ -384,7 +391,7 @@ def rowwise_adagrad():
)


def approx_rowwise_adagrad():
def approx_rowwise_adagrad() -> None:
split_weight_update = """
weight_new.fma_(grad, -multiplier);
assert(false); // approx rowwise AdaGrad is not supported on GPU
Expand Down Expand Up @@ -424,7 +431,7 @@ def approx_rowwise_adagrad():
)


def sgd():
def sgd() -> None:
split_weight_update = """
weight_new.fma_(grad, -learning_rate);
"""
Expand All @@ -443,7 +450,7 @@ def sgd():
)


def approx_sgd():
def approx_sgd() -> None:
split_weight_update = """
// approx_sgd not supported for GPU.
// Just do the same thing as exact sgd to avoid unused variable warning.
Expand All @@ -463,7 +470,7 @@ def approx_sgd():
)


def lamb():
def lamb() -> None:
split_precomputation = """
acc_type<cache_t, true> weight_sum_sq = 0.0;
acc_type<cache_t, true> rtw_sum_sq = 0.0;
Expand Down Expand Up @@ -533,7 +540,7 @@ def lamb():
)


def partial_rowwise_lamb():
def partial_rowwise_lamb() -> None:
split_precomputation = """
acc_type<cache_t, true> g_local_sum_square = 0.0;
Expand Down Expand Up @@ -619,7 +626,7 @@ def partial_rowwise_lamb():
)


def adam():
def adam() -> None:
split_weight_update = """
Vec4T<cache_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x *= beta1;
Expand Down Expand Up @@ -669,7 +676,7 @@ def adam():
)


def partial_rowwise_adam():
def partial_rowwise_adam() -> None:
split_precomputation = """
acc_type<cache_t, true> g_local_sum_square = 0.0;
#pragma unroll kMaxVecsPerThread
Expand Down Expand Up @@ -729,7 +736,7 @@ def partial_rowwise_adam():
)


def lars_sgd():
def lars_sgd() -> None:
split_precomputation = """
acc_type<cache_t, true> weight_sum_sq = 0.0;
acc_type<cache_t, true> grad_sum_sq = 0.0;
Expand Down Expand Up @@ -787,7 +794,7 @@ def lars_sgd():
)


def forward_split():
def forward_split() -> None:
template = env.get_template("embedding_forward_split_template.cu")

src_cu = template.render(weighted=False)
Expand All @@ -801,15 +808,15 @@ def forward_split():
write("gen_embedding_forward_dense_weighted_codegen_cuda.cu", src_cu)


def backward_indices():
def backward_indices() -> None:
template = env.get_template("embedding_backward_split_indice_weights_template.cu")
src_cu = template.render()
write("gen_embedding_backward_split_indice_weights_codegen_cuda.cu", src_cu)
src_cu = template.render(dense=True)
write("gen_embedding_backward_dense_indice_weights_codegen_cuda.cu", src_cu)


def backward_dense():
def backward_dense() -> None:
generate(
optimizer="dense",
dense=True,
Expand All @@ -821,13 +828,13 @@ def backward_dense():
)


def gen__init__py():
def gen__init__py() -> None:
template = env.get_template("__init__.template")
src_py = template.render()
write("__init__.py", src_py)


def emb_codegen(install_dir=None, is_fbcode=True):
def emb_codegen(install_dir: Optional[str] = None, is_fbcode: bool=True) -> None:
if install_dir is not None and len(install_dir) != 0:
args.install_dir = install_dir
args.is_fbcode = is_fbcode
Expand All @@ -848,7 +855,7 @@ def emb_codegen(install_dir=None, is_fbcode=True):
gen__init__py()


def main():
def main() -> None:
emb_codegen()


Expand Down
Loading

0 comments on commit 203f7ff

Please sign in to comment.