From 203f7ff6e07d62b042e7d755fd1f4789d978e4d1 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 24 Mar 2021 11:11:04 -0700 Subject: [PATCH] Pyre for fbgemm (#558) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/558 Reviewed By: jianyuh Differential Revision: D27165949 fbshipit-source-id: d88b0fe85913ec9b0714d28a8802e28aa7bcbe98 --- ...plit_table_batched_embeddings_benchmark.py | 12 +- .../embedding_backward_code_generator.py | 121 ++++++----- .../fbgemm_gpu/split_embedding_configs.py | 2 + .../split_table_batched_embeddings_ops.py | 153 +++++++++++--- .../split_table_batched_embeddings_test.py | 198 ++++++++++++------ 5 files changed, 330 insertions(+), 156 deletions(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index af2d40094e..43e588250d 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -13,10 +13,11 @@ import numpy as np import torch from fbgemm_gpu.split_table_batched_embeddings_ops import OptimType, SparseType +from typing import Dict logging.basicConfig(level=logging.DEBUG) -PRECISION_SIZE_MULTIPLIER = { +PRECISION_SIZE_MULTIPLIER: Dict[SparseType, int] = { SparseType.FP32: 4, SparseType.FP16: 2, SparseType.INT8: 1, @@ -98,8 +99,9 @@ def generate_requests( def benchmark_requests( requests: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. f: Callable, -): +) -> float: torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -113,9 +115,11 @@ def benchmark_requests( def benchmark_pipelined_requests( requests: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. f: Callable, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. g: Callable, -): +) -> Tuple[float, float]: torch.cuda.synchronize() start_events = [ (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) @@ -150,7 +154,7 @@ def benchmark_pipelined_requests( @click.group() -def cli(): +def cli() -> None: pass diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index da6a7b8c00..8756464333 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -6,11 +6,21 @@ # LICENSE file in the root directory of this source tree. import argparse -import collections import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple import jinja2 +args: argparse.Namespace +_: List[str] +TENSOR: int +INT_TENSOR: int +LONG_TENSOR: int +INT: int +FLOAT: int + + parser = argparse.ArgumentParser() # By default the source template files are in the same folder as # embedding_backward_code_generator.py; @@ -26,12 +36,12 @@ env.globals["dense"] = False -def write(filename, s): +def write(filename: str, s: str) -> None: with open(os.path.join(args.install_dir, filename), "w") as f: f.write(s) -def _arg_constructor(type: str, name: str, gpu: bool = True, precision: int = 32): +def _arg_constructor(type: str, name: str, gpu: bool = True, precision: int = 32) -> str: return ( f"{name}.packed_accessor{precision}<{type}, 1, RestrictPtrTraits>()" if gpu @@ -39,59 +49,59 @@ def _arg_constructor(type: str, name: str, gpu: bool = True, precision: int = 32 ) -def _arg(type: str, name: str, precision: int = 32): +def _arg(type: str, name: str, precision: int = 32) -> str: return f"PackedTensorAccessor{precision}<{type}, 1, RestrictPtrTraits> {name}" -def acc_cache_tensor_arg_constructor(name): +def acc_cache_tensor_arg_constructor(name: str) -> str: return _arg_constructor("acc_type", name, precision=64) -def acc_cache_tensor_arg(name): +def acc_cache_tensor_arg(name: str) -> str: return _arg("acc_type", name, precision=64) -def long_tensor_arg_constructor(name, gpu=True): +def long_tensor_arg_constructor(name: str, gpu: bool = True) -> str: return _arg_constructor("int64_t", name, gpu=gpu) -def long_tensor_arg(name): +def long_tensor_arg(name: str) -> str: return _arg("int64_t", name) -def int_tensor_arg_constructor(name, gpu=True): +def int_tensor_arg_constructor(name: str, gpu: bool = True) -> str: return _arg_constructor("int32_t", name, gpu=gpu) -def int_tensor_arg(name): +def int_tensor_arg(name: str) -> str: return _arg("int32_t", name) -def host_accessor_constructor(name): +def host_accessor_constructor(name: str) -> str: return _arg_constructor("acc_type", name, gpu=False) -def tensor_arg(name): +def tensor_arg(name: str) -> str: return f"Tensor {name}" -def double_arg(name): +def double_arg(name: str) -> str: return f"double {name}" -def float_arg(name): +def float_arg(name: str) -> str: return f"float {name}" -def int64_arg(name): +def int64_arg(name: str) -> str: return f"int64_t {name}" -def int_arg(name): +def int_arg(name: str) -> str: return f"int {name}" -def generate(**kwargs): +def generate(**kwargs: Any) -> None: gen_args = kwargs["args"] # Generates CUDA variants. @@ -142,27 +152,24 @@ def generate(**kwargs): ) -Args = collections.namedtuple( - "Args", - [ - "split_kernel_args", - "split_kernel_arg_constructors", - "split_host_accessor_constructors", - "split_function_args", - "split_saved_tensors", - "split_tensors", - "saved_data", - "split_function_arg_names", - "split_function_schemas", - "split_variables", - ], -) +@dataclass +class Args: + split_kernel_args: List[str] + split_kernel_arg_constructors: List[str] + split_host_accessor_constructors: List[str] + split_function_args: List[str] + split_saved_tensors: List[str] + split_tensors: List[str] + saved_data: List[Tuple[str, str]] + split_function_arg_names: List[str] + split_function_schemas: List[str] + split_variables: List[str] TENSOR, INT_TENSOR, LONG_TENSOR, INT, FLOAT = range(5) -def make_args(arg_spec): - def make_kernel_arg(ty, name): +def make_args(arg_spec: List[Tuple[int,str]]) -> Dict[str, Any]: + def make_kernel_arg(ty: int, name: str) -> str: return { TENSOR: acc_cache_tensor_arg, INT_TENSOR: int_tensor_arg, @@ -171,7 +178,7 @@ def make_kernel_arg(ty, name): FLOAT: float_arg, }[ty](name) - def make_kernel_arg_constructor(ty, name): + def make_kernel_arg_constructor(ty: int, name: str) -> str: return { TENSOR: acc_cache_tensor_arg_constructor, INT_TENSOR: int_tensor_arg_constructor, @@ -180,7 +187,7 @@ def make_kernel_arg_constructor(ty, name): FLOAT: lambda x: x, }[ty](name) - def make_host_accessor_constructor(ty, name): + def make_host_accessor_constructor(ty: int, name: str) -> str: return { TENSOR: host_accessor_constructor, INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False), @@ -189,7 +196,7 @@ def make_host_accessor_constructor(ty, name): FLOAT: lambda x: "", }[ty](name) - def make_function_arg(ty, name): + def make_function_arg(ty: int, name: str) -> str: return { TENSOR: tensor_arg, INT_TENSOR: tensor_arg, @@ -198,7 +205,7 @@ def make_function_arg(ty, name): FLOAT: double_arg, }[ty](name) - def make_function_schema_arg(ty, name): + def make_function_schema_arg(ty: int, name: str) -> str: return { TENSOR: tensor_arg, INT_TENSOR: tensor_arg, @@ -207,10 +214,10 @@ def make_function_schema_arg(ty, name): FLOAT: float_arg, }[ty](name) - def make_ivalue_cast(ty): + def make_ivalue_cast(ty: int) -> str: return {INT: "toInt", FLOAT: "toDouble"}[ty] - def make_args_for_compute_device(split_arg_spec): + def make_args_for_compute_device(split_arg_spec: List[Tuple[int, str]]) -> Args: return Args( split_kernel_args=[ make_kernel_arg(ty, name) for (ty, name) in split_arg_spec @@ -275,7 +282,7 @@ def make_args_for_compute_device(split_arg_spec): return {"cpu": cpu, "cuda": cuda} -def adagrad(): +def adagrad() -> None: split_weight_update = """ Vec4T m_t(&momentum1[idx * D + d]); m_t.acc.x += grad.acc.x * grad.acc.x; @@ -310,7 +317,7 @@ def adagrad(): ) -def table_info_precomputation(momentum_prefix="momentum1"): +def table_info_precomputation(momentum_prefix: str="momentum1") -> str: template = """ // table_begin -> (E, D, {momentum_prefix}_row_begin). std::map> table_info_map; @@ -332,7 +339,7 @@ def table_info_precomputation(momentum_prefix="momentum1"): return template.replace("{momentum_prefix}", momentum_prefix) -def rowwise_adagrad(): +def rowwise_adagrad() -> None: split_weight_update = """ weight_new.fma_(grad, -multiplier); """ @@ -384,7 +391,7 @@ def rowwise_adagrad(): ) -def approx_rowwise_adagrad(): +def approx_rowwise_adagrad() -> None: split_weight_update = """ weight_new.fma_(grad, -multiplier); assert(false); // approx rowwise AdaGrad is not supported on GPU @@ -424,7 +431,7 @@ def approx_rowwise_adagrad(): ) -def sgd(): +def sgd() -> None: split_weight_update = """ weight_new.fma_(grad, -learning_rate); """ @@ -443,7 +450,7 @@ def sgd(): ) -def approx_sgd(): +def approx_sgd() -> None: split_weight_update = """ // approx_sgd not supported for GPU. // Just do the same thing as exact sgd to avoid unused variable warning. @@ -463,7 +470,7 @@ def approx_sgd(): ) -def lamb(): +def lamb() -> None: split_precomputation = """ acc_type weight_sum_sq = 0.0; acc_type rtw_sum_sq = 0.0; @@ -533,7 +540,7 @@ def lamb(): ) -def partial_rowwise_lamb(): +def partial_rowwise_lamb() -> None: split_precomputation = """ acc_type g_local_sum_square = 0.0; @@ -619,7 +626,7 @@ def partial_rowwise_lamb(): ) -def adam(): +def adam() -> None: split_weight_update = """ Vec4T m_t(&momentum1[idx * D + d]); m_t.acc.x *= beta1; @@ -669,7 +676,7 @@ def adam(): ) -def partial_rowwise_adam(): +def partial_rowwise_adam() -> None: split_precomputation = """ acc_type g_local_sum_square = 0.0; #pragma unroll kMaxVecsPerThread @@ -729,7 +736,7 @@ def partial_rowwise_adam(): ) -def lars_sgd(): +def lars_sgd() -> None: split_precomputation = """ acc_type weight_sum_sq = 0.0; acc_type grad_sum_sq = 0.0; @@ -787,7 +794,7 @@ def lars_sgd(): ) -def forward_split(): +def forward_split() -> None: template = env.get_template("embedding_forward_split_template.cu") src_cu = template.render(weighted=False) @@ -801,7 +808,7 @@ def forward_split(): write("gen_embedding_forward_dense_weighted_codegen_cuda.cu", src_cu) -def backward_indices(): +def backward_indices() -> None: template = env.get_template("embedding_backward_split_indice_weights_template.cu") src_cu = template.render() write("gen_embedding_backward_split_indice_weights_codegen_cuda.cu", src_cu) @@ -809,7 +816,7 @@ def backward_indices(): write("gen_embedding_backward_dense_indice_weights_codegen_cuda.cu", src_cu) -def backward_dense(): +def backward_dense() -> None: generate( optimizer="dense", dense=True, @@ -821,13 +828,13 @@ def backward_dense(): ) -def gen__init__py(): +def gen__init__py() -> None: template = env.get_template("__init__.template") src_py = template.render() write("__init__.py", src_py) -def emb_codegen(install_dir=None, is_fbcode=True): +def emb_codegen(install_dir: Optional[str] = None, is_fbcode: bool=True) -> None: if install_dir is not None and len(install_dir) != 0: args.install_dir = install_dir args.is_fbcode = is_fbcode @@ -848,7 +855,7 @@ def emb_codegen(install_dir=None, is_fbcode=True): gen__init__py() -def main(): +def main() -> None: emb_codegen() diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index f408c00f79..1c6972fb1b 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -25,6 +25,7 @@ class EmbOptimType(enum.Enum): PARTIAL_ROWWISE_LAMB = "partial_row_wise_lamb" ROWWISE_ADAGRAD = "row_wise_adagrad" + # pyre-fixme[3]: Return type must be annotated. def __str__(self): return self.value @@ -35,6 +36,7 @@ class SparseType(enum.Enum): FP16 = "fp16" INT8 = "int8" + # pyre-fixme[3]: Return type must be annotated. def __str__(self): return self.value diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index b146535e76..1314c6d76e 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -8,8 +8,8 @@ import enum import logging from dataclasses import dataclass +from itertools import accumulate from math import log2 -from numbers import Number from typing import Dict, List, Optional, Tuple import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers @@ -56,15 +56,6 @@ class SplitState: offsets: List[int] -def _cumsum(arr: List[Number]): - ret: List[Number] = [] - curr = 0 - for el in arr: - curr += el - ret.append(curr) - return ret - - def construct_split_state( embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]], rowwise: bool, @@ -190,12 +181,12 @@ def __init__( # noqa C901 beta1: float = 0.9, # used by LAMB and ADAM beta2: float = 0.999, # used by LAMB and ADAM pooling_mode: PoolingMode = PoolingMode.SUM, - ): + ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() self.pooling_mode = pooling_mode self.weights_precision = weights_precision # NOTE: a placeholder to avoid multi-construction and make TorchScript work! - self.dummy_tensor = torch.tensor(0) + self.dummy_tensor: Tensor = torch.tensor(0) self.embedding_specs = embedding_specs (rows, dims, locations, compute_devices) = zip(*embedding_specs) @@ -205,10 +196,11 @@ def __init__( # noqa C901 assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" - self.use_cpu = all(cd == ComputeDevice.CPU for cd in compute_devices) + self.use_cpu: bool = all(cd == ComputeDevice.CPU for cd in compute_devices) assert not self.use_cpu or all( loc == EmbeddingLocation.HOST for loc in locations ), "ComputeDevice.CPU is only for EmbeddingLocation.HOST!" + # pyre-fixme[4]: Attribute must be annotated. self.current_device = ( torch.device("cpu") if self.use_cpu else torch.cuda.current_device() ) @@ -218,8 +210,10 @@ def __init__( # noqa C901 torch.zeros(0, device=self.current_device, dtype=torch.float) ) + # pyre-fixme[4]: Attribute must be annotated. self.int8_emb_row_dim_offset = INT8_EMB_ROW_DIM_OFFSET + # pyre-fixme[4]: Attribute must be annotated. self.feature_table_map = ( feature_table_map if feature_table_map is not None else list(range(T_)) ) @@ -231,14 +225,17 @@ def __init__( # noqa C901 assert all(table_has_feature), "Each table must have at least one feature!" D_offsets = [dims[t] for t in self.feature_table_map] - D_offsets = [0] + _cumsum(D_offsets) + D_offsets = [0] + list(accumulate(D_offsets)) + # pyre-fixme[4]: Attribute must be annotated. self.total_D = D_offsets[-1] + # pyre-fixme[4]: Attribute must be annotated. self.max_D = max(dims) cached_dims = [ embedding_spec[1] for embedding_spec in embedding_specs if embedding_spec[2] == EmbeddingLocation.MANAGED_CACHING ] + # pyre-fixme[4]: Attribute must be annotated. self.max_D_cache = max(cached_dims) if len(cached_dims) > 0 else 0 self.register_buffer( @@ -246,7 +243,7 @@ def __init__( # noqa C901 torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), ) - hash_size_cumsum = [0] + _cumsum(rows) + hash_size_cumsum = [0] + list(accumulate(rows)) self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1) # The last element is to easily access # of rows of each table by # hash_size_cumsum[t + 1] - hash_size_cumsum[t] @@ -321,20 +318,25 @@ def __init__( # noqa C901 OptimType.EXACT_SGD, ): # NOTE: make TorchScript work! + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum1_dev", torch.tensor([0], dtype=torch.int64), persistent=False ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum1_host", torch.tensor([0], dtype=torch.int64), persistent=False ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum1_uvm", torch.tensor([0], dtype=torch.int64), persistent=False ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum1_placements", torch.tensor([0], dtype=torch.int64), persistent=False, ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum1_offsets", torch.tensor([0], dtype=torch.int64), @@ -371,25 +373,31 @@ def __init__( # noqa C901 self.register_buffer("iter", torch.tensor([0], dtype=torch.int64)) else: # NOTE: make TorchScript work! + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum2_dev", torch.tensor([0], dtype=torch.int64), persistent=False ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum2_host", torch.tensor([0], dtype=torch.int64), persistent=False ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum2_uvm", torch.tensor([0], dtype=torch.int64), persistent=False ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum2_placements", torch.tensor([0], dtype=torch.int64), persistent=False, ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "momentum2_offsets", torch.tensor([0], dtype=torch.int64), persistent=False, ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "iter", torch.tensor([0], dtype=torch.int64), persistent=False ) @@ -417,6 +425,8 @@ def __init__( # noqa C901 self.step = 0 + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def get_states(self, prefix): if not hasattr(self, f"{prefix}_physical_placements"): return None @@ -433,6 +443,7 @@ def get_states(self, prefix): torch.tensor(offsets, dtype=torch.int64), ) + # pyre-fixme[3]: Return type must be annotated. def get_all_states(self): all_states = [] for prefix in ["weights", "momentum1", "momentum2"]: @@ -461,7 +472,6 @@ def forward( else self.lxu_cache_locations_list.pop(0) ) common_args = invokers.lookup_args.CommonArgs( - # pyre-fixme[16] placeholder_autograd_tensor=self.placeholder_autograd_tensor, # pyre-fixme[16] dev_weights=self.weights_dev, @@ -497,10 +507,20 @@ def forward( return invokers.lookup_approx_sgd.invoke(common_args, self.optimizer_args) momentum1 = invokers.lookup_args.Momentum( + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum1_dev`. dev=self.momentum1_dev, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum1_host`. host=self.momentum1_host, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum1_uvm`. uvm=self.momentum1_uvm, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum1_offsets`. offsets=self.momentum1_offsets, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum1_placements`. placements=self.momentum1_placements, ) @@ -523,13 +543,25 @@ def forward( ) momentum2 = invokers.lookup_args.Momentum( + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum2_dev`. dev=self.momentum2_dev, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum2_host`. host=self.momentum2_host, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum2_uvm`. uvm=self.momentum2_uvm, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum2_offsets`. offsets=self.momentum2_offsets, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum2_placements`. placements=self.momentum2_placements, ) # Ensure iter is always on CPU so the increment doesn't synchronize. + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no attribute + # `iter`. if self.iter.is_cuda: self.iter = self.iter.cpu() self.iter[0] += 1 @@ -553,7 +585,7 @@ def forward( raise ValueError(f"Invalid OptimType: {self.optimizer}") - def prefetch(self, indices: Tensor, offsets: Tensor): + def prefetch(self, indices: Tensor, offsets: Tensor) -> None: self.timestep += 1 self.timesteps_prefetched.append(self.timestep) # pyre-fixme[16] @@ -613,7 +645,7 @@ def prefetch(self, indices: Tensor, offsets: Tensor): ) ) - def init_embedding_weights_uniform(self, min_val, max_val): + def init_embedding_weights_uniform(self, min_val: float, max_val: float) -> None: splits = self.split_embedding_weights() if self.weights_precision == SparseType.INT8: # TODO: add in-place FloatToFused8BitRowwiseQuantized conversion @@ -630,8 +662,10 @@ def init_embedding_weights_uniform(self, min_val, max_val): for param in splits: param.uniform_(min_val, max_val) + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.jit.export`. @torch.jit.export - def split_embedding_weights(self): + def split_embedding_weights(self) -> List[Tensor]: """ Returns a list of weights, split by table """ @@ -639,13 +673,23 @@ def split_embedding_weights(self): for t, (rows, dim, _, _) in enumerate(self.embedding_specs): if self.weights_precision == SparseType.INT8: dim += self.int8_emb_row_dim_offset + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `weights_physical_placements`. placement = self.weights_physical_placements[t] + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `weights_physical_offsets`. offset = self.weights_physical_offsets[t] if placement == EmbeddingLocation.DEVICE.value: + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `weights_dev`. weights = self.weights_dev elif placement == EmbeddingLocation.HOST.value: + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `weights_host`. weights = self.weights_host else: + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `weights_uvm`. weights = self.weights_uvm splits.append( weights.detach()[offset : offset + rows * dim].view(rows, dim) @@ -694,11 +738,17 @@ def split_optimizer_states(self) -> List[Tuple[torch.Tensor]]: """ def get_optimizer_states( + # pyre-fixme[2]: Parameter must be annotated. state_dev, + # pyre-fixme[2]: Parameter must be annotated. state_host, + # pyre-fixme[2]: Parameter must be annotated. state_uvm, + # pyre-fixme[2]: Parameter must be annotated. state_offsets, + # pyre-fixme[2]: Parameter must be annotated. state_placements, + # pyre-fixme[2]: Parameter must be annotated. rowwise, ) -> List[torch.Tensor]: splits = [] @@ -726,8 +776,14 @@ def get_optimizer_states( ): states.append( get_optimizer_states( + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum1_dev`. self.momentum1_dev, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum1_host`. self.momentum1_host, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum1_uvm`. self.momentum1_uvm, # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` # has no attribute `momentum1_physical_offsets`. @@ -746,8 +802,14 @@ def get_optimizer_states( ): states.append( get_optimizer_states( + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum2_dev`. self.momentum2_dev, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum2_host`. self.momentum2_host, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `momentum2_uvm`. self.momentum2_uvm, # pyre-fixme[16] self.momentum2_physical_offsets, @@ -779,23 +841,41 @@ def _set_learning_rate(self, lr: float) -> float: self.optimizer_args = self.optimizer_args._replace(learning_rate=lr) return 0.0 + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.jit.export`. @torch.jit.export - def flush(self): + def flush(self) -> None: + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no attribute + # `lxu_cache_weights`. if not self.lxu_cache_weights.numel(): return torch.ops.fb.lxu_cache_flush( + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `weights_uvm`. self.weights_uvm, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `cache_hash_size_cumsum`. self.cache_hash_size_cumsum, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `cache_index_table_map`. self.cache_index_table_map, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `weights_offsets`. self.weights_offsets, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `D_offsets`. self.D_offsets, self.total_D, + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `lxu_cache_state`. self.lxu_cache_state, self.lxu_cache_weights, self.stochastic_rounding, ) - def _apply_split(self, split, prefix, dtype, enforce_hbm=False): + # pyre-fixme[2]: Parameter must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def _apply_split(self, split, prefix, dtype: torch.dtype, enforce_hbm: bool = False) -> None: setattr(self, f"{prefix}_physical_placements", split.placements) setattr(self, f"{prefix}_physical_offsets", split.offsets) @@ -862,13 +942,18 @@ def _apply_split(self, split, prefix, dtype, enforce_hbm=False): def _apply_cache_state( self, + # pyre-fixme[2]: Parameter must be annotated. cache_state, + # pyre-fixme[2]: Parameter must be annotated. cache_algorithm, + # pyre-fixme[2]: Parameter must be annotated. cache_load_factor, + # pyre-fixme[2]: Parameter must be annotated. cache_sets, + # pyre-fixme[2]: Parameter must be annotated. cache_reserved_memory, - dtype, - ): + dtype: torch.dtype, + ) -> None: self.cache_algorithm = cache_algorithm self.timestep = 1 self.timesteps_prefetched = [] @@ -886,26 +971,31 @@ def _apply_cache_state( torch.zeros(0, 0, device=self.current_device, dtype=dtype), ) # NOTE: make TorchScript work! + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "cache_hash_size_cumsum", torch.tensor([0], dtype=torch.int64), persistent=False, ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "total_cache_hash_size", torch.tensor([0], dtype=torch.int64), persistent=False, ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "cache_index_table_map", torch.tensor([0], dtype=torch.int64), persistent=False, ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "lxu_cache_state", torch.tensor([0], dtype=torch.int64), persistent=False, ) + # pyre-fixme[28]: Unexpected keyword argument `persistent`. self.register_buffer( "lxu_state", torch.tensor([0], dtype=torch.int64), @@ -981,6 +1071,7 @@ def _apply_cache_state( ) self.register_buffer( "lxu_state", + # pyre-fixme[28]: Unexpected keyword argument `size`. torch.zeros( size=(self.total_cache_hash_size + 1,) if cache_algorithm == CacheAlgorithm.LFU @@ -1027,11 +1118,12 @@ def __init__( feature_table_map: Optional[List[int]] = None, # [T] pooling_mode: PoolingMode = PoolingMode.SUM, use_cpu: bool = False, - ): # noqa C901 # tuple of (rows, dims,) + ) -> None: # noqa C901 # tuple of (rows, dims,) super(DenseTableBatchedEmbeddingBagsCodegen, self).__init__() self.pooling_mode = pooling_mode + # pyre-fixme[4]: Attribute must be annotated. self.current_device = ( torch.device("cpu") if use_cpu else torch.cuda.current_device() ) @@ -1047,7 +1139,7 @@ def __init__( T = len(feature_table_map) assert T_ <= T D_offsets = [dims[t] for t in feature_table_map] - D_offsets = [0] + _cumsum(D_offsets) + D_offsets = [0] + list(accumulate(D_offsets)) self.total_D = D_offsets[-1] self.max_D = max(dims) self.register_buffer( @@ -1056,7 +1148,7 @@ def __init__( ) assert self.D_offsets.numel() == T + 1 - hash_size_cumsum = [0] + _cumsum(rows) + hash_size_cumsum = [0] + list(accumulate(rows)) self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1) # The last element is to easily access # of rows of each table by # hash_size_cumsum[t + 1] - hash_size_cumsum[t] @@ -1069,7 +1161,7 @@ def __init__( hash_size_cumsum, device=self.current_device, dtype=torch.int64 ), ) - weights_offsets = [0] + _cumsum([row * dim for (row, dim) in embedding_specs]) + weights_offsets = [0] + list(accumulate([row * dim for (row, dim) in embedding_specs])) self.weights = nn.Parameter( torch.randn( weights_offsets[-1], @@ -1094,6 +1186,7 @@ def __init__( row for (row, _) in embedding_specs[:t] ) + # pyre-fixme[4]: Attribute must be annotated. self.weights_physical_offsets = weights_offsets weights_offsets = [weights_offsets[t] for t in feature_table_map] self.register_buffer( @@ -1126,8 +1219,10 @@ def forward( feature_requires_grad=feature_requires_grad, ) + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.jit.export`. @torch.jit.export - def split_embedding_weights(self): + def split_embedding_weights(self) -> List[Tensor]: """ Returns a list of weights, split by table """ @@ -1139,7 +1234,7 @@ def split_embedding_weights(self): ) return splits - def init_embedding_weights_uniform(self, min_val, max_val): + def init_embedding_weights_uniform(self, min_val: float, max_val: float) -> None: splits = self.split_embedding_weights() for param in splits: param.uniform_(min_val, max_val) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index e0e309f5fc..726c4718d3 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -34,12 +34,15 @@ def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.T ) -def to_device(t: torch.Tensor, use_cpu: bool): +def to_device(t: torch.Tensor, use_cpu: bool) -> torch.Tensor: return t.cpu() if use_cpu else t.cuda() +# pyre-fixme[3]: Return annotation cannot be `Any`. def b_indices( - b: Callable, x: torch.Tensor, per_sample_weights=None, use_cpu=False + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + # pyre-fixme[2]: Parameter must be annotated. + b: Callable, x: torch.Tensor, per_sample_weights=None, use_cpu: bool=False ) -> Any: (indices, offsets) = get_offsets_from_dense(x) return b( @@ -50,7 +53,7 @@ def b_indices( def get_table_batched_offsets_from_dense( - merged_indices: torch.Tensor, use_cpu=False + merged_indices: torch.Tensor, use_cpu: bool=False ) -> Tuple[torch.Tensor, torch.Tensor]: (T, B, L) = merged_indices.size() lengths = np.ones((T, B)) * L @@ -134,6 +137,9 @@ class SplitTableBatchedEmbeddingsTest(unittest.TestCase): pooling_mode=st.sampled_from(split_table_batched_embeddings_ops.PoolingMode), use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True), ) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `[hypothesis.HealthCheck.filter_too_much]` to decorator factory + # `hypothesis.settings`. @settings( verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, @@ -142,19 +148,21 @@ class SplitTableBatchedEmbeddingsTest(unittest.TestCase): ) def test_forward( self, - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - use_cache, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + mixed: bool, + use_cache: bool, + # pyre-fixme[2]: Parameter must be annotated. cache_algorithm, + # pyre-fixme[2]: Parameter must be annotated. pooling_mode, - use_cpu, - ): + use_cpu: bool, + ) -> None: # NOTE: cache is not applicable to CPU version. assume(not use_cpu or not use_cache) # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! @@ -230,9 +238,11 @@ def test_forward( xws = [xw.half() for xw in xws] fs = ( + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] if not weighted else [ + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) for (b, x, xw) in zip(bs, xs, xws) ] @@ -259,6 +269,7 @@ def test_forward( cc = torch.jit.script(cc) for t in range(T): + # pyre-fixme[16]: `Tensor` has no attribute `weight`. cc.split_embedding_weights()[t].data.copy_(bs[t].weight) x = torch.cat([x.view(1, B, L) for x in xs], dim=0) @@ -290,6 +301,9 @@ def test_forward( pooling_mode=st.sampled_from(split_table_batched_embeddings_ops.PoolingMode), use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True), ) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `[hypothesis.HealthCheck.filter_too_much]` to decorator factory + # `hypothesis.settings`. @settings( verbosity=Verbosity.verbose, max_examples=10, @@ -298,18 +312,19 @@ def test_forward( ) def test_backward_dense( self, - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - long_segments, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + mixed: bool, + long_segments: bool, + # pyre-fixme[2]: Parameter must be annotated. pooling_mode, - use_cpu, - ): + use_cpu: bool, + ) -> None: # NOTE: torch.autograd.gradcheck() is too time-consuming for CPU version # so we have to limit (T * B * L * D)! assume(not use_cpu or T * B * L * D <= 2048) @@ -373,9 +388,11 @@ def test_backward_dense( xws = [xw.half() for xw in xws] fs = ( + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] if not weighted else [ + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) for (b, x, xw) in zip(bs, xs, xws) ] @@ -383,6 +400,7 @@ def test_backward_dense( gos = [torch.randn_like(f) for f in fs] [f.backward(go) for (f, go) in zip(fs, gos)] + # pyre-fixme[16]: `Tensor` has no attribute `weight`. grad_weights = torch.cat([b.weight.grad.view(-1) for b in bs]) if weights_precision == SparseType.FP16 and not use_cpu: grad_weights = grad_weights.half() @@ -426,6 +444,8 @@ def test_backward_dense( rtol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-4, ) + # pyre-fixme[16]: `DenseTableBatchedEmbeddingBagsCodegen` has no attribute + # `double`. cc = split_table_batched_embeddings_ops.DenseTableBatchedEmbeddingBagsCodegen( [(E, D) for (E, D) in zip(Es, Ds)], # NOTE: only SUM pooling can work with per_sample_weights! @@ -458,6 +478,9 @@ def test_backward_dense( use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True), exact=st.booleans(), ) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `[hypothesis.HealthCheck.filter_too_much]` to decorator factory + # `hypothesis.settings`. @settings( verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, @@ -466,21 +489,23 @@ def test_backward_dense( ) def test_backward_sgd( # noqa C901 self, - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - use_cache, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + mixed: bool, + use_cache: bool, + # pyre-fixme[2]: Parameter must be annotated. cache_algorithm, - long_segments, + long_segments: bool, + # pyre-fixme[2]: Parameter must be annotated. pooling_mode, - use_cpu, - exact, - ): + use_cpu: bool, + exact: bool, + ) -> None: # NOTE: cache is not applicable to CPU version. assume(not use_cpu or not use_cache) # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! @@ -552,6 +577,7 @@ def test_backward_sgd( # noqa C901 feature_table_map = list(range(T)) table_to_replicate = T // 2 + # pyre-fixme[6]: Expected `HalfTensor` for 2nd param but got `Tensor`. bs.insert(table_to_replicate, bs[table_to_replicate]) feature_table_map.insert(table_to_replicate, table_to_replicate) @@ -579,9 +605,11 @@ def test_backward_sgd( # noqa C901 xws = [xw.half() for xw in xws] fs = ( + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] if not weighted else [ + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) for (b, x, xw) in zip(bs, xs, xws) ] @@ -591,6 +619,7 @@ def test_backward_sgd( # noqa C901 # do SGD update lr = 0.05 del bs[table_to_replicate] + # pyre-fixme[16]: `Tensor` has no attribute `weight`. new_weights = [(b.weight - b.weight.grad * lr) for b in bs] cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen( @@ -622,6 +651,7 @@ def test_backward_sgd( # noqa C901 for t in range(T): torch.testing.assert_allclose( cc.split_embedding_weights()[t], + # pyre-fixme[16]: `float` has no attribute `half`. new_weights[t].half() if weights_precision == SparseType.FP16 and not use_cpu else new_weights[t], @@ -651,6 +681,9 @@ def test_backward_sgd( # noqa C901 use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True), exact=st.booleans(), ) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `[hypothesis.HealthCheck.filter_too_much]` to decorator factory + # `hypothesis.settings`. @settings( verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, @@ -659,23 +692,25 @@ def test_backward_sgd( # noqa C901 ) def test_backward_adagrad( # noqa C901 self, - T, - D, - B, - log_E, - L, - D_gradcheck, - weights_precision, - stochastic_rounding, - weighted, - row_wise, - mixed, - use_cache, + T: int, + D: int, + B: int, + log_E: int, + L: int, + D_gradcheck: int, + weights_precision: SparseType, + stochastic_rounding: bool, + weighted: bool, + row_wise: bool, + mixed: bool, + use_cache: bool, + # pyre-fixme[2]: Parameter must be annotated. cache_algorithm, + # pyre-fixme[2]: Parameter must be annotated. pooling_mode, - use_cpu, - exact, - ): + use_cpu: bool, + exact: bool, + ) -> None: # NOTE: cache is not applicable to CPU version. assume(not use_cpu or not use_cache) # Approx AdaGrad only works with row_wise on CPU @@ -758,6 +793,7 @@ def test_backward_adagrad( # noqa C901 if exact: # autograd with shared embedding only works for exact table_to_replicate = T // 2 + # pyre-fixme[6]: Expected `HalfTensor` for 2nd param but got `Tensor`. bs.insert(table_to_replicate, bs[table_to_replicate]) feature_table_map.insert(table_to_replicate, table_to_replicate) @@ -780,9 +816,11 @@ def test_backward_adagrad( # noqa C901 xws = [xw.half() for xw in xws] fs = ( + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] if not weighted else [ + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) for (b, x, xw) in zip(bs, xs, xws) ] @@ -812,6 +850,7 @@ def test_backward_adagrad( # noqa C901 if exact: del bs[table_to_replicate] for t in range(T): + # pyre-fixme[16]: `Tensor` has no attribute `weight`. cc.split_embedding_weights()[t].data.copy_(bs[t].weight) x = torch.cat([x.view(1, B, L) for x in xs], dim=0) @@ -869,6 +908,8 @@ def test_backward_adagrad( # noqa C901 ) if use_cpu: # NOTE: GPU version of SplitTableBatchedEmbeddingBagsCodegen doesn't support double. + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `double`. cc = cc.double() per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) @@ -891,6 +932,7 @@ def test_backward_adagrad( # noqa C901 param.requires_grad = False y = cc(indices, offsets, per_sample_weights) y.sum().backward() + # pyre-fixme[16]: `Tensor` has no attribute `grad`. indice_weight_grad_all = per_sample_weights.grad.clone().cpu() T_ = len(xws) feature_requires_grad = to_device( @@ -920,6 +962,9 @@ def test_backward_adagrad( # noqa C901 ) @unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available") + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value = + # 5)` to decorator factory `hypothesis.given`. @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=64), @@ -932,7 +977,8 @@ def test_backward_adagrad( # noqa C901 ), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_cache_pipeline(self, T, D, B, log_E, L, mixed, cache_algorithm): + # pyre-fixme[2]: Parameter must be annotated. + def test_cache_pipeline(self, T: int, D: int, B: int, log_E: int, L: int, mixed: bool, cache_algorithm) -> None: iters = 3 E = int(10 ** log_E) D = D * 4 @@ -1025,6 +1071,9 @@ def test_cache_pipeline(self, T, D, B, log_E, L, mixed, cache_algorithm): pooling_mode=st.sampled_from(split_table_batched_embeddings_ops.PoolingMode), use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True), ) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `[hypothesis.HealthCheck.filter_too_much]` to decorator factory + # `hypothesis.settings`. @settings( verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, @@ -1033,18 +1082,20 @@ def test_cache_pipeline(self, T, D, B, log_E, L, mixed, cache_algorithm): ) def test_backward_optimizers( # noqa C901 self, - T, - D, - B, - log_E, - L, - weighted, - mixed, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + mixed: bool, + # pyre-fixme[2]: Parameter must be annotated. optimizer, - long_segments, + long_segments: bool, + # pyre-fixme[2]: Parameter must be annotated. pooling_mode, - use_cpu, - ): + use_cpu: bool, + ) -> None: # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! assume(not use_cpu or T * B * L * D <= 2048) assume( @@ -1122,9 +1173,11 @@ def test_backward_optimizers( # noqa C901 xws_acc_type = copy.deepcopy(xws) fs = ( + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] if not weighted else [ + # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`. b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) for (b, x, xw) in zip(bs, xs, xws) ] @@ -1167,10 +1220,12 @@ def test_backward_optimizers( # noqa C901 [(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)], optimizer=optimizer, pooling_mode=pooling_mode, + # pyre-fixme[6]: Expected `CacheAlgorithm` for 5th param but got `float`. **optimizer_kwargs, ) for t in range(T): + # pyre-fixme[16]: `Tensor` has no attribute `weight`. cc.split_embedding_weights()[t].data.copy_(bs[t].weight) x = torch.cat([x.view(1, B, L) for x in xs], dim=0) @@ -1236,6 +1291,8 @@ def test_backward_optimizers( # noqa C901 torch.testing.assert_allclose( m1.cpu(), m1_ref, atol=1.0e-4, rtol=1.0e-4 ) + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `iter`. iter_ = cc.iter.item() v_hat_t = m2_ref / (1 - beta2 ** iter_) v_hat_t = v_hat_t if not rowwise else v_hat_t.view(v_hat_t.numel(), 1) @@ -1314,6 +1371,7 @@ def test_backward_optimizers( # noqa C901 torch.testing.assert_allclose( m1.index_select(dim=0, index=x[t].view(-1)).cpu(), + # pyre-fixme[16]: `float` has no attribute `index_select`. m1_ref.index_select(dim=0, index=x[t].view(-1).cpu()), atol=1.0e-4, rtol=1.0e-4, @@ -1330,21 +1388,29 @@ def test_backward_optimizers( # noqa C901 @unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available") class CUMemTest(unittest.TestCase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.lists(hypothesis.strategies.integers($parameter$min_value + # = 1, $parameter$max_value = 8), $parameter$min_size = 1, $parameter$max_size = + # 4)` to decorator factory `hypothesis.given`. @given( sizes=st.lists(st.integers(min_value=1, max_value=8), min_size=1, max_size=4) ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_is_uvm_tensor(self, sizes): + def test_is_uvm_tensor(self, sizes: List[int]) -> None: uvm_t = torch.ops.fb.new_managed_tensor( torch.zeros(*sizes, device="cuda:0", dtype=torch.float), sizes ) assert torch.ops.fb.is_uvm_tensor(uvm_t) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.lists(hypothesis.strategies.integers($parameter$min_value + # = 1, $parameter$max_value = 8), $parameter$min_size = 1, $parameter$max_size = + # 4)` to decorator factory `hypothesis.given`. @given( sizes=st.lists(st.integers(min_value=1, max_value=8), min_size=1, max_size=4) ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_uvm_to_cpu(self, sizes): + def test_uvm_to_cpu(self, sizes: List[int]) -> None: uvm_t = torch.ops.fb.new_managed_tensor( torch.zeros(*sizes, device="cuda:0", dtype=torch.float), sizes )