diff --git a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat
index 2af6651ba05d7..29c213ad4246c 100644
--- a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat
+++ b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat
@@ -36,7 +36,7 @@ popd
=======
:: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
-pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-shard pytest-rerunfailures "xdoctest==1.0.2" "pygments==2.12.0" "opt-einsum>=3.3"
+pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-shard pytest-rerunfailures sympy "xdoctest==1.0.2" "pygments==2.12.0" "opt-einsum>=3.3"
if errorlevel 1 exit /b
if not errorlevel 0 exit /b
diff --git a/benchmarks/dynamo/README.md b/benchmarks/dynamo/README.md
new file mode 100644
index 0000000000000..5307e77b9b173
--- /dev/null
+++ b/benchmarks/dynamo/README.md
@@ -0,0 +1,50 @@
+# Torchdynamo Benchmarks
+
+## What We Benchmark
+TorchDynamo provides a benchmark harness that takes care of uniformly benchmarking different models. It interleaves runs of eager and dynamo to avoid machine noise/variability issues, and reports results based on medians along with P-values.
+
+The runner integrates with models from TorchBenchmark, HuggingFace and TIMM suites and covers both training and inference.
+
+The infrastructure allows us to specify a loss function. For torchbench models, we use .sum().backward() call in place of the native loss function. For TIMM models, we use a CrossEntropy loss. And HF models contain a loss function inside the model itself, so we don't need any special loss computation handling.
+
+Training benchmarks approximate training by running the model forward, computing loss and then running backward. We entirely skip the optimizer step today.
+
+Inference benchmarks and Training benchmarks measure correctness by comparing dynamo and eager model outputs given fixed inputs and seeds.
+
+## Setup
+
+### Machine
+We run benchmarks on AWS machines (p4d.24xlarge) using 8xNVidia A100 40GB cards. We suggest using Cuda 11.6 for consistency.
+
+### Benchmarks
+Make sure to carefully follow the [torchbench installation](https://github.com/pytorch/benchmark#installation) instructions, taking care to build the auxiliary libraries (torchvision, torchtext) from a matching version to your pytorch version.
+
+For HF and TIMM models, the scripts already install the transformers and timm package respectively on the first run.
+
+## Runbook
+
+### Basic Usage
+There are a lot of flags in the benchmark runner, and it can be confusing to know which settings to use or what machine to run it on. In order to support apples-to-apples comparison, we have provided the following 'standard' settings in `runner.py`. This script is a wrapper over the common benchmarking infrastructure and simplifies the flags. We will continually update `runner.py` with the latest and most relevant compilers for training and inference. It also provides some graph utilities to visualize and compare results. Some of the example commands are
+
+**Inference Commands**
+* Inference compilers on torchbench models - `python benchmarks/runner.py --suites=torchbench --inference --dtypes=float16`
+
+**Training Commands**
+* Training compilers on TIMM models - `python benchmarks/runner.py --suites=timm_models --training --dtypes=float32 --output-dir=timm_logs`
+* AOTAutograd Training compiler on TIMM models - `python benchmarks/runner.py --suites=timm_models --training --dtypes=float32 --compilers=aot_nvfuser --output-dir=timm_logs`
+
+Running runner.py generates a file named `run.sh`. This file contains the actual commands that invoke the common benchmarking infrastructure with the appropriate flags. Which brings us to the advanced usage.
+
+### Advanced Usage
+
+One could directly call `torchbench.py`, `huggingface.py` or `timm_models.py` with the necessary flags. There are a lot of flags in the benchmarks runner. Some of the examples are as follows. These are subject to change.
+
+**Inference Commands**
+* TorchScript NVFuser Inference - `python benchmarks/torchbench.py -dcuda -n100 --speedup-ts`
+* TorchInductor CUDA Graphs Inference - `python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --inductor`
+
+**Training Commands**
+* Torchscript (with TorchDynamo capture) NVFuser Training - `python benchmarks/torchbench.py --float32 -dcuda --training --nvfuser --speedup-dynamo-ts --use-eval-mode`
+* AOTAutograd Torchscript NVFuser Training - `python benchmarks/torchbench.py --float32 -dcuda --training --nvfuser --accuracy-aot-ts-mincut --use-eval-mode`
+
+Above commands are for torchbench models. You can simply replace `torchbench.py` with `huggingface.py` for HF models, and `timm_model.py` for TIMM models.
diff --git a/benchmarks/dynamo/__init__.py b/benchmarks/dynamo/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
new file mode 100644
index 0000000000000..fe3caf475f61d
--- /dev/null
+++ b/benchmarks/dynamo/common.py
@@ -0,0 +1,2021 @@
+#!/usr/bin/env python3
+import argparse
+import collections
+import copy
+import csv
+import functools
+import io
+import logging
+import os
+import random
+import signal
+import subprocess
+import sys
+import time
+import warnings
+
+import numpy as np
+import pandas as pd
+import torch
+
+import torch._dynamo
+import torch._dynamo.utils
+from microbenchmarks.operator_inp_utils import OperatorInputsMode
+from scipy.stats import gmean, ttest_ind
+from torch._dynamo.optimizations import backends
+from torch._dynamo.optimizations.log_args import conv_args_analysis
+from torch._dynamo.profiler import fx_insert_profiling, Profiler
+from torch._dynamo.testing import dummy_fx_compile, format_speedup, same
+from torch._dynamo.utils import clone_inputs
+from torch._inductor.utils import fresh_triton_cache
+from torch._subclasses.fake_tensor import FakeTensorMode
+from torch.utils._pytree import tree_map
+
+try:
+ from functorch._src.aot_autograd import set_model_name
+except ImportError:
+
+ def set_model_name(name):
+ pass
+
+
+log = logging.getLogger(__name__)
+
+# We are primarily interested in TF32
+torch.backends.cuda.matmul.allow_tf32 = True
+
+current_name = ""
+current_device = ""
+current_batch_size = None
+output_filename = None
+
+CI_SKIP_AOT_EAGER_INFERENCE = [
+ # TorchBench
+ "demucs", # OOM
+ # Huggingface
+ "AllenaiLongformerBase",
+ "BartForConditionalGeneration", # OOM
+]
+
+CI_SKIP_AOT_EAGER_TRAINING = [
+ *CI_SKIP_AOT_EAGER_INFERENCE,
+ # TorchBench
+ "Background_Matting", # fp64_OOM
+ "moco",
+ "pytorch_struct",
+ "vision_maskrcnn",
+ # Huggingface
+ "AlbertForMaskedLM", # OOM
+ "AlbertForQuestionAnswering", # OOM
+ "BigBird",
+ "M2M100ForConditionalGeneration", # OOM
+ "PegasusForConditionalGeneration", # OOM
+ "XGLMForCausalLM", # OOM
+ "XLNetLMHeadModel", # OOM
+ "YituTechConvBert",
+ # TIMM
+ "cait_m36_384", # fp64_OOM
+ "convit_base", # fp64_OOM
+ "mobilevit_s", # Accuracy
+ "xcit_large_24_p8_224", # fp64_OOM
+]
+
+CI_SKIP_INDCUTOR_INFERENCE = [
+ *CI_SKIP_AOT_EAGER_INFERENCE,
+ # TorchBench
+ "detectron2",
+ "hf_Reformer",
+ "moco", # accuracy
+ "pyhpc_equation_of_state", # Accuracy
+ "pyhpc_turbulent_kinetic_energy", # Accuracy
+ "tacotron2",
+ "vision_maskrcnn", # accuracy
+ "yolov3", # Accuracy
+ # Huggingface
+ "BigBird",
+ "YituTechConvBert",
+ # TIMM
+ "cait_m36_384", # Accuracy
+ "ghostnet_100", # Accuracy
+ "swin_base_patch4_window7_224", # Accuracy
+]
+
+CI_SKIP_INDUCTOR_TRAINING = [
+ # CI does not check accuracy for inductor training yet
+ # *CI_SKIP_AOT_EAGER_TRAINING,
+ # *CI_SKIP_INDCUTOR_INFERENCE,
+ # TorchBench
+ "attention_is_all_you_need_pytorch",
+ "drq",
+ "hf_Albert",
+ "hf_Bart",
+ "hf_GPT2",
+ "hf_Reformer",
+ "mobilenet_v3_large",
+ "moco",
+ "pytorch_struct",
+ "vgg16",
+ "speech_transformer", # from functionalization
+ "vision_maskrcnn", # from functionalization
+ "timm_efficientnet", # from functionalization (only fails for inductor)
+ "hf_Bert",
+ "soft_actor_critic",
+ "tacotron2",
+ "yolov3",
+ # OOM
+ "Background_Matting",
+ "fastNLP_Bert",
+ "hf_BigBird",
+ "mobilenet_v2",
+ "mobilenet_v2_quantized_qat",
+ "resnet50_quantized_qat",
+ "timm_regnet",
+ # Huggingface
+ "AllenaiLongformerBase",
+ "AlbertForMaskedLM", # OOM
+ "BartForConditionalGeneration", # OOM
+ "M2M100ForConditionalGeneration", # OOM
+ "MBartForConditionalGeneration", # OOM
+ "MT5ForConditionalGeneration", # OOM
+ "PegasusForConditionalGeneration", # OOM
+ "XGLMForCausalLM", # fp64_OOM
+ # OOM
+ "BigBird",
+ "TrOCRForCausalLM",
+ "AlbertForQuestionAnswering",
+ # TIMM
+ "cait_m36_384", # fp64_OOM
+ "coat_lite_mini", # time out
+ "convit_base", # fp64_OOM
+ "rexnet_100", # accuracy
+ "swin_base_patch4_window7_224",
+ "twins_pcpvt_base", # time out
+ "xcit_large_24_p8_224", # fp64_OOM
+]
+
+
+def output_csv(filename, headers, row):
+ assert filename
+ existed = os.path.exists(filename)
+ output = csv.writer(
+ io.TextIOWrapper(
+ open(filename, "ab", buffering=0),
+ "utf-8",
+ write_through=True,
+ ),
+ lineterminator="\n",
+ )
+ if not existed:
+ output.writerow(headers)
+ output.writerow([(f"{x:.4f}" if isinstance(x, float) else x) for x in row])
+
+
+class NullContext:
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+@functools.lru_cache(None)
+def patch_torch_manual_seed():
+ """Make torch manual seed deterministic. Helps with accuracy testing."""
+
+ def deterministic_torch_manual_seed(*args, **kwargs):
+ from torch._C import default_generator
+
+ seed = 1337
+ import torch.cuda
+
+ if not torch.cuda._is_in_bad_fork():
+ torch.cuda.manual_seed_all(seed)
+ return default_generator.manual_seed(seed)
+
+ torch.manual_seed = deterministic_torch_manual_seed
+
+
+def synchronize():
+ pass
+
+
+def print_summary(filename):
+ if not (filename and os.path.exists(filename)):
+ return
+ data = pd.read_csv(filename)
+ width = max(map(len, data.columns))
+ for col in data.columns:
+ try:
+ if col in ("dev", "name", "batch_size"):
+ continue
+ elif col in ("pct_ops", "pct_time"):
+ print(col.ljust(width), f"{data[col].mean():.1%}")
+ elif col in ("graphs", "graph_calls", "captured_ops", "total_ops"):
+ print(col.ljust(width), f"{data[col].mean():.1f}")
+ elif col in ("compilation_latency"):
+ print(col.ljust(width), f"mean={data[col].mean():.1f} seconds")
+ elif col in ("compression_ratio"):
+ print(col.ljust(width), f"mean={data[col].mean():.1f}x")
+ else:
+ cdata = data[col].clip(1)
+ print(
+ col.ljust(width),
+ f"gmean={gmean(cdata):.2f}x mean={cdata.mean():.2f}x",
+ )
+ except Exception:
+ pass
+
+
+def timed(model, model_iter_fn, example_inputs, times=1, return_result=False):
+ synchronize()
+ reset_rng_state()
+ t0 = time.perf_counter()
+ # Dont collect outputs to correctly measure timing
+ for _ in range(times):
+ result = model_iter_fn(model, example_inputs, collect_outputs=False)
+ synchronize()
+ t1 = time.perf_counter()
+ return (t1 - t0, result) if return_result else t1 - t0
+
+
+class Stats:
+ totals = collections.defaultdict(collections.Counter)
+
+ @classmethod
+ def reset_counters(cls):
+ for k, v in torch._dynamo.utils.counters.items():
+ cls.totals[k].update(v)
+ ok = torch._dynamo.utils.counters["frames"]["ok"]
+ total = torch._dynamo.utils.counters["frames"]["total"]
+ torch._dynamo.utils.counters.clear()
+ return ok, total
+
+ @classmethod
+ def print_summary(cls):
+ for k, v in sorted(cls.totals.items()):
+ lines = "\n ".join(map(str, v.most_common(50)))
+ print(f"STATS {k}\n {lines}")
+
+ @classmethod
+ def aot_summary(cls):
+ return [cls.totals["aot_autograd"]["total"], cls.totals["aot_autograd"]["ok"]]
+
+
+def coverage_experiment(args, model_iter_fn, model, example_inputs):
+ """
+ Test operator/model coverage of TorchDynamo and record statistics
+ taken from a profiler. This target is mainly intended to check
+ correctness.
+
+ Writes to ./coverage.csv
+ """
+ profiler = Profiler()
+ frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
+ with profiler.prof:
+ frozen_model_iter_fn(model, example_inputs)
+ coverage_result = profiler.results()
+ output_csv(
+ output_filename,
+ (
+ "dev",
+ "name",
+ "batch_size",
+ "graphs",
+ "graph_calls",
+ "captured_ops",
+ "total_ops",
+ "pct_ops",
+ "pct_time",
+ ),
+ [
+ current_device,
+ current_name,
+ current_batch_size,
+ ]
+ + coverage_result.tocsv(),
+ )
+ return coverage_result
+
+
+def speedup_experiment_fx2trt(args, model_iter_fn, model, example_inputs):
+ """
+ Measure speedups over eager using the trt inference backend. TRT backend is based fx graph
+ generated by torch._dynamo.
+ Writes to ./speedups_fx2trt.csv
+ """
+ return speedup_experiment(args, model_iter_fn, model, example_inputs)
+
+
+def recompile_profiler_experiment(args, model_iter_fn, model, example_inputs):
+ prof = torch._dynamo.utils.CompileProfiler()
+ opt_model_iter_fn = torch._dynamo.optimize(prof, nopython=args.nopython)(
+ model_iter_fn
+ )
+ opt_model_iter_fn(model, example_inputs)
+ output_csv(
+ output_filename, ["model", "profiler report"], [current_name, prof.report()]
+ )
+ met = prof.get_metrics()
+ guard_failures = len(met["guard_failures"])
+ return [guard_failures]
+
+
+def randomize_input(inputs):
+ if isinstance(inputs, (list, tuple)):
+ return type(inputs)([randomize_input(x) for x in inputs])
+ elif isinstance(inputs, torch.Tensor):
+ if inputs.dtype in (torch.float32, torch.float64):
+ torch._dynamo.utils.counters["randomize_input"]["times"] += 1
+ return torch.randn_like(inputs)
+ elif inputs.dtype == torch.int64:
+ # Note: we can not simply tune integer tensors as follows
+ # `return torch.randint_like(inputs, high=inputs.max().item())`
+ # This may break some invariants between tensors.
+ # E.g. in embedding lookup case, one tensor is the length
+ # and another is an indices tensor.
+ return inputs
+ else:
+ raise RuntimeError(
+ f"randomize_input need support tensor of type {inputs.dtype}"
+ )
+ else:
+ raise RuntimeError(
+ f"randomize_input can not handle input of type {type(inputs)}"
+ )
+
+
+def cold_start_experiment(args, model_iter_fn, model, example_inputs, optimize_ctx):
+ compile_iters = 2
+ total_iters = compile_iters + 2
+ timings = np.zeros((total_iters, 2), np.float64)
+ # if we randomize the input, we should also check the result is correct
+ should_check_result = should_randomize_input = args.randomize_input
+ is_correct = True
+
+ optimized_model_iter_fn = optimize_ctx(model_iter_fn)
+ for rep in range(total_iters):
+ inputs = (
+ randomize_input(copy.deepcopy(example_inputs))
+ if should_randomize_input
+ else example_inputs
+ )
+
+ # interleave the runs to handle frequency scaling and load changes
+ timings[rep, 0], expected_output = timed(
+ model, model_iter_fn, inputs, return_result=True
+ )
+ timings[rep, 1], actual_output = timed(
+ model, optimized_model_iter_fn, inputs, return_result=True
+ )
+ if should_check_result:
+ is_correct = is_correct and same(expected_output, actual_output)
+ pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
+ worst = np.max(timings, axis=0)
+
+ def breakeven(dynamo_times, eager_times):
+ """
+ Solve for the number of iterations it takes dynamo to 'catch up' with eager,
+ taking into account the time it spent compiling. Assumes all compilation
+ happens up front and the model is static thereafter, which is definitely not
+ true in general but might be across torchbench.
+
+ dc1, dc2 = dynamo compilation iterations (with Prof Exec)
+ d, e = dynamo, eager warmed up iteration
+ B = num iters to break even
+ dc1 + dc2 + (B-2)d = B*e
+ B = (dc1 + dc2 - 2d) / (e - d)
+ """
+ dc1, dc2, d = dynamo_times[0], dynamo_times[1], np.median(dynamo_times[2:])
+ e = np.median(eager_times)
+ if d < e:
+ return (dc1 + dc2 + 2 * d) / (e - d)
+ else:
+ # if optimized dynamo is not faster than eager we'll compute
+ # a nonsense negative number
+ return 0
+
+ speedup = worst[0] / worst[1]
+ eager_times, dynamo_times = timings[:, 0], timings[:, 1]
+ output_csv(
+ output_filename,
+ ("dev", "name", "batch_size", "cold-start speedup", "breakeven iters"),
+ [
+ current_device,
+ current_name,
+ current_batch_size,
+ float(speedup),
+ breakeven(dynamo_times, eager_times),
+ ],
+ )
+
+ def format_speedup(
+ speedup, pvalue, breakeven_iters, is_correct=True, pvalue_threshold=0.1
+ ):
+ if not is_correct:
+ return "ERROR"
+ if pvalue > pvalue_threshold:
+ return f"{speedup:.3f}x breakeven={breakeven_iters:.2f} iters SAME"
+ return f"{speedup:.3f}x breakeven={breakeven_iters:.2f} iters p={pvalue:.2f}"
+
+ return format_speedup(
+ speedup, pvalue, breakeven(dynamo_times, eager_times), is_correct=is_correct
+ )
+
+
+def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
+ """
+ Measure speedups over eager.
+
+ Writes to ./speedups.csv
+ """
+ if args.dynamic_shapes:
+ return speedup_experiment_ds(args, model_iter_fn, model, example_inputs)
+
+ timings = np.zeros((args.repeat, 2), np.float64)
+ # if we randomize the input, we should also check the result is correct
+ should_check_result = should_randomize_input = args.randomize_input
+ is_correct = True
+
+ import contextlib
+
+ @contextlib.contextmanager
+ def maybe_profile(*args, **kwargs):
+ if kwargs.pop("enabled", True):
+ with torch.profiler.profile(*args, **kwargs) as p:
+ yield p
+ else:
+ yield
+
+ with maybe_profile(enabled=args.export_profiler_trace) as p:
+ frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
+ for rep in range(args.repeat):
+ inputs = (
+ randomize_input(copy.deepcopy(example_inputs))
+ if should_randomize_input
+ else example_inputs
+ )
+
+ # interleave the runs to handle frequency scaling and load changes
+ timings[rep, 0], expected_output = timed(
+ model, model_iter_fn, inputs, return_result=True
+ )
+ timings[rep, 1], actual_output = timed(
+ model, frozen_model_iter_fn, inputs, return_result=True
+ )
+ if should_check_result:
+ is_correct = is_correct and same(expected_output, actual_output)
+ if args.export_profiler_trace:
+ name = args.profiler_trace_name + "_" + model.name + ".json"
+ name = os.path.join(torch._dynamo.config.base_dir, name)
+ p.export_chrome_trace(name)
+ pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
+ median = np.median(timings, axis=0)
+ speedup = median[0] / median[1]
+ if args.dump_raw_metrics:
+ np.save(
+ f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy",
+ timings,
+ )
+
+ headers = ("dev", "name", "batch_size", "speedup")
+ row = [current_device, current_name, current_batch_size, float(speedup)]
+ if "compilation_latency" in kwargs:
+ headers = headers + ("compilation_latency", "compression_ratio")
+ row.append(kwargs["compilation_latency"])
+ row.append(kwargs["compression_ratio"])
+
+ output_csv(
+ output_filename,
+ headers,
+ row,
+ )
+ headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
+ assert (
+ output_filename.find(".csv") > 0
+ ), f"expected output_filename to be a .csv, but got {output_filename}"
+ output_csv(
+ output_filename[:-4] + "_compilation_metrics.csv",
+ ["dev", "name", "batch_size"] + headers,
+ [current_device, current_name, current_batch_size] + data,
+ )
+ return format_speedup(speedup, pvalue, is_correct=is_correct)
+
+
+def speedup_experiment_ds(args, model_iter_fn, model, example_inputs):
+ """
+ Run dynamic shapes benchmarks.
+
+ Requires dynamic shape compatible models, which provide a list of example inputs.
+
+ Warms up using the first input example and then iterates the inputs,
+ measuring (and expecting minimal) variance between the runtime for different examples.
+
+ """
+ timings = np.zeros((args.repeat, len(example_inputs), 2), np.float64)
+
+ if args.repeat > 5:
+ print(
+ f"\ndynamic shapes experiments are slow, consider setting --repeat less than {args.repeat}\n"
+ )
+
+ nwarmup = 4
+ for rep in range(args.repeat):
+ # Start each rep fresh, e.g. only warmup on example 0
+ torch._dynamo.reset()
+ optimized_model_iter_fn = optimize_ctx(model_iter_fn)
+ for _ in range(nwarmup):
+ optimized_model_iter_fn(model, example_inputs[0])
+
+ for input_idx, inputs in enumerate(example_inputs):
+ # interleave the runs to handle frequency scaling and load changes
+ timings[rep, input_idx, 0] = timed(
+ model, model_iter_fn, inputs, return_result=False
+ )
+ # different from regular speedup_experiment, we _DO_ want to allow recompilation
+ timings[rep, input_idx, 1] = timed(
+ model, optimized_model_iter_fn, inputs, return_result=False
+ )
+ medians = np.median(timings, axis=0)
+ speedups = list(medians[:, 0] / medians[:, 1])
+ speedups_mean = np.mean(speedups)
+ speedups_median = np.median(speedups)
+ speedups_var = np.var(speedups)
+
+ # TODO this x[0] is not going to work in general but bert only has 1 input
+ shapes = [x[0].shape for x in example_inputs]
+ shape_keys = sorted(set(shapes))
+ shape_speedups = {
+ shape: list(
+ map(
+ lambda it: it[1],
+ filter(lambda it: it[0] == shape, zip(shapes, speedups)),
+ )
+ )
+ for shape in shape_keys
+ }
+ output_str = (
+ f"mean: {speedups_mean:.3f}, median: {speedups_median:.3f}, var: {speedups_var:.3f}"
+ + "\nSpeedups by shape: "
+ + "\n".join(
+ [
+ f"{shape}: "
+ + ", ".join([f"{speedup: .3g}" for speedup in shape_speedups[shape]])
+ for shape in shape_keys
+ ]
+ )
+ )
+ output_csv(
+ output_filename,
+ ("dev", "name", "batch_size", "speedup mean", "speedup median", "speedup var"),
+ [
+ current_device,
+ current_name,
+ current_batch_size,
+ speedups_mean,
+ speedups_median,
+ speedups_var,
+ ],
+ )
+ return output_str
+
+
+def overhead_experiment(*args, model_iter_fn):
+ """
+ Measure overheads of TorchDynamo by running with no backend (only
+ eager+FX), and reporting speedup/slowdown over eager.
+
+ Writes to ./overheads.csv
+ """
+ return speedup_experiment(*args, model_iter_fn)
+
+
+def print_fx(gm, example_inputs):
+ print(gm.graph)
+ return gm
+
+
+def print_aten_ops(gm, example_inputs):
+ from functorch.compile import aot_module
+
+ def trace_printer(gm, _):
+ print(gm.graph)
+ return gm
+
+ return aot_module(gm, fw_compiler=trace_printer, bw_compiler=trace_printer)
+
+
+def baselines(models, model_iter_fn, example_inputs, args):
+ """
+ Common measurement code across all baseline experiments.
+ """
+ models = list(models)
+ for idx, (name, model) in enumerate(models):
+ if idx == 0:
+ result0 = model_iter_fn(model, example_inputs)
+ elif model is not None:
+ try:
+ result = model_iter_fn(model, example_inputs)
+ if same(result0, result):
+ continue
+ print(name, "is INCORRECT")
+ except Exception:
+ log.exception("error checking %s", name)
+ models[idx] = (name, None)
+ timings = np.zeros((args.repeat, len(models)), np.float64)
+ timings.fill(1.0e10)
+ for rep in range(args.repeat):
+ for idx, (name, model) in enumerate(models):
+ if model is not None:
+ try:
+ timings[rep, idx] = timed(model, model_iter_fn, example_inputs)
+ except Exception:
+ pass
+ pvalue = [
+ ttest_ind(timings[:, 0], timings[:, i]).pvalue
+ for i in range(1, timings.shape[1])
+ ]
+ median = np.median(timings, axis=0)
+ speedup = median[0] / median[1:]
+ for idx, (name, model) in enumerate(models[1:]):
+ if model is None:
+ speedup[idx] = 0.0
+ result = " ".join(
+ [
+ format_speedup(s, p, m is not None)
+ for s, p, m in zip(speedup, pvalue, [m for n, m in models[1:]])
+ ]
+ )
+ output_csv(
+ output_filename,
+ ("dev", "name", "batch_size") + tuple(n for n, m in models[1:]),
+ [current_device, current_name, current_batch_size]
+ + [f"{x:.4f}" for x in speedup],
+ )
+ return result
+
+
+def try_script(model, example_inputs):
+ try:
+ return torch.jit.script(model)
+ except Exception:
+ return None
+
+
+def speedup_experiment_ts(args, model_iter_fn, model, example_inputs):
+ """
+ Measure baseline performance (without using TorchDynamo) of TorchScript and optimize_for_inference.
+
+ Writes to ./baseline_ts.csv
+ """
+ if args.training:
+ return baselines(
+ [
+ ("eager", model),
+ ("ts", try_script(model, example_inputs)),
+ ],
+ model_iter_fn,
+ example_inputs,
+ args,
+ )
+
+ return baselines(
+ [
+ ("eager", model),
+ ("ts", try_script(model, example_inputs)),
+ (
+ "ofi",
+ backends.ofi(try_script(model, example_inputs), example_inputs),
+ ),
+ # ("nnc", backends.nnc(try_script(model, example_inputs), example_inputs)),
+ # ("nvfuser", backends.nvfuser(try_script(model, example_inputs), example_inputs)),
+ ],
+ model_iter_fn,
+ example_inputs,
+ args,
+ )
+
+
+def speedup_experiment_sr(args, model_iter_fn, model, example_inputs):
+ """
+ Measure baseline performance (without using TorchDynamo) of static runtime.
+
+ Writes to ./baseline_sr.csv
+ """
+
+ if current_name not in ("opacus_cifar10", "timm_nfnet", "hf_T5"):
+ sr = backends.static_runtime(try_script(model, example_inputs), example_inputs)
+ else:
+ # segfaults on these models
+ sr = None
+ return baselines(
+ [
+ ("eager", model),
+ (
+ "sr",
+ sr,
+ ),
+ ],
+ model_iter_fn,
+ example_inputs,
+ args,
+ )
+
+
+def speedup_experiment_onnx(args, model_iter_fn, model, example_inputs):
+ """
+ Measure baseline performance (without using TorchDynamo) of ONNXRT and TensorFlow.
+
+ Writes to ./baseline_onnx.csv
+ """
+ if current_device == "cpu":
+ m_onnxrt = backends.onnxrt_cpu(
+ try_script(model, example_inputs), example_inputs
+ )
+ else:
+ m_onnxrt = backends.onnxrt_cuda(
+ try_script(model, example_inputs), example_inputs
+ )
+
+ if current_name != "timm_resnest":
+ m_onnx2tf = backends.onnx2tf(try_script(model, example_inputs), example_inputs)
+ else:
+ # this one takes 8+ hours to finish
+ m_onnx2tf = None
+
+ return baselines(
+ [
+ ("eager", model),
+ ("onnxrt", m_onnxrt),
+ ("onnx2tf", m_onnx2tf),
+ ],
+ model_iter_fn,
+ example_inputs,
+ args,
+ )
+
+
+def speedup_experiment_trt(args, model_iter_fn, model, example_inputs):
+ """
+ Measure baseline performance (without using TorchDynamo) of TensorRT.
+
+ Writes to ./baseline_trt.csv
+ """
+ m_onnx2trt = backends.onnx2tensorrt(
+ try_script(model, example_inputs), example_inputs
+ )
+
+ m_torch2trt = backends.torch2trt(model, example_inputs)
+
+ if current_name != "opacus_cifar10":
+ m_fx2trt = backends.fx2trt(model, example_inputs)
+ else:
+ # fx2trt infinite loops on one model
+ m_fx2trt = None
+
+ return baselines(
+ [
+ ("eager", model),
+ ("onnx2trt", m_onnx2trt),
+ ("torch2trt", m_torch2trt),
+ ("fx2trt", m_fx2trt),
+ ],
+ model_iter_fn,
+ example_inputs,
+ args,
+ )
+
+
+def read_batch_size_from_file(args, filename, model_name):
+ batch_size = None
+ if os.path.exists("benchmarks"):
+ filename = os.path.join("benchmarks", filename)
+ assert os.path.exists(filename), filename
+ with open(filename, "r") as f:
+ lines = f.readlines()
+ lines = [i.split(",") for i in lines if len(i.strip()) > 0]
+ for val in lines:
+ cur_name, b = val
+ if model_name == cur_name:
+ batch_size = int(b)
+ if batch_size is None:
+ log.warning("Could not find batch size for {}".format(model_name))
+ elif batch_size == -1:
+ raise RuntimeError(
+ f"Batch size is unset for {model_name} in {args.batch_size_file}"
+ )
+ print(f"batch size: {batch_size}")
+ return batch_size
+
+
+class TimeOutException(Exception):
+ pass
+
+
+def alarm_handler(signum, frame):
+ raise TimeOutException()
+
+
+def exit_after(s):
+ """
+ Decorator to raise TimeoutException if the fn is taking more than s seconds
+ to run.
+ """
+
+ def outer(fn):
+ def inner(*args, **kwargs):
+ signal.signal(signal.SIGALRM, alarm_handler)
+ signal.alarm(s)
+ try:
+ result = fn(*args, **kwargs)
+ finally:
+ signal.alarm(0)
+ return result
+
+ return inner
+
+ return outer
+
+
+def get_peak_memory():
+ return torch.cuda.max_memory_allocated() / 10**9
+
+
+def null_experiment(args, model_iter_fn, model, example_inputs):
+ """
+ A no-op experiment useful for making sure TorchBenchark alone works properly.
+ """
+
+ return []
+
+
+def cast_to(dtype, model, inputs):
+ # cast model and inputs to fp16
+ if dtype == torch.float16:
+ model = model.half()
+ else:
+ model = model.to(dtype)
+
+ inputs = tree_map(
+ lambda x: x.to(dtype)
+ if isinstance(x, torch.Tensor) and x.is_floating_point()
+ else x,
+ inputs,
+ )
+ return model, inputs
+
+
+def cast_to_fp16(model, inputs):
+ return cast_to(torch.float16, model, inputs)
+
+
+def cast_to_fp64(model, inputs):
+ return cast_to(torch.float64, model, inputs)
+
+
+def cast_to_fp32(model, inputs):
+ return cast_to(torch.float32, model, inputs)
+
+
+def reset_rng_state():
+ torch.manual_seed(1337)
+ random.seed(1337)
+ np.random.seed(1337)
+
+
+class DummyGradScaler:
+ def scale(self, loss):
+ return loss
+
+
+def maybe_fresh_cache(fn):
+ def inner(self, *args, **kwargs):
+ cache_minder = NullContext()
+ if self.args.cold_start_latency:
+ cache_entries = {}
+ cache_minder = fresh_triton_cache(cache_entries)
+
+ try:
+ with cache_minder:
+ return fn(self, *args, **kwargs)
+ finally:
+ dump_cache = False
+ if dump_cache and self.args.cold_start_latency:
+ output_csv(
+ output_filename[:-4] + "_triton_cache.csv",
+ ["dev", "name", "batch_size", "triton_cache"],
+ [
+ current_device,
+ current_name,
+ current_batch_size,
+ cache_entries,
+ ],
+ )
+
+ return inner
+
+
+class BenchmarkRunner:
+ def __init__(self):
+ self.model_iter_fn = None
+ self.use_amp = False
+ self.grad_scaler = DummyGradScaler()
+ self.autocast = NullContext
+ self._args = None
+
+ def setup_amp(self):
+ if self.args.amp and self.args.training:
+ assert self.args.devices == ["cuda"], "AMP is supported only for CUDA"
+ # AMP training can lead to small loss values which can undeflow
+ # gradient values returning in zero gradients. To solve this
+ # problem, PyTorch introduces GradScaler. GradScaler is a stateful
+ # structure, that scales the loss values to prevent underflow. Loss
+ # values are big at the beginning of training (therefore not
+ # requiring scaling), while loss value tends to be small as network
+ # starts getting better (requiring scaling). GradScaler manages all
+ # of this fine tuning, checking the gradients are turning to inf,
+ # discarding such batches.
+
+ # Since we are not running a long iteration, default value of
+ # init_scale 65536 is going to turn all gradients to inf. Therefore,
+ # we just use a init_scale of 2.0 for benchmarking purpose.
+ self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
+ self.autocast = torch.cuda.amp.autocast
+
+ def init_optimizer(self, device, params):
+ param_list = list(params)
+ if device == "cuda" and len(param_list) != 0:
+ # capturable is only supported on cuda at the moment
+ self.optimizer = torch.optim.Adam(param_list, capturable=True)
+ else:
+ self.optimizer = None
+
+ @property
+ def args(self):
+ return self._args
+
+ @args.setter
+ def args(self, args):
+ self._args = args
+
+ @property
+ def skip_models(self):
+ return set()
+
+ @property
+ def slow_models(self):
+ return set()
+
+ @property
+ def very_slow_models(self):
+ return set()
+
+ @property
+ def non_deterministic_models(self):
+ return set()
+
+ @property
+ def skip_not_suitable_for_training_models(self):
+ return set()
+
+ @property
+ def failing_torchinductor_models(self):
+ return set()
+
+ @property
+ def failing_fx2trt_models(self):
+ return set()
+
+ @property
+ def failing_dynamic_shape_models(self):
+ return set()
+
+ @property
+ def skip_accuracy_checks_large_models_dashboard(self):
+ return set()
+
+ @property
+ def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
+ raise NotImplementedError()
+
+ @property
+ def equal_nan(self):
+ equal_nan = True
+ if self.args.float32:
+ equal_nan = False
+ return equal_nan
+
+ def iter_models(self, args):
+ for model_name in self.iter_model_names(args):
+ for device in args.devices:
+ try:
+ yield self.load_model(
+ device,
+ model_name,
+ batch_size=args.batch_size,
+ )
+ except NotImplementedError:
+ continue # bad benchmark implementation
+
+ def validate_model(self, model, example_inputs):
+ """
+ Runs the eager model with example inputs to ensure that eager passes.
+ """
+ model = copy.deepcopy(model)
+ example_inputs = clone_inputs(example_inputs)
+ if self.args.float32:
+ model, example_inputs = cast_to_fp32(model, example_inputs)
+ elif self.args.float16:
+ model, example_inputs = cast_to_fp16(model, example_inputs)
+
+ try:
+ self.model_iter_fn(model, example_inputs)
+ except Exception:
+ raise NotImplementedError("Eager model failed to run")
+
+ def maybe_cast(self, model, example_inputs):
+ model = copy.deepcopy(model)
+ example_inputs = clone_inputs(example_inputs)
+ if self.args.float32:
+ model, example_inputs = cast_to_fp32(model, example_inputs)
+ elif self.args.float16:
+ model, example_inputs = cast_to_fp16(model, example_inputs)
+ return model, example_inputs
+
+ def decay_batch_exp(self, batch_size, factor=0.5, divisor=2):
+ out_batch_size = batch_size * factor
+ if out_batch_size > divisor:
+ out_batch_size = (out_batch_size + 1) // divisor * divisor
+ else:
+ out_batch_size = batch_size - 1
+ return max(0, int(out_batch_size))
+
+ def batch_size_finder(self, device, model_name, initial_batch_size=128):
+ batch_size = initial_batch_size
+ while batch_size >= 1:
+ torch.cuda.empty_cache()
+ try:
+ device, name, model, example_inputs, _ = self.load_model(
+ device,
+ model_name,
+ batch_size,
+ )
+ self.model_iter_fn(model, example_inputs)
+ return batch_size
+ except RuntimeError as e:
+ error_str = str(e)
+ if "channels_last" in error_str:
+ break
+ batch_size = self.decay_batch_exp(batch_size)
+ return 1
+
+ def optimizer_step(self):
+ if self.optimizer is not None:
+ self.optimizer.step()
+
+ def get_benchmark_indices(self, length):
+ start = self._args.partition_id * (length // self._args.total_partitions)
+ end = (
+ (self._args.partition_id + 1) * (length // self._args.total_partitions)
+ if self._args.partition_id < self._args.total_partitions - 1
+ else length
+ )
+ return start, end
+
+ def check_accuracy(self, name, model, example_inputs, optimize_ctx, experiment):
+ """
+ Checks accuracy.
+ 1) Collect the outputs with fp64 datatype. This is useful for error checking.
+ 2) Checks if eager itself has variations.
+ """
+
+ def record_status(accuracy_status):
+ """
+ Records the status in the csv file
+ """
+ if current_name in self.non_deterministic_models:
+ if accuracy_status in ("pass", "eager_variation", "fail_accuracy"):
+ accuracy_status = "pass"
+
+ output_csv(
+ output_filename,
+ ("dev", "name", "batch_size", "accuracy"),
+ [current_device, current_name, current_batch_size, accuracy_status],
+ )
+ return "PASS" if accuracy_status in ("pass", "pass_due_to_skip") else "FAIL"
+
+ tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
+ self.args.training, current_device, name
+ )
+
+ if name in self.skip_accuracy_checks_large_models_dashboard:
+ return record_status("pass_due_to_skip")
+
+ # Collect the fp64 reference outputs to be used later for accuracy checking.
+ fp64_outputs = None
+ try:
+ fp64_outputs = self.model_iter_fn(
+ *cast_to_fp64(
+ copy.deepcopy(model),
+ clone_inputs(example_inputs),
+ )
+ )
+ except Exception:
+ log.warning(f"fp64 golden ref were not generated for {name}")
+ fp64_outputs = None
+ if self.args.ci and self.args.training:
+ return record_status("fp64_OOM")
+
+ # Cast the model to float16/float32 as necessary
+ model, example_inputs = self.maybe_cast(model, example_inputs)
+
+ accuracy_status = "pass"
+
+ with self.pick_grad(name, self.args.training):
+ # Get results of native pytorch
+ reset_rng_state()
+ correct_result = self.model_iter_fn(
+ copy.deepcopy(model), clone_inputs(example_inputs)
+ )
+
+ # Rerun native pytorch
+ reset_rng_state()
+ correct_rerun_result = self.model_iter_fn(
+ copy.deepcopy(model), clone_inputs(example_inputs)
+ )
+ if not same(
+ correct_result,
+ correct_rerun_result,
+ fp64_outputs,
+ equal_nan=self.equal_nan,
+ ):
+ accuracy_status = "eager_variation"
+ return record_status(accuracy_status)
+ correct_rerun_result = None
+
+ # Run with Dynamo
+ reset_rng_state()
+ torch._dynamo.reset()
+ try:
+ optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
+ new_result = optimized_model_iter_fn(model, example_inputs)
+ except Exception as e:
+ accuracy_status = "fail_to_run"
+ print(
+ "TorchDynamo optimized model failed to run because of following error"
+ )
+ log.exception(e)
+ return record_status(accuracy_status)
+
+ if not same(
+ correct_result,
+ new_result,
+ fp64_outputs,
+ equal_nan=self.equal_nan,
+ cos_similarity=cos_similarity,
+ tol=tolerance,
+ ):
+ if self.args.skip_accuracy_check:
+ accuracy_status = "pass_due_to_skip"
+ else:
+ accuracy_status = "fail_accuracy"
+ return record_status(accuracy_status)
+
+ return record_status(accuracy_status)
+
+ def run_performance_test(
+ self, name, model, example_inputs, optimize_ctx, experiment
+ ):
+ def warmup(fn, model, example_inputs, mode, niters=5):
+ peak_mem = 0
+ try:
+ if current_device == "cuda":
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.empty_cache()
+ t0 = time.perf_counter()
+ for _ in range(niters):
+ fn(model, example_inputs)
+ t1 = time.perf_counter()
+ latency = t1 - t0
+ if current_device == "cuda":
+ peak_mem = get_peak_memory()
+ except Exception as e:
+ log.exception(f"Failed for {mode} {e}")
+ return sys.exit(-1)
+ return latency, peak_mem
+
+ # Cast the model to float16/float32 as necessary
+ model, example_inputs = self.maybe_cast(model, example_inputs)
+ with self.pick_grad(name, self.args.training):
+ ok, total = Stats.reset_counters()
+ experiment_kwargs = {}
+ results = []
+
+ eager_latency, eager_peak_mem = warmup(
+ self.model_iter_fn, model, example_inputs, "eager"
+ )
+ optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
+ dynamo_latency, dynamo_peak_mem = warmup(
+ optimized_model_iter_fn, model, example_inputs, "dynamo"
+ )
+
+ compilation_time = dynamo_latency - eager_latency
+ compression_ratio = eager_peak_mem / dynamo_peak_mem
+ # print(
+ # f"memory: eager: {eager_peak_mem:.2f} GB, "
+ # f"dynamo: {dynamo_peak_mem:.2f} GB, "
+ # f"ratio: {compression_ratio:.2f}"
+ # )
+
+ if experiment.func is speedup_experiment:
+ experiment_kwargs["compilation_latency"] = compilation_time
+ experiment_kwargs["compression_ratio"] = compression_ratio
+
+ if experiment.func is coverage_experiment:
+ ok, total = Stats.reset_counters()
+ results = []
+ # run with torch._dynamo few times to populate the cache
+ for _ in range(3):
+ optimized_model_iter_fn(model, example_inputs)
+ _, frames_second_pass = Stats.reset_counters() # should be 0
+ if frames_second_pass > 0:
+ optimized_model_iter_fn(model, example_inputs)
+ _, frames_third_pass = Stats.reset_counters() # should be 0
+ else:
+ frames_third_pass = 0
+
+ results.append(
+ f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s"
+ )
+
+ if not hasattr(model, name):
+ model.name = name
+ results.append(experiment(model, example_inputs, **experiment_kwargs))
+ return " ".join(map(str, results))
+
+ def compare_branches(
+ self,
+ name,
+ model,
+ example_inputs,
+ optimize_ctx,
+ experiment,
+ diff=False,
+ branch=None,
+ ):
+ assert branch is None, "Branch set during top level flow."
+ import git
+
+ repo = git.Repo(
+ "../torch._dynamo"
+ ) # Hack assumption of torchbenchmark positioning
+ curr_branch = repo.active_branch.name
+ if curr_branch != "main":
+ if repo.is_dirty():
+ raise RuntimeError(
+ "--diff_main called on dirty branch. Commit, stash, or reset."
+ )
+ # Run current
+ try:
+ self.run_one_model(
+ name,
+ model,
+ self.model_iter_fn,
+ example_inputs,
+ optimize_ctx,
+ experiment,
+ diff=False,
+ branch=curr_branch,
+ )
+ # Swap to main
+ repo.git.checkout("main")
+ # Run main
+ self.run_one_model(
+ name,
+ model,
+ self.model_iter_fn,
+ example_inputs,
+ optimize_ctx,
+ experiment,
+ diff=False,
+ branch="main",
+ )
+ finally:
+ # Swap back
+ repo.git.checkout(curr_branch)
+ return
+ else:
+ raise RuntimeError(
+ "--diff_main called on main branch, what are you diffing?"
+ )
+
+ @maybe_fresh_cache
+ def run_one_model(
+ self,
+ name,
+ model,
+ example_inputs,
+ optimize_ctx,
+ experiment,
+ diff=False,
+ branch=None,
+ ):
+ if diff:
+ self.compare_branches(
+ name, model, example_inputs, optimize_ctx, experiment, diff, branch
+ )
+ elif branch:
+ print("RUNNING ON BRANCH:", branch)
+ mode = "train" if self.args.training else "eval"
+ print(f"{current_device:4} {mode:5} {current_name:34} ", end="", flush=True)
+ if self.args.accuracy:
+ status = self.check_accuracy(
+ name, model, example_inputs, optimize_ctx, experiment
+ )
+ print(status)
+ elif self.args.performance:
+ status = self.run_performance_test(
+ name, model, example_inputs, optimize_ctx, experiment
+ )
+ print(status)
+
+
+def help(fn):
+ return fn.__doc__
+
+
+def parse_args():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--filter", "-k", action="append", help="filter benchmarks with regexp"
+ )
+ parser.add_argument(
+ "--exclude", "-x", action="append", help="filter benchmarks with regexp"
+ )
+ parser.add_argument(
+ "--total-partitions",
+ type=int,
+ default=1,
+ choices=range(1, 10),
+ help="Total number of partitions we want to divide the benchmark suite into",
+ )
+ parser.add_argument(
+ "--partition-id",
+ type=int,
+ default=0,
+ help="ID of the benchmark suite partition to be run. Used to divide CI tasks",
+ )
+ parser.add_argument("--devices", "-d", action="append", help="cpu or cuda")
+ parser.add_argument(
+ "--repeat", "-n", type=int, default=30, help="number of timing runs"
+ )
+ parser.add_argument(
+ "--randomize-input",
+ action="store_true",
+ help="Whether to randomize the input values. Dimensions will be kept the same.",
+ )
+ parser.add_argument(
+ "--threads", "-t", type=int, help="number of threads to use for eager"
+ )
+ parser.add_argument(
+ "--nopython", action="store_true", help="Turn graph breaks into errors"
+ )
+ parser.add_argument(
+ "--no-skip",
+ action="store_true",
+ help="run models that are in the global SKIP list",
+ )
+ parser.add_argument(
+ "--prims-nvfuser", action="store_true", help="user prims + nvfuser backend"
+ )
+ parser.add_argument(
+ "--dump-raw-metrics",
+ action="store_true",
+ help="dump raw timing metrics from speedup experiment",
+ )
+ parser.add_argument(
+ "--log-operator-inputs",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--channels-last",
+ action="store_true",
+ default=False,
+ help="use channels last format",
+ )
+ parser.add_argument("--batch_size", type=int, help="batch size for benchmarking")
+ parser.add_argument(
+ "--batch-size-file", type=str, help="String to load batch size from"
+ )
+ parser.add_argument("--cosine", action="store_true", help="use cosine similarity")
+ parser.add_argument(
+ "--ci", action="store_true", help="Flag to tell that its a CI run"
+ )
+ parser.add_argument(
+ "--dashboard", action="store_true", help="Flag to tell that its a Dashboard run"
+ )
+ parser.add_argument(
+ "--skip-fp64-check", action="store_true", help="skip accuracy check using fp64"
+ )
+ parser.add_argument(
+ "--fast", "-f", action="store_true", help="skip slow benchmarks"
+ )
+ parser.add_argument("--only", help="Run just one model")
+ parser.add_argument(
+ "--training",
+ action="store_true",
+ help="Performs training",
+ )
+ parser.add_argument(
+ "--dynamic-shapes",
+ action="store_true",
+ help="Runs a dynamic shapes version of the benchmark, if available.",
+ )
+ parser.add_argument(
+ "--use-eval-mode",
+ action="store_true",
+ help="sets model.eval() to reduce randomness",
+ )
+ parser.add_argument(
+ "--skip-accuracy-check",
+ action="store_true",
+ help="keeps running even when accuracy fails",
+ )
+ parser.add_argument(
+ "--generate-aot-autograd-stats",
+ action="store_true",
+ help="Generates AOT Autograd stats like how mnay graphs are sent to AOT",
+ )
+ parser.add_argument(
+ "--inductor-settings",
+ action="store_true",
+ help="Use same settings as --inductor for baseline comparisons",
+ )
+ parser.add_argument(
+ "--raise-on-assertion-error",
+ action="store_true",
+ help="Fail a benchmark if torch._dynamo triggers an internal assertion",
+ )
+ parser.add_argument(
+ "--raise-on-backend-error",
+ action="store_true",
+ help="Fail a benchmark if backend throws an exception",
+ )
+ parser.add_argument(
+ "--output",
+ help="Overrides the output filename",
+ )
+ parser.add_argument(
+ "--export-profiler-trace",
+ action="store_true",
+ help="exports trace of kineto profiler",
+ )
+ parser.add_argument("--profiler_trace_name", help="Overwrites exported trace name")
+
+ parser.add_argument(
+ "--diff_main",
+ action="store_true",
+ help="Delta this branch against main. In the future, we may add support for picking the branch.",
+ )
+
+ parser.add_argument(
+ "--cold_start_latency",
+ action="store_true",
+ help="Use a fresh triton cachedir when running each model, to force cold-start compile.",
+ )
+
+ group_fuser = parser.add_mutually_exclusive_group()
+ # --nvfuser is now the default, keep the option to not break scripts
+ group_fuser.add_argument("--nvfuser", action="store_true", help=argparse.SUPPRESS)
+ group_fuser.add_argument("--nnc", action="store_true", help="enable NNC for GPUs")
+
+ group_prec = parser.add_mutually_exclusive_group()
+ group_prec.add_argument("--float16", action="store_true", help="cast model to fp16")
+ group_prec.add_argument("--float32", action="store_true", help="cast model to fp32")
+ group_prec.add_argument(
+ "--amp", action="store_true", help="use automatic mixed precision"
+ )
+
+ group_printout = parser.add_mutually_exclusive_group()
+ group_printout.add_argument(
+ "--verbose", "-v", action="store_true", help="enable verbose debug printouts"
+ )
+ group_printout.add_argument(
+ "--quiet", "-q", action="store_true", help="suppress debug printouts"
+ )
+
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument(
+ "--coverage", action="store_true", help="(default) " + help(coverage_experiment)
+ )
+ group.add_argument(
+ "--speedup-ltc",
+ action="store_true",
+ help="speedup using the ltc backend",
+ )
+ group.add_argument(
+ "--speedup-ltc-trivial",
+ action="store_true",
+ help="speedup using the ltc backend without reusing compiled graph",
+ )
+ group.add_argument(
+ "--cold-start", action="store_true", help=help(cold_start_experiment)
+ )
+ group.add_argument(
+ "--overhead", action="store_true", help=help(overhead_experiment)
+ )
+ group.add_argument(
+ "--speedup-ts", action="store_true", help=help(speedup_experiment_ts)
+ )
+ group.add_argument(
+ "--speedup-sr", action="store_true", help=help(speedup_experiment_sr)
+ )
+ group.add_argument(
+ "--speedup-onnx", action="store_true", help=help(speedup_experiment_onnx)
+ )
+ group.add_argument(
+ "--speedup-trt", action="store_true", help=help(speedup_experiment_trt)
+ )
+ group.add_argument(
+ "--speedup-dynamo-ts",
+ action="store_true",
+ help="TorchDynamo frontend with torchscript backend",
+ )
+ group.add_argument(
+ "--speedup-fx2trt", action="store_true", help=help(speedup_experiment_fx2trt)
+ )
+ group.add_argument(
+ "--speedup-fx2trt-fp16",
+ action="store_true",
+ help=help(speedup_experiment_fx2trt),
+ )
+ group.add_argument(
+ "--print-fx",
+ action="store_true",
+ help="Print fx traces captured from model",
+ )
+ group.add_argument(
+ "--print-aten-ops",
+ action="store_true",
+ help="Print traces of aten ops captured by AOT autograd",
+ )
+ group.add_argument(
+ "--inductor",
+ action="store_true",
+ help="Measure speedup with TorchInductor",
+ )
+ group.add_argument(
+ "--inductor-dynamic",
+ action="store_true",
+ help="Measure speedup with TorchInductor",
+ )
+ group.add_argument(
+ "--backend",
+ choices=torch._dynamo.list_backends(),
+ help="measure speedup with a given backend",
+ )
+ group.add_argument("--nothing", action="store_true", help=help(null_experiment))
+ group.add_argument(
+ "--log-conv-args",
+ action="store_true",
+ help="Dump convolution input/weight/bias's shape/stride/dtype and other options to json",
+ )
+ group.add_argument(
+ "--recompile_profiler",
+ action="store_true",
+ help="Run the dynamo recompilation profiler on each model.",
+ )
+ group.add_argument(
+ "--find-batch-sizes",
+ action="store_true",
+ help="finds the largest batch size that could fit on GPUs",
+ )
+
+ mode_group = parser.add_mutually_exclusive_group(required=True)
+ mode_group.add_argument(
+ "--accuracy",
+ action="store_true",
+ help="Checks accuracy with small batch size and eval mode",
+ )
+ mode_group.add_argument(
+ "--performance", action="store_true", help="Measures performance speedup"
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main(runner, original_dir=None):
+ args = parse_args()
+
+ # Pass the parsed args object to benchmark runner object
+ runner.args = args
+
+ # defaults
+ args.filter = args.filter or [r"."]
+ args.exclude = args.exclude or [r"^$"]
+
+ if args.ci:
+ # Only dump error on CI
+ args.quiet = True
+ args.repeat = 2
+ if args.backend == "aot_eager":
+ args.exclude = (
+ CI_SKIP_AOT_EAGER_TRAINING
+ if args.training
+ else CI_SKIP_AOT_EAGER_INFERENCE
+ )
+ elif args.inductor:
+ args.exclude = (
+ CI_SKIP_INDUCTOR_TRAINING
+ if args.training
+ else CI_SKIP_INDCUTOR_INFERENCE
+ )
+
+ if args.accuracy:
+ # Use small batch size. We use >1 batch size to ensure we test
+ # batch_norm type of operators that work on batch dims.
+ # TODO - Go through the failures for batch size = 2
+ if args.batch_size is None:
+ if runner.suite_name == "huggingface":
+ args.batch_size = 1
+ else:
+ args.batch_size = 2
+
+ # Remove sources of randomness
+ args.use_eval_mode = True
+
+ # Remove randomeness when torch manual seed is called
+ patch_torch_manual_seed()
+
+ # Some models e.g. yolov3 assert batch size on n_gpus
+ if "CUDA_VISIBLE_DEVICES" not in os.environ:
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+
+ # Stricter check to disable fallbacks
+ args.raise_on_assertion_error = True
+ args.raise_on_backend_error = True
+
+ elif args.performance:
+ # Ensure that we test on real scenarios
+ args.use_eval_mode = False
+
+ if args.partition_id > args.total_partitions or args.partition_id < 0:
+ print("Invalid partition id")
+ return sys.exit(-1)
+
+ if not args.devices:
+ if torch.cuda.is_available():
+ args.devices = ["cuda"]
+ else:
+ log.warning("torch.cuda.is_available() == False, using CPU")
+ args.devices = ["cpu"]
+
+ if args.devices != ["cpu"] and torch.cuda.is_available():
+ global synchronize
+ synchronize = torch.cuda.synchronize
+
+ if (
+ args.devices == ["cuda"]
+ and torch.cuda.get_device_properties(0).total_memory < 25 * 2**30
+ ):
+ # OOM errors on an RTX 3090 with 24gb RAM
+ runner.skip_models.update(
+ {
+ # torchbench
+ "hf_Longformer",
+ "timm_nfnet",
+ "timm_efficientdet",
+ # timm
+ "beit_base_patch16_224",
+ "cait_m36_384",
+ "convmixer_768_32",
+ "deit_base_distilled_patch16_224",
+ "dm_nfnet_f0",
+ "dpn107",
+ "dm_nfnet_f0",
+ }
+ )
+ if args.training:
+ runner.skip_models.add("hf_T5")
+
+ if torch._dynamo.config.dynamic_shapes:
+ # TODO(jansel): fix bugs in these
+ runner.skip_models.update(runner.failing_dynamic_shape_models)
+
+ if args.nnc:
+ torch._C._jit_override_can_fuse_on_cpu(True)
+ torch._C._jit_override_can_fuse_on_gpu(True)
+ torch._C._jit_set_texpr_fuser_enabled(True)
+ torch._C._jit_set_nvfuser_enabled(False)
+
+ if args.threads:
+ torch.set_num_threads(args.threads)
+
+ if args.verbose:
+ torch._dynamo.config.log_level = logging.DEBUG
+
+ if args.quiet:
+ torch._dynamo.config.log_level = logging.ERROR
+
+ torch._dynamo.config.raise_on_assertion_error = args.raise_on_assertion_error
+ torch._dynamo.config.raise_on_backend_error = args.raise_on_backend_error
+
+ if args.training:
+ runner.model_iter_fn = runner.forward_and_backward_pass
+ runner.skip_models.update(runner.skip_not_suitable_for_training_models)
+ else:
+ runner.model_iter_fn = runner.forward_pass
+
+ if args.fast:
+ runner.skip_models.update(runner.slow_models)
+
+ if args.devices == ["cpu"]:
+ runner.skip_models.update(runner.very_slow_models)
+
+ if args.inductor or args.inductor_dynamic or args.inductor_settings:
+ runner.skip_models.update(runner.failing_torchinductor_models)
+ if args.float16:
+ # TODO(jansel): check if correctness issue is real
+ runner.skip_models.add("yolov3")
+
+ if args.float16:
+ # these give `INCORRECT - Variation in Eager runs itself` sometimes
+ runner.non_deterministic_models.update(
+ {
+ "demucs",
+ "pyhpc_equation_of_state",
+ "timm_efficientdet",
+ "pyhpc_isoneutral_mixing",
+ "pyhpc_turbulent_kinetic_energy",
+ "shufflenet_v2_x1_0",
+ }
+ )
+
+ if args.no_skip:
+ runner.skip_models.clear()
+
+ experiment = null_experiment
+ global current_name, current_device, current_batch_size, output_filename, optimize_ctx
+ optimize_ctx = NullContext()
+
+ if args.overhead:
+ optimize_ctx = torch._dynamo.optimize(dummy_fx_compile, nopython=args.nopython)
+ experiment = speedup_experiment
+ output_filename = "overheads.csv"
+ elif args.cold_start:
+ optimize_ctx = torch._dynamo.optimize("aot_nvfuser", nopython=args.nopython)
+ experiment = cold_start_experiment
+ assert args.nvfuser, "TODO - Add another aot string for mem fusion with NNC"
+ backend_str = "nvfuser" if args.nvfuser else "nnc"
+ output_filename = f"cold_start_{backend_str}.csv"
+ # TODO(whc) should we move this to a more general part of the script?
+ torch.backends.cuda.matmul.allow_tf32 = True
+ elif args.inductor or args.inductor_dynamic:
+ import torch._inductor.config
+
+ torch._inductor.config.debug = args.verbose
+ if args.threads:
+ torch._inductor.config.cpp.threads = args.threads
+
+ if args.inductor_dynamic:
+ torch._inductor.config.triton.cudagraphs = False
+ torch._inductor.config.dynamic_shapes = True
+ else:
+ torch._inductor.config.dynamic_shapes = False
+ if args.export_profiler_trace:
+ print("Profiling requested, setting cudagraphs to False")
+ torch._inductor.config.triton.cudagraphs = False
+
+ optimize_ctx = torch._dynamo.optimize("inductor", nopython=args.nopython)
+ experiment = speedup_experiment
+ output_filename = "inductor.csv"
+ elif args.speedup_ltc:
+ optimize_ctx = torch._dynamo.optimize(
+ backends.ltc_reuse_graph, nopython=args.nopython
+ )
+ experiment = speedup_experiment
+ output_filename = "speedups_ltc.csv"
+ elif args.speedup_ltc_trivial:
+ optimize_ctx = torch._dynamo.optimize(
+ backends.ltc_trivial, nopython=args.nopython
+ )
+ experiment = speedup_experiment
+ output_filename = "speedups_ltc_trivial.csv"
+ elif args.speedup_ts:
+ experiment = speedup_experiment_ts
+ output_filename = "baseline_ts.csv"
+ elif args.speedup_sr:
+ experiment = speedup_experiment_sr
+ output_filename = "baseline_sr.csv"
+ elif args.speedup_onnx:
+ experiment = speedup_experiment_onnx
+ output_filename = "baseline_onnx.csv"
+ elif args.speedup_trt:
+ experiment = speedup_experiment_trt
+ output_filename = "baseline_trt.csv"
+ elif args.speedup_dynamo_ts:
+ optimize_ctx = torch._dynamo.optimize(backends.ts, nopython=args.nopython)
+ experiment = speedup_experiment
+ output_filename = "speedup_dynamo_ts.csv"
+ elif args.speedup_fx2trt:
+ optimize_ctx = torch._dynamo.optimize(
+ backends.fx2trt_compiler, nopython=args.nopython
+ )
+ experiment = speedup_experiment_fx2trt
+ output_filename = "speedups_fx2trt.csv"
+ runner.skip_models.update(runner.failing_fx2trt_models)
+ args.float32 = True
+ args.float16 = False
+ args.cosine = True
+ elif args.speedup_fx2trt_fp16:
+ optimize_ctx = torch._dynamo.optimize(
+ backends.fx2trt_compiler_fp16, nopython=args.nopython
+ )
+ experiment = speedup_experiment_fx2trt
+ output_filename = "speedups_fx2trt_fp16.csv"
+ args.float32 = False
+ args.float16 = True
+ args.cosine = True
+ elif args.prims_nvfuser:
+ optimize_ctx = torch._dynamo.optimize("prims_nvfuser", nopython=args.nopython)
+ experiment = speedup_experiment
+ backend_str = "prims_nvfuser"
+ output_filename = f"accuracy_aot_{backend_str}.csv"
+ elif args.print_fx:
+ optimize_ctx = torch._dynamo.optimize(
+ print_fx,
+ nopython=args.nopython,
+ )
+ elif args.print_aten_ops:
+ optimize_ctx = torch._dynamo.optimize(
+ print_aten_ops,
+ nopython=args.nopython,
+ )
+ elif args.nothing:
+ pass
+ elif args.backend:
+ optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
+ experiment = speedup_experiment
+ if args.accuracy:
+ output_filename = f"accuracy_{args.backend}.csv"
+ else:
+ output_filename = f"speedup_{args.backend}.csv"
+ elif args.log_conv_args:
+ optimize_ctx = torch._dynamo.optimize(
+ conv_args_analysis, nopython=args.nopython
+ )
+ output_filename = "log_conv_args.csv"
+ elif args.recompile_profiler:
+ output_filename = "recompile_profiler_log.csv"
+ experiment = recompile_profiler_experiment
+ else:
+ optimize_ctx = torch._dynamo.optimize(
+ fx_insert_profiling, nopython=args.nopython
+ )
+ experiment = coverage_experiment
+ output_filename = "coverage.csv"
+
+ runner.setup_amp()
+
+ if args.output:
+ output_filename = args.output
+
+ if output_filename:
+ output_filename = os.path.join(torch._dynamo.config.base_dir, output_filename)
+
+ if args.find_batch_sizes and args.only:
+ for device in args.devices:
+ batch_size = runner.batch_size_finder(device, args.only)
+ print(args.only, batch_size)
+ output_csv(output_filename, [], [args.only, batch_size])
+ return
+
+ if args.export_profiler_trace:
+ if args.profiler_trace_name is None:
+ if args.backend:
+ args.profiler_trace_name = args.backend
+ elif args.inductor or args.inductor_dynamic:
+ args.profiler_trace_name = "inductor"
+ else:
+ args.profiler_trace_name = "profile"
+ else:
+ args.profiler_trace_name = args.profiler_trace_name
+
+ experiment = functools.partial(experiment, args, runner.model_iter_fn)
+
+ if args.only:
+ model_name = args.only
+ for device in args.devices:
+ batch_size = args.batch_size
+ if args.batch_size_file:
+ batch_size = read_batch_size_from_file(
+ args, args.batch_size_file, model_name
+ )
+ try:
+ device, name, model, example_inputs, batch_size = runner.load_model(
+ device,
+ model_name,
+ batch_size=batch_size,
+ )
+ except NotImplementedError as e:
+ print(e)
+ import traceback
+
+ print(traceback.format_exc())
+ logging.warn(f"{args.only} failed to load")
+ continue # bad benchmark implementation
+
+ current_name = name
+ current_device = device
+ current_batch_size = batch_size
+ set_model_name(name)
+
+ if args.float32:
+ model, example_inputs = cast_to_fp32(model, example_inputs)
+ elif args.float16:
+ model, example_inputs = cast_to_fp16(model, example_inputs)
+
+ if args.log_operator_inputs:
+ log_operator_inputs(
+ model, example_inputs, runner.model_iter_fn, name, args
+ )
+ continue
+
+ runner.run_one_model(
+ name,
+ model,
+ example_inputs,
+ optimize_ctx,
+ experiment,
+ diff=args.diff_main,
+ )
+ if args.generate_aot_autograd_stats:
+ stats_file = output_filename.split(".csv")[0] + "_stats.csv"
+ output_csv(
+ stats_file,
+ ("dev", "name", "batch_size", "total_aot_graphs", "ok_aot_graphs"),
+ [
+ current_device,
+ current_name,
+ current_batch_size,
+ *Stats.aot_summary(),
+ ],
+ )
+ else:
+ if output_filename and os.path.exists(output_filename):
+ os.unlink(output_filename)
+ if original_dir:
+ os.chdir(original_dir)
+ for name in runner.iter_model_names(args):
+ current_name = name
+ placeholder_batch_size = 0
+ try:
+ subprocess.check_call([sys.executable] + sys.argv + [f"--only={name}"])
+ except subprocess.SubprocessError:
+ print("ERROR")
+ for device in args.devices:
+ output_csv(
+ output_filename, [], [device, name, placeholder_batch_size, 0.0]
+ )
+ print_summary(output_filename)
+
+
+def log_operator_inputs(model, example_inputs, model_iter_fn, name, args):
+ mode = "training" if args.training else "eval"
+ output = os.path.join(os.path.dirname(args.output), f"{name}_{mode}.txt")
+
+ # TODO - add option for coalescing inputs over multiple runs
+ if os.path.exists(output):
+ print(f"Skipping {name}, {output} already exists")
+ return
+
+ print(f"Running {name}")
+
+ operator_mode = OperatorInputsMode()
+ fake_tensor_mode = FakeTensorMode()
+
+ with torch._subclasses.fake_tensor.FakeCopyMode(fake_tensor_mode):
+ model_fake = copy.deepcopy(model)
+ example_inputs_fake = copy.deepcopy(example_inputs)
+ try:
+ with fake_tensor_mode, operator_mode:
+ model_iter_fn(model_fake, example_inputs_fake, collect_outputs=False)
+ except Exception as e:
+ print(f"{name} failed to run with fake tensors, trying real. Exception: {e}")
+ operator_mode = OperatorInputsMode()
+ try:
+ with operator_mode:
+ model_iter_fn(model, example_inputs, collect_outputs=False)
+ except Exception as e2:
+ print(f"{name} failed to run with real. Exception: {e2}")
+ raise
+
+ print(f"Writing output to {output}")
+ operator_mode.log_to_file(output)
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.WARNING)
+ warnings.filterwarnings("ignore")
+ main()
diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py
new file mode 100755
index 0000000000000..87d2131087d6b
--- /dev/null
+++ b/benchmarks/dynamo/huggingface.py
@@ -0,0 +1,543 @@
+#!/usr/bin/env python3
+import importlib
+import logging
+import os
+import re
+import subprocess
+import sys
+import warnings
+
+import torch
+from common import BenchmarkRunner, main
+
+from torch._dynamo.testing import collect_results
+from torch._dynamo.utils import clone_inputs
+
+log = logging.getLogger(__name__)
+
+
+def pip_install(package):
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
+
+
+# Disable the flake warnings for the imports. Flake8 does not provide a way to
+# disable just warning for the entire file. Disabling flake8 entirely.
+# flake8: noqa
+imports = [
+ "AlbertForPreTraining",
+ "AutoConfig",
+ "AutoModelForCausalLM",
+ "AutoModelForMaskedLM",
+ "AutoModelForSeq2SeqLM",
+ "BigBirdConfig",
+ "BlenderbotForConditionalGeneration",
+ "BlenderbotModel",
+ "BlenderbotSmallForConditionalGeneration",
+ "BlenderbotSmallModel",
+ "CLIPModel",
+ "CLIPVisionModel",
+ "ElectraForPreTraining",
+ "GPT2ForSequenceClassification",
+ "GPTJForSequenceClassification",
+ "GPTNeoForSequenceClassification",
+ "HubertForSequenceClassification",
+ "LxmertForPreTraining",
+ "LxmertForQuestionAnswering",
+ "MarianForCausalLM",
+ "MarianModel",
+ "MarianMTModel",
+ "PegasusForConditionalGeneration",
+ "PegasusModel",
+ "ReformerConfig",
+ "ViTForImageClassification",
+ "ViTForMaskedImageModeling",
+ "ViTModel",
+]
+
+
+try:
+ mod = importlib.import_module("transformers")
+ for cls in imports:
+ if not hasattr(mod, cls):
+ raise ModuleNotFoundError
+except ModuleNotFoundError:
+ print("Installing HuggingFace Transformers...")
+ pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers")
+finally:
+ for cls in imports:
+ exec(f"from transformers import {cls}")
+
+
+USE_HALF_BATCH_SIZE = True
+
+
+# These models contain the models present in huggingface_models_list. It is a
+# combination of models supported by HF Fx parser and some manually supplied
+# models. For these models, we already know the largest batch size that can fit
+# on A100 GPUs - 40 GB.
+BATCH_SIZE_KNOWN_MODELS = dict()
+
+
+# Get the list of models and their batch sizes
+MODELS_FILENAME = "huggingface_models_list.txt"
+if os.path.exists("benchmarks"):
+ MODELS_FILENAME = os.path.join("benchmarks", MODELS_FILENAME)
+assert os.path.exists(MODELS_FILENAME)
+with open(MODELS_FILENAME, "r") as fh:
+ lines = fh.readlines()
+ lines = [line.rstrip() for line in lines]
+ for line in lines:
+ model_name, batch_size = line.split(",")
+ batch_size = int(batch_size)
+ BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
+assert len(BATCH_SIZE_KNOWN_MODELS)
+
+
+SKIP = {
+ # Difficult to run and compare
+ "Reformer",
+ # Fails deepcopy
+ "BlenderbotForCausalLM",
+ "BlenderbotForConditionalGeneration",
+ "GPTJForCausalLM",
+ "GPTJForQuestionAnswering",
+ "GPTNeoForCausalLM",
+ "GPTNeoForSequenceClassification",
+ # Fails with even batch size = 1
+ "DebertaV2ForMaskedLM",
+ "DebertaV2ForQuestionAnswering",
+}
+
+# TODO - Fails even after fake tensors
+USE_SMALL_BATCH_SIZE = {
+ "AlbertForMaskedLM": 2,
+ "AlbertForPreTraining": 4,
+ "AlbertForQuestionAnswering": 2,
+ "BartForCausalLM": 2,
+ "BartForConditionalGeneration": 1,
+ "BlenderbotSmallForConditionalGeneration": 32,
+ "DebertaForMaskedLM": 4,
+ "DebertaForQuestionAnswering": 4,
+ "DebertaV2ForMaskedLM": 1,
+ "DebertaV2ForQuestionAnswering": 1,
+ "DistilBertForMaskedLM": 16,
+ "ElectraForCausalLM": 1,
+ "GPTNeoForCausalLM": 1,
+ "GPTNeoForSequenceClassification": 1,
+ "M2M100ForConditionalGeneration": 2,
+ "MT5ForConditionalGeneration": 2,
+ "MegatronBertForCausalLM": 2,
+ "OPTForCausalLM": 4,
+ "PegasusForCausalLM": 8,
+ "PegasusForConditionalGeneration": 4,
+ "RobertaForCausalLM": 4,
+ "TrOCRForCausalLM": 8,
+ "XGLMForCausalLM": 1,
+ "XLNetLMHeadModel": 4,
+}
+
+
+def get_module_cls_by_model_name(model_cls_name):
+ _module_by_model_name = {
+ "Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2",
+ "TrOCRDecoder": "transformers.models.trocr.modeling_trocr",
+ }
+ module_name = _module_by_model_name.get(model_cls_name, "transformers")
+ module = importlib.import_module(module_name)
+ return getattr(module, model_cls_name)
+
+
+def get_sequence_length(model_cls, model_name):
+ if model_name.startswith(("Bert", "Roberta", "Blenderbot")):
+ seq_length = 128
+ elif model_name.startswith(("GPT2", "Bart", "T5")):
+ seq_length = 1024
+ elif model_name in ("AllenaiLongformerBase", "BigBird"):
+ seq_length = 1024
+ elif "Reformer" in model_name:
+ seq_length = 4096
+ elif model_name.startswith(
+ ("Albert", "Deberta", "Layout", "Electra", "XLNet")
+ ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"):
+ seq_length = 512
+ else:
+ log.warning(
+ f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily"
+ )
+ seq_length = 128
+ return seq_length
+
+
+def generate_inputs_for_model(
+ model_cls, model, model_name, bs, device, include_loss_args=False
+):
+ # TODO - Check if following values are representative
+ num_choices = 3
+ num_visual_features = 42
+ seq_length = get_sequence_length(model_cls, model_name)
+ vocab_size = model.config.vocab_size
+ if model_name.endswith("MultipleChoice"):
+ input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length))
+ elif model_name.startswith("Roberta"):
+ input = rand_int_tensor(device, 0, 1, (bs, seq_length))
+ else:
+ input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length))
+
+ if "Bart" in model_name:
+ input[:, -1] = model.config.eos_token_id
+
+ input_dict = {"input_ids": input}
+
+ if (
+ model_name.startswith("T5")
+ or model_name.startswith("M2M100")
+ or model_name.startswith("MT5")
+ or model_cls
+ in [
+ BlenderbotModel,
+ BlenderbotSmallModel,
+ BlenderbotForConditionalGeneration,
+ BlenderbotSmallForConditionalGeneration,
+ PegasusModel,
+ PegasusForConditionalGeneration,
+ MarianModel,
+ MarianMTModel,
+ ]
+ ):
+ input_dict["decoder_input_ids"] = input
+
+ if model_name.startswith("Lxmert"):
+ visual_feat_dim, visual_pos_dim = (
+ model.config.visual_feat_dim,
+ model.config.visual_pos_dim,
+ )
+ input_dict["visual_feats"] = torch.randn(
+ bs, num_visual_features, visual_feat_dim
+ )
+ input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim)
+
+ if include_loss_args:
+ if model_name.endswith("PreTraining"):
+ if model_cls in [ElectraForPreTraining, LxmertForPreTraining]:
+ input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length))
+ else:
+ label_name = (
+ "sentence_order_label"
+ if model_cls in [AlbertForPreTraining]
+ else "next_sentence_label"
+ )
+ input_dict["labels"] = (
+ rand_int_tensor(device, 0, vocab_size, (bs, seq_length)),
+ )
+ input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,))
+ elif model_name.endswith("QuestionAnswering"):
+ input_dict["start_positions"] = rand_int_tensor(
+ device, 0, seq_length, (bs,)
+ )
+ input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,))
+ elif (
+ model_name.endswith("MaskedLM")
+ or model_name.endswith("HeadModel")
+ or model_name.endswith("CausalLM")
+ or model_name.endswith("DoubleHeadsModel")
+ ):
+ input_dict["labels"] = rand_int_tensor(
+ device, 0, vocab_size, (bs, seq_length)
+ )
+ elif model_name.endswith("TokenClassification"):
+ input_dict["labels"] = rand_int_tensor(
+ device, 0, model.config.num_labels - 1, (bs, seq_length)
+ )
+ elif model_name.endswith("MultipleChoice"):
+ input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,))
+ elif model_name.endswith("SequenceClassification"):
+ input_dict["labels"] = rand_int_tensor(
+ device, 0, model.config.num_labels - 1, (bs,)
+ )
+ elif model_name.endswith("NextSentencePrediction"):
+ input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,))
+ elif model_name.endswith("ForConditionalGeneration"):
+ input_dict["labels"] = rand_int_tensor(
+ device, 0, vocab_size - 1, (bs, seq_length)
+ )
+ elif model_name in EXTRA_MODELS:
+ input_dict["labels"] = rand_int_tensor(
+ device, 0, vocab_size, (bs, seq_length)
+ )
+ else:
+ raise NotImplementedError(
+ f"Class {model_name} unsupported for training test "
+ )
+
+ return input_dict
+
+
+def rand_int_tensor(device, low, high, shape):
+ return torch.randint(
+ low,
+ high,
+ shape,
+ device=device,
+ dtype=torch.int64,
+ requires_grad=False,
+ )
+
+
+EXTRA_MODELS = {
+ "AllenaiLongformerBase": (
+ AutoConfig.from_pretrained("allenai/longformer-base-4096"),
+ AutoModelForMaskedLM,
+ ),
+ "Reformer": (
+ ReformerConfig(),
+ AutoModelForMaskedLM,
+ ),
+ "T5Small": (
+ AutoConfig.from_pretrained("t5-small"),
+ AutoModelForSeq2SeqLM,
+ ),
+ "BigBird": (
+ BigBirdConfig(attention_type="block_sparse"),
+ AutoModelForMaskedLM,
+ ),
+ "DistillGPT2": (
+ AutoConfig.from_pretrained("distilgpt2"),
+ AutoModelForCausalLM,
+ ),
+ "GoogleFnet": (
+ AutoConfig.from_pretrained("google/fnet-base"),
+ AutoModelForMaskedLM,
+ ),
+ "YituTechConvBert": (
+ AutoConfig.from_pretrained("YituTech/conv-bert-base"),
+ AutoModelForMaskedLM,
+ ),
+ "CamemBert": (
+ AutoConfig.from_pretrained("camembert-base"),
+ AutoModelForMaskedLM,
+ ),
+}
+
+
+class HuggingfaceRunner(BenchmarkRunner):
+ def __init__(self):
+ super(HuggingfaceRunner, self).__init__()
+ self.suite_name = "huggingface"
+
+ def load_model(
+ self,
+ device,
+ model_name,
+ batch_size=None,
+ ):
+
+ is_training = self.args.training
+ use_eval_mode = self.args.use_eval_mode
+ dtype = torch.float32
+ if model_name not in EXTRA_MODELS:
+ model_cls = get_module_cls_by_model_name(model_name)
+ config_cls = model_cls.config_class
+ config = config_cls()
+
+ # NB: some models need a pad token defined to handle BS > 1
+ if (
+ model_cls
+ in [
+ GPT2ForSequenceClassification,
+ GPTNeoForSequenceClassification,
+ GPTJForSequenceClassification,
+ ]
+ or model_cls.__name__.startswith("Roberta")
+ or model_cls.__name__.startswith("Marian")
+ ):
+ config.pad_token_id = 0
+
+ else:
+ config, model_cls = EXTRA_MODELS[model_name]
+
+ if "auto" in model_cls.__module__:
+ # Handle auto classes
+ model = model_cls.from_config(config).to(device, dtype=dtype)
+ else:
+ model = model_cls(config).to(device, dtype=dtype)
+
+ if model_name in BATCH_SIZE_KNOWN_MODELS:
+ batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name]
+ elif batch_size is None:
+ batch_size_default = 16
+ log.warning(
+ "Batch size not specified for {model_name}. Setting batch_size=16"
+ )
+
+ if batch_size is None:
+ batch_size = batch_size_default
+ if model_name in USE_SMALL_BATCH_SIZE:
+ batch_size = USE_SMALL_BATCH_SIZE[model_name]
+ log.warning(
+ f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
+ )
+ elif USE_HALF_BATCH_SIZE and batch_size >= 2:
+ batch_size = int(batch_size / 2)
+ log.warning(
+ f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
+ )
+
+ example_inputs = generate_inputs_for_model(
+ model_cls, model, model_name, batch_size, device, include_loss_args=True
+ )
+
+ # So we can check for correct gradients without eliminating the dropout computation
+ for attr in dir(config):
+ if "drop" in attr and isinstance(getattr(config, attr), float):
+ setattr(config, attr, 1e-30)
+
+ if is_training and not use_eval_mode:
+ model.train()
+ else:
+ model.eval()
+
+ self.init_optimizer(device, model.parameters())
+
+ self.validate_model(model, example_inputs)
+ return device, model_name, model, example_inputs, batch_size
+
+ def iter_model_names(self, args):
+ model_names = list(BATCH_SIZE_KNOWN_MODELS.keys()) + list(EXTRA_MODELS.keys())
+ model_names = set(model_names)
+ model_names = sorted(model_names)
+
+ start, end = self.get_benchmark_indices(len(model_names))
+ for index, model_name in enumerate(model_names):
+ if index < start or index >= end:
+ continue
+ if (
+ not re.search("|".join(args.filter), model_name, re.I)
+ or re.search("|".join(args.exclude), model_name, re.I)
+ or model_name in SKIP
+ ):
+ continue
+ yield model_name
+
+ def pick_grad(self, name, is_training):
+ if is_training:
+ return torch.enable_grad()
+ else:
+ return torch.no_grad()
+
+ def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
+ cosine = self.args.cosine
+ if is_training:
+ return 1e-2, cosine
+ return 1e-3, cosine
+
+ def compute_loss(self, pred):
+ return pred[0]
+
+ def forward_pass(self, mod, inputs, collect_outputs=True):
+ return mod(**inputs)
+
+ def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
+ cloned_inputs = clone_inputs(inputs)
+ mod.zero_grad(True)
+ with self.autocast():
+ pred = mod(**cloned_inputs)
+ loss = self.compute_loss(pred)
+ self.grad_scaler.scale(loss).backward()
+ self.optimizer_step()
+ if collect_outputs:
+ return collect_results(mod, pred, loss, cloned_inputs)
+ return None
+
+
+def refresh_model_names_and_batch_sizes():
+ """
+ This function reads the HF Fx tracer supported models and finds the largest
+ batch size that could fit on the GPU with PyTorch eager.
+
+ The resulting data is written in huggingface_models_list.txt.
+
+ Note - We only need to run this function if we believe that HF Fx tracer now
+ supports more models.
+ """
+ import transformers.utils.fx as hf_fx
+
+ family = dict()
+ lm_seen = set()
+ family_seen = set()
+ for cls_name in hf_fx._SUPPORTED_MODELS:
+
+ if "For" not in cls_name:
+ continue
+
+ model_cls = get_module_cls_by_model_name(cls_name)
+
+ # TODO: AttributeError: '*Config' object has no attribute 'vocab_size'
+ if model_cls in [
+ CLIPModel,
+ CLIPVisionModel,
+ SwinForImageClassification,
+ SwinForImageClassification,
+ SwinForMaskedImageModeling,
+ SwinModel,
+ ViTForImageClassification,
+ ViTForMaskedImageModeling,
+ ViTModel,
+ ]:
+ continue
+
+ # TODO: AssertionError: Padding_idx must be within num_embeddings
+ if model_cls in [MarianForCausalLM, MarianMTModel, MarianModel]:
+ continue
+
+ # TODO: "model is not supported yet" from HFTracer
+ if model_cls in [HubertForSequenceClassification]:
+ continue
+
+ # TODO: shape mismatch in loss calculation
+ if model_cls in [LxmertForQuestionAnswering]:
+ continue
+
+ family_name = cls_name.split("For")[0]
+ if family_name not in family:
+ family[family_name] = []
+ if cls_name.endswith(("MaskedLM", "CausalLM")) and family_name not in lm_seen:
+ family[family_name].append(cls_name)
+ lm_seen.add(family_name)
+ elif (
+ cls_name.endswith(
+ ("SequenceClassification", "ConditionalGeneration", "QuestionAnswering")
+ )
+ and family_name not in family_seen
+ ):
+ family[family_name].append(cls_name)
+ family_seen.add(family_name)
+ elif cls_name.endswith("ImageClassification"):
+ family[family_name].append(cls_name)
+
+ chosen_models = set()
+ for members in family.values():
+ chosen_models.update(set(members))
+
+ # Add the EXTRA_MODELS
+ chosen_models.update(set(EXTRA_MODELS.keys()))
+
+ for model_name in sorted(chosen_models):
+ try:
+ subprocess.check_call(
+ [sys.executable]
+ + sys.argv
+ + ["--find-batch-sizes"]
+ + [f"--only={model_name}"]
+ + [f"--output={MODELS_FILENAME}"]
+ )
+ except subprocess.SubprocessError:
+ log.warning(f"Failed to find suitable batch size for {model_name}")
+
+
+if __name__ == "__main__":
+ # Code to refresh model names and batch sizes
+ # if "--find-batch-sizes" not in sys.argv:
+ # refresh_model_names_and_batch_sizes()
+ logging.basicConfig(level=logging.WARNING)
+ warnings.filterwarnings("ignore")
+ main(HuggingfaceRunner())
diff --git a/benchmarks/dynamo/huggingface_models_list.txt b/benchmarks/dynamo/huggingface_models_list.txt
new file mode 100644
index 0000000000000..8272c79b12bda
--- /dev/null
+++ b/benchmarks/dynamo/huggingface_models_list.txt
@@ -0,0 +1,53 @@
+AlbertForMaskedLM,8
+AlbertForQuestionAnswering,8
+AllenaiLongformerBase,1
+BartForCausalLM,16
+BartForConditionalGeneration,4
+BertForMaskedLM,128
+BertForQuestionAnswering,128
+BigBird,1
+BlenderbotForCausalLM,32
+BlenderbotForConditionalGeneration,32
+BlenderbotSmallForCausalLM,128
+BlenderbotSmallForConditionalGeneration,128
+CamemBert,1
+DebertaForMaskedLM,32
+DebertaForQuestionAnswering,32
+DebertaV2ForMaskedLM,8
+DebertaV2ForQuestionAnswering,8
+DistilBertForMaskedLM,64
+DistilBertForQuestionAnswering,64
+DistillGPT2,1
+ElectraForCausalLM,64
+ElectraForQuestionAnswering,128
+GPT2ForSequenceClassification,8
+GPTJForCausalLM,1
+GPTJForQuestionAnswering,1
+GPTNeoForCausalLM,8
+GPTNeoForSequenceClassification,8
+GoogleFnet,1
+LayoutLMForMaskedLM,32
+LayoutLMForSequenceClassification,32
+M2M100ForConditionalGeneration,8
+MBartForCausalLM,32
+MBartForConditionalGeneration,16
+MT5ForConditionalGeneration,8
+MegatronBertForCausalLM,16
+MegatronBertForQuestionAnswering,16
+MobileBertForMaskedLM,32
+MobileBertForQuestionAnswering,64
+OPTForCausalLM,32
+PLBartForCausalLM,32
+PLBartForConditionalGeneration,16
+PegasusForCausalLM,32
+PegasusForConditionalGeneration,16
+Reformer,1
+RobertaForCausalLM,128
+RobertaForQuestionAnswering,128
+Speech2Text2ForCausalLM,128
+T5ForConditionalGeneration,8
+T5Small,1
+TrOCRForCausalLM,32
+XGLMForCausalLM,8
+XLNetLMHeadModel,128
+YituTechConvBert,1
diff --git a/benchmarks/dynamo/microbenchmarks/__init__.py b/benchmarks/dynamo/microbenchmarks/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/benchmarks/dynamo/microbenchmarks/bench_autotune_conv.py b/benchmarks/dynamo/microbenchmarks/bench_autotune_conv.py
new file mode 100644
index 0000000000000..ca8aeca85a284
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/bench_autotune_conv.py
@@ -0,0 +1,170 @@
+import model
+import torch
+
+import torch._dynamo
+import torch._inductor
+import torch._inductor.config as config
+import torch._inductor.triton_ops
+import triton
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
+torch.backends.cuda.matmul.allow_tf32 = True
+# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
+torch.backends.cudnn.allow_tf32 = True
+# config.debug = True
+config.triton.convolution = "autotune"
+
+
+# conv benchmarks
+conv_confs = [
+ triton.testing.Benchmark(
+ x_names=["layout"],
+ x_vals=["nchw", "nhwc"],
+ line_arg="provider",
+ line_vals=["aten", "autotune", "triton_conv", "triton_conv1x1"],
+ line_names=["aten", "autotune", "triton_conv", "triton_conv1x1"],
+ ylabel="TFLOPS",
+ plot_name=f"resnet50-conv{i}-perf",
+ args={
+ "BATCH": BATCH,
+ "IN_H": IN_H,
+ "IN_W": IN_W,
+ "IN_C": IN_C,
+ "KERNEL_N": KERNEL_N,
+ "KERNEL_H": KERNEL_H,
+ "KERNEL_W": KERNEL_W,
+ "stride": stride,
+ "padding": padding,
+ },
+ )
+ for i, (
+ IN_H,
+ IN_W,
+ IN_C,
+ KERNEL_H,
+ KERNEL_W,
+ KERNEL_N,
+ stride,
+ padding,
+ ) in enumerate(model.resnet50_layers)
+ for BATCH in [32]
+]
+
+
+@triton.testing.perf_report(conv_confs)
+def bench_op(
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ # provider
+ provider,
+ # parameters of conv
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ groups=1,
+ dtype=torch.float32,
+ layout="nhwc",
+ warmup=25,
+ rep=75,
+):
+
+ skip = False
+ # allocate inputs, nchw
+ x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
+ w = torch.randn(
+ (KERNEL_N, IN_C // groups, KERNEL_H, KERNEL_W), dtype=dtype, device="cuda"
+ )
+ bias = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
+ if layout == "nhwc":
+ x = x.to(memory_format=torch.channels_last)
+ w = w.to(memory_format=torch.channels_last)
+ OUT_H = (
+ IN_H + 2 * padding[0] - dilation[0] * (KERNEL_H - 1) - 1 + stride[0]
+ ) // stride[0]
+ OUT_W = (
+ IN_W + 2 * padding[1] - dilation[1] * (KERNEL_W - 1) - 1 + stride[1]
+ ) // stride[1]
+
+ tflops = (
+ lambda ms: 2.0
+ * BATCH
+ * OUT_H
+ * OUT_W
+ * IN_C
+ * KERNEL_H
+ * KERNEL_W
+ * KERNEL_N
+ / ms
+ * 1e-9
+ )
+ if provider == "aten":
+
+ def fn():
+ return torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+
+ elif provider == "triton_conv":
+
+ def fn():
+ return torch._inductor.triton_ops.conv(
+ x, w, bias, stride, padding, dilation, False, (0, 0), groups
+ )
+
+ elif provider == "triton_conv1x1":
+
+ def fn():
+ return torch._inductor.triton_ops.conv1x1(
+ x, w, bias, stride, padding, dilation, False, (0, 0), groups
+ )
+
+ if KERNEL_H != 1 or KERNEL_W != 1:
+ skip = True
+
+ elif provider == "autotune":
+
+ @torch._dynamo.optimize("inductor")
+ def wrap_conv(*args, **kwargs):
+ return torch.conv2d(*args, **kwargs)
+
+ def fn():
+ return wrap_conv(x, w, bias, stride, padding, dilation, groups)
+
+ # use cuda graph for fair comparison
+ elif provider != "autotune" and not skip:
+ # prepare new tensor
+ new_x = x.clone()
+ new_w = w.clone()
+ new_bias = bias.clone()
+
+ # warmp up for cudagraph
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for i in range(3):
+ fn()
+ torch.cuda.current_stream().wait_stream(s)
+
+ # capture
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g):
+ fn()
+
+ def fn():
+ x.copy_(new_x)
+ w.copy_(new_w)
+ bias.copy_(new_bias)
+ return g.replay()
+
+ if not skip:
+ ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
+ return tflops(ms), tflops(max_ms), tflops(min_ms)
+ else:
+ return 0, 0, 0
+
+
+bench_op.run(print_data=True)
diff --git a/benchmarks/dynamo/microbenchmarks/bench_conv.py b/benchmarks/dynamo/microbenchmarks/bench_conv.py
new file mode 100644
index 0000000000000..6279af6854a1b
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/bench_conv.py
@@ -0,0 +1,144 @@
+import model
+import torch
+
+import torch._inductor.triton_ops
+import triton
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
+torch.backends.cuda.matmul.allow_tf32 = True
+# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
+torch.backends.cudnn.allow_tf32 = True
+
+# https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/
+useCudaGraph = False
+
+# conv benchmarks
+conv_confs = [
+ triton.testing.Benchmark(
+ x_names=["layout"],
+ x_vals=["nchw", "nhwc"],
+ line_arg="provider",
+ line_vals=["cublas", "triton"],
+ line_names=["cuBLAS", "Triton"],
+ ylabel="TFLOPS",
+ plot_name=f"resnet50-conv{i}-perf",
+ args={
+ "BATCH": BATCH,
+ "IN_H": IN_H,
+ "IN_W": IN_W,
+ "IN_C": IN_C,
+ "KERNEL_N": KERNEL_N,
+ "KERNEL_H": KERNEL_H,
+ "KERNEL_W": KERNEL_W,
+ "stride": stride,
+ "padding": padding,
+ },
+ )
+ for i, (
+ IN_H,
+ IN_W,
+ IN_C,
+ KERNEL_H,
+ KERNEL_W,
+ KERNEL_N,
+ stride,
+ padding,
+ ) in enumerate(model.resnet50_layers)
+ for BATCH in [32]
+]
+
+
+@triton.testing.perf_report(conv_confs)
+def bench_op(
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ # provider
+ provider,
+ # parameters of conv
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ groups=1,
+ dtype=torch.float32,
+ layout="nhwc",
+ warmup=25,
+ rep=75,
+):
+
+ # allocate inputs, nchw
+ x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
+ w = torch.randn(
+ (KERNEL_N, IN_C // groups, KERNEL_H, KERNEL_W), dtype=dtype, device="cuda"
+ )
+ bias = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
+ if layout == "nhwc":
+ x = x.to(memory_format=torch.channels_last)
+ w = w.to(memory_format=torch.channels_last)
+ OUT_H = (
+ IN_H + 2 * padding[0] - dilation[0] * (KERNEL_H - 1) - 1 + stride[0]
+ ) // stride[0]
+ OUT_W = (
+ IN_W + 2 * padding[1] - dilation[1] * (KERNEL_W - 1) - 1 + stride[1]
+ ) // stride[1]
+
+ tflops = (
+ lambda ms: 2.0
+ * BATCH
+ * OUT_H
+ * OUT_W
+ * IN_C
+ * KERNEL_H
+ * KERNEL_W
+ * KERNEL_N
+ / ms
+ * 1e-9
+ )
+ if provider == "cublas":
+
+ def fn():
+ return torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+
+ elif provider == "triton":
+
+ def fn():
+ return torch._inductor.triton_ops.conv(
+ x, w, bias, stride, padding, dilation, False, (0, 0), groups
+ )
+
+ # useCudaGraph won't change the TFLOPs,
+ # because do_bench() clear L2 cache to hide the latency of CPU launch time
+ if useCudaGraph:
+ new_x = x.clone()
+ new_w = w.clone()
+ new_bias = bias.clone()
+
+ # warmp up for cudagraph
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for i in range(3):
+ fn()
+ torch.cuda.current_stream().wait_stream(s)
+
+ # capture
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g):
+ fn()
+
+ def fn():
+ x.copy_(new_x)
+ w.copy_(new_w)
+ bias.copy_(new_bias)
+ return g.replay()
+
+ ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
+ return tflops(ms), tflops(max_ms), tflops(min_ms)
+
+
+bench_op.run(print_data=True)
diff --git a/benchmarks/dynamo/microbenchmarks/bench_conv1x1.py b/benchmarks/dynamo/microbenchmarks/bench_conv1x1.py
new file mode 100644
index 0000000000000..bb70aed272065
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/bench_conv1x1.py
@@ -0,0 +1,140 @@
+import model
+import torch
+
+import torch._inductor.triton_ops
+import triton
+
+# https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/
+useCudaGraph = False
+
+# conv benchmarks
+conv_confs = [
+ triton.testing.Benchmark(
+ x_names=["layout"],
+ x_vals=["nchw", "nhwc"],
+ line_arg="provider",
+ line_vals=["cublas", "triton"],
+ line_names=["cuBLAS", "Triton"],
+ ylabel="TFLOPS",
+ plot_name=f"resnet50-conv1x1-{i}-performance",
+ args={
+ "BATCH": BATCH,
+ "IN_H": IN_H,
+ "IN_W": IN_W,
+ "IN_C": IN_C,
+ "KERNEL_N": KERNEL_N,
+ "KERNEL_H": KERNEL_H,
+ "KERNEL_W": KERNEL_W,
+ "stride": stride,
+ "padding": padding,
+ },
+ )
+ for i, (
+ IN_H,
+ IN_W,
+ IN_C,
+ KERNEL_H,
+ KERNEL_W,
+ KERNEL_N,
+ stride,
+ padding,
+ ) in enumerate(model.resnet50_layers)
+ if KERNEL_H == 1 and KERNEL_W == 1
+ for BATCH in [32]
+]
+
+
+@triton.testing.perf_report(conv_confs)
+def bench_op(
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ # provider
+ provider,
+ # parameters of conv
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ groups=1,
+ dtype=torch.float32,
+ layout="nhwc",
+ warmup=25,
+ rep=75,
+):
+
+ # allocate inputs, nchw
+ x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
+ w = torch.randn(
+ (KERNEL_N, IN_C // groups, KERNEL_H, KERNEL_W), dtype=dtype, device="cuda"
+ )
+ bias = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
+ if layout == "nhwc":
+ x = x.to(memory_format=torch.channels_last)
+ w = w.to(memory_format=torch.channels_last)
+ OUT_H = (
+ IN_H + 2 * padding[0] - dilation[0] * (KERNEL_H - 1) - 1 + stride[0]
+ ) // stride[0]
+ OUT_W = (
+ IN_W + 2 * padding[1] - dilation[1] * (KERNEL_W - 1) - 1 + stride[1]
+ ) // stride[1]
+
+ tflops = (
+ lambda ms: 2.0
+ * BATCH
+ * OUT_H
+ * OUT_W
+ * IN_C
+ * KERNEL_H
+ * KERNEL_W
+ * KERNEL_N
+ / ms
+ * 1e-9
+ )
+
+ if provider == "cublas":
+
+ def fn():
+ return torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+
+ elif provider == "triton":
+
+ def fn():
+ return torch._inductor.triton_ops.conv1x1(
+ x, w, bias, stride, padding, dilation, False, (0, 0), groups
+ )
+
+ if useCudaGraph:
+ # prepare new data
+ new_x = x.clone()
+ new_w = w.clone()
+ new_bias = bias.clone()
+
+ # warmp up for cudagraph
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for i in range(3):
+ fn()
+ torch.cuda.current_stream().wait_stream(s)
+
+ # capture
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g):
+ fn()
+
+ def fn():
+ x.copy_(new_x)
+ w.copy_(new_w)
+ bias.copy_(new_bias)
+ return g.replay()
+
+ ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
+ return tflops(ms), tflops(max_ms), tflops(min_ms)
+
+
+bench_op.run(print_data=True)
diff --git a/benchmarks/dynamo/microbenchmarks/bench_conv_fusion.py b/benchmarks/dynamo/microbenchmarks/bench_conv_fusion.py
new file mode 100644
index 0000000000000..d36c37c5a204c
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/bench_conv_fusion.py
@@ -0,0 +1,298 @@
+# flake8: noqa
+import model
+import torch
+
+import torch._dynamo
+import torch._inductor.config
+import triton
+from prettytable import PrettyTable
+
+# torch._inductor.config.debug = True
+torch._inductor.config.triton.convolution = "triton"
+torch._inductor.config.triton.dense_indexing = True
+torch.manual_seed(0)
+useCudaGraph = True
+
+
+class Func(object):
+ # conv
+ @torch._dynamo.optimize("inductor")
+ def conv_torchinductor(x, w, bias, stride, padding, dilation, groups):
+ y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
+ return y
+
+ # conv
+ def conv(x, w, bias, stride, padding, dilation, groups):
+ y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
+ return y
+
+ # conv+bias
+ @torch._dynamo.optimize("inductor")
+ def conv_add_torchinductor(x, w, bias, stride, padding, dilation, groups):
+ y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+ return y
+
+ # conv+bias
+ def conv_add(x, w, bias, stride, padding, dilation, groups):
+ y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+ return y
+
+ # relu(conv)
+ @torch._dynamo.optimize("inductor")
+ def conv_relu_torchinductor(x, w, bias, stride, padding, dilation, groups):
+ y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
+ return torch.relu(y)
+
+ # relu(conv)
+ def conv_relu(x, w, bias, stride, padding, dilation, groups):
+ y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
+ return torch.relu(y)
+
+ # relu(conv+bias)
+ @torch._dynamo.optimize("inductor")
+ def conv_add_relu_torchinductor(x, w, bias, stride, padding, dilation, groups):
+ y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+ return torch.relu(y)
+
+ # relu(conv+bias)
+ def conv_add_relu(x, w, bias, stride, padding, dilation, groups):
+ y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+ return torch.relu(y)
+
+ # bn(conv)
+ @torch._dynamo.optimize("inductor")
+ def conv_bn_torchinductor(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups,
+ running_mean,
+ running_var,
+ bn_weight,
+ bn_bias,
+ ):
+ y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
+ y = torch.batch_norm(
+ y,
+ weight=bn_weight,
+ bias=bn_bias,
+ running_mean=running_mean,
+ running_var=running_var,
+ training=False,
+ momentum=1,
+ eps=1e-5,
+ cudnn_enabled=True,
+ )
+ return y
+
+ # bn(conv)
+ def conv_bn(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups,
+ running_mean,
+ running_var,
+ bn_weight,
+ bn_bias,
+ ):
+ y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
+ y = torch.batch_norm(
+ y,
+ weight=bn_weight,
+ bias=bn_bias,
+ running_mean=running_mean,
+ running_var=running_var,
+ training=False,
+ momentum=1,
+ eps=1e-5,
+ cudnn_enabled=True,
+ )
+ return y
+
+ # relu(bn(conv))
+ @torch._dynamo.optimize("inductor")
+ def conv_bn_relu_torchinductor(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups,
+ running_mean,
+ running_var,
+ bn_weight,
+ bn_bias,
+ ):
+ y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
+ y = torch.batch_norm(
+ y,
+ weight=bn_weight,
+ bias=bn_bias,
+ running_mean=running_mean,
+ running_var=running_var,
+ training=False,
+ momentum=1,
+ eps=1e-5,
+ cudnn_enabled=True,
+ )
+ return torch.relu(y)
+
+ # relu(bn(conv))
+ def conv_bn_relu(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups,
+ running_mean,
+ running_var,
+ bn_weight,
+ bn_bias,
+ ):
+ y = torch.conv2d(x, w, None, stride, padding, dilation, groups)
+ y = torch.batch_norm(
+ y,
+ weight=bn_weight,
+ bias=bn_bias,
+ running_mean=running_mean,
+ running_var=running_var,
+ training=False,
+ momentum=1,
+ eps=1e-5,
+ cudnn_enabled=True,
+ )
+ return torch.relu(y)
+
+
+def cuda_graph(fn, x, w, bias):
+ new_x = x.clone()
+ new_w = w.clone()
+ if bias is not None:
+ new_bias = bias.clone()
+
+ # warmp up for cudagraph
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for i in range(3):
+ fn()
+ torch.cuda.current_stream().wait_stream(s)
+
+ # capture
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g):
+ fn()
+
+ def fn():
+ x.copy_(new_x)
+ w.copy_(new_w)
+ if bias is not None:
+ bias.copy_(new_bias)
+ return g.replay()
+
+ return fn
+
+
+def bench(layer_params, layer_id, p, fusion_types=[""]):
+ BATCH = 32
+ IN_H, IN_W, IN_C, KERNEL_H, KERNEL_W, KERNEL_N, stride, padding = layer_params
+ dilation, groups = (1, 1), 1
+ dtype = torch.float32
+
+ OUT_H = (
+ IN_H + 2 * padding[0] - dilation[0] * (KERNEL_H - 1) - 1 + stride[0]
+ ) // stride[0]
+ OUT_W = (
+ IN_W + 2 * padding[1] - dilation[1] * (KERNEL_W - 1) - 1 + stride[1]
+ ) // stride[1]
+ tflops = (
+ lambda ms: 2.0
+ * BATCH
+ * OUT_H
+ * OUT_W
+ * IN_C
+ * KERNEL_H
+ * KERNEL_W
+ * KERNEL_N
+ / ms
+ * 1e-9
+ )
+
+ # allocate inputs, nchw
+ x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
+ w = torch.randn(
+ (KERNEL_N, IN_C // groups, KERNEL_H, KERNEL_W), dtype=dtype, device="cuda"
+ )
+
+ row = [layer_id]
+ for fusion_type in fusion_types:
+
+ if fusion_type == "":
+ conv_torchinductor = getattr(Func, "conv_torchinductor")
+ conv = getattr(Func, "conv")
+ else:
+ conv_torchinductor = getattr(Func, f"conv_{fusion_type}_torchinductor")
+ conv = getattr(Func, f"conv_{fusion_type}")
+
+ if "add" in fusion_type:
+ bias = torch.randn((KERNEL_N,), dtype=dtype, device="cuda")
+ else:
+ bias = None
+
+ args = (x, w, bias, stride, padding, dilation, groups)
+
+ if "bn" in fusion_type:
+ running_mean = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
+ running_var = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
+ bn_weight = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
+ bn_bias = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
+ args += (
+ running_mean,
+ running_var,
+ bn_weight,
+ bn_bias,
+ )
+
+ def fn_conv():
+ return conv(*args)
+
+ def fn_conv_torchinductor():
+ return conv_torchinductor(*args)
+
+ if useCudaGraph:
+ fn_conv = cuda_graph(fn_conv, x, w, bias)
+
+ torch_conv_ms, _, _ = triton.testing.do_bench(fn_conv)
+ triton_conv_ms, _, _ = triton.testing.do_bench(fn_conv_torchinductor)
+ row.extend([tflops(torch_conv_ms), tflops(triton_conv_ms)])
+
+ p.add_row(row)
+
+
+fusion_types = ["", "add", "relu", "add_relu", "bn", "bn_relu"]
+p = PrettyTable()
+field_names = ["layer"]
+for fusion_type in fusion_types:
+ if fusion_type == "":
+ field_names.append("torch conv")
+ field_names.append("triton conv")
+ else:
+ field_names.append(f"torch conv+{fusion_type}")
+ field_names.append(f"triton conv+{fusion_type}")
+
+p.field_names = field_names
+p.float_format = ".3"
+for id, layer in enumerate(model.resnet50_layers):
+ bench(layer, id, p, fusion_types)
+
+print(p)
diff --git a/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py b/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py
new file mode 100644
index 0000000000000..eb7ce72aea35f
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py
@@ -0,0 +1,121 @@
+# flake8: noqa
+import torch
+
+import torch._dynamo
+import torch._inductor.config
+import triton
+from prettytable import PrettyTable
+
+# torch._inductor.config.debug = True
+torch._inductor.config.triton.dense_indexing = True
+torch.manual_seed(0)
+
+
+# The flag below controls whether to allow TF32 on matmul.
+torch.backends.cuda.matmul.allow_tf32 = True
+
+
+class Func(object):
+ # mm
+ @torch._dynamo.optimize("inductor")
+ def mm(a, b, bias):
+ y = torch.mm(a, b)
+ return y
+
+ # mm+bias
+ @torch._dynamo.optimize("inductor")
+ def mm_add(a, b, bias):
+ y = torch.mm(a, b)
+ return y + bias
+
+ # relu(mm)
+ @torch._dynamo.optimize("inductor")
+ def mm_relu(a, b, bias):
+ y = torch.mm(a, b)
+ return torch.relu(y)
+
+ # relu(mm+bias)
+ @torch._dynamo.optimize("inductor")
+ def mm_add_relu(a, b, bias):
+ y = torch.mm(a, b)
+ y += bias
+ return torch.relu(y)
+
+
+def bench(shape, layer_id, p, fusion_types=[""]):
+ dtype = torch.float16
+ M, K = shape[0]
+ _, N = shape[1]
+ torch.manual_seed(0)
+ # allocate inputs
+ a = torch.randn(shape[0], device="cuda", dtype=dtype)
+ b = torch.randn(shape[1], device="cuda", dtype=dtype)
+
+ def tflops(ms):
+ return M * K * N / ms * 1e-9
+
+ row = [layer_id]
+ for fusion_type in fusion_types:
+
+ if fusion_type == "":
+ fn_mm = getattr(Func, "mm")
+ else:
+ fn_mm = getattr(Func, f"mm_{fusion_type}")
+
+ if "add" in fusion_type:
+ bias = torch.randn((M, N), dtype=dtype, device="cuda")
+ else:
+ bias = None
+
+ args = (a, b, bias)
+
+ def fn():
+ return fn_mm(*args)
+
+ torch._inductor.config.triton.mm = "aten"
+ torch_mm_ms, _, _ = triton.testing.do_bench(fn)
+ torch._inductor.config.triton.mm = "triton"
+ # reset to force code gen new python code
+ torch._dynamo.reset()
+ torch._inductor.metrics.reset()
+ triton_mm_ms, _, _ = triton.testing.do_bench(fn)
+ assert (
+ torch._inductor.metrics.generated_kernel_count == 1
+ ), "codegen #kernel != 1"
+ row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)])
+
+ p.add_row(row)
+
+
+fusion_types = ["", "add", "relu", "add_relu"]
+shapes = [
+ # alexnet
+ ([128, 9216], [9216, 4096]),
+ ([128, 4096], [4096, 4096]),
+ ([128, 4096], [4096, 1000]),
+ # BERT
+ ([2048, 768], [768, 768]),
+ ([2048, 768], [768, 3072]),
+ ([2048, 3072], [3072, 768]),
+ # hf_GPT2
+ ([1024, 768], [768, 768]),
+ ([1024, 768], [768, 3072]),
+ ([1024, 3072], [3072, 768]),
+ ([1024, 768], [768, 2304]),
+]
+p = PrettyTable()
+field_names = ["layer"]
+for fusion_type in fusion_types:
+ if fusion_type == "":
+ field_names.append("torch mm")
+ field_names.append("triton mm")
+ else:
+ field_names.append(f"torch mm+{fusion_type}")
+ field_names.append(f"triton mm+{fusion_type}")
+
+p.field_names = field_names
+p.float_format = ".3"
+for id, shape in enumerate(shapes):
+ bench(shape, id, p, fusion_types)
+
+print(p)
diff --git a/benchmarks/dynamo/microbenchmarks/benchmark_helper.py b/benchmarks/dynamo/microbenchmarks/benchmark_helper.py
new file mode 100644
index 0000000000000..971d7c15c8cd6
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/benchmark_helper.py
@@ -0,0 +1,13 @@
+from torch.utils.benchmark import Timer
+
+
+def time_with_torch_timer(fn, args, kwargs=None, iters=100):
+ kwargs = kwargs or {}
+ env = {"args": args, "kwargs": kwargs, "fn": fn}
+ fn_call = "fn(*args, **kwargs)"
+
+ # Measure end-to-end time
+ timer = Timer(stmt=f"{fn_call}", globals=env)
+ tt = timer.timeit(iters)
+
+ return tt
diff --git a/benchmarks/dynamo/microbenchmarks/inductor_bmm.py b/benchmarks/dynamo/microbenchmarks/inductor_bmm.py
new file mode 100644
index 0000000000000..7ac296a58ad8c
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/inductor_bmm.py
@@ -0,0 +1,61 @@
+import torch
+
+import torch._dynamo
+import torch._dynamo.config
+import torch._inductor.config as config
+from benchmark_helper import time_with_torch_timer
+
+
+@torch._dynamo.optimize("inductor", nopython=True)
+def inductor_aten_bmm(a, b):
+ return torch.bmm(a, b)
+
+
+@torch._dynamo.optimize("inductor", nopython=True)
+def inductor_triton_bmm(a, b):
+ return torch.bmm(a, b)
+
+
+def torch_bmm(a, b):
+ return torch.bmm(a, b)
+
+
+def test_total_time(shapes):
+ print("shape; torch bmm; inductor aten bmm; inductor triton bmm")
+ for i in range(len(shapes)):
+ a_shape, b_shape = shapes[i]
+ print(a_shape, "x", b_shape, end="; ")
+ a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
+ b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
+
+ config.triton.use_bmm = False
+ inductor_aten_bmm(a, b)
+
+ config.triton.use_bmm = True
+ inductor_triton_bmm(a, b)
+
+ torch_ms = time_with_torch_timer(torch_bmm, (a, b)).mean * 1000
+
+ config.triton.use_bmm = False
+ ind_aten_ms = time_with_torch_timer(inductor_aten_bmm, (a, b)).mean * 1000
+
+ config.triton.use_bmm = True
+ ind_triton_ms = time_with_torch_timer(inductor_triton_bmm, (a, b)).mean * 1000
+
+ print(torch_ms, ind_aten_ms, ind_triton_ms, sep="; ")
+
+
+if __name__ == "__main__":
+ shapes = [
+ # BERT (all)
+ ([192, 128, 64], [192, 64, 128]),
+ ([192, 128, 128], [192, 128, 64]),
+ # hf_GPT2 (all)
+ ([12, 1024, 1024], [12, 1024, 64]),
+ ([12, 1024, 64], [12, 64, 1024]),
+ # hf_Albert (all)
+ ([12, 512, 64], [12, 64, 512]),
+ ([12, 512, 512], [12, 512, 64]),
+ ]
+
+ test_total_time(shapes)
diff --git a/benchmarks/dynamo/microbenchmarks/inductor_mm.py b/benchmarks/dynamo/microbenchmarks/inductor_mm.py
new file mode 100644
index 0000000000000..deb3d8f8b6042
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/inductor_mm.py
@@ -0,0 +1,134 @@
+import torch
+
+import torch._dynamo
+import torch._dynamo.config
+import torch._inductor.config as config
+import triton
+from benchmark_helper import time_with_torch_timer
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
+torch.backends.cuda.matmul.allow_tf32 = True
+# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
+torch.backends.cudnn.allow_tf32 = True
+
+
+@torch._dynamo.optimize("inductor", nopython=True)
+def inductor_aten_mm(a, b):
+ return torch.mm(a, b)
+
+
+@torch._dynamo.optimize("inductor", nopython=True)
+def inductor_triton_mm(a, b):
+ return torch.mm(a, b)
+
+
+def torch_mm(a, b):
+ return torch.mm(a, b)
+
+
+def triton_mm(a, b):
+ return triton.ops.matmul(a, b)
+
+
+def test_total_time(shapes):
+ print("shape; torch mm; triton mm; inductor aten mm; inductor triton mm")
+ for i in range(len(shapes)):
+ a_shape, b_shape = shapes[i]
+ print(a_shape, "x", b_shape, end="; ")
+ a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
+ b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
+
+ config.triton.mm = "aten"
+ inductor_aten_mm(a, b)
+
+ config.triton.mm = "triton"
+ inductor_triton_mm(a, b)
+
+ torch_ms = time_with_torch_timer(torch_mm, (a, b)).mean * 1000
+
+ triton_ms = time_with_torch_timer(triton_mm, (a, b)).mean * 1000
+
+ config.triton.mm = "aten"
+ ind_aten_ms = time_with_torch_timer(inductor_aten_mm, (a, b)).mean * 1000
+
+ config.triton.mm = "triton"
+ ind_triton_ms = time_with_torch_timer(inductor_triton_mm, (a, b)).mean * 1000
+
+ print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ")
+
+ torch._dynamo.reset()
+
+
+def test_GPU_time(shapes):
+ print("shape; torch mm; triton mm; inductor aten mm; inductor triton mm")
+ for i in range(len(shapes)):
+ a_shape, b_shape = shapes[i]
+ print(a_shape, "x", b_shape, end="; ")
+ a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
+ b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
+
+ config.triton.mm = "aten"
+ inductor_aten_mm(a, b)
+
+ config.triton.mm = "triton"
+ inductor_triton_mm(a, b)
+
+ torch_ms, _, _ = triton.testing.do_bench(lambda: torch_mm(a, b))
+ triton_ms, _, _ = triton.testing.do_bench(lambda: triton_mm(a, b))
+ ind_aten_ms, _, _ = triton.testing.do_bench(lambda: inductor_aten_mm(a, b))
+ ind_triton_ms, _, _ = triton.testing.do_bench(lambda: inductor_triton_mm(a, b))
+ print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ")
+
+ torch._dynamo.reset()
+
+
+if __name__ == "__main__":
+ shapes = [
+ # alexnet
+ ([128, 9216], [9216, 4096]),
+ ([128, 4096], [4096, 4096]),
+ ([128, 4096], [4096, 1000]),
+ # BERT
+ ([2048, 768], [768, 768]),
+ ([2048, 768], [768, 3072]),
+ ([2048, 3072], [3072, 768]),
+ # hf_GPT2
+ ([1024, 768], [768, 768]),
+ ([1024, 768], [768, 3072]),
+ ([1024, 3072], [3072, 768]),
+ ([1024, 768], [768, 2304]),
+ ]
+ print("test total time")
+ test_total_time(shapes)
+
+ print("test GPU time")
+ test_GPU_time(shapes)
+
+
+# Results Preview on AWS AI cluster
+"""
+test total time
+shape; torch mm; triton mm; inductor aten mm; inductor triton mm
+[128, 9216] x [9216, 4096]; 0.07240759208798409; 0.10885953903198242; 0.20063146017491817; 0.20054904278367758
+[128, 4096] x [4096, 4096]; 0.03640300128608942; 0.10960095096379519; 0.09948539081960917; 0.0996188772842288
+[128, 4096] x [4096, 1000]; 0.02215010579675436; 0.12592008337378502; 0.031120930798351765; 0.0370654184371233
+[2048, 768] x [768, 768]; 0.023501068353652954; 0.10804693214595318; 0.03004650119692087; 0.0276932492852211
+[2048, 768] x [768, 3072]; 0.045639658346772194; 0.10883208829909563; 0.062736920081079; 0.06480381824076176
+[2048, 3072] x [3072, 768]; 0.054093082435429096; 0.10804777964949608; 0.08744294755160809; 0.07766005117446184
+[1024, 768] x [768, 768]; 0.021525858901441097; 0.10909941978752613; 0.02656651195138693; 0.02683836966753006
+[1024, 768] x [768, 3072]; 0.027319076471030712; 0.10825308971107006; 0.040118801407516; 0.039282338693737984
+[1024, 3072] x [3072, 768]; 0.034132059663534164; 0.10594133753329515; 0.05069758277386427; 0.04572632722556591
+[1024, 768] x [768, 2304]; 0.02529360819607973; 0.10486091021448374; 0.03724239766597748; 0.036449190229177475
+test GPU time
+shape; torch mm; triton mm; inductor aten mm; inductor triton mm
+[128, 9216] x [9216, 4096]; 0.09113600105047226; 0.09011200070381165; 0.21606400609016418; 0.21606400609016418
+[128, 4096] x [4096, 4096]; 0.053247999399900436; 0.05222399905323982; 0.1157120019197464; 0.1157120019197464
+[128, 4096] x [4096, 1000]; 0.026623999699950218; 0.02969600073993206; 0.04710400104522705; 0.05222399905323982
+[2048, 768] x [768, 768]; 0.02457600086927414; 0.020479999482631683; 0.04095999896526337; 0.03993599861860275
+[2048, 768] x [768, 3072]; 0.05119999870657921; 0.05222399905323982; 0.07475200295448303; 0.07577600330114365
+[2048, 3072] x [3072, 768]; 0.05939200147986412; 0.05222399905323982; 0.09830400347709656; 0.0870399996638298
+[1024, 768] x [768, 768]; 0.01945599913597107; 0.016383999958634377; 0.03276799991726875; 0.03276799991726875
+[1024, 768] x [768, 3072]; 0.03174399957060814; 0.03276799991726875; 0.053247999399900436; 0.053247999399900436
+[1024, 3072] x [3072, 768]; 0.04403200000524521; 0.03379200026392937; 0.06860800087451935; 0.062463998794555664
+[1024, 768] x [768, 2304]; 0.02969600073993206; 0.02969600073993206; 0.04915200173854828; 0.048128001391887665
+"""
diff --git a/benchmarks/dynamo/microbenchmarks/matmul_relu.py b/benchmarks/dynamo/microbenchmarks/matmul_relu.py
new file mode 100644
index 0000000000000..629b574617ec3
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/matmul_relu.py
@@ -0,0 +1,100 @@
+import torch
+
+import torch._dynamo
+import torch._inductor.config as inductor_config
+from benchmark_helper import time_with_torch_timer
+
+inductor_config.triton.mm = "triton"
+
+
+@torch._dynamo.optimize("inductor", nopython=True)
+def inductor_mm(a, b):
+ return torch.mm(a, b)
+
+
+def torch_mm_relu(a, b):
+ return torch.nn.functional.relu(torch.mm(a, b))
+
+
+def torch_mm(a, b):
+ return torch.mm(a, b)
+
+
+if __name__ == "__main__":
+ # Real shapes from torchbench
+ a_shapes = [
+ [2048, 768],
+ [64, 1280],
+ [2048, 768],
+ [32, 2048],
+ [1, 39200],
+ [128, 3072],
+ [16, 1280],
+ ]
+ b_shapes = [
+ [768, 3072],
+ [1280, 1000],
+ [768, 768],
+ [2048, 1000],
+ [39200, 50],
+ [3072, 1000],
+ [1280, 1000],
+ ]
+
+ # Artificial larger shapes
+ a_shapes += [[10240, 512], [10240, 1024]]
+ b_shapes += [[512, 10240], [1024, 10240]]
+
+ for i in range(len(a_shapes)):
+ a_shape = a_shapes[i]
+ b_shape = b_shapes[i]
+ print("Shape:", a_shape, "x", b_shape)
+ a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
+ b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
+
+ time_with_torch_timer(torch_mm, (a, b), string_id="torch mm")
+ time_with_torch_timer(torch_mm_relu, (a, b), string_id="torch mm + relu")
+ time_with_torch_timer(inductor_mm, (a, b), string_id="inductor mm")
+
+
+# Results obtained on the AWS AI cluster
+# CPU: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
+# GPU: NVIDIA A100-SXM 40GB memory
+"""
+Shape: [2048, 768] x [768, 3072]
+torch mm mean: 0.0592 ms
+torch mm + relu mean: 0.0759 ms
+inductor mm mean: 0.0653 ms
+Shape: [64, 1280] x [1280, 1000]
+torch mm mean: 0.0231 ms
+torch mm + relu mean: 0.0316 ms
+inductor mm mean: 0.0252 ms
+Shape: [2048, 768] x [768, 768]
+torch mm mean: 0.0190 ms
+torch mm + relu mean: 0.0277 ms
+inductor mm mean: 0.0274 ms
+Shape: [32, 2048] x [2048, 1000]
+torch mm mean: 0.0188 ms
+torch mm + relu mean: 0.0290 ms
+inductor mm mean: 0.0244 ms
+Shape: [1, 39200] x [39200, 50]
+torch mm mean: 0.0134 ms
+torch mm + relu mean: 0.0234 ms
+inductor mm mean: 0.0290 ms
+Shape: [128, 3072] x [3072, 1000]
+torch mm mean: 0.0181 ms
+torch mm + relu mean: 0.0322 ms
+inductor mm mean: 0.0319 ms
+Shape: [16, 1280] x [1280, 1000]
+torch mm mean: 0.0188 ms
+torch mm + relu mean: 0.0289 ms
+inductor mm mean: 0.0255 ms
+Shape: [10240, 512] x [512, 10240]
+torch mm mean: 0.4589 ms
+torch mm + relu mean: 0.7896 ms
+inductor mm mean: 0.5090 ms
+Shape: [10240, 1024] x [1024, 10240]
+torch mm mean: 0.9152 ms
+torch mm + relu mean: 1.2124 ms
+inductor mm mean: 0.9462 ms
+"""
diff --git a/benchmarks/dynamo/microbenchmarks/microbench.py b/benchmarks/dynamo/microbenchmarks/microbench.py
new file mode 100755
index 0000000000000..cab1bdc444d70
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/microbench.py
@@ -0,0 +1,176 @@
+#!/usr/bin/env python3
+import argparse
+import inspect
+import sys
+
+import numpy as np
+import tabulate
+import torch
+
+import torch._inductor
+from torch._dynamo.optimizations.backends import cudagraphs_inner
+from torch._dynamo.testing import same
+from torch._inductor.compile_fx import compile_fx
+from torch._inductor.utils import timed
+
+try:
+ import test.test_torchinductor as tti
+except ImportError:
+ tti = None
+
+
+def compute_speedups(args, models, example_inputs):
+ expected = models[0](*example_inputs)
+ for model in models[1:]:
+ actual = model(*example_inputs)
+ assert same(actual, expected), expected[0] - actual[0]
+
+ timings = np.zeros((args.repeat, len(models)), np.float64)
+ for rep in range(args.repeat):
+ # interleave the runs to handle frequency scaling and load changes
+ for m, model in enumerate(models):
+ timings[rep, m] = timed(model, example_inputs)
+ median = np.median(timings, axis=0)
+ return (median[0] / median[1:]).tolist()
+
+
+def microbenchmark(args, model, example_inputs):
+ compiled_fn = compile_fx(torch.fx.symbolic_trace(model), example_inputs)
+ cudagraphs_eager = cudagraphs_inner(model, example_inputs, copy_outputs=False)
+ cudagraphs_jit = cudagraphs_inner(
+ torch.jit.trace(model, example_inputs), example_inputs, copy_outputs=False
+ )
+ return compute_speedups(
+ args,
+ [cudagraphs_eager, cudagraphs_jit, compiled_fn],
+ example_inputs,
+ )
+
+
+class MyModel1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = torch.nn.Sequential(
+ torch.nn.Linear(1024, 1024),
+ torch.nn.ReLU(),
+ )
+
+ def forward(self, input):
+ # return (self.model(input) + 1,)
+ return (self.model(input),)
+
+
+class MyModel2(torch.nn.Module):
+ def forward(self, x, y):
+ # return x / (torch.abs(x) + 1.0),
+ return (x + y,)
+
+
+class MicroBenchmarks:
+ @staticmethod
+ def add(a, b):
+ return (a + b,)
+
+ @staticmethod
+ def scale(x, m, d):
+ return ((x - m) / torch.clip(d, 1e-4),)
+
+ @staticmethod
+ def abs_norm(x):
+ return (x / (torch.abs(x) + 1),)
+
+ @staticmethod
+ def add_relu_softmax(x, a):
+ return (torch.softmax(torch.relu(x + a), -1),)
+
+ @staticmethod
+ def sum(a, b):
+ return ((a + b).sum(),)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--filter", "-k", action="append", help="filter benchmarks with regexp"
+ )
+ parser.add_argument(
+ "--exclude", "-x", action="append", help="filter benchmarks with regexp"
+ )
+ parser.add_argument("--devices", "-d", action="append", help="cpu or cuda")
+ parser.add_argument("--size", "-s", action="append", help="cpu or cuda")
+ parser.add_argument(
+ "--repeat", "-n", type=int, default=30, help="number of timing runs"
+ )
+ parser.add_argument(
+ "--threads", "-t", type=int, help="number of threads to use for eager"
+ )
+ parser.add_argument(
+ "--verbose", "-v", action="store_true", help="enable verbose debug printouts"
+ )
+ parser.add_argument(
+ "--nvfuser", action="store_true", help="enable nvfuser globally"
+ )
+ parser.add_argument("--transpose", action="store_true", help="transpose one input")
+ parser.add_argument("--broadcast", action="store_true", help="broadcast one input")
+ args = parser.parse_args()
+
+ # defaults
+ args.devices = args.devices or ["cpu", "cuda"]
+ args.filter = args.filter or [r"."]
+ args.exclude = args.exclude or [r"^$"]
+ args.size = args.size or [64, 256, 1024, 4096, 8192]
+
+ if args.nvfuser:
+ torch._C._jit_override_can_fuse_on_cpu(False)
+ torch._C._jit_override_can_fuse_on_gpu(False)
+ torch._C._jit_set_texpr_fuser_enabled(False)
+ torch._C._jit_set_nvfuser_enabled(True)
+ else:
+ torch._C._jit_override_can_fuse_on_cpu(torch._C._llvm_enabled())
+ torch._C._jit_override_can_fuse_on_gpu(True)
+ torch._C._jit_set_texpr_fuser_enabled(True)
+ if torch.cuda.is_available():
+ torch._C._jit_set_nvfuser_enabled(False)
+
+ if args.threads:
+ torch.set_num_threads(args.threads)
+ torch._inductor.config.cpp.threads = args.threads
+
+ if args.verbose:
+ torch._inductor.config.debug = True
+
+ torch._inductor.config.triton.autotune = True
+
+ rows = []
+ for model in (MicroBenchmarks.sum,):
+ nargs = len(inspect.signature(model).parameters)
+ for device in args.devices:
+ for n in args.size:
+ n = int(n)
+ sys.stdout.write(f"{model.__name__:10} {device:4} {n:5} ")
+ sys.stdout.flush()
+ inputs = [torch.rand((n, n), device=device) for _ in range(nargs)]
+ if args.broadcast:
+ inputs[-1] = torch.rand((1, n), device=device)
+ if args.transpose:
+ inputs[-1] = inputs[-1].transpose(0, 1)
+ result = microbenchmark(args, model, inputs)
+ rows.append([model.__name__, device, str(n)] + result)
+ print(" ".join(f"{v:.2f}x" for v in result))
+
+ print(
+ tabulate.tabulate(
+ rows,
+ headers=[
+ "model",
+ "dev",
+ "n",
+ "ts",
+ "inductor",
+ ],
+ )
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmarks/dynamo/microbenchmarks/model.py b/benchmarks/dynamo/microbenchmarks/model.py
new file mode 100644
index 0000000000000..c926b6c79d0ad
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/model.py
@@ -0,0 +1,26 @@
+# resnet50 layer shape
+resnet50_layers = (
+ # IN_H, IN_W, IN_C, KERNEL_H, KERNEL_W, KERNEL_N, stride, padding
+ (224, 224, 3, 7, 7, 64, (2, 2), (0, 0)),
+ # conv2_x
+ (56, 56, 64, 1, 1, 64, (1, 1), (0, 0)),
+ (56, 56, 64, 3, 3, 64, (1, 1), (0, 0)),
+ (56, 56, 64, 1, 1, 256, (1, 1), (0, 0)),
+ # conv3_x
+ (56, 56, 256, 1, 1, 128, (2, 2), (0, 0)),
+ (28, 28, 128, 3, 3, 128, (1, 1), (0, 0)),
+ (28, 28, 128, 1, 1, 512, (1, 1), (0, 0)),
+ # conv4_x
+ (28, 28, 512, 1, 1, 256, (2, 2), (0, 0)),
+ (14, 14, 256, 3, 3, 256, (1, 1), (0, 0)),
+ (14, 14, 256, 1, 1, 1024, (1, 1), (0, 0)),
+ # conv5_x
+ (14, 14, 1024, 1, 1, 512, (2, 2), (0, 0)),
+ (7, 7, 512, 3, 3, 512, (1, 1), (0, 0)),
+ (7, 7, 512, 1, 1, 2048, (1, 1), (0, 0)),
+)
+
+alexnet_layers = (
+ # IN_H, IN_W, IN_C, KERNEL_H, KERNEL_W, KERNEL_N, stride, padding
+ (224, 224, 3, 11, 11, 64, (4, 4), (2, 2)),
+)
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/AlbertForMaskedLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/AlbertForMaskedLM_training.txt
new file mode 100644
index 0000000000000..b2374b7faa537
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/AlbertForMaskedLM_training.txt
@@ -0,0 +1,115 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([1024, 30000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([1024, 30000], f16), T([1024, 30000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([2, 64, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([2, 64, 512, 512], f16), T([2, 64, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([2, 1, 1, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([2, 64, 512, 64], f16), [128, 512, 64]), {})
+cnt: 12, ((T([2, 64, 64, 512], f16), [128, 64, 512]), {})
+cnt: 12, ((T([128, 512, 512], f16), [2, 64, 512, 512]), {})
+cnt: 12, ((T([128, 512, 64], f16), [2, 64, 512, 64]), {})
+cnt: 36, ((T([2, 512, 64, 64], f16), [2, 512, 4096]), {})
+cnt: 12, ((T([2, 512, 4096], f16), [1024, 4096]), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([2, 512, 128], f16), T([2, 512, 128], f16)), {})
+cnt: 12, ((T([2, 64, 512, 512], f16), T([2, 1, 1, 512], f16)), {})
+cnt: 72, ((T([2, 512, 4096], f16), T([2, 512, 4096], f16)), {})
+cnt: 36, ((T([2, 512, 16384], f16), T([2, 512, 16384], f16)), {})
+cnt: 12, ((T([2, 512, 16384], f16), 1.0), {})
+cnt: 1, ((T([2, 512, 128], f16), 1.0), {})
+cnt: 99, ((T([4096], f16), T([4096], f16)), {})
+cnt: 11, ((T([4096, 16384], f16), T([4096, 16384], f16)), {})
+cnt: 11, ((T([16384], f16), T([16384], f16)), {})
+cnt: 11, ((T([16384, 4096], f16), T([16384, 4096], f16)), {})
+cnt: 44, ((T([4096, 4096], f16), T([4096, 4096], f16)), {})
+cnt: 1, ((T([30000, 128], f16), T([30000, 128], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([2, 512, 128], f16), T([1, 512, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([4096], f16), T([1024, 128], f16), T([128, 4096], f16, stride=(1, 128))), {})
+cnt: 48, ((T([4096], f16), T([1024, 4096], f16), T([4096, 4096], f16, stride=(1, 4096))), {})
+cnt: 12, ((T([16384], f16), T([1024, 4096], f16), T([4096, 16384], f16, stride=(1, 4096))), {})
+cnt: 12, ((T([4096], f16), T([1024, 16384], f16), T([16384, 4096], f16, stride=(1, 16384))), {})
+cnt: 1, ((T([128], f16), T([1024, 4096], f16), T([4096, 128], f16, stride=(1, 4096))), {})
+cnt: 1, ((T([30000], f16), T([1024, 128], f16), T([128, 30000], f16, stride=(1, 128))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([128, 512, 64], f16), T([128, 64, 512], f16)), {})
+cnt: 12, ((T([128, 512, 512], f16), T([128, 512, 64], f16)), {})
+cnt: 12, ((T([128, 512, 512], f16, stride=(262144, 1, 512)), T([128, 512, 64], f16)), {})
+cnt: 12, ((T([128, 512, 64], f16), T([128, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([128, 64, 512], f16, stride=(32768, 1, 64)), T([128, 512, 512], f16)), {})
+cnt: 12, ((T([128, 512, 512], f16), T([128, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([2, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([2, 512], i64), T([2, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([2, 64, 512, 512], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30000, 128], f16), T([2, 512], i64), 0), {})
+cnt: 1, ((T([2, 128], f16), T([2, 512], i64, stride=(0, 1))), {})
+cnt: 1, ((T([512, 128], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([2, 512, 128], f16), T([2, 512], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([2, 512, 128], f16), T([2, 512], i64), 30000, 0, False), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 30000], f16), T([30000, 128], f16)), {})
+cnt: 1, ((T([30000, 1024], f16, stride=(1, 30000)), T([1024, 128], f16)), {})
+cnt: 1, ((T([1024, 128], f16), T([128, 4096], f16)), {})
+cnt: 1, ((T([128, 1024], f16, stride=(1, 128)), T([1024, 4096], f16)), {})
+cnt: 12, ((T([1024, 4096], f16), T([4096, 16384], f16)), {})
+cnt: 12, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 16384], f16)), {})
+cnt: 12, ((T([1024, 16384], f16), T([16384, 4096], f16)), {})
+cnt: 12, ((T([16384, 1024], f16, stride=(1, 16384)), T([1024, 4096], f16)), {})
+cnt: 48, ((T([1024, 4096], f16), T([4096, 4096], f16)), {})
+cnt: 48, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 4096], f16)), {})
+cnt: 1, ((T([1024, 4096], f16), T([4096, 128], f16)), {})
+cnt: 1, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 128], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 1, ((T([2, 512, 128], f16), 3.0), {})
+cnt: 12, ((T([2, 512, 16384], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([2, 1, 1, 512], f16), -65504.0), {})
+cnt: 24, ((T([2, 512, 16384], f16), 0.5), {})
+cnt: 24, ((T([2, 512, 16384], f16), 0.044715), {})
+cnt: 24, ((T([2, 512, 16384], f16), 0.7978845608028654), {})
+cnt: 48, ((T([2, 512, 16384], f16), T([2, 512, 16384], f16)), {})
+cnt: 2, ((T([2, 512, 128], f16), 0.5), {})
+cnt: 2, ((T([2, 512, 128], f16), 0.044715), {})
+cnt: 2, ((T([2, 512, 128], f16), 0.7978845608028654), {})
+cnt: 4, ((T([2, 512, 128], f16), T([2, 512, 128], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 2, ((T([2, 512, 128], f16), [128], T([128], f16), T([128], f16), 1e-12), {})
+cnt: 24, ((T([2, 512, 4096], f16), [4096], T([4096], f16), T([4096], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 2, ((T([2, 512, 128], f16), T([2, 512, 128], f16), [128], T([2, 512, 1], f32), T([2, 512, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 24, ((T([2, 512, 4096], f16), T([2, 512, 4096], f16), [4096], T([2, 512, 1], f32), T([2, 512, 1], f32), T([4096], f16), T([4096], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([1024, 30000], f16), T([1024], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([1024, 30000], f16), T([1024], i64), None, 1, -100), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 12, ((T([2, 512, 16384], f16), 3.0), {})
+cnt: 1, ((T([2, 512, 128], f16), 3.0), {})
+cnt: 1, ((T([2, 512, 128], f16), 2.0), {})
+cnt: 12, ((T([2, 512, 16384], f16), 2.0), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([2, 1, 1, 512], f16), 1.0), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([1024, 30000], f16), [0], True), {})
+cnt: 1, ((T([1024, 128], f16), [0], True), {})
+cnt: 61, ((T([1024, 4096], f16), [0], True), {})
+cnt: 12, ((T([1024, 16384], f16), [0], True), {})
+cnt: 1, ((T([2, 512, 128], f16), [0], True), {})
+Operator: aten.tanh.default
+cnt: 12, ((T([2, 512, 16384], f16),), {})
+cnt: 1, ((T([2, 512, 128], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 1, ((T([2, 512, 128], f16), T([2, 512, 128], f16)), {})
+cnt: 12, ((T([2, 512, 16384], f16), T([2, 512, 16384], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/AlbertForQuestionAnswering_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/AlbertForQuestionAnswering_training.txt
new file mode 100644
index 0000000000000..8e25df92770b6
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/AlbertForQuestionAnswering_training.txt
@@ -0,0 +1,110 @@
+Operator: aten._log_softmax.default
+cnt: 2, ((T([2, 512], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 2, ((T([2, 512], f16), T([2, 512], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([2, 64, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([2, 64, 512, 512], f16), T([2, 64, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([2, 1, 1, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([2, 64, 512, 64], f16), [128, 512, 64]), {})
+cnt: 12, ((T([2, 64, 64, 512], f16), [128, 64, 512]), {})
+cnt: 12, ((T([128, 512, 512], f16), [2, 64, 512, 512]), {})
+cnt: 12, ((T([128, 512, 64], f16), [2, 64, 512, 64]), {})
+cnt: 36, ((T([2, 512, 64, 64], f16), [2, 512, 4096]), {})
+cnt: 12, ((T([2, 512, 4096], f16), [1024, 4096]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([2, 512, 128], f16), T([2, 512, 128], f16)), {})
+cnt: 12, ((T([2, 64, 512, 512], f16), T([2, 1, 1, 512], f16)), {})
+cnt: 72, ((T([2, 512, 4096], f16), T([2, 512, 4096], f16)), {})
+cnt: 36, ((T([2, 512, 16384], f16), T([2, 512, 16384], f16)), {})
+cnt: 12, ((T([2, 512, 16384], f16), 1.0), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+cnt: 99, ((T([4096], f16), T([4096], f16)), {})
+cnt: 11, ((T([4096, 16384], f16), T([4096, 16384], f16)), {})
+cnt: 11, ((T([16384], f16), T([16384], f16)), {})
+cnt: 11, ((T([16384, 4096], f16), T([16384, 4096], f16)), {})
+cnt: 44, ((T([4096, 4096], f16), T([4096, 4096], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([2, 512, 128], f16), T([1, 512, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([4096], f16), T([1024, 128], f16), T([128, 4096], f16, stride=(1, 128))), {})
+cnt: 48, ((T([4096], f16), T([1024, 4096], f16), T([4096, 4096], f16, stride=(1, 4096))), {})
+cnt: 12, ((T([16384], f16), T([1024, 4096], f16), T([4096, 16384], f16, stride=(1, 4096))), {})
+cnt: 12, ((T([4096], f16), T([1024, 16384], f16), T([16384, 4096], f16, stride=(1, 16384))), {})
+cnt: 1, ((T([2], f16), T([1024, 4096], f16), T([4096, 2], f16, stride=(1, 4096))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([128, 512, 64], f16), T([128, 64, 512], f16)), {})
+cnt: 12, ((T([128, 512, 512], f16), T([128, 512, 64], f16)), {})
+cnt: 12, ((T([128, 512, 512], f16, stride=(262144, 1, 512)), T([128, 512, 64], f16)), {})
+cnt: 12, ((T([128, 512, 64], f16), T([128, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([128, 64, 512], f16, stride=(32768, 1, 64)), T([128, 512, 512], f16)), {})
+cnt: 12, ((T([128, 512, 512], f16), T([128, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([2, 512, 1], f16), T([2, 512, 1], f16)], 2), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([2], i64), 0, 512), {})
+Operator: aten.clone.default
+cnt: 1, ((T([2, 512], i64),), {})
+cnt: 2, ((T([2], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([2, 512], i64), T([2, 512], i64)), {})
+cnt: 2, ((T([2], i64), T([2], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([2, 64, 512, 512], f16), 8.0), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30000, 128], f16), T([2, 512], i64), 0), {})
+cnt: 1, ((T([2, 128], f16), T([2, 512], i64, stride=(0, 1))), {})
+cnt: 1, ((T([512, 128], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([2, 512, 128], f16), T([2, 512], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([2, 512, 128], f16), T([2, 512], i64), 30000, 0, False), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 2], f16), T([2, 4096], f16)), {})
+cnt: 1, ((T([2, 1024], f16, stride=(1, 2)), T([1024, 4096], f16)), {})
+cnt: 12, ((T([1024, 4096], f16), T([4096, 16384], f16)), {})
+cnt: 12, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 16384], f16)), {})
+cnt: 12, ((T([1024, 16384], f16), T([16384, 4096], f16)), {})
+cnt: 12, ((T([16384, 1024], f16, stride=(1, 16384)), T([1024, 4096], f16)), {})
+cnt: 48, ((T([1024, 4096], f16), T([4096, 4096], f16)), {})
+cnt: 48, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 4096], f16)), {})
+cnt: 1, ((T([1024, 4096], f16), T([4096, 128], f16)), {})
+cnt: 1, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 128], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 12, ((T([2, 512, 16384], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([2, 1, 1, 512], f16), -65504.0), {})
+cnt: 24, ((T([2, 512, 16384], f16), 0.5), {})
+cnt: 24, ((T([2, 512, 16384], f16), 0.044715), {})
+cnt: 24, ((T([2, 512, 16384], f16), 0.7978845608028654), {})
+cnt: 48, ((T([2, 512, 16384], f16), T([2, 512, 16384], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 1, ((T([2, 512, 128], f16), [128], T([128], f16), T([128], f16), 1e-12), {})
+cnt: 24, ((T([2, 512, 4096], f16), [4096], T([4096], f16), T([4096], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 24, ((T([2, 512, 4096], f16), T([2, 512, 4096], f16), [4096], T([2, 512, 1], f32), T([2, 512, 1], f32), T([4096], f16), T([4096], f16), [True, True, True]), {})
+cnt: 1, ((T([2, 512, 128], f16), T([2, 512, 128], f16), [128], T([2, 512, 1], f32), T([2, 512, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 2, ((T([], f16), T([2, 512], f16), T([2], i64), None, 1, 512, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 2, ((T([2, 512], f16), T([2], i64), None, 1, 512), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 12, ((T([2, 512, 16384], f16), 3.0), {})
+cnt: 12, ((T([2, 512, 16384], f16), 2.0), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([2, 1, 1, 512], f16), 1.0), {})
+Operator: aten.split.Tensor
+cnt: 1, ((T([2, 512, 2], f16), 1, -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([1024, 2], f16), [0], True), {})
+cnt: 61, ((T([1024, 4096], f16), [0], True), {})
+cnt: 12, ((T([1024, 16384], f16), [0], True), {})
+cnt: 1, ((T([2, 512, 128], f16), [0], True), {})
+Operator: aten.tanh.default
+cnt: 12, ((T([2, 512, 16384], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 12, ((T([2, 512, 16384], f16), T([2, 512, 16384], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/AllenaiLongformerBase_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/AllenaiLongformerBase_training.txt
new file mode 100644
index 0000000000000..5cf27686039e4
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/AllenaiLongformerBase_training.txt
@@ -0,0 +1,186 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([1024, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([1024, 50265], f16), T([1024, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([1, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)), -1, True), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([1, 1024, 12, 513], f32), T([1, 1024, 12, 513], f32), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1, 1, 1, 1024], f32),), {'dtype': f16})
+cnt: 1, ((T([1, 1024], b8),), {'dtype': i32})
+cnt: 1, ((T([1, 1024], i64),), {'dtype': i32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([1, 1024], i32),), {'dtype': i64})
+cnt: 12, ((T([1, 1024, 1, 1], b8),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 12, ((T([1, 1024, 12, 513], f32),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 12, ((T([1, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([12, 3, 512, 64, 1], f16), [36, 512, 64]), {})
+cnt: 12, ((T([12, 3, 64, 512, 1], f16), [36, 64, 512]), {})
+cnt: 12, ((T([12, 4, 768, 64, 1], f16), [48, 768, 64]), {})
+cnt: 24, ((T([1024, 1, 12, 64], f16), [1024, 1, 768]), {})
+cnt: 12, ((T([12, 4, 256, 1, 64], f16), [48, 256, 64]), {})
+cnt: 12, ((T([12, 4, 768, 64], i64), [2359296]), {})
+cnt: 12, ((T([12, 3, 512, 64], f16), [1179648]), {})
+cnt: 24, ((T([12, 3, 512, 64], i64), [1179648]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([1, 1024], i64), 1), {})
+cnt: 50, ((T([1, 1024, 768], f16), T([1, 1024, 768], f16)), {})
+cnt: 36, ((T([12, 3, 512, 513], f16), T([12, 3, 512, 513], f16)), {})
+cnt: 24, ((T([1024, 1, 768], f16), T([1024, 1, 768], f16)), {})
+cnt: 1, ((T([50265, 768], f16), T([50265, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 12, ((T([1, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)), T([1, 1024, 1, 513], f16)), {})
+Operator: aten.addmm.default
+cnt: 49, ((T([768], f16), T([1024, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([1024, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([1024, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([50265], f16), T([1024, 768], f16), T([768, 50265], f16, stride=(1, 768))), {})
+Operator: aten.any.default
+cnt: 1, ((T([1024], b8),), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([36, 512, 64], f16), T([36, 64, 512], f16)), {})
+cnt: 12, ((T([48, 256, 768], f16, stride=(197120, 769, 1)), T([48, 768, 64], f16)), {})
+cnt: 12, ((T([48, 768, 256], f16, stride=(197120, 1, 769)), T([48, 256, 64], f16)), {})
+cnt: 12, ((T([48, 256, 64], f16), T([48, 64, 768], f16, stride=(49152, 1, 64))), {})
+cnt: 12, ((T([36, 64, 512], f16, stride=(32768, 1, 64)), T([36, 512, 512], f16)), {})
+cnt: 12, ((T([36, 512, 512], f16), T([36, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([1, 1024], i64),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 12, ((T([12, 3, 512, 512], f16), [0, 0, 0, 1], 0.0), {})
+cnt: 12, ((T([1, 3, 512, 512], f16), [0, 0, 0, 1], 0.0), {})
+cnt: 12, ((T([12, 1024, 64], f16, stride=(64, 768, 1)), [0, 0, 256, 256], -1.0), {})
+cnt: 12, ((T([12, 4, 256, 513], f16, stride=(513, 1575936, 6156, 1)), [0, 257], 0.0), {})
+cnt: 12, ((T([12, 4, 256, 770], f16), [0, -257]), {})
+cnt: 12, ((T([12, 1536, 64], f16), [0, 0, -256, -256]), {})
+cnt: 12, ((T([12, 3, 513, 512], f16), [0, 0, 0, -1]), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([1, 1024], i64), T([1, 1024], i64)), {})
+cnt: 12, ((T([12, 3, 256, 257], f16, stride=(525312, 131328, 513, 1)), T([12, 3, 256, 257], f16, stride=(787968, 262656, 513, 1))), {})
+cnt: 12, ((T([12, 256, 257], f16, stride=(525312, 513, 1)), T([12, 256, 257], f16, stride=(787968, 513, 1))), {})
+cnt: 12, ((T([12, 3, 256, 256], f16, stride=(525312, 131328, 513, 1)), T([12, 3, 256, 256], f16, stride=(787968, 262656, 513, 1))), {})
+cnt: 12, ((T([12, 255, 255], f16, stride=(525312, 513, 1)), T([12, 255, 255], f16, stride=(787968, 513, 1))), {})
+cnt: 12, ((T([1, 3, 256, 257], f16, stride=(525312, 131328, 513, 1)), T([1, 3, 256, 257], f16, stride=(787968, 262656, 513, 1))), {})
+cnt: 12, ((T([1, 256, 257], f16, stride=(525312, 513, 1)), T([1, 256, 257], f16, stride=(787968, 513, 1))), {})
+cnt: 12, ((T([1, 3, 256, 256], f16, stride=(525312, 131328, 513, 1)), T([1, 3, 256, 256], f16, stride=(787968, 262656, 513, 1))), {})
+cnt: 12, ((T([1, 255, 255], f16, stride=(525312, 513, 1)), T([1, 255, 255], f16, stride=(787968, 513, 1))), {})
+cnt: 12, ((T([1024, 12, 513], f16, stride=(513, 525312, 1)), T([1024, 12, 513], f16)), {})
+cnt: 84, ((T([12, 4, 256, 513], f16), T([12, 4, 256, 513], f16)), {})
+cnt: 12, ((T([1, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)), T([1, 1024, 12, 513], f16)), {})
+cnt: 24, ((T([1, 256, 12, 257], f16, stride=(6303744, 513, 525312, 1)), T([1, 256, 12, 257], f16)), {})
+cnt: 12, ((T([12, 255, 255], f16, stride=(525312, 513, 1)), T([12, 255, 255], f16)), {})
+cnt: 12, ((T([12, 3, 256, 256], f16, stride=(525312, 131328, 513, 1)), T([12, 3, 256, 256], f16)), {})
+cnt: 12, ((T([12, 256, 257], f16, stride=(525312, 513, 1)), T([12, 256, 257], f16)), {})
+cnt: 24, ((T([1024, 768], f16), T([1024, 768], f16)), {})
+cnt: 12, ((T([1024, 1, 768], f16), T([1024, 1, 768], f16)), {})
+Operator: aten.cumsum.default
+cnt: 1, ((T([1, 1024], i32), 1), {})
+Operator: aten.div.Tensor
+cnt: 12, ((T([1024, 1, 768], f16), 8.0), {})
+Operator: aten.div_.Tensor
+cnt: 12, ((T([1024, 1, 768], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50265, 768], f16), T([1, 1024], i64), 1), {})
+cnt: 1, ((T([4098, 768], f16), T([1, 1024], i64), 1), {})
+cnt: 1, ((T([1, 768], f16), T([1, 1024], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024], i64), 1, -1, False), {})
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024], i64), 4098, 1, False), {})
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024], i64), 50265, 1, False), {})
+Operator: aten.eq.Scalar
+cnt: 24, ((T([1, 256, 12, 257], f16, stride=(65792, 257, 0, 1)), 1), {})
+cnt: 24, ((T([1, 256, 1, 257], f16), 1), {})
+Operator: aten.flip.default
+cnt: 24, ((T([256, 257], f16), [0]), {})
+cnt: 24, ((T([1, 256, 1, 257], f16), [1, 3]), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([1, 1024, 3072], f16),), {})
+cnt: 1, ((T([1, 1024, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024, 768], f16)), {})
+cnt: 12, ((T([1, 1024, 3072], f16), T([1, 1024, 3072], f16)), {})
+Operator: aten.gt.Scalar
+cnt: 1, ((T([1, 1024], f16), 0), {})
+Operator: aten.index_add_.default
+cnt: 12, ((T([1179648], f16), 0, T([2359296], i64), T([2359296], f16)), {})
+cnt: 24, ((T([786432], f16), 0, T([1179648], i64), T([1179648], f16)), {})
+Operator: aten.lt.Scalar
+cnt: 1, ((T([1, 1024], f16), 0), {})
+Operator: aten.masked_fill.Scalar
+cnt: 12, ((T([1, 1024, 1, 1], f16), T([1, 1024, 1, 1], b8), -65504.0), {})
+cnt: 12, ((T([1, 1024, 12, 513], f32), T([1, 1024, 1, 1], b8), 0.0), {})
+cnt: 12, ((T([1, 1024, 12, 513], f32, stride=(6303744, 513, 525312, 1)), T([1, 1024, 1, 1], b8), 0), {})
+cnt: 24, ((T([1, 256, 12, 257], f16), T([1, 256, 12, 257], b8), 0), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 24, ((T([1, 256, 12, 257], f16, stride=(6303744, 513, 525312, 1)), T([1, 256, 12, 257], b8), -inf), {})
+cnt: 24, ((T([1, 256, 1, 257], f16, stride=(525312, 513, 525312, 1)), T([1, 256, 1, 257], b8), -inf), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 50265], f16), T([50265, 768], f16)), {})
+cnt: 1, ((T([50265, 1024], f16, stride=(1, 50265)), T([1024, 768], f16)), {})
+cnt: 49, ((T([1024, 768], f16), T([768, 768], f16)), {})
+cnt: 49, ((T([768, 1024], f16, stride=(1, 768)), T([1024, 768], f16)), {})
+cnt: 12, ((T([1024, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 1024], f16, stride=(1, 768)), T([1024, 3072], f16)), {})
+cnt: 12, ((T([1024, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 1024], f16, stride=(1, 3072)), T([1024, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([1, 1, 1, 1024], f16), -65504.0), {})
+cnt: 1, ((T([1, 1024], i32), T([1, 1024], i32)), {})
+cnt: 12, ((T([1, 3, 512, 1], f16, stride=(1024, 256, 1, 1)), T([1, 3, 1, 512], f16, stride=(1024, 256, 1, 1))), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([1, 1024, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([1, 1024, 768], f16), T([1, 1024, 768], f16), [768], T([1, 1024, 1], f32), T([1, 1024, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([1, 1024], i64), 1), {})
+cnt: 12, ((T([1, 1024], f16), 0), {})
+Operator: aten.new_empty.default
+cnt: 12, ((T([12, 3, 512, 513], f16), [12, 4, 256, 513]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 12, ((T([1, 3, 512, 513], f16), [1, 4, 256, 513]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_empty_strided.default
+cnt: 84, ((T([12, 4, 256, 513], f16), [12, 4, 256, 513], [525312, 131328, 513, 1]), {})
+cnt: 12, ((T([1024, 768], f16), [1024, 768], [768, 1]), {})
+Operator: aten.new_ones.default
+cnt: 12, ((T([1, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)), [256, 257]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 12, ((T([1, 1024, 1, 1], f16), [1, 1024, 1, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 12, ((T([1, 1024, 1, 513], f16), [256, 257]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_zeros.default
+cnt: 12, ((T([12, 4, 768, 64], f16), [1179648]), {})
+cnt: 12, ((T([1024, 12, 513], f16), [6303744]), {})
+cnt: 12, ((T([12, 3, 512, 64], f16, stride=(98304, 32768, 1, 512)), [786432]), {})
+cnt: 12, ((T([12, 3, 512, 64], f16), [786432]), {})
+cnt: 12, ((T([1024, 768], f16), [786432]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([1024, 50265], f16), T([1024], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([1024, 50265], f16), T([1024], i64), None, 1, -100), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([1, 1, 1, 1024], f16), 1.0), {})
+Operator: aten.select_backward.default
+cnt: 12, ((T([12, 512, 513], f16), [12, 3, 512, 513], 1, 0), {})
+cnt: 12, ((T([12, 512, 513], f16), [12, 3, 512, 513], 1, -1), {})
+Operator: aten.slice_backward.default
+cnt: 12, ((T([12, 4, 256, 768], f16), [12, 4, 256, 769], 3, 0, -1, 1), {})
+cnt: 12, ((T([12, 4, 256, 769], f16), [12, 4, 256, 769], 2, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 4, 256, 769], f16), [12, 4, 256, 769], 1, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 4, 256, 769], f16), [12, 4, 256, 769], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 4, 196864], f16), [12, 4, 197120], 2, 0, -256, 1), {})
+cnt: 12, ((T([12, 4, 197120], f16), [12, 4, 197120], 1, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 4, 197120], f16), [12, 4, 197120], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 255, 255], f16), [12, 255, 513], 2, -255, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 255, 513], f16), [12, 512, 513], 1, 0, 255, 1), {})
+cnt: 48, ((T([12, 3, 512, 513], f16), [12, 3, 512, 513], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 3, 256, 256], f16), [12, 3, 256, 513], 3, 257, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 3, 256, 513], f16), [12, 3, 512, 513], 2, -257, -1, 1), {})
+cnt: 24, ((T([12, 3, 512, 513], f16), [12, 3, 512, 513], 1, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 256, 257], f16), [12, 256, 513], 2, 0, 257, 1), {})
+cnt: 12, ((T([12, 256, 513], f16), [12, 512, 513], 1, 256, 9223372036854775807, 1), {})
+cnt: 12, ((T([12, 3, 256, 257], f16), [12, 3, 256, 513], 3, 0, 257, 1), {})
+cnt: 12, ((T([12, 3, 256, 513], f16), [12, 3, 512, 513], 2, 0, 256, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([1024, 50265], f16), [0], True), {})
+cnt: 61, ((T([1024, 768], f16), [0], True), {})
+cnt: 12, ((T([1024, 3072], f16), [0], True), {})
+Operator: aten.tril.default
+cnt: 24, ((T([256, 257], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BartForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BartForCausalLM_training.txt
new file mode 100644
index 0000000000000..25d8b0b7a02ac
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BartForCausalLM_training.txt
@@ -0,0 +1,73 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([4096, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([4096, 50265], f16), T([4096, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([64, 1024, 1024], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([64, 1024, 1024], f16), T([64, 1024, 1024], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1024, 1024], f32),), {'dtype': f16})
+cnt: 1, ((T([4, 1, 1024, 1024], f16, stride=(0, 1048576, 1024, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([4, 1024, 16, 64], f16), [4, 1024, 1024]), {})
+cnt: 1, ((T([4096, 50265], f16), [4, 1024, 50265]), {})
+cnt: 12, ((T([4, 16, 1024, 64], f16), [64, 1024, 64]), {})
+cnt: 12, ((T([4, 1024, 1024], f16), [4096, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([1024], i64), 1), {})
+cnt: 1, ((T([4, 1024], i64, stride=(0, 1)), 2), {})
+cnt: 73, ((T([4, 1024, 1024], f16), T([4, 1024, 1024], f16)), {})
+cnt: 12, ((T([4, 16, 1024, 1024], f16), T([4, 1, 1024, 1024], f16)), {})
+cnt: 1, ((T([50265, 1024], f16), T([50265, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([1024], f16), T([4096, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 12, ((T([4096], f16), T([4096, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 12, ((T([1024], f16), T([4096, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([64, 1024, 64], f16), T([64, 64, 1024], f16, stride=(65536, 1, 64))), {})
+cnt: 24, ((T([64, 1024, 1024], f16), T([64, 1024, 64], f16)), {})
+cnt: 12, ((T([64, 1024, 1024], f16, stride=(1048576, 1, 1024)), T([64, 1024, 64], f16)), {})
+cnt: 12, ((T([64, 64, 1024], f16, stride=(65536, 1, 64)), T([64, 1024, 1024], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([4, 1024], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([4, 1024], i64), T([4, 1024], i64)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50265, 1024], f16), T([4, 1024], i64), 1), {})
+cnt: 1, ((T([1026, 1024], f16), T([4, 1024], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([4, 1024, 1024], f16), T([4, 1024], i64), 1026, -1, False), {})
+cnt: 1, ((T([4, 1024, 1024], f16), T([4, 1024], i64), 50265, 1, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([4, 1024, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([4, 1024, 4096], f16), T([4, 1024, 4096], f16)), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([1024], i64), T([1024, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([1024, 1024], f32), T([1024, 1024], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([4096, 1024], f16), T([1024, 50265], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([50265, 4096], f16, stride=(1, 50265)), T([4096, 1024], f16)), {})
+cnt: 1, ((T([4096, 50265], f16), T([50265, 1024], f16)), {})
+cnt: 12, ((T([4096, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 12, ((T([1024, 4096], f16, stride=(1, 1024)), T([4096, 4096], f16)), {})
+cnt: 12, ((T([4096, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 12, ((T([4096, 4096], f16, stride=(1, 4096)), T([4096, 1024], f16)), {})
+cnt: 48, ((T([4096, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 48, ((T([1024, 4096], f16, stride=(1, 1024)), T([4096, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([4, 1024, 1024], f16), 1.0), {})
+cnt: 24, ((T([4, 1024, 1024], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([4, 1024, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([4, 1024, 1024], f16), T([4, 1024, 1024], f16), [1024], T([4, 1024, 1], f32), T([4, 1024, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([4096, 50265], f16), T([4096], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([4096, 50265], f16), T([4096], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 60, ((T([4096, 1024], f16), [0], True), {})
+cnt: 12, ((T([4096, 4096], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BartForConditionalGeneration_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BartForConditionalGeneration_training.txt
new file mode 100644
index 0000000000000..0e388c6062e74
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BartForConditionalGeneration_training.txt
@@ -0,0 +1,89 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([2048, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([2048, 50265], f16), T([2048, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 36, ((T([32, 1024, 1024], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 36, ((T([32, 1024, 1024], f16), T([32, 1024, 1024], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1024, 1024], f32),), {'dtype': f16})
+cnt: 1, ((T([2, 1, 1024, 1024], f16, stride=(0, 1048576, 1024, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 108, ((T([2, 1024, 16, 64], f16), [2, 1024, 1024]), {})
+cnt: 1, ((T([2048, 50265], f16), [2, 1024, 50265]), {})
+cnt: 36, ((T([2, 16, 1024, 64], f16), [32, 1024, 64]), {})
+cnt: 36, ((T([2, 1024, 1024], f16), [2048, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([2, 1024], i64, stride=(0, 1)), 2), {})
+cnt: 193, ((T([2, 1024, 1024], f16), T([2, 1024, 1024], f16)), {})
+cnt: 1, ((T([1024], i64), 1), {})
+cnt: 12, ((T([2, 16, 1024, 1024], f16), T([2, 1, 1024, 1024], f16)), {})
+cnt: 1, ((T([2, 1024, 50265], f16), T([1, 50265], f16)), {})
+cnt: 2, ((T([50265, 1024], f16), T([50265, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 144, ((T([1024], f16), T([2048, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([4096], f16), T([2048, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([1024], f16), T([2048, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+Operator: aten.any.default
+cnt: 24, ((T([2, 1024, 1024], b8),), {})
+Operator: aten.bmm.default
+cnt: 72, ((T([32, 1024, 64], f16), T([32, 64, 1024], f16, stride=(65536, 1, 64))), {})
+cnt: 72, ((T([32, 1024, 1024], f16), T([32, 1024, 64], f16)), {})
+cnt: 36, ((T([32, 1024, 1024], f16, stride=(1048576, 1, 1024)), T([32, 1024, 64], f16)), {})
+cnt: 36, ((T([32, 64, 1024], f16, stride=(65536, 1, 64)), T([32, 1024, 1024], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([2, 1024], i64),), {})
+cnt: 1, ((T([2, 1023], i64, stride=(1024, 1)),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([2, 1024], i64), T([2, 1024], i64)), {})
+cnt: 1, ((T([2, 1023], i64, stride=(1024, 1)), T([2, 1023], i64)), {})
+Operator: aten.embedding.default
+cnt: 2, ((T([50265, 1024], f16), T([2, 1024], i64), 1), {})
+cnt: 2, ((T([1026, 1024], f16), T([2, 1024], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 2, ((T([2, 1024, 1024], f16), T([2, 1024], i64), 1026, -1, False), {})
+cnt: 2, ((T([2, 1024, 1024], f16), T([2, 1024], i64), 50265, 1, False), {})
+Operator: aten.eq.Scalar
+cnt: 1, ((T([2, 1024], i64), -100), {})
+Operator: aten.fill_.Tensor
+cnt: 1, ((T([2], i64, stride=(1024,)), T([], i64)), {})
+Operator: aten.gelu.default
+cnt: 24, ((T([2, 1024, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 24, ((T([2, 1024, 4096], f16), T([2, 1024, 4096], f16)), {})
+Operator: aten.isinf.default
+cnt: 12, ((T([2, 1024, 1024], f16),), {})
+Operator: aten.isnan.default
+cnt: 12, ((T([2, 1024, 1024], f16),), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([1024], i64), T([1024, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([2, 1024], i64), T([2, 1024], b8), 1), {})
+cnt: 1, ((T([1024, 1024], f32), T([1024, 1024], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 1024], f16), T([1024, 50265], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([50265, 2048], f16, stride=(1, 50265)), T([2048, 1024], f16)), {})
+cnt: 1, ((T([2048, 50265], f16), T([50265, 1024], f16)), {})
+cnt: 24, ((T([2048, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 2048], f16, stride=(1, 1024)), T([2048, 4096], f16)), {})
+cnt: 24, ((T([2048, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 24, ((T([4096, 2048], f16, stride=(1, 4096)), T([2048, 1024], f16)), {})
+cnt: 144, ((T([2048, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 144, ((T([1024, 2048], f16, stride=(1, 1024)), T([2048, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([2, 1024, 1024], f16), 1.0), {})
+cnt: 72, ((T([2, 1024, 1024], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 62, ((T([2, 1024, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 62, ((T([2, 1024, 1024], f16), T([2, 1024, 1024], f16), [1024], T([2, 1024, 1], f32), T([2, 1024, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([2, 1024], i64), [2, 1024]), {'dtype': i64, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([2048, 50265], f16), T([2048], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([2048, 50265], f16), T([2048], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 168, ((T([2048, 1024], f16), [0], True), {})
+cnt: 24, ((T([2048, 4096], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BertForMaskedLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BertForMaskedLM_training.txt
new file mode 100644
index 0000000000000..5cd41366b65e7
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BertForMaskedLM_training.txt
@@ -0,0 +1,81 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([8192, 30522], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([8192, 30522], f16), T([8192, 30522], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([64, 12, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([64, 12, 128, 128], f16), T([64, 12, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([64, 1, 1, 128], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([64, 12, 128, 64], f16), [768, 128, 64]), {})
+cnt: 12, ((T([64, 12, 64, 128], f16), [768, 64, 128]), {})
+cnt: 12, ((T([768, 128, 128], f16), [64, 12, 128, 128]), {})
+cnt: 12, ((T([768, 128, 64], f16), [64, 12, 128, 64]), {})
+cnt: 24, ((T([64, 128, 12, 64], f16), [64, 128, 768]), {})
+cnt: 12, ((T([64, 128, 768], f16), [8192, 768]), {})
+Operator: aten.add.Tensor
+cnt: 73, ((T([64, 128, 768], f16), T([64, 128, 768], f16)), {})
+cnt: 12, ((T([64, 12, 128, 128], f16), T([64, 1, 1, 128], f16)), {})
+cnt: 1, ((T([30522, 768], f16), T([30522, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([64, 128, 768], f16), T([1, 128, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 49, ((T([768], f16), T([8192, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([8192, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([8192, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([30522], f16), T([8192, 768], f16), T([768, 30522], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([768, 128, 64], f16), T([768, 64, 128], f16)), {})
+cnt: 12, ((T([768, 128, 128], f16), T([768, 128, 64], f16)), {})
+cnt: 12, ((T([768, 128, 128], f16, stride=(16384, 1, 128)), T([768, 128, 64], f16)), {})
+cnt: 12, ((T([768, 128, 64], f16), T([768, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 12, ((T([768, 64, 128], f16, stride=(8192, 1, 64)), T([768, 128, 128], f16)), {})
+cnt: 12, ((T([768, 128, 128], f16), T([768, 128, 64], f16, stride=(8192, 1, 128))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([64, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([64, 128], i64), T([64, 128], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([64, 12, 128, 128], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([64, 128], i64), 0), {})
+cnt: 1, ((T([2, 768], f16), T([64, 128], i64, stride=(0, 1))), {})
+cnt: 1, ((T([512, 768], f16), T([1, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 128, 768], f16), T([1, 128], i64), 512, -1, False), {})
+cnt: 1, ((T([64, 128, 768], f16), T([64, 128], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([64, 128, 768], f16), T([64, 128], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([64, 128, 3072], f16),), {})
+cnt: 1, ((T([64, 128, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([64, 128, 768], f16), T([64, 128, 768], f16)), {})
+cnt: 12, ((T([64, 128, 3072], f16), T([64, 128, 3072], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([8192, 30522], f16), T([30522, 768], f16)), {})
+cnt: 1, ((T([30522, 8192], f16, stride=(1, 30522)), T([8192, 768], f16)), {})
+cnt: 49, ((T([8192, 768], f16), T([768, 768], f16)), {})
+cnt: 49, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 768], f16)), {})
+cnt: 12, ((T([8192, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 3072], f16)), {})
+cnt: 12, ((T([8192, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 8192], f16, stride=(1, 3072)), T([8192, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([64, 1, 1, 128], f16), -65504.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([64, 128, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([64, 128, 768], f16), T([64, 128, 768], f16), [768], T([64, 128, 1], f32), T([64, 128, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([8192, 30522], f16), T([8192], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([8192, 30522], f16), T([8192], i64), None, 1, -100), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([64, 1, 1, 128], f16), 1.0), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([8192, 30522], f16), [0], True), {})
+cnt: 61, ((T([8192, 768], f16), [0], True), {})
+cnt: 12, ((T([8192, 3072], f16), [0], True), {})
+cnt: 1, ((T([64, 128, 768], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BertForQuestionAnswering_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BertForQuestionAnswering_training.txt
new file mode 100644
index 0000000000000..463fb6ada1578
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BertForQuestionAnswering_training.txt
@@ -0,0 +1,88 @@
+Operator: aten._log_softmax.default
+cnt: 2, ((T([64, 128], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 2, ((T([64, 128], f16), T([64, 128], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([64, 12, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([64, 12, 128, 128], f16), T([64, 12, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([64, 1, 1, 128], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([64, 12, 128, 64], f16), [768, 128, 64]), {})
+cnt: 12, ((T([64, 12, 64, 128], f16), [768, 64, 128]), {})
+cnt: 12, ((T([768, 128, 128], f16), [64, 12, 128, 128]), {})
+cnt: 12, ((T([768, 128, 64], f16), [64, 12, 128, 64]), {})
+cnt: 24, ((T([64, 128, 12, 64], f16), [64, 128, 768]), {})
+cnt: 12, ((T([64, 128, 768], f16), [8192, 768]), {})
+Operator: aten.add.Tensor
+cnt: 73, ((T([64, 128, 768], f16), T([64, 128, 768], f16)), {})
+cnt: 12, ((T([64, 12, 128, 128], f16), T([64, 1, 1, 128], f16)), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([64, 128, 768], f16), T([1, 128, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([768], f16), T([8192, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([8192, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([8192, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([2], f16), T([8192, 768], f16), T([768, 2], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([768, 128, 64], f16), T([768, 64, 128], f16)), {})
+cnt: 12, ((T([768, 128, 128], f16), T([768, 128, 64], f16)), {})
+cnt: 12, ((T([768, 128, 128], f16, stride=(16384, 1, 128)), T([768, 128, 64], f16)), {})
+cnt: 12, ((T([768, 128, 64], f16), T([768, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 12, ((T([768, 64, 128], f16, stride=(8192, 1, 64)), T([768, 128, 128], f16)), {})
+cnt: 12, ((T([768, 128, 128], f16), T([768, 128, 64], f16, stride=(8192, 1, 128))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 128, 1], f16), T([64, 128, 1], f16)], 2), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([64], i64), 0, 128), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 128], i64),), {})
+cnt: 2, ((T([64], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 128], i64), T([64, 128], i64)), {})
+cnt: 2, ((T([64], i64), T([64], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([64, 12, 128, 128], f16), 8.0), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([64, 128], i64), 0), {})
+cnt: 1, ((T([2, 768], f16), T([64, 128], i64, stride=(0, 1))), {})
+cnt: 1, ((T([512, 768], f16), T([1, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 128, 768], f16), T([1, 128], i64), 512, -1, False), {})
+cnt: 1, ((T([64, 128, 768], f16), T([64, 128], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([64, 128, 768], f16), T([64, 128], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([64, 128, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([64, 128, 3072], f16), T([64, 128, 3072], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([8192, 2], f16), T([2, 768], f16)), {})
+cnt: 1, ((T([2, 8192], f16, stride=(1, 2)), T([8192, 768], f16)), {})
+cnt: 12, ((T([8192, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 3072], f16)), {})
+cnt: 12, ((T([8192, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 8192], f16, stride=(1, 3072)), T([8192, 768], f16)), {})
+cnt: 48, ((T([8192, 768], f16), T([768, 768], f16)), {})
+cnt: 48, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([64, 1, 1, 128], f16), -65504.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([64, 128, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([64, 128, 768], f16), T([64, 128, 768], f16), [768], T([64, 128, 1], f32), T([64, 128, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 2, ((T([], f16), T([64, 128], f16), T([64], i64), None, 1, 128, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 2, ((T([64, 128], f16), T([64], i64), None, 1, 128), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([64, 1, 1, 128], f16), 1.0), {})
+Operator: aten.split.Tensor
+cnt: 1, ((T([64, 128, 2], f16), 1, -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([8192, 2], f16), [0], True), {})
+cnt: 60, ((T([8192, 768], f16), [0], True), {})
+cnt: 12, ((T([8192, 3072], f16), [0], True), {})
+cnt: 1, ((T([64, 128, 768], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BigBird_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BigBird_training.txt
new file mode 100644
index 0000000000000..7bc500b33d95d
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BigBird_training.txt
@@ -0,0 +1,237 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([1024, 50358], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([1024, 50358], f16), T([1024, 50358], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([1, 12, 64, 1024], f16), -1, False), {})
+cnt: 24, ((T([1, 12, 64, 448], f16), -1, False), {})
+cnt: 12, ((T([1, 12, 12, 64, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([1, 12, 64, 1024], f16), T([1, 12, 64, 1024], f16), -1, f16), {})
+cnt: 24, ((T([1, 12, 64, 448], f16), T([1, 12, 64, 448], f16), -1, f16), {})
+cnt: 12, ((T([1, 12, 12, 64, 512], f16), T([1, 12, 12, 64, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 12, ((T([1, 1, 12, 64, 192], f32),), {'dtype': f16})
+cnt: 12, ((T([1, 1, 1024, 1], f32),), {'dtype': f16})
+cnt: 12, ((T([1, 1, 1, 1024], f32),), {'dtype': f16})
+cnt: 12, ((T([12, 14, 3], i32),), {'dtype': i64, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 24, ((T([1, 12, 16, 64, 64], f16), [192, 64, 64]), {})
+cnt: 24, ((T([1, 12, 12, 64, 64], f16), [144, 64, 64]), {})
+cnt: 24, ((T([1, 12, 12, 192, 64], f16), [144, 192, 64]), {})
+cnt: 24, ((T([1, 1024, 12, 64], f16), [1, 1024, 768]), {})
+Operator: aten.add.Tensor
+cnt: 76, ((T([1, 1024, 768], f16), T([1, 1024, 768], f16)), {})
+cnt: 24, ((T([504], i64), T([504], i64)), {})
+cnt: 36, ((T([1, 1024, 3072], f16), T([1, 1024, 3072], f16)), {})
+cnt: 12, ((T([1, 1024, 3072], f16), 1.0), {})
+cnt: 1, ((T([1, 1024, 768], f16), 1.0), {})
+cnt: 360, ((T([1, 12, 16, 64, 64], f16), T([1, 12, 16, 64, 64], f16)), {})
+cnt: 36, ((T([1, 12, 12, 64, 512], f16), T([1, 12, 12, 64, 512], f16)), {})
+cnt: 48, ((T([1, 12, 14, 192, 64], f16), T([1, 12, 14, 192, 64], f16)), {})
+cnt: 36, ((T([1, 12, 12, 64, 64], f16), T([1, 12, 12, 64, 64], f16)), {})
+cnt: 24, ((T([1, 12, 1024, 64], f16), T([1, 12, 1024, 64], f16)), {})
+cnt: 12, ((T([1, 12, 1024, 64], f16, stride=(786432, 65536, 1, 1024)), T([1, 12, 1024, 64], f16, stride=(786432, 65536, 1, 1024))), {})
+cnt: 12, ((T([1, 12, 1024, 64], f16, stride=(786432, 65536, 1, 1024)), T([1, 12, 1024, 64], f16)), {})
+cnt: 1, ((T([50358, 768], f16), T([50358, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024, 768], f16)), {})
+cnt: 24, ((T([1, 12, 64, 1024], f16), T([1, 1, 1, 1024], f16)), {})
+cnt: 24, ((T([1, 12, 64, 448], f16), T([1, 12, 64, 448], f32)), {})
+cnt: 12, ((T([1, 12, 12, 64, 192], f16), T([1, 1, 12, 64, 192], f16)), {})
+cnt: 24, ((T([1, 12, 12, 64, 64], f16), T([1, 1, 1, 1, 64], f16)), {})
+cnt: 12, ((T([1, 12, 12, 64, 192], f16), T([1, 12, 12, 64, 192], f32)), {})
+cnt: 36, ((T([1, 12, 12, 64, 64], f16), T([1, 12, 12, 64, 64], f16)), {})
+Operator: aten.addmm.default
+cnt: 49, ((T([768], f16), T([1024, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([1024, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([1024, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([768], f16), T([1, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 1, ((T([50358], f16), T([1024, 768], f16), T([768, 50358], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 48, ((T([12, 64, 64], f16, stride=(64, 768, 1)), T([12, 64, 1024], f16, stride=(64, 1, 768))), {})
+cnt: 48, ((T([12, 64, 1024], f16), T([12, 1024, 64], f16, stride=(64, 768, 1))), {})
+cnt: 48, ((T([12, 64, 64], f16, stride=(64, 768, 1)), T([12, 64, 448], f16, stride=(28672, 1, 64))), {})
+cnt: 48, ((T([12, 64, 448], f16), T([12, 448, 64], f16)), {})
+cnt: 48, ((T([144, 64, 64], f16), T([144, 64, 192], f16, stride=(12288, 1, 64))), {})
+cnt: 24, ((T([12, 768, 64], f16, stride=(64, 768, 1)), T([12, 64, 64], f16, stride=(64, 1, 768))), {})
+cnt: 24, ((T([144, 64, 192], f16, stride=(32768, 512, 1)), T([144, 192, 64], f16)), {})
+cnt: 24, ((T([12, 768, 64], f16, stride=(393216, 512, 1)), T([12, 64, 64], f16, stride=(64, 768, 1))), {})
+cnt: 24, ((T([12, 1024, 64], f16, stride=(65536, 1, 1024)), T([12, 64, 64], f16, stride=(64, 768, 1))), {})
+cnt: 24, ((T([12, 64, 64], f16, stride=(64, 1, 768)), T([12, 64, 1024], f16)), {})
+cnt: 24, ((T([12, 448, 64], f16, stride=(28672, 1, 448)), T([12, 64, 64], f16, stride=(64, 768, 1))), {})
+cnt: 24, ((T([12, 64, 64], f16, stride=(64, 1, 768)), T([12, 64, 448], f16)), {})
+cnt: 24, ((T([12, 64, 768], f16, stride=(393216, 1, 512)), T([12, 768, 64], f16)), {})
+cnt: 24, ((T([12, 768, 64], f16), T([12, 64, 64], f16, stride=(64, 1, 768))), {})
+cnt: 24, ((T([144, 192, 64], f16, stride=(32768, 1, 512)), T([144, 64, 64], f16)), {})
+cnt: 24, ((T([12, 64, 768], f16, stride=(64, 1, 768)), T([12, 768, 64], f16)), {})
+cnt: 24, ((T([12, 768, 64], f16), T([12, 64, 64], f16, stride=(64, 768, 1))), {})
+cnt: 24, ((T([144, 64, 64], f16, stride=(4096, 1, 64)), T([144, 64, 192], f16)), {})
+cnt: 24, ((T([144, 64, 192], f16), T([144, 192, 64], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([1, 12, 64], f32), T([1, 12, 64], f32), T([1, 12, 64], f32)], 2), {})
+cnt: 12, (([T([1, 12, 14, 3], i64)],), {})
+cnt: 48, (([T([1, 12, 64, 64], f16, stride=(768, 64, 768, 1)), T([1, 12, 64, 64], f16, stride=(768, 64, 768, 1)), T([1, 12, 64, 64], f16, stride=(768, 64, 768, 1)), T([1, 12, 64, 64], f16, stride=(768, 64, 768, 1)), T([1, 12, 192, 64], f16, stride=(2064384, 172032, 64, 1))], 2), {})
+cnt: 12, (([T([1, 1, 1, 192], f16), T([1, 1, 1, 64], f16), T([1, 1, 1, 192], f16)], 3), {})
+cnt: 24, (([T([1, 12, 64, 256], f32), T([1, 12, 64, 192], f32, stride=(2064384, 172032, 192, 1))], 3), {})
+cnt: 24, (([T([1, 12, 12, 64, 64], f16, stride=(768, 64, 49152, 768, 1)), T([1, 12, 12, 64, 64], f16, stride=(768, 64, 49152, 768, 1)), T([1, 12, 12, 64, 64], f16, stride=(768, 64, 49152, 768, 1))], 3), {})
+cnt: 12, (([T([1, 12, 12, 64, 64], f16), T([1, 12, 12, 64, 192], f16), T([1, 12, 12, 64, 192], f16), T([1, 12, 12, 64, 64], f16)], -1), {})
+cnt: 12, (([T([1, 1, 1, 64], f16), T([1, 1, 1, 192], f16), T([1, 1, 1, 192], f16)], 3), {})
+cnt: 12, (([T([1, 12, 1, 64, 64], f16), T([1, 12, 1, 64, 64], f16), T([1, 12, 12, 64, 64], f16), T([1, 12, 1, 64, 64], f16), T([1, 12, 1, 64, 64], f16)], 2), {})
+Operator: aten.clone.default
+cnt: 2, ((T([1, 1024], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([1, 1024], i64), T([1, 1024], i64)), {})
+cnt: 12, ((T([12, 12, 64, 64], f16), T([12, 12, 64, 64], f16, stride=(64, 49152, 768, 1))), {})
+cnt: 36, ((T([144, 64, 64], f16), T([144, 64, 64], f16)), {})
+cnt: 36, ((T([1, 12, 12, 64, 64], f16), T([1, 12, 12, 64, 64], f16)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50358, 768], f16), T([1, 1024], i64), 0), {})
+cnt: 1, ((T([2, 768], f16), T([1, 1024], i64)), {})
+cnt: 1, ((T([4096, 768], f16), T([1, 1024], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024], i64), 4096, -1, False), {})
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024], i64), 2, -1, False), {})
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024], i64), 50358, 0, False), {})
+Operator: aten.floor_divide.default
+cnt: 24, ((T([504], i64), 42), {})
+Operator: aten.index.Tensor
+cnt: 12, ((T([16, 64], f32), [T([504], i64)]), {})
+Operator: aten.index_add.default
+cnt: 24, ((T([192, 64, 64], f16), 0, T([504], i64), T([504, 64, 64], f16)), {})
+Operator: aten.index_select.default
+cnt: 24, ((T([192, 64, 64], f16), 0, T([504], i64)), {})
+Operator: aten.minimum.default
+cnt: 24, ((T([1, 1, 1, 448], f16), T([1, 12, 64, 448], f32)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 50358], f16), T([50358, 768], f16)), {})
+cnt: 1, ((T([50358, 1024], f16, stride=(1, 50358)), T([1024, 768], f16)), {})
+cnt: 37, ((T([1024, 768], f16), T([768, 768], f16)), {})
+cnt: 37, ((T([768, 1024], f16, stride=(1, 768)), T([1024, 768], f16)), {})
+cnt: 12, ((T([1024, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 1024], f16, stride=(1, 768)), T([1024, 3072], f16)), {})
+cnt: 12, ((T([1024, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 1024], f16, stride=(1, 3072)), T([1024, 768], f16)), {})
+cnt: 12, ((T([1024, 768], f16, stride=(1, 1024)), T([768, 768], f16)), {})
+cnt: 12, ((T([768, 1024], f16), T([1024, 768], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 1, ((T([1, 1024, 768], f16), 3.0), {})
+cnt: 12, ((T([1, 1024, 3072], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([1, 12, 64, 1], f32), T([1, 12, 1, 192], f32)), {})
+cnt: 12, ((T([1, 1, 14, 64, 1], f32), T([1, 12, 14, 1, 192], f32)), {})
+cnt: 24, ((T([504], i64), 16), {})
+cnt: 48, ((T([1, 12, 64, 1024], f16), 0.125), {})
+cnt: 24, ((T([1, 1, 1, 1024], f16), -10000.0), {})
+cnt: 48, ((T([1, 12, 64, 448], f16), 0.125), {})
+cnt: 24, ((T([1, 12, 64, 448], f32), -10000.0), {})
+cnt: 24, ((T([1, 12, 12, 64, 192], f16), 0.125), {})
+cnt: 24, ((T([1, 12, 12, 64, 64], f16), 0.125), {})
+cnt: 12, ((T([1, 1, 12, 64, 192], f16), -10000.0), {})
+cnt: 24, ((T([1, 1, 1, 1, 64], f16), -10000.0), {})
+cnt: 12, ((T([1, 12, 12, 64, 192], f32), -10000.0), {})
+cnt: 12, ((T([1, 12, 1024, 64], f16), T([1, 1, 1024, 1], f16)), {})
+cnt: 24, ((T([1, 1024, 3072], f16), 0.5), {})
+cnt: 24, ((T([1, 1024, 3072], f16), 0.044715), {})
+cnt: 24, ((T([1, 1024, 3072], f16), 0.7978845608028654), {})
+cnt: 48, ((T([1, 1024, 3072], f16), T([1, 1024, 3072], f16)), {})
+cnt: 2, ((T([1, 1024, 768], f16), 0.5), {})
+cnt: 2, ((T([1, 1024, 768], f16), 0.044715), {})
+cnt: 2, ((T([1, 1024, 768], f16), 0.7978845608028654), {})
+cnt: 4, ((T([1, 1024, 768], f16), T([1, 1024, 768], f16)), {})
+cnt: 12, ((T([1, 12, 1024, 64], f16, stride=(786432, 64, 768, 1)), T([1, 1, 1024, 1], f16)), {})
+cnt: 24, ((T([1, 12, 12, 64, 64], f16, stride=(4718592, 393216, 32768, 512, 1)), 0.125), {})
+cnt: 24, ((T([1, 12, 12, 64, 192], f16, stride=(4718592, 393216, 32768, 512, 1)), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([1, 1024, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([1, 1024, 768], f16), T([1, 1024, 768], f16), [768], T([1, 1024, 1], f32), T([1, 1024, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 36, ((T([144, 64, 64], f16), [144, 64, 64], [4096, 64, 1]), {})
+Operator: aten.new_ones.default
+cnt: 24, ((T([1, 1, 1, 1024], f16), [1, 1, 1, 192]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 24, ((T([1, 12, 14, 64, 192], f32), [1, 12, 64, 256]), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_zeros.default
+cnt: 12, ((T([12, 12, 64, 64], f16, stride=(64, 49152, 768, 1)), [589824]), {})
+cnt: 24, ((T([504, 64, 64], f16), [192, 64, 64]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([1024, 50358], f16), T([1024], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([1024, 50358], f16), T([1024], i64), None, 1, -100), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 12, ((T([1, 1024, 3072], f16), 3.0), {})
+cnt: 1, ((T([1, 1024, 768], f16), 3.0), {})
+cnt: 1, ((T([1, 1024, 768], f16), 2.0), {})
+cnt: 12, ((T([1, 1024, 3072], f16), 2.0), {})
+Operator: aten.rsub.Scalar
+cnt: 24, ((T([1, 1, 1, 1024], f16), 1.0), {})
+cnt: 24, ((T([1, 12, 64, 448], f32), 1.0), {})
+cnt: 12, ((T([1, 1, 12, 64, 192], f16), 1.0), {})
+cnt: 24, ((T([1, 1, 1, 1, 64], f16), 1.0), {})
+cnt: 12, ((T([1, 12, 12, 64, 192], f32, stride=(2064384, 172032, 12288, 192, 1)), 1.0), {})
+Operator: aten.select_backward.default
+cnt: 24, ((T([1, 12, 64, 64], f16), [1, 12, 16, 64, 64], 2, -1), {})
+cnt: 12, ((T([1, 12, 64, 64], f16), [1, 12, 16, 64, 64], 2, -2), {})
+cnt: 12, ((T([1, 12, 192, 64], f16, stride=(344064, 28672, 64, 1)), [1, 12, 14, 192, 64], 2, -1), {})
+cnt: 24, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [1, 12, 16, 64, 64], 2, -1), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [1, 12, 16, 64, 64], 2, -2), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [1, 12, 16, 64, 64], 2, -3), {})
+cnt: 24, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [1, 12, 16, 64, 64], 2, 0), {})
+cnt: 12, ((T([1, 12, 192, 64], f16, stride=(344064, 28672, 1, 448)), [1, 12, 14, 192, 64], 2, -1), {})
+cnt: 24, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [1, 12, 16, 64, 64], 2, -1), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [1, 12, 16, 64, 64], 2, -2), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [1, 12, 16, 64, 64], 2, -3), {})
+cnt: 24, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [1, 12, 16, 64, 64], 2, 0), {})
+cnt: 24, ((T([1, 12, 64, 64], f16), [1, 12, 16, 64, 64], 2, 0), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(64, 4096, 1, 64)), [1, 12, 16, 64, 64], 2, -1), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(64, 4096, 1, 64)), [1, 12, 16, 64, 64], 2, 0), {})
+cnt: 12, ((T([1, 12, 64, 64], f16), [1, 12, 16, 64, 64], 2, 1), {})
+cnt: 12, ((T([1, 12, 192, 64], f16, stride=(344064, 28672, 64, 1)), [1, 12, 14, 192, 64], 2, 0), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [1, 12, 16, 64, 64], 2, 2), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [1, 12, 16, 64, 64], 2, 1), {})
+cnt: 12, ((T([1, 12, 192, 64], f16, stride=(344064, 28672, 1, 448)), [1, 12, 14, 192, 64], 2, 0), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [1, 12, 16, 64, 64], 2, 2), {})
+cnt: 12, ((T([1, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [1, 12, 16, 64, 64], 2, 1), {})
+Operator: aten.slice_backward.default
+cnt: 372, ((T([1, 12, 16, 64, 64], f16), [1, 12, 16, 64, 64], 1, 0, 9223372036854775807, 1), {})
+cnt: 372, ((T([1, 12, 16, 64, 64], f16), [1, 12, 16, 64, 64], 0, 0, 9223372036854775807, 1), {})
+cnt: 72, ((T([1, 12, 14, 192, 64], f16), [1, 12, 14, 192, 64], 1, 0, 9223372036854775807, 1), {})
+cnt: 72, ((T([1, 12, 14, 192, 64], f16), [1, 12, 14, 192, 64], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 64], f16), [1, 12, 12, 64, 512], 4, -64, 9223372036854775807, 1), {})
+cnt: 48, ((T([1, 12, 12, 64, 512], f16), [1, 12, 12, 64, 512], 3, 0, 9223372036854775807, 1), {})
+cnt: 48, ((T([1, 12, 12, 64, 512], f16), [1, 12, 12, 64, 512], 2, 0, 9223372036854775807, 1), {})
+cnt: 48, ((T([1, 12, 12, 64, 512], f16), [1, 12, 12, 64, 512], 1, 0, 9223372036854775807, 1), {})
+cnt: 48, ((T([1, 12, 12, 64, 512], f16), [1, 12, 12, 64, 512], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 64], f16), [1, 12, 12, 64, 512], 4, 0, 64, 1), {})
+cnt: 12, ((T([1, 12, 12, 192, 64], f16), [1, 12, 14, 192, 64], 2, 1, -1, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 192], f16), [1, 12, 12, 64, 512], 4, 256, -64, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 192], f16), [1, 12, 12, 64, 512], 4, 64, 256, 1), {})
+cnt: 12, ((T([1, 12, 12, 192, 64], f16, stride=(1769472, 147456, 12288, 1, 192)), [1, 12, 14, 192, 64], 2, 1, -1, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 64], f16), [1, 12, 16, 64, 64], 2, 2, -2, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 64, 1)), [1, 12, 16, 64, 64], 2, 3, -1, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 64, 1)), [1, 12, 16, 64, 64], 2, 2, -2, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 64, 1)), [1, 12, 16, 64, 64], 2, 1, -3, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 1, 192)), [1, 12, 16, 64, 64], 2, 3, -1, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 1, 192)), [1, 12, 16, 64, 64], 2, 2, -2, 1), {})
+cnt: 12, ((T([1, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 1, 192)), [1, 12, 16, 64, 64], 2, 1, -3, 1), {})
+Operator: aten.stack.default
+cnt: 12, (([T([504, 64], f32)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([1024, 50358], f16), [0], True), {})
+cnt: 49, ((T([1024, 768], f16), [0], True), {})
+cnt: 12, ((T([1024, 3072], f16), [0], True), {})
+cnt: 12, ((T([1024, 768], f16, stride=(1, 1024)), [0], True), {})
+Operator: aten.tanh.default
+cnt: 12, ((T([1, 1024, 3072], f16),), {})
+cnt: 1, ((T([1, 768], f16),), {})
+cnt: 1, ((T([1, 1024, 768], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024, 768], f16)), {})
+cnt: 12, ((T([1, 1024, 3072], f16), T([1, 1024, 3072], f16)), {})
+Operator: aten.unbind.int
+cnt: 12, ((T([1, 16, 64], f32),), {})
+cnt: 12, ((T([1, 12, 14, 3], i64),), {})
+Operator: aten.unsqueeze_.default
+cnt: 1, ((T([1, 12, 64, 192], f32), 1), {})
+cnt: 12, ((T([12, 14, 3], i64), 0), {})
+cnt: 48, ((T([1, 12, 64, 64], f16), 2), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BlenderbotSmallForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BlenderbotSmallForCausalLM_training.txt
new file mode 100644
index 0000000000000..3bb0b46b03980
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BlenderbotSmallForCausalLM_training.txt
@@ -0,0 +1,74 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([8192, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([8192, 50265], f16), T([8192, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 8, ((T([1024, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 8, ((T([1024, 128, 128], f16), T([1024, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([64, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 24, ((T([64, 128, 16, 32], f16), [64, 128, 512]), {})
+cnt: 1, ((T([8192, 50265], f16), [64, 128, 50265]), {})
+cnt: 8, ((T([64, 16, 128, 32], f16), [1024, 128, 32]), {})
+cnt: 8, ((T([64, 128, 512], f16), [8192, 512]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 1, ((T([64, 128, 512], f16), T([128, 512], f16)), {})
+cnt: 8, ((T([64, 16, 128, 128], f16), T([64, 1, 128, 128], f16)), {})
+cnt: 48, ((T([64, 128, 512], f16), T([64, 128, 512], f16)), {})
+cnt: 1, ((T([50265, 512], f16), T([50265, 512], f16)), {})
+Operator: aten.addmm.default
+cnt: 32, ((T([512], f16), T([8192, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 8, ((T([2048], f16), T([8192, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 8, ((T([512], f16), T([8192, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+Operator: aten.bmm.default
+cnt: 16, ((T([1024, 128, 32], f16), T([1024, 32, 128], f16, stride=(4096, 1, 32))), {})
+cnt: 16, ((T([1024, 128, 128], f16), T([1024, 128, 32], f16)), {})
+cnt: 8, ((T([1024, 128, 128], f16, stride=(16384, 1, 128)), T([1024, 128, 32], f16)), {})
+cnt: 8, ((T([1024, 32, 128], f16, stride=(4096, 1, 32)), T([1024, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([64, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([64, 128], i64), T([64, 128], i64)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50265, 512], f16), T([64, 128], i64), 0), {})
+cnt: 1, ((T([512, 512], f16), T([128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([128, 512], f16), T([128], i64), 512, -1, False), {})
+cnt: 1, ((T([64, 128, 512], f16), T([64, 128], i64), 50265, 0, False), {})
+Operator: aten.gelu.default
+cnt: 8, ((T([64, 128, 2048], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 8, ((T([64, 128, 2048], f16), T([64, 128, 2048], f16)), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([8192, 512], f16), T([512, 50265], f16, stride=(1, 512))), {})
+cnt: 1, ((T([50265, 8192], f16, stride=(1, 50265)), T([8192, 512], f16)), {})
+cnt: 1, ((T([8192, 50265], f16), T([50265, 512], f16)), {})
+cnt: 8, ((T([8192, 512], f16), T([512, 2048], f16)), {})
+cnt: 8, ((T([512, 8192], f16, stride=(1, 512)), T([8192, 2048], f16)), {})
+cnt: 8, ((T([8192, 2048], f16), T([2048, 512], f16)), {})
+cnt: 8, ((T([2048, 8192], f16, stride=(1, 2048)), T([8192, 512], f16)), {})
+cnt: 32, ((T([8192, 512], f16), T([512, 512], f16)), {})
+cnt: 32, ((T([512, 8192], f16, stride=(1, 512)), T([8192, 512], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([64, 128, 512], f16), 1.0), {})
+cnt: 16, ((T([64, 128, 512], f16), 0.1767766952966369), {})
+Operator: aten.native_layer_norm.default
+cnt: 17, ((T([64, 128, 512], f16), [512], T([512], f16), T([512], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 17, ((T([64, 128, 512], f16), T([64, 128, 512], f16), [512], T([64, 128, 1], f32), T([64, 128, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([8192, 50265], f16), T([8192], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([8192, 50265], f16), T([8192], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 40, ((T([8192, 512], f16), [0], True), {})
+cnt: 8, ((T([8192, 2048], f16), [0], True), {})
+cnt: 1, ((T([64, 128, 512], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BlenderbotSmallForConditionalGeneration_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BlenderbotSmallForConditionalGeneration_training.txt
new file mode 100644
index 0000000000000..866fb90264184
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/BlenderbotSmallForConditionalGeneration_training.txt
@@ -0,0 +1,81 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([8192, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([8192, 50265], f16), T([8192, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([1024, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([1024, 128, 128], f16), T([1024, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([64, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 72, ((T([64, 128, 16, 32], f16), [64, 128, 512]), {})
+cnt: 1, ((T([8192, 50265], f16), [64, 128, 50265]), {})
+cnt: 24, ((T([64, 16, 128, 32], f16), [1024, 128, 32]), {})
+cnt: 24, ((T([64, 128, 512], f16), [8192, 512]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([64, 128, 512], f16), T([128, 512], f16)), {})
+cnt: 127, ((T([64, 128, 512], f16), T([64, 128, 512], f16)), {})
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 8, ((T([64, 16, 128, 128], f16), T([64, 1, 128, 128], f16)), {})
+cnt: 1, ((T([64, 128, 50265], f16), T([1, 50265], f16)), {})
+cnt: 2, ((T([50265, 512], f16), T([50265, 512], f16)), {})
+Operator: aten.addmm.default
+cnt: 96, ((T([512], f16), T([8192, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 16, ((T([2048], f16), T([8192, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 16, ((T([512], f16), T([8192, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+Operator: aten.any.default
+cnt: 16, ((T([64, 128, 512], b8),), {})
+Operator: aten.bmm.default
+cnt: 48, ((T([1024, 128, 32], f16), T([1024, 32, 128], f16, stride=(4096, 1, 32))), {})
+cnt: 48, ((T([1024, 128, 128], f16), T([1024, 128, 32], f16)), {})
+cnt: 24, ((T([1024, 128, 128], f16, stride=(16384, 1, 128)), T([1024, 128, 32], f16)), {})
+cnt: 24, ((T([1024, 32, 128], f16, stride=(4096, 1, 32)), T([1024, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 3, ((T([64, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 3, ((T([64, 128], i64), T([64, 128], i64)), {})
+Operator: aten.embedding.default
+cnt: 2, ((T([50265, 512], f16), T([64, 128], i64), 0), {})
+cnt: 2, ((T([512, 512], f16), T([128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 2, ((T([128, 512], f16), T([128], i64), 512, -1, False), {})
+cnt: 2, ((T([64, 128, 512], f16), T([64, 128], i64), 50265, 0, False), {})
+Operator: aten.gelu.default
+cnt: 16, ((T([64, 128, 2048], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 16, ((T([64, 128, 2048], f16), T([64, 128, 2048], f16)), {})
+Operator: aten.isinf.default
+cnt: 8, ((T([64, 128, 512], f16),), {})
+Operator: aten.isnan.default
+cnt: 8, ((T([64, 128, 512], f16),), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([8192, 512], f16), T([512, 50265], f16, stride=(1, 512))), {})
+cnt: 1, ((T([50265, 8192], f16, stride=(1, 50265)), T([8192, 512], f16)), {})
+cnt: 1, ((T([8192, 50265], f16), T([50265, 512], f16)), {})
+cnt: 16, ((T([8192, 512], f16), T([512, 2048], f16)), {})
+cnt: 16, ((T([512, 8192], f16, stride=(1, 512)), T([8192, 2048], f16)), {})
+cnt: 16, ((T([8192, 2048], f16), T([2048, 512], f16)), {})
+cnt: 16, ((T([2048, 8192], f16, stride=(1, 2048)), T([8192, 512], f16)), {})
+cnt: 96, ((T([8192, 512], f16), T([512, 512], f16)), {})
+cnt: 96, ((T([512, 8192], f16, stride=(1, 512)), T([8192, 512], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([64, 128, 512], f16), 1.0), {})
+cnt: 48, ((T([64, 128, 512], f16), 0.1767766952966369), {})
+Operator: aten.native_layer_norm.default
+cnt: 42, ((T([64, 128, 512], f16), [512], T([512], f16), T([512], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 42, ((T([64, 128, 512], f16), T([64, 128, 512], f16), [512], T([64, 128, 1], f32), T([64, 128, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([8192, 50265], f16), T([8192], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([8192, 50265], f16), T([8192], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 112, ((T([8192, 512], f16), [0], True), {})
+cnt: 16, ((T([8192, 2048], f16), [0], True), {})
+cnt: 2, ((T([64, 128, 512], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/CamemBert_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/CamemBert_training.txt
new file mode 100644
index 0000000000000..2ce6229b7d4b5
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/CamemBert_training.txt
@@ -0,0 +1,88 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([512, 32005], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([512, 32005], f16), T([512, 32005], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([1, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([1, 12, 512, 512], f16), T([1, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1, 1, 1, 512], f32),), {'dtype': f16})
+cnt: 1, ((T([1, 512], b8),), {'dtype': i32})
+cnt: 1, ((T([1, 512], i64),), {'dtype': i32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([1, 512], i32),), {'dtype': i64})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([12, 512, 512], f16), [1, 12, 512, 512]), {})
+cnt: 12, ((T([12, 512, 64], f16), [1, 12, 512, 64]), {})
+cnt: 24, ((T([1, 512, 12, 64], f16), [1, 512, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([1, 512], i32), 0), {})
+cnt: 1, ((T([1, 512], i64), 1), {})
+cnt: 73, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 12, ((T([1, 12, 512, 512], f16), T([1, 1, 1, 512], f16)), {})
+cnt: 1, ((T([32005, 768], f16), T([32005, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 49, ((T([768], f16), T([512, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([512, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([512, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([32005], f16), T([512, 768], f16), T([768, 32005], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([12, 512, 64], f16, stride=(64, 768, 1)), T([12, 64, 512], f16, stride=(64, 1, 768))), {})
+cnt: 24, ((T([12, 512, 512], f16), T([12, 512, 64], f16, stride=(64, 768, 1))), {})
+cnt: 12, ((T([12, 512, 512], f16, stride=(262144, 1, 512)), T([12, 512, 64], f16, stride=(64, 768, 1))), {})
+cnt: 12, ((T([12, 64, 512], f16, stride=(64, 1, 768)), T([12, 512, 512], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([1, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([1, 512], i64), T([1, 512], i64)), {})
+Operator: aten.cumsum.default
+cnt: 1, ((T([1, 512], i32), 1), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([1, 12, 512, 512], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([32005, 768], f16), T([1, 512], i64), 1), {})
+cnt: 1, ((T([1, 768], f16), T([1, 512], i64)), {})
+cnt: 1, ((T([514, 768], f16), T([1, 512], i64), 1), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 514, 1, False), {})
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 1, -1, False), {})
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 32005, 1, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([1, 512, 3072], f16),), {})
+cnt: 1, ((T([1, 512, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 12, ((T([1, 512, 3072], f16), T([1, 512, 3072], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 32005], f16), T([32005, 768], f16)), {})
+cnt: 1, ((T([32005, 512], f16, stride=(1, 32005)), T([512, 768], f16)), {})
+cnt: 37, ((T([512, 768], f16), T([768, 768], f16)), {})
+cnt: 37, ((T([768, 512], f16, stride=(1, 768)), T([512, 768], f16)), {})
+cnt: 12, ((T([512, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 512], f16, stride=(1, 768)), T([512, 3072], f16)), {})
+cnt: 12, ((T([512, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 512], f16, stride=(1, 3072)), T([512, 768], f16)), {})
+cnt: 12, ((T([512, 768], f16, stride=(1, 512)), T([768, 768], f16)), {})
+cnt: 12, ((T([768, 512], f16), T([512, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([1, 1, 1, 512], f16), -65504.0), {})
+cnt: 1, ((T([1, 512], i32), T([1, 512], i32)), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([1, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([1, 512, 768], f16), T([1, 512, 768], f16), [768], T([1, 512, 1], f32), T([1, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([1, 512], i64), 1), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([512, 32005], f16), T([512], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([512, 32005], f16), T([512], i64), None, 1, -100), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([1, 1, 1, 512], f16), 1.0), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([512, 32005], f16), [0], True), {})
+cnt: 49, ((T([512, 768], f16), [0], True), {})
+cnt: 12, ((T([512, 3072], f16), [0], True), {})
+cnt: 12, ((T([512, 768], f16, stride=(1, 512)), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaForMaskedLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaForMaskedLM_training.txt
new file mode 100644
index 0000000000000..f3146c3fd934f
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaForMaskedLM_training.txt
@@ -0,0 +1,132 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([2048, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([2048, 50265], f16), T([2048, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([4, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([4, 12, 512, 512], f16), T([4, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 25, ((T([4, 512, 768], f16),), {'dtype': f32})
+cnt: 25, ((T([4, 512, 768], f32),), {'dtype': f16})
+cnt: 1, ((T([4, 512, 1], f32),), {'dtype': f16})
+cnt: 1, ((T([4, 1, 512, 512], f32),), {'dtype': torch.uint8})
+cnt: 12, ((T([], f32),), {'dtype': f16, 'device': "torch.device('cpu')"})
+cnt: 12, ((T([4, 1, 512, 512], u8),), {'dtype': torch.bool})
+cnt: 25, ((T([4, 512, 768], f16),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 25, ((T([4, 512, 768], f32),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([2048, 2304], f16), [4, 512, 2304]), {})
+cnt: 36, ((T([4, 12, 512, 64], f16), [48, 512, 64]), {})
+cnt: 12, ((T([4, 12, 64, 512], f16), [48, 64, 512]), {})
+cnt: 12, ((T([48, 512, 512], f16), [4, 12, 512, 512]), {})
+cnt: 12, ((T([48, 512, 64], f16), [4, 12, 512, 64]), {})
+cnt: 12, ((T([4, 512, 12, 192], f16), [4, 512, 2304]), {})
+Operator: aten.add.Tensor
+cnt: 25, ((T([4, 512, 1], f32), 1e-07), {})
+cnt: 25, ((T([4, 512, 768], f16), T([768], f16)), {})
+cnt: 24, ((T([4, 12, 512, 64], f16, stride=(1179648, 192, 2304, 1)), T([1, 12, 1, 64], f16)), {})
+cnt: 48, ((T([4, 512, 768], f16), T([4, 512, 768], f16)), {})
+cnt: 50, ((T([4, 512, 768], f32), T([4, 512, 768], f32)), {})
+cnt: 25, ((T([4, 512, 1], f32), T([4, 512, 1], f32)), {})
+cnt: 1, ((T([50265, 768], f16), T([50265, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([4, 512, 768], f16), T([1, 512, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 13, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([50265], f16), T([2048, 768], f16), T([768, 50265], f16, stride=(1, 768))), {})
+Operator: aten.bitwise_not.default
+cnt: 12, ((T([4, 1, 512, 512], b8),), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([48, 512, 64], f16), T([48, 64, 512], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16), T([48, 512, 64], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16, stride=(262144, 1, 512)), T([48, 512, 64], f16)), {})
+cnt: 12, ((T([48, 512, 64], f16), T([48, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([48, 64, 512], f16, stride=(32768, 1, 64)), T([48, 512, 512], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16), T([48, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.cat.default
+cnt: 12, (([T([4, 12, 512, 64], f16), T([4, 12, 512, 64], f16, stride=(393216, 32768, 1, 512)), T([4, 12, 512, 64], f16)], 3), {})
+Operator: aten.clone.default
+cnt: 2, ((T([4, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([4, 512], i64), T([4, 512], i64)), {})
+Operator: aten.div.Scalar
+cnt: 50, ((T([4, 512, 768], f32, stride=(512, 1, 0)), 768), {})
+Operator: aten.div.Tensor
+cnt: 100, ((T([4, 512, 768], f32), T([4, 512, 1], f32)), {})
+cnt: 12, ((T([4, 12, 512, 64], f16, stride=(393216, 64, 768, 1)), T([], f16)), {})
+cnt: 25, ((T([4, 512, 1], f32), T([4, 512, 1], f32)), {})
+cnt: 12, ((T([4, 12, 512, 64], f16), T([], f16)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50265, 768], f16), T([4, 512], i64), 0), {})
+cnt: 1, ((T([512, 768], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([4, 512, 768], f16), T([4, 512], i64), 50265, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([4, 512, 3072], f16),), {})
+cnt: 1, ((T([4, 512, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([4, 512, 768], f16), T([4, 512, 768], f16)), {})
+cnt: 12, ((T([4, 512, 3072], f16), T([4, 512, 3072], f16)), {})
+Operator: aten.masked_fill.Tensor
+cnt: 12, ((T([4, 12, 512, 512], f16), T([4, 1, 512, 512], b8), T([], f32)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 12, ((T([4, 12, 512, 512], f16), T([4, 1, 512, 512], b8), 0), {})
+Operator: aten.mean.dim
+cnt: 50, ((T([4, 512, 768], f32), [-1], True), {})
+Operator: aten.mm.default
+cnt: 12, ((T([2048, 768], f16), T([768, 2304], f16, stride=(1, 768))), {})
+cnt: 1, ((T([2048, 50265], f16), T([50265, 768], f16)), {})
+cnt: 1, ((T([50265, 2048], f16, stride=(1, 50265)), T([2048, 768], f16)), {})
+cnt: 13, ((T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 13, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+cnt: 12, ((T([2048, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2304, 2048], f16, stride=(1, 2304)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 2304], f16), T([2304, 768], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 25, ((T([4, 512, 1], f32), 2), {})
+cnt: 25, ((T([4, 512, 768], f32), 2.0), {})
+Operator: aten.mul.Tensor
+cnt: 25, ((T([768], f16), T([4, 512, 768], f16)), {})
+cnt: 2, ((T([4, 512, 768], f16), T([4, 512, 1], f16)), {})
+cnt: 1, ((T([4, 1, 1, 512], f32), T([4, 1, 512, 1], f32)), {})
+cnt: 12, ((T([], f32), 1), {})
+cnt: 25, ((T([4, 512, 768], f16), T([768], f16)), {})
+cnt: 25, ((T([4, 512, 768], f16), T([4, 512, 768], f16)), {})
+cnt: 50, ((T([4, 512, 768], f32), T([4, 512, 768], f32)), {})
+Operator: aten.native_layer_norm.default
+cnt: 1, ((T([4, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-07), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 1, ((T([4, 512, 768], f16), T([4, 512, 768], f16), [768], T([4, 512, 1], f32), T([4, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.neg.default
+cnt: 75, ((T([4, 512, 768], f32),), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([2048, 50265], f16), T([2048], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([2048, 50265], f16), T([2048], i64), None, 1, -100), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 25, ((T([4, 512, 768], f32), 2), {})
+cnt: 25, ((T([4, 512, 768], f32), 1.0), {})
+Operator: aten.slice_backward.default
+cnt: 24, ((T([1, 1, 768], f16), [1, 1, 768], 2, 0, 9223372036854775807, 1), {})
+Operator: aten.split.Tensor
+cnt: 12, ((T([4, 12, 512, 192], f16, stride=(1179648, 192, 2304, 1)), 64, -1), {})
+Operator: aten.sqrt.default
+cnt: 25, ((T([4, 512, 1], f32),), {})
+cnt: 12, ((T([], f32),), {})
+Operator: aten.sub.Tensor
+cnt: 50, ((T([4, 512, 768], f32), T([4, 512, 1], f32)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2048, 50265], f16), [0], True), {})
+cnt: 25, ((T([2048, 768], f16), [0], True), {})
+cnt: 50, ((T([4, 512, 768], f16), [0, 1], True), {})
+cnt: 75, ((T([4, 512, 768], f32), [2], True), {})
+cnt: 12, ((T([2048, 3072], f16), [0], True), {})
+cnt: 24, ((T([4, 12, 512, 64], f16), [0, 2], True), {})
+cnt: 1, ((T([4, 512, 768], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaForQuestionAnswering_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaForQuestionAnswering_training.txt
new file mode 100644
index 0000000000000..cd06e0d09756d
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaForQuestionAnswering_training.txt
@@ -0,0 +1,133 @@
+Operator: aten._log_softmax.default
+cnt: 2, ((T([4, 512], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 2, ((T([4, 512], f16), T([4, 512], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([4, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([4, 12, 512, 512], f16), T([4, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 25, ((T([4, 512, 768], f16),), {'dtype': f32})
+cnt: 25, ((T([4, 512, 768], f32),), {'dtype': f16})
+cnt: 1, ((T([4, 512, 1], f32),), {'dtype': f16})
+cnt: 1, ((T([4, 1, 512, 512], f32),), {'dtype': torch.uint8})
+cnt: 12, ((T([], f32),), {'dtype': f16, 'device': "torch.device('cpu')"})
+cnt: 12, ((T([4, 1, 512, 512], u8),), {'dtype': torch.bool})
+cnt: 25, ((T([4, 512, 768], f16),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 25, ((T([4, 512, 768], f32),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([2048, 2304], f16), [4, 512, 2304]), {})
+cnt: 36, ((T([4, 12, 512, 64], f16), [48, 512, 64]), {})
+cnt: 12, ((T([4, 12, 64, 512], f16), [48, 64, 512]), {})
+cnt: 12, ((T([48, 512, 512], f16), [4, 12, 512, 512]), {})
+cnt: 12, ((T([48, 512, 64], f16), [4, 12, 512, 64]), {})
+cnt: 12, ((T([4, 512, 12, 192], f16), [4, 512, 2304]), {})
+Operator: aten.add.Tensor
+cnt: 25, ((T([4, 512, 1], f32), 1e-07), {})
+cnt: 25, ((T([4, 512, 768], f16), T([768], f16)), {})
+cnt: 24, ((T([4, 12, 512, 64], f16, stride=(1179648, 192, 2304, 1)), T([1, 12, 1, 64], f16)), {})
+cnt: 48, ((T([4, 512, 768], f16), T([4, 512, 768], f16)), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+cnt: 50, ((T([4, 512, 768], f32), T([4, 512, 768], f32)), {})
+cnt: 25, ((T([4, 512, 1], f32), T([4, 512, 1], f32)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([4, 512, 768], f16), T([1, 512, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([2], f16), T([2048, 768], f16), T([768, 2], f16, stride=(1, 768))), {})
+Operator: aten.bitwise_not.default
+cnt: 12, ((T([4, 1, 512, 512], b8),), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([48, 512, 64], f16), T([48, 64, 512], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16), T([48, 512, 64], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16, stride=(262144, 1, 512)), T([48, 512, 64], f16)), {})
+cnt: 12, ((T([48, 512, 64], f16), T([48, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([48, 64, 512], f16, stride=(32768, 1, 64)), T([48, 512, 512], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16), T([48, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([4, 512, 1], f16), T([4, 512, 1], f16)], 2), {})
+cnt: 12, (([T([4, 12, 512, 64], f16), T([4, 12, 512, 64], f16, stride=(393216, 32768, 1, 512)), T([4, 12, 512, 64], f16)], 3), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([4], i64), 0, 512), {})
+Operator: aten.clone.default
+cnt: 1, ((T([4, 512], i64),), {})
+cnt: 2, ((T([4], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([4, 512], i64), T([4, 512], i64)), {})
+cnt: 2, ((T([4], i64), T([4], i64)), {})
+Operator: aten.div.Scalar
+cnt: 50, ((T([4, 512, 768], f32, stride=(512, 1, 0)), 768), {})
+Operator: aten.div.Tensor
+cnt: 100, ((T([4, 512, 768], f32), T([4, 512, 1], f32)), {})
+cnt: 12, ((T([4, 12, 512, 64], f16, stride=(393216, 64, 768, 1)), T([], f16)), {})
+cnt: 2, ((T([], f16), 2), {})
+cnt: 25, ((T([4, 512, 1], f32), T([4, 512, 1], f32)), {})
+cnt: 12, ((T([4, 12, 512, 64], f16), T([], f16)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50265, 768], f16), T([4, 512], i64), 0), {})
+cnt: 1, ((T([512, 768], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([4, 512, 768], f16), T([4, 512], i64), 50265, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([4, 512, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([4, 512, 3072], f16), T([4, 512, 3072], f16)), {})
+Operator: aten.masked_fill.Tensor
+cnt: 12, ((T([4, 12, 512, 512], f16), T([4, 1, 512, 512], b8), T([], f32)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 12, ((T([4, 12, 512, 512], f16), T([4, 1, 512, 512], b8), 0), {})
+Operator: aten.mean.dim
+cnt: 50, ((T([4, 512, 768], f32), [-1], True), {})
+Operator: aten.mm.default
+cnt: 12, ((T([2048, 768], f16), T([768, 2304], f16, stride=(1, 768))), {})
+cnt: 1, ((T([2048, 2], f16), T([2, 768], f16)), {})
+cnt: 1, ((T([2, 2048], f16, stride=(1, 2)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+cnt: 12, ((T([2048, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2304, 2048], f16, stride=(1, 2304)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 2304], f16), T([2304, 768], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 25, ((T([4, 512, 1], f32), 2), {})
+cnt: 25, ((T([4, 512, 768], f32), 2.0), {})
+Operator: aten.mul.Tensor
+cnt: 25, ((T([768], f16), T([4, 512, 768], f16)), {})
+cnt: 2, ((T([4, 512, 768], f16), T([4, 512, 1], f16)), {})
+cnt: 1, ((T([4, 1, 1, 512], f32), T([4, 1, 512, 1], f32)), {})
+cnt: 12, ((T([], f32), 1), {})
+cnt: 25, ((T([4, 512, 768], f16), T([768], f16)), {})
+cnt: 25, ((T([4, 512, 768], f16), T([4, 512, 768], f16)), {})
+cnt: 50, ((T([4, 512, 768], f32), T([4, 512, 768], f32)), {})
+Operator: aten.neg.default
+cnt: 75, ((T([4, 512, 768], f32),), {})
+Operator: aten.nll_loss_backward.default
+cnt: 2, ((T([], f16), T([4, 512], f16), T([4], i64), None, 1, 512, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 2, ((T([4, 512], f16), T([4], i64), None, 1, 512), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 25, ((T([4, 512, 768], f32), 2), {})
+cnt: 25, ((T([4, 512, 768], f32), 1.0), {})
+Operator: aten.slice_backward.default
+cnt: 24, ((T([1, 1, 768], f16), [1, 1, 768], 2, 0, 9223372036854775807, 1), {})
+Operator: aten.split.Tensor
+cnt: 12, ((T([4, 12, 512, 192], f16, stride=(1179648, 192, 2304, 1)), 64, -1), {})
+cnt: 1, ((T([4, 512, 2], f16), 1, -1), {})
+Operator: aten.sqrt.default
+cnt: 25, ((T([4, 512, 1], f32),), {})
+cnt: 12, ((T([], f32),), {})
+Operator: aten.sub.Tensor
+cnt: 50, ((T([4, 512, 768], f32), T([4, 512, 1], f32)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2048, 2], f16), [0], True), {})
+cnt: 50, ((T([4, 512, 768], f16), [0, 1], True), {})
+cnt: 75, ((T([4, 512, 768], f32), [2], True), {})
+cnt: 24, ((T([2048, 768], f16), [0], True), {})
+cnt: 12, ((T([2048, 3072], f16), [0], True), {})
+cnt: 24, ((T([4, 12, 512, 64], f16), [0, 2], True), {})
+cnt: 1, ((T([4, 512, 768], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaV2ForMaskedLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaV2ForMaskedLM_training.txt
new file mode 100644
index 0000000000000..157e119eeefc0
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaV2ForMaskedLM_training.txt
@@ -0,0 +1,85 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([512, 128100], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([512, 128100], f16), T([512, 128100], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([1, 24, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([1, 24, 512, 512], f16), T([1, 24, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1, 512, 1], f32),), {'dtype': f16})
+cnt: 1, ((T([1, 1, 512, 512], f32),), {'dtype': torch.uint8})
+cnt: 24, ((T([], f32),), {'dtype': f16, 'device': "torch.device('cpu')"})
+cnt: 24, ((T([1, 1, 512, 512], u8),), {'dtype': torch.bool})
+Operator: aten._unsafe_view.default
+cnt: 48, ((T([1, 512, 24, 64], f16), [1, 512, 1536]), {})
+Operator: aten.add.Tensor
+cnt: 144, ((T([1, 512, 1536], f16), T([1, 512, 1536], f16)), {})
+cnt: 1, ((T([128100, 1536], f16), T([128100, 1536], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([1, 512, 1536], f16), T([1, 512, 1536], f16)), {})
+Operator: aten.addmm.default
+cnt: 97, ((T([1536], f16), T([512, 1536], f16), T([1536, 1536], f16, stride=(1, 1536))), {})
+cnt: 24, ((T([6144], f16), T([512, 1536], f16), T([1536, 6144], f16, stride=(1, 1536))), {})
+cnt: 24, ((T([1536], f16), T([512, 6144], f16), T([6144, 1536], f16, stride=(1, 6144))), {})
+cnt: 1, ((T([128100], f16), T([512, 1536], f16), T([1536, 128100], f16, stride=(1, 1536))), {})
+Operator: aten.bitwise_not.default
+cnt: 24, ((T([1, 1, 512, 512], b8),), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([24, 512, 64], f16), T([24, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 48, ((T([24, 512, 512], f16), T([24, 512, 64], f16)), {})
+cnt: 24, ((T([24, 512, 512], f16, stride=(262144, 1, 512)), T([24, 512, 64], f16, stride=(64, 1536, 1))), {})
+cnt: 24, ((T([24, 512, 64], f16, stride=(64, 1536, 1)), T([24, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 24, ((T([24, 64, 512], f16, stride=(32768, 1, 64)), T([24, 512, 512], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([1, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([1, 512], i64), T([1, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 48, ((T([24, 512, 512], f16), T([], f16)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([128100, 1536], f16), T([1, 512], i64), 0), {})
+cnt: 1, ((T([512, 1536], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 1536], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([1, 512, 1536], f16), T([1, 512], i64), 128100, 0, False), {})
+Operator: aten.gelu.default
+cnt: 24, ((T([1, 512, 6144], f16),), {})
+cnt: 1, ((T([1, 512, 1536], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([1, 512, 1536], f16), T([1, 512, 1536], f16)), {})
+cnt: 24, ((T([1, 512, 6144], f16), T([1, 512, 6144], f16)), {})
+Operator: aten.masked_fill.Tensor
+cnt: 24, ((T([1, 24, 512, 512], f16), T([1, 1, 512, 512], b8), T([], f32)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 24, ((T([1, 24, 512, 512], f16), T([1, 1, 512, 512], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 128100], f16), T([128100, 1536], f16)), {})
+cnt: 1, ((T([128100, 512], f16, stride=(1, 128100)), T([512, 1536], f16)), {})
+cnt: 73, ((T([512, 1536], f16), T([1536, 1536], f16)), {})
+cnt: 73, ((T([1536, 512], f16, stride=(1, 1536)), T([512, 1536], f16)), {})
+cnt: 24, ((T([512, 1536], f16), T([1536, 6144], f16)), {})
+cnt: 24, ((T([1536, 512], f16, stride=(1, 1536)), T([512, 6144], f16)), {})
+cnt: 24, ((T([512, 6144], f16), T([6144, 1536], f16)), {})
+cnt: 24, ((T([6144, 512], f16, stride=(1, 6144)), T([512, 1536], f16)), {})
+cnt: 24, ((T([512, 1536], f16, stride=(1, 512)), T([1536, 1536], f16)), {})
+cnt: 24, ((T([1536, 512], f16), T([512, 1536], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([1, 512, 1536], f16), T([1, 512, 1], f16)), {})
+cnt: 1, ((T([1, 1, 1, 512], f32), T([1, 1, 512, 1], f32)), {})
+cnt: 24, ((T([], f32), 1), {})
+Operator: aten.native_layer_norm.default
+cnt: 50, ((T([1, 512, 1536], f16), [1536], T([1536], f16), T([1536], f16), 1e-07), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 50, ((T([1, 512, 1536], f16), T([1, 512, 1536], f16), [1536], T([1, 512, 1], f32), T([1, 512, 1], f32), T([1536], f16), T([1536], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([512, 128100], f16), T([512], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([512, 128100], f16), T([512], i64), None, 1, -100), {})
+Operator: aten.sqrt.default
+cnt: 24, ((T([], f32),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([512, 128100], f16), [0], True), {})
+cnt: 97, ((T([512, 1536], f16), [0], True), {})
+cnt: 24, ((T([512, 6144], f16), [0], True), {})
+cnt: 24, ((T([512, 1536], f16, stride=(1, 512)), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaV2ForQuestionAnswering_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaV2ForQuestionAnswering_training.txt
new file mode 100644
index 0000000000000..94ffa58562aa6
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DebertaV2ForQuestionAnswering_training.txt
@@ -0,0 +1,92 @@
+Operator: aten._log_softmax.default
+cnt: 2, ((T([1, 512], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 2, ((T([1, 512], f16), T([1, 512], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([1, 24, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([1, 24, 512, 512], f16), T([1, 24, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1, 512, 1], f32),), {'dtype': f16})
+cnt: 1, ((T([1, 1, 512, 512], f32),), {'dtype': torch.uint8})
+cnt: 24, ((T([], f32),), {'dtype': f16, 'device': "torch.device('cpu')"})
+cnt: 24, ((T([1, 1, 512, 512], u8),), {'dtype': torch.bool})
+Operator: aten._unsafe_view.default
+cnt: 48, ((T([1, 512, 24, 64], f16), [1, 512, 1536]), {})
+Operator: aten.add.Tensor
+cnt: 144, ((T([1, 512, 1536], f16), T([1, 512, 1536], f16)), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([1, 512, 1536], f16), T([1, 512, 1536], f16)), {})
+Operator: aten.addmm.default
+cnt: 96, ((T([1536], f16), T([512, 1536], f16), T([1536, 1536], f16, stride=(1, 1536))), {})
+cnt: 24, ((T([6144], f16), T([512, 1536], f16), T([1536, 6144], f16, stride=(1, 1536))), {})
+cnt: 24, ((T([1536], f16), T([512, 6144], f16), T([6144, 1536], f16, stride=(1, 6144))), {})
+cnt: 1, ((T([2], f16), T([512, 1536], f16), T([1536, 2], f16, stride=(1, 1536))), {})
+Operator: aten.bitwise_not.default
+cnt: 24, ((T([1, 1, 512, 512], b8),), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([24, 512, 64], f16), T([24, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 48, ((T([24, 512, 512], f16), T([24, 512, 64], f16)), {})
+cnt: 24, ((T([24, 512, 512], f16, stride=(262144, 1, 512)), T([24, 512, 64], f16, stride=(64, 1536, 1))), {})
+cnt: 24, ((T([24, 512, 64], f16, stride=(64, 1536, 1)), T([24, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 24, ((T([24, 64, 512], f16, stride=(32768, 1, 64)), T([24, 512, 512], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([1, 512, 1], f16), T([1, 512, 1], f16)], 2), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([1], i64), 0, 512), {})
+Operator: aten.clone.default
+cnt: 1, ((T([1, 512], i64),), {})
+cnt: 2, ((T([1], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([1, 512], i64), T([1, 512], i64)), {})
+cnt: 2, ((T([1], i64), T([1], i64)), {})
+Operator: aten.div.Tensor
+cnt: 48, ((T([24, 512, 512], f16), T([], f16)), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([128100, 1536], f16), T([1, 512], i64), 0), {})
+cnt: 1, ((T([512, 1536], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 1536], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([1, 512, 1536], f16), T([1, 512], i64), 128100, 0, False), {})
+Operator: aten.gelu.default
+cnt: 24, ((T([1, 512, 6144], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 24, ((T([1, 512, 6144], f16), T([1, 512, 6144], f16)), {})
+Operator: aten.masked_fill.Tensor
+cnt: 24, ((T([1, 24, 512, 512], f16), T([1, 1, 512, 512], b8), T([], f32)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 24, ((T([1, 24, 512, 512], f16), T([1, 1, 512, 512], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 2], f16), T([2, 1536], f16)), {})
+cnt: 1, ((T([2, 512], f16, stride=(1, 2)), T([512, 1536], f16)), {})
+cnt: 24, ((T([512, 1536], f16), T([1536, 6144], f16)), {})
+cnt: 24, ((T([1536, 512], f16, stride=(1, 1536)), T([512, 6144], f16)), {})
+cnt: 24, ((T([512, 6144], f16), T([6144, 1536], f16)), {})
+cnt: 24, ((T([6144, 512], f16, stride=(1, 6144)), T([512, 1536], f16)), {})
+cnt: 72, ((T([512, 1536], f16), T([1536, 1536], f16)), {})
+cnt: 72, ((T([1536, 512], f16, stride=(1, 1536)), T([512, 1536], f16)), {})
+cnt: 24, ((T([512, 1536], f16, stride=(1, 512)), T([1536, 1536], f16)), {})
+cnt: 24, ((T([1536, 512], f16), T([512, 1536], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([1, 512, 1536], f16), T([1, 512, 1], f16)), {})
+cnt: 1, ((T([1, 1, 1, 512], f32), T([1, 1, 512, 1], f32)), {})
+cnt: 24, ((T([], f32), 1), {})
+Operator: aten.native_layer_norm.default
+cnt: 49, ((T([1, 512, 1536], f16), [1536], T([1536], f16), T([1536], f16), 1e-07), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 49, ((T([1, 512, 1536], f16), T([1, 512, 1536], f16), [1536], T([1, 512, 1], f32), T([1, 512, 1], f32), T([1536], f16), T([1536], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 2, ((T([], f16), T([1, 512], f16), T([1], i64), None, 1, 512, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 2, ((T([1, 512], f16), T([1], i64), None, 1, 512), {})
+Operator: aten.split.Tensor
+cnt: 1, ((T([1, 512, 2], f16), 1, -1), {})
+Operator: aten.sqrt.default
+cnt: 24, ((T([], f32),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([512, 2], f16), [0], True), {})
+cnt: 96, ((T([512, 1536], f16), [0], True), {})
+cnt: 24, ((T([512, 6144], f16), [0], True), {})
+cnt: 24, ((T([512, 1536], f16, stride=(1, 512)), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DistilBertForMaskedLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DistilBertForMaskedLM_training.txt
new file mode 100644
index 0000000000000..37d0d4707d8af
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DistilBertForMaskedLM_training.txt
@@ -0,0 +1,78 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([2048, 30522], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([2048, 30522], f16), T([2048, 30522], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 6, ((T([16, 12, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 6, ((T([16, 12, 128, 128], f16), T([16, 12, 128, 128], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 18, ((T([16, 12, 128, 64], f16), [192, 128, 64]), {})
+cnt: 6, ((T([16, 12, 64, 128], f16), [192, 64, 128]), {})
+cnt: 6, ((T([192, 128, 128], f16), [16, 12, 128, 128]), {})
+cnt: 6, ((T([192, 128, 64], f16), [16, 12, 128, 64]), {})
+cnt: 12, ((T([16, 128, 12, 64], f16), [16, 128, 768]), {})
+cnt: 6, ((T([16, 128, 768], f16), [2048, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([16, 128, 768], f16), T([1, 128, 768], f16)), {})
+cnt: 36, ((T([16, 128, 768], f16), T([16, 128, 768], f16)), {})
+cnt: 1, ((T([30522, 768], f16), T([30522, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 25, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 6, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 6, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([30522], f16), T([2048, 768], f16), T([768, 30522], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 6, ((T([192, 128, 64], f16), T([192, 64, 128], f16)), {})
+cnt: 6, ((T([192, 128, 128], f16), T([192, 128, 64], f16)), {})
+cnt: 6, ((T([192, 128, 128], f16, stride=(16384, 1, 128)), T([192, 128, 64], f16)), {})
+cnt: 6, ((T([192, 128, 64], f16), T([192, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 6, ((T([192, 64, 128], f16, stride=(8192, 1, 64)), T([192, 128, 128], f16)), {})
+cnt: 6, ((T([192, 128, 128], f16), T([192, 128, 64], f16, stride=(8192, 1, 128))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([16, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([16, 128], i64), T([16, 128], i64)), {})
+Operator: aten.div.Tensor
+cnt: 6, ((T([16, 12, 128, 64], f16, stride=(98304, 64, 768, 1)), 8.0), {})
+cnt: 6, ((T([16, 12, 128, 64], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([16, 128], i64), 0), {})
+cnt: 1, ((T([512, 768], f16), T([1, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 128, 768], f16), T([1, 128], i64), 512, -1, False), {})
+cnt: 1, ((T([16, 128, 768], f16), T([16, 128], i64), 30522, 0, False), {})
+Operator: aten.eq.Scalar
+cnt: 6, ((T([16, 128], f32), 0), {})
+Operator: aten.gelu.default
+cnt: 6, ((T([16, 128, 3072], f16),), {})
+cnt: 1, ((T([16, 128, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([16, 128, 768], f16), T([16, 128, 768], f16)), {})
+cnt: 6, ((T([16, 128, 3072], f16), T([16, 128, 3072], f16)), {})
+Operator: aten.masked_fill.Scalar
+cnt: 6, ((T([16, 12, 128, 128], f16), T([16, 12, 128, 128], b8, stride=(128, 0, 0, 1)), 0), {})
+Operator: aten.masked_fill.Tensor
+cnt: 6, ((T([16, 12, 128, 128], f16), T([16, 12, 128, 128], b8, stride=(128, 0, 0, 1)), T([], f32)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 30522], f16), T([30522, 768], f16)), {})
+cnt: 1, ((T([30522, 2048], f16, stride=(1, 30522)), T([2048, 768], f16)), {})
+cnt: 25, ((T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 25, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+cnt: 6, ((T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 6, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+cnt: 6, ((T([2048, 3072], f16), T([3072, 768], f16)), {})
+cnt: 6, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 14, ((T([16, 128, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 14, ((T([16, 128, 768], f16), T([16, 128, 768], f16), [768], T([16, 128, 1], f32), T([16, 128, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([2048, 30522], f16), T([2048], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([2048, 30522], f16), T([2048], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2048, 30522], f16), [0], True), {})
+cnt: 31, ((T([2048, 768], f16), [0], True), {})
+cnt: 6, ((T([2048, 3072], f16), [0], True), {})
+cnt: 1, ((T([16, 128, 768], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DistilBertForQuestionAnswering_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DistilBertForQuestionAnswering_training.txt
new file mode 100644
index 0000000000000..350ed80182bdc
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DistilBertForQuestionAnswering_training.txt
@@ -0,0 +1,85 @@
+Operator: aten._log_softmax.default
+cnt: 2, ((T([32, 128], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 2, ((T([32, 128], f16), T([32, 128], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 6, ((T([32, 12, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 6, ((T([32, 12, 128, 128], f16), T([32, 12, 128, 128], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 18, ((T([32, 12, 128, 64], f16), [384, 128, 64]), {})
+cnt: 6, ((T([32, 12, 64, 128], f16), [384, 64, 128]), {})
+cnt: 6, ((T([384, 128, 128], f16), [32, 12, 128, 128]), {})
+cnt: 6, ((T([384, 128, 64], f16), [32, 12, 128, 64]), {})
+cnt: 12, ((T([32, 128, 12, 64], f16), [32, 128, 768]), {})
+cnt: 6, ((T([32, 128, 768], f16), [4096, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([32, 128, 768], f16), T([1, 128, 768], f16)), {})
+cnt: 36, ((T([32, 128, 768], f16), T([32, 128, 768], f16)), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+Operator: aten.addmm.default
+cnt: 24, ((T([768], f16), T([4096, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 6, ((T([3072], f16), T([4096, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 6, ((T([768], f16), T([4096, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([2], f16), T([4096, 768], f16), T([768, 2], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 6, ((T([384, 128, 64], f16), T([384, 64, 128], f16)), {})
+cnt: 6, ((T([384, 128, 128], f16), T([384, 128, 64], f16)), {})
+cnt: 6, ((T([384, 128, 128], f16, stride=(16384, 1, 128)), T([384, 128, 64], f16)), {})
+cnt: 6, ((T([384, 128, 64], f16), T([384, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 6, ((T([384, 64, 128], f16, stride=(8192, 1, 64)), T([384, 128, 128], f16)), {})
+cnt: 6, ((T([384, 128, 128], f16), T([384, 128, 64], f16, stride=(8192, 1, 128))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([32, 128, 1], f16), T([32, 128, 1], f16)], 2), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([32], i64), 0, 128), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 128], i64),), {})
+cnt: 2, ((T([32], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 128], i64), T([32, 128], i64)), {})
+cnt: 2, ((T([32], i64), T([32], i64)), {})
+Operator: aten.div.Tensor
+cnt: 6, ((T([32, 12, 128, 64], f16, stride=(98304, 64, 768, 1)), 8.0), {})
+cnt: 2, ((T([], f16), 2), {})
+cnt: 6, ((T([32, 12, 128, 64], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([32, 128], i64), 0), {})
+cnt: 1, ((T([512, 768], f16), T([1, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 128, 768], f16), T([1, 128], i64), 512, -1, False), {})
+cnt: 1, ((T([32, 128, 768], f16), T([32, 128], i64), 30522, 0, False), {})
+Operator: aten.eq.Scalar
+cnt: 6, ((T([32, 128], f32), 0), {})
+Operator: aten.gelu.default
+cnt: 6, ((T([32, 128, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 6, ((T([32, 128, 3072], f16), T([32, 128, 3072], f16)), {})
+Operator: aten.masked_fill.Scalar
+cnt: 6, ((T([32, 12, 128, 128], f16), T([32, 12, 128, 128], b8, stride=(128, 0, 0, 1)), 0), {})
+Operator: aten.masked_fill.Tensor
+cnt: 6, ((T([32, 12, 128, 128], f16), T([32, 12, 128, 128], b8, stride=(128, 0, 0, 1)), T([], f32)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([4096, 2], f16), T([2, 768], f16)), {})
+cnt: 1, ((T([2, 4096], f16, stride=(1, 2)), T([4096, 768], f16)), {})
+cnt: 6, ((T([4096, 768], f16), T([768, 3072], f16)), {})
+cnt: 6, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 3072], f16)), {})
+cnt: 6, ((T([4096, 3072], f16), T([3072, 768], f16)), {})
+cnt: 6, ((T([3072, 4096], f16, stride=(1, 3072)), T([4096, 768], f16)), {})
+cnt: 24, ((T([4096, 768], f16), T([768, 768], f16)), {})
+cnt: 24, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 768], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 13, ((T([32, 128, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 13, ((T([32, 128, 768], f16), T([32, 128, 768], f16), [768], T([32, 128, 1], f32), T([32, 128, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 2, ((T([], f16), T([32, 128], f16), T([32], i64), None, 1, 128, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 2, ((T([32, 128], f16), T([32], i64), None, 1, 128), {})
+Operator: aten.split.Tensor
+cnt: 1, ((T([32, 128, 2], f16), 1, -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([4096, 2], f16), [0], True), {})
+cnt: 30, ((T([4096, 768], f16), [0], True), {})
+cnt: 6, ((T([4096, 3072], f16), [0], True), {})
+cnt: 1, ((T([32, 128, 768], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DistillGPT2_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DistillGPT2_training.txt
new file mode 100644
index 0000000000000..5654c4bbd4d9f
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/DistillGPT2_training.txt
@@ -0,0 +1,91 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([511, 50257], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([511, 50257], f16), T([511, 50257], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 6, ((T([1, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 6, ((T([1, 12, 512, 512], f16), T([1, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 6, ((T([1, 1, 512, 512], u8, stride=(1048576, 1048576, 1024, 1)),), {'dtype': torch.bool})
+cnt: 6, ((T([], f16),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 6, ((T([12, 512, 512], f16), [1, 12, 512, 512]), {})
+cnt: 6, ((T([12, 512, 64], f16), [1, 12, 512, 64]), {})
+cnt: 1, ((T([512, 50257], f16), [1, 512, 50257]), {})
+cnt: 12, ((T([1, 512, 12, 64], f16), [1, 512, 768]), {})
+Operator: aten.add.Tensor
+cnt: 25, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 18, ((T([1, 512, 3072], f16), T([1, 512, 3072], f16)), {})
+cnt: 6, ((T([1, 512, 3072], f16), 1.0), {})
+cnt: 1, ((T([50257, 768], f16), T([50257, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 6, ((T([2304], f16), T([512, 768], f16), T([768, 2304], f16)), {})
+cnt: 6, ((T([768], f16), T([512, 768], f16), T([768, 768], f16)), {})
+cnt: 6, ((T([3072], f16), T([512, 768], f16), T([768, 3072], f16)), {})
+cnt: 6, ((T([768], f16), T([512, 3072], f16), T([3072, 768], f16)), {})
+Operator: aten.bmm.default
+cnt: 6, ((T([12, 512, 64], f16, stride=(64, 2304, 1)), T([12, 64, 512], f16, stride=(64, 1, 2304))), {})
+cnt: 12, ((T([12, 512, 512], f16), T([12, 512, 64], f16, stride=(64, 2304, 1))), {})
+cnt: 6, ((T([12, 512, 512], f16, stride=(262144, 1, 512)), T([12, 512, 64], f16, stride=(64, 768, 1))), {})
+cnt: 6, ((T([12, 512, 64], f16, stride=(64, 768, 1)), T([12, 64, 512], f16, stride=(64, 1, 2304))), {})
+cnt: 6, ((T([12, 64, 512], f16, stride=(64, 1, 2304)), T([12, 512, 512], f16)), {})
+Operator: aten.cat.default
+cnt: 6, (([T([1, 512, 768], f16), T([1, 512, 768], f16, stride=(512, 1, 512)), T([1, 512, 768], f16)], 2), {})
+Operator: aten.clone.default
+cnt: 2, ((T([1, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([1, 512], i64), T([1, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 12, ((T([1, 12, 512, 512], f16), T([], f16)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50257, 768], f16), T([1, 512], i64)), {})
+cnt: 1, ((T([1024, 768], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 1024, -1, False), {})
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 50257, -1, False), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 768], f16), T([768, 50257], f16, stride=(1, 768))), {})
+cnt: 1, ((T([50257, 512], f16, stride=(1, 50257)), T([512, 768], f16)), {})
+cnt: 1, ((T([512, 50257], f16), T([50257, 768], f16)), {})
+cnt: 6, ((T([512, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 6, ((T([3072, 512], f16, stride=(1, 3072)), T([512, 768], f16)), {})
+cnt: 6, ((T([512, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 6, ((T([768, 512], f16, stride=(1, 768)), T([512, 3072], f16)), {})
+cnt: 6, ((T([512, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 6, ((T([768, 512], f16, stride=(1, 768)), T([512, 768], f16)), {})
+cnt: 6, ((T([512, 2304], f16), T([2304, 768], f16, stride=(1, 2304))), {})
+cnt: 6, ((T([768, 512], f16, stride=(1, 768)), T([512, 2304], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 6, ((T([1, 512, 3072], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 12, ((T([1, 512, 3072], f16), 0.5), {})
+cnt: 12, ((T([1, 512, 3072], f16), 0.044715), {})
+cnt: 12, ((T([1, 512, 3072], f16), 0.7978845608028654), {})
+cnt: 24, ((T([1, 512, 3072], f16), T([1, 512, 3072], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 13, ((T([1, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 13, ((T([1, 512, 768], f16), T([1, 512, 768], f16), [768], T([1, 512, 1], f32), T([1, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([511, 50257], f16), T([511], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([511, 50257], f16), T([511], i64), None, 1, -100), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 6, ((T([1, 512, 3072], f16), 3.0), {})
+cnt: 6, ((T([1, 512, 3072], f16), 2.0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([1, 511, 50257], f16), [1, 511, 50257], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([1, 511, 50257], f16), [1, 512, 50257], 1, 0, -1, 1), {})
+Operator: aten.split.Tensor
+cnt: 6, ((T([1, 512, 2304], f16), 768, 2), {})
+Operator: aten.sum.SymInt
+cnt: 12, ((T([512, 768], f16), [0], True), {})
+cnt: 6, ((T([512, 3072], f16), [0], True), {})
+cnt: 6, ((T([512, 2304], f16), [0], True), {})
+Operator: aten.tanh.default
+cnt: 6, ((T([1, 512, 3072], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 6, ((T([1, 512, 3072], f16), T([1, 512, 3072], f16)), {})
+Operator: aten.where.self
+cnt: 12, ((T([1, 1, 512, 512], b8), T([1, 12, 512, 512], f16), T([], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/ElectraForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/ElectraForCausalLM_training.txt
new file mode 100644
index 0000000000000..adbb45be62697
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/ElectraForCausalLM_training.txt
@@ -0,0 +1,92 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([511, 30522], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([511, 30522], f16), T([511, 30522], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([1, 4, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([1, 4, 512, 512], f16), T([1, 4, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1, 1, 1, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([4, 512, 512], f16), [1, 4, 512, 512]), {})
+cnt: 12, ((T([4, 512, 64], f16), [1, 4, 512, 64]), {})
+cnt: 24, ((T([1, 512, 4, 64], f16), [1, 512, 256]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512, 128], f16)), {})
+cnt: 12, ((T([1, 4, 512, 512], f16), T([1, 1, 1, 512], f16)), {})
+cnt: 72, ((T([1, 512, 256], f16), T([1, 512, 256], f16)), {})
+cnt: 1, ((T([30522, 128], f16), T([30522, 128], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([256], f16), T([512, 128], f16), T([128, 256], f16, stride=(1, 128))), {})
+cnt: 48, ((T([256], f16), T([512, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 12, ((T([1024], f16), T([512, 256], f16), T([256, 1024], f16, stride=(1, 256))), {})
+cnt: 12, ((T([256], f16), T([512, 1024], f16), T([1024, 256], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([128], f16), T([512, 256], f16), T([256, 128], f16, stride=(1, 256))), {})
+cnt: 1, ((T([30522], f16), T([512, 128], f16), T([128, 30522], f16, stride=(1, 128))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([4, 512, 64], f16, stride=(64, 256, 1)), T([4, 64, 512], f16, stride=(64, 1, 256))), {})
+cnt: 24, ((T([4, 512, 512], f16), T([4, 512, 64], f16, stride=(64, 256, 1))), {})
+cnt: 12, ((T([4, 512, 512], f16, stride=(262144, 1, 512)), T([4, 512, 64], f16, stride=(64, 256, 1))), {})
+cnt: 12, ((T([4, 64, 512], f16, stride=(64, 1, 256)), T([4, 512, 512], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([1, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([1, 512], i64), T([1, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([1, 4, 512, 512], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 128], f16), T([1, 512], i64), 0), {})
+cnt: 1, ((T([2, 128], f16), T([1, 512], i64)), {})
+cnt: 1, ((T([512, 128], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512], i64), 2, -1, False), {})
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([1, 512, 1024], f16),), {})
+cnt: 1, ((T([1, 512, 128], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512, 128], f16)), {})
+cnt: 12, ((T([1, 512, 1024], f16), T([1, 512, 1024], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 30522], f16), T([30522, 128], f16)), {})
+cnt: 1, ((T([30522, 512], f16, stride=(1, 30522)), T([512, 128], f16)), {})
+cnt: 1, ((T([512, 128], f16), T([128, 256], f16)), {})
+cnt: 1, ((T([128, 512], f16, stride=(1, 128)), T([512, 256], f16)), {})
+cnt: 12, ((T([512, 256], f16), T([256, 1024], f16)), {})
+cnt: 12, ((T([256, 512], f16, stride=(1, 256)), T([512, 1024], f16)), {})
+cnt: 12, ((T([512, 1024], f16), T([1024, 256], f16)), {})
+cnt: 12, ((T([1024, 512], f16, stride=(1, 1024)), T([512, 256], f16)), {})
+cnt: 36, ((T([512, 256], f16), T([256, 256], f16)), {})
+cnt: 36, ((T([256, 512], f16, stride=(1, 256)), T([512, 256], f16)), {})
+cnt: 12, ((T([512, 256], f16, stride=(1, 512)), T([256, 256], f16)), {})
+cnt: 12, ((T([256, 512], f16), T([512, 256], f16)), {})
+cnt: 1, ((T([512, 256], f16), T([256, 128], f16)), {})
+cnt: 1, ((T([256, 512], f16, stride=(1, 256)), T([512, 128], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([1, 1, 1, 512], f16), -65504.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 2, ((T([1, 512, 128], f16), [128], T([128], f16), T([128], f16), 1e-12), {})
+cnt: 24, ((T([1, 512, 256], f16), [256], T([256], f16), T([256], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 2, ((T([1, 512, 128], f16), T([1, 512, 128], f16), [128], T([1, 512, 1], f32), T([1, 512, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 24, ((T([1, 512, 256], f16), T([1, 512, 256], f16), [256], T([1, 512, 1], f32), T([1, 512, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([511, 30522], f16), T([511], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([511, 30522], f16), T([511], i64), None, 1, -100), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([1, 1, 1, 512], f16), 1.0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([1, 511, 30522], f16), [1, 511, 30522], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([1, 511, 30522], f16), [1, 512, 30522], 1, 0, -1, 1), {})
+cnt: 1, ((T([1, 512, 30522], f16), [1, 512, 30522], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([512, 30522], f16), [0], True), {})
+cnt: 1, ((T([512, 128], f16), [0], True), {})
+cnt: 49, ((T([512, 256], f16), [0], True), {})
+cnt: 12, ((T([512, 1024], f16), [0], True), {})
+cnt: 12, ((T([512, 256], f16, stride=(1, 512)), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/ElectraForQuestionAnswering_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/ElectraForQuestionAnswering_training.txt
new file mode 100644
index 0000000000000..c2e4a8beb5222
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/ElectraForQuestionAnswering_training.txt
@@ -0,0 +1,94 @@
+Operator: aten._log_softmax.default
+cnt: 2, ((T([64, 512], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 2, ((T([64, 512], f16), T([64, 512], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([64, 4, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([64, 4, 512, 512], f16), T([64, 4, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([64, 1, 1, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([64, 4, 512, 64], f16), [256, 512, 64]), {})
+cnt: 12, ((T([64, 4, 64, 512], f16), [256, 64, 512]), {})
+cnt: 12, ((T([256, 512, 512], f16), [64, 4, 512, 512]), {})
+cnt: 12, ((T([256, 512, 64], f16), [64, 4, 512, 64]), {})
+cnt: 24, ((T([64, 512, 4, 64], f16), [64, 512, 256]), {})
+cnt: 12, ((T([64, 512, 256], f16), [32768, 256]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([64, 512, 128], f16), T([64, 512, 128], f16)), {})
+cnt: 12, ((T([64, 4, 512, 512], f16), T([64, 1, 1, 512], f16)), {})
+cnt: 72, ((T([64, 512, 256], f16), T([64, 512, 256], f16)), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([64, 512, 128], f16), T([1, 512, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([256], f16), T([32768, 128], f16), T([128, 256], f16, stride=(1, 128))), {})
+cnt: 48, ((T([256], f16), T([32768, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 12, ((T([1024], f16), T([32768, 256], f16), T([256, 1024], f16, stride=(1, 256))), {})
+cnt: 12, ((T([256], f16), T([32768, 1024], f16), T([1024, 256], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([2], f16), T([32768, 256], f16), T([256, 2], f16, stride=(1, 256))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([256, 512, 64], f16), T([256, 64, 512], f16)), {})
+cnt: 12, ((T([256, 512, 512], f16), T([256, 512, 64], f16)), {})
+cnt: 12, ((T([256, 512, 512], f16, stride=(262144, 1, 512)), T([256, 512, 64], f16)), {})
+cnt: 12, ((T([256, 512, 64], f16), T([256, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([256, 64, 512], f16, stride=(32768, 1, 64)), T([256, 512, 512], f16)), {})
+cnt: 12, ((T([256, 512, 512], f16), T([256, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 512, 1], f16), T([64, 512, 1], f16)], 2), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([64], i64), 0, 512), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 512], i64),), {})
+cnt: 2, ((T([64], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 512], i64), T([64, 512], i64)), {})
+cnt: 2, ((T([64], i64), T([64], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([64, 4, 512, 512], f16), 8.0), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 128], f16), T([64, 512], i64), 0), {})
+cnt: 1, ((T([2, 128], f16), T([64, 512], i64, stride=(0, 1))), {})
+cnt: 1, ((T([512, 128], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([64, 512, 128], f16), T([64, 512], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([64, 512, 128], f16), T([64, 512], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([64, 512, 1024], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([64, 512, 1024], f16), T([64, 512, 1024], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32768, 2], f16), T([2, 256], f16)), {})
+cnt: 1, ((T([2, 32768], f16, stride=(1, 2)), T([32768, 256], f16)), {})
+cnt: 12, ((T([32768, 256], f16), T([256, 1024], f16)), {})
+cnt: 12, ((T([256, 32768], f16, stride=(1, 256)), T([32768, 1024], f16)), {})
+cnt: 12, ((T([32768, 1024], f16), T([1024, 256], f16)), {})
+cnt: 12, ((T([1024, 32768], f16, stride=(1, 1024)), T([32768, 256], f16)), {})
+cnt: 48, ((T([32768, 256], f16), T([256, 256], f16)), {})
+cnt: 48, ((T([256, 32768], f16, stride=(1, 256)), T([32768, 256], f16)), {})
+cnt: 1, ((T([32768, 256], f16), T([256, 128], f16)), {})
+cnt: 1, ((T([256, 32768], f16, stride=(1, 256)), T([32768, 128], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([64, 1, 1, 512], f16), -65504.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 1, ((T([64, 512, 128], f16), [128], T([128], f16), T([128], f16), 1e-12), {})
+cnt: 24, ((T([64, 512, 256], f16), [256], T([256], f16), T([256], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 24, ((T([64, 512, 256], f16), T([64, 512, 256], f16), [256], T([64, 512, 1], f32), T([64, 512, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+cnt: 1, ((T([64, 512, 128], f16), T([64, 512, 128], f16), [128], T([64, 512, 1], f32), T([64, 512, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 2, ((T([], f16), T([64, 512], f16), T([64], i64), None, 1, 512, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 2, ((T([64, 512], f16), T([64], i64), None, 1, 512), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([64, 1, 1, 512], f16), 1.0), {})
+Operator: aten.split.Tensor
+cnt: 1, ((T([64, 512, 2], f16), 1, -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32768, 2], f16), [0], True), {})
+cnt: 61, ((T([32768, 256], f16), [0], True), {})
+cnt: 12, ((T([32768, 1024], f16), [0], True), {})
+cnt: 1, ((T([64, 512, 128], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GPT2ForSequenceClassification_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GPT2ForSequenceClassification_training.txt
new file mode 100644
index 0000000000000..4be61bd96d909
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GPT2ForSequenceClassification_training.txt
@@ -0,0 +1,106 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([4, 2], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([4, 2], f16), T([4, 2], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([4, 12, 1024, 1024], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([4, 12, 1024, 1024], f16), T([4, 12, 1024, 1024], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 12, ((T([1, 1, 1024, 1024], u8),), {'dtype': torch.bool})
+cnt: 12, ((T([], f16),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([4, 12, 1024, 64], f16), [48, 1024, 64]), {})
+cnt: 12, ((T([4, 12, 64, 1024], f16), [48, 64, 1024]), {})
+cnt: 12, ((T([48, 1024, 1024], f16), [4, 12, 1024, 1024]), {})
+cnt: 12, ((T([48, 1024, 64], f16), [4, 12, 1024, 64]), {})
+cnt: 1, ((T([4096, 2], f16), [4, 1024, 2]), {})
+cnt: 24, ((T([4, 1024, 12, 64], f16), [4, 1024, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([4, 1024, 768], f16), T([1, 1024, 768], f16)), {})
+cnt: 48, ((T([4, 1024, 768], f16), T([4, 1024, 768], f16)), {})
+cnt: 36, ((T([4, 1024, 3072], f16), T([4, 1024, 3072], f16)), {})
+cnt: 12, ((T([4, 1024, 3072], f16), 1.0), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([2304], f16), T([4096, 768], f16), T([768, 2304], f16)), {})
+cnt: 12, ((T([768], f16), T([4096, 768], f16), T([768, 768], f16)), {})
+cnt: 12, ((T([3072], f16), T([4096, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768], f16), T([4096, 3072], f16), T([3072, 768], f16)), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([48, 1024, 64], f16), T([48, 64, 1024], f16)), {})
+cnt: 12, ((T([48, 1024, 1024], f16), T([48, 1024, 64], f16)), {})
+cnt: 12, ((T([48, 1024, 1024], f16, stride=(1048576, 1, 1024)), T([48, 1024, 64], f16)), {})
+cnt: 12, ((T([48, 1024, 64], f16), T([48, 64, 1024], f16, stride=(65536, 1, 64))), {})
+cnt: 12, ((T([48, 64, 1024], f16, stride=(65536, 1, 64)), T([48, 1024, 1024], f16)), {})
+cnt: 12, ((T([48, 1024, 1024], f16), T([48, 1024, 64], f16, stride=(65536, 1, 1024))), {})
+Operator: aten.cat.default
+cnt: 12, (([T([4, 1024, 768], f16), T([4, 1024, 768], f16, stride=(786432, 1, 1024)), T([4, 1024, 768], f16)], 2), {})
+Operator: aten.clone.default
+cnt: 1, ((T([4, 1024], i64),), {})
+cnt: 1, ((T([4], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([4, 1024], i64), T([4, 1024], i64)), {})
+cnt: 1, ((T([4], i64), T([4], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([4, 12, 1024, 1024], f16), T([], f16)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50257, 768], f16), T([4, 1024], i64)), {})
+cnt: 1, ((T([1024, 768], f16), T([1, 1024], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024], i64), 1024, -1, False), {})
+cnt: 1, ((T([4, 1024, 768], f16), T([4, 1024], i64), 50257, -1, False), {})
+Operator: aten.index.Tensor
+cnt: 1, ((T([4, 1024, 2], f16), [T([4], i64), T([4], i64)]), {})
+Operator: aten.index_put.default
+cnt: 1, ((T([4, 1024, 2], f16), [T([4], i64), T([4], i64)], T([4, 2], f16), True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([4096, 768], f16), T([768, 2], f16, stride=(1, 768))), {})
+cnt: 1, ((T([2, 4096], f16, stride=(1, 2)), T([4096, 768], f16)), {})
+cnt: 1, ((T([4096, 2], f16), T([2, 768], f16)), {})
+cnt: 12, ((T([4096, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072, 4096], f16, stride=(1, 3072)), T([4096, 768], f16)), {})
+cnt: 12, ((T([4096, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 12, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 3072], f16)), {})
+cnt: 12, ((T([4096, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 768], f16)), {})
+cnt: 12, ((T([4096, 2304], f16), T([2304, 768], f16, stride=(1, 2304))), {})
+cnt: 12, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 2304], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 12, ((T([4, 1024, 3072], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 24, ((T([4, 1024, 3072], f16), 0.5), {})
+cnt: 24, ((T([4, 1024, 3072], f16), 0.044715), {})
+cnt: 24, ((T([4, 1024, 3072], f16), 0.7978845608028654), {})
+cnt: 48, ((T([4, 1024, 3072], f16), T([4, 1024, 3072], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([4, 1024, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([4, 1024, 768], f16), T([4, 1024, 768], f16), [768], T([4, 1024, 1], f32), T([4, 1024, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([4, 1024], i64), 0), {})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([4, 2], f16), [4, 1024, 2]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([4, 2], f16), T([4], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([4, 2], f16), T([4], i64), None, 1, -100), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 12, ((T([4, 1024, 3072], f16), 3.0), {})
+cnt: 12, ((T([4, 1024, 3072], f16), 2.0), {})
+Operator: aten.split.Tensor
+cnt: 12, ((T([4, 1024, 2304], f16), 768, 2), {})
+Operator: aten.sub.Tensor
+cnt: 1, ((T([4], i64), 1), {})
+Operator: aten.sum.SymInt
+cnt: 24, ((T([4096, 768], f16), [0], True), {})
+cnt: 12, ((T([4096, 3072], f16), [0], True), {})
+cnt: 12, ((T([4096, 2304], f16), [0], True), {})
+cnt: 1, ((T([4, 1024, 768], f16), [0], True), {})
+Operator: aten.sum.dim_IntList
+cnt: 1, ((T([4, 1024], b8), [-1]), {})
+Operator: aten.tanh.default
+cnt: 12, ((T([4, 1024, 3072], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 12, ((T([4, 1024, 3072], f16), T([4, 1024, 3072], f16)), {})
+Operator: aten.where.self
+cnt: 24, ((T([1, 1, 1024, 1024], b8), T([4, 12, 1024, 1024], f16), T([], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GPTNeoForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GPTNeoForCausalLM_training.txt
new file mode 100644
index 0000000000000..013350f4bc8cb
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GPTNeoForCausalLM_training.txt
@@ -0,0 +1,96 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([127, 50257], f32), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([127, 50257], f32), T([127, 50257], f32), 1, f32), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([1, 16, 128, 128], f32), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([1, 16, 128, 128], f32), T([1, 16, 128, 128], f32), -1, f32), {})
+Operator: aten._to_copy.default
+cnt: 48, ((T([1, 16, 128, 128], f16, stride=(262144, 128, 2048, 1)),), {'dtype': f32})
+cnt: 24, ((T([1, 1, 128, 128], u8, stride=(4194304, 4194304, 2048, 1)),), {'dtype': torch.bool})
+cnt: 24, ((T([], f32),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 24, ((T([1, 16, 128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([1, 128, 50257], f16),), {'dtype': f32})
+cnt: 1, ((T([1, 128, 50257], f32),), {'dtype': f16})
+cnt: 1, ((T([], f32),), {'dtype': f16})
+cnt: 1, ((T([], f16),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([1, 128, 50257], f32),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 24, ((T([1, 16, 128, 128], f16),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 24, ((T([1, 16, 128, 128], f32, stride=(262144, 16384, 1, 128)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 24, ((T([1, 16, 128, 128], f32),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 72, ((T([128, 2048], f16), [1, 128, 2048]), {})
+cnt: 24, ((T([16, 128, 128], f32), [1, 16, 128, 128]), {})
+cnt: 24, ((T([16, 128, 128], f16), [1, 16, 128, 128]), {})
+cnt: 1, ((T([128, 50257], f16), [1, 128, 50257]), {})
+cnt: 48, ((T([1, 128, 16, 128], f16), [1, 128, 2048]), {})
+Operator: aten.add.Tensor
+cnt: 145, ((T([1, 128, 2048], f16), T([1, 128, 2048], f16)), {})
+cnt: 72, ((T([1, 128, 8192], f16), T([1, 128, 8192], f16)), {})
+cnt: 24, ((T([1, 128, 8192], f16), 1.0), {})
+cnt: 1, ((T([50257, 2048], f16), T([50257, 2048], f16)), {})
+Operator: aten.addmm.default
+cnt: 24, ((T([2048], f16), T([128, 2048], f16), T([2048, 2048], f16, stride=(1, 2048))), {})
+cnt: 24, ((T([8192], f16), T([128, 2048], f16), T([2048, 8192], f16, stride=(1, 2048))), {})
+cnt: 24, ((T([2048], f16), T([128, 8192], f16), T([8192, 2048], f16, stride=(1, 8192))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([16, 128, 128], f32, stride=(128, 2048, 1)), T([16, 128, 128], f32, stride=(128, 1, 2048))), {})
+cnt: 24, ((T([16, 128, 128], f16), T([16, 128, 128], f16, stride=(128, 2048, 1))), {})
+cnt: 24, ((T([16, 128, 128], f16, stride=(16384, 1, 128)), T([16, 128, 128], f16, stride=(128, 2048, 1))), {})
+cnt: 24, ((T([16, 128, 128], f16, stride=(128, 2048, 1)), T([16, 128, 128], f16, stride=(128, 1, 2048))), {})
+cnt: 24, ((T([16, 128, 128], f32, stride=(128, 1, 2048)), T([16, 128, 128], f32)), {})
+cnt: 24, ((T([16, 128, 128], f32), T([16, 128, 128], f32, stride=(128, 2048, 1))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([1, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([1, 128], i64), T([1, 128], i64)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50257, 2048], f16), T([1, 128], i64)), {})
+cnt: 1, ((T([2048, 2048], f16), T([1, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 128, 2048], f16), T([1, 128], i64), 2048, -1, False), {})
+cnt: 1, ((T([1, 128, 2048], f16), T([1, 128], i64), 50257, -1, False), {})
+Operator: aten.mm.default
+cnt: 72, ((T([128, 2048], f16), T([2048, 2048], f16, stride=(1, 2048))), {})
+cnt: 1, ((T([128, 2048], f16), T([2048, 50257], f16, stride=(1, 2048))), {})
+cnt: 1, ((T([50257, 128], f16, stride=(1, 50257)), T([128, 2048], f16)), {})
+cnt: 1, ((T([128, 50257], f16), T([50257, 2048], f16)), {})
+cnt: 24, ((T([128, 2048], f16), T([2048, 8192], f16)), {})
+cnt: 24, ((T([2048, 128], f16, stride=(1, 2048)), T([128, 8192], f16)), {})
+cnt: 24, ((T([128, 8192], f16), T([8192, 2048], f16)), {})
+cnt: 24, ((T([8192, 128], f16, stride=(1, 8192)), T([128, 2048], f16)), {})
+cnt: 72, ((T([128, 2048], f16), T([2048, 2048], f16)), {})
+cnt: 72, ((T([2048, 128], f16, stride=(1, 2048)), T([128, 2048], f16)), {})
+cnt: 24, ((T([2048, 128], f16), T([128, 2048], f16)), {})
+cnt: 24, ((T([128, 2048], f16, stride=(1, 128)), T([2048, 2048], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 24, ((T([1, 128, 8192], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 48, ((T([1, 128, 8192], f16), 0.5), {})
+cnt: 48, ((T([1, 128, 8192], f16), 0.044715), {})
+cnt: 48, ((T([1, 128, 8192], f16), 0.7978845608028654), {})
+cnt: 96, ((T([1, 128, 8192], f16), T([1, 128, 8192], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 49, ((T([1, 128, 2048], f16), [2048], T([2048], f16), T([2048], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 49, ((T([1, 128, 2048], f16), T([1, 128, 2048], f16), [2048], T([1, 128, 1], f32), T([1, 128, 1], f32), T([2048], f16), T([2048], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f32), T([127, 50257], f32), T([127], i64), None, 1, -100, T([], f32)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([127, 50257], f32), T([127], i64), None, 1, -100), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 24, ((T([1, 128, 8192], f16), 3.0), {})
+cnt: 24, ((T([1, 128, 8192], f16), 2.0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([1, 127, 50257], f32), [1, 127, 50257], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([1, 127, 50257], f32), [1, 128, 50257], 1, 0, -1, 1), {})
+Operator: aten.sum.SymInt
+cnt: 48, ((T([128, 2048], f16), [0], True), {})
+cnt: 24, ((T([128, 8192], f16), [0], True), {})
+Operator: aten.tanh.default
+cnt: 24, ((T([1, 128, 8192], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 24, ((T([1, 128, 8192], f16), T([1, 128, 8192], f16)), {})
+Operator: aten.where.self
+cnt: 48, ((T([1, 1, 128, 128], b8), T([1, 16, 128, 128], f32), T([], f32)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GPTNeoForSequenceClassification_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GPTNeoForSequenceClassification_training.txt
new file mode 100644
index 0000000000000..a537c2d6c04fb
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GPTNeoForSequenceClassification_training.txt
@@ -0,0 +1,101 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([1, 2], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([1, 2], f16), T([1, 2], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([1, 16, 128, 128], f32), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([1, 16, 128, 128], f32), T([1, 16, 128, 128], f32), -1, f32), {})
+Operator: aten._to_copy.default
+cnt: 48, ((T([1, 16, 128, 128], f16, stride=(262144, 128, 2048, 1)),), {'dtype': f32})
+cnt: 24, ((T([1, 1, 128, 128], u8, stride=(4194304, 4194304, 2048, 1)),), {'dtype': torch.bool})
+cnt: 24, ((T([], f32),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 24, ((T([1, 16, 128, 128], f32),), {'dtype': f16})
+cnt: 24, ((T([1, 16, 128, 128], f16),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 24, ((T([1, 16, 128, 128], f32, stride=(262144, 16384, 1, 128)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 24, ((T([1, 16, 128, 128], f32),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 72, ((T([128, 2048], f16), [1, 128, 2048]), {})
+cnt: 24, ((T([16, 128, 128], f32), [1, 16, 128, 128]), {})
+cnt: 24, ((T([16, 128, 128], f16), [1, 16, 128, 128]), {})
+cnt: 1, ((T([128, 2], f16), [1, 128, 2]), {})
+cnt: 48, ((T([1, 128, 16, 128], f16), [1, 128, 2048]), {})
+Operator: aten.add.Tensor
+cnt: 145, ((T([1, 128, 2048], f16), T([1, 128, 2048], f16)), {})
+cnt: 72, ((T([1, 128, 8192], f16), T([1, 128, 8192], f16)), {})
+cnt: 24, ((T([1, 128, 8192], f16), 1.0), {})
+Operator: aten.addmm.default
+cnt: 24, ((T([2048], f16), T([128, 2048], f16), T([2048, 2048], f16, stride=(1, 2048))), {})
+cnt: 24, ((T([8192], f16), T([128, 2048], f16), T([2048, 8192], f16, stride=(1, 2048))), {})
+cnt: 24, ((T([2048], f16), T([128, 8192], f16), T([8192, 2048], f16, stride=(1, 8192))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([16, 128, 128], f32, stride=(128, 2048, 1)), T([16, 128, 128], f32, stride=(128, 1, 2048))), {})
+cnt: 24, ((T([16, 128, 128], f16), T([16, 128, 128], f16, stride=(128, 2048, 1))), {})
+cnt: 24, ((T([16, 128, 128], f16, stride=(16384, 1, 128)), T([16, 128, 128], f16, stride=(128, 2048, 1))), {})
+cnt: 24, ((T([16, 128, 128], f16, stride=(128, 2048, 1)), T([16, 128, 128], f16, stride=(128, 1, 2048))), {})
+cnt: 24, ((T([16, 128, 128], f32, stride=(128, 1, 2048)), T([16, 128, 128], f32)), {})
+cnt: 24, ((T([16, 128, 128], f32), T([16, 128, 128], f32, stride=(128, 2048, 1))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([1, 128], i64),), {})
+cnt: 1, ((T([1], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([1, 128], i64), T([1, 128], i64)), {})
+cnt: 1, ((T([1], i64), T([1], i64)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50257, 2048], f16), T([1, 128], i64)), {})
+cnt: 1, ((T([2048, 2048], f16), T([1, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 128, 2048], f16), T([1, 128], i64), 2048, -1, False), {})
+cnt: 1, ((T([1, 128, 2048], f16), T([1, 128], i64), 50257, -1, False), {})
+Operator: aten.index.Tensor
+cnt: 1, ((T([1, 128, 2], f16), [T([1], i64), T([1], i64)]), {})
+Operator: aten.index_put.default
+cnt: 1, ((T([1, 128, 2], f16), [T([1], i64), T([1], i64)], T([1, 2], f16), True), {})
+Operator: aten.mm.default
+cnt: 72, ((T([128, 2048], f16), T([2048, 2048], f16, stride=(1, 2048))), {})
+cnt: 1, ((T([128, 2048], f16), T([2048, 2], f16, stride=(1, 2048))), {})
+cnt: 1, ((T([2, 128], f16, stride=(1, 2)), T([128, 2048], f16)), {})
+cnt: 1, ((T([128, 2], f16), T([2, 2048], f16)), {})
+cnt: 24, ((T([128, 2048], f16), T([2048, 8192], f16)), {})
+cnt: 24, ((T([2048, 128], f16, stride=(1, 2048)), T([128, 8192], f16)), {})
+cnt: 24, ((T([128, 8192], f16), T([8192, 2048], f16)), {})
+cnt: 24, ((T([8192, 128], f16, stride=(1, 8192)), T([128, 2048], f16)), {})
+cnt: 72, ((T([128, 2048], f16), T([2048, 2048], f16)), {})
+cnt: 72, ((T([2048, 128], f16, stride=(1, 2048)), T([128, 2048], f16)), {})
+cnt: 24, ((T([2048, 128], f16), T([128, 2048], f16)), {})
+cnt: 24, ((T([128, 2048], f16, stride=(1, 128)), T([2048, 2048], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 24, ((T([1, 128, 8192], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 48, ((T([1, 128, 8192], f16), 0.5), {})
+cnt: 48, ((T([1, 128, 8192], f16), 0.044715), {})
+cnt: 48, ((T([1, 128, 8192], f16), 0.7978845608028654), {})
+cnt: 96, ((T([1, 128, 8192], f16), T([1, 128, 8192], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 49, ((T([1, 128, 2048], f16), [2048], T([2048], f16), T([2048], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 49, ((T([1, 128, 2048], f16), T([1, 128, 2048], f16), [2048], T([1, 128, 1], f32), T([1, 128, 1], f32), T([2048], f16), T([2048], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([1, 128], i64), 0), {})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([1, 2], f16), [1, 128, 2]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([1, 2], f16), T([1], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([1, 2], f16), T([1], i64), None, 1, -100), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 24, ((T([1, 128, 8192], f16), 3.0), {})
+cnt: 24, ((T([1, 128, 8192], f16), 2.0), {})
+Operator: aten.sub.Tensor
+cnt: 1, ((T([1], i64), 1), {})
+Operator: aten.sum.SymInt
+cnt: 48, ((T([128, 2048], f16), [0], True), {})
+cnt: 24, ((T([128, 8192], f16), [0], True), {})
+Operator: aten.sum.dim_IntList
+cnt: 1, ((T([1, 128], b8), [-1]), {})
+Operator: aten.tanh.default
+cnt: 24, ((T([1, 128, 8192], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 24, ((T([1, 128, 8192], f16), T([1, 128, 8192], f16)), {})
+Operator: aten.where.self
+cnt: 48, ((T([1, 1, 128, 128], b8), T([1, 16, 128, 128], f32), T([], f32)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GoogleFnet_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GoogleFnet_training.txt
new file mode 100644
index 0000000000000..c234ce838bf7b
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/GoogleFnet_training.txt
@@ -0,0 +1,83 @@
+Operator: aten._fft_c2c.default
+cnt: 12, ((T([1, 512, 768], c32), [1, 2], 0, True), {})
+cnt: 12, ((T([1, 512, 768], c32), [1, 2], 0, False), {})
+Operator: aten._log_softmax.default
+cnt: 1, ((T([512, 32000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([512, 32000], f16), T([512, 32000], f16), 1, f16), {})
+Operator: aten._to_copy.default
+cnt: 12, ((T([1, 512, 768], f16),), {'dtype': c32})
+Operator: aten.add.Tensor
+cnt: 28, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 24, ((T([1, 512, 768], f16), T([1, 512, 768], f16, stride=(786432, 1536, 2))), {})
+cnt: 36, ((T([1, 512, 3072], f16), T([1, 512, 3072], f16)), {})
+cnt: 12, ((T([1, 512, 3072], f16), 1.0), {})
+cnt: 1, ((T([1, 512, 768], f16), 1.0), {})
+cnt: 1, ((T([32000, 768], f16), T([32000, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 2, ((T([768], f16), T([512, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([512, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([512, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([768], f16), T([1, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 1, ((T([32000], f16), T([512, 768], f16), T([768, 32000], f16, stride=(1, 768))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([1, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([1, 512], i64), T([1, 512], i64)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([32000, 768], f16), T([1, 512], i64), 3), {})
+cnt: 1, ((T([4, 768], f16), T([1, 512], i64)), {})
+cnt: 1, ((T([512, 768], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 4, -1, False), {})
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 32000, 3, False), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 32000], f16), T([32000, 768], f16)), {})
+cnt: 1, ((T([32000, 512], f16, stride=(1, 32000)), T([512, 768], f16)), {})
+cnt: 2, ((T([512, 768], f16), T([768, 768], f16)), {})
+cnt: 2, ((T([768, 512], f16, stride=(1, 768)), T([512, 768], f16)), {})
+cnt: 12, ((T([512, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 512], f16, stride=(1, 768)), T([512, 3072], f16)), {})
+cnt: 12, ((T([512, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 512], f16, stride=(1, 3072)), T([512, 768], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 1, ((T([1, 512, 768], f16), 3.0), {})
+cnt: 12, ((T([1, 512, 3072], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 24, ((T([1, 512, 3072], f16), 0.5), {})
+cnt: 24, ((T([1, 512, 3072], f16), 0.044715), {})
+cnt: 24, ((T([1, 512, 3072], f16), 0.7978845608028654), {})
+cnt: 48, ((T([1, 512, 3072], f16), T([1, 512, 3072], f16)), {})
+cnt: 2, ((T([1, 512, 768], f16), 0.5), {})
+cnt: 2, ((T([1, 512, 768], f16), 0.044715), {})
+cnt: 2, ((T([1, 512, 768], f16), 0.7978845608028654), {})
+cnt: 4, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([1, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([1, 512, 768], f16), T([1, 512, 768], f16), [768], T([1, 512, 1], f32), T([1, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([512, 32000], f16), T([512], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([512, 32000], f16), T([512], i64), None, 1, -100), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 12, ((T([1, 512, 3072], f16), 3.0), {})
+cnt: 1, ((T([1, 512, 768], f16), 3.0), {})
+cnt: 1, ((T([1, 512, 768], f16), 2.0), {})
+cnt: 12, ((T([1, 512, 3072], f16), 2.0), {})
+Operator: aten.select_backward.default
+cnt: 12, ((T([1, 512, 768], f16), [1, 512, 768, 2], 3, 0), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([512, 32000], f16), [0], True), {})
+cnt: 14, ((T([512, 768], f16), [0], True), {})
+cnt: 12, ((T([512, 3072], f16), [0], True), {})
+Operator: aten.tanh.default
+cnt: 12, ((T([1, 512, 3072], f16),), {})
+cnt: 1, ((T([1, 768], f16),), {})
+cnt: 1, ((T([1, 512, 768], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 12, ((T([1, 512, 3072], f16), T([1, 512, 3072], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/LayoutLMForMaskedLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/LayoutLMForMaskedLM_training.txt
new file mode 100644
index 0000000000000..e10fea3367ca7
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/LayoutLMForMaskedLM_training.txt
@@ -0,0 +1,90 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([8192, 30522], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([8192, 30522], f16), T([8192, 30522], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([16, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([16, 12, 512, 512], f16), T([16, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([16, 1, 1, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([16, 12, 512, 64], f16), [192, 512, 64]), {})
+cnt: 12, ((T([16, 12, 64, 512], f16), [192, 64, 512]), {})
+cnt: 12, ((T([192, 512, 512], f16), [16, 12, 512, 512]), {})
+cnt: 12, ((T([192, 512, 64], f16), [16, 12, 512, 64]), {})
+cnt: 24, ((T([16, 512, 12, 64], f16), [16, 512, 768]), {})
+cnt: 12, ((T([16, 512, 768], f16), [8192, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([16, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 79, ((T([16, 512, 768], f16), T([16, 512, 768], f16)), {})
+cnt: 12, ((T([16, 12, 512, 512], f16), T([16, 1, 1, 512], f16)), {})
+cnt: 2, ((T([1024, 768], f16), T([1024, 768], f16)), {})
+cnt: 1, ((T([30522, 768], f16), T([30522, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 49, ((T([768], f16), T([8192, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([8192, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([8192, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([768], f16), T([16, 768], f16, stride=(393216, 1)), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 1, ((T([30522], f16), T([8192, 768], f16), T([768, 30522], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([192, 512, 64], f16), T([192, 64, 512], f16)), {})
+cnt: 12, ((T([192, 512, 512], f16), T([192, 512, 64], f16)), {})
+cnt: 12, ((T([192, 512, 512], f16, stride=(262144, 1, 512)), T([192, 512, 64], f16)), {})
+cnt: 12, ((T([192, 512, 64], f16), T([192, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([192, 64, 512], f16, stride=(32768, 1, 64)), T([192, 512, 512], f16)), {})
+cnt: 12, ((T([192, 512, 512], f16), T([192, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([16, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([16, 512], i64), T([16, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([16, 12, 512, 512], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([16, 512], i64), 0), {})
+cnt: 1, ((T([512, 768], f16), T([1, 512], i64)), {})
+cnt: 4, ((T([1024, 768], f16), T([16, 512], i64, stride=(2048, 4))), {})
+cnt: 2, ((T([1024, 768], f16), T([16, 512], i64)), {})
+cnt: 1, ((T([2, 768], f16), T([16, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([16, 512, 768], f16), T([16, 512], i64), 2, -1, False), {})
+cnt: 2, ((T([16, 512, 768], f16), T([16, 512], i64), 1024, -1, False), {})
+cnt: 4, ((T([16, 512, 768], f16), T([16, 512], i64, stride=(2048, 4)), 1024, -1, False), {})
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([16, 512, 768], f16), T([16, 512], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([16, 512, 3072], f16),), {})
+cnt: 1, ((T([16, 512, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([16, 512, 768], f16), T([16, 512, 768], f16)), {})
+cnt: 12, ((T([16, 512, 3072], f16), T([16, 512, 3072], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([8192, 30522], f16), T([30522, 768], f16)), {})
+cnt: 1, ((T([30522, 8192], f16, stride=(1, 30522)), T([8192, 768], f16)), {})
+cnt: 49, ((T([8192, 768], f16), T([768, 768], f16)), {})
+cnt: 49, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 768], f16)), {})
+cnt: 12, ((T([8192, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 3072], f16)), {})
+cnt: 12, ((T([8192, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 8192], f16, stride=(1, 3072)), T([8192, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([16, 1, 1, 512], f16), -65504.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([16, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([16, 512, 768], f16), T([16, 512, 768], f16), [768], T([16, 512, 1], f32), T([16, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([8192, 30522], f16), T([8192], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([8192, 30522], f16), T([8192], i64), None, 1, -100), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([16, 1, 1, 512], f16), 1.0), {})
+Operator: aten.sub.Tensor
+cnt: 2, ((T([16, 512], i64, stride=(2048, 4)), T([16, 512], i64, stride=(2048, 4))), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([8192, 30522], f16), [0], True), {})
+cnt: 61, ((T([8192, 768], f16), [0], True), {})
+cnt: 12, ((T([8192, 3072], f16), [0], True), {})
+cnt: 1, ((T([16, 512, 768], f16), [0], True), {})
+Operator: aten.tanh.default
+cnt: 1, ((T([16, 768], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/LayoutLMForSequenceClassification_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/LayoutLMForSequenceClassification_training.txt
new file mode 100644
index 0000000000000..3d06f14961a04
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/LayoutLMForSequenceClassification_training.txt
@@ -0,0 +1,98 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([16, 2], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([16, 2], f16), T([16, 2], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([16, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([16, 12, 512, 512], f16), T([16, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([16, 1, 1, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([16, 12, 512, 64], f16), [192, 512, 64]), {})
+cnt: 12, ((T([16, 12, 64, 512], f16), [192, 64, 512]), {})
+cnt: 12, ((T([192, 512, 512], f16), [16, 12, 512, 512]), {})
+cnt: 12, ((T([192, 512, 64], f16), [16, 12, 512, 64]), {})
+cnt: 24, ((T([16, 512, 12, 64], f16), [16, 512, 768]), {})
+cnt: 12, ((T([16, 512, 768], f16), [8192, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([16, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 79, ((T([16, 512, 768], f16), T([16, 512, 768], f16)), {})
+cnt: 12, ((T([16, 12, 512, 512], f16), T([16, 1, 1, 512], f16)), {})
+cnt: 2, ((T([1024, 768], f16), T([1024, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([768], f16), T([8192, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([8192, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([8192, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([768], f16), T([16, 768], f16, stride=(393216, 1)), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 1, ((T([2], f16), T([16, 768], f16), T([768, 2], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([192, 512, 64], f16), T([192, 64, 512], f16)), {})
+cnt: 12, ((T([192, 512, 512], f16), T([192, 512, 64], f16)), {})
+cnt: 12, ((T([192, 512, 512], f16, stride=(262144, 1, 512)), T([192, 512, 64], f16)), {})
+cnt: 12, ((T([192, 512, 64], f16), T([192, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([192, 64, 512], f16, stride=(32768, 1, 64)), T([192, 512, 512], f16)), {})
+cnt: 12, ((T([192, 512, 512], f16), T([192, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([16, 512], i64),), {})
+cnt: 1, ((T([16], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([16, 512], i64), T([16, 512], i64)), {})
+cnt: 1, ((T([16], i64), T([16], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([16, 12, 512, 512], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([16, 512], i64), 0), {})
+cnt: 1, ((T([512, 768], f16), T([1, 512], i64)), {})
+cnt: 4, ((T([1024, 768], f16), T([16, 512], i64, stride=(2048, 4))), {})
+cnt: 2, ((T([1024, 768], f16), T([16, 512], i64)), {})
+cnt: 1, ((T([2, 768], f16), T([16, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([16, 512, 768], f16), T([16, 512], i64), 2, -1, False), {})
+cnt: 2, ((T([16, 512, 768], f16), T([16, 512], i64), 1024, -1, False), {})
+cnt: 4, ((T([16, 512, 768], f16), T([16, 512], i64, stride=(2048, 4)), 1024, -1, False), {})
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([16, 512, 768], f16), T([16, 512], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([16, 512, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([16, 512, 3072], f16), T([16, 512, 3072], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([16, 2], f16), T([2, 768], f16)), {})
+cnt: 1, ((T([2, 16], f16, stride=(1, 2)), T([16, 768], f16)), {})
+cnt: 1, ((T([16, 768], f16), T([768, 768], f16)), {})
+cnt: 1, ((T([768, 16], f16, stride=(1, 768)), T([16, 768], f16, stride=(393216, 1))), {})
+cnt: 12, ((T([8192, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 3072], f16)), {})
+cnt: 12, ((T([8192, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 8192], f16, stride=(1, 3072)), T([8192, 768], f16)), {})
+cnt: 48, ((T([8192, 768], f16), T([768, 768], f16)), {})
+cnt: 48, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([16, 1, 1, 512], f16), -65504.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([16, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([16, 512, 768], f16), T([16, 512, 768], f16), [768], T([16, 512, 1], f32), T([16, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([16, 2], f16), T([16], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([16, 2], f16), T([16], i64), None, 1, -100), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([16, 1, 1, 512], f16), 1.0), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([16, 768], f16), [16, 512, 768], 1, 0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([16, 512, 768], f16), [16, 512, 768], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sub.Tensor
+cnt: 2, ((T([16, 512], i64, stride=(2048, 4)), T([16, 512], i64, stride=(2048, 4))), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([16, 2], f16), [0], True), {})
+cnt: 1, ((T([16, 768], f16), [0], True), {})
+cnt: 60, ((T([8192, 768], f16), [0], True), {})
+cnt: 12, ((T([8192, 3072], f16), [0], True), {})
+cnt: 1, ((T([16, 512, 768], f16), [0], True), {})
+Operator: aten.tanh.default
+cnt: 1, ((T([16, 768], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 1, ((T([16, 768], f16), T([16, 768], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/M2M100ForConditionalGeneration_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/M2M100ForConditionalGeneration_training.txt
new file mode 100644
index 0000000000000..bafa9de2de0a6
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/M2M100ForConditionalGeneration_training.txt
@@ -0,0 +1,88 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([256, 128112], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([256, 128112], f16), T([256, 128112], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 36, ((T([32, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 36, ((T([32, 128, 128], f16), T([32, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 2, ((T([2, 128], b8),), {'dtype': i32})
+cnt: 2, ((T([2, 128], i64),), {'dtype': i32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 2, ((T([2, 128], i32),), {'dtype': i64})
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([2, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 108, ((T([2, 128, 16, 64], f16), [2, 128, 1024]), {})
+cnt: 1, ((T([256, 128112], f16), [2, 128, 128112]), {})
+cnt: 36, ((T([2, 16, 128, 64], f16), [32, 128, 64]), {})
+cnt: 36, ((T([2, 128, 1024], f16), [256, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([2, 128], i32), 0), {})
+cnt: 2, ((T([2, 128], i64), 1), {})
+cnt: 193, ((T([2, 128, 1024], f16), T([2, 128, 1024], f16)), {})
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 12, ((T([2, 16, 128, 128], f16), T([2, 1, 128, 128], f16)), {})
+cnt: 2, ((T([128112, 1024], f16), T([128112, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 144, ((T([1024], f16), T([256, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([4096], f16), T([256, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([1024], f16), T([256, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+Operator: aten.any.default
+cnt: 24, ((T([2, 128, 1024], b8),), {})
+Operator: aten.bmm.default
+cnt: 72, ((T([32, 128, 64], f16), T([32, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 72, ((T([32, 128, 128], f16), T([32, 128, 64], f16)), {})
+cnt: 36, ((T([32, 128, 128], f16, stride=(16384, 1, 128)), T([32, 128, 64], f16)), {})
+cnt: 36, ((T([32, 64, 128], f16, stride=(8192, 1, 64)), T([32, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 3, ((T([2, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 3, ((T([2, 128], i64), T([2, 128], i64)), {})
+Operator: aten.cumsum.default
+cnt: 2, ((T([2, 128], i32), 1), {})
+Operator: aten.embedding.default
+cnt: 2, ((T([128112, 1024], f16), T([2, 128], i64), 1), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 2, ((T([2, 128, 1024], f16), T([2, 128], i64), 128112, 1, False), {})
+Operator: aten.index_select.default
+cnt: 2, ((T([1026, 1024], f16), 0, T([256], i64)), {})
+Operator: aten.isinf.default
+cnt: 12, ((T([2, 128, 1024], f16),), {})
+Operator: aten.isnan.default
+cnt: 12, ((T([2, 128, 1024], f16),), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([256, 1024], f16), T([1024, 128112], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([128112, 256], f16, stride=(1, 128112)), T([256, 1024], f16)), {})
+cnt: 1, ((T([256, 128112], f16), T([128112, 1024], f16)), {})
+cnt: 24, ((T([256, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 256], f16, stride=(1, 1024)), T([256, 4096], f16)), {})
+cnt: 24, ((T([256, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 24, ((T([4096, 256], f16, stride=(1, 4096)), T([256, 1024], f16)), {})
+cnt: 144, ((T([256, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 144, ((T([1024, 256], f16, stride=(1, 1024)), T([256, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([2, 128, 1024], f16), 32.0), {})
+cnt: 2, ((T([2, 128], i32), T([2, 128], i32)), {})
+cnt: 72, ((T([2, 128, 1024], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 62, ((T([2, 128, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 62, ((T([2, 128, 1024], f16), T([2, 128, 1024], f16), [1024], T([2, 128, 1], f32), T([2, 128, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 2, ((T([2, 128], i64), 1), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([256, 128112], f16), T([256], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([256, 128112], f16), T([256], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 24, ((T([2, 128, 4096], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 168, ((T([256, 1024], f16), [0], True), {})
+cnt: 24, ((T([256, 4096], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 24, ((T([2, 128, 4096], f16), T([2, 128, 4096], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MBartForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MBartForCausalLM_training.txt
new file mode 100644
index 0000000000000..288b2cd2cbb2e
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MBartForCausalLM_training.txt
@@ -0,0 +1,73 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([2048, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([2048, 50265], f16), T([2048, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([256, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([256, 128, 128], f16), T([256, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([16, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([16, 128, 16, 64], f16), [16, 128, 1024]), {})
+cnt: 1, ((T([2048, 50265], f16), [16, 128, 50265]), {})
+cnt: 12, ((T([16, 16, 128, 64], f16), [256, 128, 64]), {})
+cnt: 12, ((T([16, 128, 1024], f16), [2048, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 1, ((T([16, 128], i64, stride=(0, 1)), 2), {})
+cnt: 73, ((T([16, 128, 1024], f16), T([16, 128, 1024], f16)), {})
+cnt: 12, ((T([16, 16, 128, 128], f16), T([16, 1, 128, 128], f16)), {})
+cnt: 1, ((T([50265, 1024], f16), T([50265, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([1024], f16), T([2048, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 12, ((T([4096], f16), T([2048, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 12, ((T([1024], f16), T([2048, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([256, 128, 64], f16), T([256, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 24, ((T([256, 128, 128], f16), T([256, 128, 64], f16)), {})
+cnt: 12, ((T([256, 128, 128], f16, stride=(16384, 1, 128)), T([256, 128, 64], f16)), {})
+cnt: 12, ((T([256, 64, 128], f16, stride=(8192, 1, 64)), T([256, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([16, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([16, 128], i64), T([16, 128], i64)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50265, 1024], f16), T([16, 128], i64), 1), {})
+cnt: 1, ((T([1026, 1024], f16), T([16, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([16, 128, 1024], f16), T([16, 128], i64), 1026, -1, False), {})
+cnt: 1, ((T([16, 128, 1024], f16), T([16, 128], i64), 50265, 1, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([16, 128, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([16, 128, 4096], f16), T([16, 128, 4096], f16)), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 1024], f16), T([1024, 50265], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([50265, 2048], f16, stride=(1, 50265)), T([2048, 1024], f16)), {})
+cnt: 1, ((T([2048, 50265], f16), T([50265, 1024], f16)), {})
+cnt: 12, ((T([2048, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 12, ((T([1024, 2048], f16, stride=(1, 1024)), T([2048, 4096], f16)), {})
+cnt: 12, ((T([2048, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 12, ((T([4096, 2048], f16, stride=(1, 4096)), T([2048, 1024], f16)), {})
+cnt: 48, ((T([2048, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 48, ((T([1024, 2048], f16, stride=(1, 1024)), T([2048, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([16, 128, 1024], f16), 1.0), {})
+cnt: 24, ((T([16, 128, 1024], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([16, 128, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([16, 128, 1024], f16), T([16, 128, 1024], f16), [1024], T([16, 128, 1], f32), T([16, 128, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([2048, 50265], f16), T([2048], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([2048, 50265], f16), T([2048], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 60, ((T([2048, 1024], f16), [0], True), {})
+cnt: 12, ((T([2048, 4096], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MBartForConditionalGeneration_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MBartForConditionalGeneration_training.txt
new file mode 100644
index 0000000000000..2ca11dd081846
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MBartForConditionalGeneration_training.txt
@@ -0,0 +1,94 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([1024, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([1024, 50265], f16), T([1024, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 36, ((T([128, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 36, ((T([128, 128, 128], f16), T([128, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([8, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 108, ((T([8, 128, 16, 64], f16), [8, 128, 1024]), {})
+cnt: 1, ((T([1024, 50265], f16), [8, 128, 50265]), {})
+cnt: 36, ((T([8, 16, 128, 64], f16), [128, 128, 64]), {})
+cnt: 36, ((T([8, 128, 1024], f16), [1024, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([8, 128], i64, stride=(0, 1)), 2), {})
+cnt: 193, ((T([8, 128, 1024], f16), T([8, 128, 1024], f16)), {})
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 12, ((T([8, 16, 128, 128], f16), T([8, 1, 128, 128], f16)), {})
+cnt: 1, ((T([8, 128, 50265], f16), T([1, 50265], f16)), {})
+cnt: 2, ((T([50265, 1024], f16), T([50265, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 144, ((T([1024], f16), T([1024, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([4096], f16), T([1024, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([1024], f16), T([1024, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+Operator: aten.any.default
+cnt: 24, ((T([8, 128, 1024], b8),), {})
+Operator: aten.bmm.default
+cnt: 72, ((T([128, 128, 64], f16), T([128, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 72, ((T([128, 128, 128], f16), T([128, 128, 64], f16)), {})
+cnt: 36, ((T([128, 128, 128], f16, stride=(16384, 1, 128)), T([128, 128, 64], f16)), {})
+cnt: 36, ((T([128, 64, 128], f16, stride=(8192, 1, 64)), T([128, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 3, ((T([8, 128], i64),), {})
+cnt: 1, ((T([8, 127], i64, stride=(128, 1)),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([8, 128], i64), T([8, 128], i64)), {})
+cnt: 1, ((T([8, 127], i64, stride=(128, 1)), T([8, 127], i64)), {})
+cnt: 1, ((T([8], i64, stride=(128,)), T([8], i64)), {})
+Operator: aten.embedding.default
+cnt: 2, ((T([50265, 1024], f16), T([8, 128], i64), 1), {})
+cnt: 2, ((T([1026, 1024], f16), T([8, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 2, ((T([8, 128, 1024], f16), T([8, 128], i64), 1026, -1, False), {})
+cnt: 2, ((T([8, 128, 1024], f16), T([8, 128], i64), 50265, 1, False), {})
+Operator: aten.eq.Scalar
+cnt: 1, ((T([8, 128], i64), -100), {})
+Operator: aten.gather.default
+cnt: 1, ((T([8, 128], i64), 1, T([8, 1], i64)), {})
+Operator: aten.gelu.default
+cnt: 24, ((T([8, 128, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 24, ((T([8, 128, 4096], f16), T([8, 128, 4096], f16)), {})
+Operator: aten.isinf.default
+cnt: 12, ((T([8, 128, 1024], f16),), {})
+Operator: aten.isnan.default
+cnt: 12, ((T([8, 128, 1024], f16),), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([8, 128], i64), T([8, 128], b8), 1), {})
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 1024], f16), T([1024, 50265], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([50265, 1024], f16, stride=(1, 50265)), T([1024, 1024], f16)), {})
+cnt: 1, ((T([1024, 50265], f16), T([50265, 1024], f16)), {})
+cnt: 24, ((T([1024, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 1024], f16, stride=(1, 1024)), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 24, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 1024], f16)), {})
+cnt: 144, ((T([1024, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 144, ((T([1024, 1024], f16, stride=(1, 1024)), T([1024, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([8, 128, 1024], f16), 1.0), {})
+cnt: 72, ((T([8, 128, 1024], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 64, ((T([8, 128, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 64, ((T([8, 128, 1024], f16), T([8, 128, 1024], f16), [1024], T([8, 128, 1], f32), T([8, 128, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([8, 128], i64), 1), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([1024, 50265], f16), T([1024], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([1024, 50265], f16), T([1024], i64), None, 1, -100), {})
+Operator: aten.sub.Tensor
+cnt: 1, ((T([8], i64), 1), {})
+Operator: aten.sum.SymInt
+cnt: 168, ((T([1024, 1024], f16), [0], True), {})
+cnt: 24, ((T([1024, 4096], f16), [0], True), {})
+Operator: aten.sum.dim_IntList
+cnt: 1, ((T([8, 128], b8), [1]), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MegatronBertForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MegatronBertForCausalLM_training.txt
new file mode 100644
index 0000000000000..efe2661fcc679
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MegatronBertForCausalLM_training.txt
@@ -0,0 +1,85 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([254, 29056], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([254, 29056], f16), T([254, 29056], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([2, 16, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([2, 16, 128, 128], f16), T([2, 16, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([2, 1, 1, 128], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 72, ((T([2, 16, 128, 64], f16), [32, 128, 64]), {})
+cnt: 24, ((T([2, 16, 64, 128], f16), [32, 64, 128]), {})
+cnt: 24, ((T([32, 128, 128], f16), [2, 16, 128, 128]), {})
+cnt: 24, ((T([32, 128, 64], f16), [2, 16, 128, 64]), {})
+cnt: 48, ((T([2, 128, 16, 64], f16), [2, 128, 1024]), {})
+cnt: 24, ((T([2, 128, 1024], f16), [256, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 145, ((T([2, 128, 1024], f16), T([2, 128, 1024], f16)), {})
+cnt: 24, ((T([2, 16, 128, 128], f16), T([2, 1, 1, 128], f16)), {})
+cnt: 1, ((T([29056, 1024], f16), T([29056, 1024], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([2, 128, 1024], f16), T([1, 128, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 97, ((T([1024], f16), T([256, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([4096], f16), T([256, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([1024], f16), T([256, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+cnt: 1, ((T([29056], f16), T([256, 1024], f16), T([1024, 29056], f16, stride=(1, 1024))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([32, 128, 64], f16), T([32, 64, 128], f16)), {})
+cnt: 24, ((T([32, 128, 128], f16), T([32, 128, 64], f16)), {})
+cnt: 24, ((T([32, 128, 128], f16, stride=(16384, 1, 128)), T([32, 128, 64], f16)), {})
+cnt: 24, ((T([32, 128, 64], f16), T([32, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 24, ((T([32, 64, 128], f16, stride=(8192, 1, 64)), T([32, 128, 128], f16)), {})
+cnt: 24, ((T([32, 128, 128], f16), T([32, 128, 64], f16, stride=(8192, 1, 128))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([2, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([2, 128], i64), T([2, 128], i64)), {})
+Operator: aten.div.Tensor
+cnt: 48, ((T([2, 16, 128, 128], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([29056, 1024], f16), T([2, 128], i64), 0), {})
+cnt: 1, ((T([2, 1024], f16), T([2, 128], i64)), {})
+cnt: 1, ((T([512, 1024], f16), T([1, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 128, 1024], f16), T([1, 128], i64), 512, -1, False), {})
+cnt: 1, ((T([2, 128, 1024], f16), T([2, 128], i64), 2, -1, False), {})
+cnt: 1, ((T([2, 128, 1024], f16), T([2, 128], i64), 29056, 0, False), {})
+Operator: aten.gelu.default
+cnt: 24, ((T([2, 128, 4096], f16),), {})
+cnt: 1, ((T([2, 128, 1024], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([2, 128, 1024], f16), T([2, 128, 1024], f16)), {})
+cnt: 24, ((T([2, 128, 4096], f16), T([2, 128, 4096], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([256, 29056], f16), T([29056, 1024], f16)), {})
+cnt: 1, ((T([29056, 256], f16, stride=(1, 29056)), T([256, 1024], f16)), {})
+cnt: 97, ((T([256, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 97, ((T([1024, 256], f16, stride=(1, 1024)), T([256, 1024], f16)), {})
+cnt: 24, ((T([256, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 256], f16, stride=(1, 1024)), T([256, 4096], f16)), {})
+cnt: 24, ((T([256, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 24, ((T([4096, 256], f16, stride=(1, 4096)), T([256, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([2, 1, 1, 128], f16), -65504.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 50, ((T([2, 128, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 50, ((T([2, 128, 1024], f16), T([2, 128, 1024], f16), [1024], T([2, 128, 1], f32), T([2, 128, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([254, 29056], f16), T([254], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([254, 29056], f16), T([254], i64), None, 1, -100), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([2, 1, 1, 128], f16), 1.0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([2, 127, 29056], f16), [2, 127, 29056], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([2, 127, 29056], f16), [2, 128, 29056], 1, 0, -1, 1), {})
+cnt: 1, ((T([2, 128, 29056], f16), [2, 128, 29056], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([256, 29056], f16), [0], True), {})
+cnt: 121, ((T([256, 1024], f16), [0], True), {})
+cnt: 24, ((T([256, 4096], f16), [0], True), {})
+cnt: 1, ((T([2, 128, 1024], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MegatronBertForQuestionAnswering_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MegatronBertForQuestionAnswering_training.txt
new file mode 100644
index 0000000000000..5c1861e54231a
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MegatronBertForQuestionAnswering_training.txt
@@ -0,0 +1,88 @@
+Operator: aten._log_softmax.default
+cnt: 2, ((T([8, 128], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 2, ((T([8, 128], f16), T([8, 128], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([8, 16, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([8, 16, 128, 128], f16), T([8, 16, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([8, 1, 1, 128], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 72, ((T([8, 16, 128, 64], f16), [128, 128, 64]), {})
+cnt: 24, ((T([8, 16, 64, 128], f16), [128, 64, 128]), {})
+cnt: 24, ((T([128, 128, 128], f16), [8, 16, 128, 128]), {})
+cnt: 24, ((T([128, 128, 64], f16), [8, 16, 128, 64]), {})
+cnt: 48, ((T([8, 128, 16, 64], f16), [8, 128, 1024]), {})
+cnt: 24, ((T([8, 128, 1024], f16), [1024, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 145, ((T([8, 128, 1024], f16), T([8, 128, 1024], f16)), {})
+cnt: 24, ((T([8, 16, 128, 128], f16), T([8, 1, 1, 128], f16)), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([8, 128, 1024], f16), T([1, 128, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 96, ((T([1024], f16), T([1024, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([4096], f16), T([1024, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([1024], f16), T([1024, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+cnt: 1, ((T([2], f16), T([1024, 1024], f16), T([1024, 2], f16, stride=(1, 1024))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([128, 128, 64], f16), T([128, 64, 128], f16)), {})
+cnt: 24, ((T([128, 128, 128], f16), T([128, 128, 64], f16)), {})
+cnt: 24, ((T([128, 128, 128], f16, stride=(16384, 1, 128)), T([128, 128, 64], f16)), {})
+cnt: 24, ((T([128, 128, 64], f16), T([128, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 24, ((T([128, 64, 128], f16, stride=(8192, 1, 64)), T([128, 128, 128], f16)), {})
+cnt: 24, ((T([128, 128, 128], f16), T([128, 128, 64], f16, stride=(8192, 1, 128))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([8, 128, 1], f16), T([8, 128, 1], f16)], 2), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([8], i64), 0, 128), {})
+Operator: aten.clone.default
+cnt: 1, ((T([8, 128], i64),), {})
+cnt: 2, ((T([8], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([8, 128], i64), T([8, 128], i64)), {})
+cnt: 2, ((T([8], i64), T([8], i64)), {})
+Operator: aten.div.Tensor
+cnt: 48, ((T([8, 16, 128, 128], f16), 8.0), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([29056, 1024], f16), T([8, 128], i64), 0), {})
+cnt: 1, ((T([2, 1024], f16), T([8, 128], i64)), {})
+cnt: 1, ((T([512, 1024], f16), T([1, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 128, 1024], f16), T([1, 128], i64), 512, -1, False), {})
+cnt: 1, ((T([8, 128, 1024], f16), T([8, 128], i64), 2, -1, False), {})
+cnt: 1, ((T([8, 128, 1024], f16), T([8, 128], i64), 29056, 0, False), {})
+Operator: aten.gelu.default
+cnt: 24, ((T([8, 128, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 24, ((T([8, 128, 4096], f16), T([8, 128, 4096], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 2], f16), T([2, 1024], f16)), {})
+cnt: 1, ((T([2, 1024], f16, stride=(1, 2)), T([1024, 1024], f16)), {})
+cnt: 24, ((T([1024, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 1024], f16, stride=(1, 1024)), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 24, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 1024], f16)), {})
+cnt: 96, ((T([1024, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 96, ((T([1024, 1024], f16, stride=(1, 1024)), T([1024, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([8, 1, 1, 128], f16), -65504.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 49, ((T([8, 128, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 49, ((T([8, 128, 1024], f16), T([8, 128, 1024], f16), [1024], T([8, 128, 1], f32), T([8, 128, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 2, ((T([], f16), T([8, 128], f16), T([8], i64), None, 1, 128, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 2, ((T([8, 128], f16), T([8], i64), None, 1, 128), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([8, 1, 1, 128], f16), 1.0), {})
+Operator: aten.split.Tensor
+cnt: 1, ((T([8, 128, 2], f16), 1, -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([1024, 2], f16), [0], True), {})
+cnt: 120, ((T([1024, 1024], f16), [0], True), {})
+cnt: 24, ((T([1024, 4096], f16), [0], True), {})
+cnt: 1, ((T([8, 128, 1024], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MobileBertForMaskedLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MobileBertForMaskedLM_training.txt
new file mode 100644
index 0000000000000..e6b91aa0181ec
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MobileBertForMaskedLM_training.txt
@@ -0,0 +1,112 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([2048, 30522], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([2048, 30522], f16), T([2048, 30522], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([16, 4, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([16, 4, 128, 128], f16), T([16, 4, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([16, 1, 1, 128], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 72, ((T([16, 4, 128, 32], f16), [64, 128, 32]), {})
+cnt: 24, ((T([16, 4, 32, 128], f16), [64, 32, 128]), {})
+cnt: 24, ((T([64, 128, 128], f16), [16, 4, 128, 128]), {})
+cnt: 24, ((T([64, 128, 32], f16), [16, 4, 128, 32]), {})
+cnt: 1, ((T([2048, 30522], f16), [16, 128, 30522]), {})
+cnt: 48, ((T([16, 128, 4, 32], f16), [16, 128, 128]), {})
+cnt: 24, ((T([16, 128, 128], f16), [2048, 128]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([16, 128, 512], f16), T([1, 128, 512], f16)), {})
+cnt: 97, ((T([16, 128, 512], f16), T([16, 128, 512], f16)), {})
+cnt: 25, ((T([16, 128, 512], f16), T([512], f16)), {})
+cnt: 168, ((T([16, 128, 128], f16), T([128], f16)), {})
+cnt: 24, ((T([16, 4, 128, 128], f16), T([16, 1, 1, 128], f16)), {})
+cnt: 241, ((T([16, 128, 128], f16), T([16, 128, 128], f16)), {})
+cnt: 1, ((T([16, 128, 128], f16, stride=(49152, 384, 1)), T([16, 128, 128], f16)), {})
+cnt: 1, ((T([30522, 128], f16, stride=(1, 30522)), T([30522, 128], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([16, 128, 30522], f16), T([30522], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([512], f16), T([2048, 384], f16), T([384, 512], f16, stride=(1, 384))), {})
+cnt: 168, ((T([128], f16), T([2048, 512], f16), T([512, 128], f16, stride=(1, 512))), {})
+cnt: 72, ((T([128], f16), T([2048, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 120, ((T([512], f16), T([2048, 128], f16), T([128, 512], f16, stride=(1, 128))), {})
+cnt: 1, ((T([512], f16), T([2048, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([64, 128, 32], f16), T([64, 32, 128], f16)), {})
+cnt: 24, ((T([64, 128, 128], f16), T([64, 128, 32], f16)), {})
+cnt: 24, ((T([64, 128, 128], f16, stride=(16384, 1, 128)), T([64, 128, 32], f16)), {})
+cnt: 24, ((T([64, 128, 32], f16), T([64, 32, 128], f16, stride=(4096, 1, 32))), {})
+cnt: 24, ((T([64, 32, 128], f16, stride=(4096, 1, 32)), T([64, 128, 128], f16)), {})
+cnt: 24, ((T([64, 128, 128], f16), T([64, 128, 32], f16, stride=(4096, 1, 128))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([16, 128, 128], f16), T([16, 128, 128], f16), T([16, 128, 128], f16)], 2), {})
+cnt: 1, (([T([128, 30522], f16, stride=(1, 128)), T([384, 30522], f16)],), {})
+Operator: aten.clone.default
+cnt: 2, ((T([16, 128], i64),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([16, 127, 128], f16, stride=(16384, 128, 1)), [0, 0, 0, 1, 0, 0], 0.0), {})
+cnt: 1, ((T([16, 127, 128], f16, stride=(16384, 128, 1)), [0, 0, 1, 0, 0, 0], 0.0), {})
+cnt: 1, ((T([16, 128, 128], f16, stride=(49152, 384, 1)), [0, 0, -1, 0, 0, 0]), {})
+cnt: 1, ((T([16, 128, 128], f16, stride=(49152, 384, 1)), [0, 0, 0, -1, 0, 0]), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([16, 128], i64), T([16, 128], i64)), {})
+cnt: 1, ((T([30522, 128], f16), T([30522, 128], f16, stride=(1, 30522))), {})
+Operator: aten.div.Tensor
+cnt: 48, ((T([16, 4, 128, 128], f16), 5.656854249492381), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 128], f16), T([16, 128], i64), 0), {})
+cnt: 1, ((T([512, 512], f16), T([1, 128], i64)), {})
+cnt: 1, ((T([2, 512], f16), T([16, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([16, 128, 512], f16), T([16, 128], i64), 2, -1, False), {})
+cnt: 1, ((T([1, 128, 512], f16), T([1, 128], i64), 512, -1, False), {})
+cnt: 1, ((T([16, 128, 128], f16), T([16, 128], i64), 30522, 0, False), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 512], f16), T([512, 30522], f16)), {})
+cnt: 1, ((T([512, 2048], f16, stride=(1, 512)), T([2048, 30522], f16)), {})
+cnt: 1, ((T([2048, 30522], f16), T([30522, 512], f16, stride=(1, 30522))), {})
+cnt: 1, ((T([2048, 512], f16), T([512, 512], f16)), {})
+cnt: 1, ((T([512, 2048], f16, stride=(1, 512)), T([2048, 512], f16)), {})
+cnt: 120, ((T([2048, 512], f16), T([512, 128], f16)), {})
+cnt: 120, ((T([512, 2048], f16, stride=(1, 512)), T([2048, 128], f16)), {})
+cnt: 168, ((T([2048, 128], f16), T([128, 512], f16)), {})
+cnt: 168, ((T([128, 2048], f16, stride=(1, 128)), T([2048, 512], f16)), {})
+cnt: 72, ((T([2048, 128], f16), T([128, 128], f16)), {})
+cnt: 72, ((T([128, 2048], f16, stride=(1, 128)), T([2048, 128], f16)), {})
+cnt: 1, ((T([2048, 512], f16), T([512, 384], f16)), {})
+cnt: 1, ((T([512, 2048], f16, stride=(1, 512)), T([2048, 384], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([16, 1, 1, 128], f16), -65504.0), {})
+cnt: 50, ((T([16, 128, 512], f16), T([512], f16)), {})
+cnt: 336, ((T([16, 128, 128], f16), T([128], f16)), {})
+cnt: 25, ((T([16, 128, 512], f16), T([16, 128, 512], f16)), {})
+cnt: 168, ((T([16, 128, 128], f16), T([16, 128, 128], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 1, ((T([16, 128, 512], f16), [512], T([512], f16), T([512], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 1, ((T([16, 128, 512], f16), T([16, 128, 512], f16), [512], T([16, 128, 1], f32), T([16, 128, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 1, ((T([30522, 128], f16, stride=(1, 30522)), [30522, 128], [128, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([2048, 30522], f16), T([2048], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([2048, 30522], f16), T([2048], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 97, ((T([16, 128, 512], f16),), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([16, 1, 1, 128], f16), 1.0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([16, 127, 128], f16), [16, 128, 128], 1, 0, -1, 1), {})
+cnt: 2, ((T([16, 128, 128], f16), [16, 128, 128], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([16, 127, 128], f16), [16, 128, 128], 1, 1, 9223372036854775807, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([16, 128, 30522], f16), [0, 1], True), {})
+cnt: 122, ((T([2048, 512], f16), [0], True), {})
+cnt: 50, ((T([16, 128, 512], f16), [0, 1], True), {})
+cnt: 336, ((T([16, 128, 128], f16), [0, 1], True), {})
+cnt: 240, ((T([2048, 128], f16), [0], True), {})
+cnt: 1, ((T([16, 128, 512], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 97, ((T([16, 128, 512], f16), T([16, 128, 512], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MobileBertForQuestionAnswering_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MobileBertForQuestionAnswering_training.txt
new file mode 100644
index 0000000000000..c5e7b0f51c677
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/MobileBertForQuestionAnswering_training.txt
@@ -0,0 +1,106 @@
+Operator: aten._log_softmax.default
+cnt: 2, ((T([32, 128], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 2, ((T([32, 128], f16), T([32, 128], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([32, 4, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([32, 4, 128, 128], f16), T([32, 4, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([32, 1, 1, 128], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 72, ((T([32, 4, 128, 32], f16), [128, 128, 32]), {})
+cnt: 24, ((T([32, 4, 32, 128], f16), [128, 32, 128]), {})
+cnt: 24, ((T([128, 128, 128], f16), [32, 4, 128, 128]), {})
+cnt: 24, ((T([128, 128, 32], f16), [32, 4, 128, 32]), {})
+cnt: 48, ((T([32, 128, 4, 32], f16), [32, 128, 128]), {})
+cnt: 24, ((T([32, 128, 128], f16), [4096, 128]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([32, 128, 512], f16), T([1, 128, 512], f16)), {})
+cnt: 97, ((T([32, 128, 512], f16), T([32, 128, 512], f16)), {})
+cnt: 25, ((T([32, 128, 512], f16), T([512], f16)), {})
+cnt: 168, ((T([32, 128, 128], f16), T([128], f16)), {})
+cnt: 24, ((T([32, 4, 128, 128], f16), T([32, 1, 1, 128], f16)), {})
+cnt: 241, ((T([32, 128, 128], f16), T([32, 128, 128], f16)), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+cnt: 1, ((T([32, 128, 128], f16, stride=(49152, 384, 1)), T([32, 128, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([512], f16), T([4096, 384], f16), T([384, 512], f16, stride=(1, 384))), {})
+cnt: 168, ((T([128], f16), T([4096, 512], f16), T([512, 128], f16, stride=(1, 512))), {})
+cnt: 72, ((T([128], f16), T([4096, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 120, ((T([512], f16), T([4096, 128], f16), T([128, 512], f16, stride=(1, 128))), {})
+cnt: 1, ((T([2], f16), T([4096, 512], f16), T([512, 2], f16, stride=(1, 512))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([128, 128, 32], f16), T([128, 32, 128], f16)), {})
+cnt: 24, ((T([128, 128, 128], f16), T([128, 128, 32], f16)), {})
+cnt: 24, ((T([128, 128, 128], f16, stride=(16384, 1, 128)), T([128, 128, 32], f16)), {})
+cnt: 24, ((T([128, 128, 32], f16), T([128, 32, 128], f16, stride=(4096, 1, 32))), {})
+cnt: 24, ((T([128, 32, 128], f16, stride=(4096, 1, 32)), T([128, 128, 128], f16)), {})
+cnt: 24, ((T([128, 128, 128], f16), T([128, 128, 32], f16, stride=(4096, 1, 128))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([32, 128, 128], f16), T([32, 128, 128], f16), T([32, 128, 128], f16)], 2), {})
+cnt: 1, (([T([32, 128, 1], f16), T([32, 128, 1], f16)], 2), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([32], i64), 0, 128), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 128], i64),), {})
+cnt: 2, ((T([32], i64),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([32, 127, 128], f16, stride=(16384, 128, 1)), [0, 0, 0, 1, 0, 0], 0.0), {})
+cnt: 1, ((T([32, 127, 128], f16, stride=(16384, 128, 1)), [0, 0, 1, 0, 0, 0], 0.0), {})
+cnt: 1, ((T([32, 128, 128], f16, stride=(49152, 384, 1)), [0, 0, -1, 0, 0, 0]), {})
+cnt: 1, ((T([32, 128, 128], f16, stride=(49152, 384, 1)), [0, 0, 0, -1, 0, 0]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 128], i64), T([32, 128], i64)), {})
+cnt: 2, ((T([32], i64), T([32], i64)), {})
+Operator: aten.div.Tensor
+cnt: 48, ((T([32, 4, 128, 128], f16), 5.656854249492381), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 128], f16), T([32, 128], i64), 0), {})
+cnt: 1, ((T([512, 512], f16), T([1, 128], i64)), {})
+cnt: 1, ((T([2, 512], f16), T([32, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([32, 128, 512], f16), T([32, 128], i64), 2, -1, False), {})
+cnt: 1, ((T([1, 128, 512], f16), T([1, 128], i64), 512, -1, False), {})
+cnt: 1, ((T([32, 128, 128], f16), T([32, 128], i64), 30522, 0, False), {})
+Operator: aten.mm.default
+cnt: 1, ((T([4096, 2], f16), T([2, 512], f16)), {})
+cnt: 1, ((T([2, 4096], f16, stride=(1, 2)), T([4096, 512], f16)), {})
+cnt: 120, ((T([4096, 512], f16), T([512, 128], f16)), {})
+cnt: 120, ((T([512, 4096], f16, stride=(1, 512)), T([4096, 128], f16)), {})
+cnt: 168, ((T([4096, 128], f16), T([128, 512], f16)), {})
+cnt: 168, ((T([128, 4096], f16, stride=(1, 128)), T([4096, 512], f16)), {})
+cnt: 72, ((T([4096, 128], f16), T([128, 128], f16)), {})
+cnt: 72, ((T([128, 4096], f16, stride=(1, 128)), T([4096, 128], f16)), {})
+cnt: 1, ((T([4096, 512], f16), T([512, 384], f16)), {})
+cnt: 1, ((T([512, 4096], f16, stride=(1, 512)), T([4096, 384], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([32, 1, 1, 128], f16), -65504.0), {})
+cnt: 50, ((T([32, 128, 512], f16), T([512], f16)), {})
+cnt: 336, ((T([32, 128, 128], f16), T([128], f16)), {})
+cnt: 25, ((T([32, 128, 512], f16), T([32, 128, 512], f16)), {})
+cnt: 168, ((T([32, 128, 128], f16), T([32, 128, 128], f16)), {})
+Operator: aten.nll_loss_backward.default
+cnt: 2, ((T([], f16), T([32, 128], f16), T([32], i64), None, 1, 128, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 2, ((T([32, 128], f16), T([32], i64), None, 1, 128), {})
+Operator: aten.relu.default
+cnt: 96, ((T([32, 128, 512], f16),), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([32, 1, 1, 128], f16), 1.0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([32, 127, 128], f16), [32, 128, 128], 1, 0, -1, 1), {})
+cnt: 2, ((T([32, 128, 128], f16), [32, 128, 128], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 127, 128], f16), [32, 128, 128], 1, 1, 9223372036854775807, 1), {})
+Operator: aten.split.Tensor
+cnt: 1, ((T([32, 128, 2], f16), 1, -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([4096, 2], f16), [0], True), {})
+cnt: 50, ((T([32, 128, 512], f16), [0, 1], True), {})
+cnt: 121, ((T([4096, 512], f16), [0], True), {})
+cnt: 336, ((T([32, 128, 128], f16), [0, 1], True), {})
+cnt: 240, ((T([4096, 128], f16), [0], True), {})
+cnt: 1, ((T([32, 128, 512], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 96, ((T([32, 128, 512], f16), T([32, 128, 512], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/OPTForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/OPTForCausalLM_training.txt
new file mode 100644
index 0000000000000..533b1875674b2
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/OPTForCausalLM_training.txt
@@ -0,0 +1,103 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([508, 50272], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([508, 50272], f16), T([508, 50272], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([48, 128, 128], f16), -1, True), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([48, 128, 128], f32), T([48, 128, 128], f32), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([4, 128], b8),), {'dtype': i64})
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([4, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([4, 1, 128, 128], b8, stride=(128, 128, 0, 1)),), {'dtype': f16})
+cnt: 1, ((T([4, 1, 128, 128], f16),), {'dtype': torch.bool})
+cnt: 12, ((T([48, 128, 128], f32),), {'dtype': f16})
+cnt: 12, ((T([48, 128, 128], f16),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([4, 128, 12, 64], f16), [4, 128, 768]), {})
+cnt: 1, ((T([512, 50272], f16), [4, 128, 50272]), {})
+cnt: 12, ((T([4, 12, 128, 64], f16), [48, 128, 64]), {})
+cnt: 12, ((T([4, 128, 768], f16), [512, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([4, 128], i64), 2), {})
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 1, ((T([4, 1, 128, 128], f16), T([4, 1, 128, 128], f16)), {})
+cnt: 49, ((T([4, 128, 768], f16), T([4, 128, 768], f16)), {})
+cnt: 12, ((T([4, 12, 128, 128], f16), T([4, 1, 128, 128], f16)), {})
+cnt: 24, ((T([512, 768], f16), T([512, 768], f16)), {})
+cnt: 1, ((T([50272, 768], f16), T([50272, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([768], f16), T([512, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([512, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([512, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([48, 128, 64], f16), T([48, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 24, ((T([48, 128, 128], f16), T([48, 128, 64], f16)), {})
+cnt: 12, ((T([48, 128, 128], f16, stride=(16384, 1, 128)), T([48, 128, 64], f16)), {})
+cnt: 12, ((T([48, 64, 128], f16, stride=(8192, 1, 64)), T([48, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([4, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([4, 128], i64), T([4, 128], i64)), {})
+Operator: aten.cumsum.default
+cnt: 1, ((T([4, 128], i64), 1), {})
+Operator: aten.div.Scalar
+cnt: 12, ((T([4, 12, 128, 128], f16), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50272, 768], f16), T([4, 128], i64), 1), {})
+cnt: 1, ((T([2050, 768], f16), T([4, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([4, 128, 768], f16), T([4, 128], i64), 2050, -1, False), {})
+cnt: 1, ((T([4, 128, 768], f16), T([4, 128], i64), 50272, 1, False), {})
+Operator: aten.eq.Tensor
+cnt: 12, ((T([4, 12, 128, 128], f16), T([], f32)), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+cnt: 12, ((T([4, 12, 128, 128], f16), T([], f32)), {})
+Operator: aten.masked_fill.Scalar
+cnt: 1, ((T([4, 1, 128, 128], f16), T([4, 1, 128, 128], b8), -65504.0), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+cnt: 12, ((T([4, 12, 128, 128], f16), T([4, 12, 128, 128], b8), 0), {})
+Operator: aten.maximum.default
+cnt: 12, ((T([4, 12, 128, 128], f16), T([], f32)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 768], f16), T([768, 50272], f16, stride=(1, 768))), {})
+cnt: 1, ((T([50272, 512], f16, stride=(1, 50272)), T([512, 768], f16)), {})
+cnt: 1, ((T([512, 50272], f16), T([50272, 768], f16)), {})
+cnt: 12, ((T([512, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 512], f16, stride=(1, 768)), T([512, 3072], f16)), {})
+cnt: 12, ((T([512, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 512], f16, stride=(1, 3072)), T([512, 768], f16)), {})
+cnt: 48, ((T([512, 768], f16), T([768, 768], f16)), {})
+cnt: 48, ((T([768, 512], f16, stride=(1, 768)), T([512, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([4, 128], i64), T([4, 128], i64)), {})
+cnt: 24, ((T([4, 128, 768], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 13, ((T([4, 128, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+cnt: 12, ((T([512, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 13, ((T([4, 128, 768], f16), T([4, 128, 768], f16), [768], T([4, 128, 1], f32), T([4, 128, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+cnt: 12, ((T([512, 768], f16), T([512, 768], f16), [768], T([512, 1], f32), T([512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([508, 50272], f16), T([508], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([508, 50272], f16), T([508], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 12, ((T([512, 3072], f16),), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([4, 1, 128, 128], f16), 1.0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([4, 127, 50272], f16), [4, 127, 50272], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([4, 127, 50272], f16), [4, 128, 50272], 1, 0, -1, 1), {})
+Operator: aten.sub.Tensor
+cnt: 1, ((T([4, 128], i64), 1), {})
+Operator: aten.sum.SymInt
+cnt: 60, ((T([512, 768], f16), [0], True), {})
+cnt: 12, ((T([512, 3072], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 12, ((T([512, 3072], f16), T([512, 3072], f16), 0), {})
+Operator: aten.where.self
+cnt: 12, ((T([4, 12, 128, 128], b8), T([4, 12, 128, 128], f16), T([4, 12, 128, 128], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PLBartForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PLBartForCausalLM_training.txt
new file mode 100644
index 0000000000000..7617876fd4aad
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PLBartForCausalLM_training.txt
@@ -0,0 +1,73 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([2048, 50005], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([2048, 50005], f16), T([2048, 50005], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 6, ((T([192, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 6, ((T([192, 128, 128], f16), T([192, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([16, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 18, ((T([16, 128, 12, 64], f16), [16, 128, 768]), {})
+cnt: 1, ((T([2048, 50005], f16), [16, 128, 50005]), {})
+cnt: 6, ((T([16, 12, 128, 64], f16), [192, 128, 64]), {})
+cnt: 6, ((T([16, 128, 768], f16), [2048, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 1, ((T([16, 128], i64, stride=(0, 1)), 2), {})
+cnt: 37, ((T([16, 128, 768], f16), T([16, 128, 768], f16)), {})
+cnt: 6, ((T([16, 12, 128, 128], f16), T([16, 1, 128, 128], f16)), {})
+cnt: 1, ((T([50005, 768], f16), T([50005, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 24, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 6, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 6, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([192, 128, 64], f16), T([192, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 12, ((T([192, 128, 128], f16), T([192, 128, 64], f16)), {})
+cnt: 6, ((T([192, 128, 128], f16, stride=(16384, 1, 128)), T([192, 128, 64], f16)), {})
+cnt: 6, ((T([192, 64, 128], f16, stride=(8192, 1, 64)), T([192, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([16, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([16, 128], i64), T([16, 128], i64)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50005, 768], f16), T([16, 128], i64), 1), {})
+cnt: 1, ((T([1026, 768], f16), T([16, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([16, 128, 768], f16), T([16, 128], i64), 1026, -1, False), {})
+cnt: 1, ((T([16, 128, 768], f16), T([16, 128], i64), 50005, 1, False), {})
+Operator: aten.gelu.default
+cnt: 6, ((T([16, 128, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 6, ((T([16, 128, 3072], f16), T([16, 128, 3072], f16)), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 768], f16), T([768, 50005], f16, stride=(1, 768))), {})
+cnt: 1, ((T([50005, 2048], f16, stride=(1, 50005)), T([2048, 768], f16)), {})
+cnt: 1, ((T([2048, 50005], f16), T([50005, 768], f16)), {})
+cnt: 6, ((T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 6, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+cnt: 6, ((T([2048, 3072], f16), T([3072, 768], f16)), {})
+cnt: 6, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+cnt: 24, ((T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 24, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([16, 128, 768], f16), 27.712812921102035), {})
+cnt: 12, ((T([16, 128, 768], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 13, ((T([16, 128, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 13, ((T([16, 128, 768], f16), T([16, 128, 768], f16), [768], T([16, 128, 1], f32), T([16, 128, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([2048, 50005], f16), T([2048], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([2048, 50005], f16), T([2048], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 30, ((T([2048, 768], f16), [0], True), {})
+cnt: 6, ((T([2048, 3072], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PLBartForConditionalGeneration_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PLBartForConditionalGeneration_training.txt
new file mode 100644
index 0000000000000..55115055a052d
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PLBartForConditionalGeneration_training.txt
@@ -0,0 +1,94 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([1024, 50005], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([1024, 50005], f16), T([1024, 50005], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 18, ((T([96, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 18, ((T([96, 128, 128], f16), T([96, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([8, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 54, ((T([8, 128, 12, 64], f16), [8, 128, 768]), {})
+cnt: 1, ((T([1024, 50005], f16), [8, 128, 50005]), {})
+cnt: 18, ((T([8, 12, 128, 64], f16), [96, 128, 64]), {})
+cnt: 18, ((T([8, 128, 768], f16), [1024, 768]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([8, 128], i64, stride=(0, 1)), 2), {})
+cnt: 97, ((T([8, 128, 768], f16), T([8, 128, 768], f16)), {})
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 6, ((T([8, 12, 128, 128], f16), T([8, 1, 128, 128], f16)), {})
+cnt: 1, ((T([8, 128, 50005], f16), T([1, 50005], f16)), {})
+cnt: 2, ((T([50005, 768], f16), T([50005, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 72, ((T([768], f16), T([1024, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([1024, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([1024, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+Operator: aten.any.default
+cnt: 12, ((T([8, 128, 768], b8),), {})
+Operator: aten.bmm.default
+cnt: 36, ((T([96, 128, 64], f16), T([96, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 36, ((T([96, 128, 128], f16), T([96, 128, 64], f16)), {})
+cnt: 18, ((T([96, 128, 128], f16, stride=(16384, 1, 128)), T([96, 128, 64], f16)), {})
+cnt: 18, ((T([96, 64, 128], f16, stride=(8192, 1, 64)), T([96, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 3, ((T([8, 128], i64),), {})
+cnt: 1, ((T([8, 127], i64, stride=(128, 1)),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([8, 128], i64), T([8, 128], i64)), {})
+cnt: 1, ((T([8, 127], i64, stride=(128, 1)), T([8, 127], i64)), {})
+cnt: 1, ((T([8], i64, stride=(128,)), T([8], i64)), {})
+Operator: aten.embedding.default
+cnt: 2, ((T([50005, 768], f16), T([8, 128], i64), 1), {})
+cnt: 2, ((T([1026, 768], f16), T([8, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 2, ((T([8, 128, 768], f16), T([8, 128], i64), 1026, -1, False), {})
+cnt: 2, ((T([8, 128, 768], f16), T([8, 128], i64), 50005, 1, False), {})
+Operator: aten.eq.Scalar
+cnt: 1, ((T([8, 128], i64), -100), {})
+Operator: aten.gather.default
+cnt: 1, ((T([8, 128], i64), 1, T([8, 1], i64)), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([8, 128, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([8, 128, 3072], f16), T([8, 128, 3072], f16)), {})
+Operator: aten.isinf.default
+cnt: 6, ((T([8, 128, 768], f16),), {})
+Operator: aten.isnan.default
+cnt: 6, ((T([8, 128, 768], f16),), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([8, 128], i64), T([8, 128], b8), 1), {})
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 768], f16), T([768, 50005], f16, stride=(1, 768))), {})
+cnt: 1, ((T([50005, 1024], f16, stride=(1, 50005)), T([1024, 768], f16)), {})
+cnt: 1, ((T([1024, 50005], f16), T([50005, 768], f16)), {})
+cnt: 12, ((T([1024, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 1024], f16, stride=(1, 768)), T([1024, 3072], f16)), {})
+cnt: 12, ((T([1024, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 1024], f16, stride=(1, 3072)), T([1024, 768], f16)), {})
+cnt: 72, ((T([1024, 768], f16), T([768, 768], f16)), {})
+cnt: 72, ((T([768, 1024], f16, stride=(1, 768)), T([1024, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([8, 128, 768], f16), 27.712812921102035), {})
+cnt: 36, ((T([8, 128, 768], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 32, ((T([8, 128, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 32, ((T([8, 128, 768], f16), T([8, 128, 768], f16), [768], T([8, 128, 1], f32), T([8, 128, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([8, 128], i64), 1), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([1024, 50005], f16), T([1024], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([1024, 50005], f16), T([1024], i64), None, 1, -100), {})
+Operator: aten.sub.Tensor
+cnt: 1, ((T([8], i64), 1), {})
+Operator: aten.sum.SymInt
+cnt: 84, ((T([1024, 768], f16), [0], True), {})
+cnt: 12, ((T([1024, 3072], f16), [0], True), {})
+Operator: aten.sum.dim_IntList
+cnt: 1, ((T([8, 128], b8), [1]), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PegasusForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PegasusForCausalLM_training.txt
new file mode 100644
index 0000000000000..1341c27983983
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PegasusForCausalLM_training.txt
@@ -0,0 +1,72 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([1024, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([1024, 50265], f16), T([1024, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([128, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([128, 128, 128], f16), T([128, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([8, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([8, 128, 16, 64], f16), [8, 128, 1024]), {})
+cnt: 1, ((T([1024, 50265], f16), [8, 128, 50265]), {})
+cnt: 12, ((T([8, 16, 128, 64], f16), [128, 128, 64]), {})
+cnt: 12, ((T([8, 128, 1024], f16), [1024, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 1, ((T([8, 128, 1024], f16), T([128, 1024], f16)), {})
+cnt: 12, ((T([8, 16, 128, 128], f16), T([8, 1, 128, 128], f16)), {})
+cnt: 72, ((T([8, 128, 1024], f16), T([8, 128, 1024], f16)), {})
+cnt: 1, ((T([50265, 1024], f16), T([50265, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([1024], f16), T([1024, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 12, ((T([4096], f16), T([1024, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 12, ((T([1024], f16), T([1024, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([128, 128, 64], f16), T([128, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 24, ((T([128, 128, 128], f16), T([128, 128, 64], f16)), {})
+cnt: 12, ((T([128, 128, 128], f16, stride=(16384, 1, 128)), T([128, 128, 64], f16)), {})
+cnt: 12, ((T([128, 64, 128], f16, stride=(8192, 1, 64)), T([128, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([8, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([8, 128], i64), T([8, 128], i64)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50265, 1024], f16), T([8, 128], i64), 0), {})
+cnt: 1, ((T([1024, 1024], f16), T([128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([8, 128, 1024], f16), T([8, 128], i64), 50265, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([8, 128, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([8, 128, 4096], f16), T([8, 128, 4096], f16)), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 1024], f16), T([1024, 50265], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([50265, 1024], f16, stride=(1, 50265)), T([1024, 1024], f16)), {})
+cnt: 1, ((T([1024, 50265], f16), T([50265, 1024], f16)), {})
+cnt: 12, ((T([1024, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 12, ((T([1024, 1024], f16, stride=(1, 1024)), T([1024, 4096], f16)), {})
+cnt: 12, ((T([1024, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 12, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 1024], f16)), {})
+cnt: 48, ((T([1024, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 48, ((T([1024, 1024], f16, stride=(1, 1024)), T([1024, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([8, 128, 1024], f16), 1.0), {})
+cnt: 24, ((T([8, 128, 1024], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([8, 128, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([8, 128, 1024], f16), T([8, 128, 1024], f16), [1024], T([8, 128, 1], f32), T([8, 128, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([1024, 50265], f16), T([1024], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([1024, 50265], f16), T([1024], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 60, ((T([1024, 1024], f16), [0], True), {})
+cnt: 12, ((T([1024, 4096], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PegasusForConditionalGeneration_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PegasusForConditionalGeneration_training.txt
new file mode 100644
index 0000000000000..970513d4b3547
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/PegasusForConditionalGeneration_training.txt
@@ -0,0 +1,79 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([512, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([512, 50265], f16), T([512, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 36, ((T([64, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 36, ((T([64, 128, 128], f16), T([64, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([4, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 108, ((T([4, 128, 16, 64], f16), [4, 128, 1024]), {})
+cnt: 1, ((T([512, 50265], f16), [4, 128, 50265]), {})
+cnt: 36, ((T([4, 16, 128, 64], f16), [64, 128, 64]), {})
+cnt: 36, ((T([4, 128, 1024], f16), [512, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([4, 128, 1024], f16), T([128, 1024], f16)), {})
+cnt: 191, ((T([4, 128, 1024], f16), T([4, 128, 1024], f16)), {})
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 12, ((T([4, 16, 128, 128], f16), T([4, 1, 128, 128], f16)), {})
+cnt: 1, ((T([4, 128, 50265], f16), T([1, 50265], f16)), {})
+cnt: 2, ((T([50265, 1024], f16), T([50265, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 144, ((T([1024], f16), T([512, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([4096], f16), T([512, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([1024], f16), T([512, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+Operator: aten.any.default
+cnt: 24, ((T([4, 128, 1024], b8),), {})
+Operator: aten.bmm.default
+cnt: 72, ((T([64, 128, 64], f16), T([64, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 72, ((T([64, 128, 128], f16), T([64, 128, 64], f16)), {})
+cnt: 36, ((T([64, 128, 128], f16, stride=(16384, 1, 128)), T([64, 128, 64], f16)), {})
+cnt: 36, ((T([64, 64, 128], f16, stride=(8192, 1, 64)), T([64, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 3, ((T([4, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 3, ((T([4, 128], i64), T([4, 128], i64)), {})
+Operator: aten.embedding.default
+cnt: 2, ((T([50265, 1024], f16), T([4, 128], i64), 0), {})
+cnt: 2, ((T([1024, 1024], f16), T([128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 2, ((T([4, 128, 1024], f16), T([4, 128], i64), 50265, 0, False), {})
+Operator: aten.gelu.default
+cnt: 24, ((T([4, 128, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 24, ((T([4, 128, 4096], f16), T([4, 128, 4096], f16)), {})
+Operator: aten.isinf.default
+cnt: 12, ((T([4, 128, 1024], f16),), {})
+Operator: aten.isnan.default
+cnt: 12, ((T([4, 128, 1024], f16),), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 1024], f16), T([1024, 50265], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([50265, 512], f16, stride=(1, 50265)), T([512, 1024], f16)), {})
+cnt: 1, ((T([512, 50265], f16), T([50265, 1024], f16)), {})
+cnt: 24, ((T([512, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 512], f16, stride=(1, 1024)), T([512, 4096], f16)), {})
+cnt: 24, ((T([512, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 24, ((T([4096, 512], f16, stride=(1, 4096)), T([512, 1024], f16)), {})
+cnt: 144, ((T([512, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 144, ((T([1024, 512], f16, stride=(1, 1024)), T([512, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([4, 128, 1024], f16), 1.0), {})
+cnt: 72, ((T([4, 128, 1024], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 62, ((T([4, 128, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 62, ((T([4, 128, 1024], f16), T([4, 128, 1024], f16), [1024], T([4, 128, 1], f32), T([4, 128, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([512, 50265], f16), T([512], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([512, 50265], f16), T([512], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 168, ((T([512, 1024], f16), [0], True), {})
+cnt: 24, ((T([512, 4096], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/RobertaForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/RobertaForCausalLM_training.txt
new file mode 100644
index 0000000000000..25b78750deb50
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/RobertaForCausalLM_training.txt
@@ -0,0 +1,94 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([508, 30522], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([508, 30522], f16), T([508, 30522], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([4, 12, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([4, 12, 128, 128], f16), T([4, 12, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([4, 1, 1, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([4, 128], b8),), {'dtype': i32})
+cnt: 1, ((T([4, 128], i64),), {'dtype': i32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([4, 128], i32),), {'dtype': i64})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([4, 12, 128, 64], f16), [48, 128, 64]), {})
+cnt: 12, ((T([4, 12, 64, 128], f16), [48, 64, 128]), {})
+cnt: 12, ((T([48, 128, 128], f16), [4, 12, 128, 128]), {})
+cnt: 12, ((T([48, 128, 64], f16), [4, 12, 128, 64]), {})
+cnt: 24, ((T([4, 128, 12, 64], f16), [4, 128, 768]), {})
+cnt: 12, ((T([4, 128, 768], f16), [512, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([4, 128], i32), 0), {})
+cnt: 1, ((T([4, 128], i64), 0), {})
+cnt: 73, ((T([4, 128, 768], f16), T([4, 128, 768], f16)), {})
+cnt: 12, ((T([4, 12, 128, 128], f16), T([4, 1, 1, 128], f16)), {})
+cnt: 1, ((T([30522, 768], f16), T([30522, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([4, 128, 768], f16), T([4, 128, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 49, ((T([768], f16), T([512, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([512, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([512, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([30522], f16), T([512, 768], f16), T([768, 30522], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([48, 128, 64], f16), T([48, 64, 128], f16)), {})
+cnt: 12, ((T([48, 128, 128], f16), T([48, 128, 64], f16)), {})
+cnt: 12, ((T([48, 128, 128], f16, stride=(16384, 1, 128)), T([48, 128, 64], f16)), {})
+cnt: 12, ((T([48, 128, 64], f16), T([48, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 12, ((T([48, 64, 128], f16, stride=(8192, 1, 64)), T([48, 128, 128], f16)), {})
+cnt: 12, ((T([48, 128, 128], f16), T([48, 128, 64], f16, stride=(8192, 1, 128))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([4, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([4, 128], i64), T([4, 128], i64)), {})
+Operator: aten.cumsum.default
+cnt: 1, ((T([4, 128], i32), 1), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([4, 12, 128, 128], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([4, 128], i64), 0), {})
+cnt: 1, ((T([2, 768], f16), T([4, 128], i64, stride=(0, 1))), {})
+cnt: 1, ((T([512, 768], f16), T([4, 128], i64), 0), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([4, 128, 768], f16), T([4, 128], i64), 512, 0, False), {})
+cnt: 1, ((T([4, 128, 768], f16), T([4, 128], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([4, 128, 768], f16), T([4, 128], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([4, 128, 3072], f16),), {})
+cnt: 1, ((T([4, 128, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([4, 128, 768], f16), T([4, 128, 768], f16)), {})
+cnt: 12, ((T([4, 128, 3072], f16), T([4, 128, 3072], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 30522], f16), T([30522, 768], f16)), {})
+cnt: 1, ((T([30522, 512], f16, stride=(1, 30522)), T([512, 768], f16)), {})
+cnt: 49, ((T([512, 768], f16), T([768, 768], f16)), {})
+cnt: 49, ((T([768, 512], f16, stride=(1, 768)), T([512, 768], f16)), {})
+cnt: 12, ((T([512, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 512], f16, stride=(1, 768)), T([512, 3072], f16)), {})
+cnt: 12, ((T([512, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 512], f16, stride=(1, 3072)), T([512, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([4, 1, 1, 128], f16), -65504.0), {})
+cnt: 1, ((T([4, 128], i32), T([4, 128], i32)), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([4, 128, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([4, 128, 768], f16), T([4, 128, 768], f16), [768], T([4, 128, 1], f32), T([4, 128, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([4, 128], i64), 0), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([508, 30522], f16), T([508], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([508, 30522], f16), T([508], i64), None, 1, -100), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([4, 1, 1, 128], f16), 1.0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([4, 127, 30522], f16), [4, 127, 30522], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([4, 127, 30522], f16), [4, 128, 30522], 1, 0, -1, 1), {})
+cnt: 1, ((T([4, 128, 30522], f16), [4, 128, 30522], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([512, 30522], f16), [0], True), {})
+cnt: 61, ((T([512, 768], f16), [0], True), {})
+cnt: 12, ((T([512, 3072], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/RobertaForQuestionAnswering_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/RobertaForQuestionAnswering_training.txt
new file mode 100644
index 0000000000000..02cf28ea08677
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/RobertaForQuestionAnswering_training.txt
@@ -0,0 +1,97 @@
+Operator: aten._log_softmax.default
+cnt: 2, ((T([64, 128], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 2, ((T([64, 128], f16), T([64, 128], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([64, 12, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([64, 12, 128, 128], f16), T([64, 12, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([64, 1, 1, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([64, 128], b8),), {'dtype': i32})
+cnt: 1, ((T([64, 128], i64),), {'dtype': i32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([64, 128], i32),), {'dtype': i64})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([64, 12, 128, 64], f16), [768, 128, 64]), {})
+cnt: 12, ((T([64, 12, 64, 128], f16), [768, 64, 128]), {})
+cnt: 12, ((T([768, 128, 128], f16), [64, 12, 128, 128]), {})
+cnt: 12, ((T([768, 128, 64], f16), [64, 12, 128, 64]), {})
+cnt: 24, ((T([64, 128, 12, 64], f16), [64, 128, 768]), {})
+cnt: 12, ((T([64, 128, 768], f16), [8192, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([64, 128], i32), 0), {})
+cnt: 1, ((T([64, 128], i64), 0), {})
+cnt: 73, ((T([64, 128, 768], f16), T([64, 128, 768], f16)), {})
+cnt: 12, ((T([64, 12, 128, 128], f16), T([64, 1, 1, 128], f16)), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([64, 128, 768], f16), T([64, 128, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([768], f16), T([8192, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([8192, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([8192, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([2], f16), T([8192, 768], f16), T([768, 2], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([768, 128, 64], f16), T([768, 64, 128], f16)), {})
+cnt: 12, ((T([768, 128, 128], f16), T([768, 128, 64], f16)), {})
+cnt: 12, ((T([768, 128, 128], f16, stride=(16384, 1, 128)), T([768, 128, 64], f16)), {})
+cnt: 12, ((T([768, 128, 64], f16), T([768, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 12, ((T([768, 64, 128], f16, stride=(8192, 1, 64)), T([768, 128, 128], f16)), {})
+cnt: 12, ((T([768, 128, 128], f16), T([768, 128, 64], f16, stride=(8192, 1, 128))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 128, 1], f16), T([64, 128, 1], f16)], 2), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([64], i64), 0, 128), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 128], i64),), {})
+cnt: 2, ((T([64], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 128], i64), T([64, 128], i64)), {})
+cnt: 2, ((T([64], i64), T([64], i64)), {})
+Operator: aten.cumsum.default
+cnt: 1, ((T([64, 128], i32), 1), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([64, 12, 128, 128], f16), 8.0), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([64, 128], i64), 0), {})
+cnt: 1, ((T([2, 768], f16), T([64, 128], i64, stride=(0, 1))), {})
+cnt: 1, ((T([512, 768], f16), T([64, 128], i64), 0), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([64, 128, 768], f16), T([64, 128], i64), 512, 0, False), {})
+cnt: 1, ((T([64, 128, 768], f16), T([64, 128], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([64, 128, 768], f16), T([64, 128], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([64, 128, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([64, 128, 3072], f16), T([64, 128, 3072], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([8192, 2], f16), T([2, 768], f16)), {})
+cnt: 1, ((T([2, 8192], f16, stride=(1, 2)), T([8192, 768], f16)), {})
+cnt: 12, ((T([8192, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 3072], f16)), {})
+cnt: 12, ((T([8192, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 8192], f16, stride=(1, 3072)), T([8192, 768], f16)), {})
+cnt: 48, ((T([8192, 768], f16), T([768, 768], f16)), {})
+cnt: 48, ((T([768, 8192], f16, stride=(1, 768)), T([8192, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([64, 1, 1, 128], f16), -65504.0), {})
+cnt: 1, ((T([64, 128], i32), T([64, 128], i32)), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([64, 128, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([64, 128, 768], f16), T([64, 128, 768], f16), [768], T([64, 128, 1], f32), T([64, 128, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([64, 128], i64), 0), {})
+Operator: aten.nll_loss_backward.default
+cnt: 2, ((T([], f16), T([64, 128], f16), T([64], i64), None, 1, 128, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 2, ((T([64, 128], f16), T([64], i64), None, 1, 128), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([64, 1, 1, 128], f16), 1.0), {})
+Operator: aten.split.Tensor
+cnt: 1, ((T([64, 128, 2], f16), 1, -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([8192, 2], f16), [0], True), {})
+cnt: 60, ((T([8192, 768], f16), [0], True), {})
+cnt: 12, ((T([8192, 3072], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/Speech2Text2ForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/Speech2Text2ForCausalLM_training.txt
new file mode 100644
index 0000000000000..a816e067e3636
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/Speech2Text2ForCausalLM_training.txt
@@ -0,0 +1,82 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([8192, 10000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([8192, 10000], f16), T([8192, 10000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 6, ((T([256, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 6, ((T([256, 128, 128], f16), T([256, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([64, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([64, 128], b8),), {'dtype': i32})
+cnt: 1, ((T([64, 128], i64),), {'dtype': i32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([64, 128], i32),), {'dtype': i64})
+Operator: aten._unsafe_view.default
+cnt: 18, ((T([64, 128, 4, 64], f16), [64, 128, 256]), {})
+cnt: 1, ((T([8192, 10000], f16), [64, 128, 10000]), {})
+cnt: 6, ((T([64, 4, 128, 64], f16), [256, 128, 64]), {})
+cnt: 6, ((T([64, 128, 256], f16), [8192, 256]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 1, ((T([64, 128], i32), 0), {})
+cnt: 1, ((T([64, 128], i64), 1), {})
+cnt: 37, ((T([64, 128, 256], f16), T([64, 128, 256], f16)), {})
+cnt: 6, ((T([64, 4, 128, 128], f16), T([64, 1, 128, 128], f16)), {})
+cnt: 1, ((T([10000, 256], f16), T([10000, 256], f16)), {})
+Operator: aten.addmm.default
+cnt: 24, ((T([256], f16), T([8192, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 6, ((T([2048], f16), T([8192, 256], f16), T([256, 2048], f16, stride=(1, 256))), {})
+cnt: 6, ((T([256], f16), T([8192, 2048], f16), T([2048, 256], f16, stride=(1, 2048))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([256, 128, 64], f16), T([256, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 12, ((T([256, 128, 128], f16), T([256, 128, 64], f16)), {})
+cnt: 6, ((T([256, 128, 128], f16, stride=(16384, 1, 128)), T([256, 128, 64], f16)), {})
+cnt: 6, ((T([256, 64, 128], f16, stride=(8192, 1, 64)), T([256, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([64, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([64, 128], i64), T([64, 128], i64)), {})
+Operator: aten.cumsum.default
+cnt: 1, ((T([64, 128], i32), 1), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([10000, 256], f16), T([64, 128], i64), 1), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([64, 128, 256], f16), T([64, 128], i64), 10000, 1, False), {})
+Operator: aten.index_select.default
+cnt: 1, ((T([1026, 256], f16), 0, T([8192], i64)), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([8192, 256], f16), T([256, 10000], f16, stride=(1, 256))), {})
+cnt: 1, ((T([10000, 8192], f16, stride=(1, 10000)), T([8192, 256], f16)), {})
+cnt: 1, ((T([8192, 10000], f16), T([10000, 256], f16)), {})
+cnt: 6, ((T([8192, 256], f16), T([256, 2048], f16)), {})
+cnt: 6, ((T([256, 8192], f16, stride=(1, 256)), T([8192, 2048], f16)), {})
+cnt: 6, ((T([8192, 2048], f16), T([2048, 256], f16)), {})
+cnt: 6, ((T([2048, 8192], f16, stride=(1, 2048)), T([8192, 256], f16)), {})
+cnt: 24, ((T([8192, 256], f16), T([256, 256], f16)), {})
+cnt: 24, ((T([256, 8192], f16, stride=(1, 256)), T([8192, 256], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([64, 128, 256], f16), 16.0), {})
+cnt: 1, ((T([64, 128], i32), T([64, 128], i32)), {})
+cnt: 12, ((T([64, 128, 256], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 12, ((T([64, 128, 256], f16), [256], T([256], f16), T([256], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 12, ((T([64, 128, 256], f16), T([64, 128, 256], f16), [256], T([64, 128, 1], f32), T([64, 128, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([64, 128], i64), 1), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([8192, 10000], f16), T([8192], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([8192, 10000], f16), T([8192], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 6, ((T([64, 128, 2048], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 30, ((T([8192, 256], f16), [0], True), {})
+cnt: 6, ((T([8192, 2048], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 6, ((T([64, 128, 2048], f16), T([64, 128, 2048], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/TrOCRForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/TrOCRForCausalLM_training.txt
new file mode 100644
index 0000000000000..97c3b304cee47
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/TrOCRForCausalLM_training.txt
@@ -0,0 +1,73 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([1024, 50265], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([1024, 50265], f16), T([1024, 50265], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([128, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([128, 128, 128], f16), T([128, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([8, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([8, 128, 16, 64], f16), [8, 128, 1024]), {})
+cnt: 1, ((T([1024, 50265], f16), [8, 128, 50265]), {})
+cnt: 12, ((T([8, 16, 128, 64], f16), [128, 128, 64]), {})
+cnt: 12, ((T([8, 128, 1024], f16), [1024, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([8, 128], i64, stride=(0, 1)), 2), {})
+cnt: 73, ((T([8, 128, 1024], f16), T([8, 128, 1024], f16)), {})
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 12, ((T([8, 16, 128, 128], f16), T([8, 1, 128, 128], f16)), {})
+cnt: 1, ((T([50265, 1024], f16), T([50265, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([1024], f16), T([1024, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 12, ((T([4096], f16), T([1024, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 12, ((T([1024], f16), T([1024, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([128, 128, 64], f16), T([128, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 24, ((T([128, 128, 128], f16), T([128, 128, 64], f16)), {})
+cnt: 12, ((T([128, 128, 128], f16, stride=(16384, 1, 128)), T([128, 128, 64], f16)), {})
+cnt: 12, ((T([128, 64, 128], f16, stride=(8192, 1, 64)), T([128, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([8, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([8, 128], i64), T([8, 128], i64)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50265, 1024], f16), T([8, 128], i64), 1), {})
+cnt: 1, ((T([514, 1024], f16), T([8, 128], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([8, 128, 1024], f16), T([8, 128], i64), 514, -1, False), {})
+cnt: 1, ((T([8, 128, 1024], f16), T([8, 128], i64), 50265, 1, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([8, 128, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([8, 128, 4096], f16), T([8, 128, 4096], f16)), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 1024], f16), T([1024, 50265], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([50265, 1024], f16, stride=(1, 50265)), T([1024, 1024], f16)), {})
+cnt: 1, ((T([1024, 50265], f16), T([50265, 1024], f16)), {})
+cnt: 12, ((T([1024, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 12, ((T([1024, 1024], f16, stride=(1, 1024)), T([1024, 4096], f16)), {})
+cnt: 12, ((T([1024, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 12, ((T([4096, 1024], f16, stride=(1, 4096)), T([1024, 1024], f16)), {})
+cnt: 48, ((T([1024, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 48, ((T([1024, 1024], f16, stride=(1, 1024)), T([1024, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([8, 128, 1024], f16), 1.0), {})
+cnt: 24, ((T([8, 128, 1024], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([8, 128, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([8, 128, 1024], f16), T([8, 128, 1024], f16), [1024], T([8, 128, 1], f32), T([8, 128, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([1024, 50265], f16), T([1024], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([1024, 50265], f16), T([1024], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 60, ((T([1024, 1024], f16), [0], True), {})
+cnt: 12, ((T([1024, 4096], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/XGLMForCausalLM_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/XGLMForCausalLM_training.txt
new file mode 100644
index 0000000000000..a8317b48f20dd
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/XGLMForCausalLM_training.txt
@@ -0,0 +1,88 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([256, 256008], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([256, 256008], f16), T([256, 256008], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([32, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([32, 128, 128], f16), T([32, 128, 128], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([128, 128], f32),), {'dtype': f16})
+cnt: 1, ((T([2, 1, 128, 128], f16, stride=(0, 16384, 128, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([2, 128], b8),), {'dtype': i32})
+cnt: 1, ((T([2, 128], i64),), {'dtype': i32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([2, 128], i32),), {'dtype': i64})
+Operator: aten._unsafe_view.default
+cnt: 72, ((T([2, 128, 16, 64], f16), [2, 128, 1024]), {})
+cnt: 1, ((T([256, 256008], f16), [2, 128, 256008]), {})
+cnt: 24, ((T([2, 16, 128, 64], f16), [32, 128, 64]), {})
+cnt: 24, ((T([2, 128, 1024], f16), [256, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128], i64), 1), {})
+cnt: 1, ((T([2, 128], i32), 0), {})
+cnt: 1, ((T([2, 128], i64), 1), {})
+cnt: 145, ((T([2, 128, 1024], f16), T([2, 128, 1024], f16)), {})
+cnt: 24, ((T([2, 16, 128, 128], f16), T([2, 1, 128, 128], f16)), {})
+cnt: 1, ((T([256008, 1024], f16), T([256008, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 96, ((T([1024], f16), T([256, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([4096], f16), T([256, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([1024], f16), T([256, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+Operator: aten.bmm.default
+cnt: 48, ((T([32, 128, 64], f16), T([32, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 48, ((T([32, 128, 128], f16), T([32, 128, 64], f16)), {})
+cnt: 24, ((T([32, 128, 128], f16, stride=(16384, 1, 128)), T([32, 128, 64], f16)), {})
+cnt: 24, ((T([32, 64, 128], f16, stride=(8192, 1, 64)), T([32, 128, 128], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([2, 128], i64),), {})
+cnt: 1, ((T([2, 127], i64, stride=(128, 1)),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([2, 128], i64), T([2, 128], i64)), {})
+cnt: 1, ((T([2, 127], i64, stride=(128, 1)), T([2, 127], i64)), {})
+Operator: aten.cumsum.default
+cnt: 1, ((T([2, 128], i32), 1), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([256008, 1024], f16), T([2, 128], i64), 1), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([2, 128, 1024], f16), T([2, 128], i64), 256008, 1, False), {})
+Operator: aten.fill_.Tensor
+cnt: 1, ((T([2], i64, stride=(128,)), T([], i64)), {})
+Operator: aten.gelu.default
+cnt: 24, ((T([2, 128, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 24, ((T([2, 128, 4096], f16), T([2, 128, 4096], f16)), {})
+Operator: aten.index_select.default
+cnt: 1, ((T([2050, 1024], f16), 0, T([256], i64)), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([128], i64), T([128, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([128, 128], f32), T([128, 128], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([256, 1024], f16), T([1024, 256008], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([256008, 256], f16, stride=(1, 256008)), T([256, 1024], f16)), {})
+cnt: 1, ((T([256, 256008], f16), T([256008, 1024], f16)), {})
+cnt: 24, ((T([256, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 256], f16, stride=(1, 1024)), T([256, 4096], f16)), {})
+cnt: 24, ((T([256, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 24, ((T([4096, 256], f16, stride=(1, 4096)), T([256, 1024], f16)), {})
+cnt: 96, ((T([256, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 96, ((T([1024, 256], f16, stride=(1, 1024)), T([256, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([2, 128, 1024], f16), 32.0), {})
+cnt: 1, ((T([2, 128], i32), T([2, 128], i32)), {})
+cnt: 48, ((T([2, 128, 1024], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 49, ((T([2, 128, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 49, ((T([2, 128, 1024], f16), T([2, 128, 1024], f16), [1024], T([2, 128, 1], f32), T([2, 128, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([2, 128], i64), 1), {})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([2, 128], i64), [2, 128]), {'dtype': i64, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([256, 256008], f16), T([256], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([256, 256008], f16), T([256], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 120, ((T([256, 1024], f16), [0], True), {})
+cnt: 24, ((T([256, 4096], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/XLNetLMHeadModel_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/XLNetLMHeadModel_training.txt
new file mode 100644
index 0000000000000..f3056de63d924
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/XLNetLMHeadModel_training.txt
@@ -0,0 +1,105 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([2048, 32000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([2048, 32000], f16), T([2048, 32000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 24, ((T([4, 16, 512, 512], f16), 3, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([4, 16, 512, 512], f16), T([4, 16, 512, 512], f16), 3, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1024, 4, 1024], f32, stride=(1024, 0, 1)),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 24, ((T([1024, 4, 1024], f32),), {'dtype': f16, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 24, ((T([512, 4, 64, 16, 1], f16), [1, 2048, 1024]), {})
+cnt: 24, ((T([64, 16, 1024, 1, 1], f16), [1, 1024, 1024]), {})
+cnt: 24, ((T([4, 16, 512, 1, 64], f16), [64, 512, 64]), {})
+cnt: 24, ((T([1024, 4, 1, 16, 64], f16), [1, 4096, 1024]), {})
+cnt: 72, ((T([512, 4, 1, 16, 64], f16), [1, 2048, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 48, ((T([512, 4, 16, 64], f16), T([16, 64], f16)), {})
+cnt: 24, ((T([4, 16, 512, 512], f16), T([4, 16, 512, 512], f16)), {})
+cnt: 24, ((T([4, 16, 512, 512], f16), 0), {})
+cnt: 144, ((T([512, 4, 1024], f16), T([512, 4, 1024], f16)), {})
+cnt: 24, ((T([512, 4, 16, 64], f16, stride=(64, 524288, 32768, 1)), T([512, 4, 16, 64], f16, stride=(64, 524288, 32768, 1))), {})
+cnt: 1, ((T([32000, 1024], f16), T([32000, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 24, ((T([4096], f16), T([2048, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 24, ((T([1024], f16), T([2048, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+cnt: 1, ((T([32000], f16), T([2048, 1024], f16), T([1024, 32000], f16, stride=(1, 1024))), {})
+Operator: aten.bmm.default
+cnt: 96, ((T([1, 2048, 1024], f16), T([1, 1024, 1024], f16)), {})
+cnt: 24, ((T([1, 4096, 1024], f16), T([1, 1024, 1024], f16)), {})
+cnt: 24, ((T([64, 512, 64], f16, stride=(64, 4096, 1)), T([64, 64, 512], f16, stride=(64, 1, 4096))), {})
+cnt: 24, ((T([64, 512, 64], f16, stride=(64, 4096, 1)), T([64, 64, 1024], f16, stride=(64, 1, 4096))), {})
+cnt: 48, ((T([64, 512, 512], f16), T([64, 512, 64], f16, stride=(64, 4096, 1))), {})
+cnt: 96, ((T([1, 1024, 2048], f16, stride=(2097152, 1, 1024)), T([1, 2048, 1024], f16)), {})
+cnt: 96, ((T([1, 2048, 1024], f16), T([1, 1024, 1024], f16, stride=(1048576, 1, 1024))), {})
+cnt: 24, ((T([64, 512, 512], f16, stride=(262144, 1, 512)), T([64, 512, 64], f16)), {})
+cnt: 24, ((T([64, 512, 64], f16), T([64, 64, 512], f16, stride=(64, 1, 4096))), {})
+cnt: 24, ((T([64, 64, 512], f16, stride=(64, 1, 4096)), T([64, 512, 1024], f16)), {})
+cnt: 24, ((T([64, 512, 1024], f16), T([64, 1024, 64], f16, stride=(64, 4096, 1))), {})
+cnt: 24, ((T([64, 64, 512], f16, stride=(64, 1, 4096)), T([64, 512, 512], f16)), {})
+cnt: 24, ((T([1, 1024, 4096], f16, stride=(4194304, 1, 1024)), T([1, 4096, 1024], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([1024, 512], f32), T([1024, 512], f32)], -1), {})
+Operator: aten.clone.default
+cnt: 2, ((T([4, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([4, 512], i64), T([4, 512], i64)), {})
+cnt: 24, ((T([1024, 16, 64], f16), T([1024, 16, 64], f16, stride=(1, 1024, 16384))), {})
+Operator: aten.cos.default
+cnt: 1, ((T([1024, 512], f32),), {})
+Operator: aten.div.Tensor
+cnt: 1, ((T([512], f32), 1024), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([32000, 1024], f16), T([512, 4], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([512, 4, 1024], f16), T([512, 4], i64), 32000, -1, False), {})
+Operator: aten.gelu.default
+cnt: 24, ((T([512, 4, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 24, ((T([512, 4, 4096], f16), T([512, 4, 4096], f16)), {})
+Operator: aten.index_add.default
+cnt: 24, ((T([4, 16, 512, 1023], f16), 3, T([512], i64), T([4, 16, 512, 512], f16)), {})
+Operator: aten.index_select.default
+cnt: 24, ((T([4, 16, 512, 1023], f16, stride=(8388608, 524288, 1023, 1)), 3, T([512], i64)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 32000], f16), T([32000, 1024], f16)), {})
+cnt: 1, ((T([32000, 2048], f16, stride=(1, 32000)), T([2048, 1024], f16)), {})
+cnt: 24, ((T([2048, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 24, ((T([1024, 2048], f16, stride=(1, 1024)), T([2048, 4096], f16)), {})
+cnt: 24, ((T([2048, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 24, ((T([4096, 2048], f16, stride=(1, 4096)), T([2048, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([512], f32), 1), {})
+cnt: 1, ((T([1024, 1], f32), T([1, 512], f32)), {})
+cnt: 48, ((T([4, 16, 512, 512], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 48, ((T([512, 4, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 1, ((T([512, 4, 1024], f16, stride=(1024, 524288, 1)), T([512, 4, 1024], f16), [1024], T([512, 4, 1], f32), T([512, 4, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+cnt: 47, ((T([512, 4, 1024], f16), T([512, 4, 1024], f16), [1024], T([512, 4, 1], f32), T([512, 4, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 24, ((T([1024, 16, 64], f16, stride=(1, 1024, 16384)), [1024, 16, 64], [1024, 64, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.new_zeros.default
+cnt: 24, ((T([4, 16, 512, 512], f16), [4, 16, 512, 1023]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([2048, 32000], f16), T([2048], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([2048, 32000], f16), T([2048], i64), None, 1, -100), {})
+Operator: aten.pow.Scalar
+cnt: 1, ((10000, T([512], f32)), {})
+Operator: aten.reciprocal.default
+cnt: 1, ((T([512], f32),), {})
+Operator: aten.sin.default
+cnt: 1, ((T([1024, 512], f32),), {})
+Operator: aten.slice_backward.default
+cnt: 24, ((T([4, 16, 1023, 512], f16), [4, 16, 1023, 512], 3, 0, 9223372036854775807, 1), {})
+cnt: 24, ((T([4, 16, 1023, 512], f16), [4, 16, 1024, 512], 2, 1, 9223372036854775807, 1), {})
+cnt: 24, ((T([4, 16, 1024, 512], f16), [4, 16, 1024, 512], 1, 0, 9223372036854775807, 1), {})
+cnt: 24, ((T([4, 16, 1024, 512], f16), [4, 16, 1024, 512], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2048, 32000], f16), [0], True), {})
+cnt: 24, ((T([2048, 1024], f16), [0], True), {})
+cnt: 24, ((T([2048, 4096], f16), [0], True), {})
+cnt: 48, ((T([512, 4, 16, 64], f16, stride=(64, 524288, 32768, 1)), [0, 1], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/YituTechConvBert_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/YituTechConvBert_training.txt
new file mode 100644
index 0000000000000..d1a6dcccdaa19
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/hf_train/YituTechConvBert_training.txt
@@ -0,0 +1,119 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([512, 30522], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([512, 30522], f16), T([512, 30522], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([3072, 9, 1], f16), 1, False), {})
+cnt: 12, ((T([1, 6, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([1, 6, 512, 512], f16), T([1, 6, 512, 512], f16), -1, f16), {})
+cnt: 12, ((T([3072, 9, 1], f16), T([3072, 9, 1], f16), 1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1, 1, 1, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([1, 512, 54], f16), [1, 512, 54]), {})
+cnt: 12, ((T([1, 512, 384, 9], f16), [3072, 64, 9]), {})
+cnt: 12, ((T([3072, 64, 1], f16), [3072, 64, 1]), {})
+cnt: 12, ((T([6, 512, 512], f16), [1, 6, 512, 512]), {})
+cnt: 12, ((T([6, 512, 64], f16), [1, 6, 512, 64]), {})
+cnt: 12, ((T([512, 384], f16), [3072, 64, 1]), {})
+cnt: 24, ((T([1, 512, 6, 64], f16), [1, 512, 384]), {})
+Operator: aten.add.Tensor
+cnt: 86, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 12, ((T([1, 512, 54], f16), T([54], f16)), {})
+cnt: 12, ((T([1, 6, 512, 512], f16), T([1, 1, 1, 512], f16)), {})
+cnt: 12, ((T([1, 512, 384], f16), T([1, 512, 384], f16)), {})
+cnt: 12, ((T([1, 512, 768], f16), T([1, 512, 768], f16, stride=(393216, 1, 512))), {})
+cnt: 1, ((T([30522, 768], f16), T([30522, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 12, ((T([1, 384, 512], f16), T([384, 1], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([384], f16), T([512, 768], f16), T([768, 384], f16, stride=(1, 768))), {})
+cnt: 13, ((T([768], f16), T([512, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([512, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([512, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([30522], f16), T([512, 768], f16), T([768, 30522], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([1, 512, 384], f16, stride=(512, 1, 512)), T([1, 384, 54], f16, stride=(384, 1, 384))), {})
+cnt: 12, ((T([3072, 64, 9], f16), T([3072, 9, 1], f16)), {})
+cnt: 12, ((T([6, 512, 64], f16, stride=(64, 384, 1)), T([6, 64, 512], f16, stride=(64, 1, 384))), {})
+cnt: 24, ((T([6, 512, 512], f16), T([6, 512, 64], f16, stride=(64, 384, 1))), {})
+cnt: 12, ((T([6, 512, 512], f16, stride=(262144, 1, 512)), T([6, 512, 64], f16, stride=(64, 768, 1))), {})
+cnt: 12, ((T([6, 512, 64], f16, stride=(64, 768, 1)), T([6, 64, 512], f16, stride=(64, 1, 384))), {})
+cnt: 12, ((T([6, 64, 512], f16, stride=(64, 1, 384)), T([6, 512, 512], f16)), {})
+cnt: 12, ((T([3072, 9, 64], f16, stride=(576, 1, 9)), T([3072, 64, 1], f16)), {})
+cnt: 12, ((T([3072, 64, 1], f16), T([3072, 1, 9], f16)), {})
+cnt: 12, ((T([1, 384, 512], f16), T([1, 512, 54], f16)), {})
+cnt: 12, ((T([1, 512, 54], f16), T([1, 54, 384], f16)), {})
+Operator: aten.cat.default
+cnt: 12, (([T([1, 512, 6, 64], f16), T([1, 512, 6, 64], f16)], 2), {})
+Operator: aten.clone.default
+cnt: 2, ((T([1, 512], i64),), {})
+Operator: aten.convolution.default
+cnt: 12, ((T([1, 768, 512], f16, stride=(393216, 1, 768)), T([768, 1, 9], f16), None, [1], [4], [1], False, [0], 768), {})
+cnt: 12, ((T([1, 768, 512], f16), T([384, 768, 1], f16), None, [1], [0], [1], False, [0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 12, ((T([1, 384, 512], f16, stride=(196608, 1, 384)), T([1, 768, 512], f16), T([384, 768, 1], f16), [0], [1], [0], [1], False, [0], 1, [True, True, False]), {})
+cnt: 12, ((T([1, 768, 512], f16), T([1, 768, 512], f16, stride=(393216, 1, 768)), T([768, 1, 9], f16), [0], [1], [4], [1], False, [0], 768, [True, True, False]), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([1, 512], i64), T([1, 512], i64)), {})
+cnt: 12, ((T([54, 384], f16), T([54, 384], f16, stride=(1, 54))), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([1, 6, 512, 512], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([1, 512], i64), 0), {})
+cnt: 1, ((T([512, 768], f16), T([1, 512], i64)), {})
+cnt: 1, ((T([2, 768], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 2, -1, False), {})
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([1, 512, 3072], f16),), {})
+cnt: 1, ((T([1, 512, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 12, ((T([1, 512, 3072], f16), T([1, 512, 3072], f16)), {})
+Operator: aten.im2col.default
+cnt: 12, ((T([1, 384, 512, 1], f16), [9, 1], [1, 1], [4, 0], [1, 1]), {})
+Operator: aten.im2col_backward.default
+cnt: 12, ((T([1, 3456, 512], f16, stride=(1769472, 1, 3456)), [512, 1], [9, 1], [1, 1], [4, 0], [1, 1]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([512, 30522], f16), T([30522, 768], f16)), {})
+cnt: 1, ((T([30522, 512], f16, stride=(1, 30522)), T([512, 768], f16)), {})
+cnt: 13, ((T([512, 768], f16), T([768, 768], f16)), {})
+cnt: 13, ((T([768, 512], f16, stride=(1, 768)), T([512, 768], f16)), {})
+cnt: 12, ((T([512, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 512], f16, stride=(1, 768)), T([512, 3072], f16)), {})
+cnt: 12, ((T([512, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 512], f16, stride=(1, 3072)), T([512, 768], f16)), {})
+cnt: 24, ((T([512, 384], f16, stride=(1, 512)), T([384, 768], f16)), {})
+cnt: 24, ((T([384, 512], f16), T([512, 768], f16)), {})
+cnt: 24, ((T([512, 384], f16), T([384, 768], f16)), {})
+cnt: 24, ((T([384, 512], f16, stride=(1, 384)), T([512, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([1, 1, 1, 512], f16), -65504.0), {})
+cnt: 12, ((T([1, 512, 384], f16, stride=(196608, 1, 512)), T([1, 512, 384], f16)), {})
+cnt: 12, ((T([1, 512, 384], f16), T([1, 512, 384], f16, stride=(196608, 1, 512))), {})
+cnt: 12, ((T([1, 512, 384], f16), T([1, 512, 384], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([1, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([1, 512, 768], f16), T([1, 512, 768], f16), [768], T([1, 512, 1], f32), T([1, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 12, ((T([54, 384], f16, stride=(1, 54)), [54, 384], [384, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([512, 30522], f16), T([512], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([512, 30522], f16), T([512], i64), None, 1, -100), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([1, 1, 1, 512], f16), 1.0), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([512, 30522], f16), [0], True), {})
+cnt: 25, ((T([512, 768], f16), [0], True), {})
+cnt: 12, ((T([512, 3072], f16), [0], True), {})
+cnt: 24, ((T([512, 384], f16, stride=(1, 512)), [0], True), {})
+cnt: 12, ((T([1, 512, 54], f16), [0, 1], True), {})
+cnt: 12, ((T([1, 384, 54], f16), [0], True), {})
+cnt: 12, ((T([1, 384, 512], f16, stride=(196608, 1, 384)), [0, 2], True), {})
+cnt: 24, ((T([512, 384], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/adv_inception_v3_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/adv_inception_v3_training.txt
new file mode 100644
index 0000000000000..c11cd6890c765
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/adv_inception_v3_training.txt
@@ -0,0 +1,239 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16)), {})
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16)), {})
+cnt: 3, ((T([128, 1280, 8, 8], f16), T([128, 1280, 8, 8], f16)), {})
+cnt: 14, ((T([128, 768, 17, 17], f16), T([128, 768, 17, 17], f16)), {})
+cnt: 5, ((T([128, 288, 35, 35], f16), T([128, 288, 35, 35], f16)), {})
+cnt: 3, ((T([128, 256, 35, 35], f16), T([128, 256, 35, 35], f16)), {})
+cnt: 3, ((T([128, 192, 35, 35], f16), T([128, 192, 35, 35], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 94, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 192, 35, 35], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 256, 35, 35], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 4, ((T([128, 768, 17, 17], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), [3, 3], [1, 1], [1, 1]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([128, 1280, 8, 8], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 4, ((T([128, 768, 17, 17], f16), T([128, 768, 17, 17], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), T([128, 288, 35, 35], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 256, 35, 35], f16), T([128, 256, 35, 35], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([128, 192, 35, 35], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), T([128, 96, 35, 35], f16), T([128, 32, 35, 35], f16)], 1), {})
+cnt: 2, (([T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), T([128, 96, 35, 35], f16), T([128, 64, 35, 35], f16)], 1), {})
+cnt: 1, (([T([128, 384, 17, 17], f16), T([128, 96, 17, 17], f16), T([128, 288, 17, 17], f16)], 1), {})
+cnt: 4, (([T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16)], 1), {})
+cnt: 1, (([T([128, 320, 8, 8], f16), T([128, 192, 8, 8], f16), T([128, 768, 8, 8], f16)], 1), {})
+cnt: 4, (([T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16)], 1), {})
+cnt: 2, (([T([128, 320, 8, 8], f16), T([128, 768, 8, 8], f16), T([128, 768, 8, 8], f16), T([128, 192, 8, 8], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 299, 299], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 299, 299], f16), T([32, 3, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([32, 32, 3, 3], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 73, 73], f16), T([80, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([192, 80, 3, 3], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 192, 35, 35], f16), T([64, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([48, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([64, 48, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 64, 35, 35], f16), T([96, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([32, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 256, 35, 35], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 35, 35], f16), T([48, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 288, 35, 35], f16), T([64, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), T([48, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), T([384, 288, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([128, 768, 17, 17], f16), T([192, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 768, 17, 17], f16), T([128, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 17, 17], f16), T([192, 128, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 17, 17], f16), T([192, 128, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 768, 17, 17], f16), T([160, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([160, 160, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 160, 17, 17], f16), T([192, 160, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([160, 160, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 160, 17, 17], f16), T([192, 160, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([192, 192, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([192, 192, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([320, 192, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([192, 192, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([320, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([384, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([384, 384, 1, 3], f16), None, [1, 1], [0, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([384, 384, 3, 1], f16), None, [1, 1], [1, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([448, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([384, 448, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([192, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([320, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([384, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([448, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([192, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 192, 8, 8], f16), T([128, 2048, 8, 8], f16), T([192, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), T([384, 384, 3, 1], f16), [0], [1, 1], [1, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), T([384, 384, 1, 3], f16), [0], [1, 1], [0, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 384, 8, 8], f16), T([128, 448, 8, 8], f16), T([384, 448, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 448, 8, 8], f16), T([128, 2048, 8, 8], f16), T([448, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 8, 8], f16), T([128, 2048, 8, 8], f16), T([384, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 8, 8], f16), T([128, 2048, 8, 8], f16), T([320, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 8, 8], f16), T([128, 1280, 8, 8], f16), T([192, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 448, 8, 8], f16), T([128, 1280, 8, 8], f16), T([448, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 8, 8], f16), T([128, 1280, 8, 8], f16), T([384, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 8, 8], f16), T([128, 1280, 8, 8], f16), T([320, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 8, 8], f16), T([128, 192, 17, 17], f16), T([192, 192, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([192, 192, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([192, 192, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 12, ((T([128, 192, 17, 17], f16), T([128, 768, 17, 17], f16), T([192, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 8, 8], f16), T([128, 192, 17, 17], f16), T([320, 192, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 192, 17, 17], f16), T([128, 160, 17, 17], f16), T([192, 160, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), T([160, 160, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), T([160, 160, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([128, 768, 17, 17], f16), T([160, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 192, 17, 17], f16), T([128, 160, 17, 17], f16), T([192, 160, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([128, 128, 17, 17], f16), T([192, 128, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), T([128, 128, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), T([128, 128, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 768, 17, 17], f16), T([128, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([128, 128, 17, 17], f16), T([192, 128, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 17, 17], f16), T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 96, 35, 35], f16), T([128, 64, 35, 35], f16), T([96, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 64, 35, 35], f16), T([128, 288, 35, 35], f16), T([64, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 17, 17], f16), T([128, 288, 35, 35], f16), T([384, 288, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 96, 35, 35], f16), T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 35, 35], f16), T([128, 48, 35, 35], f16), T([64, 48, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 35, 35], f16), T([128, 288, 35, 35], f16), T([48, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 35, 35], f16), T([128, 256, 35, 35], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 35, 35], f16), T([128, 256, 35, 35], f16), T([48, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 35, 35], f16), T([128, 192, 35, 35], f16), T([32, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 35, 35], f16), T([128, 192, 35, 35], f16), T([64, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 35, 35], f16), T([128, 192, 35, 35], f16), T([48, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([128, 80, 73, 73], f16), T([192, 80, 3, 3], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([128, 64, 73, 73], f16), T([80, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([128, 32, 147, 147], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([128, 32, 149, 149], f16), T([32, 32, 3, 3], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([128, 3, 299, 299], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 299, 299], f16), T([128, 3, 299, 299], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2048, 8, 8], f16, stride=(2048, 1, 0, 0)), 64), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 147, 147], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 768, 17, 17], f16), [3, 3], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 768, 8, 8], f16, stride=(81920, 64, 8, 1)), T([128, 768, 17, 17], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 768, 8, 8], i64)), {})
+cnt: 1, ((T([128, 288, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 288, 35, 35], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 288, 17, 17], i64)), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([128, 192, 71, 71], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 192, 35, 35], i64)), {})
+cnt: 1, ((T([128, 64, 73, 73], f16), T([128, 64, 147, 147], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 64, 73, 73], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 2048, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 32, 149, 149], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 12, ((T([128, 64, 35, 35], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 0.001), {})
+cnt: 7, ((T([128, 96, 35, 35], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 32, 35, 35], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 384, 17, 17], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 96, 17, 17], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 26, ((T([128, 192, 17, 17], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 6, ((T([128, 128, 17, 17], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 0.001), {})
+cnt: 12, ((T([128, 160, 17, 17], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 320, 8, 8], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 192, 8, 8], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 12, ((T([128, 384, 8, 8], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), True, 0.1, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([128, 192, 8, 8], f16), T([128, 192, 8, 8], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 12, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([128, 448, 8, 8], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 320, 8, 8], f16), T([128, 320, 8, 8], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 0.001, [True, True, True]), {})
+cnt: 26, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 12, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 0.001, [True, True, True]), {})
+cnt: 6, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 17, 17], f16), T([128, 96, 17, 17], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+cnt: 7, ((T([128, 96, 35, 35], f16), T([128, 96, 35, 35], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+cnt: 12, ((T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 384, 17, 17], f16), T([128, 384, 17, 17], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([128, 48, 35, 35], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 35, 35], f16), T([128, 32, 35, 35], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([128, 192, 71, 71], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([128, 80, 73, 73], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([128, 64, 147, 147], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([128, 32, 147, 147], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([128, 32, 149, 149], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 32, 149, 149], f16),), {})
+cnt: 1, ((T([128, 32, 147, 147], f16),), {})
+cnt: 1, ((T([128, 64, 147, 147], f16),), {})
+cnt: 1, ((T([128, 80, 73, 73], f16),), {})
+cnt: 1, ((T([128, 192, 71, 71], f16),), {})
+cnt: 12, ((T([128, 64, 35, 35], f16),), {})
+cnt: 3, ((T([128, 48, 35, 35], f16),), {})
+cnt: 7, ((T([128, 96, 35, 35], f16),), {})
+cnt: 1, ((T([128, 32, 35, 35], f16),), {})
+cnt: 1, ((T([128, 384, 17, 17], f16),), {})
+cnt: 1, ((T([128, 96, 17, 17], f16),), {})
+cnt: 26, ((T([128, 192, 17, 17], f16),), {})
+cnt: 6, ((T([128, 128, 17, 17], f16),), {})
+cnt: 12, ((T([128, 160, 17, 17], f16),), {})
+cnt: 3, ((T([128, 320, 8, 8], f16),), {})
+cnt: 3, ((T([128, 192, 8, 8], f16),), {})
+cnt: 12, ((T([128, 384, 8, 8], f16),), {})
+cnt: 2, ((T([128, 448, 8, 8], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([128, 192, 8, 8], f16, stride=(131072, 64, 8, 1)), T([128, 192, 8, 8], f16), 0), {})
+cnt: 8, ((T([128, 384, 8, 8], f16, stride=(131072, 64, 8, 1)), T([128, 384, 8, 8], f16), 0), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), 0), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([128, 448, 8, 8], f16), 0), {})
+cnt: 2, ((T([128, 320, 8, 8], f16, stride=(131072, 64, 8, 1)), T([128, 320, 8, 8], f16), 0), {})
+cnt: 1, ((T([128, 192, 8, 8], f16, stride=(81920, 64, 8, 1)), T([128, 192, 8, 8], f16), 0), {})
+cnt: 10, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), 0), {})
+cnt: 1, ((T([128, 320, 8, 8], f16, stride=(81920, 64, 8, 1)), T([128, 320, 8, 8], f16), 0), {})
+cnt: 16, ((T([128, 192, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 192, 17, 17], f16), 0), {})
+cnt: 12, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), 0), {})
+cnt: 6, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), 0), {})
+cnt: 1, ((T([128, 96, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 96, 17, 17], f16), 0), {})
+cnt: 4, ((T([128, 96, 35, 35], f16), T([128, 96, 35, 35], f16), 0), {})
+cnt: 4, ((T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 384, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 384, 17, 17], f16), 0), {})
+cnt: 6, ((T([128, 64, 35, 35], f16, stride=(352800, 1225, 35, 1)), T([128, 64, 35, 35], f16), 0), {})
+cnt: 2, ((T([128, 96, 35, 35], f16, stride=(352800, 1225, 35, 1)), T([128, 96, 35, 35], f16), 0), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([128, 48, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 32, 35, 35], f16, stride=(313600, 1225, 35, 1)), T([128, 32, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 96, 35, 35], f16, stride=(313600, 1225, 35, 1)), T([128, 96, 35, 35], f16), 0), {})
+cnt: 2, ((T([128, 64, 35, 35], f16, stride=(313600, 1225, 35, 1)), T([128, 64, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([128, 192, 71, 71], f16), 0), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([128, 80, 73, 73], f16), 0), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([128, 64, 147, 147], f16), 0), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([128, 32, 147, 147], f16), 0), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([128, 32, 149, 149], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/beit_base_patch16_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/beit_base_patch16_224_training.txt
new file mode 100644
index 0000000000000..c4df651ef1037
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/beit_base_patch16_224_training.txt
@@ -0,0 +1,100 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([64, 12, 197, 197], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([64, 12, 197, 197], f16), T([64, 12, 197, 197], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([64, 12, 197, 64], f16), [768, 197, 64]), {})
+cnt: 12, ((T([64, 12, 64, 197], f16), [768, 64, 197]), {})
+cnt: 12, ((T([768, 197, 197], f16), [64, 12, 197, 197]), {})
+cnt: 12, ((T([768, 197, 64], f16), [64, 12, 197, 64]), {})
+cnt: 12, ((T([64, 197, 12, 64], f16), [64, 197, 768]), {})
+cnt: 12, ((T([64, 197, 3, 12, 64], f16), [64, 197, 2304]), {})
+Operator: aten.add.Tensor
+cnt: 12, ((T([64, 12, 197, 197], f16), T([1, 12, 197, 197], f16)), {})
+cnt: 48, ((T([64, 197, 768], f16), T([64, 197, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([2304], f16), T([12608, 768], f16), T([768, 2304], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([12608, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([12608, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([12608, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([1000], f16), T([64, 768], f16), T([768, 1000], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([768, 197, 64], f16), T([768, 64, 197], f16)), {})
+cnt: 12, ((T([768, 197, 197], f16), T([768, 197, 64], f16)), {})
+cnt: 12, ((T([768, 197, 197], f16, stride=(38809, 1, 197)), T([768, 197, 64], f16)), {})
+cnt: 12, ((T([768, 197, 64], f16), T([768, 64, 197], f16, stride=(12608, 1, 64))), {})
+cnt: 12, ((T([768, 64, 197], f16, stride=(12608, 1, 64)), T([768, 197, 197], f16)), {})
+cnt: 12, ((T([768, 197, 197], f16), T([768, 197, 64], f16, stride=(12608, 1, 197))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 1, 768], f16, stride=(0, 768, 1)), T([64, 196, 768], f16, stride=(150528, 1, 196))], 1), {})
+cnt: 12, (([T([768], f16), T([768], f16), T([768], f16)],), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([768, 3, 16, 16], f16), T([768], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 768, 14, 14], f16, stride=(151296, 1, 10752, 768)), T([64, 3, 224, 224], f16), T([768, 3, 16, 16], f16), [768], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 196, 768], f16, stride=(768, 0, 1)), 196), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([64, 197, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([64, 197, 3072], f16), T([64, 197, 3072], f16)), {})
+Operator: aten.index.Tensor
+cnt: 12, ((T([732, 12], f16), [T([38809], i64)]), {})
+Operator: aten.index_put.default
+cnt: 12, ((T([732, 12], f16), [T([38809], i64)], T([38809, 12], f16, stride=(1, 38809)), True), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 196, 768], f16, stride=(151296, 768, 1)), [1]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 768], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 768], f16)), {})
+cnt: 12, ((T([12608, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 12608], f16, stride=(1, 768)), T([12608, 3072], f16)), {})
+cnt: 12, ((T([12608, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 12608], f16, stride=(1, 3072)), T([12608, 768], f16)), {})
+cnt: 12, ((T([12608, 768], f16), T([768, 768], f16)), {})
+cnt: 12, ((T([768, 12608], f16, stride=(1, 768)), T([12608, 768], f16)), {})
+cnt: 12, ((T([12608, 2304], f16), T([2304, 768], f16)), {})
+cnt: 12, ((T([2304, 12608], f16, stride=(1, 2304)), T([12608, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 12, ((T([64, 12, 197, 64], f16, stride=(453888, 64, 2304, 1)), 0.125), {})
+cnt: 24, ((T([768], f16), T([64, 197, 768], f16)), {})
+cnt: 24, ((T([64, 197, 768], f16), T([768], f16)), {})
+cnt: 24, ((T([64, 197, 768], f16), T([64, 197, 768], f16)), {})
+cnt: 12, ((T([64, 12, 197, 64], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 24, ((T([64, 197, 768], f16), [768], T([768], f16), T([768], f16), 1e-06), {})
+cnt: 1, ((T([64, 768], f16), [768], T([768], f16), T([768], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 1, ((T([64, 768], f16), T([64, 768], f16), [768], T([64, 1], f32), T([64, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+cnt: 24, ((T([64, 197, 768], f16), T([64, 197, 768], f16), [768], T([64, 197, 1], f32), T([64, 197, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.new_zeros.default
+cnt: 12, ((T([38809, 12], f16, stride=(1, 38809)), [732, 12]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([64, 196, 768], f16), [64, 197, 768], 1, 1, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 197, 768], f16), [64, 197, 768], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.stack.default
+cnt: 12, (([T([64, 12, 197, 64], f16), T([64, 12, 197, 64], f16, stride=(151296, 12608, 1, 197)), T([64, 12, 197, 64], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 24, ((T([64, 197, 768], f16), [0, 1], True), {})
+cnt: 24, ((T([12608, 768], f16), [0], True), {})
+cnt: 12, ((T([12608, 3072], f16), [0], True), {})
+cnt: 12, ((T([64, 12, 197, 197], f16), [0], True), {})
+cnt: 12, ((T([12608, 2304], f16), [0], True), {})
+cnt: 1, ((T([64, 1, 768], f16, stride=(151296, 768, 1)), [0], True), {})
+Operator: aten.unbind.int
+cnt: 12, ((T([3, 64, 12, 197, 64], f16, stride=(768, 453888, 64, 2304, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/botnet26t_256_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/botnet26t_256_training.txt
new file mode 100644
index 0000000000000..4f2a25afb62e0
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/botnet26t_256_training.txt
@@ -0,0 +1,244 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 2, ((T([512, 256, 256], f16), -1, False), {})
+cnt: 1, ((T([512, 64, 64], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 1, ((T([512, 64, 64], f16), T([512, 64, 64], f16), -1, f16), {})
+cnt: 2, ((T([512, 256, 256], f16), T([512, 256, 256], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 3, ((T([128, 256, 16, 16], f16), [512, 64, 256]), {})
+cnt: 2, ((T([512, 256, 256], f16), [512, 256, 256]), {})
+cnt: 2, ((T([512, 16, 16, 64], f16), [131072, 64]), {})
+cnt: 4, ((T([131072, 31], f16), [512, 16, 16, 31]), {})
+cnt: 2, ((T([512, 16, 16, 16, 16], f16), [512, 256, 256]), {})
+cnt: 1, ((T([512, 256, 64], f16), [512, 256, 64]), {})
+cnt: 3, ((T([512, 64, 256], f16), [128, 256, 16, 16]), {})
+cnt: 3, ((T([128, 512, 16, 16], f16), [512, 128, 256]), {})
+cnt: 2, ((T([512, 16, 16, 128], f16), [131072, 128]), {})
+cnt: 1, ((T([512, 256, 128], f16), [512, 256, 128]), {})
+cnt: 3, ((T([512, 128, 256], f16), [128, 512, 16, 16]), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), [512, 128, 64]), {})
+cnt: 1, ((T([512, 64, 64], f16), [512, 64, 64]), {})
+cnt: 2, ((T([512, 8, 8, 128], f16), [32768, 128]), {})
+cnt: 2, ((T([32768, 15], f16), [512, 8, 8, 15]), {})
+cnt: 1, ((T([512, 8, 8, 8, 8], f16), [512, 64, 64]), {})
+cnt: 1, ((T([512, 64, 128], f16), [512, 64, 128]), {})
+cnt: 3, ((T([512, 128, 64], f16), [128, 512, 8, 8]), {})
+cnt: 1, ((T([512, 8, 8, 128], f16), [512, 64, 128]), {})
+cnt: 1, ((T([512, 16, 16, 128], f16), [512, 256, 128]), {})
+cnt: 1, ((T([512, 16, 16, 64], f16), [512, 256, 64]), {})
+Operator: aten.add.Tensor
+cnt: 31, ((T([], i64), 1), {})
+cnt: 4, ((T([128, 256, 64, 64], f16), T([128, 256, 64, 64], f16)), {})
+cnt: 4, ((T([128, 512, 32, 32], f16), T([128, 512, 32, 32], f16)), {})
+cnt: 4, ((T([128, 1024, 16, 16], f16), T([128, 1024, 16, 16], f16)), {})
+cnt: 2, ((T([512, 16, 16, 16, 16], f16, stride=(8432, 31, 527, 1, 0)), T([512, 16, 16, 16, 16], f16, stride=(8432, 527, 31, 0, 1))), {})
+cnt: 2, ((T([512, 256, 256], f16), T([512, 256, 256], f16)), {})
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16)), {})
+cnt: 1, ((T([512, 8, 8, 8, 8], f16, stride=(1080, 15, 135, 1, 0)), T([512, 8, 8, 8, 8], f16, stride=(1080, 135, 15, 0, 1))), {})
+cnt: 1, ((T([512, 64, 64], f16), T([512, 64, 64], f16)), {})
+cnt: 1, ((T([512, 8, 8, 128], f16, stride=(8192, 128, 1024, 1)), T([512, 8, 8, 128], f16)), {})
+cnt: 1, ((T([512, 64, 128], f16), T([512, 64, 128], f16)), {})
+cnt: 1, ((T([512, 16, 16, 128], f16, stride=(32768, 128, 2048, 1)), T([512, 16, 16, 128], f16)), {})
+cnt: 1, ((T([512, 256, 128], f16), T([512, 256, 128], f16)), {})
+cnt: 1, ((T([512, 16, 16, 64], f16, stride=(16384, 64, 1024, 1)), T([512, 16, 16, 64], f16)), {})
+cnt: 1, ((T([512, 256, 64], f16), T([512, 256, 64], f16)), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 512, 16, 16], f16), [2, 2], [2, 2]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 512, 8, 8], f16), T([128, 512, 16, 16], f16), [2, 2], [2, 2], [0, 0], False, True, None), {})
+Operator: aten.bmm.default
+cnt: 2, ((T([512, 256, 64], f16, stride=(16384, 1, 256)), T([512, 64, 256], f16)), {})
+cnt: 2, ((T([512, 256, 256], f16), T([512, 256, 64], f16, stride=(16384, 1, 256))), {})
+cnt: 2, ((T([512, 256, 128], f16, stride=(32768, 1, 256)), T([512, 128, 256], f16)), {})
+cnt: 2, ((T([512, 256, 256], f16), T([512, 256, 128], f16, stride=(32768, 1, 256))), {})
+cnt: 2, ((T([512, 64, 128], f16, stride=(8192, 1, 64)), T([512, 128, 64], f16)), {})
+cnt: 2, ((T([512, 64, 64], f16), T([512, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 1, ((T([512, 64, 64], f16, stride=(4096, 1, 64)), T([512, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 1, ((T([512, 128, 64], f16), T([512, 64, 64], f16)), {})
+cnt: 1, ((T([512, 256, 256], f16, stride=(65536, 1, 256)), T([512, 256, 128], f16, stride=(32768, 1, 256))), {})
+cnt: 1, ((T([512, 128, 256], f16), T([512, 256, 256], f16)), {})
+cnt: 1, ((T([512, 256, 256], f16, stride=(65536, 1, 256)), T([512, 256, 64], f16, stride=(16384, 1, 256))), {})
+cnt: 1, ((T([512, 64, 256], f16), T([512, 256, 256], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 512, 8, 8], f16), T([128, 512, 8, 8], f16), T([128, 512, 8, 8], f16)], 1), {})
+cnt: 1, (([T([128, 512, 16, 16], f16), T([128, 512, 16, 16], f16), T([128, 512, 16, 16], f16)], 1), {})
+cnt: 1, (([T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 256, 256], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 4, ((T([8192, 16, 31], f16), [0, 1], 0.0), {})
+cnt: 4, ((T([8192, 512], f16), [0, 15], 0.0), {})
+cnt: 2, ((T([4096, 8, 15], f16), [0, 1], 0.0), {})
+cnt: 2, ((T([4096, 128], f16), [0, 7], 0.0), {})
+cnt: 2, ((T([4096, 135], f16), [0, -7]), {})
+cnt: 2, ((T([4096, 8, 16], f16), [0, -1]), {})
+cnt: 4, ((T([8192, 527], f16), [0, -15]), {})
+cnt: 4, ((T([8192, 16, 32], f16), [0, -1]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 256, 256], f16), T([24, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([32, 24, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 64, 64, 64], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 64, 64], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 64, 64], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 64, 64], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([256, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([768, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([1536, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 512, 8, 8], f16), T([2048, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([512, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 8, 8], f16), T([1536, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 2, ((T([128, 2048, 8, 8], f16), T([128, 512, 8, 8], f16), T([2048, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1536, 8, 8], f16), T([128, 512, 8, 8], f16), T([1536, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 8, 8], f16), T([128, 2048, 8, 8], f16), T([512, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([128, 1024, 16, 16], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1536, 16, 16], f16), T([128, 512, 16, 16], f16), T([1536, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16), T([128, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 768, 16, 16], f16), T([128, 256, 16, 16], f16), T([768, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([128, 512, 32, 32], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 256, 32, 32], f16), T([256, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([128, 512, 32, 32], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 512, 32, 32], f16), T([128, 128, 32, 32], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 512, 32, 32], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([128, 256, 64, 64], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 128, 64, 64], f16), T([128, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 256, 64, 64], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 256, 64, 64], f16), T([128, 64, 64, 64], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 256, 64, 64], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([128, 32, 128, 128], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 24, 128, 128], f16), T([32, 24, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([128, 3, 256, 256], f16), T([24, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 256, 256], f16), T([128, 3, 256, 256], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2048, 8, 8], f16, stride=(2048, 1, 0, 0)), 64), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 128, 128], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 64, 128, 128], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([128, 64, 64, 64], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 2048, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 2, ((T([131072, 64], f16), T([64, 31], f16, stride=(1, 64))), {})
+cnt: 2, ((T([131072, 128], f16), T([128, 31], f16, stride=(1, 128))), {})
+cnt: 2, ((T([32768, 128], f16), T([128, 15], f16, stride=(1, 128))), {})
+cnt: 1, ((T([128, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2048], f16)), {})
+cnt: 2, ((T([15, 32768], f16, stride=(1, 15)), T([32768, 128], f16)), {})
+cnt: 2, ((T([32768, 15], f16), T([15, 128], f16)), {})
+cnt: 2, ((T([31, 131072], f16, stride=(1, 31)), T([131072, 128], f16)), {})
+cnt: 2, ((T([131072, 31], f16), T([31, 128], f16)), {})
+cnt: 2, ((T([31, 131072], f16, stride=(1, 31)), T([131072, 64], f16)), {})
+cnt: 2, ((T([131072, 31], f16), T([31, 64], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([512, 256, 256], f16), 0.125), {})
+cnt: 2, ((T([512, 256, 256], f16), 0.08838834764831845), {})
+cnt: 2, ((T([512, 64, 64], f16), 0.08838834764831845), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 24, 128, 128], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), T([128, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 1024, 16, 16], f16), T([128, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([128, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 512, 32, 32], f16), T([128, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 64, 64], f16), T([128, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([128, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([128, 24, 128, 128], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 24, 128, 128], f16),), {})
+cnt: 1, ((T([128, 32, 128, 128], f16),), {})
+cnt: 1, ((T([128, 64, 128, 128], f16),), {})
+cnt: 4, ((T([128, 64, 64, 64], f16),), {})
+cnt: 2, ((T([128, 256, 64, 64], f16),), {})
+cnt: 1, ((T([128, 128, 64, 64], f16),), {})
+cnt: 3, ((T([128, 128, 32, 32], f16),), {})
+cnt: 2, ((T([128, 512, 32, 32], f16),), {})
+cnt: 1, ((T([128, 256, 32, 32], f16),), {})
+cnt: 3, ((T([128, 256, 16, 16], f16),), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16),), {})
+cnt: 1, ((T([128, 512, 16, 16], f16),), {})
+cnt: 3, ((T([128, 512, 8, 8], f16),), {})
+cnt: 2, ((T([128, 2048, 8, 8], f16),), {})
+Operator: aten.slice_backward.default
+cnt: 2, ((T([4096, 8, 8], f16), [4096, 8, 15], 2, 7, 9223372036854775807, 1), {})
+cnt: 2, ((T([4096, 8, 15], f16), [4096, 9, 15], 1, 0, 8, 1), {})
+cnt: 2, ((T([4096, 9, 15], f16), [4096, 9, 15], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([8192, 16, 16], f16), [8192, 16, 31], 2, 15, 9223372036854775807, 1), {})
+cnt: 4, ((T([8192, 16, 31], f16), [8192, 17, 31], 1, 0, 16, 1), {})
+cnt: 4, ((T([8192, 17, 31], f16), [8192, 17, 31], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.split_with_sizes.default
+cnt: 1, ((T([128, 768, 16, 16], f16), [256, 256, 256], 1), {})
+cnt: 1, ((T([128, 1536, 16, 16], f16), [512, 512, 512], 1), {})
+cnt: 1, ((T([128, 1536, 8, 8], f16), [512, 512, 512], 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 1, ((T([512, 8, 8, 8, 8], f16, stride=(4096, 64, 1, 512, 8)), [2], True), {})
+cnt: 1, ((T([512, 8, 8, 8, 8], f16, stride=(4096, 512, 8, 64, 1)), [2], True), {})
+cnt: 2, ((T([512, 16, 16, 16, 16], f16, stride=(65536, 256, 1, 4096, 16)), [2], True), {})
+cnt: 2, ((T([512, 16, 16, 16, 16], f16, stride=(65536, 4096, 16, 256, 1)), [2], True), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16), 0), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), T([128, 512, 8, 8], f16), 0), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 512, 16, 16], f16), 0), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16), T([128, 1024, 16, 16], f16), 0), {})
+cnt: 3, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16), 0), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([128, 256, 32, 32], f16), 0), {})
+cnt: 2, ((T([128, 512, 32, 32], f16), T([128, 512, 32, 32], f16), 0), {})
+cnt: 3, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16), 0), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16), 0), {})
+cnt: 2, ((T([128, 256, 64, 64], f16), T([128, 256, 64, 64], f16), 0), {})
+cnt: 4, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), 0), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([128, 64, 128, 128], f16), 0), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 128, 128], f16), 0), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([128, 24, 128, 128], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/cait_m36_384_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/cait_m36_384_training.txt
new file mode 100644
index 0000000000000..b49e975750829
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/cait_m36_384_training.txt
@@ -0,0 +1,149 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([2, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([2, 1000], f16), T([2, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 36, ((T([2, 16, 576, 576], f16, stride=(5308416, 1, 9216, 16)), -1, False), {})
+cnt: 2, ((T([2, 16, 1, 577], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 2, ((T([2, 16, 1, 577], f16), T([2, 16, 1, 577], f16), -1, f16), {})
+cnt: 36, ((T([2, 16, 576, 576], f16, stride=(5308416, 1, 9216, 16)), T([2, 16, 576, 576], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 108, ((T([2, 16, 576, 48], f16), [32, 576, 48]), {})
+cnt: 36, ((T([2, 16, 48, 576], f16), [32, 48, 576]), {})
+cnt: 36, ((T([32, 576, 576], f16), [2, 16, 576, 576]), {})
+cnt: 144, ((T([2, 576, 576, 16], f16), [663552, 16]), {})
+cnt: 72, ((T([663552, 16], f16), [2, 576, 576, 16]), {})
+cnt: 72, ((T([2, 16, 576, 576], f16), [32, 576, 576]), {})
+cnt: 36, ((T([32, 576, 48], f16), [2, 16, 576, 48]), {})
+cnt: 36, ((T([2, 576, 16, 48], f16), [2, 576, 768]), {})
+cnt: 2, ((T([2, 16, 48, 577], f16), [32, 48, 577]), {})
+cnt: 2, ((T([32, 1, 577], f16), [2, 16, 1, 577]), {})
+cnt: 2, ((T([2, 16, 577, 48], f16), [32, 577, 48]), {})
+cnt: 2, ((T([32, 1, 48], f16), [2, 16, 1, 48]), {})
+cnt: 2, ((T([2, 577, 16, 48], f16), [2, 577, 768]), {})
+cnt: 2, ((T([2, 577, 768], f16), [1154, 768]), {})
+cnt: 36, ((T([2, 576, 3, 16, 48], f16), [2, 576, 2304]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([2, 576, 768], f16, stride=(442368, 1, 576)), T([1, 576, 768], f16)), {})
+cnt: 72, ((T([2, 576, 576, 16], f16), T([16], f16)), {})
+cnt: 72, ((T([2, 576, 768], f16, stride=(442368, 1, 576)), T([2, 576, 768], f16)), {})
+cnt: 1, ((T([2, 1, 768], f16, stride=(0, 768, 1)), T([2, 1, 768], f16)), {})
+cnt: 4, ((T([2, 1, 768], f16), T([2, 1, 768], f16)), {})
+cnt: 1, ((T([2, 1, 768], f16, stride=(443136, 768, 1)), T([2, 1, 768], f16)), {})
+cnt: 4, ((T([2, 577, 768], f16), T([2, 577, 768], f16)), {})
+cnt: 2, ((T([2, 1, 768], f16), T([2, 1, 768], f16, stride=(443136, 768, 1))), {})
+cnt: 1, ((T([2, 576, 768], f16, stride=(443136, 768, 1)), T([2, 576, 768], f16, stride=(443136, 768, 1))), {})
+cnt: 1, ((T([2, 576, 768], f16), T([2, 576, 768], f16, stride=(443136, 768, 1))), {})
+cnt: 72, ((T([2, 576, 768], f16), T([2, 576, 768], f16)), {})
+cnt: 72, ((T([3, 2, 16, 576, 48], f16), T([3, 2, 16, 576, 48], f16)), {})
+Operator: aten.addmm.default
+cnt: 36, ((T([2304], f16), T([1152, 768], f16), T([768, 2304], f16, stride=(1, 768))), {})
+cnt: 36, ((T([768], f16), T([1152, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 36, ((T([3072], f16), T([1152, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 36, ((T([768], f16), T([1152, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 2, ((T([768], f16), T([2, 768], f16, stride=(443136, 1)), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 4, ((T([768], f16), T([1154, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 2, ((T([768], f16), T([2, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 2, ((T([3072], f16), T([2, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 2, ((T([768], f16), T([2, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([1000], f16), T([2, 768], f16, stride=(443136, 1)), T([768, 1000], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 36, ((T([32, 576, 48], f16), T([32, 48, 576], f16)), {})
+cnt: 36, ((T([32, 576, 576], f16), T([32, 576, 48], f16)), {})
+cnt: 2, ((T([32, 1, 48], f16), T([32, 48, 577], f16)), {})
+cnt: 2, ((T([32, 1, 577], f16), T([32, 577, 48], f16)), {})
+cnt: 2, ((T([32, 577, 1], f16), T([32, 1, 48], f16)), {})
+cnt: 2, ((T([32, 1, 48], f16), T([32, 48, 577], f16, stride=(27696, 1, 48))), {})
+cnt: 2, ((T([32, 48, 1], f16), T([32, 1, 577], f16)), {})
+cnt: 2, ((T([32, 1, 577], f16), T([32, 577, 48], f16, stride=(27696, 1, 577))), {})
+cnt: 36, ((T([32, 576, 576], f16, stride=(331776, 1, 576)), T([32, 576, 48], f16)), {})
+cnt: 36, ((T([32, 576, 48], f16), T([32, 48, 576], f16, stride=(27648, 1, 48))), {})
+cnt: 36, ((T([32, 48, 576], f16, stride=(27648, 1, 48)), T([32, 576, 576], f16)), {})
+cnt: 36, ((T([32, 576, 576], f16), T([32, 576, 48], f16, stride=(27648, 1, 576))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([2, 1, 768], f16, stride=(0, 768, 1)), T([2, 576, 768], f16, stride=(442368, 1, 576))], 1), {})
+cnt: 2, (([T([2, 1, 768], f16), T([2, 576, 768], f16, stride=(442368, 1, 576))], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([2, 3, 384, 384], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([2, 3, 384, 384], f16), T([768, 3, 16, 16], f16), T([768], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([2, 768, 24, 24], f16, stride=(442368, 1, 18432, 768)), T([2, 3, 384, 384], f16), T([768, 3, 16, 16], f16), [768], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([2, 3, 384, 384], f16), T([2, 3, 384, 384], f16)), {})
+Operator: aten.gelu.default
+cnt: 36, ((T([2, 576, 3072], f16),), {})
+cnt: 2, ((T([2, 1, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 2, ((T([2, 1, 3072], f16), T([2, 1, 3072], f16)), {})
+cnt: 36, ((T([2, 576, 3072], f16), T([2, 576, 3072], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([2], i64),), {})
+Operator: aten.mm.default
+cnt: 72, ((T([663552, 16], f16), T([16, 16], f16, stride=(1, 16))), {})
+cnt: 1, ((T([2, 1000], f16), T([1000, 768], f16)), {})
+cnt: 1, ((T([1000, 2], f16, stride=(1, 1000)), T([2, 768], f16, stride=(443136, 1))), {})
+cnt: 2, ((T([2, 768], f16), T([768, 3072], f16)), {})
+cnt: 2, ((T([768, 2], f16, stride=(1, 768)), T([2, 3072], f16)), {})
+cnt: 2, ((T([2, 3072], f16), T([3072, 768], f16)), {})
+cnt: 2, ((T([3072, 2], f16, stride=(1, 3072)), T([2, 768], f16)), {})
+cnt: 4, ((T([2, 768], f16), T([768, 768], f16)), {})
+cnt: 2, ((T([768, 2], f16, stride=(1, 768)), T([2, 768], f16)), {})
+cnt: 4, ((T([1154, 768], f16), T([768, 768], f16)), {})
+cnt: 4, ((T([768, 1154], f16, stride=(1, 768)), T([1154, 768], f16)), {})
+cnt: 2, ((T([768, 2], f16, stride=(1, 768)), T([2, 768], f16, stride=(443136, 1))), {})
+cnt: 36, ((T([1152, 768], f16), T([768, 3072], f16)), {})
+cnt: 36, ((T([768, 1152], f16, stride=(1, 768)), T([1152, 3072], f16)), {})
+cnt: 36, ((T([1152, 3072], f16), T([3072, 768], f16)), {})
+cnt: 36, ((T([3072, 1152], f16, stride=(1, 3072)), T([1152, 768], f16)), {})
+cnt: 36, ((T([1152, 768], f16), T([768, 768], f16)), {})
+cnt: 36, ((T([768, 1152], f16, stride=(1, 768)), T([1152, 768], f16)), {})
+cnt: 72, ((T([16, 663552], f16, stride=(1, 16)), T([663552, 16], f16)), {})
+cnt: 72, ((T([663552, 16], f16), T([16, 16], f16)), {})
+cnt: 36, ((T([1152, 2304], f16), T([2304, 768], f16)), {})
+cnt: 36, ((T([2304, 1152], f16, stride=(1, 2304)), T([1152, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 36, ((T([2, 16, 576, 48], f16, stride=(1327104, 48, 2304, 1)), 0.14433756729740643), {})
+cnt: 72, ((T([768], f16), T([2, 576, 768], f16)), {})
+cnt: 4, ((T([2, 16, 1, 48], f16), 0.14433756729740643), {})
+cnt: 4, ((T([768], f16), T([2, 1, 768], f16)), {})
+cnt: 1, ((T([2, 1, 768], f16, stride=(443136, 768, 1)), T([768], f16)), {})
+cnt: 1, ((T([2, 1, 768], f16, stride=(443136, 768, 1)), T([2, 1, 768], f16)), {})
+cnt: 3, ((T([2, 1, 768], f16), T([768], f16)), {})
+cnt: 3, ((T([2, 1, 768], f16), T([2, 1, 768], f16)), {})
+cnt: 72, ((T([2, 576, 768], f16), T([768], f16)), {})
+cnt: 72, ((T([2, 576, 768], f16), T([2, 576, 768], f16)), {})
+cnt: 36, ((T([2, 16, 576, 48], f16), 0.14433756729740643), {})
+Operator: aten.native_layer_norm.default
+cnt: 72, ((T([2, 576, 768], f16, stride=(442368, 1, 576)), [768], T([768], f16), T([768], f16), 1e-06), {})
+cnt: 3, ((T([2, 577, 768], f16), [768], T([768], f16), T([768], f16), 1e-06), {})
+cnt: 2, ((T([2, 1, 768], f16), [768], T([768], f16), T([768], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 3, ((T([2, 577, 768], f16), T([2, 577, 768], f16), [768], T([2, 577, 1], f32), T([2, 577, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+cnt: 2, ((T([2, 1, 768], f16), T([2, 1, 768], f16), [768], T([2, 1, 1], f32), T([2, 1, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+cnt: 72, ((T([2, 576, 768], f16), T([2, 576, 768], f16, stride=(442368, 1, 576)), [768], T([2, 576, 1], f32), T([2, 576, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([2, 1000], f16), T([2], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([2, 1000], f16), T([2], i64), None, 1, -100), {})
+Operator: aten.select_backward.default
+cnt: 3, ((T([2, 768], f16), [2, 577, 768], 1, 0), {})
+cnt: 36, ((T([2, 16, 576, 48], f16), [3, 2, 16, 576, 48], 0, 2), {})
+cnt: 36, ((T([2, 16, 576, 48], f16, stride=(442368, 27648, 1, 576)), [3, 2, 16, 576, 48], 0, 1), {})
+cnt: 36, ((T([2, 16, 576, 48], f16), [3, 2, 16, 576, 48], 0, 0), {})
+Operator: aten.slice_backward.default
+cnt: 3, ((T([2, 577, 768], f16), [2, 577, 768], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2, 1000], f16), [0], True), {})
+cnt: 4, ((T([2, 1, 768], f16), [0, 1], True), {})
+cnt: 6, ((T([2, 768], f16), [0], True), {})
+cnt: 2, ((T([2, 3072], f16), [0], True), {})
+cnt: 4, ((T([1154, 768], f16), [0], True), {})
+cnt: 1, ((T([2, 1, 768], f16), [0], True), {})
+cnt: 72, ((T([2, 576, 768], f16), [0, 1], True), {})
+cnt: 72, ((T([1152, 768], f16), [0], True), {})
+cnt: 36, ((T([1152, 3072], f16), [0], True), {})
+cnt: 72, ((T([2, 576, 576, 16], f16, stride=(5308416, 576, 1, 331776)), [0, 1, 2], True), {})
+cnt: 36, ((T([1152, 2304], f16), [0], True), {})
+cnt: 1, ((T([2, 576, 768], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/coat_lite_mini_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/coat_lite_mini_training.txt
new file mode 100644
index 0000000000000..cba167ebdb848
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/coat_lite_mini_training.txt
@@ -0,0 +1,348 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 2, ((T([128, 8, 3137, 8], f16, stride=(602304, 8, 192, 1)), 2, False), {})
+cnt: 2, ((T([128, 8, 785, 16], f16, stride=(301440, 16, 384, 1)), 2, False), {})
+cnt: 2, ((T([128, 8, 197, 40], f16, stride=(189120, 40, 960, 1)), 2, False), {})
+cnt: 2, ((T([128, 8, 50, 64], f16, stride=(76800, 64, 1536, 1)), 2, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 2, ((T([128, 8, 50, 64], f16, stride=(25600, 3200, 1, 50)), T([128, 8, 50, 64], f16), 2, f16), {})
+cnt: 2, ((T([128, 8, 197, 40], f16, stride=(63040, 7880, 1, 197)), T([128, 8, 197, 40], f16), 2, f16), {})
+cnt: 2, ((T([128, 8, 785, 16], f16, stride=(100480, 12560, 1, 785)), T([128, 8, 785, 16], f16), 2, f16), {})
+cnt: 2, ((T([128, 8, 3137, 8], f16, stride=(200768, 25096, 1, 3137)), T([128, 8, 3137, 8], f16), 2, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 6, ((T([128, 8, 3137, 8], f16), [1024, 3137, 8]), {})
+cnt: 2, ((T([1024, 8, 8], f16), [128, 8, 8, 8]), {})
+cnt: 2, ((T([1024, 3137, 8], f16), [128, 8, 3137, 8]), {})
+cnt: 2, ((T([128, 3137, 8, 8], f16), [128, 3137, 64]), {})
+cnt: 6, ((T([128, 8, 785, 16], f16), [1024, 785, 16]), {})
+cnt: 2, ((T([1024, 16, 16], f16), [128, 8, 16, 16]), {})
+cnt: 2, ((T([1024, 785, 16], f16), [128, 8, 785, 16]), {})
+cnt: 2, ((T([128, 785, 8, 16], f16), [128, 785, 128]), {})
+cnt: 6, ((T([128, 8, 197, 40], f16), [1024, 197, 40]), {})
+cnt: 2, ((T([1024, 40, 40], f16), [128, 8, 40, 40]), {})
+cnt: 2, ((T([1024, 197, 40], f16), [128, 8, 197, 40]), {})
+cnt: 2, ((T([128, 197, 8, 40], f16), [128, 197, 320]), {})
+cnt: 6, ((T([128, 8, 50, 64], f16), [1024, 50, 64]), {})
+cnt: 2, ((T([1024, 64, 64], f16), [128, 8, 64, 64]), {})
+cnt: 2, ((T([1024, 50, 64], f16), [128, 8, 50, 64]), {})
+cnt: 2, ((T([128, 50, 8, 64], f16), [128, 50, 512]), {})
+cnt: 2, ((T([128, 50, 3, 8, 64], f16), [128, 50, 1536]), {})
+cnt: 2, ((T([128, 197, 3, 8, 40], f16), [128, 197, 960]), {})
+cnt: 2, ((T([128, 785, 3, 8, 16], f16), [128, 785, 384]), {})
+cnt: 2, ((T([128, 3137, 3, 8, 8], f16), [128, 3137, 192]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16, stride=(200768, 1, 3584, 64))), {})
+cnt: 6, ((T([128, 8, 3137, 8], f16), T([128, 8, 3137, 8], f16)), {})
+cnt: 10, ((T([128, 3137, 64], f16), T([128, 3137, 64], f16)), {})
+cnt: 2, ((T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16, stride=(100480, 1, 3584, 128))), {})
+cnt: 6, ((T([128, 8, 785, 16], f16), T([128, 8, 785, 16], f16)), {})
+cnt: 10, ((T([128, 785, 128], f16), T([128, 785, 128], f16)), {})
+cnt: 2, ((T([128, 320, 14, 14], f16), T([128, 320, 14, 14], f16, stride=(63040, 1, 4480, 320))), {})
+cnt: 6, ((T([128, 8, 197, 40], f16), T([128, 8, 197, 40], f16)), {})
+cnt: 10, ((T([128, 197, 320], f16), T([128, 197, 320], f16)), {})
+cnt: 2, ((T([128, 512, 7, 7], f16), T([128, 512, 7, 7], f16, stride=(25600, 1, 3584, 512))), {})
+cnt: 6, ((T([128, 8, 50, 64], f16), T([128, 8, 50, 64], f16)), {})
+cnt: 10, ((T([128, 50, 512], f16), T([128, 50, 512], f16)), {})
+cnt: 4, ((T([3, 128, 8, 50, 64], f16), T([3, 128, 8, 50, 64], f16)), {})
+cnt: 2, ((T([128, 512, 7, 7], f16, stride=(25600, 1, 3584, 512)), T([128, 512, 7, 7], f16, stride=(25088, 1, 3584, 512))), {})
+cnt: 1, ((T([192, 1, 7, 7], f16), T([192, 1, 7, 7], f16)), {})
+cnt: 2, ((T([192], f16), T([192], f16)), {})
+cnt: 1, ((T([192, 1, 5, 5], f16), T([192, 1, 5, 5], f16)), {})
+cnt: 2, ((T([128, 1, 3, 3], f16), T([128, 1, 3, 3], f16)), {})
+cnt: 2, ((T([128], f16), T([128], f16)), {})
+cnt: 1, ((T([512, 1, 3, 3], f16), T([512, 1, 3, 3], f16)), {})
+cnt: 1, ((T([512], f16), T([512], f16)), {})
+cnt: 4, ((T([3, 128, 8, 197, 40], f16), T([3, 128, 8, 197, 40], f16)), {})
+cnt: 2, ((T([128, 320, 14, 14], f16, stride=(63040, 1, 4480, 320)), T([128, 320, 14, 14], f16, stride=(62720, 1, 4480, 320))), {})
+cnt: 1, ((T([120, 1, 7, 7], f16), T([120, 1, 7, 7], f16)), {})
+cnt: 2, ((T([120], f16), T([120], f16)), {})
+cnt: 1, ((T([120, 1, 5, 5], f16), T([120, 1, 5, 5], f16)), {})
+cnt: 1, ((T([80, 1, 3, 3], f16), T([80, 1, 3, 3], f16)), {})
+cnt: 1, ((T([80], f16), T([80], f16)), {})
+cnt: 1, ((T([320, 1, 3, 3], f16), T([320, 1, 3, 3], f16)), {})
+cnt: 1, ((T([320], f16), T([320], f16)), {})
+cnt: 4, ((T([3, 128, 8, 785, 16], f16), T([3, 128, 8, 785, 16], f16)), {})
+cnt: 2, ((T([128, 128, 28, 28], f16, stride=(100480, 1, 3584, 128)), T([128, 128, 28, 28], f16, stride=(100352, 1, 3584, 128))), {})
+cnt: 1, ((T([48, 1, 7, 7], f16), T([48, 1, 7, 7], f16)), {})
+cnt: 2, ((T([48], f16), T([48], f16)), {})
+cnt: 1, ((T([48, 1, 5, 5], f16), T([48, 1, 5, 5], f16)), {})
+cnt: 1, ((T([32, 1, 3, 3], f16), T([32, 1, 3, 3], f16)), {})
+cnt: 1, ((T([32], f16), T([32], f16)), {})
+cnt: 4, ((T([3, 128, 8, 3137, 8], f16), T([3, 128, 8, 3137, 8], f16)), {})
+cnt: 2, ((T([128, 64, 56, 56], f16, stride=(200768, 1, 3584, 64)), T([128, 64, 56, 56], f16, stride=(200704, 1, 3584, 64))), {})
+cnt: 1, ((T([24, 1, 7, 7], f16), T([24, 1, 7, 7], f16)), {})
+cnt: 2, ((T([24], f16), T([24], f16)), {})
+cnt: 1, ((T([24, 1, 5, 5], f16), T([24, 1, 5, 5], f16)), {})
+cnt: 1, ((T([16, 1, 3, 3], f16), T([16, 1, 3, 3], f16)), {})
+cnt: 1, ((T([16], f16), T([16], f16)), {})
+cnt: 1, ((T([64, 1, 3, 3], f16), T([64, 1, 3, 3], f16)), {})
+cnt: 1, ((T([64], f16), T([64], f16)), {})
+Operator: aten.addmm.default
+cnt: 2, ((T([192], f16), T([401536, 64], f16), T([64, 192], f16, stride=(1, 64))), {})
+cnt: 2, ((T([64], f16), T([401536, 64], f16), T([64, 64], f16, stride=(1, 64))), {})
+cnt: 2, ((T([512], f16), T([401536, 64], f16), T([64, 512], f16, stride=(1, 64))), {})
+cnt: 2, ((T([64], f16), T([401536, 512], f16), T([512, 64], f16, stride=(1, 512))), {})
+cnt: 2, ((T([384], f16), T([100480, 128], f16), T([128, 384], f16, stride=(1, 128))), {})
+cnt: 2, ((T([128], f16), T([100480, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 2, ((T([1024], f16), T([100480, 128], f16), T([128, 1024], f16, stride=(1, 128))), {})
+cnt: 2, ((T([128], f16), T([100480, 1024], f16), T([1024, 128], f16, stride=(1, 1024))), {})
+cnt: 2, ((T([960], f16), T([25216, 320], f16), T([320, 960], f16, stride=(1, 320))), {})
+cnt: 2, ((T([320], f16), T([25216, 320], f16), T([320, 320], f16, stride=(1, 320))), {})
+cnt: 2, ((T([1280], f16), T([25216, 320], f16), T([320, 1280], f16, stride=(1, 320))), {})
+cnt: 2, ((T([320], f16), T([25216, 1280], f16), T([1280, 320], f16, stride=(1, 1280))), {})
+cnt: 2, ((T([1536], f16), T([6400, 512], f16), T([512, 1536], f16, stride=(1, 512))), {})
+cnt: 2, ((T([512], f16), T([6400, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 2, ((T([2048], f16), T([6400, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 2, ((T([512], f16), T([6400, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+cnt: 1, ((T([1000], f16), T([128, 512], f16, stride=(25600, 1)), T([512, 1000], f16, stride=(1, 512))), {})
+Operator: aten.bmm.default
+cnt: 4, ((T([1024, 8, 3137], f16, stride=(25096, 1, 8)), T([1024, 3137, 8], f16)), {})
+cnt: 4, ((T([1024, 3137, 8], f16), T([1024, 8, 8], f16)), {})
+cnt: 4, ((T([1024, 16, 785], f16, stride=(12560, 1, 16)), T([1024, 785, 16], f16)), {})
+cnt: 4, ((T([1024, 785, 16], f16), T([1024, 16, 16], f16)), {})
+cnt: 4, ((T([1024, 40, 197], f16, stride=(7880, 1, 40)), T([1024, 197, 40], f16)), {})
+cnt: 4, ((T([1024, 197, 40], f16), T([1024, 40, 40], f16)), {})
+cnt: 4, ((T([1024, 64, 50], f16, stride=(3200, 1, 64)), T([1024, 50, 64], f16)), {})
+cnt: 4, ((T([1024, 50, 64], f16), T([1024, 64, 64], f16)), {})
+cnt: 2, ((T([1024, 50, 64], f16), T([1024, 64, 64], f16, stride=(4096, 1, 64))), {})
+cnt: 2, ((T([1024, 64, 64], f16), T([1024, 64, 50], f16, stride=(3200, 1, 64))), {})
+cnt: 2, ((T([1024, 197, 40], f16), T([1024, 40, 40], f16, stride=(1600, 1, 40))), {})
+cnt: 2, ((T([1024, 40, 40], f16), T([1024, 40, 197], f16, stride=(7880, 1, 40))), {})
+cnt: 2, ((T([1024, 785, 16], f16), T([1024, 16, 16], f16, stride=(256, 1, 16))), {})
+cnt: 2, ((T([1024, 16, 16], f16), T([1024, 16, 785], f16, stride=(12560, 1, 16))), {})
+cnt: 2, ((T([1024, 3137, 8], f16), T([1024, 8, 8], f16, stride=(64, 1, 8))), {})
+cnt: 2, ((T([1024, 8, 8], f16), T([1024, 8, 3137], f16, stride=(25096, 1, 8))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 1, 64], f16, stride=(0, 64, 1)), T([128, 3136, 64], f16)], 1), {})
+cnt: 2, (([T([128, 1, 64], f16, stride=(200768, 64, 1)), T([128, 3136, 64], f16, stride=(200704, 1, 3136))], 1), {})
+cnt: 2, (([T([128, 16, 56, 56], f16), T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)], 1), {})
+cnt: 1, (([T([128, 1, 128], f16, stride=(0, 128, 1)), T([128, 784, 128], f16)], 1), {})
+cnt: 2, (([T([128, 1, 128], f16, stride=(100480, 128, 1)), T([128, 784, 128], f16, stride=(100352, 1, 784))], 1), {})
+cnt: 2, (([T([128, 32, 28, 28], f16), T([128, 48, 28, 28], f16), T([128, 48, 28, 28], f16)], 1), {})
+cnt: 1, (([T([128, 1, 320], f16, stride=(0, 320, 1)), T([128, 196, 320], f16)], 1), {})
+cnt: 2, (([T([128, 1, 320], f16, stride=(63040, 320, 1)), T([128, 196, 320], f16, stride=(62720, 1, 196))], 1), {})
+cnt: 2, (([T([128, 80, 14, 14], f16), T([128, 120, 14, 14], f16), T([128, 120, 14, 14], f16)], 1), {})
+cnt: 1, (([T([128, 1, 512], f16, stride=(0, 512, 1)), T([128, 49, 512], f16)], 1), {})
+cnt: 2, (([T([128, 1, 512], f16, stride=(25600, 512, 1)), T([128, 49, 512], f16, stride=(25088, 1, 49))], 1), {})
+cnt: 2, (([T([128, 128, 7, 7], f16), T([128, 192, 7, 7], f16), T([128, 192, 7, 7], f16)], 1), {})
+cnt: 2, (([T([128, 128, 7, 7], f16, stride=(6272, 1, 896, 128)), T([128, 192, 7, 7], f16, stride=(9408, 1, 1344, 192)), T([128, 192, 7, 7], f16, stride=(9408, 1, 1344, 192))], 1), {})
+cnt: 2, (([T([128, 80, 14, 14], f16, stride=(15680, 1, 1120, 80)), T([128, 120, 14, 14], f16, stride=(23520, 1, 1680, 120)), T([128, 120, 14, 14], f16, stride=(23520, 1, 1680, 120))], 1), {})
+cnt: 2, (([T([128, 32, 28, 28], f16, stride=(25088, 1, 896, 32)), T([128, 48, 28, 28], f16, stride=(37632, 1, 1344, 48)), T([128, 48, 28, 28], f16, stride=(37632, 1, 1344, 48))], 1), {})
+cnt: 2, (([T([128, 16, 56, 56], f16, stride=(50176, 1, 896, 16)), T([128, 24, 56, 56], f16, stride=(75264, 1, 1344, 24)), T([128, 24, 56, 56], f16, stride=(75264, 1, 1344, 24))], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 2, ((T([128, 8, 3136, 8], f16, stride=(200704, 8, 64, 1)), [0, 0, 1, 0, 0, 0], 0.0), {})
+cnt: 2, ((T([128, 8, 784, 16], f16, stride=(100352, 16, 128, 1)), [0, 0, 1, 0, 0, 0], 0.0), {})
+cnt: 2, ((T([128, 8, 196, 40], f16, stride=(62720, 40, 320, 1)), [0, 0, 1, 0, 0, 0], 0.0), {})
+cnt: 2, ((T([128, 8, 49, 64], f16, stride=(25088, 64, 512, 1)), [0, 0, 1, 0, 0, 0], 0.0), {})
+cnt: 2, ((T([128, 8, 50, 64], f16, stride=(25600, 64, 512, 1)), [0, 0, -1, 0, 0, 0]), {})
+cnt: 2, ((T([128, 8, 197, 40], f16, stride=(63040, 40, 320, 1)), [0, 0, -1, 0, 0, 0]), {})
+cnt: 2, ((T([128, 8, 785, 16], f16, stride=(100480, 16, 128, 1)), [0, 0, -1, 0, 0, 0]), {})
+cnt: 2, ((T([128, 8, 3137, 8], f16, stride=(200768, 8, 64, 1)), [0, 0, -1, 0, 0, 0]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([64, 3, 4, 4], f16), T([64], f16), [4, 4], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 64, 56, 56], f16, stride=(200768, 1, 3584, 64)), T([64, 1, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 2, ((T([128, 16, 56, 56], f16, stride=(602304, 1, 10752, 192)), T([16, 1, 3, 3], f16), T([16], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 2, ((T([128, 24, 56, 56], f16, stride=(602304, 1, 10752, 192)), T([24, 1, 5, 5], f16), T([24], f16), [1, 1], [2, 2], [1, 1], False, [0, 0], 24), {})
+cnt: 2, ((T([128, 24, 56, 56], f16, stride=(602304, 1, 10752, 192)), T([24, 1, 7, 7], f16), T([24], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 24), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 2, 2], f16), T([128], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 28, 28], f16, stride=(100480, 1, 3584, 128)), T([128, 1, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 128), {})
+cnt: 2, ((T([128, 32, 28, 28], f16, stride=(301440, 1, 10752, 384)), T([32, 1, 3, 3], f16), T([32], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 2, ((T([128, 48, 28, 28], f16, stride=(301440, 1, 10752, 384)), T([48, 1, 5, 5], f16), T([48], f16), [1, 1], [2, 2], [1, 1], False, [0, 0], 48), {})
+cnt: 2, ((T([128, 48, 28, 28], f16, stride=(301440, 1, 10752, 384)), T([48, 1, 7, 7], f16), T([48], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 48), {})
+cnt: 1, ((T([128, 128, 28, 28], f16), T([320, 128, 2, 2], f16), T([320], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 320, 14, 14], f16, stride=(63040, 1, 4480, 320)), T([320, 1, 3, 3], f16), T([320], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 320), {})
+cnt: 2, ((T([128, 80, 14, 14], f16, stride=(189120, 1, 13440, 960)), T([80, 1, 3, 3], f16), T([80], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 80), {})
+cnt: 2, ((T([128, 120, 14, 14], f16, stride=(189120, 1, 13440, 960)), T([120, 1, 5, 5], f16), T([120], f16), [1, 1], [2, 2], [1, 1], False, [0, 0], 120), {})
+cnt: 2, ((T([128, 120, 14, 14], f16, stride=(189120, 1, 13440, 960)), T([120, 1, 7, 7], f16), T([120], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 120), {})
+cnt: 1, ((T([128, 320, 14, 14], f16), T([512, 320, 2, 2], f16), T([512], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 512, 7, 7], f16, stride=(25600, 1, 3584, 512)), T([512, 1, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 512), {})
+cnt: 2, ((T([128, 128, 7, 7], f16, stride=(76800, 1, 10752, 1536)), T([128, 1, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 128), {})
+cnt: 2, ((T([128, 192, 7, 7], f16, stride=(76800, 1, 10752, 1536)), T([192, 1, 5, 5], f16), T([192], f16), [1, 1], [2, 2], [1, 1], False, [0, 0], 192), {})
+cnt: 2, ((T([128, 192, 7, 7], f16, stride=(76800, 1, 10752, 1536)), T([192, 1, 7, 7], f16), T([192], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 192), {})
+Operator: aten.convolution_backward.default
+cnt: 2, ((T([128, 192, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([128, 192, 7, 7], f16, stride=(76800, 1, 10752, 1536)), T([192, 1, 7, 7], f16), [192], [1, 1], [3, 3], [1, 1], False, [0, 0], 192, [True, True, True]), {})
+cnt: 2, ((T([128, 192, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([128, 192, 7, 7], f16, stride=(76800, 1, 10752, 1536)), T([192, 1, 5, 5], f16), [192], [1, 1], [2, 2], [1, 1], False, [0, 0], 192, [True, True, True]), {})
+cnt: 2, ((T([128, 128, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([128, 128, 7, 7], f16, stride=(76800, 1, 10752, 1536)), T([128, 1, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 128, [True, True, True]), {})
+cnt: 2, ((T([128, 512, 7, 7], f16, stride=(25600, 1, 3584, 512)), T([128, 512, 7, 7], f16, stride=(25600, 1, 3584, 512)), T([512, 1, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 512, [True, True, True]), {})
+cnt: 1, ((T([128, 512, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([128, 320, 14, 14], f16), T([512, 320, 2, 2], f16), [512], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 120, 14, 14], f16, stride=(62720, 1, 4480, 320)), T([128, 120, 14, 14], f16, stride=(189120, 1, 13440, 960)), T([120, 1, 7, 7], f16), [120], [1, 1], [3, 3], [1, 1], False, [0, 0], 120, [True, True, True]), {})
+cnt: 2, ((T([128, 120, 14, 14], f16, stride=(62720, 1, 4480, 320)), T([128, 120, 14, 14], f16, stride=(189120, 1, 13440, 960)), T([120, 1, 5, 5], f16), [120], [1, 1], [2, 2], [1, 1], False, [0, 0], 120, [True, True, True]), {})
+cnt: 2, ((T([128, 80, 14, 14], f16, stride=(62720, 1, 4480, 320)), T([128, 80, 14, 14], f16, stride=(189120, 1, 13440, 960)), T([80, 1, 3, 3], f16), [80], [1, 1], [1, 1], [1, 1], False, [0, 0], 80, [True, True, True]), {})
+cnt: 2, ((T([128, 320, 14, 14], f16, stride=(63040, 1, 4480, 320)), T([128, 320, 14, 14], f16, stride=(63040, 1, 4480, 320)), T([320, 1, 3, 3], f16), [320], [1, 1], [1, 1], [1, 1], False, [0, 0], 320, [True, True, True]), {})
+cnt: 1, ((T([128, 320, 14, 14], f16, stride=(62720, 1, 4480, 320)), T([128, 128, 28, 28], f16), T([320, 128, 2, 2], f16), [320], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 48, 28, 28], f16, stride=(100352, 1, 3584, 128)), T([128, 48, 28, 28], f16, stride=(301440, 1, 10752, 384)), T([48, 1, 7, 7], f16), [48], [1, 1], [3, 3], [1, 1], False, [0, 0], 48, [True, True, True]), {})
+cnt: 2, ((T([128, 48, 28, 28], f16, stride=(100352, 1, 3584, 128)), T([128, 48, 28, 28], f16, stride=(301440, 1, 10752, 384)), T([48, 1, 5, 5], f16), [48], [1, 1], [2, 2], [1, 1], False, [0, 0], 48, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 28, 28], f16, stride=(100352, 1, 3584, 128)), T([128, 32, 28, 28], f16, stride=(301440, 1, 10752, 384)), T([32, 1, 3, 3], f16), [32], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, True]), {})
+cnt: 2, ((T([128, 128, 28, 28], f16, stride=(100480, 1, 3584, 128)), T([128, 128, 28, 28], f16, stride=(100480, 1, 3584, 128)), T([128, 1, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 128, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 28, 28], f16, stride=(100352, 1, 3584, 128)), T([128, 64, 56, 56], f16), T([128, 64, 2, 2], f16), [128], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16, stride=(200704, 1, 3584, 64)), T([128, 24, 56, 56], f16, stride=(602304, 1, 10752, 192)), T([24, 1, 7, 7], f16), [24], [1, 1], [3, 3], [1, 1], False, [0, 0], 24, [True, True, True]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16, stride=(200704, 1, 3584, 64)), T([128, 24, 56, 56], f16, stride=(602304, 1, 10752, 192)), T([24, 1, 5, 5], f16), [24], [1, 1], [2, 2], [1, 1], False, [0, 0], 24, [True, True, True]), {})
+cnt: 2, ((T([128, 16, 56, 56], f16, stride=(200704, 1, 3584, 64)), T([128, 16, 56, 56], f16, stride=(602304, 1, 10752, 192)), T([16, 1, 3, 3], f16), [16], [1, 1], [1, 1], [1, 1], False, [0, 0], 16, [True, True, True]), {})
+cnt: 2, ((T([128, 64, 56, 56], f16, stride=(200768, 1, 3584, 64)), T([128, 64, 56, 56], f16, stride=(200768, 1, 3584, 64)), T([64, 1, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16, stride=(200704, 1, 3584, 64)), T([128, 3, 224, 224], f16), T([64, 3, 4, 4], f16), [64], [4, 4], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.gelu.default
+cnt: 2, ((T([128, 3137, 512], f16),), {})
+cnt: 2, ((T([128, 785, 1024], f16),), {})
+cnt: 2, ((T([128, 197, 1280], f16),), {})
+cnt: 2, ((T([128, 50, 2048], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 2, ((T([128, 50, 2048], f16), T([128, 50, 2048], f16)), {})
+cnt: 2, ((T([128, 197, 1280], f16), T([128, 197, 1280], f16)), {})
+cnt: 2, ((T([128, 785, 1024], f16), T([128, 785, 1024], f16)), {})
+cnt: 2, ((T([128, 3137, 512], f16), T([128, 3137, 512], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 512], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 512], f16, stride=(25600, 1))), {})
+cnt: 2, ((T([6400, 512], f16), T([512, 2048], f16)), {})
+cnt: 2, ((T([512, 6400], f16, stride=(1, 512)), T([6400, 2048], f16)), {})
+cnt: 2, ((T([6400, 2048], f16), T([2048, 512], f16)), {})
+cnt: 2, ((T([2048, 6400], f16, stride=(1, 2048)), T([6400, 512], f16)), {})
+cnt: 2, ((T([6400, 512], f16), T([512, 512], f16)), {})
+cnt: 2, ((T([512, 6400], f16, stride=(1, 512)), T([6400, 512], f16)), {})
+cnt: 2, ((T([6400, 1536], f16), T([1536, 512], f16)), {})
+cnt: 2, ((T([1536, 6400], f16, stride=(1, 1536)), T([6400, 512], f16)), {})
+cnt: 2, ((T([25216, 320], f16), T([320, 1280], f16)), {})
+cnt: 2, ((T([320, 25216], f16, stride=(1, 320)), T([25216, 1280], f16)), {})
+cnt: 2, ((T([25216, 1280], f16), T([1280, 320], f16)), {})
+cnt: 2, ((T([1280, 25216], f16, stride=(1, 1280)), T([25216, 320], f16)), {})
+cnt: 2, ((T([25216, 320], f16), T([320, 320], f16)), {})
+cnt: 2, ((T([320, 25216], f16, stride=(1, 320)), T([25216, 320], f16)), {})
+cnt: 2, ((T([25216, 960], f16), T([960, 320], f16)), {})
+cnt: 2, ((T([960, 25216], f16, stride=(1, 960)), T([25216, 320], f16)), {})
+cnt: 2, ((T([100480, 128], f16), T([128, 1024], f16)), {})
+cnt: 2, ((T([128, 100480], f16, stride=(1, 128)), T([100480, 1024], f16)), {})
+cnt: 2, ((T([100480, 1024], f16), T([1024, 128], f16)), {})
+cnt: 2, ((T([1024, 100480], f16, stride=(1, 1024)), T([100480, 128], f16)), {})
+cnt: 2, ((T([100480, 128], f16), T([128, 128], f16)), {})
+cnt: 2, ((T([128, 100480], f16, stride=(1, 128)), T([100480, 128], f16)), {})
+cnt: 2, ((T([100480, 384], f16), T([384, 128], f16)), {})
+cnt: 2, ((T([384, 100480], f16, stride=(1, 384)), T([100480, 128], f16)), {})
+cnt: 2, ((T([401536, 64], f16), T([64, 512], f16)), {})
+cnt: 2, ((T([64, 401536], f16, stride=(1, 64)), T([401536, 512], f16)), {})
+cnt: 2, ((T([401536, 512], f16), T([512, 64], f16)), {})
+cnt: 2, ((T([512, 401536], f16, stride=(1, 512)), T([401536, 64], f16)), {})
+cnt: 2, ((T([401536, 64], f16), T([64, 64], f16)), {})
+cnt: 2, ((T([64, 401536], f16, stride=(1, 64)), T([401536, 64], f16)), {})
+cnt: 2, ((T([401536, 192], f16), T([192, 64], f16)), {})
+cnt: 2, ((T([192, 401536], f16, stride=(1, 192)), T([401536, 64], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 8, 3136, 8], f16, stride=(602304, 8, 192, 1)), T([128, 8, 3136, 8], f16, stride=(200704, 25088, 1, 3136))), {})
+cnt: 2, ((T([128, 8, 3137, 8], f16), 0.3535533905932738), {})
+cnt: 2, ((T([128, 8, 784, 16], f16, stride=(301440, 16, 384, 1)), T([128, 8, 784, 16], f16, stride=(100352, 12544, 1, 784))), {})
+cnt: 2, ((T([128, 8, 785, 16], f16), 0.25), {})
+cnt: 2, ((T([128, 8, 196, 40], f16, stride=(189120, 40, 960, 1)), T([128, 8, 196, 40], f16, stride=(62720, 7840, 1, 196))), {})
+cnt: 2, ((T([128, 8, 197, 40], f16), 0.15811388300841897), {})
+cnt: 2, ((T([128, 8, 49, 64], f16, stride=(76800, 64, 1536, 1)), T([128, 8, 49, 64], f16, stride=(25088, 3136, 1, 49))), {})
+cnt: 2, ((T([128, 8, 50, 64], f16), 0.125), {})
+cnt: 2, ((T([128, 8, 50, 64], f16, stride=(25600, 64, 512, 1)), 0.125), {})
+cnt: 2, ((T([128, 8, 49, 64], f16, stride=(25088, 64, 512, 1)), T([128, 8, 49, 64], f16, stride=(76800, 64, 1536, 1))), {})
+cnt: 2, ((T([128, 8, 49, 64], f16, stride=(25088, 64, 512, 1)), T([128, 8, 49, 64], f16, stride=(25088, 3136, 1, 49))), {})
+cnt: 2, ((T([128, 8, 197, 40], f16, stride=(63040, 40, 320, 1)), 0.15811388300841897), {})
+cnt: 2, ((T([128, 8, 196, 40], f16, stride=(62720, 40, 320, 1)), T([128, 8, 196, 40], f16, stride=(189120, 40, 960, 1))), {})
+cnt: 2, ((T([128, 8, 196, 40], f16, stride=(62720, 40, 320, 1)), T([128, 8, 196, 40], f16, stride=(62720, 7840, 1, 196))), {})
+cnt: 2, ((T([128, 8, 785, 16], f16, stride=(100480, 16, 128, 1)), 0.25), {})
+cnt: 2, ((T([128, 8, 784, 16], f16, stride=(100352, 16, 128, 1)), T([128, 8, 784, 16], f16, stride=(301440, 16, 384, 1))), {})
+cnt: 2, ((T([128, 8, 784, 16], f16, stride=(100352, 16, 128, 1)), T([128, 8, 784, 16], f16, stride=(100352, 12544, 1, 784))), {})
+cnt: 2, ((T([128, 8, 3137, 8], f16, stride=(200768, 8, 64, 1)), 0.3535533905932738), {})
+cnt: 2, ((T([128, 8, 3136, 8], f16, stride=(200704, 8, 64, 1)), T([128, 8, 3136, 8], f16, stride=(602304, 8, 192, 1))), {})
+cnt: 2, ((T([128, 8, 3136, 8], f16, stride=(200704, 8, 64, 1)), T([128, 8, 3136, 8], f16, stride=(200704, 25088, 1, 3136))), {})
+Operator: aten.native_layer_norm.default
+cnt: 1, ((T([128, 3136, 64], f16, stride=(200704, 1, 3136)), [64], T([64], f16), T([64], f16), 1e-05), {})
+cnt: 4, ((T([128, 3137, 64], f16), [64], T([64], f16), T([64], f16), 1e-06), {})
+cnt: 1, ((T([128, 784, 128], f16, stride=(100352, 1, 784)), [128], T([128], f16), T([128], f16), 1e-05), {})
+cnt: 4, ((T([128, 785, 128], f16), [128], T([128], f16), T([128], f16), 1e-06), {})
+cnt: 1, ((T([128, 196, 320], f16, stride=(62720, 1, 196)), [320], T([320], f16), T([320], f16), 1e-05), {})
+cnt: 4, ((T([128, 197, 320], f16), [320], T([320], f16), T([320], f16), 1e-06), {})
+cnt: 1, ((T([128, 49, 512], f16, stride=(25088, 1, 49)), [512], T([512], f16), T([512], f16), 1e-05), {})
+cnt: 5, ((T([128, 50, 512], f16), [512], T([512], f16), T([512], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 5, ((T([128, 50, 512], f16), T([128, 50, 512], f16), [512], T([128, 50, 1], f32), T([128, 50, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 1, ((T([128, 49, 512], f16, stride=(25600, 512, 1)), T([128, 49, 512], f16, stride=(25088, 1, 49)), [512], T([128, 49, 1], f32), T([128, 49, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 4, ((T([128, 197, 320], f16), T([128, 197, 320], f16), [320], T([128, 197, 1], f32), T([128, 197, 1], f32), T([320], f16), T([320], f16), [True, True, True]), {})
+cnt: 1, ((T([128, 196, 320], f16, stride=(63040, 320, 1)), T([128, 196, 320], f16, stride=(62720, 1, 196)), [320], T([128, 196, 1], f32), T([128, 196, 1], f32), T([320], f16), T([320], f16), [True, True, True]), {})
+cnt: 4, ((T([128, 785, 128], f16), T([128, 785, 128], f16), [128], T([128, 785, 1], f32), T([128, 785, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 1, ((T([128, 784, 128], f16, stride=(100480, 128, 1)), T([128, 784, 128], f16, stride=(100352, 1, 784)), [128], T([128, 784, 1], f32), T([128, 784, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 4, ((T([128, 3137, 64], f16), T([128, 3137, 64], f16), [64], T([128, 3137, 1], f32), T([128, 3137, 1], f32), T([64], f16), T([64], f16), [True, True, True]), {})
+cnt: 1, ((T([128, 3136, 64], f16, stride=(200768, 64, 1)), T([128, 3136, 64], f16, stride=(200704, 1, 3136)), [64], T([128, 3136, 1], f32), T([128, 3136, 1], f32), T([64], f16), T([64], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([128, 512], f16), [128, 50, 512], 1, 0), {})
+cnt: 2, ((T([128, 8, 50, 64], f16), [3, 128, 8, 50, 64], 0, 2), {})
+cnt: 2, ((T([128, 8, 50, 64], f16), [3, 128, 8, 50, 64], 0, 1), {})
+cnt: 2, ((T([128, 8, 50, 64], f16), [3, 128, 8, 50, 64], 0, 0), {})
+cnt: 2, ((T([128, 8, 197, 40], f16), [3, 128, 8, 197, 40], 0, 2), {})
+cnt: 2, ((T([128, 8, 197, 40], f16), [3, 128, 8, 197, 40], 0, 1), {})
+cnt: 2, ((T([128, 8, 197, 40], f16), [3, 128, 8, 197, 40], 0, 0), {})
+cnt: 2, ((T([128, 8, 785, 16], f16), [3, 128, 8, 785, 16], 0, 2), {})
+cnt: 2, ((T([128, 8, 785, 16], f16), [3, 128, 8, 785, 16], 0, 1), {})
+cnt: 2, ((T([128, 8, 785, 16], f16), [3, 128, 8, 785, 16], 0, 0), {})
+cnt: 2, ((T([128, 8, 3137, 8], f16), [3, 128, 8, 3137, 8], 0, 2), {})
+cnt: 2, ((T([128, 8, 3137, 8], f16), [3, 128, 8, 3137, 8], 0, 1), {})
+cnt: 2, ((T([128, 8, 3137, 8], f16), [3, 128, 8, 3137, 8], 0, 0), {})
+Operator: aten.slice_backward.default
+cnt: 5, ((T([128, 50, 512], f16), [128, 50, 512], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 49, 64], f16, stride=(25088, 64, 512, 1)), [128, 8, 49, 64], 3, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 49, 64], f16), [128, 8, 50, 64], 2, 1, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 50, 64], f16), [128, 8, 50, 64], 1, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 50, 64], f16), [128, 8, 50, 64], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 49, 512], f16), [128, 50, 512], 1, 1, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 1, 512], f16, stride=(25600, 512, 1)), [128, 50, 512], 1, 0, 1, 1), {})
+cnt: 1, ((T([128, 196, 320], f16, stride=(62720, 1, 196)), [128, 196, 320], 2, 0, 9223372036854775807, 1), {})
+cnt: 3, ((T([128, 196, 320], f16), [128, 197, 320], 1, 1, 9223372036854775807, 1), {})
+cnt: 5, ((T([128, 197, 320], f16), [128, 197, 320], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 196, 40], f16, stride=(62720, 40, 320, 1)), [128, 8, 196, 40], 3, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 196, 40], f16), [128, 8, 197, 40], 2, 1, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 197, 40], f16), [128, 8, 197, 40], 1, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 197, 40], f16), [128, 8, 197, 40], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 1, 320], f16, stride=(63040, 320, 1)), [128, 197, 320], 1, 0, 1, 1), {})
+cnt: 1, ((T([128, 784, 128], f16, stride=(100352, 1, 784)), [128, 784, 128], 2, 0, 9223372036854775807, 1), {})
+cnt: 3, ((T([128, 784, 128], f16), [128, 785, 128], 1, 1, 9223372036854775807, 1), {})
+cnt: 5, ((T([128, 785, 128], f16), [128, 785, 128], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 784, 16], f16, stride=(100352, 16, 128, 1)), [128, 8, 784, 16], 3, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 784, 16], f16), [128, 8, 785, 16], 2, 1, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 785, 16], f16), [128, 8, 785, 16], 1, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 785, 16], f16), [128, 8, 785, 16], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 1, 128], f16, stride=(100480, 128, 1)), [128, 785, 128], 1, 0, 1, 1), {})
+cnt: 1, ((T([128, 3136, 64], f16, stride=(200704, 1, 3136)), [128, 3136, 64], 2, 0, 9223372036854775807, 1), {})
+cnt: 3, ((T([128, 3136, 64], f16), [128, 3137, 64], 1, 1, 9223372036854775807, 1), {})
+cnt: 5, ((T([128, 3137, 64], f16), [128, 3137, 64], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 3136, 8], f16, stride=(200704, 8, 64, 1)), [128, 8, 3136, 8], 3, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 3136, 8], f16), [128, 8, 3137, 8], 2, 1, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 3137, 8], f16), [128, 8, 3137, 8], 1, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 8, 3137, 8], f16), [128, 8, 3137, 8], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 1, 64], f16, stride=(200768, 64, 1)), [128, 3137, 64], 1, 0, 1, 1), {})
+Operator: aten.split_with_sizes.default
+cnt: 2, ((T([128, 64, 56, 56], f16, stride=(602304, 1, 10752, 192)), [16, 24, 24], 1), {})
+cnt: 2, ((T([128, 128, 28, 28], f16, stride=(301440, 1, 10752, 384)), [32, 48, 48], 1), {})
+cnt: 2, ((T([128, 320, 14, 14], f16, stride=(189120, 1, 13440, 960)), [80, 120, 120], 1), {})
+cnt: 2, ((T([128, 512, 7, 7], f16, stride=(76800, 1, 10752, 1536)), [128, 192, 192], 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 4, ((T([6400, 512], f16), [0], True), {})
+cnt: 2, ((T([6400, 2048], f16), [0], True), {})
+cnt: 2, ((T([6400, 1536], f16), [0], True), {})
+cnt: 1, ((T([128, 1, 512], f16, stride=(25600, 512, 1)), [0], True), {})
+cnt: 4, ((T([25216, 320], f16), [0], True), {})
+cnt: 2, ((T([25216, 1280], f16), [0], True), {})
+cnt: 2, ((T([25216, 960], f16), [0], True), {})
+cnt: 1, ((T([128, 1, 320], f16, stride=(63040, 320, 1)), [0], True), {})
+cnt: 4, ((T([100480, 128], f16), [0], True), {})
+cnt: 2, ((T([100480, 1024], f16), [0], True), {})
+cnt: 2, ((T([100480, 384], f16), [0], True), {})
+cnt: 1, ((T([128, 1, 128], f16, stride=(100480, 128, 1)), [0], True), {})
+cnt: 4, ((T([401536, 64], f16), [0], True), {})
+cnt: 2, ((T([401536, 512], f16), [0], True), {})
+cnt: 2, ((T([401536, 192], f16), [0], True), {})
+cnt: 1, ((T([128, 1, 64], f16, stride=(200768, 64, 1)), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/convmixer_768_32_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/convmixer_768_32_training.txt
new file mode 100644
index 0000000000000..a41c3378022c5
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/convmixer_768_32_training.txt
@@ -0,0 +1,45 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([32, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([32, 1000], f16), T([32, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 64, ((T([32, 768, 32, 32], f16), T([32, 768, 32, 32], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 65, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 768], f16), T([768, 1000], f16, stride=(1, 768))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([768, 3, 7, 7], f16), T([768], f16), [7, 7], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 32, ((T([32, 768, 32, 32], f16), T([768, 1, 7, 7], f16), T([768], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 768), {})
+cnt: 32, ((T([32, 768, 32, 32], f16), T([768, 768, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 32, ((T([32, 768, 32, 32], f16), T([32, 768, 32, 32], f16), T([768, 768, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 32, ((T([32, 768, 32, 32], f16), T([32, 768, 32, 32], f16), T([768, 1, 7, 7], f16), [768], [1, 1], [3, 3], [1, 1], False, [0, 0], 768, [True, True, True]), {})
+cnt: 1, ((T([32, 768, 32, 32], f16), T([32, 3, 224, 224], f16), T([768, 3, 7, 7], f16), [768], [7, 7], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 768, 32, 32], f16, stride=(768, 1, 0, 0)), 1024), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([32], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 768, 32, 32], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16), T([1000, 768], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(1, 1000)), T([32, 768], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 65, ((T([32, 768, 32, 32], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 65, ((T([32, 768, 32, 32], f16), T([32, 768, 32, 32], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([32, 1000], f16), T([32], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([32, 1000], f16), T([32], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 65, ((T([32, 768, 32, 32], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 65, ((T([32, 768, 32, 32], f16), T([32, 768, 32, 32], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/convnext_base_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/convnext_base_training.txt
new file mode 100644
index 0000000000000..8e67418f598fe
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/convnext_base_training.txt
@@ -0,0 +1,210 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([32, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([32, 1000], f16), T([32, 1000], f16), 1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 3, ((T([100352, 512], f16), [32, 56, 56, 512]), {})
+cnt: 3, ((T([100352, 128], f16), [32, 56, 56, 128]), {})
+cnt: 3, ((T([25088, 1024], f16), [32, 28, 28, 1024]), {})
+cnt: 3, ((T([25088, 256], f16), [32, 28, 28, 256]), {})
+cnt: 27, ((T([6272, 2048], f16), [32, 14, 14, 2048]), {})
+cnt: 27, ((T([6272, 512], f16), [32, 14, 14, 512]), {})
+cnt: 3, ((T([1568, 4096], f16), [32, 7, 7, 4096]), {})
+cnt: 3, ((T([1568, 1024], f16), [32, 7, 7, 1024]), {})
+cnt: 3, ((T([32, 7, 7, 1024], f16), [1568, 1024]), {})
+Operator: aten.add.Tensor
+cnt: 3, ((T([32, 56, 56, 512], f16), T([512], f16)), {})
+cnt: 3, ((T([32, 56, 56, 128], f16), T([128], f16)), {})
+cnt: 7, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128))), {})
+cnt: 1, ((T([32, 1, 56, 56], f16), 1e-06), {})
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([128, 1, 1], f16)), {})
+cnt: 3, ((T([32, 28, 28, 1024], f16), T([1024], f16)), {})
+cnt: 3, ((T([32, 28, 28, 256], f16), T([256], f16)), {})
+cnt: 7, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256))), {})
+cnt: 1, ((T([32, 1, 28, 28], f16), 1e-06), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([256, 1, 1], f16)), {})
+cnt: 27, ((T([32, 14, 14, 2048], f16), T([2048], f16)), {})
+cnt: 27, ((T([32, 14, 14, 512], f16), T([512], f16)), {})
+cnt: 55, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512))), {})
+cnt: 1, ((T([32, 1, 14, 14], f16), 1e-06), {})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([512, 1, 1], f16)), {})
+cnt: 3, ((T([32, 7, 7, 4096], f16), T([4096], f16)), {})
+cnt: 3, ((T([32, 7, 7, 1024], f16), T([1024], f16)), {})
+cnt: 3, ((T([32, 1024, 7, 7], f16, stride=(50176, 1, 7168, 1024)), T([32, 1024, 7, 7], f16, stride=(50176, 1, 7168, 1024))), {})
+cnt: 3, ((T([32, 1024, 7, 7], f16), T([32, 1024, 7, 7], f16, stride=(50176, 1, 7168, 1024))), {})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([32, 512, 14, 14], f16)), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([32, 256, 28, 28], f16)), {})
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([32, 128, 56, 56], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.as_strided_.default
+cnt: 1, ((T([32, 1024, 1, 1], f16), [32, 1024, 1, 1], [1024, 1, 1024, 1024]), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([128, 3, 4, 4], f16), T([128], f16), [4, 4], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([128, 1, 7, 7], f16), T([128], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 128), {})
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([256, 128, 2, 2], f16), T([256], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([256, 1, 7, 7], f16), T([256], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 256), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([512, 256, 2, 2], f16), T([512], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 27, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([512, 1, 7, 7], f16), T([512], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 512), {})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([1024, 512, 2, 2], f16), T([1024], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 1024, 7, 7], f16, stride=(50176, 1, 7168, 1024)), T([1024, 1, 7, 7], f16), T([1024], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 1024), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([32, 1024, 7, 7], f16, stride=(50176, 1, 7168, 1024)), T([32, 1024, 7, 7], f16, stride=(50176, 1, 7168, 1024)), T([1024, 1, 7, 7], f16), [1024], [1, 1], [3, 3], [1, 1], False, [0, 0], 1024, [True, True, True]), {})
+cnt: 1, ((T([32, 1024, 7, 7], f16), T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([1024, 512, 2, 2], f16), [1024], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 27, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([512, 1, 7, 7], f16), [512], [1, 1], [3, 3], [1, 1], False, [0, 0], 512, [True, True, True]), {})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([512, 256, 2, 2], f16), [512], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([256, 1, 7, 7], f16), [256], [1, 1], [3, 3], [1, 1], False, [0, 0], 256, [True, True, True]), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([256, 128, 2, 2], f16), [256], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([128, 1, 7, 7], f16), [128], [1, 1], [3, 3], [1, 1], False, [0, 0], 128, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([32, 3, 224, 224], f16), T([128, 3, 4, 4], f16), [128], [4, 4], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+cnt: 1, ((T([32, 1024], f16), T([32, 1024], f16)), {})
+cnt: 1, ((T([1024, 512, 2, 2], f16), T([1024, 512, 2, 2], f16, stride=(2048, 1, 1024, 512))), {})
+cnt: 1, ((T([512, 256, 2, 2], f16), T([512, 256, 2, 2], f16, stride=(1024, 1, 512, 256))), {})
+cnt: 1, ((T([256, 128, 2, 2], f16), T([256, 128, 2, 2], f16, stride=(512, 1, 256, 128))), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 1024, 7, 7], f16, stride=(1024, 1, 0, 0)), 49), {})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(196, 0, 14, 1)), 512), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(784, 0, 28, 1)), 256), {})
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(3136, 0, 56, 1)), 128), {})
+Operator: aten.gelu.default
+cnt: 3, ((T([32, 56, 56, 512], f16),), {})
+cnt: 3, ((T([32, 28, 28, 1024], f16),), {})
+cnt: 27, ((T([32, 14, 14, 2048], f16),), {})
+cnt: 3, ((T([32, 7, 7, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 3, ((T([32, 7, 7, 4096], f16), T([32, 7, 7, 4096], f16)), {})
+cnt: 27, ((T([32, 14, 14, 2048], f16), T([32, 14, 14, 2048], f16)), {})
+cnt: 3, ((T([32, 28, 28, 1024], f16), T([32, 28, 28, 1024], f16)), {})
+cnt: 3, ((T([32, 56, 56, 512], f16), T([32, 56, 56, 512], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([32], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 1024, 7, 7], f16, stride=(50176, 1, 7168, 1024)), [-1, -2], True), {})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), [1], True), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), [1], True), {})
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), [1], True), {})
+Operator: aten.mm.default
+cnt: 3, ((T([100352, 128], f16), T([128, 512], f16, stride=(1, 128))), {})
+cnt: 3, ((T([100352, 512], f16), T([512, 128], f16, stride=(1, 512))), {})
+cnt: 3, ((T([25088, 256], f16), T([256, 1024], f16, stride=(1, 256))), {})
+cnt: 3, ((T([25088, 1024], f16), T([1024, 256], f16, stride=(1, 1024))), {})
+cnt: 27, ((T([6272, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 27, ((T([6272, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+cnt: 3, ((T([1568, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 3, ((T([1568, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+cnt: 1, ((T([32, 1000], f16), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(1, 1000)), T([32, 1024], f16)), {})
+cnt: 3, ((T([1024, 1568], f16, stride=(1, 1024)), T([1568, 4096], f16)), {})
+cnt: 3, ((T([1568, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 3, ((T([4096, 1568], f16, stride=(1, 4096)), T([1568, 1024], f16)), {})
+cnt: 3, ((T([1568, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 27, ((T([512, 6272], f16, stride=(1, 512)), T([6272, 2048], f16)), {})
+cnt: 27, ((T([6272, 512], f16), T([512, 2048], f16)), {})
+cnt: 27, ((T([2048, 6272], f16, stride=(1, 2048)), T([6272, 512], f16)), {})
+cnt: 27, ((T([6272, 2048], f16), T([2048, 512], f16)), {})
+cnt: 3, ((T([256, 25088], f16, stride=(1, 256)), T([25088, 1024], f16)), {})
+cnt: 3, ((T([25088, 256], f16), T([256, 1024], f16)), {})
+cnt: 3, ((T([1024, 25088], f16, stride=(1, 1024)), T([25088, 256], f16)), {})
+cnt: 3, ((T([25088, 1024], f16), T([1024, 256], f16)), {})
+cnt: 3, ((T([128, 100352], f16, stride=(1, 128)), T([100352, 512], f16)), {})
+cnt: 3, ((T([100352, 128], f16), T([128, 512], f16)), {})
+cnt: 3, ((T([512, 100352], f16, stride=(1, 512)), T([100352, 128], f16)), {})
+cnt: 3, ((T([100352, 512], f16), T([512, 128], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 1, ((T([32, 1, 14, 14], f16), -0.5), {})
+cnt: 1, ((T([32, 1, 14, 14], f16), 0.00390625), {})
+cnt: 1, ((T([32, 1, 28, 28], f16), -0.5), {})
+cnt: 1, ((T([32, 1, 28, 28], f16), 0.0078125), {})
+cnt: 1, ((T([32, 1, 56, 56], f16), -0.5), {})
+cnt: 1, ((T([32, 1, 56, 56], f16), 0.015625), {})
+Operator: aten.mul.Tensor
+cnt: 6, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([1, 128, 1, 1], f16)), {})
+cnt: 2, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([32, 1, 56, 56], f16)), {})
+cnt: 2, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([128, 1, 1], f16)), {})
+cnt: 6, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([1, 256, 1, 1], f16)), {})
+cnt: 2, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([32, 1, 28, 28], f16)), {})
+cnt: 2, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([256, 1, 1], f16)), {})
+cnt: 54, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([1, 512, 1, 1], f16)), {})
+cnt: 2, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([32, 1, 14, 14], f16)), {})
+cnt: 2, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([512, 1, 1], f16)), {})
+cnt: 3, ((T([32, 1024, 7, 7], f16, stride=(50176, 1, 7168, 1024)), T([1, 1024, 1, 1], f16)), {})
+cnt: 3, ((T([32, 1024, 7, 7], f16), T([32, 1024, 7, 7], f16, stride=(50176, 1, 7168, 1024))), {})
+cnt: 3, ((T([32, 1024, 7, 7], f16), T([1, 1024, 1, 1], f16)), {})
+cnt: 29, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512))), {})
+cnt: 1, ((T([32, 1, 14, 14], f16), T([32, 1, 14, 14], f16)), {})
+cnt: 1, ((T([32, 1, 14, 14], f16), T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512))), {})
+cnt: 5, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256))), {})
+cnt: 1, ((T([32, 1, 28, 28], f16), T([32, 1, 28, 28], f16)), {})
+cnt: 1, ((T([32, 1, 28, 28], f16), T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256))), {})
+cnt: 5, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128))), {})
+cnt: 1, ((T([32, 1, 56, 56], f16), T([32, 1, 56, 56], f16)), {})
+cnt: 1, ((T([32, 1, 56, 56], f16), T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128))), {})
+Operator: aten.native_layer_norm.default
+cnt: 1, ((T([32, 56, 56, 128], f16, stride=(401408, 56, 1, 3136)), [128], T([128], f16), T([128], f16), 1e-06), {})
+cnt: 3, ((T([32, 56, 56, 128], f16), [128], T([128], f16), T([128], f16), 1e-06), {})
+cnt: 3, ((T([32, 28, 28, 256], f16), [256], T([256], f16), T([256], f16), 1e-06), {})
+cnt: 27, ((T([32, 14, 14, 512], f16), [512], T([512], f16), T([512], f16), 1e-06), {})
+cnt: 3, ((T([32, 7, 7, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-06), {})
+cnt: 1, ((T([32, 1, 1, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 1, ((T([32, 1, 1, 1024], f16), T([32, 1, 1, 1024], f16), [1024], T([32, 1, 1, 1], f32), T([32, 1, 1, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+cnt: 3, ((T([32, 7, 7, 1024], f16), T([32, 7, 7, 1024], f16), [1024], T([32, 7, 7, 1], f32), T([32, 7, 7, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+cnt: 27, ((T([32, 14, 14, 512], f16), T([32, 14, 14, 512], f16), [512], T([32, 14, 14, 1], f32), T([32, 14, 14, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 3, ((T([32, 28, 28, 256], f16), T([32, 28, 28, 256], f16), [256], T([32, 28, 28, 1], f32), T([32, 28, 28, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+cnt: 3, ((T([32, 56, 56, 128], f16), T([32, 56, 56, 128], f16), [128], T([32, 56, 56, 1], f32), T([32, 56, 56, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 1, ((T([32, 56, 56, 128], f16), T([32, 56, 56, 128], f16, stride=(401408, 56, 1, 3136)), [128], T([32, 56, 56, 1], f32), T([32, 56, 56, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+Operator: aten.neg.default
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)),), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)),), {})
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)),), {})
+Operator: aten.new_empty_strided.default
+cnt: 1, ((T([1024, 512, 2, 2], f16, stride=(2048, 1, 1024, 512)), [1024, 512, 2, 2], [2048, 4, 2, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([512, 256, 2, 2], f16, stride=(1024, 1, 512, 256)), [512, 256, 2, 2], [1024, 4, 2, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([256, 128, 2, 2], f16, stride=(512, 1, 256, 128)), [256, 128, 2, 2], [512, 4, 2, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([32, 1024], f16), [32768]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([32, 1000], f16), T([32], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([32, 1000], f16), T([32], i64), None, 1, -100), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 1, ((T([32, 1, 14, 14], f16), 3), {})
+cnt: 1, ((T([32, 1, 28, 28], f16), 3), {})
+cnt: 1, ((T([32, 1, 56, 56], f16), 3), {})
+Operator: aten.rsqrt.default
+cnt: 1, ((T([32, 1, 56, 56], f16),), {})
+cnt: 1, ((T([32, 1, 28, 28], f16),), {})
+cnt: 1, ((T([32, 1, 14, 14], f16),), {})
+Operator: aten.slice_backward.default
+cnt: 2, ((T([512], f16), [512], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([256], f16), [256], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128], f16), [128], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sub.Tensor
+cnt: 2, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([32, 1, 56, 56], f16)), {})
+cnt: 2, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([32, 1, 28, 28], f16)), {})
+cnt: 2, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), T([32, 1, 14, 14], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16), [0], True), {})
+cnt: 3, ((T([32, 1024, 7, 7], f16), [0, 2, 3], True), {})
+cnt: 3, ((T([32, 7, 7, 1024], f16, stride=(50176, 7, 1, 49)), [0, 1, 2], True), {})
+cnt: 3, ((T([32, 7, 7, 4096], f16), [0, 1, 2], True), {})
+cnt: 29, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), [0, 2, 3], True), {})
+cnt: 2, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), [1], True), {})
+cnt: 27, ((T([32, 14, 14, 512], f16), [0, 1, 2], True), {})
+cnt: 27, ((T([32, 14, 14, 2048], f16), [0, 1, 2], True), {})
+cnt: 5, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), [0, 2, 3], True), {})
+cnt: 2, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), [1], True), {})
+cnt: 3, ((T([32, 28, 28, 256], f16), [0, 1, 2], True), {})
+cnt: 3, ((T([32, 28, 28, 1024], f16), [0, 1, 2], True), {})
+cnt: 5, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), [0, 2, 3], True), {})
+cnt: 2, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), [1], True), {})
+cnt: 3, ((T([32, 56, 56, 128], f16), [0, 1, 2], True), {})
+cnt: 3, ((T([32, 56, 56, 512], f16), [0, 1, 2], True), {})
+Operator: aten.var_mean.correction
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), [1]), {'correction': 0, 'keepdim': True})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), [1]), {'correction': 0, 'keepdim': True})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), [1]), {'correction': 0, 'keepdim': True})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/crossvit_9_240_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/crossvit_9_240_training.txt
new file mode 100644
index 0000000000000..eea124ed321f9
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/crossvit_9_240_training.txt
@@ -0,0 +1,203 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 3, ((T([64, 4, 401, 401], f16), -1, False), {})
+cnt: 9, ((T([64, 4, 197, 197], f16), -1, False), {})
+cnt: 3, ((T([64, 4, 1, 197], f16), -1, False), {})
+cnt: 3, ((T([64, 4, 1, 401], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 3, ((T([64, 4, 1, 401], f16), T([64, 4, 1, 401], f16), -1, f16), {})
+cnt: 3, ((T([64, 4, 1, 197], f16), T([64, 4, 1, 197], f16), -1, f16), {})
+cnt: 9, ((T([64, 4, 197, 197], f16), T([64, 4, 197, 197], f16), -1, f16), {})
+cnt: 3, ((T([64, 4, 401, 401], f16), T([64, 4, 401, 401], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([64, 4, 401, 32], f16), [256, 401, 32]), {})
+cnt: 6, ((T([64, 4, 32, 401], f16), [256, 32, 401]), {})
+cnt: 3, ((T([256, 401, 401], f16), [64, 4, 401, 401]), {})
+cnt: 3, ((T([256, 401, 32], f16), [64, 4, 401, 32]), {})
+cnt: 6, ((T([64, 401, 4, 32], f16), [64, 401, 128]), {})
+cnt: 30, ((T([64, 4, 197, 64], f16), [256, 197, 64]), {})
+cnt: 12, ((T([64, 4, 64, 197], f16), [256, 64, 197]), {})
+cnt: 9, ((T([256, 197, 197], f16), [64, 4, 197, 197]), {})
+cnt: 9, ((T([256, 197, 64], f16), [64, 4, 197, 64]), {})
+cnt: 12, ((T([64, 197, 4, 64], f16), [64, 197, 256]), {})
+cnt: 3, ((T([64, 256], f16), [64, 1, 256]), {})
+cnt: 3, ((T([256, 1, 197], f16), [64, 4, 1, 197]), {})
+cnt: 3, ((T([256, 1, 64], f16), [64, 4, 1, 64]), {})
+cnt: 3, ((T([64, 128], f16), [64, 1, 128]), {})
+cnt: 3, ((T([256, 1, 401], f16), [64, 4, 1, 401]), {})
+cnt: 3, ((T([256, 1, 32], f16), [64, 4, 1, 32]), {})
+cnt: 3, ((T([64, 401, 128], f16), [25664, 128]), {})
+cnt: 3, ((T([64, 197, 256], f16), [12608, 256]), {})
+cnt: 9, ((T([64, 197, 3, 4, 64], f16), [64, 197, 768]), {})
+cnt: 3, ((T([64, 401, 3, 4, 32], f16), [64, 401, 384]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([64, 401, 128], f16), T([1, 401, 128], f16)), {})
+cnt: 1, ((T([64, 197, 256], f16), T([1, 197, 256], f16)), {})
+cnt: 27, ((T([64, 401, 128], f16), T([64, 401, 128], f16)), {})
+cnt: 51, ((T([64, 197, 256], f16), T([64, 197, 256], f16)), {})
+cnt: 3, ((T([64, 1, 256], f16), T([256], f16)), {})
+cnt: 3, ((T([64, 1, 256], f16, stride=(50432, 256, 1)), T([64, 1, 256], f16)), {})
+cnt: 3, ((T([64, 1, 128], f16), T([128], f16)), {})
+cnt: 3, ((T([64, 1, 128], f16, stride=(51328, 128, 1)), T([64, 1, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 6, ((T([384], f16), T([25664, 128], f16), T([128, 384], f16, stride=(1, 128))), {})
+cnt: 9, ((T([128], f16), T([25664, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 3, ((T([128], f16), T([25664, 384], f16), T([384, 128], f16, stride=(1, 384))), {})
+cnt: 18, ((T([768], f16), T([12608, 256], f16), T([256, 768], f16, stride=(1, 256))), {})
+cnt: 15, ((T([256], f16), T([12608, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 9, ((T([256], f16), T([12608, 768], f16), T([768, 256], f16, stride=(1, 768))), {})
+cnt: 6, ((T([256], f16), T([64, 128], f16), T([128, 256], f16, stride=(1, 128))), {})
+cnt: 6, ((T([128], f16), T([64, 256], f16), T([256, 128], f16, stride=(1, 256))), {})
+cnt: 3, ((T([256], f16), T([64, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 3, ((T([128], f16), T([64, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 1, ((T([1000], f16), T([64, 128], f16, stride=(51328, 1)), T([128, 1000], f16, stride=(1, 128))), {})
+cnt: 1, ((T([1000], f16), T([64, 256], f16, stride=(50432, 1)), T([256, 1000], f16, stride=(1, 256))), {})
+Operator: aten.bmm.default
+cnt: 3, ((T([256, 401, 32], f16), T([256, 32, 401], f16)), {})
+cnt: 3, ((T([256, 401, 401], f16), T([256, 401, 32], f16)), {})
+cnt: 9, ((T([256, 197, 64], f16), T([256, 64, 197], f16)), {})
+cnt: 9, ((T([256, 197, 197], f16), T([256, 197, 64], f16)), {})
+cnt: 3, ((T([256, 1, 64], f16), T([256, 64, 197], f16)), {})
+cnt: 3, ((T([256, 1, 197], f16), T([256, 197, 64], f16)), {})
+cnt: 3, ((T([256, 1, 32], f16), T([256, 32, 401], f16)), {})
+cnt: 3, ((T([256, 1, 401], f16), T([256, 401, 32], f16)), {})
+cnt: 3, ((T([256, 401, 1], f16), T([256, 1, 32], f16)), {})
+cnt: 3, ((T([256, 1, 32], f16), T([256, 32, 401], f16, stride=(12832, 1, 32))), {})
+cnt: 3, ((T([256, 32, 1], f16), T([256, 1, 401], f16)), {})
+cnt: 3, ((T([256, 1, 401], f16), T([256, 401, 32], f16, stride=(12832, 1, 401))), {})
+cnt: 3, ((T([256, 197, 1], f16), T([256, 1, 64], f16)), {})
+cnt: 3, ((T([256, 1, 64], f16), T([256, 64, 197], f16, stride=(12608, 1, 64))), {})
+cnt: 3, ((T([256, 64, 1], f16), T([256, 1, 197], f16)), {})
+cnt: 3, ((T([256, 1, 197], f16), T([256, 197, 64], f16, stride=(12608, 1, 197))), {})
+cnt: 9, ((T([256, 197, 197], f16, stride=(38809, 1, 197)), T([256, 197, 64], f16)), {})
+cnt: 9, ((T([256, 197, 64], f16), T([256, 64, 197], f16, stride=(12608, 1, 64))), {})
+cnt: 9, ((T([256, 64, 197], f16, stride=(12608, 1, 64)), T([256, 197, 197], f16)), {})
+cnt: 9, ((T([256, 197, 197], f16), T([256, 197, 64], f16, stride=(12608, 1, 197))), {})
+cnt: 3, ((T([256, 401, 401], f16, stride=(160801, 1, 401)), T([256, 401, 32], f16)), {})
+cnt: 3, ((T([256, 401, 32], f16), T([256, 32, 401], f16, stride=(12832, 1, 32))), {})
+cnt: 3, ((T([256, 32, 401], f16, stride=(12832, 1, 32)), T([256, 401, 401], f16)), {})
+cnt: 3, ((T([256, 401, 401], f16), T([256, 401, 32], f16, stride=(12832, 1, 401))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 1, 128], f16, stride=(0, 128, 1)), T([64, 400, 128], f16, stride=(51200, 1, 400))], 1), {})
+cnt: 1, (([T([64, 1, 256], f16, stride=(0, 256, 1)), T([64, 196, 256], f16, stride=(50176, 1, 196))], 1), {})
+cnt: 6, (([T([64, 1, 256], f16), T([64, 196, 256], f16, stride=(50432, 256, 1))], 1), {})
+cnt: 6, (([T([64, 1, 128], f16), T([64, 400, 128], f16, stride=(51328, 128, 1))], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 240, 240], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 240, 240], f16), T([128, 3, 12, 12], f16), T([128], f16), [12, 12], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 3, 224, 224], f16), T([256, 3, 16, 16], f16), T([256], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 256, 14, 14], f16, stride=(50432, 1, 3584, 256)), T([64, 3, 224, 224], f16), T([256, 3, 16, 16], f16), [256], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+cnt: 1, ((T([64, 128, 20, 20], f16, stride=(51328, 1, 2560, 128)), T([64, 3, 240, 240], f16), T([128, 3, 12, 12], f16), [128], [12, 12], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 240, 240], f16), T([64, 3, 240, 240], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([2, 64, 1000], f16, stride=(0, 1000, 1)), 2), {})
+Operator: aten.gelu.default
+cnt: 3, ((T([64, 401, 384], f16),), {})
+cnt: 9, ((T([64, 197, 768], f16),), {})
+cnt: 6, ((T([64, 1, 128], f16),), {})
+cnt: 6, ((T([64, 1, 256], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 6, ((T([64, 1, 128], f16), T([64, 1, 128], f16)), {})
+cnt: 6, ((T([64, 1, 256], f16), T([64, 1, 256], f16)), {})
+cnt: 9, ((T([64, 197, 768], f16), T([64, 197, 768], f16)), {})
+cnt: 3, ((T([64, 401, 384], f16), T([64, 401, 384], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([2, 64, 1000], f16), [0]), {})
+Operator: aten.mm.default
+cnt: 3, ((T([64, 256], f16, stride=(50432, 1)), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 3, ((T([64, 128], f16, stride=(51328, 1)), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 1, ((T([64, 1000], f16), T([1000, 256], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 256], f16, stride=(50432, 1))), {})
+cnt: 1, ((T([64, 1000], f16), T([1000, 128], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 128], f16, stride=(51328, 1))), {})
+cnt: 6, ((T([64, 256], f16, stride=(50432, 1)), T([256, 128], f16)), {})
+cnt: 6, ((T([256, 64], f16, stride=(1, 50432)), T([64, 128], f16)), {})
+cnt: 6, ((T([64, 128], f16), T([128, 128], f16)), {})
+cnt: 3, ((T([128, 64], f16, stride=(1, 128)), T([64, 128], f16)), {})
+cnt: 9, ((T([25664, 128], f16), T([128, 128], f16)), {})
+cnt: 9, ((T([128, 25664], f16, stride=(1, 128)), T([25664, 128], f16)), {})
+cnt: 3, ((T([128, 64], f16, stride=(1, 128)), T([64, 128], f16, stride=(51328, 1))), {})
+cnt: 6, ((T([64, 128], f16, stride=(51328, 1)), T([128, 256], f16)), {})
+cnt: 6, ((T([128, 64], f16, stride=(1, 51328)), T([64, 256], f16)), {})
+cnt: 6, ((T([64, 256], f16), T([256, 256], f16)), {})
+cnt: 3, ((T([256, 64], f16, stride=(1, 256)), T([64, 256], f16)), {})
+cnt: 15, ((T([12608, 256], f16), T([256, 256], f16)), {})
+cnt: 15, ((T([256, 12608], f16, stride=(1, 256)), T([12608, 256], f16)), {})
+cnt: 3, ((T([256, 64], f16, stride=(1, 256)), T([64, 256], f16, stride=(50432, 1))), {})
+cnt: 9, ((T([12608, 256], f16), T([256, 768], f16)), {})
+cnt: 9, ((T([256, 12608], f16, stride=(1, 256)), T([12608, 768], f16)), {})
+cnt: 18, ((T([12608, 768], f16), T([768, 256], f16)), {})
+cnt: 18, ((T([768, 12608], f16, stride=(1, 768)), T([12608, 256], f16)), {})
+cnt: 3, ((T([25664, 128], f16), T([128, 384], f16)), {})
+cnt: 3, ((T([128, 25664], f16, stride=(1, 128)), T([25664, 384], f16)), {})
+cnt: 6, ((T([25664, 384], f16), T([384, 128], f16)), {})
+cnt: 6, ((T([384, 25664], f16, stride=(1, 384)), T([25664, 128], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 6, ((T([64, 4, 401, 401], f16), 0.1767766952966369), {})
+cnt: 18, ((T([64, 4, 197, 197], f16), 0.125), {})
+cnt: 6, ((T([64, 4, 1, 197], f16), 0.125), {})
+cnt: 6, ((T([64, 4, 1, 401], f16), 0.1767766952966369), {})
+Operator: aten.native_layer_norm.default
+cnt: 10, ((T([64, 401, 128], f16), [128], T([128], f16), T([128], f16), 1e-06), {})
+cnt: 22, ((T([64, 197, 256], f16), [256], T([256], f16), T([256], f16), 1e-06), {})
+cnt: 3, ((T([64, 1, 128], f16, stride=(51328, 128, 1)), [128], T([128], f16), T([128], f16), 1e-06), {})
+cnt: 3, ((T([64, 1, 256], f16, stride=(50432, 256, 1)), [256], T([256], f16), T([256], f16), 1e-06), {})
+cnt: 3, ((T([64, 1, 256], f16), [256], T([256], f16), T([256], f16), 1e-06), {})
+cnt: 3, ((T([64, 1, 128], f16), [128], T([128], f16), T([128], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 22, ((T([64, 197, 256], f16), T([64, 197, 256], f16), [256], T([64, 197, 1], f32), T([64, 197, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+cnt: 10, ((T([64, 401, 128], f16), T([64, 401, 128], f16), [128], T([64, 401, 1], f32), T([64, 401, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 3, ((T([64, 1, 128], f16), T([64, 1, 128], f16), [128], T([64, 1, 1], f32), T([64, 1, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 3, ((T([64, 1, 256], f16), T([64, 1, 256], f16), [256], T([64, 1, 1], f32), T([64, 1, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+cnt: 3, ((T([64, 1, 256], f16), T([64, 1, 256], f16, stride=(50432, 256, 1)), [256], T([64, 1, 1], f32), T([64, 1, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+cnt: 3, ((T([64, 1, 128], f16), T([64, 1, 128], f16, stride=(51328, 128, 1)), [128], T([64, 1, 1], f32), T([64, 1, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([64, 256], f16), [64, 197, 256], 1, 0), {})
+cnt: 1, ((T([64, 128], f16), [64, 401, 128], 1, 0), {})
+Operator: aten.slice_backward.default
+cnt: 16, ((T([64, 197, 256], f16), [64, 197, 256], 0, 0, 9223372036854775807, 1), {})
+cnt: 16, ((T([64, 401, 128], f16), [64, 401, 128], 0, 0, 9223372036854775807, 1), {})
+cnt: 6, ((T([64, 196, 256], f16, stride=(50432, 256, 1)), [64, 197, 256], 1, 1, 9223372036854775807, 1), {})
+cnt: 3, ((T([64, 1, 128], f16), [64, 1, 128], 0, 0, 9223372036854775807, 1), {})
+cnt: 9, ((T([64, 1, 128], f16), [64, 401, 128], 1, 0, 1, 1), {})
+cnt: 6, ((T([64, 400, 128], f16, stride=(51328, 128, 1)), [64, 401, 128], 1, 1, 9223372036854775807, 1), {})
+cnt: 3, ((T([64, 1, 256], f16), [64, 1, 256], 0, 0, 9223372036854775807, 1), {})
+cnt: 9, ((T([64, 1, 256], f16), [64, 197, 256], 1, 0, 1, 1), {})
+Operator: aten.stack.default
+cnt: 1, (([T([64, 1000], f16), T([64, 1000], f16)],), {})
+cnt: 9, (([T([64, 4, 197, 64], f16), T([64, 4, 197, 64], f16, stride=(50432, 12608, 1, 197)), T([64, 4, 197, 64], f16)],), {})
+cnt: 3, (([T([64, 4, 401, 32], f16), T([64, 4, 401, 32], f16, stride=(51328, 12832, 1, 401)), T([64, 4, 401, 32], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 2, ((T([64, 1000], f16), [0], True), {})
+cnt: 6, ((T([64, 256], f16, stride=(50432, 1)), [0], True), {})
+cnt: 3, ((T([64, 128], f16), [0], True), {})
+cnt: 12, ((T([25664, 128], f16), [0], True), {})
+cnt: 3, ((T([64, 1, 128], f16), [0, 1], True), {})
+cnt: 6, ((T([64, 128], f16, stride=(51328, 1)), [0], True), {})
+cnt: 3, ((T([64, 256], f16), [0], True), {})
+cnt: 24, ((T([12608, 256], f16), [0], True), {})
+cnt: 3, ((T([64, 1, 256], f16), [0, 1], True), {})
+cnt: 18, ((T([12608, 768], f16), [0], True), {})
+cnt: 6, ((T([25664, 384], f16), [0], True), {})
+cnt: 1, ((T([64, 197, 256], f16), [0], True), {})
+cnt: 1, ((T([64, 1, 256], f16, stride=(50432, 256, 1)), [0], True), {})
+cnt: 1, ((T([64, 401, 128], f16), [0], True), {})
+cnt: 1, ((T([64, 1, 128], f16, stride=(51328, 128, 1)), [0], True), {})
+Operator: aten.unbind.int
+cnt: 3, ((T([3, 64, 4, 401, 32], f16, stride=(128, 153984, 32, 384, 1)),), {})
+cnt: 9, ((T([3, 64, 4, 197, 64], f16, stride=(256, 151296, 64, 768, 1)),), {})
+cnt: 1, ((T([2, 64, 1000], f16),), {})
+Operator: aten.upsample_bicubic2d.vec
+cnt: 1, ((T([64, 3, 240, 240], f16), [224, 224], False, None), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/cspdarknet53_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/cspdarknet53_training.txt
new file mode 100644
index 0000000000000..9332a617dadd6
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/cspdarknet53_training.txt
@@ -0,0 +1,177 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 67, ((T([], i64), 1), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 64, 128, 128], f16, stride=(2097152, 16384, 128, 1))), {})
+cnt: 1, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16, stride=(524288, 4096, 64, 1))), {})
+cnt: 3, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16)), {})
+cnt: 1, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16, stride=(262144, 1024, 32, 1))), {})
+cnt: 15, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16)), {})
+cnt: 1, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16, stride=(131072, 256, 16, 1))), {})
+cnt: 15, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16)), {})
+cnt: 1, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16, stride=(65536, 64, 8, 1))), {})
+cnt: 7, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16)), {})
+cnt: 1, ((T([64, 1024, 8, 8], f16), T([64, 1024, 8, 8], f16)), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16)), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([64, 256, 32, 32], f16)), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([64, 128, 64, 64], f16)), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 64, 128, 128], f16)), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([64, 128, 128, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([64, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 64, 128, 128], f16, stride=(2097152, 16384, 128, 1)), T([64, 64, 128, 128], f16)], 1), {})
+cnt: 1, (([T([64, 64, 64, 64], f16, stride=(524288, 4096, 64, 1)), T([64, 64, 64, 64], f16)], 1), {})
+cnt: 1, (([T([64, 128, 32, 32], f16, stride=(262144, 1024, 32, 1)), T([64, 128, 32, 32], f16)], 1), {})
+cnt: 1, (([T([64, 256, 16, 16], f16, stride=(131072, 256, 16, 1)), T([64, 256, 16, 16], f16)], 1), {})
+cnt: 1, (([T([64, 512, 8, 8], f16, stride=(65536, 64, 8, 1)), T([64, 512, 8, 8], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 256, 256], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 256, 256], f16), T([32, 3, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 256, 256], f16), T([64, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([128, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 128, 128], f16, stride=(2097152, 16384, 128, 1)), T([32, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([64, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([128, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 128, 64, 64], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 64, 64], f16, stride=(524288, 4096, 64, 1)), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 64, 64, 64], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 64, 64, 64], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([256, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 256, 32, 32], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 32, 32], f16, stride=(262144, 1024, 32, 1)), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([64, 128, 32, 32], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([64, 128, 32, 32], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([512, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 512, 16, 16], f16), T([512, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 16, 16], f16, stride=(131072, 256, 16, 1)), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([64, 256, 16, 16], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([64, 256, 16, 16], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([1024, 512, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 1024, 8, 8], f16), T([1024, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 8, 8], f16, stride=(65536, 64, 8, 1)), T([512, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([64, 512, 8, 8], f16), T([512, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([64, 512, 8, 8], f16), T([512, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 2, ((T([64, 1024, 8, 8], f16), T([64, 1024, 8, 8], f16), T([1024, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16), T([512, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16), T([512, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16, stride=(65536, 64, 8, 1)), T([512, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 1024, 8, 8], f16), T([64, 512, 16, 16], f16), T([1024, 512, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16), T([512, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16, stride=(131072, 256, 16, 1)), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([64, 256, 32, 32], f16), T([512, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 256, 32, 32], f16), T([64, 256, 32, 32], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16, stride=(262144, 1024, 32, 1)), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([64, 128, 64, 64], f16), T([256, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 128, 64, 64], f16), T([64, 128, 64, 64], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16, stride=(524288, 4096, 64, 1)), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([64, 64, 128, 128], f16), T([128, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 128, 128, 128], f16), T([64, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 64, 128, 128], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 32, 128, 128], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 64, 128, 128], f16, stride=(2097152, 16384, 128, 1)), T([32, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([64, 64, 128, 128], f16), T([128, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 32, 256, 256], f16), T([64, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 256, 256], f16), T([64, 3, 256, 256], f16), T([32, 3, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 256, 256], f16), T([64, 3, 256, 256], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 1024, 8, 8], f16, stride=(1024, 1, 0, 0)), 64), {})
+Operator: aten.leaky_relu_.default
+cnt: 1, ((T([64, 32, 256, 256], f16),), {})
+cnt: 4, ((T([64, 64, 128, 128], f16),), {})
+cnt: 1, ((T([64, 128, 128, 128], f16),), {})
+cnt: 1, ((T([64, 32, 128, 128], f16),), {})
+cnt: 3, ((T([64, 128, 64, 64], f16),), {})
+cnt: 5, ((T([64, 64, 64, 64], f16),), {})
+cnt: 3, ((T([64, 256, 32, 32], f16),), {})
+cnt: 17, ((T([64, 128, 32, 32], f16),), {})
+cnt: 3, ((T([64, 512, 16, 16], f16),), {})
+cnt: 17, ((T([64, 256, 16, 16], f16),), {})
+cnt: 3, ((T([64, 1024, 8, 8], f16),), {})
+cnt: 9, ((T([64, 512, 8, 8], f16),), {})
+Operator: aten.leaky_relu_backward.default
+cnt: 3, ((T([64, 1024, 8, 8], f16), T([64, 1024, 8, 8], f16), 0.01, True), {})
+cnt: 1, ((T([64, 512, 8, 8], f16, stride=(65536, 64, 8, 1)), T([64, 512, 8, 8], f16), 0.01, True), {})
+cnt: 8, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16), 0.01, True), {})
+cnt: 3, ((T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16), 0.01, True), {})
+cnt: 1, ((T([64, 256, 16, 16], f16, stride=(131072, 256, 16, 1)), T([64, 256, 16, 16], f16), 0.01, True), {})
+cnt: 16, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16), 0.01, True), {})
+cnt: 3, ((T([64, 256, 32, 32], f16), T([64, 256, 32, 32], f16), 0.01, True), {})
+cnt: 1, ((T([64, 128, 32, 32], f16, stride=(262144, 1024, 32, 1)), T([64, 128, 32, 32], f16), 0.01, True), {})
+cnt: 16, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16), 0.01, True), {})
+cnt: 3, ((T([64, 128, 64, 64], f16), T([64, 128, 64, 64], f16), 0.01, True), {})
+cnt: 1, ((T([64, 64, 64, 64], f16, stride=(524288, 4096, 64, 1)), T([64, 64, 64, 64], f16), 0.01, True), {})
+cnt: 4, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16), 0.01, True), {})
+cnt: 3, ((T([64, 64, 128, 128], f16), T([64, 64, 128, 128], f16), 0.01, True), {})
+cnt: 1, ((T([64, 64, 128, 128], f16, stride=(2097152, 16384, 128, 1)), T([64, 64, 128, 128], f16), 0.01, True), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 32, 128, 128], f16), 0.01, True), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([64, 128, 128, 128], f16), 0.01, True), {})
+cnt: 1, ((T([64, 32, 256, 256], f16), T([64, 32, 256, 256], f16), 0.01, True), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 1024, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 1024], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([64, 32, 256, 256], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([64, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 17, ((T([64, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 17, ((T([64, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 1024, 8, 8], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([64, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([64, 1024, 8, 8], f16), T([64, 1024, 8, 8], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 17, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 256, 32, 32], f16), T([64, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 17, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 128, 64, 64], f16), T([64, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 64, 128, 128], f16), T([64, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([64, 128, 128, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 32, 256, 256], f16), T([64, 32, 256, 256], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([64, 512, 8, 8], f16), [64, 1024, 8, 8], 1, 512, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 1024, 8, 8], f16), [64, 1024, 8, 8], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 512, 8, 8], f16, stride=(65536, 64, 8, 1)), [64, 1024, 8, 8], 1, 0, 512, 1), {})
+cnt: 1, ((T([64, 256, 16, 16], f16), [64, 512, 16, 16], 1, 256, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 512, 16, 16], f16), [64, 512, 16, 16], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 256, 16, 16], f16, stride=(131072, 256, 16, 1)), [64, 512, 16, 16], 1, 0, 256, 1), {})
+cnt: 1, ((T([64, 128, 32, 32], f16), [64, 256, 32, 32], 1, 128, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 256, 32, 32], f16), [64, 256, 32, 32], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 128, 32, 32], f16, stride=(262144, 1024, 32, 1)), [64, 256, 32, 32], 1, 0, 128, 1), {})
+cnt: 1, ((T([64, 64, 64, 64], f16), [64, 128, 64, 64], 1, 64, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 128, 64, 64], f16), [64, 128, 64, 64], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 64, 64, 64], f16, stride=(524288, 4096, 64, 1)), [64, 128, 64, 64], 1, 0, 64, 1), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), [64, 128, 128, 128], 1, 64, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 128, 128, 128], f16), [64, 128, 128, 128], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 64, 128, 128], f16, stride=(2097152, 16384, 128, 1)), [64, 128, 128, 128], 1, 0, 64, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/deit_base_distilled_patch16_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/deit_base_distilled_patch16_224_training.txt
new file mode 100644
index 0000000000000..486ee80cd59a3
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/deit_base_distilled_patch16_224_training.txt
@@ -0,0 +1,87 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([64, 12, 198, 198], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([64, 12, 198, 198], f16), T([64, 12, 198, 198], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([64, 12, 198, 64], f16), [768, 198, 64]), {})
+cnt: 12, ((T([64, 12, 64, 198], f16), [768, 64, 198]), {})
+cnt: 12, ((T([768, 198, 198], f16), [64, 12, 198, 198]), {})
+cnt: 12, ((T([768, 198, 64], f16), [64, 12, 198, 64]), {})
+cnt: 12, ((T([64, 198, 12, 64], f16), [64, 198, 768]), {})
+cnt: 12, ((T([64, 198, 3, 12, 64], f16), [64, 198, 2304]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([64, 198, 768], f16), T([1, 198, 768], f16)), {})
+cnt: 49, ((T([64, 198, 768], f16), T([64, 198, 768], f16)), {})
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16)), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([2304], f16), T([12672, 768], f16), T([768, 2304], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([12672, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([12672, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([12672, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 2, ((T([1000], f16), T([64, 768], f16, stride=(152064, 1)), T([768, 1000], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([768, 198, 64], f16), T([768, 64, 198], f16)), {})
+cnt: 12, ((T([768, 198, 198], f16), T([768, 198, 64], f16)), {})
+cnt: 12, ((T([768, 198, 198], f16, stride=(39204, 1, 198)), T([768, 198, 64], f16)), {})
+cnt: 12, ((T([768, 198, 64], f16), T([768, 64, 198], f16, stride=(12672, 1, 64))), {})
+cnt: 12, ((T([768, 64, 198], f16, stride=(12672, 1, 64)), T([768, 198, 198], f16)), {})
+cnt: 12, ((T([768, 198, 198], f16), T([768, 198, 64], f16, stride=(12672, 1, 198))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 1, 768], f16, stride=(0, 768, 1)), T([64, 1, 768], f16, stride=(0, 768, 1)), T([64, 196, 768], f16, stride=(150528, 1, 196))], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([768, 3, 16, 16], f16), T([768], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 768, 14, 14], f16, stride=(152064, 1, 10752, 768)), T([64, 3, 224, 224], f16), T([768, 3, 16, 16], f16), [768], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([64, 1000], f16), 2), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([64, 198, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([64, 198, 3072], f16), T([64, 198, 3072], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mm.default
+cnt: 2, ((T([64, 1000], f16), T([1000, 768], f16)), {})
+cnt: 2, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 768], f16, stride=(152064, 1))), {})
+cnt: 12, ((T([12672, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 12672], f16, stride=(1, 768)), T([12672, 3072], f16)), {})
+cnt: 12, ((T([12672, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 12672], f16, stride=(1, 3072)), T([12672, 768], f16)), {})
+cnt: 12, ((T([12672, 768], f16), T([768, 768], f16)), {})
+cnt: 12, ((T([768, 12672], f16, stride=(1, 768)), T([12672, 768], f16)), {})
+cnt: 12, ((T([12672, 2304], f16), T([2304, 768], f16)), {})
+cnt: 12, ((T([2304, 12672], f16, stride=(1, 2304)), T([12672, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 24, ((T([64, 12, 198, 198], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([64, 198, 768], f16), [768], T([768], f16), T([768], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([64, 198, 768], f16), T([64, 198, 768], f16), [768], T([64, 198, 1], f32), T([64, 198, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([64, 768], f16), [64, 198, 768], 1, 1), {})
+cnt: 1, ((T([64, 768], f16), [64, 198, 768], 1, 0), {})
+Operator: aten.slice_backward.default
+cnt: 2, ((T([64, 198, 768], f16), [64, 198, 768], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.stack.default
+cnt: 12, (([T([64, 12, 198, 64], f16), T([64, 12, 198, 64], f16, stride=(152064, 12672, 1, 198)), T([64, 12, 198, 64], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 2, ((T([64, 1000], f16), [0], True), {})
+cnt: 24, ((T([12672, 768], f16), [0], True), {})
+cnt: 12, ((T([12672, 3072], f16), [0], True), {})
+cnt: 12, ((T([12672, 2304], f16), [0], True), {})
+cnt: 1, ((T([64, 198, 768], f16), [0], True), {})
+cnt: 2, ((T([64, 1, 768], f16, stride=(152064, 768, 1)), [0], True), {})
+Operator: aten.unbind.int
+cnt: 12, ((T([3, 64, 12, 198, 64], f16, stride=(768, 456192, 64, 2304, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/densenet121_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/densenet121_training.txt
new file mode 100644
index 0000000000000..983f9ccb10448
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/densenet121_training.txt
@@ -0,0 +1,616 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 121, ((T([], i64), 1), {})
+cnt: 1, ((T([64, 512, 7, 7], f16, stride=(50176, 49, 7, 1)), T([64, 512, 7, 7], f16, stride=(48608, 49, 7, 1))), {})
+cnt: 15, ((T([64, 32, 7, 7], f16, stride=(50176, 49, 7, 1)), T([64, 32, 7, 7], f16, stride=(48608, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(47040, 49, 7, 1))), {})
+cnt: 14, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(47040, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(45472, 49, 7, 1))), {})
+cnt: 13, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(45472, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(43904, 49, 7, 1))), {})
+cnt: 12, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(43904, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(42336, 49, 7, 1))), {})
+cnt: 11, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(42336, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(40768, 49, 7, 1))), {})
+cnt: 10, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(40768, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(39200, 49, 7, 1))), {})
+cnt: 9, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(39200, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(37632, 49, 7, 1))), {})
+cnt: 8, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(37632, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(36064, 49, 7, 1))), {})
+cnt: 7, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(36064, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(34496, 49, 7, 1))), {})
+cnt: 6, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(34496, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(32928, 49, 7, 1))), {})
+cnt: 5, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(32928, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(31360, 49, 7, 1))), {})
+cnt: 4, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(31360, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(29792, 49, 7, 1))), {})
+cnt: 3, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(29792, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(28224, 49, 7, 1))), {})
+cnt: 2, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(28224, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16, stride=(26656, 49, 7, 1))), {})
+cnt: 1, ((T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16, stride=(26656, 49, 7, 1))), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16)), {})
+cnt: 1, ((T([64, 256, 14, 14], f16, stride=(200704, 196, 14, 1)), T([64, 256, 14, 14], f16, stride=(194432, 196, 14, 1))), {})
+cnt: 23, ((T([64, 32, 14, 14], f16, stride=(200704, 196, 14, 1)), T([64, 32, 14, 14], f16, stride=(194432, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(188160, 196, 14, 1))), {})
+cnt: 22, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(188160, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(181888, 196, 14, 1))), {})
+cnt: 21, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(181888, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(175616, 196, 14, 1))), {})
+cnt: 20, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(175616, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(169344, 196, 14, 1))), {})
+cnt: 19, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(169344, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(163072, 196, 14, 1))), {})
+cnt: 18, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(163072, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(156800, 196, 14, 1))), {})
+cnt: 17, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(156800, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(150528, 196, 14, 1))), {})
+cnt: 16, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(150528, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(144256, 196, 14, 1))), {})
+cnt: 15, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(144256, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(137984, 196, 14, 1))), {})
+cnt: 14, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(137984, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(131712, 196, 14, 1))), {})
+cnt: 13, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(131712, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(125440, 196, 14, 1))), {})
+cnt: 12, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(125440, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(119168, 196, 14, 1))), {})
+cnt: 11, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(119168, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(112896, 196, 14, 1))), {})
+cnt: 10, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(112896, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(106624, 196, 14, 1))), {})
+cnt: 9, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(106624, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(100352, 196, 14, 1))), {})
+cnt: 8, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(100352, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(94080, 196, 14, 1))), {})
+cnt: 7, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(94080, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(87808, 196, 14, 1))), {})
+cnt: 6, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(87808, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(81536, 196, 14, 1))), {})
+cnt: 5, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(81536, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(75264, 196, 14, 1))), {})
+cnt: 4, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(75264, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(68992, 196, 14, 1))), {})
+cnt: 3, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(68992, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(62720, 196, 14, 1))), {})
+cnt: 2, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(62720, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16, stride=(56448, 196, 14, 1))), {})
+cnt: 1, ((T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16, stride=(56448, 196, 14, 1))), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16)), {})
+cnt: 1, ((T([64, 128, 28, 28], f16, stride=(401408, 784, 28, 1)), T([64, 128, 28, 28], f16, stride=(376320, 784, 28, 1))), {})
+cnt: 11, ((T([64, 32, 28, 28], f16, stride=(401408, 784, 28, 1)), T([64, 32, 28, 28], f16, stride=(376320, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(351232, 784, 28, 1))), {})
+cnt: 10, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(351232, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(326144, 784, 28, 1))), {})
+cnt: 9, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(326144, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(301056, 784, 28, 1))), {})
+cnt: 8, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(301056, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(275968, 784, 28, 1))), {})
+cnt: 7, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(275968, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(250880, 784, 28, 1))), {})
+cnt: 6, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(250880, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(225792, 784, 28, 1))), {})
+cnt: 5, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(225792, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(200704, 784, 28, 1))), {})
+cnt: 4, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(200704, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(175616, 784, 28, 1))), {})
+cnt: 3, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(175616, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(150528, 784, 28, 1))), {})
+cnt: 2, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(150528, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16, stride=(125440, 784, 28, 1))), {})
+cnt: 1, ((T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16, stride=(125440, 784, 28, 1))), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16)), {})
+cnt: 1, ((T([64, 64, 56, 56], f16, stride=(802816, 3136, 56, 1)), T([64, 64, 56, 56], f16, stride=(702464, 3136, 56, 1))), {})
+cnt: 5, ((T([64, 32, 56, 56], f16, stride=(802816, 3136, 56, 1)), T([64, 32, 56, 56], f16, stride=(702464, 3136, 56, 1))), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16, stride=(602112, 3136, 56, 1))), {})
+cnt: 4, ((T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16, stride=(602112, 3136, 56, 1))), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16, stride=(501760, 3136, 56, 1))), {})
+cnt: 3, ((T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16, stride=(501760, 3136, 56, 1))), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16, stride=(401408, 3136, 56, 1))), {})
+cnt: 2, ((T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16, stride=(401408, 3136, 56, 1))), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16, stride=(301056, 3136, 56, 1))), {})
+cnt: 1, ((T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16, stride=(301056, 3136, 56, 1))), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([64, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([64, 128, 56, 56], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), [2, 2], [2, 2]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 14, 14], f16), [2, 2], [2, 2], [0, 0], False, True, None), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 28, 28], f16), [2, 2], [2, 2], [0, 0], False, True, None), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 56, 56], f16), [2, 2], [2, 2], [0, 0], False, True, None), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 64, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 64, 56, 56], f16), T([64, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 64, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 64, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 64, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 64, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 64, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16), T([64, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 128, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16), T([64, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 256, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16), T([64, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 512, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16), T([64, 32, 7, 7], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([128, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 128, 56, 56], f16), T([32, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 96, 56, 56], f16), T([128, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 160, 56, 56], f16), T([128, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([128, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 224, 56, 56], f16), T([128, 224, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([64, 128, 28, 28], f16), T([32, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 160, 28, 28], f16), T([128, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 192, 28, 28], f16), T([128, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 224, 28, 28], f16), T([128, 224, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 288, 28, 28], f16), T([128, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 320, 28, 28], f16), T([128, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 352, 28, 28], f16), T([128, 352, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 384, 28, 28], f16), T([128, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 416, 28, 28], f16), T([128, 416, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 448, 28, 28], f16), T([128, 448, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 480, 28, 28], f16), T([128, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 24, ((T([64, 128, 14, 14], f16), T([32, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 288, 14, 14], f16), T([128, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 320, 14, 14], f16), T([128, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 352, 14, 14], f16), T([128, 352, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 384, 14, 14], f16), T([128, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 416, 14, 14], f16), T([128, 416, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 448, 14, 14], f16), T([128, 448, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 480, 14, 14], f16), T([128, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 544, 14, 14], f16), T([128, 544, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 576, 14, 14], f16), T([128, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 608, 14, 14], f16), T([128, 608, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 640, 14, 14], f16), T([128, 640, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 672, 14, 14], f16), T([128, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 704, 14, 14], f16), T([128, 704, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 736, 14, 14], f16), T([128, 736, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 768, 14, 14], f16), T([128, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 800, 14, 14], f16), T([128, 800, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 832, 14, 14], f16), T([128, 832, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 864, 14, 14], f16), T([128, 864, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 896, 14, 14], f16), T([128, 896, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 928, 14, 14], f16), T([128, 928, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([128, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 992, 14, 14], f16), T([128, 992, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 16, ((T([64, 128, 7, 7], f16), T([32, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 544, 7, 7], f16), T([128, 544, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 576, 7, 7], f16), T([128, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 608, 7, 7], f16), T([128, 608, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 640, 7, 7], f16), T([128, 640, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 672, 7, 7], f16), T([128, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 704, 7, 7], f16), T([128, 704, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 736, 7, 7], f16), T([128, 736, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 768, 7, 7], f16), T([128, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 800, 7, 7], f16), T([128, 800, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 832, 7, 7], f16), T([128, 832, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 864, 7, 7], f16), T([128, 864, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 896, 7, 7], f16), T([128, 896, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 928, 7, 7], f16), T([128, 928, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([128, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 992, 7, 7], f16), T([128, 992, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 32, 7, 7], f16, stride=(50176, 49, 7, 1)), T([64, 128, 7, 7], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 992, 7, 7], f16), T([128, 992, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 15, ((T([64, 32, 7, 7], f16), T([64, 128, 7, 7], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 960, 7, 7], f16), T([128, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 928, 7, 7], f16), T([128, 928, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 896, 7, 7], f16), T([128, 896, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 864, 7, 7], f16), T([128, 864, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 832, 7, 7], f16), T([128, 832, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 800, 7, 7], f16), T([128, 800, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 768, 7, 7], f16), T([128, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 736, 7, 7], f16), T([128, 736, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 704, 7, 7], f16), T([128, 704, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 672, 7, 7], f16), T([128, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 640, 7, 7], f16), T([128, 640, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 608, 7, 7], f16), T([128, 608, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 576, 7, 7], f16), T([128, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 544, 7, 7], f16), T([128, 544, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 7, 7], f16), T([64, 512, 7, 7], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 14, 14], f16, stride=(200704, 196, 14, 1)), T([64, 128, 14, 14], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 992, 14, 14], f16), T([128, 992, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 23, ((T([64, 32, 14, 14], f16), T([64, 128, 14, 14], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 960, 14, 14], f16), T([128, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 928, 14, 14], f16), T([128, 928, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 896, 14, 14], f16), T([128, 896, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 864, 14, 14], f16), T([128, 864, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 832, 14, 14], f16), T([128, 832, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 800, 14, 14], f16), T([128, 800, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 768, 14, 14], f16), T([128, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 736, 14, 14], f16), T([128, 736, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 704, 14, 14], f16), T([128, 704, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 672, 14, 14], f16), T([128, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 640, 14, 14], f16), T([128, 640, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 608, 14, 14], f16), T([128, 608, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 576, 14, 14], f16), T([128, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 544, 14, 14], f16), T([128, 544, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 512, 14, 14], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 480, 14, 14], f16), T([128, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 448, 14, 14], f16), T([128, 448, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 416, 14, 14], f16), T([128, 416, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 384, 14, 14], f16), T([128, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 352, 14, 14], f16), T([128, 352, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 320, 14, 14], f16), T([128, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 288, 14, 14], f16), T([128, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 14, 14], f16), T([64, 256, 14, 14], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 28, 28], f16, stride=(401408, 784, 28, 1)), T([64, 128, 28, 28], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 480, 28, 28], f16), T([128, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 11, ((T([64, 32, 28, 28], f16), T([64, 128, 28, 28], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 448, 28, 28], f16), T([128, 448, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 416, 28, 28], f16), T([128, 416, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 384, 28, 28], f16), T([128, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 352, 28, 28], f16), T([128, 352, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 320, 28, 28], f16), T([128, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 288, 28, 28], f16), T([128, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 256, 28, 28], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 224, 28, 28], f16), T([128, 224, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 192, 28, 28], f16), T([128, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 160, 28, 28], f16), T([128, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 56, 56], f16, stride=(802816, 3136, 56, 1)), T([64, 128, 56, 56], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 224, 56, 56], f16), T([128, 224, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([64, 32, 56, 56], f16), T([64, 128, 56, 56], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 192, 56, 56], f16), T([128, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 160, 56, 56], f16), T([128, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 96, 56, 56], f16), T([128, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 64, 56, 56], f16), T([128, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 1024, 7, 7], f16, stride=(1024, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([64, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([64, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 1024, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 1024], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([64, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 160, 56, 56], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 224, 56, 56], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 13, ((T([64, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 160, 28, 28], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 224, 28, 28], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 288, 28, 28], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 320, 28, 28], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 352, 28, 28], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 384, 28, 28], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 416, 28, 28], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 448, 28, 28], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 480, 28, 28], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 24, ((T([64, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 288, 14, 14], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 320, 14, 14], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 352, 14, 14], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 416, 14, 14], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 448, 14, 14], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 544, 14, 14], f16), T([544], f16), T([544], f16), T([544], f16), T([544], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 608, 14, 14], f16), T([608], f16), T([608], f16), T([608], f16), T([608], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 640, 14, 14], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 704, 14, 14], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 736, 14, 14], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 800, 14, 14], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 832, 14, 14], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 864, 14, 14], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 896, 14, 14], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 928, 14, 14], f16), T([928], f16), T([928], f16), T([928], f16), T([928], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 992, 14, 14], f16), T([992], f16), T([992], f16), T([992], f16), T([992], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 16, ((T([64, 128, 7, 7], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 544, 7, 7], f16), T([544], f16), T([544], f16), T([544], f16), T([544], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 608, 7, 7], f16), T([608], f16), T([608], f16), T([608], f16), T([608], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 640, 7, 7], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 704, 7, 7], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 736, 7, 7], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 768, 7, 7], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 800, 7, 7], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 832, 7, 7], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 864, 7, 7], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 896, 7, 7], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 928, 7, 7], f16), T([928], f16), T([928], f16), T([928], f16), T([928], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 992, 7, 7], f16), T([992], f16), T([992], f16), T([992], f16), T([992], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([64, 1024, 7, 7], f16), T([64, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 16, ((T([64, 128, 7, 7], f16), T([64, 128, 7, 7], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 992, 7, 7], f16), T([64, 992, 7, 7], f16), T([992], f16), T([992], f16), T([992], f16), T([992], f32), T([992], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 928, 7, 7], f16), T([64, 928, 7, 7], f16), T([928], f16), T([928], f16), T([928], f16), T([928], f32), T([928], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 896, 7, 7], f16), T([64, 896, 7, 7], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f32), T([896], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 864, 7, 7], f16), T([64, 864, 7, 7], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f32), T([864], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 832, 7, 7], f16), T([64, 832, 7, 7], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f32), T([832], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 800, 7, 7], f16), T([64, 800, 7, 7], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f32), T([800], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 768, 7, 7], f16), T([64, 768, 7, 7], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 736, 7, 7], f16), T([64, 736, 7, 7], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f32), T([736], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 704, 7, 7], f16), T([64, 704, 7, 7], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f32), T([704], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 672, 7, 7], f16), T([64, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 640, 7, 7], f16), T([64, 640, 7, 7], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f32), T([640], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 608, 7, 7], f16), T([64, 608, 7, 7], f16), T([608], f16), T([608], f16), T([608], f16), T([608], f32), T([608], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 576, 7, 7], f16), T([64, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 544, 7, 7], f16), T([64, 544, 7, 7], f16), T([544], f16), T([544], f16), T([544], f16), T([544], f32), T([544], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 24, ((T([64, 128, 14, 14], f16), T([64, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 992, 14, 14], f16), T([64, 992, 14, 14], f16), T([992], f16), T([992], f16), T([992], f16), T([992], f32), T([992], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([64, 960, 14, 14], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 928, 14, 14], f16), T([64, 928, 14, 14], f16), T([928], f16), T([928], f16), T([928], f16), T([928], f32), T([928], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 896, 14, 14], f16), T([64, 896, 14, 14], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f32), T([896], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 864, 14, 14], f16), T([64, 864, 14, 14], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f32), T([864], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 832, 14, 14], f16), T([64, 832, 14, 14], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f32), T([832], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 800, 14, 14], f16), T([64, 800, 14, 14], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f32), T([800], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 768, 14, 14], f16), T([64, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 736, 14, 14], f16), T([64, 736, 14, 14], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f32), T([736], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 704, 14, 14], f16), T([64, 704, 14, 14], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f32), T([704], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 672, 14, 14], f16), T([64, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 640, 14, 14], f16), T([64, 640, 14, 14], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f32), T([640], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 608, 14, 14], f16), T([64, 608, 14, 14], f16), T([608], f16), T([608], f16), T([608], f16), T([608], f32), T([608], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 576, 14, 14], f16), T([64, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 544, 14, 14], f16), T([64, 544, 14, 14], f16), T([544], f16), T([544], f16), T([544], f16), T([544], f32), T([544], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 448, 14, 14], f16), T([64, 448, 14, 14], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 416, 14, 14], f16), T([64, 416, 14, 14], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f32), T([416], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 384, 14, 14], f16), T([64, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 352, 14, 14], f16), T([64, 352, 14, 14], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f32), T([352], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 320, 14, 14], f16), T([64, 320, 14, 14], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 288, 14, 14], f16), T([64, 288, 14, 14], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f32), T([288], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 13, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 480, 28, 28], f16), T([64, 480, 28, 28], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 448, 28, 28], f16), T([64, 448, 28, 28], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 416, 28, 28], f16), T([64, 416, 28, 28], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f32), T([416], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 384, 28, 28], f16), T([64, 384, 28, 28], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 352, 28, 28], f16), T([64, 352, 28, 28], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f32), T([352], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 320, 28, 28], f16), T([64, 320, 28, 28], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 288, 28, 28], f16), T([64, 288, 28, 28], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f32), T([288], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 224, 28, 28], f16), T([64, 224, 28, 28], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 192, 28, 28], f16), T([64, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 160, 28, 28], f16), T([64, 160, 28, 28], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 224, 56, 56], f16), T([64, 224, 56, 56], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([64, 192, 56, 56], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 160, 56, 56], f16), T([64, 160, 56, 56], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 96, 56, 56], f16), T([64, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([64, 64, 112, 112], f16),), {})
+cnt: 1, ((T([64, 64, 56, 56], f16),), {})
+cnt: 7, ((T([64, 128, 56, 56], f16),), {})
+cnt: 1, ((T([64, 96, 56, 56], f16),), {})
+cnt: 1, ((T([64, 160, 56, 56], f16),), {})
+cnt: 1, ((T([64, 192, 56, 56], f16),), {})
+cnt: 1, ((T([64, 224, 56, 56], f16),), {})
+cnt: 1, ((T([64, 256, 56, 56], f16),), {})
+cnt: 13, ((T([64, 128, 28, 28], f16),), {})
+cnt: 1, ((T([64, 160, 28, 28], f16),), {})
+cnt: 1, ((T([64, 192, 28, 28], f16),), {})
+cnt: 1, ((T([64, 224, 28, 28], f16),), {})
+cnt: 1, ((T([64, 256, 28, 28], f16),), {})
+cnt: 1, ((T([64, 288, 28, 28], f16),), {})
+cnt: 1, ((T([64, 320, 28, 28], f16),), {})
+cnt: 1, ((T([64, 352, 28, 28], f16),), {})
+cnt: 1, ((T([64, 384, 28, 28], f16),), {})
+cnt: 1, ((T([64, 416, 28, 28], f16),), {})
+cnt: 1, ((T([64, 448, 28, 28], f16),), {})
+cnt: 1, ((T([64, 480, 28, 28], f16),), {})
+cnt: 1, ((T([64, 512, 28, 28], f16),), {})
+cnt: 1, ((T([64, 256, 14, 14], f16),), {})
+cnt: 24, ((T([64, 128, 14, 14], f16),), {})
+cnt: 1, ((T([64, 288, 14, 14], f16),), {})
+cnt: 1, ((T([64, 320, 14, 14], f16),), {})
+cnt: 1, ((T([64, 352, 14, 14], f16),), {})
+cnt: 1, ((T([64, 384, 14, 14], f16),), {})
+cnt: 1, ((T([64, 416, 14, 14], f16),), {})
+cnt: 1, ((T([64, 448, 14, 14], f16),), {})
+cnt: 1, ((T([64, 480, 14, 14], f16),), {})
+cnt: 1, ((T([64, 512, 14, 14], f16),), {})
+cnt: 1, ((T([64, 544, 14, 14], f16),), {})
+cnt: 1, ((T([64, 576, 14, 14], f16),), {})
+cnt: 1, ((T([64, 608, 14, 14], f16),), {})
+cnt: 1, ((T([64, 640, 14, 14], f16),), {})
+cnt: 1, ((T([64, 672, 14, 14], f16),), {})
+cnt: 1, ((T([64, 704, 14, 14], f16),), {})
+cnt: 1, ((T([64, 736, 14, 14], f16),), {})
+cnt: 1, ((T([64, 768, 14, 14], f16),), {})
+cnt: 1, ((T([64, 800, 14, 14], f16),), {})
+cnt: 1, ((T([64, 832, 14, 14], f16),), {})
+cnt: 1, ((T([64, 864, 14, 14], f16),), {})
+cnt: 1, ((T([64, 896, 14, 14], f16),), {})
+cnt: 1, ((T([64, 928, 14, 14], f16),), {})
+cnt: 1, ((T([64, 960, 14, 14], f16),), {})
+cnt: 1, ((T([64, 992, 14, 14], f16),), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16),), {})
+cnt: 1, ((T([64, 512, 7, 7], f16),), {})
+cnt: 16, ((T([64, 128, 7, 7], f16),), {})
+cnt: 1, ((T([64, 544, 7, 7], f16),), {})
+cnt: 1, ((T([64, 576, 7, 7], f16),), {})
+cnt: 1, ((T([64, 608, 7, 7], f16),), {})
+cnt: 1, ((T([64, 640, 7, 7], f16),), {})
+cnt: 1, ((T([64, 672, 7, 7], f16),), {})
+cnt: 1, ((T([64, 704, 7, 7], f16),), {})
+cnt: 1, ((T([64, 736, 7, 7], f16),), {})
+cnt: 1, ((T([64, 768, 7, 7], f16),), {})
+cnt: 1, ((T([64, 800, 7, 7], f16),), {})
+cnt: 1, ((T([64, 832, 7, 7], f16),), {})
+cnt: 1, ((T([64, 864, 7, 7], f16),), {})
+cnt: 1, ((T([64, 896, 7, 7], f16),), {})
+cnt: 1, ((T([64, 928, 7, 7], f16),), {})
+cnt: 1, ((T([64, 960, 7, 7], f16),), {})
+cnt: 1, ((T([64, 992, 7, 7], f16),), {})
+cnt: 1, ((T([64, 1024, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([64, 1024, 7, 7], f16), T([64, 1024, 7, 7], f16), 0), {})
+cnt: 16, ((T([64, 128, 7, 7], f16), T([64, 128, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 992, 7, 7], f16), T([64, 992, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 928, 7, 7], f16), T([64, 928, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 896, 7, 7], f16), T([64, 896, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 864, 7, 7], f16), T([64, 864, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 832, 7, 7], f16), T([64, 832, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 800, 7, 7], f16), T([64, 800, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 768, 7, 7], f16), T([64, 768, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 736, 7, 7], f16), T([64, 736, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 704, 7, 7], f16), T([64, 704, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 672, 7, 7], f16), T([64, 672, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 640, 7, 7], f16), T([64, 640, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 608, 7, 7], f16), T([64, 608, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 576, 7, 7], f16), T([64, 576, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 544, 7, 7], f16), T([64, 544, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16), 0), {})
+cnt: 24, ((T([64, 128, 14, 14], f16), T([64, 128, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 992, 14, 14], f16), T([64, 992, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([64, 960, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 928, 14, 14], f16), T([64, 928, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 896, 14, 14], f16), T([64, 896, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 864, 14, 14], f16), T([64, 864, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 832, 14, 14], f16), T([64, 832, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 800, 14, 14], f16), T([64, 800, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 768, 14, 14], f16), T([64, 768, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 736, 14, 14], f16), T([64, 736, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 704, 14, 14], f16), T([64, 704, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 672, 14, 14], f16), T([64, 672, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 640, 14, 14], f16), T([64, 640, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 608, 14, 14], f16), T([64, 608, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 576, 14, 14], f16), T([64, 576, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 544, 14, 14], f16), T([64, 544, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 448, 14, 14], f16), T([64, 448, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 416, 14, 14], f16), T([64, 416, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 384, 14, 14], f16), T([64, 384, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 352, 14, 14], f16), T([64, 352, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 320, 14, 14], f16), T([64, 320, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 288, 14, 14], f16), T([64, 288, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16), 0), {})
+cnt: 13, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 480, 28, 28], f16), T([64, 480, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 448, 28, 28], f16), T([64, 448, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 416, 28, 28], f16), T([64, 416, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 384, 28, 28], f16), T([64, 384, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 352, 28, 28], f16), T([64, 352, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 320, 28, 28], f16), T([64, 320, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 288, 28, 28], f16), T([64, 288, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 224, 28, 28], f16), T([64, 224, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 192, 28, 28], f16), T([64, 192, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 160, 28, 28], f16), T([64, 160, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16), 0), {})
+cnt: 7, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 224, 56, 56], f16), T([64, 224, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([64, 192, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 160, 56, 56], f16), T([64, 160, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 96, 56, 56], f16), T([64, 96, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/dla102_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/dla102_training.txt
new file mode 100644
index 0000000000000..68226f899cee0
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/dla102_training.txt
@@ -0,0 +1,189 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([64, 1024, 7, 7], f16), T([64, 1024, 7, 7], f16, stride=(125440, 49, 7, 1))), {})
+cnt: 1, ((T([64, 1024, 7, 7], f16, stride=(125440, 49, 7, 1)), T([64, 1024, 7, 7], f16)), {})
+cnt: 1, ((T([64, 1024, 7, 7], f16), T([64, 1024, 7, 7], f16)), {})
+cnt: 1, ((T([64, 512, 7, 7], f16, stride=(125440, 49, 7, 1)), T([64, 512, 7, 7], f16)), {})
+cnt: 16, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16)), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16, stride=(551936, 196, 14, 1))), {})
+cnt: 4, ((T([64, 512, 14, 14], f16, stride=(551936, 196, 14, 1)), T([64, 512, 14, 14], f16)), {})
+cnt: 4, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16, stride=(200704, 196, 14, 1))), {})
+cnt: 4, ((T([64, 512, 14, 14], f16, stride=(200704, 196, 14, 1)), T([64, 512, 14, 14], f16)), {})
+cnt: 2, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16, stride=(301056, 196, 14, 1))), {})
+cnt: 4, ((T([64, 512, 14, 14], f16, stride=(301056, 196, 14, 1)), T([64, 512, 14, 14], f16)), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16, stride=(401408, 196, 14, 1))), {})
+cnt: 3, ((T([64, 512, 14, 14], f16, stride=(401408, 196, 14, 1)), T([64, 512, 14, 14], f16)), {})
+cnt: 9, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16)), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16, stride=(903168, 784, 28, 1))), {})
+cnt: 3, ((T([64, 256, 28, 28], f16, stride=(903168, 784, 28, 1)), T([64, 256, 28, 28], f16)), {})
+cnt: 2, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16, stride=(401408, 784, 28, 1))), {})
+cnt: 2, ((T([64, 256, 28, 28], f16, stride=(401408, 784, 28, 1)), T([64, 256, 28, 28], f16)), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16, stride=(602112, 784, 28, 1))), {})
+cnt: 2, ((T([64, 256, 28, 28], f16, stride=(602112, 784, 28, 1)), T([64, 256, 28, 28], f16)), {})
+cnt: 3, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16)), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16, stride=(802816, 3136, 56, 1))), {})
+cnt: 1, ((T([64, 128, 56, 56], f16, stride=(802816, 3136, 56, 1)), T([64, 128, 56, 56], f16)), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 105, ((T([], i64), 1), {})
+cnt: 3, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16)), {})
+cnt: 12, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16)), {})
+cnt: 24, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16)), {})
+cnt: 3, ((T([64, 1024, 7, 7], f16), T([64, 1024, 7, 7], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16)], 1), {})
+cnt: 2, (([T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16), T([64, 128, 28, 28], f16), T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16)], 1), {})
+cnt: 4, (([T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16)], 1), {})
+cnt: 2, (([T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([64, 256, 14, 14], f16), T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 1024, 7, 7], f16), T([64, 1024, 7, 7], f16), T([64, 512, 7, 7], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([16, 3, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 16, 224, 224], f16), T([16, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 16, 224, 224], f16), T([32, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 56, 56], f16), T([128, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 64, 56, 56], f16), T([128, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([64, 128, 28, 28], f16), T([256, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([128, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([64, 256, 28, 28], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([64, 128, 28, 28], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 768, 28, 28], f16), T([256, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1152, 28, 28], f16), T([256, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 17, ((T([64, 256, 14, 14], f16), T([512, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([256, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 15, ((T([64, 512, 14, 14], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 15, ((T([64, 256, 14, 14], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([64, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 1536, 14, 14], f16), T([512, 1536, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 2048, 14, 14], f16), T([512, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 2816, 14, 14], f16), T([512, 2816, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 512, 7, 7], f16), T([1024, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([512, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([512, 512, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1024, 7, 7], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([512, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 2560, 7, 7], f16), T([1024, 2560, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1024, 1, 1], f16), T([1000, 1024, 1, 1], f16), T([1000], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 1000, 1, 1], f16), T([64, 1024, 1, 1], f16), T([1000, 1024, 1, 1], f16), [1000], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 1024, 7, 7], f16), T([64, 2560, 7, 7], f16), T([1024, 2560, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 1024, 7, 7], f16), T([64, 512, 7, 7], f16), T([1024, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16), T([512, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 1024, 7, 7], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 14, 14], f16), T([512, 512, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([512, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 2816, 14, 14], f16), T([512, 2816, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 17, ((T([64, 512, 14, 14], f16), T([64, 256, 14, 14], f16), T([512, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 15, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 15, ((T([64, 256, 14, 14], f16), T([64, 512, 14, 14], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([64, 512, 14, 14], f16), T([64, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 512, 14, 14], f16), T([64, 1536, 14, 14], f16), T([512, 1536, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 2048, 14, 14], f16), T([512, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 28, 28], f16), T([256, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 1152, 28, 28], f16), T([256, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 9, ((T([64, 256, 28, 28], f16), T([64, 128, 28, 28], f16), T([256, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([64, 128, 28, 28], f16), T([64, 256, 28, 28], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 256, 28, 28], f16), T([64, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 768, 28, 28], f16), T([256, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 56, 56], f16), T([128, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 128, 56, 56], f16), T([64, 64, 56, 56], f16), T([128, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 128, 56, 56], f16), T([64, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 112, 112], f16), T([64, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 32, 112, 112], f16), T([64, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 32, 56, 56], f16), T([128, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 16, 224, 224], f16), T([32, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 16, 224, 224], f16), T([64, 16, 224, 224], f16), T([16, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 16, 224, 224], f16), T([64, 3, 224, 224], f16), T([16, 3, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 1024, 7, 7], f16, stride=(1024, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([64, 32, 112, 112], f16), [2, 2], [2, 2]), {})
+cnt: 3, ((T([64, 128, 56, 56], f16), [2, 2], [2, 2]), {})
+cnt: 4, ((T([64, 256, 28, 28], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), [2, 2], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 14, 14], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 512, 7, 7], i64)), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 28, 28], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 256, 14, 14], i64)), {})
+cnt: 1, ((T([64, 256, 14, 14], f16, stride=(551936, 196, 14, 1)), T([64, 256, 28, 28], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 256, 14, 14], i64)), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 56, 56], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 128, 28, 28], i64)), {})
+cnt: 1, ((T([64, 128, 28, 28], f16, stride=(903168, 784, 28, 1)), T([64, 128, 56, 56], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 128, 28, 28], i64)), {})
+cnt: 1, ((T([64, 32, 56, 56], f16), T([64, 32, 112, 112], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 32, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 1024, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([64, 16, 224, 224], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([64, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 14, ((T([64, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 15, ((T([64, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 26, ((T([64, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 31, ((T([64, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([64, 1024, 7, 7], f16), T([64, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 26, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 31, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 14, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 15, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([64, 16, 224, 224], f16), T([64, 16, 224, 224], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([64, 16, 224, 224], f16),), {})
+cnt: 1, ((T([64, 32, 112, 112], f16),), {})
+cnt: 1, ((T([64, 64, 112, 112], f16),), {})
+cnt: 3, ((T([64, 64, 56, 56], f16),), {})
+cnt: 4, ((T([64, 128, 56, 56], f16),), {})
+cnt: 15, ((T([64, 128, 28, 28], f16),), {})
+cnt: 13, ((T([64, 256, 28, 28], f16),), {})
+cnt: 31, ((T([64, 256, 14, 14], f16),), {})
+cnt: 25, ((T([64, 512, 14, 14], f16),), {})
+cnt: 3, ((T([64, 512, 7, 7], f16),), {})
+cnt: 3, ((T([64, 1024, 7, 7], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([64, 1024, 7, 7], f16), T([64, 1024, 7, 7], f16), 0), {})
+cnt: 3, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16), 0), {})
+cnt: 25, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), 0), {})
+cnt: 31, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16), 0), {})
+cnt: 13, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16), 0), {})
+cnt: 15, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16), 0), {})
+cnt: 4, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16), 0), {})
+cnt: 3, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), 0), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), 0), {})
+cnt: 2, ((T([64, 16, 224, 224], f16), T([64, 16, 224, 224], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/dm_nfnet_f0_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/dm_nfnet_f0_training.txt
new file mode 100644
index 0000000000000..683e671e28665
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/dm_nfnet_f0_training.txt
@@ -0,0 +1,296 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 3, ((T([128, 256, 48, 48], f16), T([128, 256, 48, 48], f16)), {})
+cnt: 6, ((T([128, 512, 24, 24], f16), T([128, 512, 24, 24], f16)), {})
+cnt: 18, ((T([128, 1536, 12, 12], f16), T([128, 1536, 12, 12], f16)), {})
+cnt: 8, ((T([128, 1536, 6, 6], f16), T([128, 1536, 6, 6], f16)), {})
+cnt: 1, ((T([128, 128, 48, 48], f16), T([128, 128, 48, 48], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 3072], f16), T([3072, 1000], f16, stride=(1, 3072))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 256, 48, 48], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([128, 512, 24, 24], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([128, 1536, 12, 12], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 1536, 6, 6], f16), T([128, 1536, 12, 12], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([128, 512, 12, 12], f16), T([128, 512, 24, 24], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([128, 256, 24, 24], f16), T([128, 256, 48, 48], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 192, 192], f16),), {})
+cnt: 1, ((T([128, 256, 48, 48], f16),), {})
+cnt: 2, ((T([128, 512, 24, 24], f16),), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16),), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([128, 3, 192, 192], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 768, 12, 12], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 768, 13, 13], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([128, 768, 25, 25], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([128, 256, 49, 49], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([128, 64, 97, 97], f16), [0, -1, 0, -1]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 193, 193], f16), T([16, 3, 3, 3], f16), T([16], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([32, 16, 3, 3], f16), T([32], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([64, 32, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 97, 97], f16), T([128, 64, 3, 3], f16), T([128], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 48, 48], f16), T([256, 128, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 48, 48], f16), T([128, 128, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 48, 48], f16), T([128, 128, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 1, 1], f16), T([256, 128, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 256, 24, 24], f16), T([512, 256, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), T([256, 256, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 49, 49], f16), T([256, 128, 3, 3], f16), T([256], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 2), {})
+cnt: 3, ((T([128, 256, 24, 24], f16), T([256, 128, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 2, ((T([128, 512, 1, 1], f16), T([256, 512, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 256, 1, 1], f16), T([512, 256, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 24, 24], f16), T([256, 512, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 12, 12], f16), T([1536, 512, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 24, 24], f16), T([768, 512, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 768, 25, 25], f16), T([768, 128, 3, 3], f16), T([768], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 6), {})
+cnt: 11, ((T([128, 768, 12, 12], f16), T([768, 128, 3, 3], f16), T([768], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 6), {})
+cnt: 6, ((T([128, 768, 12, 12], f16), T([1536, 768, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([128, 1536, 1, 1], f16), T([768, 1536, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([128, 768, 1, 1], f16), T([1536, 768, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([768, 1536, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1536, 6, 6], f16), T([1536, 1536, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 768, 13, 13], f16), T([768, 128, 3, 3], f16), T([768], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 6), {})
+cnt: 5, ((T([128, 768, 6, 6], f16), T([768, 128, 3, 3], f16), T([768], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 6), {})
+cnt: 3, ((T([128, 768, 6, 6], f16), T([1536, 768, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), T([768, 1536, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1536, 6, 6], f16), T([3072, 1536, 1, 1], f16), T([3072], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 3072, 6, 6], f16), T([128, 1536, 6, 6], f16), T([3072, 1536, 1, 1], f16), [3072], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 9, ((T([128, 1536, 1, 1], f16), T([128, 768, 1, 1], f16), T([1536, 768, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 9, ((T([128, 768, 1, 1], f16), T([128, 1536, 1, 1], f16), T([768, 1536, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), T([128, 768, 6, 6], f16), T([1536, 768, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 5, ((T([128, 768, 6, 6], f16), T([128, 768, 6, 6], f16), T([768, 128, 3, 3], f16), [768], [1, 1], [1, 1], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 2, ((T([128, 768, 6, 6], f16), T([128, 1536, 6, 6], f16), T([768, 1536, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 6, 6], f16), T([128, 768, 13, 13], f16), T([768, 128, 3, 3], f16), [768], [2, 2], [0, 0], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 6, ((T([128, 768, 12, 12], f16), T([128, 1536, 12, 12], f16), T([768, 1536, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1536, 6, 6], f16), T([128, 1536, 6, 6], f16), T([1536, 1536, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([128, 768, 12, 12], f16), T([1536, 768, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 11, ((T([128, 768, 12, 12], f16), T([128, 768, 12, 12], f16), T([768, 128, 3, 3], f16), [768], [1, 1], [1, 1], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 12, 12], f16), T([128, 768, 25, 25], f16), T([768, 128, 3, 3], f16), [768], [2, 2], [0, 0], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), T([128, 512, 24, 24], f16), T([768, 512, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1536, 12, 12], f16), T([128, 512, 12, 12], f16), T([1536, 512, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 512, 1, 1], f16), T([128, 256, 1, 1], f16), T([512, 256, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 256, 1, 1], f16), T([128, 512, 1, 1], f16), T([256, 512, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 512, 24, 24], f16), T([128, 256, 24, 24], f16), T([512, 256, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 24, 24], f16), T([128, 256, 24, 24], f16), T([256, 128, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 24, 24], f16), T([128, 512, 24, 24], f16), T([256, 512, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 24, 24], f16), T([128, 256, 49, 49], f16), T([256, 128, 3, 3], f16), [256], [2, 2], [0, 0], [1, 1], False, [0, 0], 2, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), T([128, 256, 48, 48], f16), T([256, 256, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 128, 1, 1], f16), T([256, 128, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 1, 1], f16), T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), T([128, 128, 48, 48], f16), T([256, 128, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 128, 48, 48], f16), T([128, 128, 48, 48], f16), T([128, 128, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 48, 48], f16), T([128, 128, 48, 48], f16), T([128, 128, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 48, 48], f16), T([128, 64, 97, 97], f16), T([128, 64, 3, 3], f16), [128], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), T([128, 32, 96, 96], f16), T([64, 32, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([128, 16, 96, 96], f16), T([32, 16, 3, 3], f16), [32], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([128, 3, 193, 193], f16), T([16, 3, 3, 3], f16), [16], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 192, 192], f16), T([128, 3, 192, 192], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 3072, 6, 6], f16, stride=(3072, 1, 0, 0)), 36), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16, stride=(1536, 1, 0, 0)), 36), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16, stride=(1536, 1, 0, 0)), 144), {})
+cnt: 2, ((T([128, 512, 24, 24], f16, stride=(512, 1, 0, 0)), 576), {})
+cnt: 1, ((T([128, 256, 48, 48], f16, stride=(256, 1, 0, 0)), 2304), {})
+Operator: aten.gelu.default
+cnt: 1, ((T([128, 16, 96, 96], f16),), {})
+cnt: 1, ((T([128, 32, 96, 96], f16),), {})
+cnt: 1, ((T([128, 64, 96, 96], f16),), {})
+cnt: 4, ((T([128, 128, 48, 48], f16),), {})
+cnt: 2, ((T([128, 256, 48, 48], f16),), {})
+cnt: 5, ((T([128, 256, 24, 24], f16),), {})
+cnt: 2, ((T([128, 512, 24, 24], f16),), {})
+cnt: 1, ((T([128, 768, 24, 24], f16),), {})
+cnt: 18, ((T([128, 768, 12, 12], f16),), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16),), {})
+cnt: 8, ((T([128, 768, 6, 6], f16),), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16),), {})
+cnt: 1, ((T([128, 3072, 6, 6], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([128, 3072, 6, 6], f16), T([128, 3072, 6, 6], f16)), {})
+cnt: 8, ((T([128, 768, 6, 6], f16), T([128, 768, 6, 6], f16)), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), T([128, 1536, 6, 6], f16)), {})
+cnt: 18, ((T([128, 768, 12, 12], f16), T([128, 768, 12, 12], f16)), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([128, 1536, 12, 12], f16)), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), T([128, 768, 24, 24], f16)), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), T([128, 512, 24, 24], f16)), {})
+cnt: 5, ((T([128, 256, 24, 24], f16), T([128, 256, 24, 24], f16)), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), T([128, 256, 48, 48], f16)), {})
+cnt: 4, ((T([128, 128, 48, 48], f16), T([128, 128, 48, 48], f16)), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), T([128, 64, 96, 96], f16)), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([128, 32, 96, 96], f16)), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([128, 16, 96, 96], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 256, 48, 48], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), [2, 3], True), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), [2, 3], True), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 3072, 6, 6], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 3072], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 3072], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([16, 1, 1, 1], f16), 0.19245008972987526), {})
+cnt: 2, ((T([32, 1, 1, 1], f16), 0.08333333333333333), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.05892556509887896), {})
+cnt: 2, ((T([128, 1, 1, 1], f16), 0.041666666666666664), {})
+cnt: 2, ((T([128, 128, 48, 48], f16), 1.0), {})
+cnt: 4, ((T([256, 1, 1, 1], f16), 0.08838834764831845), {})
+cnt: 2, ((T([128, 1, 1, 1], f16), 0.08838834764831845), {})
+cnt: 4, ((T([128, 1, 1, 1], f16), 0.02946278254943948), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), T([128, 256, 1, 1], f16)), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 2.0), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 0.2), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 0.9805806756909201), {})
+cnt: 6, ((T([512, 1, 1, 1], f16), 0.0625), {})
+cnt: 2, ((T([256, 1, 1, 1], f16), 0.0625), {})
+cnt: 8, ((T([256, 1, 1, 1], f16), 0.02946278254943948), {})
+cnt: 4, ((T([128, 512, 24, 24], f16), T([128, 512, 1, 1], f16)), {})
+cnt: 4, ((T([128, 512, 24, 24], f16), 2.0), {})
+cnt: 4, ((T([128, 512, 24, 24], f16), 0.2), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), 0.9805806756909201), {})
+cnt: 2, ((T([256, 1, 1, 1], f16), 0.04419417382415922), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), 0.9622504486493761), {})
+cnt: 2, ((T([1536, 1, 1, 1], f16), 0.04419417382415922), {})
+cnt: 2, ((T([768, 1, 1, 1], f16), 0.04419417382415922), {})
+cnt: 36, ((T([768, 1, 1, 1], f16), 0.02946278254943948), {})
+cnt: 18, ((T([1536, 1, 1, 1], f16), 0.03608439182435161), {})
+cnt: 12, ((T([128, 1536, 12, 12], f16), T([128, 1536, 1, 1], f16)), {})
+cnt: 12, ((T([128, 1536, 12, 12], f16), 2.0), {})
+cnt: 12, ((T([128, 1536, 12, 12], f16), 0.2), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9805806756909201), {})
+cnt: 16, ((T([768, 1, 1, 1], f16), 0.02551551815399144), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9622504486493761), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9449111825230679), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9284766908852592), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9128709291752768), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.8980265101338745), {})
+cnt: 2, ((T([1536, 1, 1, 1], f16), 0.02551551815399144), {})
+cnt: 6, ((T([128, 1536, 6, 6], f16), T([128, 1536, 1, 1], f16)), {})
+cnt: 6, ((T([128, 1536, 6, 6], f16), 2.0), {})
+cnt: 6, ((T([128, 1536, 6, 6], f16), 0.2), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), 0.9805806756909201), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), 0.9622504486493761), {})
+cnt: 2, ((T([3072, 1, 1, 1], f16), 0.02551551815399144), {})
+cnt: 1, ((T([128, 3072, 6, 6], f16), 1.7015043497085571), {})
+cnt: 6, ((T([128, 1536, 6, 6], f16), T([128, 1536, 6, 6], f16)), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), T([], f16)), {})
+cnt: 8, ((T([128, 768, 6, 6], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), 1.7015043497085571), {})
+cnt: 18, ((T([128, 768, 12, 12], f16), 1.7015043497085571), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), 1.7015043497085571), {})
+cnt: 12, ((T([128, 1536, 12, 12], f16), T([128, 1536, 12, 12], f16)), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([], f16)), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), 1.7015043497085571), {})
+cnt: 4, ((T([128, 512, 24, 24], f16), T([128, 512, 24, 24], f16)), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), T([], f16)), {})
+cnt: 5, ((T([128, 256, 24, 24], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), T([128, 256, 48, 48], f16)), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), T([], f16)), {})
+cnt: 4, ((T([128, 128, 48, 48], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), 1.7015043497085571), {})
+Operator: aten.mul_.Tensor
+cnt: 1, ((T([128, 16, 96, 96], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), 1.7015043497085571), {})
+cnt: 4, ((T([128, 128, 48, 48], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), T([], f16)), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 1.7015043497085571), {})
+cnt: 5, ((T([128, 256, 24, 24], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), T([], f16)), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), 1.7015043497085571), {})
+cnt: 18, ((T([128, 768, 12, 12], f16), 1.7015043497085571), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([], f16)), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), 1.7015043497085571), {})
+cnt: 8, ((T([128, 768, 6, 6], f16), 1.7015043497085571), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), T([], f16)), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 3072, 6, 6], f16), 1.7015043497085571), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([1, 16, 27], f16), T([16], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 32, 144], f16), T([32], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 64, 288], f16), T([64], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 128, 576], f16), T([128], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 2, ((T([1, 256, 128], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 128, 128], f16), T([128], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 2, ((T([1, 128, 1152], f16), T([128], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 3, ((T([1, 512, 256], f16), T([512], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 256, 256], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 4, ((T([1, 256, 1152], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 256, 512], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 1536, 512], f16), T([1536], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 768, 512], f16), T([768], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 18, ((T([1, 768, 1152], f16), T([768], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 9, ((T([1, 1536, 768], f16), T([1536], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 8, ((T([1, 768, 1536], f16), T([768], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 1536, 1536], f16), T([1536], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 3072, 1536], f16), T([3072], f16), None, None, None, True, 0.0, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([1, 3072, 1536], f16), T([1, 3072, 1536], f16), T([3072], f16), None, None, T([3072], f32), T([3072], f32), True, 1e-05, [True, True, False]), {})
+cnt: 9, ((T([1, 1536, 768], f16), T([1, 1536, 768], f16), T([1536], f16), None, None, T([1536], f32), T([1536], f32), True, 1e-05, [True, True, False]), {})
+cnt: 18, ((T([1, 768, 1152], f16), T([1, 768, 1152], f16), T([768], f16), None, None, T([768], f32), T([768], f32), True, 1e-05, [True, True, False]), {})
+cnt: 8, ((T([1, 768, 1536], f16), T([1, 768, 1536], f16), T([768], f16), None, None, T([768], f32), T([768], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 1536, 1536], f16), T([1, 1536, 1536], f16), T([1536], f16), None, None, T([1536], f32), T([1536], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 768, 512], f16), T([1, 768, 512], f16), T([768], f16), None, None, T([768], f32), T([768], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 1536, 512], f16), T([1, 1536, 512], f16), T([1536], f16), None, None, T([1536], f32), T([1536], f32), True, 1e-05, [True, True, False]), {})
+cnt: 3, ((T([1, 512, 256], f16), T([1, 512, 256], f16), T([512], f16), None, None, T([512], f32), T([512], f32), True, 1e-05, [True, True, False]), {})
+cnt: 4, ((T([1, 256, 1152], f16), T([1, 256, 1152], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 256, 512], f16), T([1, 256, 512], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 256, 256], f16), T([1, 256, 256], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 2, ((T([1, 256, 128], f16), T([1, 256, 128], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 2, ((T([1, 128, 1152], f16), T([1, 128, 1152], f16), T([128], f16), None, None, T([128], f32), T([128], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 128, 128], f16), T([1, 128, 128], f16), T([128], f16), None, None, T([128], f32), T([128], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 128, 576], f16), T([1, 128, 576], f16), T([128], f16), None, None, T([128], f32), T([128], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 64, 288], f16), T([1, 64, 288], f16), T([64], f16), None, None, T([64], f32), T([64], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 32, 144], f16), T([1, 32, 144], f16), T([32], f16), None, None, T([32], f32), T([32], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 16, 27], f16), T([1, 16, 27], f16), T([16], f16), None, None, T([16], f32), T([16], f32), True, 1e-05, [True, True, False]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 128, 1, 1], f16),), {})
+cnt: 2, ((T([128, 256, 1, 1], f16),), {})
+cnt: 9, ((T([128, 768, 1, 1], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([128, 256, 1, 1], f16),), {})
+cnt: 2, ((T([128, 512, 1, 1], f16),), {})
+cnt: 9, ((T([128, 1536, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 9, ((T([128, 1536, 1, 1], f16), T([128, 1536, 1, 1], f16)), {})
+cnt: 2, ((T([128, 512, 1, 1], f16), T([128, 512, 1, 1], f16)), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), [2, 3], True), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), [2, 3], True), {})
+Operator: aten.sum.default
+cnt: 3, ((T([128, 1536, 6, 6], f16),), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16),), {})
+cnt: 2, ((T([128, 512, 24, 24], f16),), {})
+cnt: 1, ((T([128, 256, 48, 48], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 9, ((T([128, 768, 1, 1], f16), T([128, 768, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 128, 1, 1], f16), T([128, 128, 1, 1], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/dpn107_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/dpn107_training.txt
new file mode 100644
index 0000000000000..d1572e4cd2ce0
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/dpn107_training.txt
@@ -0,0 +1,545 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([32, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([32, 1000], f16), T([32, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 111, ((T([], i64), 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16, stride=(928256, 3136, 56, 1)), T([32, 256, 56, 56], f16, stride=(865536, 3136, 56, 1))), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16, stride=(865536, 3136, 56, 1))), {})
+cnt: 1, ((T([32, 512, 28, 28], f16, stride=(501760, 784, 28, 1)), T([32, 512, 28, 28], f16, stride=(451584, 784, 28, 1))), {})
+cnt: 7, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16, stride=(451584, 784, 28, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16, stride=(225792, 196, 14, 1)), T([32, 1024, 14, 14], f16, stride=(213248, 196, 14, 1))), {})
+cnt: 19, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(213248, 196, 14, 1))), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16, stride=(112896, 49, 7, 1)), T([32, 2048, 7, 7], f16, stride=(106624, 49, 7, 1))), {})
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16, stride=(106624, 49, 7, 1))), {})
+cnt: 3, ((T([32, 2176, 7, 7], f16), T([32, 2176, 7, 7], f16)), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16, stride=(131712, 49, 7, 1)), T([32, 2048, 7, 7], f16, stride=(125440, 49, 7, 1))), {})
+cnt: 1, ((T([32, 512, 7, 7], f16, stride=(131712, 49, 7, 1)), T([32, 512, 7, 7], f16, stride=(125440, 49, 7, 1))), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16, stride=(119168, 49, 7, 1))), {})
+cnt: 1, ((T([32, 384, 7, 7], f16, stride=(25088, 49, 7, 1)), T([32, 384, 7, 7], f16, stride=(119168, 49, 7, 1))), {})
+cnt: 1, ((T([32, 2304, 7, 7], f16), T([32, 2304, 7, 7], f16)), {})
+cnt: 1, ((T([32, 2432, 14, 14], f16), T([32, 2432, 14, 14], f16)), {})
+cnt: 20, ((T([32, 1088, 14, 14], f16), T([32, 1088, 14, 14], f16)), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16, stride=(476672, 196, 14, 1)), T([32, 1024, 14, 14], f16, stride=(464128, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1344, 14, 14], f16, stride=(476672, 196, 14, 1)), T([32, 1344, 14, 14], f16, stride=(464128, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(451584, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1280, 14, 14], f16, stride=(263424, 196, 14, 1)), T([32, 1280, 14, 14], f16, stride=(451584, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(439040, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1216, 14, 14], f16, stride=(250880, 196, 14, 1)), T([32, 1216, 14, 14], f16, stride=(439040, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(426496, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1152, 14, 14], f16, stride=(238336, 196, 14, 1)), T([32, 1152, 14, 14], f16, stride=(426496, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(413952, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1088, 14, 14], f16, stride=(225792, 196, 14, 1)), T([32, 1088, 14, 14], f16, stride=(413952, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(401408, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16, stride=(213248, 196, 14, 1)), T([32, 1024, 14, 14], f16, stride=(401408, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(388864, 196, 14, 1))), {})
+cnt: 1, ((T([32, 960, 14, 14], f16, stride=(200704, 196, 14, 1)), T([32, 960, 14, 14], f16, stride=(388864, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(376320, 196, 14, 1))), {})
+cnt: 1, ((T([32, 896, 14, 14], f16, stride=(188160, 196, 14, 1)), T([32, 896, 14, 14], f16, stride=(376320, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(363776, 196, 14, 1))), {})
+cnt: 1, ((T([32, 832, 14, 14], f16, stride=(175616, 196, 14, 1)), T([32, 832, 14, 14], f16, stride=(363776, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(351232, 196, 14, 1))), {})
+cnt: 1, ((T([32, 768, 14, 14], f16, stride=(163072, 196, 14, 1)), T([32, 768, 14, 14], f16, stride=(351232, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(338688, 196, 14, 1))), {})
+cnt: 1, ((T([32, 704, 14, 14], f16, stride=(150528, 196, 14, 1)), T([32, 704, 14, 14], f16, stride=(338688, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(326144, 196, 14, 1))), {})
+cnt: 1, ((T([32, 640, 14, 14], f16, stride=(137984, 196, 14, 1)), T([32, 640, 14, 14], f16, stride=(326144, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(313600, 196, 14, 1))), {})
+cnt: 1, ((T([32, 576, 14, 14], f16, stride=(125440, 196, 14, 1)), T([32, 576, 14, 14], f16, stride=(313600, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(301056, 196, 14, 1))), {})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(112896, 196, 14, 1)), T([32, 512, 14, 14], f16, stride=(301056, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(288512, 196, 14, 1))), {})
+cnt: 1, ((T([32, 448, 14, 14], f16, stride=(100352, 196, 14, 1)), T([32, 448, 14, 14], f16, stride=(288512, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(275968, 196, 14, 1))), {})
+cnt: 1, ((T([32, 384, 14, 14], f16, stride=(87808, 196, 14, 1)), T([32, 384, 14, 14], f16, stride=(275968, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(263424, 196, 14, 1))), {})
+cnt: 1, ((T([32, 320, 14, 14], f16, stride=(75264, 196, 14, 1)), T([32, 320, 14, 14], f16, stride=(263424, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(250880, 196, 14, 1))), {})
+cnt: 1, ((T([32, 256, 14, 14], f16, stride=(62720, 196, 14, 1)), T([32, 256, 14, 14], f16, stride=(250880, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16, stride=(238336, 196, 14, 1))), {})
+cnt: 1, ((T([32, 192, 14, 14], f16, stride=(50176, 196, 14, 1)), T([32, 192, 14, 14], f16, stride=(238336, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1152, 14, 14], f16), T([32, 1152, 14, 14], f16)), {})
+cnt: 1, ((T([32, 1152, 28, 28], f16), T([32, 1152, 28, 28], f16)), {})
+cnt: 8, ((T([32, 576, 28, 28], f16), T([32, 576, 28, 28], f16)), {})
+cnt: 1, ((T([32, 512, 28, 28], f16, stride=(903168, 784, 28, 1)), T([32, 512, 28, 28], f16, stride=(852992, 784, 28, 1))), {})
+cnt: 1, ((T([32, 576, 28, 28], f16, stride=(903168, 784, 28, 1)), T([32, 576, 28, 28], f16, stride=(852992, 784, 28, 1))), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16, stride=(802816, 784, 28, 1))), {})
+cnt: 1, ((T([32, 512, 28, 28], f16, stride=(451584, 784, 28, 1)), T([32, 512, 28, 28], f16, stride=(802816, 784, 28, 1))), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16, stride=(752640, 784, 28, 1))), {})
+cnt: 1, ((T([32, 448, 28, 28], f16, stride=(401408, 784, 28, 1)), T([32, 448, 28, 28], f16, stride=(752640, 784, 28, 1))), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16, stride=(702464, 784, 28, 1))), {})
+cnt: 1, ((T([32, 384, 28, 28], f16, stride=(351232, 784, 28, 1)), T([32, 384, 28, 28], f16, stride=(702464, 784, 28, 1))), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16, stride=(652288, 784, 28, 1))), {})
+cnt: 1, ((T([32, 320, 28, 28], f16, stride=(301056, 784, 28, 1)), T([32, 320, 28, 28], f16, stride=(652288, 784, 28, 1))), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16, stride=(602112, 784, 28, 1))), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(250880, 784, 28, 1)), T([32, 256, 28, 28], f16, stride=(602112, 784, 28, 1))), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16, stride=(551936, 784, 28, 1))), {})
+cnt: 1, ((T([32, 192, 28, 28], f16, stride=(200704, 784, 28, 1)), T([32, 192, 28, 28], f16, stride=(551936, 784, 28, 1))), {})
+cnt: 1, ((T([32, 640, 28, 28], f16), T([32, 640, 28, 28], f16)), {})
+cnt: 1, ((T([32, 376, 56, 56], f16), T([32, 376, 56, 56], f16)), {})
+cnt: 4, ((T([32, 276, 56, 56], f16), T([32, 276, 56, 56], f16)), {})
+cnt: 1, ((T([32, 256, 56, 56], f16, stride=(1179136, 3136, 56, 1)), T([32, 256, 56, 56], f16, stride=(1116416, 3136, 56, 1))), {})
+cnt: 1, ((T([32, 100, 56, 56], f16, stride=(1179136, 3136, 56, 1)), T([32, 100, 56, 56], f16, stride=(1116416, 3136, 56, 1))), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16, stride=(1053696, 3136, 56, 1))), {})
+cnt: 1, ((T([32, 80, 56, 56], f16, stride=(313600, 3136, 56, 1)), T([32, 80, 56, 56], f16, stride=(1053696, 3136, 56, 1))), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16, stride=(990976, 3136, 56, 1))), {})
+cnt: 1, ((T([32, 60, 56, 56], f16, stride=(250880, 3136, 56, 1)), T([32, 60, 56, 56], f16, stride=(990976, 3136, 56, 1))), {})
+cnt: 1, ((T([32, 296, 56, 56], f16), T([32, 296, 56, 56], f16)), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([32, 40, 56, 56], f16, stride=(928256, 3136, 56, 1)), T([32, 20, 56, 56], f16, stride=(865536, 3136, 56, 1))], 1), {})
+cnt: 1, (([T([32, 256, 56, 56], f16), T([32, 60, 56, 56], f16)], 1), {})
+cnt: 1, (([T([32, 60, 56, 56], f16), T([32, 20, 56, 56], f16, stride=(865536, 3136, 56, 1))], 1), {})
+cnt: 1, (([T([32, 256, 56, 56], f16), T([32, 80, 56, 56], f16)], 1), {})
+cnt: 1, (([T([32, 80, 56, 56], f16), T([32, 20, 56, 56], f16, stride=(865536, 3136, 56, 1))], 1), {})
+cnt: 1, (([T([32, 256, 56, 56], f16), T([32, 100, 56, 56], f16)], 1), {})
+cnt: 1, (([T([32, 100, 56, 56], f16), T([32, 20, 56, 56], f16, stride=(865536, 3136, 56, 1))], 1), {})
+cnt: 1, (([T([32, 256, 56, 56], f16), T([32, 120, 56, 56], f16)], 1), {})
+cnt: 1, (([T([32, 128, 28, 28], f16, stride=(501760, 784, 28, 1)), T([32, 64, 28, 28], f16, stride=(451584, 784, 28, 1))], 1), {})
+cnt: 1, (([T([32, 512, 28, 28], f16), T([32, 192, 28, 28], f16)], 1), {})
+cnt: 1, (([T([32, 192, 28, 28], f16), T([32, 64, 28, 28], f16, stride=(451584, 784, 28, 1))], 1), {})
+cnt: 1, (([T([32, 512, 28, 28], f16), T([32, 256, 28, 28], f16)], 1), {})
+cnt: 1, (([T([32, 256, 28, 28], f16), T([32, 64, 28, 28], f16, stride=(451584, 784, 28, 1))], 1), {})
+cnt: 1, (([T([32, 512, 28, 28], f16), T([32, 320, 28, 28], f16)], 1), {})
+cnt: 1, (([T([32, 320, 28, 28], f16), T([32, 64, 28, 28], f16, stride=(451584, 784, 28, 1))], 1), {})
+cnt: 1, (([T([32, 512, 28, 28], f16), T([32, 384, 28, 28], f16)], 1), {})
+cnt: 1, (([T([32, 384, 28, 28], f16), T([32, 64, 28, 28], f16, stride=(451584, 784, 28, 1))], 1), {})
+cnt: 1, (([T([32, 512, 28, 28], f16), T([32, 448, 28, 28], f16)], 1), {})
+cnt: 1, (([T([32, 448, 28, 28], f16), T([32, 64, 28, 28], f16, stride=(451584, 784, 28, 1))], 1), {})
+cnt: 1, (([T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)], 1), {})
+cnt: 1, (([T([32, 512, 28, 28], f16), T([32, 64, 28, 28], f16, stride=(451584, 784, 28, 1))], 1), {})
+cnt: 1, (([T([32, 512, 28, 28], f16), T([32, 576, 28, 28], f16)], 1), {})
+cnt: 1, (([T([32, 576, 28, 28], f16), T([32, 64, 28, 28], f16, stride=(451584, 784, 28, 1))], 1), {})
+cnt: 1, (([T([32, 512, 28, 28], f16), T([32, 640, 28, 28], f16)], 1), {})
+cnt: 1, (([T([32, 128, 14, 14], f16, stride=(225792, 196, 14, 1)), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 192, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 192, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 256, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 256, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 320, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 320, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 384, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 384, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 448, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 448, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 512, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 512, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 576, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 576, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 640, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 640, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 704, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 704, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 768, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 768, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 832, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 832, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 896, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 896, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 960, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 960, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 1088, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 1088, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 1152, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 1152, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 1216, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 1216, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 1280, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 1280, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 1344, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 1344, 14, 14], f16), T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1))], 1), {})
+cnt: 1, (([T([32, 1024, 14, 14], f16), T([32, 1408, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 256, 7, 7], f16, stride=(112896, 49, 7, 1)), T([32, 128, 7, 7], f16, stride=(106624, 49, 7, 1))], 1), {})
+cnt: 1, (([T([32, 2048, 7, 7], f16), T([32, 384, 7, 7], f16)], 1), {})
+cnt: 1, (([T([32, 384, 7, 7], f16), T([32, 128, 7, 7], f16, stride=(106624, 49, 7, 1))], 1), {})
+cnt: 1, (([T([32, 2048, 7, 7], f16), T([32, 512, 7, 7], f16)], 1), {})
+cnt: 1, (([T([32, 512, 7, 7], f16), T([32, 128, 7, 7], f16, stride=(106624, 49, 7, 1))], 1), {})
+cnt: 1, (([T([32, 2048, 7, 7], f16), T([32, 640, 7, 7], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([128, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([296, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([200, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 200, 56, 56], f16), T([200, 4, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 50), {})
+cnt: 4, ((T([32, 200, 56, 56], f16), T([276, 200, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 316, 56, 56], f16), T([200, 316, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 336, 56, 56], f16), T([200, 336, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 356, 56, 56], f16), T([200, 356, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 376, 56, 56], f16), T([640, 376, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 376, 56, 56], f16), T([400, 376, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 400, 56, 56], f16), T([400, 8, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 50), {})
+cnt: 8, ((T([32, 400, 28, 28], f16), T([576, 400, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 704, 28, 28], f16), T([400, 704, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([32, 400, 28, 28], f16), T([400, 8, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 50), {})
+cnt: 1, ((T([32, 768, 28, 28], f16), T([400, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 832, 28, 28], f16), T([400, 832, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 896, 28, 28], f16), T([400, 896, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 960, 28, 28], f16), T([400, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 28, 28], f16), T([400, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1088, 28, 28], f16), T([400, 1088, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1152, 28, 28], f16), T([1152, 1152, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1152, 28, 28], f16), T([800, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 800, 28, 28], f16), T([800, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 50), {})
+cnt: 20, ((T([32, 800, 14, 14], f16), T([1088, 800, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1216, 14, 14], f16), T([800, 1216, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 19, ((T([32, 800, 14, 14], f16), T([800, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 50), {})
+cnt: 1, ((T([32, 1280, 14, 14], f16), T([800, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1344, 14, 14], f16), T([800, 1344, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1408, 14, 14], f16), T([800, 1408, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1472, 14, 14], f16), T([800, 1472, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1536, 14, 14], f16), T([800, 1536, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1600, 14, 14], f16), T([800, 1600, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1664, 14, 14], f16), T([800, 1664, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1728, 14, 14], f16), T([800, 1728, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1792, 14, 14], f16), T([800, 1792, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1856, 14, 14], f16), T([800, 1856, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1920, 14, 14], f16), T([800, 1920, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1984, 14, 14], f16), T([800, 1984, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2048, 14, 14], f16), T([800, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2112, 14, 14], f16), T([800, 2112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2176, 14, 14], f16), T([800, 2176, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16), T([800, 2240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2304, 14, 14], f16), T([800, 2304, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2368, 14, 14], f16), T([800, 2368, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2432, 14, 14], f16), T([2304, 2432, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2432, 14, 14], f16), T([1600, 2432, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1600, 14, 14], f16), T([1600, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 50), {})
+cnt: 3, ((T([32, 1600, 7, 7], f16), T([2176, 1600, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2432, 7, 7], f16), T([1600, 2432, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 1600, 7, 7], f16), T([1600, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 50), {})
+cnt: 1, ((T([32, 2560, 7, 7], f16), T([1600, 2560, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2688, 1, 1], f16), T([1000, 2688, 1, 1], f16), T([1000], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 1000, 1, 1], f16), T([32, 2688, 1, 1], f16), T([1000, 2688, 1, 1], f16), [1000], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 2176, 7, 7], f16), T([32, 1600, 7, 7], f16), T([2176, 1600, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 1600, 7, 7], f16), T([32, 1600, 7, 7], f16), T([1600, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 50, [True, True, False]), {})
+cnt: 1, ((T([32, 1600, 7, 7], f16), T([32, 2560, 7, 7], f16), T([1600, 2560, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1600, 7, 7], f16), T([32, 2432, 7, 7], f16), T([1600, 2432, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1600, 7, 7], f16), T([32, 1600, 14, 14], f16), T([1600, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 50, [True, True, False]), {})
+cnt: 1, ((T([32, 1600, 14, 14], f16), T([32, 2432, 14, 14], f16), T([1600, 2432, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2304, 7, 7], f16), T([32, 2432, 14, 14], f16), T([2304, 2432, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 20, ((T([32, 1088, 14, 14], f16), T([32, 800, 14, 14], f16), T([1088, 800, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 19, ((T([32, 800, 14, 14], f16), T([32, 800, 14, 14], f16), T([800, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 50, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 2368, 14, 14], f16), T([800, 2368, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 2304, 14, 14], f16), T([800, 2304, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 2240, 14, 14], f16), T([800, 2240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 2176, 14, 14], f16), T([800, 2176, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 2112, 14, 14], f16), T([800, 2112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 2048, 14, 14], f16), T([800, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1984, 14, 14], f16), T([800, 1984, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1920, 14, 14], f16), T([800, 1920, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1856, 14, 14], f16), T([800, 1856, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1792, 14, 14], f16), T([800, 1792, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1728, 14, 14], f16), T([800, 1728, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1664, 14, 14], f16), T([800, 1664, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1600, 14, 14], f16), T([800, 1600, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1536, 14, 14], f16), T([800, 1536, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1472, 14, 14], f16), T([800, 1472, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1408, 14, 14], f16), T([800, 1408, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1344, 14, 14], f16), T([800, 1344, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1280, 14, 14], f16), T([800, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 1216, 14, 14], f16), T([800, 1216, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 14, 14], f16), T([32, 800, 28, 28], f16), T([800, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 50, [True, True, False]), {})
+cnt: 1, ((T([32, 800, 28, 28], f16), T([32, 1152, 28, 28], f16), T([800, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1152, 14, 14], f16), T([32, 1152, 28, 28], f16), T([1152, 1152, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([32, 576, 28, 28], f16), T([32, 400, 28, 28], f16), T([576, 400, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([32, 400, 28, 28], f16), T([32, 400, 28, 28], f16), T([400, 8, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 50, [True, True, False]), {})
+cnt: 1, ((T([32, 400, 28, 28], f16), T([32, 1088, 28, 28], f16), T([400, 1088, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 400, 28, 28], f16), T([32, 1024, 28, 28], f16), T([400, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 400, 28, 28], f16), T([32, 960, 28, 28], f16), T([400, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 400, 28, 28], f16), T([32, 896, 28, 28], f16), T([400, 896, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 400, 28, 28], f16), T([32, 832, 28, 28], f16), T([400, 832, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 400, 28, 28], f16), T([32, 768, 28, 28], f16), T([400, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 400, 28, 28], f16), T([32, 704, 28, 28], f16), T([400, 704, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 400, 28, 28], f16), T([32, 400, 56, 56], f16), T([400, 8, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 50, [True, True, False]), {})
+cnt: 1, ((T([32, 400, 56, 56], f16), T([32, 376, 56, 56], f16), T([400, 376, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 640, 28, 28], f16), T([32, 376, 56, 56], f16), T([640, 376, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 276, 56, 56], f16), T([32, 200, 56, 56], f16), T([276, 200, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 200, 56, 56], f16), T([32, 200, 56, 56], f16), T([200, 4, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 50, [True, True, False]), {})
+cnt: 1, ((T([32, 200, 56, 56], f16), T([32, 356, 56, 56], f16), T([200, 356, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 200, 56, 56], f16), T([32, 336, 56, 56], f16), T([200, 336, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 200, 56, 56], f16), T([32, 316, 56, 56], f16), T([200, 316, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 200, 56, 56], f16), T([32, 128, 56, 56], f16), T([200, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 296, 56, 56], f16), T([32, 128, 56, 56], f16), T([296, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([32, 3, 224, 224], f16), T([128, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 2688, 7, 7], f16, stride=(2688, 1, 0, 0)), 49), {})
+Operator: aten.elu.default
+cnt: 1, ((T([32, 2688, 7, 7], f16), 1.0), {})
+Operator: aten.elu_backward.default
+cnt: 1, ((T([32, 2688, 7, 7], f16), 1.0, 1, 1, False, T([32, 2688, 7, 7], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([32], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([32, 128, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([32, 128, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 2688, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([32, 128, 112, 112], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 0.001), {})
+cnt: 8, ((T([32, 200, 56, 56], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 316, 56, 56], f16), T([316], f16), T([316], f16), T([316], f16), T([316], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 336, 56, 56], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 356, 56, 56], f16), T([356], f16), T([356], f16), T([356], f16), T([356], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([32, 376, 56, 56], f16), T([376], f16), T([376], f16), T([376], f16), T([376], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 400, 56, 56], f16), T([400], f16), T([400], f16), T([400], f16), T([400], f16), True, 0.1, 0.001), {})
+cnt: 15, ((T([32, 400, 28, 28], f16), T([400], f16), T([400], f16), T([400], f16), T([400], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 704, 28, 28], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 768, 28, 28], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 832, 28, 28], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 896, 28, 28], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 960, 28, 28], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1024, 28, 28], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1088, 28, 28], f16), T([1088], f16), T([1088], f16), T([1088], f16), T([1088], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([32, 1152, 28, 28], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 800, 28, 28], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f16), True, 0.1, 0.001), {})
+cnt: 39, ((T([32, 800, 14, 14], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1216, 14, 14], f16), T([1216], f16), T([1216], f16), T([1216], f16), T([1216], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1280, 14, 14], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1344, 14, 14], f16), T([1344], f16), T([1344], f16), T([1344], f16), T([1344], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1408, 14, 14], f16), T([1408], f16), T([1408], f16), T([1408], f16), T([1408], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1472, 14, 14], f16), T([1472], f16), T([1472], f16), T([1472], f16), T([1472], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1536, 14, 14], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([32, 1600, 14, 14], f16), T([1600], f16), T([1600], f16), T([1600], f16), T([1600], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1664, 14, 14], f16), T([1664], f16), T([1664], f16), T([1664], f16), T([1664], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1728, 14, 14], f16), T([1728], f16), T([1728], f16), T([1728], f16), T([1728], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1792, 14, 14], f16), T([1792], f16), T([1792], f16), T([1792], f16), T([1792], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1856, 14, 14], f16), T([1856], f16), T([1856], f16), T([1856], f16), T([1856], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1920, 14, 14], f16), T([1920], f16), T([1920], f16), T([1920], f16), T([1920], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 1984, 14, 14], f16), T([1984], f16), T([1984], f16), T([1984], f16), T([1984], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 2048, 14, 14], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 2112, 14, 14], f16), T([2112], f16), T([2112], f16), T([2112], f16), T([2112], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 2176, 14, 14], f16), T([2176], f16), T([2176], f16), T([2176], f16), T([2176], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16), T([2240], f16), T([2240], f16), T([2240], f16), T([2240], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 2304, 14, 14], f16), T([2304], f16), T([2304], f16), T([2304], f16), T([2304], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 2368, 14, 14], f16), T([2368], f16), T([2368], f16), T([2368], f16), T([2368], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([32, 2432, 14, 14], f16), T([2432], f16), T([2432], f16), T([2432], f16), T([2432], f16), True, 0.1, 0.001), {})
+cnt: 5, ((T([32, 1600, 7, 7], f16), T([1600], f16), T([1600], f16), T([1600], f16), T([1600], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 2432, 7, 7], f16), T([2432], f16), T([2432], f16), T([2432], f16), T([2432], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 2560, 7, 7], f16), T([2560], f16), T([2560], f16), T([2560], f16), T([2560], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([32, 2688, 7, 7], f16), T([2688], f16), T([2688], f16), T([2688], f16), T([2688], f16), True, 0.1, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([32, 2688, 7, 7], f16), T([32, 2688, 7, 7], f16), T([2688], f16), T([2688], f16), T([2688], f16), T([2688], f32), T([2688], f32), True, 0.001, [True, True, True]), {})
+cnt: 5, ((T([32, 1600, 7, 7], f16), T([32, 1600, 7, 7], f16), T([1600], f16), T([1600], f16), T([1600], f16), T([1600], f32), T([1600], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 2560, 7, 7], f16), T([32, 2560, 7, 7], f16), T([2560], f16), T([2560], f16), T([2560], f16), T([2560], f32), T([2560], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 2432, 7, 7], f16), T([32, 2432, 7, 7], f16), T([2432], f16), T([2432], f16), T([2432], f16), T([2432], f32), T([2432], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([32, 1600, 14, 14], f16), T([32, 1600, 14, 14], f16), T([1600], f16), T([1600], f16), T([1600], f16), T([1600], f32), T([1600], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([32, 2432, 14, 14], f16), T([32, 2432, 14, 14], f16), T([2432], f16), T([2432], f16), T([2432], f16), T([2432], f32), T([2432], f32), True, 0.001, [True, True, True]), {})
+cnt: 39, ((T([32, 800, 14, 14], f16), T([32, 800, 14, 14], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f32), T([800], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 2368, 14, 14], f16), T([32, 2368, 14, 14], f16), T([2368], f16), T([2368], f16), T([2368], f16), T([2368], f32), T([2368], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 2304, 14, 14], f16), T([32, 2304, 14, 14], f16), T([2304], f16), T([2304], f16), T([2304], f16), T([2304], f32), T([2304], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16), T([32, 2240, 14, 14], f16), T([2240], f16), T([2240], f16), T([2240], f16), T([2240], f32), T([2240], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 2176, 14, 14], f16), T([32, 2176, 14, 14], f16), T([2176], f16), T([2176], f16), T([2176], f16), T([2176], f32), T([2176], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 2112, 14, 14], f16), T([32, 2112, 14, 14], f16), T([2112], f16), T([2112], f16), T([2112], f16), T([2112], f32), T([2112], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 2048, 14, 14], f16), T([32, 2048, 14, 14], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1984, 14, 14], f16), T([32, 1984, 14, 14], f16), T([1984], f16), T([1984], f16), T([1984], f16), T([1984], f32), T([1984], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1920, 14, 14], f16), T([32, 1920, 14, 14], f16), T([1920], f16), T([1920], f16), T([1920], f16), T([1920], f32), T([1920], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1856, 14, 14], f16), T([32, 1856, 14, 14], f16), T([1856], f16), T([1856], f16), T([1856], f16), T([1856], f32), T([1856], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1792, 14, 14], f16), T([32, 1792, 14, 14], f16), T([1792], f16), T([1792], f16), T([1792], f16), T([1792], f32), T([1792], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1728, 14, 14], f16), T([32, 1728, 14, 14], f16), T([1728], f16), T([1728], f16), T([1728], f16), T([1728], f32), T([1728], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1664, 14, 14], f16), T([32, 1664, 14, 14], f16), T([1664], f16), T([1664], f16), T([1664], f16), T([1664], f32), T([1664], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1536, 14, 14], f16), T([32, 1536, 14, 14], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f32), T([1536], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1472, 14, 14], f16), T([32, 1472, 14, 14], f16), T([1472], f16), T([1472], f16), T([1472], f16), T([1472], f32), T([1472], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1408, 14, 14], f16), T([32, 1408, 14, 14], f16), T([1408], f16), T([1408], f16), T([1408], f16), T([1408], f32), T([1408], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1344, 14, 14], f16), T([32, 1344, 14, 14], f16), T([1344], f16), T([1344], f16), T([1344], f16), T([1344], f32), T([1344], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1280, 14, 14], f16), T([32, 1280, 14, 14], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1216, 14, 14], f16), T([32, 1216, 14, 14], f16), T([1216], f16), T([1216], f16), T([1216], f16), T([1216], f32), T([1216], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 800, 28, 28], f16), T([32, 800, 28, 28], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f32), T([800], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([32, 1152, 28, 28], f16), T([32, 1152, 28, 28], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f32), T([1152], f32), True, 0.001, [True, True, True]), {})
+cnt: 15, ((T([32, 400, 28, 28], f16), T([32, 400, 28, 28], f16), T([400], f16), T([400], f16), T([400], f16), T([400], f32), T([400], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1088, 28, 28], f16), T([32, 1088, 28, 28], f16), T([1088], f16), T([1088], f16), T([1088], f16), T([1088], f32), T([1088], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 1024, 28, 28], f16), T([32, 1024, 28, 28], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 960, 28, 28], f16), T([32, 960, 28, 28], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 896, 28, 28], f16), T([32, 896, 28, 28], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f32), T([896], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 832, 28, 28], f16), T([32, 832, 28, 28], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f32), T([832], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 768, 28, 28], f16), T([32, 768, 28, 28], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 704, 28, 28], f16), T([32, 704, 28, 28], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f32), T([704], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 400, 56, 56], f16), T([32, 400, 56, 56], f16), T([400], f16), T([400], f16), T([400], f16), T([400], f32), T([400], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([32, 376, 56, 56], f16), T([32, 376, 56, 56], f16), T([376], f16), T([376], f16), T([376], f16), T([376], f32), T([376], f32), True, 0.001, [True, True, True]), {})
+cnt: 8, ((T([32, 200, 56, 56], f16), T([32, 200, 56, 56], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f32), T([200], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 356, 56, 56], f16), T([32, 356, 56, 56], f16), T([356], f16), T([356], f16), T([356], f16), T([356], f32), T([356], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 336, 56, 56], f16), T([32, 336, 56, 56], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 316, 56, 56], f16), T([32, 316, 56, 56], f16), T([316], f16), T([316], f16), T([316], f16), T([316], f32), T([316], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([32, 128, 112, 112], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 0.001, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([32, 1000], f16), T([32], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([32, 1000], f16), T([32], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([32, 128, 112, 112], f16),), {})
+cnt: 2, ((T([32, 128, 56, 56], f16),), {})
+cnt: 8, ((T([32, 200, 56, 56], f16),), {})
+cnt: 1, ((T([32, 316, 56, 56], f16),), {})
+cnt: 1, ((T([32, 336, 56, 56], f16),), {})
+cnt: 1, ((T([32, 356, 56, 56], f16),), {})
+cnt: 2, ((T([32, 376, 56, 56], f16),), {})
+cnt: 1, ((T([32, 400, 56, 56], f16),), {})
+cnt: 15, ((T([32, 400, 28, 28], f16),), {})
+cnt: 1, ((T([32, 704, 28, 28], f16),), {})
+cnt: 1, ((T([32, 768, 28, 28], f16),), {})
+cnt: 1, ((T([32, 832, 28, 28], f16),), {})
+cnt: 1, ((T([32, 896, 28, 28], f16),), {})
+cnt: 1, ((T([32, 960, 28, 28], f16),), {})
+cnt: 1, ((T([32, 1024, 28, 28], f16),), {})
+cnt: 1, ((T([32, 1088, 28, 28], f16),), {})
+cnt: 2, ((T([32, 1152, 28, 28], f16),), {})
+cnt: 1, ((T([32, 800, 28, 28], f16),), {})
+cnt: 39, ((T([32, 800, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1216, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1280, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1344, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1408, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1472, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1536, 14, 14], f16),), {})
+cnt: 2, ((T([32, 1600, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1664, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1728, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1792, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1856, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1920, 14, 14], f16),), {})
+cnt: 1, ((T([32, 1984, 14, 14], f16),), {})
+cnt: 1, ((T([32, 2048, 14, 14], f16),), {})
+cnt: 1, ((T([32, 2112, 14, 14], f16),), {})
+cnt: 1, ((T([32, 2176, 14, 14], f16),), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16),), {})
+cnt: 1, ((T([32, 2304, 14, 14], f16),), {})
+cnt: 1, ((T([32, 2368, 14, 14], f16),), {})
+cnt: 2, ((T([32, 2432, 14, 14], f16),), {})
+cnt: 5, ((T([32, 1600, 7, 7], f16),), {})
+cnt: 1, ((T([32, 2432, 7, 7], f16),), {})
+cnt: 1, ((T([32, 2560, 7, 7], f16),), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([32, 128, 7, 7], f16, stride=(131712, 49, 7, 1)), [32, 128, 7, 7], 3, 0, 9223372036854775807, 1), {})
+cnt: 3, ((T([32, 128, 7, 7], f16), [32, 128, 7, 7], 2, 0, 9223372036854775807, 1), {})
+cnt: 3, ((T([32, 128, 7, 7], f16), [32, 2176, 7, 7], 1, 2048, 9223372036854775807, 1), {})
+cnt: 6, ((T([32, 2176, 7, 7], f16), [32, 2176, 7, 7], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16, stride=(131712, 49, 7, 1)), [32, 2048, 7, 7], 3, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([32, 2048, 7, 7], f16), [32, 2048, 7, 7], 2, 0, 9223372036854775807, 1), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), [32, 2176, 7, 7], 1, 0, 2048, 1), {})
+cnt: 1, ((T([32, 128, 7, 7], f16, stride=(25088, 49, 7, 1)), [32, 128, 7, 7], 3, 0, 9223372036854775807, 1), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), [32, 2048, 7, 7], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 128, 7, 7], f16, stride=(18816, 49, 7, 1)), [32, 128, 7, 7], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 256, 7, 7], f16, stride=(18816, 49, 7, 1)), [32, 256, 7, 7], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 256, 7, 7], f16), [32, 256, 7, 7], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 256, 7, 7], f16), [32, 2304, 7, 7], 1, 2048, 9223372036854775807, 1), {})
+cnt: 2, ((T([32, 2304, 7, 7], f16), [32, 2304, 7, 7], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), [32, 2304, 7, 7], 1, 0, 2048, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(476672, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 20, ((T([32, 64, 14, 14], f16), [32, 64, 14, 14], 2, 0, 9223372036854775807, 1), {})
+cnt: 20, ((T([32, 64, 14, 14], f16), [32, 1088, 14, 14], 1, 1024, 9223372036854775807, 1), {})
+cnt: 40, ((T([32, 1088, 14, 14], f16), [32, 1088, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16, stride=(476672, 196, 14, 1)), [32, 1024, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 21, ((T([32, 1024, 14, 14], f16), [32, 1024, 14, 14], 2, 0, 9223372036854775807, 1), {})
+cnt: 20, ((T([32, 1024, 14, 14], f16), [32, 1088, 14, 14], 1, 0, 1024, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(263424, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 20, ((T([32, 1024, 14, 14], f16), [32, 1024, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(250880, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(238336, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(225792, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(213248, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(200704, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(188160, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(175616, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(163072, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(150528, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(137984, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(125440, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(112896, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(100352, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(87808, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(75264, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(62720, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(50176, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 14, 14], f16, stride=(37632, 196, 14, 1)), [32, 64, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 128, 14, 14], f16, stride=(37632, 196, 14, 1)), [32, 128, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 128, 14, 14], f16), [32, 128, 14, 14], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 128, 14, 14], f16), [32, 1152, 14, 14], 1, 1024, 9223372036854775807, 1), {})
+cnt: 2, ((T([32, 1152, 14, 14], f16), [32, 1152, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), [32, 1152, 14, 14], 1, 0, 1024, 1), {})
+cnt: 1, ((T([32, 64, 28, 28], f16, stride=(903168, 784, 28, 1)), [32, 64, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 8, ((T([32, 64, 28, 28], f16), [32, 64, 28, 28], 2, 0, 9223372036854775807, 1), {})
+cnt: 8, ((T([32, 64, 28, 28], f16), [32, 576, 28, 28], 1, 512, 9223372036854775807, 1), {})
+cnt: 16, ((T([32, 576, 28, 28], f16), [32, 576, 28, 28], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 512, 28, 28], f16, stride=(903168, 784, 28, 1)), [32, 512, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 9, ((T([32, 512, 28, 28], f16), [32, 512, 28, 28], 2, 0, 9223372036854775807, 1), {})
+cnt: 8, ((T([32, 512, 28, 28], f16), [32, 576, 28, 28], 1, 0, 512, 1), {})
+cnt: 1, ((T([32, 64, 28, 28], f16, stride=(451584, 784, 28, 1)), [32, 64, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 8, ((T([32, 512, 28, 28], f16), [32, 512, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 28, 28], f16, stride=(401408, 784, 28, 1)), [32, 64, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 28, 28], f16, stride=(351232, 784, 28, 1)), [32, 64, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 28, 28], f16, stride=(301056, 784, 28, 1)), [32, 64, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 28, 28], f16, stride=(250880, 784, 28, 1)), [32, 64, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 28, 28], f16, stride=(200704, 784, 28, 1)), [32, 64, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 64, 28, 28], f16, stride=(150528, 784, 28, 1)), [32, 64, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 128, 28, 28], f16, stride=(150528, 784, 28, 1)), [32, 128, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 128, 28, 28], f16), [32, 128, 28, 28], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 128, 28, 28], f16), [32, 640, 28, 28], 1, 512, 9223372036854775807, 1), {})
+cnt: 2, ((T([32, 640, 28, 28], f16), [32, 640, 28, 28], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), [32, 640, 28, 28], 1, 0, 512, 1), {})
+cnt: 1, ((T([32, 20, 56, 56], f16, stride=(1179136, 3136, 56, 1)), [32, 20, 56, 56], 3, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([32, 20, 56, 56], f16), [32, 20, 56, 56], 2, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([32, 20, 56, 56], f16), [32, 276, 56, 56], 1, 256, 9223372036854775807, 1), {})
+cnt: 8, ((T([32, 276, 56, 56], f16), [32, 276, 56, 56], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16, stride=(1179136, 3136, 56, 1)), [32, 256, 56, 56], 3, 0, 9223372036854775807, 1), {})
+cnt: 5, ((T([32, 256, 56, 56], f16), [32, 256, 56, 56], 2, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), [32, 276, 56, 56], 1, 0, 256, 1), {})
+cnt: 1, ((T([32, 20, 56, 56], f16, stride=(313600, 3136, 56, 1)), [32, 20, 56, 56], 3, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), [32, 256, 56, 56], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 20, 56, 56], f16, stride=(250880, 3136, 56, 1)), [32, 20, 56, 56], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 20, 56, 56], f16, stride=(188160, 3136, 56, 1)), [32, 20, 56, 56], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 40, 56, 56], f16, stride=(188160, 3136, 56, 1)), [32, 40, 56, 56], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 40, 56, 56], f16), [32, 40, 56, 56], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 40, 56, 56], f16), [32, 296, 56, 56], 1, 256, 9223372036854775807, 1), {})
+cnt: 2, ((T([32, 296, 56, 56], f16), [32, 296, 56, 56], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), [32, 296, 56, 56], 1, 0, 256, 1), {})
+Operator: aten.threshold_backward.default
+cnt: 5, ((T([32, 1600, 7, 7], f16), T([32, 1600, 7, 7], f16), 0), {})
+cnt: 1, ((T([32, 2560, 7, 7], f16), T([32, 2560, 7, 7], f16), 0), {})
+cnt: 1, ((T([32, 2432, 7, 7], f16), T([32, 2432, 7, 7], f16), 0), {})
+cnt: 2, ((T([32, 1600, 14, 14], f16), T([32, 1600, 14, 14], f16), 0), {})
+cnt: 2, ((T([32, 2432, 14, 14], f16), T([32, 2432, 14, 14], f16), 0), {})
+cnt: 39, ((T([32, 800, 14, 14], f16), T([32, 800, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 2368, 14, 14], f16), T([32, 2368, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 2304, 14, 14], f16), T([32, 2304, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16), T([32, 2240, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 2176, 14, 14], f16), T([32, 2176, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 2112, 14, 14], f16), T([32, 2112, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 2048, 14, 14], f16), T([32, 2048, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1984, 14, 14], f16), T([32, 1984, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1920, 14, 14], f16), T([32, 1920, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1856, 14, 14], f16), T([32, 1856, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1792, 14, 14], f16), T([32, 1792, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1728, 14, 14], f16), T([32, 1728, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1664, 14, 14], f16), T([32, 1664, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1536, 14, 14], f16), T([32, 1536, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1472, 14, 14], f16), T([32, 1472, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1408, 14, 14], f16), T([32, 1408, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1344, 14, 14], f16), T([32, 1344, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1280, 14, 14], f16), T([32, 1280, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 1216, 14, 14], f16), T([32, 1216, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 800, 28, 28], f16), T([32, 800, 28, 28], f16), 0), {})
+cnt: 2, ((T([32, 1152, 28, 28], f16), T([32, 1152, 28, 28], f16), 0), {})
+cnt: 15, ((T([32, 400, 28, 28], f16), T([32, 400, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 1088, 28, 28], f16), T([32, 1088, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 1024, 28, 28], f16), T([32, 1024, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 960, 28, 28], f16), T([32, 960, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 896, 28, 28], f16), T([32, 896, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 832, 28, 28], f16), T([32, 832, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 768, 28, 28], f16), T([32, 768, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 704, 28, 28], f16), T([32, 704, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 400, 56, 56], f16), T([32, 400, 56, 56], f16), 0), {})
+cnt: 2, ((T([32, 376, 56, 56], f16), T([32, 376, 56, 56], f16), 0), {})
+cnt: 8, ((T([32, 200, 56, 56], f16), T([32, 200, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 356, 56, 56], f16), T([32, 356, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 336, 56, 56], f16), T([32, 336, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 316, 56, 56], f16), T([32, 316, 56, 56], f16), 0), {})
+cnt: 2, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([32, 128, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/eca_botnext26ts_256_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/eca_botnext26ts_256_training.txt
new file mode 100644
index 0000000000000..ab778074aa37f
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/eca_botnext26ts_256_training.txt
@@ -0,0 +1,288 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 2, ((T([512, 256, 256], f16), -1, False), {})
+cnt: 1, ((T([512, 64, 64], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 1, ((T([512, 64, 64], f16), T([512, 64, 64], f16), -1, f16), {})
+cnt: 2, ((T([512, 256, 256], f16), T([512, 256, 256], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 4, ((T([128, 64, 16, 16], f16), [512, 16, 256]), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), [512, 64, 256]), {})
+cnt: 2, ((T([512, 256, 256], f16), [512, 256, 256]), {})
+cnt: 4, ((T([512, 16, 16, 16], f16), [131072, 16]), {})
+cnt: 4, ((T([131072, 31], f16), [512, 16, 16, 31]), {})
+cnt: 2, ((T([512, 16, 16, 16, 16], f16), [512, 256, 256]), {})
+cnt: 1, ((T([512, 256, 64], f16), [512, 256, 64]), {})
+cnt: 2, ((T([512, 64, 256], f16), [128, 256, 16, 16]), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), [512, 128, 256]), {})
+cnt: 1, ((T([512, 256, 128], f16), [512, 256, 128]), {})
+cnt: 2, ((T([512, 128, 256], f16), [128, 512, 16, 16]), {})
+cnt: 2, ((T([128, 64, 8, 8], f16), [512, 16, 64]), {})
+cnt: 1, ((T([128, 512, 8, 8], f16), [512, 128, 64]), {})
+cnt: 1, ((T([512, 64, 64], f16), [512, 64, 64]), {})
+cnt: 2, ((T([512, 8, 8, 16], f16), [32768, 16]), {})
+cnt: 2, ((T([32768, 15], f16), [512, 8, 8, 15]), {})
+cnt: 1, ((T([512, 8, 8, 8, 8], f16), [512, 64, 64]), {})
+cnt: 1, ((T([512, 64, 128], f16), [512, 64, 128]), {})
+cnt: 2, ((T([512, 128, 64], f16), [128, 512, 8, 8]), {})
+cnt: 1, ((T([512, 8, 8, 16], f16), [512, 64, 16]), {})
+cnt: 1, ((T([512, 16, 64], f16), [128, 64, 8, 8]), {})
+cnt: 2, ((T([512, 16, 16, 16], f16), [512, 256, 16]), {})
+cnt: 2, ((T([512, 16, 256], f16), [128, 64, 16, 16]), {})
+Operator: aten.add.Tensor
+cnt: 31, ((T([], i64), 1), {})
+cnt: 4, ((T([128, 256, 64, 64], f16), T([128, 256, 64, 64], f16)), {})
+cnt: 4, ((T([128, 512, 32, 32], f16), T([128, 512, 32, 32], f16)), {})
+cnt: 4, ((T([128, 1024, 16, 16], f16), T([128, 1024, 16, 16], f16)), {})
+cnt: 2, ((T([512, 16, 16, 16, 16], f16, stride=(8432, 31, 527, 1, 0)), T([512, 16, 16, 16, 16], f16, stride=(8432, 527, 31, 0, 1))), {})
+cnt: 2, ((T([512, 256, 256], f16), T([512, 256, 256], f16)), {})
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16)), {})
+cnt: 1, ((T([512, 8, 8, 8, 8], f16, stride=(1080, 15, 135, 1, 0)), T([512, 8, 8, 8, 8], f16, stride=(1080, 135, 15, 0, 1))), {})
+cnt: 1, ((T([512, 64, 64], f16), T([512, 64, 64], f16)), {})
+cnt: 1, ((T([512, 8, 8, 16], f16, stride=(1024, 16, 128, 1)), T([512, 8, 8, 16], f16)), {})
+cnt: 1, ((T([512, 64, 16], f16), T([512, 64, 16], f16)), {})
+cnt: 2, ((T([512, 16, 16, 16], f16, stride=(4096, 16, 256, 1)), T([512, 16, 16, 16], f16)), {})
+cnt: 2, ((T([512, 256, 16], f16), T([512, 256, 16], f16)), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16)), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16)), {})
+cnt: 3, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 512, 16, 16], f16), [2, 2], [2, 2]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 512, 8, 8], f16), T([128, 512, 16, 16], f16), [2, 2], [2, 2], [0, 0], False, True, None), {})
+Operator: aten.bmm.default
+cnt: 2, ((T([512, 256, 16], f16, stride=(4096, 1, 256)), T([512, 16, 256], f16)), {})
+cnt: 1, ((T([512, 256, 256], f16), T([512, 256, 64], f16, stride=(16384, 1, 256))), {})
+cnt: 1, ((T([512, 256, 256], f16), T([512, 256, 128], f16, stride=(32768, 1, 256))), {})
+cnt: 1, ((T([512, 64, 16], f16, stride=(1024, 1, 64)), T([512, 16, 64], f16)), {})
+cnt: 1, ((T([512, 64, 64], f16), T([512, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 1, ((T([512, 64, 64], f16, stride=(4096, 1, 64)), T([512, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 1, ((T([512, 64, 128], f16, stride=(8192, 1, 64)), T([512, 128, 64], f16)), {})
+cnt: 1, ((T([512, 16, 64], f16), T([512, 64, 64], f16)), {})
+cnt: 1, ((T([512, 64, 64], f16), T([512, 64, 16], f16, stride=(1024, 1, 64))), {})
+cnt: 1, ((T([512, 256, 256], f16, stride=(65536, 1, 256)), T([512, 256, 128], f16, stride=(32768, 1, 256))), {})
+cnt: 1, ((T([512, 256, 128], f16, stride=(32768, 1, 256)), T([512, 128, 256], f16)), {})
+cnt: 2, ((T([512, 16, 256], f16), T([512, 256, 256], f16)), {})
+cnt: 2, ((T([512, 256, 256], f16), T([512, 256, 16], f16, stride=(4096, 1, 256))), {})
+cnt: 1, ((T([512, 256, 256], f16, stride=(65536, 1, 256)), T([512, 256, 64], f16, stride=(16384, 1, 256))), {})
+cnt: 1, ((T([512, 256, 64], f16, stride=(16384, 1, 256)), T([512, 64, 256], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 64, 8, 8], f16), T([128, 64, 8, 8], f16), T([128, 512, 8, 8], f16)], 1), {})
+cnt: 1, (([T([128, 64, 16, 16], f16), T([128, 64, 16, 16], f16), T([128, 512, 16, 16], f16)], 1), {})
+cnt: 1, (([T([128, 64, 16, 16], f16), T([128, 64, 16, 16], f16), T([128, 256, 16, 16], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 256, 256], f16),), {})
+cnt: 1, ((T([128, 24, 128, 128], f16),), {})
+cnt: 1, ((T([128, 32, 128, 128], f16),), {})
+cnt: 1, ((T([128, 64, 128, 128], f16),), {})
+cnt: 4, ((T([128, 64, 64, 64], f16),), {})
+cnt: 2, ((T([128, 256, 64, 64], f16),), {})
+cnt: 1, ((T([128, 128, 64, 64], f16),), {})
+cnt: 3, ((T([128, 128, 32, 32], f16),), {})
+cnt: 2, ((T([128, 512, 32, 32], f16),), {})
+cnt: 1, ((T([128, 256, 32, 32], f16),), {})
+cnt: 3, ((T([128, 256, 16, 16], f16),), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16),), {})
+cnt: 1, ((T([128, 512, 16, 16], f16),), {})
+cnt: 3, ((T([128, 512, 8, 8], f16),), {})
+cnt: 2, ((T([128, 2048, 8, 8], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 4, ((T([8192, 16, 31], f16), [0, 1], 0.0), {})
+cnt: 4, ((T([8192, 512], f16), [0, 15], 0.0), {})
+cnt: 2, ((T([4096, 8, 15], f16), [0, 1], 0.0), {})
+cnt: 2, ((T([4096, 128], f16), [0, 7], 0.0), {})
+cnt: 2, ((T([4096, 135], f16), [0, -7]), {})
+cnt: 2, ((T([4096, 8, 16], f16), [0, -1]), {})
+cnt: 4, ((T([8192, 527], f16), [0, -15]), {})
+cnt: 4, ((T([8192, 16, 32], f16), [0, -1]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 256, 256], f16), T([24, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([32, 24, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), T([64, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 4), {})
+cnt: 2, ((T([128, 1, 64], f16), T([1, 1, 3], f16), None, [1], [1], [1], False, [0], 1), {})
+cnt: 3, ((T([128, 64, 64, 64], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 64, 64], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 64, 64], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 2, ((T([128, 1, 128], f16), T([1, 1, 5], f16), None, [1], [2], [1], False, [0], 1), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 64, 64], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([256, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 1, ((T([128, 1, 256], f16), T([1, 1, 5], f16), None, [1], [2], [1], False, [0], 1), {})
+cnt: 2, ((T([128, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([384, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([640, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 512, 8, 8], f16), T([2048, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([512, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 8, 8], f16), T([640, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 2, ((T([128, 2048, 8, 8], f16), T([128, 512, 8, 8], f16), T([2048, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 640, 8, 8], f16), T([128, 512, 8, 8], f16), T([640, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 8, 8], f16), T([128, 2048, 8, 8], f16), T([512, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([128, 1024, 16, 16], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 640, 16, 16], f16), T([128, 512, 16, 16], f16), T([640, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16), T([128, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 16, 16], f16), T([128, 256, 16, 16], f16), T([384, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([128, 512, 32, 32], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1, 256], f16), T([128, 1, 256], f16), T([1, 1, 5], f16), [0], [1], [2], [1], False, [0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 256, 32, 32], f16), T([256, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 16, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([128, 512, 32, 32], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 512, 32, 32], f16), T([128, 128, 32, 32], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 1, 128], f16), T([128, 1, 128], f16), T([1, 1, 5], f16), [0], [1], [2], [1], False, [0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16), T([128, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 512, 32, 32], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([128, 256, 64, 64], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 128, 64, 64], f16), T([128, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 256, 64, 64], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 256, 64, 64], f16), T([128, 64, 64, 64], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 1, 64], f16), T([128, 1, 64], f16), T([1, 1, 3], f16), [0], [1], [1], [1], False, [0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), T([64, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 256, 64, 64], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([128, 32, 128, 128], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 24, 128, 128], f16), T([32, 24, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([128, 3, 256, 256], f16), T([24, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 256, 256], f16), T([128, 3, 256, 256], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2048, 8, 8], f16, stride=(2048, 1, 0, 0)), 64), {})
+cnt: 1, ((T([128, 256, 16, 16], f16, stride=(256, 1, 0, 0)), 256), {})
+cnt: 2, ((T([128, 128, 32, 32], f16, stride=(128, 1, 0, 0)), 1024), {})
+cnt: 2, ((T([128, 64, 64, 64], f16, stride=(64, 1, 0, 0)), 4096), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 128, 128], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 64, 128, 128], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([128, 64, 64, 64], i64)), {})
+Operator: aten.mean.dim
+cnt: 2, ((T([128, 64, 64, 64], f16), [2, 3]), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), [2, 3]), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), [2, 3]), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 4, ((T([131072, 16], f16), T([16, 31], f16, stride=(1, 16))), {})
+cnt: 2, ((T([32768, 16], f16), T([16, 15], f16, stride=(1, 16))), {})
+cnt: 1, ((T([128, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2048], f16)), {})
+cnt: 2, ((T([15, 32768], f16, stride=(1, 15)), T([32768, 16], f16)), {})
+cnt: 2, ((T([32768, 15], f16), T([15, 16], f16)), {})
+cnt: 4, ((T([31, 131072], f16, stride=(1, 31)), T([131072, 16], f16)), {})
+cnt: 4, ((T([131072, 31], f16), T([31, 16], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16, stride=(64, 1, 0, 0))), {})
+cnt: 4, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16, stride=(128, 1, 0, 0))), {})
+cnt: 2, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16, stride=(256, 1, 0, 0))), {})
+cnt: 4, ((T([512, 256, 256], f16), 0.25), {})
+cnt: 2, ((T([512, 64, 64], f16), 0.25), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16)), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16)), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 24, 128, 128], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), T([128, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 1024, 16, 16], f16), T([128, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([128, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 512, 32, 32], f16), T([128, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 64, 64], f16), T([128, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([128, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([128, 24, 128, 128], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.sigmoid.default
+cnt: 2, ((T([128, 1, 64], f16),), {})
+cnt: 2, ((T([128, 1, 128], f16),), {})
+cnt: 1, ((T([128, 1, 256], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([128, 1, 256], f16), T([128, 1, 256], f16)), {})
+cnt: 2, ((T([128, 1, 128], f16), T([128, 1, 128], f16)), {})
+cnt: 2, ((T([128, 1, 64], f16), T([128, 1, 64], f16)), {})
+Operator: aten.silu_.default
+cnt: 1, ((T([128, 24, 128, 128], f16),), {})
+cnt: 1, ((T([128, 32, 128, 128], f16),), {})
+cnt: 1, ((T([128, 64, 128, 128], f16),), {})
+cnt: 4, ((T([128, 64, 64, 64], f16),), {})
+cnt: 2, ((T([128, 256, 64, 64], f16),), {})
+cnt: 1, ((T([128, 128, 64, 64], f16),), {})
+cnt: 3, ((T([128, 128, 32, 32], f16),), {})
+cnt: 2, ((T([128, 512, 32, 32], f16),), {})
+cnt: 1, ((T([128, 256, 32, 32], f16),), {})
+cnt: 3, ((T([128, 256, 16, 16], f16),), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16),), {})
+cnt: 1, ((T([128, 512, 16, 16], f16),), {})
+cnt: 3, ((T([128, 512, 8, 8], f16),), {})
+cnt: 2, ((T([128, 2048, 8, 8], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 2, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16)), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), T([128, 512, 8, 8], f16)), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 512, 16, 16], f16)), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16), T([128, 1024, 16, 16], f16)), {})
+cnt: 3, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16)), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([128, 256, 32, 32], f16)), {})
+cnt: 2, ((T([128, 512, 32, 32], f16), T([128, 512, 32, 32], f16)), {})
+cnt: 3, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16)), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16)), {})
+cnt: 2, ((T([128, 256, 64, 64], f16), T([128, 256, 64, 64], f16)), {})
+cnt: 4, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16)), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([128, 64, 128, 128], f16)), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 128, 128], f16)), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([128, 24, 128, 128], f16)), {})
+Operator: aten.slice_backward.default
+cnt: 2, ((T([4096, 8, 8], f16), [4096, 8, 15], 2, 7, 9223372036854775807, 1), {})
+cnt: 2, ((T([4096, 8, 15], f16), [4096, 9, 15], 1, 0, 8, 1), {})
+cnt: 2, ((T([4096, 9, 15], f16), [4096, 9, 15], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([8192, 16, 16], f16), [8192, 16, 31], 2, 15, 9223372036854775807, 1), {})
+cnt: 4, ((T([8192, 16, 31], f16), [8192, 17, 31], 1, 0, 16, 1), {})
+cnt: 4, ((T([8192, 17, 31], f16), [8192, 17, 31], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.split_with_sizes.default
+cnt: 1, ((T([128, 384, 16, 16], f16), [64, 64, 256], 1), {})
+cnt: 1, ((T([128, 640, 16, 16], f16), [64, 64, 512], 1), {})
+cnt: 1, ((T([128, 640, 8, 8], f16), [64, 64, 512], 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 1, ((T([512, 8, 8, 8, 8], f16, stride=(4096, 64, 1, 512, 8)), [2], True), {})
+cnt: 1, ((T([512, 8, 8, 8, 8], f16, stride=(4096, 512, 8, 64, 1)), [2], True), {})
+cnt: 2, ((T([512, 16, 16, 16, 16], f16, stride=(65536, 256, 1, 4096, 16)), [2], True), {})
+cnt: 2, ((T([512, 16, 16, 16, 16], f16, stride=(65536, 4096, 16, 256, 1)), [2], True), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), [2, 3], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/eca_halonext26ts_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/eca_halonext26ts_training.txt
new file mode 100644
index 0000000000000..714fcdbbaf06b
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/eca_halonext26ts_training.txt
@@ -0,0 +1,343 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 1, ((T([1024, 4, 64, 144], f16), -1, False), {})
+cnt: 1, ((T([1024, 4, 16, 144], f16), -1, False), {})
+cnt: 1, ((T([1024, 1, 64, 144], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 1, ((T([1024, 1, 64, 144], f16), T([1024, 1, 64, 144], f16), -1, f16), {})
+cnt: 1, ((T([1024, 4, 16, 144], f16), T([1024, 4, 16, 144], f16), -1, f16), {})
+cnt: 1, ((T([1024, 4, 64, 144], f16), T([1024, 4, 64, 144], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 1, ((T([1024, 16, 8, 8, 2, 2], f16), [1024, 16, 64, 4]), {})
+cnt: 1, ((T([128, 384, 2, 2, 12, 12], f16), [1024, 48, 4, 144]), {})
+cnt: 1, ((T([1024, 4, 64, 16], f16), [4096, 64, 16]), {})
+cnt: 2, ((T([1024, 4, 16, 144], f16), [4096, 16, 144]), {})
+cnt: 1, ((T([4096, 64, 144], f16), [1024, 4, 64, 144]), {})
+cnt: 1, ((T([1024, 4, 64, 16], f16), [4096, 8, 8, 16]), {})
+cnt: 2, ((T([262144, 23], f16), [4096, 8, 8, 23]), {})
+cnt: 1, ((T([4096, 8, 8, 16], f16), [262144, 16]), {})
+cnt: 1, ((T([4096, 8, 8, 12, 12], f16), [1024, 4, 64, 144]), {})
+cnt: 1, ((T([1024, 4, 144, 32], f16), [4096, 144, 32]), {})
+cnt: 1, ((T([4096, 64, 32], f16), [1024, 4, 64, 32]), {})
+cnt: 1, ((T([1024, 32, 64, 4], f16), [32768, 8, 8, 2, 2]), {})
+cnt: 1, ((T([1024, 16, 4, 4, 2, 2], f16), [1024, 16, 16, 4]), {})
+cnt: 1, ((T([128, 640, 2, 2, 12, 12], f16), [1024, 80, 4, 144]), {})
+cnt: 1, ((T([1024, 4, 16, 16], f16), [4096, 16, 16]), {})
+cnt: 1, ((T([4096, 16, 144], f16), [1024, 4, 16, 144]), {})
+cnt: 1, ((T([1024, 4, 16, 16], f16), [4096, 4, 4, 16]), {})
+cnt: 2, ((T([65536, 23], f16), [4096, 4, 4, 23]), {})
+cnt: 1, ((T([4096, 4, 4, 16], f16), [65536, 16]), {})
+cnt: 1, ((T([4096, 4, 4, 12, 12], f16), [1024, 4, 16, 144]), {})
+cnt: 1, ((T([1024, 4, 144, 64], f16), [4096, 144, 64]), {})
+cnt: 1, ((T([4096, 16, 64], f16), [1024, 4, 16, 64]), {})
+cnt: 1, ((T([1024, 64, 16, 4], f16), [65536, 4, 4, 2, 2]), {})
+cnt: 1, ((T([1024, 64, 144], f16), [1024, 1, 64, 144]), {})
+cnt: 2, ((T([1024, 8, 8, 16], f16), [65536, 16]), {})
+cnt: 2, ((T([65536, 23], f16), [1024, 8, 8, 23]), {})
+cnt: 1, ((T([1024, 8, 8, 12, 12], f16), [1024, 1, 64, 144]), {})
+cnt: 1, ((T([1024, 64, 64], f16), [1024, 1, 64, 64]), {})
+cnt: 1, ((T([1024, 64, 64, 1], f16), [65536, 8, 8, 1, 1]), {})
+cnt: 1, ((T([1024, 8, 8, 16], f16), [1024, 1, 64, 16]), {})
+cnt: 1, ((T([1024, 80, 1, 144], f16), [128, 640, 1, 1, 12, 12]), {})
+cnt: 1, ((T([1024, 16, 1, 8, 1, 8], f16), [128, 128, 8, 8]), {})
+cnt: 1, ((T([65536, 4, 4, 2, 2], f16), [1024, 64, 16, 4]), {})
+cnt: 1, ((T([1024, 4, 16, 64], f16), [4096, 16, 64]), {})
+cnt: 1, ((T([4096, 4, 4, 16], f16), [1024, 4, 16, 16]), {})
+cnt: 1, ((T([1024, 80, 4, 144], f16), [128, 640, 2, 2, 12, 12]), {})
+cnt: 1, ((T([1024, 16, 2, 4, 2, 4], f16), [128, 128, 8, 8]), {})
+cnt: 1, ((T([32768, 8, 8, 2, 2], f16), [1024, 32, 64, 4]), {})
+cnt: 1, ((T([1024, 4, 64, 32], f16), [4096, 64, 32]), {})
+cnt: 1, ((T([4096, 8, 8, 16], f16), [1024, 4, 64, 16]), {})
+cnt: 1, ((T([1024, 48, 4, 144], f16), [128, 384, 2, 2, 12, 12]), {})
+cnt: 1, ((T([1024, 16, 2, 8, 2, 8], f16), [128, 128, 16, 16]), {})
+Operator: aten.add.Tensor
+cnt: 31, ((T([], i64), 1), {})
+cnt: 4, ((T([128, 256, 64, 64], f16), T([128, 256, 64, 64], f16)), {})
+cnt: 4, ((T([128, 512, 32, 32], f16), T([128, 512, 32, 32], f16)), {})
+cnt: 4, ((T([128, 1024, 16, 16], f16), T([128, 1024, 16, 16], f16)), {})
+cnt: 1, ((T([4096, 8, 8, 12, 12], f16, stride=(1656, 23, 207, 1, 0)), T([4096, 8, 8, 12, 12], f16, stride=(1656, 207, 23, 0, 1))), {})
+cnt: 1, ((T([1024, 4, 64, 144], f16), T([1024, 4, 64, 144], f16)), {})
+cnt: 1, ((T([4096, 4, 4, 12, 12], f16, stride=(460, 23, 115, 1, 0)), T([4096, 4, 4, 12, 12], f16, stride=(460, 115, 23, 0, 1))), {})
+cnt: 1, ((T([1024, 4, 16, 144], f16), T([1024, 4, 16, 144], f16)), {})
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16)), {})
+cnt: 1, ((T([1024, 8, 8, 12, 12], f16, stride=(1656, 23, 207, 1, 0)), T([1024, 8, 8, 12, 12], f16, stride=(1656, 207, 23, 0, 1))), {})
+cnt: 1, ((T([1024, 1, 64, 144], f16), T([1024, 1, 64, 144], f16)), {})
+cnt: 1, ((T([1024, 8, 8, 16], f16, stride=(1024, 16, 128, 1)), T([1024, 8, 8, 16], f16)), {})
+cnt: 1, ((T([1024, 1, 64, 16], f16), T([1024, 1, 64, 16], f16)), {})
+cnt: 1, ((T([128, 512, 8, 8], f16), T([128, 512, 8, 8], f16)), {})
+cnt: 1, ((T([4096, 4, 4, 16], f16, stride=(256, 16, 64, 1)), T([4096, 4, 4, 16], f16)), {})
+cnt: 1, ((T([1024, 4, 16, 16], f16), T([1024, 4, 16, 16], f16)), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 512, 16, 16], f16)), {})
+cnt: 1, ((T([4096, 8, 8, 16], f16, stride=(1024, 16, 128, 1)), T([4096, 8, 8, 16], f16)), {})
+cnt: 1, ((T([1024, 4, 64, 16], f16), T([1024, 4, 64, 16], f16)), {})
+cnt: 2, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16)), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16)), {})
+cnt: 3, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.bmm.default
+cnt: 1, ((T([4096, 64, 16], f16), T([4096, 16, 144], f16)), {})
+cnt: 1, ((T([4096, 64, 144], f16), T([4096, 144, 32], f16)), {})
+cnt: 1, ((T([4096, 16, 16], f16), T([4096, 16, 144], f16)), {})
+cnt: 1, ((T([4096, 16, 144], f16), T([4096, 144, 64], f16)), {})
+cnt: 1, ((T([1024, 64, 16], f16, stride=(1024, 1, 64)), T([1024, 16, 144], f16, stride=(11520, 144, 1))), {})
+cnt: 1, ((T([1024, 64, 144], f16), T([1024, 144, 64], f16, stride=(11520, 1, 144))), {})
+cnt: 1, ((T([1024, 144, 64], f16, stride=(9216, 1, 144)), T([1024, 64, 64], f16, stride=(4096, 1, 64))), {})
+cnt: 1, ((T([1024, 64, 64], f16, stride=(4096, 1, 64)), T([1024, 64, 144], f16, stride=(11520, 144, 1))), {})
+cnt: 1, ((T([1024, 16, 64], f16), T([1024, 64, 144], f16)), {})
+cnt: 1, ((T([1024, 64, 144], f16), T([1024, 144, 16], f16, stride=(11520, 1, 144))), {})
+cnt: 1, ((T([4096, 144, 16], f16, stride=(2304, 1, 144)), T([4096, 16, 64], f16)), {})
+cnt: 1, ((T([4096, 16, 64], f16), T([4096, 64, 144], f16, stride=(9216, 1, 64))), {})
+cnt: 1, ((T([4096, 16, 16], f16, stride=(256, 1, 16)), T([4096, 16, 144], f16)), {})
+cnt: 1, ((T([4096, 16, 144], f16), T([4096, 144, 16], f16, stride=(2304, 1, 144))), {})
+cnt: 1, ((T([4096, 144, 64], f16, stride=(9216, 1, 144)), T([4096, 64, 32], f16)), {})
+cnt: 1, ((T([4096, 64, 32], f16), T([4096, 32, 144], f16, stride=(4608, 1, 32))), {})
+cnt: 1, ((T([4096, 16, 64], f16, stride=(1024, 1, 16)), T([4096, 64, 144], f16)), {})
+cnt: 1, ((T([4096, 64, 144], f16), T([4096, 144, 16], f16, stride=(2304, 1, 144))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([1024, 1, 144, 16], f16, stride=(2304, 2304, 1, 144)), T([1024, 1, 144, 64], f16)], 3), {})
+cnt: 1, (([T([1024, 4, 144, 16], f16, stride=(9216, 2304, 1, 144)), T([1024, 4, 144, 64], f16)], 3), {})
+cnt: 1, (([T([1024, 4, 144, 16], f16, stride=(9216, 2304, 1, 144)), T([1024, 4, 144, 32], f16)], 3), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 256, 256], f16),), {})
+cnt: 1, ((T([128, 24, 128, 128], f16),), {})
+cnt: 1, ((T([128, 32, 128, 128], f16),), {})
+cnt: 1, ((T([128, 64, 128, 128], f16),), {})
+cnt: 4, ((T([128, 64, 64, 64], f16),), {})
+cnt: 2, ((T([128, 256, 64, 64], f16),), {})
+cnt: 1, ((T([128, 128, 64, 64], f16),), {})
+cnt: 3, ((T([128, 128, 32, 32], f16),), {})
+cnt: 2, ((T([128, 512, 32, 32], f16),), {})
+cnt: 1, ((T([128, 256, 32, 32], f16),), {})
+cnt: 3, ((T([128, 256, 16, 16], f16),), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16),), {})
+cnt: 1, ((T([128, 512, 16, 16], f16),), {})
+cnt: 3, ((T([128, 512, 8, 8], f16),), {})
+cnt: 2, ((T([128, 2048, 8, 8], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([128, 384, 16, 16], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 2, ((T([32768, 8, 23], f16), [0, 1], 0.0), {})
+cnt: 2, ((T([32768, 192], f16), [0, 15], 0.0), {})
+cnt: 1, ((T([128, 640, 16, 16], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 2, ((T([16384, 4, 23], f16), [0, 1], 0.0), {})
+cnt: 2, ((T([16384, 96], f16), [0, 19], 0.0), {})
+cnt: 1, ((T([128, 640, 8, 8], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 2, ((T([8192, 8, 23], f16), [0, 1], 0.0), {})
+cnt: 2, ((T([8192, 192], f16), [0, 15], 0.0), {})
+cnt: 2, ((T([8192, 207], f16), [0, -15]), {})
+cnt: 2, ((T([8192, 8, 24], f16), [0, -1]), {})
+cnt: 1, ((T([128, 640, 12, 12], f16), [-2, -2, -2, -2]), {})
+cnt: 2, ((T([16384, 115], f16), [0, -19]), {})
+cnt: 2, ((T([16384, 4, 24], f16), [0, -1]), {})
+cnt: 1, ((T([128, 640, 20, 20], f16), [-2, -2, -2, -2]), {})
+cnt: 2, ((T([32768, 207], f16), [0, -15]), {})
+cnt: 2, ((T([32768, 8, 24], f16), [0, -1]), {})
+cnt: 1, ((T([128, 384, 20, 20], f16), [-2, -2, -2, -2]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 256, 256], f16), T([24, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([32, 24, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), T([64, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 4), {})
+cnt: 2, ((T([128, 1, 64], f16), T([1, 1, 3], f16), None, [1], [1], [1], False, [0], 1), {})
+cnt: 3, ((T([128, 64, 64, 64], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 64, 64], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 64, 64], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 2, ((T([128, 1, 128], f16), T([1, 1, 5], f16), None, [1], [2], [1], False, [0], 1), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 64, 64], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([256, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 1, ((T([128, 1, 256], f16), T([1, 1, 5], f16), None, [1], [2], [1], False, [0], 1), {})
+cnt: 2, ((T([128, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([384, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([640, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 512, 8, 8], f16), T([2048, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([512, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 8, 8], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 8, 8], f16), T([640, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 2, ((T([128, 2048, 8, 8], f16), T([128, 512, 8, 8], f16), T([2048, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 640, 8, 8], f16), T([128, 512, 8, 8], f16), T([640, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 8, 8], f16), T([128, 512, 8, 8], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 8, 8], f16), T([128, 2048, 8, 8], f16), T([512, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([128, 1024, 16, 16], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 640, 16, 16], f16), T([128, 512, 16, 16], f16), T([640, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 8, 8], f16), T([128, 512, 16, 16], f16), T([128, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16), T([128, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 16, 16], f16), T([128, 256, 16, 16], f16), T([384, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 16, 16], f16), T([128, 256, 16, 16], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1024, 16, 16], f16), T([128, 512, 32, 32], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1, 256], f16), T([128, 1, 256], f16), T([1, 1, 5], f16), [0], [1], [2], [1], False, [0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 256, 32, 32], f16), T([256, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 16, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([128, 512, 32, 32], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 512, 32, 32], f16), T([128, 128, 32, 32], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 1, 128], f16), T([128, 1, 128], f16), T([1, 1, 5], f16), [0], [1], [2], [1], False, [0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16), T([128, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 512, 32, 32], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 32, 32], f16), T([128, 256, 64, 64], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 32, 32], f16), T([128, 128, 64, 64], f16), T([128, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 256, 64, 64], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 256, 64, 64], f16), T([128, 64, 64, 64], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 1, 64], f16), T([128, 1, 64], f16), T([1, 1, 3], f16), [0], [1], [1], [1], False, [0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), T([64, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 256, 64, 64], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([128, 32, 128, 128], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 24, 128, 128], f16), T([32, 24, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([128, 3, 256, 256], f16), T([24, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 256, 256], f16), T([128, 3, 256, 256], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2048, 8, 8], f16, stride=(2048, 1, 0, 0)), 64), {})
+cnt: 1, ((T([128, 256, 16, 16], f16, stride=(256, 1, 0, 0)), 256), {})
+cnt: 2, ((T([128, 128, 32, 32], f16, stride=(128, 1, 0, 0)), 1024), {})
+cnt: 2, ((T([128, 64, 64, 64], f16, stride=(64, 1, 0, 0)), 4096), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 128, 128], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 64, 64, 64], f16), T([128, 64, 128, 128], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([128, 64, 64, 64], i64)), {})
+Operator: aten.mean.dim
+cnt: 2, ((T([128, 64, 64, 64], f16), [2, 3]), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), [2, 3]), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), [2, 3]), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 2, ((T([262144, 16], f16), T([16, 23], f16, stride=(1, 16))), {})
+cnt: 4, ((T([65536, 16], f16), T([16, 23], f16, stride=(1, 16))), {})
+cnt: 1, ((T([128, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2048], f16)), {})
+cnt: 4, ((T([23, 65536], f16, stride=(1, 23)), T([65536, 16], f16)), {})
+cnt: 4, ((T([65536, 23], f16), T([23, 16], f16)), {})
+cnt: 2, ((T([23, 262144], f16, stride=(1, 23)), T([262144, 16], f16)), {})
+cnt: 2, ((T([262144, 23], f16), T([23, 16], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16, stride=(64, 1, 0, 0))), {})
+cnt: 4, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16, stride=(128, 1, 0, 0))), {})
+cnt: 2, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16, stride=(256, 1, 0, 0))), {})
+cnt: 2, ((T([1024, 4, 64, 144], f16), 0.25), {})
+cnt: 2, ((T([1024, 4, 16, 144], f16), 0.25), {})
+cnt: 2, ((T([1024, 1, 64, 144], f16), 0.25), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16)), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16)), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 24, 128, 128], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), T([128, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 1024, 16, 16], f16), T([128, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([128, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 512, 32, 32], f16), T([128, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 64, 64], f16), T([128, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([128, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([128, 24, 128, 128], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.sigmoid.default
+cnt: 2, ((T([128, 1, 64], f16),), {})
+cnt: 2, ((T([128, 1, 128], f16),), {})
+cnt: 1, ((T([128, 1, 256], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([128, 1, 256], f16), T([128, 1, 256], f16)), {})
+cnt: 2, ((T([128, 1, 128], f16), T([128, 1, 128], f16)), {})
+cnt: 2, ((T([128, 1, 64], f16), T([128, 1, 64], f16)), {})
+Operator: aten.silu_.default
+cnt: 1, ((T([128, 24, 128, 128], f16),), {})
+cnt: 1, ((T([128, 32, 128, 128], f16),), {})
+cnt: 1, ((T([128, 64, 128, 128], f16),), {})
+cnt: 4, ((T([128, 64, 64, 64], f16),), {})
+cnt: 2, ((T([128, 256, 64, 64], f16),), {})
+cnt: 1, ((T([128, 128, 64, 64], f16),), {})
+cnt: 3, ((T([128, 128, 32, 32], f16),), {})
+cnt: 2, ((T([128, 512, 32, 32], f16),), {})
+cnt: 1, ((T([128, 256, 32, 32], f16),), {})
+cnt: 3, ((T([128, 256, 16, 16], f16),), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16),), {})
+cnt: 1, ((T([128, 512, 16, 16], f16),), {})
+cnt: 3, ((T([128, 512, 8, 8], f16),), {})
+cnt: 2, ((T([128, 2048, 8, 8], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 2, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16)), {})
+cnt: 3, ((T([128, 512, 8, 8], f16), T([128, 512, 8, 8], f16)), {})
+cnt: 1, ((T([128, 512, 16, 16], f16), T([128, 512, 16, 16], f16)), {})
+cnt: 2, ((T([128, 1024, 16, 16], f16), T([128, 1024, 16, 16], f16)), {})
+cnt: 3, ((T([128, 256, 16, 16], f16), T([128, 256, 16, 16], f16)), {})
+cnt: 1, ((T([128, 256, 32, 32], f16), T([128, 256, 32, 32], f16)), {})
+cnt: 2, ((T([128, 512, 32, 32], f16), T([128, 512, 32, 32], f16)), {})
+cnt: 3, ((T([128, 128, 32, 32], f16), T([128, 128, 32, 32], f16)), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16)), {})
+cnt: 2, ((T([128, 256, 64, 64], f16), T([128, 256, 64, 64], f16)), {})
+cnt: 4, ((T([128, 64, 64, 64], f16), T([128, 64, 64, 64], f16)), {})
+cnt: 1, ((T([128, 64, 128, 128], f16), T([128, 64, 128, 128], f16)), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 128, 128], f16)), {})
+cnt: 1, ((T([128, 24, 128, 128], f16), T([128, 24, 128, 128], f16)), {})
+Operator: aten.slice_backward.default
+cnt: 2, ((T([8192, 8, 12], f16), [8192, 8, 23], 2, 11, 9223372036854775807, 1), {})
+cnt: 2, ((T([8192, 8, 23], f16), [8192, 9, 23], 1, 0, 8, 1), {})
+cnt: 2, ((T([8192, 9, 23], f16), [8192, 9, 23], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([16384, 4, 12], f16), [16384, 4, 23], 2, 11, 9223372036854775807, 1), {})
+cnt: 2, ((T([16384, 4, 23], f16), [16384, 5, 23], 1, 0, 4, 1), {})
+cnt: 2, ((T([16384, 5, 23], f16), [16384, 5, 23], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([32768, 8, 12], f16), [32768, 8, 23], 2, 11, 9223372036854775807, 1), {})
+cnt: 2, ((T([32768, 8, 23], f16), [32768, 9, 23], 1, 0, 8, 1), {})
+cnt: 2, ((T([32768, 9, 23], f16), [32768, 9, 23], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.split_with_sizes.default
+cnt: 1, ((T([1024, 4, 144, 48], f16, stride=(27648, 144, 1, 576)), [16, 32], -1), {})
+cnt: 1, ((T([1024, 4, 144, 80], f16, stride=(46080, 144, 1, 576)), [16, 64], -1), {})
+cnt: 1, ((T([1024, 1, 144, 80], f16, stride=(11520, 144, 1, 144)), [16, 64], -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 1, ((T([1024, 8, 12, 8, 12], f16, stride=(9216, 144, 1, 1152, 12)), [2], True), {})
+cnt: 1, ((T([1024, 8, 12, 8, 12], f16, stride=(9216, 1152, 12, 144, 1)), [2], True), {})
+cnt: 1, ((T([4096, 4, 12, 4, 12], f16, stride=(2304, 144, 1, 576, 12)), [2], True), {})
+cnt: 1, ((T([4096, 4, 12, 4, 12], f16, stride=(2304, 576, 12, 144, 1)), [2], True), {})
+cnt: 1, ((T([4096, 8, 12, 8, 12], f16, stride=(9216, 144, 1, 1152, 12)), [2], True), {})
+cnt: 1, ((T([4096, 8, 12, 8, 12], f16, stride=(9216, 1152, 12, 144, 1)), [2], True), {})
+cnt: 1, ((T([128, 256, 16, 16], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 128, 32, 32], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 64, 64, 64], f16), [2, 3], True), {})
+Operator: aten.unfold_backward.default
+cnt: 1, ((T([128, 640, 1, 1, 12, 12], f16), [128, 640, 1, 12, 12], 3, 12, 8), {})
+cnt: 1, ((T([128, 640, 1, 12, 12], f16), [128, 640, 12, 12], 2, 12, 8), {})
+cnt: 1, ((T([128, 640, 2, 2, 12, 12], f16), [128, 640, 2, 20, 12], 3, 12, 8), {})
+cnt: 1, ((T([128, 640, 2, 20, 12], f16), [128, 640, 20, 20], 2, 12, 8), {})
+cnt: 1, ((T([128, 384, 2, 2, 12, 12], f16), [128, 384, 2, 20, 12], 3, 12, 8), {})
+cnt: 1, ((T([128, 384, 2, 20, 12], f16), [128, 384, 20, 20], 2, 12, 8), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/ecaresnet101d_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/ecaresnet101d_training.txt
new file mode 100644
index 0000000000000..21e66cff13b0f
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/ecaresnet101d_training.txt
@@ -0,0 +1,195 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 5, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16)), {})
+cnt: 46, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16)), {})
+cnt: 8, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16)), {})
+cnt: 6, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16)), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 106, ((T([], i64), 1), {})
+cnt: 3, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16)), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16)), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16)), {})
+cnt: 3, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([64, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([64, 256, 56, 56], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([64, 1024, 7, 7], f16), T([64, 1024, 14, 14], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 28, 28], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 256, 56, 56], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([32, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 64, 56, 56], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([64, 64, 56, 56], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 1, 256], f16), T([1, 1, 5], f16), None, [1], [2], [1], False, [0], 1), {})
+cnt: 2, ((T([64, 256, 56, 56], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([128, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([64, 128, 28, 28], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([64, 1, 512], f16), T([1, 1, 5], f16), None, [1], [2], [1], False, [0], 1), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([512, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 512, 28, 28], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 128, 28, 28], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([256, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 23, ((T([64, 256, 14, 14], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 23, ((T([64, 1, 1024], f16), T([1, 1, 5], f16), None, [1], [2], [1], False, [0], 1), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([1024, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 22, ((T([64, 1024, 14, 14], f16), T([256, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 22, ((T([64, 256, 14, 14], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([512, 512, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 512, 7, 7], f16), T([2048, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 1, 2048], f16), T([1, 1, 7], f16), None, [1], [3], [1], False, [0], 1), {})
+cnt: 1, ((T([64, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 2048, 7, 7], f16), T([512, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 512, 7, 7], f16), T([512, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([64, 1, 2048], f16), T([64, 1, 2048], f16), T([1, 1, 7], f16), [0], [1], [3], [1], False, [0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 2048, 7, 7], f16), T([64, 512, 7, 7], f16), T([2048, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16), T([512, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 512, 7, 7], f16), T([64, 2048, 7, 7], f16), T([512, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 2048, 7, 7], f16), T([64, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 14, 14], f16), T([512, 512, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 23, ((T([64, 1, 1024], f16), T([64, 1, 1024], f16), T([1, 1, 5], f16), [0], [1], [2], [1], False, [0], 1, [True, True, False]), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), T([64, 256, 14, 14], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 22, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 22, ((T([64, 256, 14, 14], f16), T([64, 1024, 14, 14], f16), T([256, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), T([64, 512, 14, 14], f16), T([1024, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 14, 14], f16), T([64, 256, 28, 28], f16), T([256, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([64, 1, 512], f16), T([64, 1, 512], f16), T([1, 1, 5], f16), [0], [1], [2], [1], False, [0], 1, [True, True, False]), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), T([64, 128, 28, 28], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 128, 28, 28], f16), T([64, 512, 28, 28], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([64, 256, 28, 28], f16), T([512, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 28, 28], f16), T([64, 128, 56, 56], f16), T([128, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 1, 256], f16), T([64, 1, 256], f16), T([1, 1, 5], f16), [0], [1], [2], [1], False, [0], 1, [True, True, False]), {})
+cnt: 4, ((T([64, 256, 56, 56], f16), T([64, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 64, 56, 56], f16), T([64, 256, 56, 56], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 32, 112, 112], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), T([32, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 4, ((T([64, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16, stride=(1024, 1, 0, 0)), 196), {})
+cnt: 4, ((T([64, 512, 28, 28], f16, stride=(512, 1, 0, 0)), 784), {})
+cnt: 3, ((T([64, 256, 56, 56], f16, stride=(256, 1, 0, 0)), 3136), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([64, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([64, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 3, ((T([64, 256, 56, 56], f16), [2, 3]), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), [2, 3]), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), [2, 3]), {})
+cnt: 3, ((T([64, 2048, 7, 7], f16), [2, 3]), {})
+cnt: 1, ((T([64, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 2048], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 6, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16, stride=(256, 1, 0, 0))), {})
+cnt: 8, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16, stride=(512, 1, 0, 0))), {})
+cnt: 46, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16, stride=(1024, 1, 0, 0))), {})
+cnt: 6, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16, stride=(2048, 1, 0, 0))), {})
+cnt: 3, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16)), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16)), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16)), {})
+cnt: 3, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([64, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([64, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([64, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([64, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 45, ((T([64, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 24, ((T([64, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([64, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 24, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 45, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([64, 32, 112, 112], f16),), {})
+cnt: 1, ((T([64, 64, 112, 112], f16),), {})
+cnt: 6, ((T([64, 64, 56, 56], f16),), {})
+cnt: 3, ((T([64, 256, 56, 56], f16),), {})
+cnt: 1, ((T([64, 128, 56, 56], f16),), {})
+cnt: 7, ((T([64, 128, 28, 28], f16),), {})
+cnt: 4, ((T([64, 512, 28, 28], f16),), {})
+cnt: 1, ((T([64, 256, 28, 28], f16),), {})
+cnt: 45, ((T([64, 256, 14, 14], f16),), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16),), {})
+cnt: 1, ((T([64, 512, 14, 14], f16),), {})
+cnt: 5, ((T([64, 512, 7, 7], f16),), {})
+cnt: 3, ((T([64, 2048, 7, 7], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 3, ((T([64, 1, 256], f16),), {})
+cnt: 4, ((T([64, 1, 512], f16),), {})
+cnt: 23, ((T([64, 1, 1024], f16),), {})
+cnt: 3, ((T([64, 1, 2048], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 3, ((T([64, 1, 2048], f16), T([64, 1, 2048], f16)), {})
+cnt: 23, ((T([64, 1, 1024], f16), T([64, 1, 1024], f16)), {})
+cnt: 4, ((T([64, 1, 512], f16), T([64, 1, 512], f16)), {})
+cnt: 3, ((T([64, 1, 256], f16), T([64, 1, 256], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 3, ((T([64, 2048, 7, 7], f16), [2, 3], True), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), [2, 3], True), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 256, 56, 56], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16), 0), {})
+cnt: 5, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), 0), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16), 0), {})
+cnt: 45, ((T([64, 256, 14, 14], f16), T([64, 256, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 256, 28, 28], f16), 0), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16), 0), {})
+cnt: 7, ((T([64, 128, 28, 28], f16), T([64, 128, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 128, 56, 56], f16), 0), {})
+cnt: 3, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16), 0), {})
+cnt: 6, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), 0), {})
+cnt: 2, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/ese_vovnet19b_dw_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/ese_vovnet19b_dw_training.txt
new file mode 100644
index 0000000000000..f81cd27ece756
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/ese_vovnet19b_dw_training.txt
@@ -0,0 +1,182 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 23, ((T([], i64), 1), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16)), {})
+cnt: 2, ((T([128, 224, 7, 7], f16, stride=(70560, 49, 7, 1)), T([128, 224, 7, 7], f16)), {})
+cnt: 1, ((T([128, 768, 7, 7], f16, stride=(70560, 49, 7, 1)), T([128, 768, 7, 7], f16)), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([128, 768, 14, 14], f16)), {})
+cnt: 2, ((T([128, 192, 14, 14], f16, stride=(213248, 196, 14, 1)), T([128, 192, 14, 14], f16)), {})
+cnt: 1, ((T([128, 512, 14, 14], f16, stride=(213248, 196, 14, 1)), T([128, 512, 14, 14], f16)), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16)), {})
+cnt: 2, ((T([128, 160, 28, 28], f16, stride=(577024, 784, 28, 1)), T([128, 160, 28, 28], f16)), {})
+cnt: 1, ((T([128, 256, 28, 28], f16, stride=(577024, 784, 28, 1)), T([128, 256, 28, 28], f16)), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+cnt: 2, ((T([128, 128, 56, 56], f16, stride=(1404928, 3136, 56, 1)), T([128, 128, 56, 56], f16)), {})
+cnt: 1, ((T([128, 64, 56, 56], f16, stride=(1404928, 3136, 56, 1)), T([128, 64, 56, 56], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 64, 56, 56], f16), T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16)], 1), {})
+cnt: 1, (([T([128, 256, 28, 28], f16), T([128, 160, 28, 28], f16), T([128, 160, 28, 28], f16), T([128, 160, 28, 28], f16)], 1), {})
+cnt: 1, (([T([128, 512, 14, 14], f16), T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16)], 1), {})
+cnt: 1, (([T([128, 768, 7, 7], f16), T([128, 224, 7, 7], f16), T([128, 224, 7, 7], f16), T([128, 224, 7, 7], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([64, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 128, 56, 56], f16), T([128, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 128), {})
+cnt: 3, ((T([128, 128, 56, 56], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 448, 56, 56], f16), T([256, 448, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([256, 256, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([160, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 160, 28, 28], f16), T([160, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 160), {})
+cnt: 3, ((T([128, 160, 28, 28], f16), T([160, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 736, 28, 28], f16), T([512, 736, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 1, 1], f16), T([512, 512, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([192, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 192, 14, 14], f16), T([192, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 192), {})
+cnt: 3, ((T([128, 192, 14, 14], f16), T([192, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1088, 14, 14], f16), T([768, 1088, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 768, 1, 1], f16), T([768, 768, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([224, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 224, 7, 7], f16), T([224, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 224), {})
+cnt: 3, ((T([128, 224, 7, 7], f16), T([224, 224, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1440, 7, 7], f16), T([1024, 1440, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 1, 1], f16), T([1024, 1024, 1, 1], f16), T([1024], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1024, 1, 1], f16), T([128, 1024, 1, 1], f16), T([1024, 1024, 1, 1], f16), [1024], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1440, 7, 7], f16), T([1024, 1440, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 224, 7, 7], f16), T([128, 224, 7, 7], f16), T([224, 224, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 224, 7, 7], f16), T([128, 224, 7, 7], f16), T([224, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 224, [True, True, False]), {})
+cnt: 1, ((T([128, 224, 7, 7], f16), T([128, 768, 7, 7], f16), T([224, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 768, 1, 1], f16), T([128, 768, 1, 1], f16), T([768, 768, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([128, 1088, 14, 14], f16), T([768, 1088, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), T([192, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), T([192, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 192, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([128, 512, 14, 14], f16), T([192, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 1, 1], f16), T([128, 512, 1, 1], f16), T([512, 512, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 736, 28, 28], f16), T([512, 736, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 160, 28, 28], f16), T([128, 160, 28, 28], f16), T([160, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 160, 28, 28], f16), T([128, 160, 28, 28], f16), T([160, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 160, [True, True, False]), {})
+cnt: 1, ((T([128, 160, 28, 28], f16), T([128, 256, 28, 28], f16), T([160, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16), T([256, 256, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 448, 56, 56], f16), T([256, 448, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), T([128, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 128, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 64, 56, 56], f16), T([128, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 112, 112], f16), T([64, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 3, 224, 224], f16), T([64, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 2, ((T([128, 1024, 7, 7], f16, stride=(1024, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 768, 14, 14], f16, stride=(768, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 512, 28, 28], f16, stride=(512, 1, 0, 0)), 784), {})
+cnt: 1, ((T([128, 256, 56, 56], f16, stride=(256, 1, 0, 0)), 3136), {})
+Operator: aten.hardsigmoid.default
+cnt: 1, ((T([128, 256, 1, 1], f16),), {})
+cnt: 1, ((T([128, 512, 1, 1], f16),), {})
+cnt: 1, ((T([128, 768, 1, 1], f16),), {})
+cnt: 1, ((T([128, 1024, 1, 1], f16),), {})
+Operator: aten.hardsigmoid_backward.default
+cnt: 1, ((T([128, 1024, 1, 1], f16), T([128, 1024, 1, 1], f16)), {})
+cnt: 1, ((T([128, 768, 1, 1], f16), T([128, 768, 1, 1], f16)), {})
+cnt: 1, ((T([128, 512, 1, 1], f16), T([128, 512, 1, 1], f16)), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 256, 56, 56], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 768, 7, 7], f16), T([128, 768, 14, 14], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([128, 768, 7, 7], i64)), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([128, 512, 28, 28], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([128, 512, 14, 14], i64)), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([128, 256, 56, 56], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([128, 256, 28, 28], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 256, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1024], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 256, 56, 56], f16), T([128, 256, 1, 1], f16)), {})
+cnt: 2, ((T([128, 512, 28, 28], f16), T([128, 512, 1, 1], f16)), {})
+cnt: 2, ((T([128, 768, 14, 14], f16), T([128, 768, 1, 1], f16)), {})
+cnt: 2, ((T([128, 1024, 7, 7], f16), T([128, 1024, 1, 1], f16)), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16)), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([128, 768, 14, 14], f16)), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16)), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 160, 28, 28], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 224, 7, 7], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 224, 7, 7], f16), T([128, 224, 7, 7], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([128, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 160, 28, 28], f16), T([128, 160, 28, 28], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([128, 64, 112, 112], f16),), {})
+cnt: 1, ((T([128, 64, 56, 56], f16),), {})
+cnt: 4, ((T([128, 128, 56, 56], f16),), {})
+cnt: 1, ((T([128, 256, 56, 56], f16),), {})
+cnt: 4, ((T([128, 160, 28, 28], f16),), {})
+cnt: 1, ((T([128, 512, 28, 28], f16),), {})
+cnt: 4, ((T([128, 192, 14, 14], f16),), {})
+cnt: 1, ((T([128, 768, 14, 14], f16),), {})
+cnt: 4, ((T([128, 224, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 224, 7, 7], f16, stride=(70560, 49, 7, 1)), T([128, 224, 7, 7], f16), 0), {})
+cnt: 3, ((T([128, 224, 7, 7], f16), T([128, 224, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([128, 768, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 192, 14, 14], f16, stride=(213248, 196, 14, 1)), T([128, 192, 14, 14], f16), 0), {})
+cnt: 3, ((T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 160, 28, 28], f16, stride=(577024, 784, 28, 1)), T([128, 160, 28, 28], f16), 0), {})
+cnt: 3, ((T([128, 160, 28, 28], f16), T([128, 160, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 128, 56, 56], f16, stride=(1404928, 3136, 56, 1)), T([128, 128, 56, 56], f16), 0), {})
+cnt: 3, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), 0), {})
+cnt: 2, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/fbnetc_100_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/fbnetc_100_training.txt
new file mode 100644
index 0000000000000..4be2a0309a2e5
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/fbnetc_100_training.txt
@@ -0,0 +1,189 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 65, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+cnt: 4, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 6, ((T([128, 32, 28, 28], f16), T([128, 32, 28, 28], f16)), {})
+cnt: 6, ((T([128, 64, 14, 14], f16), T([128, 64, 14, 14], f16)), {})
+cnt: 6, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16)), {})
+cnt: 6, ((T([128, 184, 7, 7], f16), T([128, 184, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1984], f16), T([1984, 1000], f16, stride=(1, 1984))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([16, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([96, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([24, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 24, 56, 56], f16), T([24, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([24, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 24), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([144, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([144, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([32, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 28, 28], f16), T([96, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 28, 28], f16), T([96, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([128, 96, 28, 28], f16), T([32, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 32, 28, 28], f16), T([192, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([192, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 192), {})
+cnt: 2, ((T([128, 192, 28, 28], f16), T([32, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([192, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 192), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([192, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 192), {})
+cnt: 2, ((T([128, 192, 14, 14], f16), T([64, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 14, 14], f16), T([192, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([192, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 192), {})
+cnt: 3, ((T([128, 64, 14, 14], f16), T([384, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 384, 14, 14], f16), T([384, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 384), {})
+cnt: 2, ((T([128, 384, 14, 14], f16), T([64, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([112, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 112, 14, 14], f16), T([672, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([112, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([336, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 336, 14, 14], f16), T([336, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 336), {})
+cnt: 1, ((T([128, 336, 14, 14], f16), T([112, 336, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([184, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 184, 7, 7], f16), T([1104, 184, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 1104, 7, 7], f16), T([1104, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1104), {})
+cnt: 3, ((T([128, 1104, 7, 7], f16), T([184, 1104, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), T([1104, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1104), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), T([352, 1104, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 352, 7, 7], f16), T([1984, 352, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1984, 7, 7], f16), T([128, 352, 7, 7], f16), T([1984, 352, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 352, 7, 7], f16), T([128, 1104, 7, 7], f16), T([352, 1104, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), T([128, 1104, 7, 7], f16), T([1104, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1104, [True, True, False]), {})
+cnt: 4, ((T([128, 1104, 7, 7], f16), T([128, 184, 7, 7], f16), T([1104, 184, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 184, 7, 7], f16), T([128, 1104, 7, 7], f16), T([184, 1104, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 1104, 7, 7], f16), T([128, 1104, 7, 7], f16), T([1104, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1104, [True, True, False]), {})
+cnt: 1, ((T([128, 184, 7, 7], f16), T([128, 672, 7, 7], f16), T([184, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 3, ((T([128, 672, 14, 14], f16), T([128, 112, 14, 14], f16), T([672, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([128, 336, 14, 14], f16), T([112, 336, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 336, 14, 14], f16), T([128, 336, 14, 14], f16), T([336, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 336, [True, True, False]), {})
+cnt: 1, ((T([128, 336, 14, 14], f16), T([128, 112, 14, 14], f16), T([336, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([128, 672, 14, 14], f16), T([112, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([128, 384, 14, 14], f16), T([112, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 384, [True, True, False]), {})
+cnt: 3, ((T([128, 384, 14, 14], f16), T([128, 64, 14, 14], f16), T([384, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 14, 14], f16), T([128, 384, 14, 14], f16), T([64, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 14, 14], f16), T([128, 192, 14, 14], f16), T([64, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), T([192, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 192, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([128, 64, 14, 14], f16), T([192, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([128, 192, 28, 28], f16), T([192, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 192, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 28, 28], f16), T([128, 32, 28, 28], f16), T([192, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 32, 28, 28], f16), T([128, 192, 28, 28], f16), T([32, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), T([192, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 192, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), T([192, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 192, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 28, 28], f16), T([128, 96, 28, 28], f16), T([32, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 28, 28], f16), T([128, 96, 28, 28], f16), T([96, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 28, 28], f16), T([128, 32, 28, 28], f16), T([96, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 28, 28], f16), T([128, 144, 28, 28], f16), T([32, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 56, 56], f16), T([144, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 24, 56, 56], f16), T([144, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 24, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 96, 56, 56], f16), T([24, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 112, 112], f16), T([96, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 16, 112, 112], f16), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 16, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1984, 7, 7], f16, stride=(1984, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 1984, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1984], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1984], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 4, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 32, 28, 28], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 96, 28, 28], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 64, 14, 14], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([128, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 336, 14, 14], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 184, 7, 7], f16), T([184], f16), T([184], f16), T([184], f16), T([184], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([128, 1104, 7, 7], f16), T([1104], f16), T([1104], f16), T([1104], f16), T([1104], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 352, 7, 7], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1984, 7, 7], f16), T([1984], f16), T([1984], f16), T([1984], f16), T([1984], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1984, 7, 7], f16), T([128, 1984, 7, 7], f16), T([1984], f16), T([1984], f16), T([1984], f16), T([1984], f32), T([1984], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 352, 7, 7], f16), T([128, 352, 7, 7], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f32), T([352], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([128, 1104, 7, 7], f16), T([128, 1104, 7, 7], f16), T([1104], f16), T([1104], f16), T([1104], f16), T([1104], f32), T([1104], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 184, 7, 7], f16), T([128, 184, 7, 7], f16), T([184], f16), T([184], f16), T([184], f16), T([184], f32), T([184], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 336, 14, 14], f16), T([128, 336, 14, 14], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 64, 14, 14], f16), T([128, 64, 14, 14], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 32, 28, 28], f16), T([128, 32, 28, 28], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 96, 28, 28], f16), T([128, 96, 28, 28], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 3, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 96, 112, 112], f16),), {})
+cnt: 1, ((T([128, 96, 56, 56], f16),), {})
+cnt: 4, ((T([128, 24, 56, 56], f16),), {})
+cnt: 1, ((T([128, 144, 56, 56], f16),), {})
+cnt: 1, ((T([128, 144, 28, 28], f16),), {})
+cnt: 2, ((T([128, 96, 28, 28], f16),), {})
+cnt: 5, ((T([128, 192, 28, 28], f16),), {})
+cnt: 3, ((T([128, 192, 14, 14], f16),), {})
+cnt: 6, ((T([128, 384, 14, 14], f16),), {})
+cnt: 5, ((T([128, 672, 14, 14], f16),), {})
+cnt: 2, ((T([128, 336, 14, 14], f16),), {})
+cnt: 1, ((T([128, 672, 7, 7], f16),), {})
+cnt: 8, ((T([128, 1104, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1984, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 1984, 7, 7], f16), T([128, 1984, 7, 7], f16), 0), {})
+cnt: 8, ((T([128, 1104, 7, 7], f16), T([128, 1104, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16), 0), {})
+cnt: 5, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16), 0), {})
+cnt: 2, ((T([128, 336, 14, 14], f16), T([128, 336, 14, 14], f16), 0), {})
+cnt: 6, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), 0), {})
+cnt: 3, ((T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), 0), {})
+cnt: 5, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), 0), {})
+cnt: 2, ((T([128, 96, 28, 28], f16), T([128, 96, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16), 0), {})
+cnt: 4, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 96, 112, 112], f16), 0), {})
+cnt: 3, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/fbnetv3_b_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/fbnetv3_b_training.txt
new file mode 100644
index 0000000000000..85ee90a54b645
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/fbnetv3_b_training.txt
@@ -0,0 +1,287 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 87, ((T([], i64), 1), {})
+cnt: 4, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+cnt: 6, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 8, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16)), {})
+cnt: 8, ((T([128, 72, 14, 14], f16), T([128, 72, 14, 14], f16)), {})
+cnt: 10, ((T([128, 120, 14, 14], f16), T([128, 120, 14, 14], f16)), {})
+cnt: 10, ((T([128, 184, 7, 7], f16), T([128, 184, 7, 7], f16)), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), T([128, 1104, 7, 7], f16)), {})
+cnt: 5, ((T([128, 736, 7, 7], f16), T([128, 736, 7, 7], f16)), {})
+cnt: 1, ((T([128, 720, 7, 7], f16), T([128, 720, 7, 7], f16)), {})
+cnt: 6, ((T([128, 360, 14, 14], f16), T([128, 360, 14, 14], f16)), {})
+cnt: 5, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1984], f16), T([1984, 1000], f16, stride=(1, 1984))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+cnt: 3, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 1, ((T([128, 64, 56, 56], f16),), {})
+cnt: 6, ((T([128, 48, 56, 56], f16),), {})
+cnt: 1, ((T([128, 120, 56, 56], f16),), {})
+cnt: 9, ((T([128, 120, 28, 28], f16),), {})
+cnt: 1, ((T([128, 8, 1, 1], f16),), {})
+cnt: 4, ((T([128, 16, 1, 1], f16),), {})
+cnt: 1, ((T([128, 200, 28, 28], f16),), {})
+cnt: 1, ((T([128, 200, 14, 14], f16),), {})
+cnt: 8, ((T([128, 216, 14, 14], f16),), {})
+cnt: 12, ((T([128, 360, 14, 14], f16),), {})
+cnt: 1, ((T([128, 24, 1, 1], f16),), {})
+cnt: 6, ((T([128, 32, 1, 1], f16),), {})
+cnt: 1, ((T([128, 720, 14, 14], f16),), {})
+cnt: 1, ((T([128, 720, 7, 7], f16),), {})
+cnt: 10, ((T([128, 736, 7, 7], f16),), {})
+cnt: 6, ((T([128, 48, 1, 1], f16),), {})
+cnt: 2, ((T([128, 1104, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1344, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1984, 1, 1], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([16, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([64, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([24, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([48, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 48, 56, 56], f16), T([48, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 48), {})
+cnt: 3, ((T([128, 48, 56, 56], f16), T([24, 48, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([120, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 120, 56, 56], f16), T([120, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 120), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([8, 120, 1, 1], f16), T([8], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([120, 8, 1, 1], f16), T([120], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 120, 28, 28], f16), T([40, 120, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([120, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 120), {})
+cnt: 4, ((T([128, 120, 1, 1], f16), T([16, 120, 1, 1], f16), T([16], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 16, 1, 1], f16), T([120, 16, 1, 1], f16), T([120], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([200, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 200, 28, 28], f16), T([200, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 200), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([72, 200, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 72, 14, 14], f16), T([216, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 216, 14, 14], f16), T([216, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 216), {})
+cnt: 4, ((T([128, 216, 14, 14], f16), T([72, 216, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 14, 14], f16), T([360, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 360, 14, 14], f16), T([360, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 360), {})
+cnt: 1, ((T([128, 360, 1, 1], f16), T([24, 360, 1, 1], f16), T([24], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([360, 24, 1, 1], f16), T([360], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([128, 360, 14, 14], f16), T([120, 360, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 120, 14, 14], f16), T([360, 120, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 360, 14, 14], f16), T([360, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 360), {})
+cnt: 5, ((T([128, 360, 1, 1], f16), T([32, 360, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 32, 1, 1], f16), T([360, 32, 1, 1], f16), T([360], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 120, 14, 14], f16), T([720, 120, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 720, 14, 14], f16), T([720, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 720), {})
+cnt: 1, ((T([128, 720, 1, 1], f16), T([32, 720, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([720, 32, 1, 1], f16), T([720], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 720, 7, 7], f16), T([184, 720, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 184, 7, 7], f16), T([736, 184, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 736, 7, 7], f16), T([736, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 736), {})
+cnt: 5, ((T([128, 736, 1, 1], f16), T([48, 736, 1, 1], f16), T([48], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 48, 1, 1], f16), T([736, 48, 1, 1], f16), T([736], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 736, 7, 7], f16), T([184, 736, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 184, 7, 7], f16), T([1104, 184, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), T([1104, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1104), {})
+cnt: 1, ((T([128, 1104, 1, 1], f16), T([48, 1104, 1, 1], f16), T([48], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 48, 1, 1], f16), T([1104, 48, 1, 1], f16), T([1104], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), T([224, 1104, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 224, 7, 7], f16), T([1344, 224, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1344, 1, 1], f16), T([1984, 1344, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1984, 1, 1], f16), T([128, 1344, 1, 1], f16), T([1984, 1344, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1344, 7, 7], f16), T([128, 224, 7, 7], f16), T([1344, 224, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 224, 7, 7], f16), T([128, 1104, 7, 7], f16), T([224, 1104, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1104, 1, 1], f16), T([128, 48, 1, 1], f16), T([1104, 48, 1, 1], f16), [1104], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 48, 1, 1], f16), T([128, 1104, 1, 1], f16), T([48, 1104, 1, 1], f16), [48], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), T([128, 1104, 7, 7], f16), T([1104, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1104, [True, True, False]), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), T([128, 184, 7, 7], f16), T([1104, 184, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 184, 7, 7], f16), T([128, 736, 7, 7], f16), T([184, 736, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 736, 1, 1], f16), T([128, 48, 1, 1], f16), T([736, 48, 1, 1], f16), [736], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 5, ((T([128, 48, 1, 1], f16), T([128, 736, 1, 1], f16), T([48, 736, 1, 1], f16), [48], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 5, ((T([128, 736, 7, 7], f16), T([128, 736, 7, 7], f16), T([736, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 736, [True, True, False]), {})
+cnt: 5, ((T([128, 736, 7, 7], f16), T([128, 184, 7, 7], f16), T([736, 184, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 184, 7, 7], f16), T([128, 720, 7, 7], f16), T([184, 720, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 720, 1, 1], f16), T([128, 32, 1, 1], f16), T([720, 32, 1, 1], f16), [720], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 720, 1, 1], f16), T([32, 720, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 720, 7, 7], f16), T([128, 720, 14, 14], f16), T([720, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 720, [True, True, False]), {})
+cnt: 1, ((T([128, 720, 14, 14], f16), T([128, 120, 14, 14], f16), T([720, 120, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([128, 120, 14, 14], f16), T([128, 360, 14, 14], f16), T([120, 360, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 360, 1, 1], f16), T([128, 32, 1, 1], f16), T([360, 32, 1, 1], f16), [360], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 5, ((T([128, 32, 1, 1], f16), T([128, 360, 1, 1], f16), T([32, 360, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 5, ((T([128, 360, 14, 14], f16), T([128, 360, 14, 14], f16), T([360, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 360, [True, True, False]), {})
+cnt: 5, ((T([128, 360, 14, 14], f16), T([128, 120, 14, 14], f16), T([360, 120, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 360, 1, 1], f16), T([128, 24, 1, 1], f16), T([360, 24, 1, 1], f16), [360], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([128, 360, 1, 1], f16), T([24, 360, 1, 1], f16), [24], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 360, 14, 14], f16), T([128, 360, 14, 14], f16), T([360, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 360, [True, True, False]), {})
+cnt: 1, ((T([128, 360, 14, 14], f16), T([128, 72, 14, 14], f16), T([360, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 72, 14, 14], f16), T([128, 216, 14, 14], f16), T([72, 216, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 216, 14, 14], f16), T([128, 216, 14, 14], f16), T([216, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 216, [True, True, False]), {})
+cnt: 4, ((T([128, 216, 14, 14], f16), T([128, 72, 14, 14], f16), T([216, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 72, 14, 14], f16), T([128, 200, 14, 14], f16), T([72, 200, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([128, 200, 28, 28], f16), T([200, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 200, [True, True, False]), {})
+cnt: 1, ((T([128, 200, 28, 28], f16), T([128, 40, 28, 28], f16), T([200, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 40, 28, 28], f16), T([128, 120, 28, 28], f16), T([40, 120, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 120, 1, 1], f16), T([128, 16, 1, 1], f16), T([120, 16, 1, 1], f16), [120], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([128, 16, 1, 1], f16), T([128, 120, 1, 1], f16), T([16, 120, 1, 1], f16), [16], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([128, 8, 1, 1], f16), T([120, 8, 1, 1], f16), [120], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([128, 120, 1, 1], f16), T([8, 120, 1, 1], f16), [8], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), T([128, 120, 56, 56], f16), T([120, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 1, ((T([128, 120, 56, 56], f16), T([128, 24, 56, 56], f16), T([120, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([128, 48, 56, 56], f16), T([24, 48, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16), T([48, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 48, [True, True, False]), {})
+cnt: 3, ((T([128, 48, 56, 56], f16), T([128, 24, 56, 56], f16), T([48, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 64, 56, 56], f16), T([24, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 112, 112], f16), T([64, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 16, 112, 112], f16), T([64, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 16, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1344, 7, 7], f16, stride=(1344, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16, stride=(1104, 1, 0, 0)), 49), {})
+cnt: 5, ((T([128, 736, 7, 7], f16, stride=(736, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 720, 7, 7], f16, stride=(720, 1, 0, 0)), 49), {})
+cnt: 6, ((T([128, 360, 14, 14], f16, stride=(360, 1, 0, 0)), 196), {})
+cnt: 5, ((T([128, 120, 28, 28], f16, stride=(120, 1, 0, 0)), 784), {})
+Operator: aten.hardsigmoid.default
+cnt: 5, ((T([128, 120, 1, 1], f16),), {})
+cnt: 6, ((T([128, 360, 1, 1], f16),), {})
+cnt: 1, ((T([128, 720, 1, 1], f16),), {})
+cnt: 5, ((T([128, 736, 1, 1], f16),), {})
+cnt: 1, ((T([128, 1104, 1, 1], f16),), {})
+Operator: aten.hardsigmoid_backward.default
+cnt: 1, ((T([128, 1104, 1, 1], f16), T([128, 1104, 1, 1], f16)), {})
+cnt: 5, ((T([128, 736, 1, 1], f16), T([128, 736, 1, 1], f16)), {})
+cnt: 1, ((T([128, 720, 1, 1], f16), T([128, 720, 1, 1], f16)), {})
+cnt: 6, ((T([128, 360, 1, 1], f16), T([128, 360, 1, 1], f16)), {})
+cnt: 5, ((T([128, 120, 1, 1], f16), T([128, 120, 1, 1], f16)), {})
+Operator: aten.hardswish_.default
+cnt: 3, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 1, ((T([128, 64, 56, 56], f16),), {})
+cnt: 6, ((T([128, 48, 56, 56], f16),), {})
+cnt: 1, ((T([128, 120, 56, 56], f16),), {})
+cnt: 9, ((T([128, 120, 28, 28], f16),), {})
+cnt: 1, ((T([128, 8, 1, 1], f16),), {})
+cnt: 4, ((T([128, 16, 1, 1], f16),), {})
+cnt: 1, ((T([128, 200, 28, 28], f16),), {})
+cnt: 1, ((T([128, 200, 14, 14], f16),), {})
+cnt: 8, ((T([128, 216, 14, 14], f16),), {})
+cnt: 12, ((T([128, 360, 14, 14], f16),), {})
+cnt: 1, ((T([128, 24, 1, 1], f16),), {})
+cnt: 6, ((T([128, 32, 1, 1], f16),), {})
+cnt: 1, ((T([128, 720, 14, 14], f16),), {})
+cnt: 1, ((T([128, 720, 7, 7], f16),), {})
+cnt: 10, ((T([128, 736, 7, 7], f16),), {})
+cnt: 6, ((T([128, 48, 1, 1], f16),), {})
+cnt: 2, ((T([128, 1104, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1344, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1984, 1, 1], f16),), {})
+Operator: aten.hardswish_backward.default
+cnt: 1, ((T([128, 1984, 1, 1], f16), T([128, 1984, 1, 1], f16)), {})
+cnt: 1, ((T([128, 1344, 7, 7], f16), T([128, 1344, 7, 7], f16)), {})
+cnt: 6, ((T([128, 48, 1, 1], f16), T([128, 48, 1, 1], f16)), {})
+cnt: 2, ((T([128, 1104, 7, 7], f16), T([128, 1104, 7, 7], f16)), {})
+cnt: 10, ((T([128, 736, 7, 7], f16), T([128, 736, 7, 7], f16)), {})
+cnt: 6, ((T([128, 32, 1, 1], f16), T([128, 32, 1, 1], f16)), {})
+cnt: 1, ((T([128, 720, 7, 7], f16), T([128, 720, 7, 7], f16)), {})
+cnt: 1, ((T([128, 720, 14, 14], f16), T([128, 720, 14, 14], f16)), {})
+cnt: 12, ((T([128, 360, 14, 14], f16), T([128, 360, 14, 14], f16)), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([128, 24, 1, 1], f16)), {})
+cnt: 8, ((T([128, 216, 14, 14], f16), T([128, 216, 14, 14], f16)), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([128, 200, 14, 14], f16)), {})
+cnt: 1, ((T([128, 200, 28, 28], f16), T([128, 200, 28, 28], f16)), {})
+cnt: 4, ((T([128, 16, 1, 1], f16), T([128, 16, 1, 1], f16)), {})
+cnt: 9, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16)), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([128, 8, 1, 1], f16)), {})
+cnt: 1, ((T([128, 120, 56, 56], f16), T([128, 120, 56, 56], f16)), {})
+cnt: 6, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16)), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16)), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16)), {})
+cnt: 3, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 5, ((T([128, 120, 28, 28], f16), [2, 3], True), {})
+cnt: 6, ((T([128, 360, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 720, 7, 7], f16), [2, 3], True), {})
+cnt: 5, ((T([128, 736, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 1344, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1984], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1984], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 10, ((T([128, 120, 28, 28], f16), T([128, 120, 1, 1], f16)), {})
+cnt: 12, ((T([128, 360, 14, 14], f16), T([128, 360, 1, 1], f16)), {})
+cnt: 2, ((T([128, 720, 7, 7], f16), T([128, 720, 1, 1], f16)), {})
+cnt: 10, ((T([128, 736, 7, 7], f16), T([128, 736, 1, 1], f16)), {})
+cnt: 2, ((T([128, 1104, 7, 7], f16), T([128, 1104, 1, 1], f16)), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), T([128, 1104, 7, 7], f16)), {})
+cnt: 5, ((T([128, 736, 7, 7], f16), T([128, 736, 7, 7], f16)), {})
+cnt: 1, ((T([128, 720, 7, 7], f16), T([128, 720, 7, 7], f16)), {})
+cnt: 6, ((T([128, 360, 14, 14], f16), T([128, 360, 14, 14], f16)), {})
+cnt: 5, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 5, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 120, 56, 56], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 200, 28, 28], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 72, 14, 14], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([128, 216, 14, 14], f16), T([216], f16), T([216], f16), T([216], f16), T([216], f16), True, 0.1, 1e-05), {})
+cnt: 12, ((T([128, 360, 14, 14], f16), T([360], f16), T([360], f16), T([360], f16), T([360], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([128, 120, 14, 14], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 720, 14, 14], f16), T([720], f16), T([720], f16), T([720], f16), T([720], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 720, 7, 7], f16), T([720], f16), T([720], f16), T([720], f16), T([720], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([128, 184, 7, 7], f16), T([184], f16), T([184], f16), T([184], f16), T([184], f16), True, 0.1, 1e-05), {})
+cnt: 10, ((T([128, 736, 7, 7], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 1104, 7, 7], f16), T([1104], f16), T([1104], f16), T([1104], f16), T([1104], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 224, 7, 7], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1344, 7, 7], f16), T([1344], f16), T([1344], f16), T([1344], f16), T([1344], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1344, 7, 7], f16), T([128, 1344, 7, 7], f16), T([1344], f16), T([1344], f16), T([1344], f16), T([1344], f32), T([1344], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 224, 7, 7], f16), T([128, 224, 7, 7], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 1104, 7, 7], f16), T([128, 1104, 7, 7], f16), T([1104], f16), T([1104], f16), T([1104], f16), T([1104], f32), T([1104], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([128, 184, 7, 7], f16), T([128, 184, 7, 7], f16), T([184], f16), T([184], f16), T([184], f16), T([184], f32), T([184], f32), True, 1e-05, [True, True, True]), {})
+cnt: 10, ((T([128, 736, 7, 7], f16), T([128, 736, 7, 7], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f32), T([736], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 720, 7, 7], f16), T([128, 720, 7, 7], f16), T([720], f16), T([720], f16), T([720], f16), T([720], f32), T([720], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 720, 14, 14], f16), T([128, 720, 14, 14], f16), T([720], f16), T([720], f16), T([720], f16), T([720], f32), T([720], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([128, 120, 14, 14], f16), T([128, 120, 14, 14], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), True, 1e-05, [True, True, True]), {})
+cnt: 12, ((T([128, 360, 14, 14], f16), T([128, 360, 14, 14], f16), T([360], f16), T([360], f16), T([360], f16), T([360], f32), T([360], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 72, 14, 14], f16), T([128, 72, 14, 14], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([128, 216, 14, 14], f16), T([128, 216, 14, 14], f16), T([216], f16), T([216], f16), T([216], f16), T([216], f32), T([216], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([128, 200, 14, 14], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f32), T([200], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 200, 28, 28], f16), T([128, 200, 28, 28], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f32), T([200], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 120, 56, 56], f16), T([128, 120, 56, 56], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 1, ((T([128, 1104, 7, 7], f16), [2, 3], True), {})
+cnt: 5, ((T([128, 736, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 720, 7, 7], f16), [2, 3], True), {})
+cnt: 6, ((T([128, 360, 14, 14], f16), [2, 3], True), {})
+cnt: 5, ((T([128, 120, 28, 28], f16), [2, 3], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gernet_l_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gernet_l_training.txt
new file mode 100644
index 0000000000000..1efcbbfec35ee
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gernet_l_training.txt
@@ -0,0 +1,118 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 57, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16)), {})
+cnt: 4, ((T([128, 192, 32, 32], f16), T([128, 192, 32, 32], f16)), {})
+cnt: 12, ((T([128, 640, 16, 16], f16), T([128, 640, 16, 16], f16)), {})
+cnt: 17, ((T([128, 640, 8, 8], f16), T([128, 640, 8, 8], f16)), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 128, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2560], f16), T([2560, 1000], f16, stride=(1, 2560))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 256, 256], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 256, 256], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([192, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 192, 32, 32], f16), T([192, 192, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([192, 128, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 32, 32], f16), T([160, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 160, 32, 32], f16), T([160, 160, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([128, 160, 16, 16], f16), T([640, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 32, 32], f16), T([640, 192, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 640, 16, 16], f16), T([160, 640, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 160, 16, 16], f16), T([160, 160, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 640, 16, 16], f16), T([1920, 640, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1920, 16, 16], f16), T([1920, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1920), {})
+cnt: 9, ((T([128, 1920, 8, 8], f16), T([640, 1920, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 640, 16, 16], f16), T([640, 640, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([128, 640, 8, 8], f16), T([1920, 640, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([128, 1920, 8, 8], f16), T([1920, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1920), {})
+cnt: 1, ((T([128, 640, 8, 8], f16), T([2560, 640, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 2560, 8, 8], f16), T([128, 640, 8, 8], f16), T([2560, 640, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 9, ((T([128, 640, 8, 8], f16), T([128, 1920, 8, 8], f16), T([640, 1920, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([128, 1920, 8, 8], f16), T([128, 1920, 8, 8], f16), T([1920, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1920, [True, True, False]), {})
+cnt: 8, ((T([128, 1920, 8, 8], f16), T([128, 640, 8, 8], f16), T([1920, 640, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 640, 8, 8], f16), T([128, 640, 16, 16], f16), T([640, 640, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1920, 8, 8], f16), T([128, 1920, 16, 16], f16), T([1920, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1920, [True, True, False]), {})
+cnt: 1, ((T([128, 1920, 16, 16], f16), T([128, 640, 16, 16], f16), T([1920, 640, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([128, 640, 16, 16], f16), T([128, 160, 16, 16], f16), T([640, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 160, 16, 16], f16), T([128, 160, 16, 16], f16), T([160, 160, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 160, 16, 16], f16), T([128, 640, 16, 16], f16), T([160, 640, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 640, 16, 16], f16), T([128, 192, 32, 32], f16), T([640, 192, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 160, 16, 16], f16), T([128, 160, 32, 32], f16), T([160, 160, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 160, 32, 32], f16), T([128, 192, 32, 32], f16), T([160, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 32, 32], f16), T([128, 192, 32, 32], f16), T([192, 192, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 32, 32], f16), T([128, 128, 64, 64], f16), T([192, 128, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 32, 32], f16), T([128, 128, 64, 64], f16), T([192, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 32, 128, 128], f16), T([128, 32, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 64, 64], f16), T([128, 32, 128, 128], f16), T([128, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 3, 256, 256], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 256, 256], f16), T([128, 3, 256, 256], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2560, 8, 8], f16, stride=(2560, 1, 0, 0)), 64), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 2560, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 2560], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2560], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 192, 32, 32], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 160, 32, 32], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 11, ((T([128, 160, 16, 16], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([128, 640, 16, 16], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1920, 16, 16], f16), T([1920], f16), T([1920], f16), T([1920], f16), T([1920], f16), True, 0.1, 1e-05), {})
+cnt: 17, ((T([128, 1920, 8, 8], f16), T([1920], f16), T([1920], f16), T([1920], f16), T([1920], f16), True, 0.1, 1e-05), {})
+cnt: 10, ((T([128, 640, 8, 8], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 2560, 8, 8], f16), T([2560], f16), T([2560], f16), T([2560], f16), T([2560], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 2560, 8, 8], f16), T([128, 2560, 8, 8], f16), T([2560], f16), T([2560], f16), T([2560], f16), T([2560], f32), T([2560], f32), True, 1e-05, [True, True, True]), {})
+cnt: 10, ((T([128, 640, 8, 8], f16), T([128, 640, 8, 8], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f32), T([640], f32), True, 1e-05, [True, True, True]), {})
+cnt: 17, ((T([128, 1920, 8, 8], f16), T([128, 1920, 8, 8], f16), T([1920], f16), T([1920], f16), T([1920], f16), T([1920], f32), T([1920], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 1920, 16, 16], f16), T([128, 1920, 16, 16], f16), T([1920], f16), T([1920], f16), T([1920], f16), T([1920], f32), T([1920], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([128, 640, 16, 16], f16), T([128, 640, 16, 16], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f32), T([640], f32), True, 1e-05, [True, True, True]), {})
+cnt: 11, ((T([128, 160, 16, 16], f16), T([128, 160, 16, 16], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 160, 32, 32], f16), T([128, 160, 32, 32], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 192, 32, 32], f16), T([128, 192, 32, 32], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 32, 128, 128], f16),), {})
+cnt: 2, ((T([128, 128, 64, 64], f16),), {})
+cnt: 4, ((T([128, 192, 32, 32], f16),), {})
+cnt: 1, ((T([128, 160, 32, 32], f16),), {})
+cnt: 11, ((T([128, 160, 16, 16], f16),), {})
+cnt: 6, ((T([128, 640, 16, 16], f16),), {})
+cnt: 1, ((T([128, 1920, 16, 16], f16),), {})
+cnt: 17, ((T([128, 1920, 8, 8], f16),), {})
+cnt: 9, ((T([128, 640, 8, 8], f16),), {})
+cnt: 1, ((T([128, 2560, 8, 8], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 2560, 8, 8], f16), T([128, 2560, 8, 8], f16), 0), {})
+cnt: 9, ((T([128, 640, 8, 8], f16), T([128, 640, 8, 8], f16), 0), {})
+cnt: 17, ((T([128, 1920, 8, 8], f16), T([128, 1920, 8, 8], f16), 0), {})
+cnt: 1, ((T([128, 1920, 16, 16], f16), T([128, 1920, 16, 16], f16), 0), {})
+cnt: 6, ((T([128, 640, 16, 16], f16), T([128, 640, 16, 16], f16), 0), {})
+cnt: 11, ((T([128, 160, 16, 16], f16), T([128, 160, 16, 16], f16), 0), {})
+cnt: 1, ((T([128, 160, 32, 32], f16), T([128, 160, 32, 32], f16), 0), {})
+cnt: 4, ((T([128, 192, 32, 32], f16), T([128, 192, 32, 32], f16), 0), {})
+cnt: 2, ((T([128, 128, 64, 64], f16), T([128, 128, 64, 64], f16), 0), {})
+cnt: 1, ((T([128, 32, 128, 128], f16), T([128, 32, 128, 128], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/ghostnet_100_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/ghostnet_100_training.txt
new file mode 100644
index 0000000000000..15066dcc1a0c3
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/ghostnet_100_training.txt
@@ -0,0 +1,411 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([], i64), 1), {})
+cnt: 5, ((T([128, 80, 7, 7], f16, stride=(7840, 49, 7, 1)), T([128, 80, 7, 7], f16)), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16)), {})
+cnt: 4, ((T([128, 480, 7, 7], f16, stride=(47040, 49, 7, 1)), T([128, 480, 7, 7], f16)), {})
+cnt: 4, ((T([128, 160, 7, 7], f16), T([128, 160, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 2, ((T([128, 336, 14, 14], f16, stride=(131712, 196, 14, 1)), T([128, 336, 14, 14], f16)), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16)), {})
+cnt: 2, ((T([128, 56, 14, 14], f16, stride=(21952, 196, 14, 1)), T([128, 56, 14, 14], f16)), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 14, 14], f16, stride=(94080, 196, 14, 1)), T([128, 240, 14, 14], f16)), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16)), {})
+cnt: 4, ((T([128, 40, 14, 14], f16, stride=(15680, 196, 14, 1)), T([128, 40, 14, 14], f16)), {})
+cnt: 2, ((T([128, 92, 14, 14], f16, stride=(36064, 196, 14, 1)), T([128, 92, 14, 14], f16)), {})
+cnt: 1, ((T([128, 100, 14, 14], f16, stride=(39200, 196, 14, 1)), T([128, 100, 14, 14], f16)), {})
+cnt: 1, ((T([128, 120, 28, 28], f16, stride=(188160, 784, 28, 1)), T([128, 120, 28, 28], f16)), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16)), {})
+cnt: 2, ((T([128, 20, 28, 28], f16, stride=(31360, 784, 28, 1)), T([128, 20, 28, 28], f16)), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16)), {})
+cnt: 1, ((T([128, 60, 28, 28], f16, stride=(94080, 784, 28, 1)), T([128, 60, 28, 28], f16)), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16)), {})
+cnt: 2, ((T([128, 36, 56, 56], f16, stride=(225792, 3136, 56, 1)), T([128, 36, 56, 56], f16)), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 2, ((T([128, 12, 56, 56], f16, stride=(75264, 3136, 56, 1)), T([128, 12, 56, 56], f16)), {})
+cnt: 1, ((T([128, 24, 112, 112], f16, stride=(602112, 12544, 112, 1)), T([128, 24, 112, 112], f16)), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+cnt: 2, ((T([128, 8, 112, 112], f16, stride=(200704, 12544, 112, 1)), T([128, 8, 112, 112], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 79, ((T([], i64), 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16)), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16)), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16)), {})
+cnt: 5, ((T([128, 160, 7, 7], f16), T([128, 160, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.cat.default
+cnt: 2, (([T([128, 8, 112, 112], f16), T([128, 8, 112, 112], f16)], 1), {})
+cnt: 1, (([T([128, 24, 112, 112], f16), T([128, 24, 112, 112], f16)], 1), {})
+cnt: 2, (([T([128, 12, 56, 56], f16), T([128, 12, 56, 56], f16)], 1), {})
+cnt: 2, (([T([128, 36, 56, 56], f16), T([128, 36, 56, 56], f16)], 1), {})
+cnt: 2, (([T([128, 20, 28, 28], f16), T([128, 20, 28, 28], f16)], 1), {})
+cnt: 1, (([T([128, 60, 28, 28], f16), T([128, 60, 28, 28], f16)], 1), {})
+cnt: 1, (([T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16)], 1), {})
+cnt: 4, (([T([128, 40, 14, 14], f16), T([128, 40, 14, 14], f16)], 1), {})
+cnt: 1, (([T([128, 100, 14, 14], f16), T([128, 100, 14, 14], f16)], 1), {})
+cnt: 2, (([T([128, 92, 14, 14], f16), T([128, 92, 14, 14], f16)], 1), {})
+cnt: 1, (([T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16)], 1), {})
+cnt: 2, (([T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16)], 1), {})
+cnt: 2, (([T([128, 336, 14, 14], f16), T([128, 336, 14, 14], f16)], 1), {})
+cnt: 5, (([T([128, 80, 7, 7], f16), T([128, 80, 7, 7], f16)], 1), {})
+cnt: 4, (([T([128, 480, 7, 7], f16), T([128, 480, 7, 7], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([8, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 8, 112, 112], f16), T([8, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([24, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([24, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 24), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([48, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 48), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([12, 48, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 12, 56, 56], f16), T([12, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 12), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 1, ((T([128, 16, 56, 56], f16), T([24, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([36, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 36, 56, 56], f16), T([36, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 36), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([12, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 72), {})
+cnt: 1, ((T([128, 72, 1, 1], f16), T([20, 72, 1, 1], f16), T([20], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 20, 1, 1], f16), T([72, 20, 1, 1], f16), T([72], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([20, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 20, 28, 28], f16), T([20, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 20), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([24, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 24), {})
+cnt: 1, ((T([128, 24, 28, 28], f16), T([40, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([60, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 60, 28, 28], f16), T([60, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 60), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([32, 120, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([120, 32, 1, 1], f16), T([120], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), T([20, 120, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), T([120, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 120), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([40, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 40, 14, 14], f16), T([40, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 40), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([40, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 40), {})
+cnt: 1, ((T([128, 40, 14, 14], f16), T([80, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([100, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 100, 14, 14], f16), T([100, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 100), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([40, 200, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([92, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 92, 14, 14], f16), T([92, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 92), {})
+cnt: 2, ((T([128, 184, 14, 14], f16), T([40, 184, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([240, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([240, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([128, 480, 1, 1], f16), T([120, 480, 1, 1], f16), T([120], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([480, 120, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([56, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 56, 14, 14], f16), T([56, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 56), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([80, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 80), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([112, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([336, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 336, 14, 14], f16), T([336, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 336), {})
+cnt: 2, ((T([128, 672, 1, 1], f16), T([168, 672, 1, 1], f16), T([168], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 168, 1, 1], f16), T([672, 168, 1, 1], f16), T([672], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([56, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([80, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 80, 7, 7], f16), T([80, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 80), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([112, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 112), {})
+cnt: 1, ((T([128, 112, 7, 7], f16), T([160, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 160, 7, 7], f16), T([480, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 480, 7, 7], f16), T([480, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 480), {})
+cnt: 4, ((T([128, 960, 7, 7], f16), T([80, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 960, 1, 1], f16), T([240, 960, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([960, 240, 1, 1], f16), T([960], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 160, 7, 7], f16), T([960, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 960, 1, 1], f16), T([1280, 960, 1, 1], f16), T([1280], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 1, 1], f16), T([128, 960, 1, 1], f16), T([1280, 960, 1, 1], f16), [1280], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([128, 160, 7, 7], f16), T([960, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 80, 7, 7], f16), T([128, 80, 7, 7], f16), T([80, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 80, [True, True, False]), {})
+cnt: 4, ((T([128, 80, 7, 7], f16), T([128, 960, 7, 7], f16), T([80, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 960, 1, 1], f16), T([128, 240, 1, 1], f16), T([960, 240, 1, 1], f16), [960], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 960, 1, 1], f16), T([240, 960, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([128, 480, 7, 7], f16), T([128, 480, 7, 7], f16), T([480, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 4, ((T([128, 480, 7, 7], f16), T([128, 160, 7, 7], f16), T([480, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 160, 7, 7], f16), T([128, 112, 7, 7], f16), T([160, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 7, 7], f16), T([128, 112, 14, 14], f16), T([112, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 112, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 7, 7], f16), T([128, 672, 7, 7], f16), T([80, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 672, 1, 1], f16), T([128, 168, 1, 1], f16), T([672, 168, 1, 1], f16), [672], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 168, 1, 1], f16), T([128, 672, 1, 1], f16), T([168, 672, 1, 1], f16), [168], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 2, ((T([128, 336, 14, 14], f16), T([128, 336, 14, 14], f16), T([336, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 336, [True, True, False]), {})
+cnt: 2, ((T([128, 336, 14, 14], f16), T([128, 112, 14, 14], f16), T([336, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([56, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 56, [True, True, False]), {})
+cnt: 1, ((T([128, 56, 14, 14], f16), T([128, 672, 14, 14], f16), T([56, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([128, 80, 14, 14], f16), T([112, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16), T([80, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 80, [True, True, False]), {})
+cnt: 1, ((T([128, 56, 14, 14], f16), T([128, 480, 14, 14], f16), T([56, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 480, 1, 1], f16), T([128, 120, 1, 1], f16), T([480, 120, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([128, 480, 1, 1], f16), T([120, 480, 1, 1], f16), [120], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), T([240, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 80, 14, 14], f16), T([240, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 40, 14, 14], f16), T([128, 40, 14, 14], f16), T([40, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 40, [True, True, False]), {})
+cnt: 2, ((T([128, 40, 14, 14], f16), T([128, 184, 14, 14], f16), T([40, 184, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 92, 14, 14], f16), T([128, 92, 14, 14], f16), T([92, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 92, [True, True, False]), {})
+cnt: 2, ((T([128, 92, 14, 14], f16), T([128, 80, 14, 14], f16), T([92, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 14, 14], f16), T([128, 200, 14, 14], f16), T([40, 200, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 100, 14, 14], f16), T([128, 100, 14, 14], f16), T([100, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 100, [True, True, False]), {})
+cnt: 1, ((T([128, 100, 14, 14], f16), T([128, 80, 14, 14], f16), T([100, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([128, 40, 14, 14], f16), T([80, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 14, 14], f16), T([128, 40, 28, 28], f16), T([40, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 40, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 14, 14], f16), T([128, 240, 14, 14], f16), T([40, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 28, 28], f16), T([240, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 20, 28, 28], f16), T([128, 20, 28, 28], f16), T([20, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 20, [True, True, False]), {})
+cnt: 1, ((T([128, 20, 28, 28], f16), T([128, 120, 28, 28], f16), T([20, 120, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([128, 32, 1, 1], f16), T([120, 32, 1, 1], f16), [120], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 120, 1, 1], f16), T([32, 120, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 60, 28, 28], f16), T([128, 60, 28, 28], f16), T([60, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 60, [True, True, False]), {})
+cnt: 1, ((T([128, 60, 28, 28], f16), T([128, 40, 28, 28], f16), T([60, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([128, 24, 28, 28], f16), T([40, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 28, 28], f16), T([128, 24, 56, 56], f16), T([24, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 24, [True, True, False]), {})
+cnt: 1, ((T([128, 20, 28, 28], f16), T([128, 72, 28, 28], f16), T([20, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 72, 1, 1], f16), T([128, 20, 1, 1], f16), T([72, 20, 1, 1], f16), [72], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 20, 1, 1], f16), T([128, 72, 1, 1], f16), T([20, 72, 1, 1], f16), [20], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 2, ((T([128, 36, 56, 56], f16), T([128, 36, 56, 56], f16), T([36, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 36, [True, True, False]), {})
+cnt: 2, ((T([128, 36, 56, 56], f16), T([128, 24, 56, 56], f16), T([36, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 12, 56, 56], f16), T([128, 12, 56, 56], f16), T([12, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 12, [True, True, False]), {})
+cnt: 1, ((T([128, 12, 56, 56], f16), T([128, 72, 56, 56], f16), T([12, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 16, 56, 56], f16), T([24, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 56, 56], f16), T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 16, [True, True, False]), {})
+cnt: 1, ((T([128, 12, 56, 56], f16), T([128, 48, 56, 56], f16), T([12, 48, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 112, 112], f16), T([48, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 48, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([128, 24, 112, 112], f16), T([24, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 24, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([128, 16, 112, 112], f16), T([24, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 8, 112, 112], f16), T([128, 8, 112, 112], f16), T([8, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 2, ((T([128, 8, 112, 112], f16), T([128, 16, 112, 112], f16), T([8, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+cnt: 15, ((T([128, 160, 7, 7], f16), T([128, 160, 7, 7], f16)), {})
+cnt: 6, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16)), {})
+cnt: 12, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16)), {})
+cnt: 6, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16)), {})
+cnt: 6, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 3, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+Operator: aten.div.Scalar
+cnt: 3, ((T([128, 960, 7, 7], f16, stride=(960, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 672, 7, 7], f16, stride=(672, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 672, 14, 14], f16, stride=(672, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 480, 14, 14], f16, stride=(480, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 120, 28, 28], f16, stride=(120, 1, 0, 0)), 784), {})
+cnt: 1, ((T([128, 72, 28, 28], f16, stride=(72, 1, 0, 0)), 784), {})
+Operator: aten.hardsigmoid.default
+cnt: 1, ((T([128, 72, 1, 1], f16),), {})
+cnt: 1, ((T([128, 120, 1, 1], f16),), {})
+cnt: 1, ((T([128, 480, 1, 1], f16),), {})
+cnt: 2, ((T([128, 672, 1, 1], f16),), {})
+cnt: 2, ((T([128, 960, 1, 1], f16),), {})
+Operator: aten.hardsigmoid_backward.default
+cnt: 2, ((T([128, 960, 1, 1], f16), T([128, 960, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 1, 1], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 1, ((T([128, 480, 1, 1], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([128, 120, 1, 1], f16)), {})
+cnt: 1, ((T([128, 72, 1, 1], f16), T([128, 72, 1, 1], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 72, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 72, 28, 28], f16), T([128, 72, 1, 1], f16)), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([128, 120, 1, 1], f16)), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 7, 7], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 4, ((T([128, 960, 7, 7], f16), T([128, 960, 1, 1], f16)), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16)), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 8, 112, 112], f16), T([8], f16), T([8], f16), T([8], f16), T([8], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 24, 112, 112], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 12, 56, 56], f16), T([12], f16), T([12], f16), T([12], f16), T([12], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 16, 56, 56], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 36, 56, 56], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 20, 28, 28], f16), T([20], f16), T([20], f16), T([20], f16), T([20], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 24, 28, 28], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 60, 28, 28], f16), T([60], f16), T([60], f16), T([60], f16), T([60], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([128, 40, 14, 14], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 100, 14, 14], f16), T([100], f16), T([100], f16), T([100], f16), T([100], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 92, 14, 14], f16), T([92], f16), T([92], f16), T([92], f16), T([92], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 56, 14, 14], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 336, 14, 14], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 10, ((T([128, 80, 7, 7], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 112, 7, 7], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([128, 480, 7, 7], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 80, 7, 7], f16, stride=(7840, 49, 7, 1)), T([128, 80, 7, 7], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 80, 7, 7], f16), T([128, 80, 7, 7], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([128, 480, 7, 7], f16), T([128, 480, 7, 7], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 160, 7, 7], f16), T([128, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 336, 14, 14], f16), T([128, 336, 14, 14], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 56, 14, 14], f16, stride=(21952, 196, 14, 1)), T([128, 56, 14, 14], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f32), T([56], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f32), T([56], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 40, 14, 14], f16, stride=(15680, 196, 14, 1)), T([128, 40, 14, 14], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 40, 14, 14], f16), T([128, 40, 14, 14], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 92, 14, 14], f16), T([128, 92, 14, 14], f16), T([92], f16), T([92], f16), T([92], f16), T([92], f32), T([92], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 100, 14, 14], f16), T([128, 100, 14, 14], f16), T([100], f16), T([100], f16), T([100], f16), T([100], f32), T([100], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 20, 28, 28], f16, stride=(31360, 784, 28, 1)), T([128, 20, 28, 28], f16), T([20], f16), T([20], f16), T([20], f16), T([20], f32), T([20], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 20, 28, 28], f16), T([128, 20, 28, 28], f16), T([20], f16), T([20], f16), T([20], f16), T([20], f32), T([20], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 60, 28, 28], f16), T([128, 60, 28, 28], f16), T([60], f16), T([60], f16), T([60], f16), T([60], f32), T([60], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 28, 28], f16), T([128, 24, 28, 28], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 36, 56, 56], f16), T([128, 36, 56, 56], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f32), T([36], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 12, 56, 56], f16, stride=(75264, 3136, 56, 1)), T([128, 12, 56, 56], f16), T([12], f16), T([12], f16), T([12], f16), T([12], f32), T([12], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 12, 56, 56], f16), T([128, 12, 56, 56], f16), T([12], f16), T([12], f16), T([12], f16), T([12], f32), T([12], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 56, 56], f16), T([128, 16, 56, 56], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 24, 112, 112], f16), T([128, 24, 112, 112], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 8, 112, 112], f16, stride=(200704, 12544, 112, 1)), T([128, 8, 112, 112], f16), T([8], f16), T([8], f16), T([8], f16), T([8], f32), T([8], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 8, 112, 112], f16), T([128, 8, 112, 112], f16), T([8], f16), T([8], f16), T([8], f16), T([8], f32), T([8], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 5, ((T([128, 160, 7, 7], f16), [128, 160, 7, 7], [7840, 49, 7, 1]), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), [128, 112, 14, 14], [21952, 196, 14, 1]), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), [128, 80, 14, 14], [15680, 196, 14, 1]), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), [128, 40, 28, 28], [31360, 784, 28, 1]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), [128, 24, 56, 56], [75264, 3136, 56, 1]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), [128, 16, 112, 112], [200704, 12544, 112, 1]), {})
+Operator: aten.new_zeros.default
+cnt: 5, ((T([128, 160, 7, 7], f16), [1003520]), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), [2809856]), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), [2007040]), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), [4014080]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), [9633792]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), [25690112]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 16, 112, 112], f16),), {})
+cnt: 2, ((T([128, 8, 112, 112], f16),), {})
+cnt: 2, ((T([128, 24, 112, 112], f16),), {})
+cnt: 4, ((T([128, 36, 56, 56], f16),), {})
+cnt: 1, ((T([128, 20, 1, 1], f16),), {})
+cnt: 2, ((T([128, 60, 28, 28], f16),), {})
+cnt: 1, ((T([128, 32, 1, 1], f16),), {})
+cnt: 2, ((T([128, 120, 28, 28], f16),), {})
+cnt: 2, ((T([128, 100, 14, 14], f16),), {})
+cnt: 4, ((T([128, 92, 14, 14], f16),), {})
+cnt: 2, ((T([128, 240, 14, 14], f16),), {})
+cnt: 1, ((T([128, 120, 1, 1], f16),), {})
+cnt: 4, ((T([128, 336, 14, 14], f16),), {})
+cnt: 2, ((T([128, 168, 1, 1], f16),), {})
+cnt: 8, ((T([128, 480, 7, 7], f16),), {})
+cnt: 2, ((T([128, 240, 1, 1], f16),), {})
+cnt: 1, ((T([128, 960, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 1, 1], f16),), {})
+Operator: aten.slice_backward.default
+cnt: 4, ((T([128, 960, 7, 7], f16), [128, 960, 7, 7], 3, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 960, 7, 7], f16), [128, 960, 7, 7], 2, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([128, 960, 7, 7], f16), [128, 960, 7, 7], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), [128, 672, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), [128, 672, 14, 14], 2, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), [128, 672, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), [128, 480, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), [128, 480, 14, 14], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), [128, 480, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 184, 14, 14], f16), [128, 184, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 184, 14, 14], f16), [128, 184, 14, 14], 2, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 184, 14, 14], f16), [128, 184, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), [128, 200, 14, 14], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), [128, 200, 14, 14], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), [128, 200, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), [128, 240, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), [128, 240, 28, 28], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), [128, 240, 28, 28], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), [128, 120, 28, 28], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), [128, 120, 28, 28], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), [128, 120, 28, 28], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), [128, 72, 56, 56], 3, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), [128, 72, 56, 56], 2, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), [128, 72, 56, 56], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), [128, 48, 112, 112], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), [128, 48, 112, 112], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), [128, 48, 112, 112], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), [128, 16, 112, 112], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), [128, 16, 112, 112], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), [128, 16, 112, 112], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 1280, 1, 1], f16), T([128, 1280, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), 0), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 240, 1, 1], f16), 0), {})
+cnt: 4, ((T([128, 480, 7, 7], f16, stride=(47040, 49, 7, 1)), T([128, 480, 7, 7], f16), 0), {})
+cnt: 4, ((T([128, 480, 7, 7], f16), T([128, 480, 7, 7], f16), 0), {})
+cnt: 2, ((T([128, 168, 1, 1], f16), T([128, 168, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 336, 14, 14], f16, stride=(131712, 196, 14, 1)), T([128, 336, 14, 14], f16), 0), {})
+cnt: 2, ((T([128, 336, 14, 14], f16), T([128, 336, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([128, 120, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 240, 14, 14], f16, stride=(94080, 196, 14, 1)), T([128, 240, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), 0), {})
+cnt: 2, ((T([128, 92, 14, 14], f16, stride=(36064, 196, 14, 1)), T([128, 92, 14, 14], f16), 0), {})
+cnt: 2, ((T([128, 92, 14, 14], f16), T([128, 92, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 100, 14, 14], f16, stride=(39200, 196, 14, 1)), T([128, 100, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 100, 14, 14], f16), T([128, 100, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 120, 28, 28], f16, stride=(188160, 784, 28, 1)), T([128, 120, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 32, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 60, 28, 28], f16, stride=(94080, 784, 28, 1)), T([128, 60, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 60, 28, 28], f16), T([128, 60, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 20, 1, 1], f16), T([128, 20, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 36, 56, 56], f16, stride=(225792, 3136, 56, 1)), T([128, 36, 56, 56], f16), 0), {})
+cnt: 2, ((T([128, 36, 56, 56], f16), T([128, 36, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 24, 112, 112], f16, stride=(602112, 12544, 112, 1)), T([128, 24, 112, 112], f16), 0), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([128, 24, 112, 112], f16), 0), {})
+cnt: 1, ((T([128, 8, 112, 112], f16, stride=(200704, 12544, 112, 1)), T([128, 8, 112, 112], f16), 0), {})
+cnt: 1, ((T([128, 8, 112, 112], f16), T([128, 8, 112, 112], f16), 0), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gluon_inception_v3_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gluon_inception_v3_training.txt
new file mode 100644
index 0000000000000..c11cd6890c765
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gluon_inception_v3_training.txt
@@ -0,0 +1,239 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16)), {})
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16)), {})
+cnt: 3, ((T([128, 1280, 8, 8], f16), T([128, 1280, 8, 8], f16)), {})
+cnt: 14, ((T([128, 768, 17, 17], f16), T([128, 768, 17, 17], f16)), {})
+cnt: 5, ((T([128, 288, 35, 35], f16), T([128, 288, 35, 35], f16)), {})
+cnt: 3, ((T([128, 256, 35, 35], f16), T([128, 256, 35, 35], f16)), {})
+cnt: 3, ((T([128, 192, 35, 35], f16), T([128, 192, 35, 35], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 94, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 192, 35, 35], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 256, 35, 35], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 4, ((T([128, 768, 17, 17], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), [3, 3], [1, 1], [1, 1]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([128, 1280, 8, 8], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 4, ((T([128, 768, 17, 17], f16), T([128, 768, 17, 17], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), T([128, 288, 35, 35], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 256, 35, 35], f16), T([128, 256, 35, 35], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([128, 192, 35, 35], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), T([128, 96, 35, 35], f16), T([128, 32, 35, 35], f16)], 1), {})
+cnt: 2, (([T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), T([128, 96, 35, 35], f16), T([128, 64, 35, 35], f16)], 1), {})
+cnt: 1, (([T([128, 384, 17, 17], f16), T([128, 96, 17, 17], f16), T([128, 288, 17, 17], f16)], 1), {})
+cnt: 4, (([T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16)], 1), {})
+cnt: 1, (([T([128, 320, 8, 8], f16), T([128, 192, 8, 8], f16), T([128, 768, 8, 8], f16)], 1), {})
+cnt: 4, (([T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16)], 1), {})
+cnt: 2, (([T([128, 320, 8, 8], f16), T([128, 768, 8, 8], f16), T([128, 768, 8, 8], f16), T([128, 192, 8, 8], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 299, 299], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 299, 299], f16), T([32, 3, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([32, 32, 3, 3], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 73, 73], f16), T([80, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([192, 80, 3, 3], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 192, 35, 35], f16), T([64, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([48, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([64, 48, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 64, 35, 35], f16), T([96, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([32, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 256, 35, 35], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 35, 35], f16), T([48, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 288, 35, 35], f16), T([64, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), T([48, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), T([384, 288, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([128, 768, 17, 17], f16), T([192, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 768, 17, 17], f16), T([128, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 17, 17], f16), T([192, 128, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 17, 17], f16), T([192, 128, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 768, 17, 17], f16), T([160, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([160, 160, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 160, 17, 17], f16), T([192, 160, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([160, 160, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 160, 17, 17], f16), T([192, 160, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([192, 192, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([192, 192, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([320, 192, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([192, 192, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([320, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([384, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([384, 384, 1, 3], f16), None, [1, 1], [0, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([384, 384, 3, 1], f16), None, [1, 1], [1, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([448, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([384, 448, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([192, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([320, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([384, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([448, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([192, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 192, 8, 8], f16), T([128, 2048, 8, 8], f16), T([192, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), T([384, 384, 3, 1], f16), [0], [1, 1], [1, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), T([384, 384, 1, 3], f16), [0], [1, 1], [0, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 384, 8, 8], f16), T([128, 448, 8, 8], f16), T([384, 448, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 448, 8, 8], f16), T([128, 2048, 8, 8], f16), T([448, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 8, 8], f16), T([128, 2048, 8, 8], f16), T([384, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 8, 8], f16), T([128, 2048, 8, 8], f16), T([320, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 8, 8], f16), T([128, 1280, 8, 8], f16), T([192, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 448, 8, 8], f16), T([128, 1280, 8, 8], f16), T([448, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 8, 8], f16), T([128, 1280, 8, 8], f16), T([384, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 8, 8], f16), T([128, 1280, 8, 8], f16), T([320, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 8, 8], f16), T([128, 192, 17, 17], f16), T([192, 192, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([192, 192, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([192, 192, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 12, ((T([128, 192, 17, 17], f16), T([128, 768, 17, 17], f16), T([192, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 8, 8], f16), T([128, 192, 17, 17], f16), T([320, 192, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 192, 17, 17], f16), T([128, 160, 17, 17], f16), T([192, 160, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), T([160, 160, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), T([160, 160, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([128, 768, 17, 17], f16), T([160, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 192, 17, 17], f16), T([128, 160, 17, 17], f16), T([192, 160, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([128, 128, 17, 17], f16), T([192, 128, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), T([128, 128, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), T([128, 128, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 768, 17, 17], f16), T([128, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([128, 128, 17, 17], f16), T([192, 128, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 17, 17], f16), T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 96, 35, 35], f16), T([128, 64, 35, 35], f16), T([96, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 64, 35, 35], f16), T([128, 288, 35, 35], f16), T([64, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 17, 17], f16), T([128, 288, 35, 35], f16), T([384, 288, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 96, 35, 35], f16), T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 35, 35], f16), T([128, 48, 35, 35], f16), T([64, 48, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 35, 35], f16), T([128, 288, 35, 35], f16), T([48, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 35, 35], f16), T([128, 256, 35, 35], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 35, 35], f16), T([128, 256, 35, 35], f16), T([48, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 35, 35], f16), T([128, 192, 35, 35], f16), T([32, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 35, 35], f16), T([128, 192, 35, 35], f16), T([64, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 35, 35], f16), T([128, 192, 35, 35], f16), T([48, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([128, 80, 73, 73], f16), T([192, 80, 3, 3], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([128, 64, 73, 73], f16), T([80, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([128, 32, 147, 147], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([128, 32, 149, 149], f16), T([32, 32, 3, 3], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([128, 3, 299, 299], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 299, 299], f16), T([128, 3, 299, 299], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2048, 8, 8], f16, stride=(2048, 1, 0, 0)), 64), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 147, 147], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 768, 17, 17], f16), [3, 3], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 768, 8, 8], f16, stride=(81920, 64, 8, 1)), T([128, 768, 17, 17], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 768, 8, 8], i64)), {})
+cnt: 1, ((T([128, 288, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 288, 35, 35], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 288, 17, 17], i64)), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([128, 192, 71, 71], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 192, 35, 35], i64)), {})
+cnt: 1, ((T([128, 64, 73, 73], f16), T([128, 64, 147, 147], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 64, 73, 73], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 2048, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 32, 149, 149], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 12, ((T([128, 64, 35, 35], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 0.001), {})
+cnt: 7, ((T([128, 96, 35, 35], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 32, 35, 35], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 384, 17, 17], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 96, 17, 17], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 26, ((T([128, 192, 17, 17], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 6, ((T([128, 128, 17, 17], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 0.001), {})
+cnt: 12, ((T([128, 160, 17, 17], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 320, 8, 8], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 192, 8, 8], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 12, ((T([128, 384, 8, 8], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), True, 0.1, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([128, 192, 8, 8], f16), T([128, 192, 8, 8], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 12, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([128, 448, 8, 8], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 320, 8, 8], f16), T([128, 320, 8, 8], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 0.001, [True, True, True]), {})
+cnt: 26, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 12, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 0.001, [True, True, True]), {})
+cnt: 6, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 17, 17], f16), T([128, 96, 17, 17], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+cnt: 7, ((T([128, 96, 35, 35], f16), T([128, 96, 35, 35], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+cnt: 12, ((T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 384, 17, 17], f16), T([128, 384, 17, 17], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([128, 48, 35, 35], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 35, 35], f16), T([128, 32, 35, 35], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([128, 192, 71, 71], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([128, 80, 73, 73], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([128, 64, 147, 147], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([128, 32, 147, 147], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([128, 32, 149, 149], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 32, 149, 149], f16),), {})
+cnt: 1, ((T([128, 32, 147, 147], f16),), {})
+cnt: 1, ((T([128, 64, 147, 147], f16),), {})
+cnt: 1, ((T([128, 80, 73, 73], f16),), {})
+cnt: 1, ((T([128, 192, 71, 71], f16),), {})
+cnt: 12, ((T([128, 64, 35, 35], f16),), {})
+cnt: 3, ((T([128, 48, 35, 35], f16),), {})
+cnt: 7, ((T([128, 96, 35, 35], f16),), {})
+cnt: 1, ((T([128, 32, 35, 35], f16),), {})
+cnt: 1, ((T([128, 384, 17, 17], f16),), {})
+cnt: 1, ((T([128, 96, 17, 17], f16),), {})
+cnt: 26, ((T([128, 192, 17, 17], f16),), {})
+cnt: 6, ((T([128, 128, 17, 17], f16),), {})
+cnt: 12, ((T([128, 160, 17, 17], f16),), {})
+cnt: 3, ((T([128, 320, 8, 8], f16),), {})
+cnt: 3, ((T([128, 192, 8, 8], f16),), {})
+cnt: 12, ((T([128, 384, 8, 8], f16),), {})
+cnt: 2, ((T([128, 448, 8, 8], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([128, 192, 8, 8], f16, stride=(131072, 64, 8, 1)), T([128, 192, 8, 8], f16), 0), {})
+cnt: 8, ((T([128, 384, 8, 8], f16, stride=(131072, 64, 8, 1)), T([128, 384, 8, 8], f16), 0), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), 0), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([128, 448, 8, 8], f16), 0), {})
+cnt: 2, ((T([128, 320, 8, 8], f16, stride=(131072, 64, 8, 1)), T([128, 320, 8, 8], f16), 0), {})
+cnt: 1, ((T([128, 192, 8, 8], f16, stride=(81920, 64, 8, 1)), T([128, 192, 8, 8], f16), 0), {})
+cnt: 10, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), 0), {})
+cnt: 1, ((T([128, 320, 8, 8], f16, stride=(81920, 64, 8, 1)), T([128, 320, 8, 8], f16), 0), {})
+cnt: 16, ((T([128, 192, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 192, 17, 17], f16), 0), {})
+cnt: 12, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), 0), {})
+cnt: 6, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), 0), {})
+cnt: 1, ((T([128, 96, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 96, 17, 17], f16), 0), {})
+cnt: 4, ((T([128, 96, 35, 35], f16), T([128, 96, 35, 35], f16), 0), {})
+cnt: 4, ((T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 384, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 384, 17, 17], f16), 0), {})
+cnt: 6, ((T([128, 64, 35, 35], f16, stride=(352800, 1225, 35, 1)), T([128, 64, 35, 35], f16), 0), {})
+cnt: 2, ((T([128, 96, 35, 35], f16, stride=(352800, 1225, 35, 1)), T([128, 96, 35, 35], f16), 0), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([128, 48, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 32, 35, 35], f16, stride=(313600, 1225, 35, 1)), T([128, 32, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 96, 35, 35], f16, stride=(313600, 1225, 35, 1)), T([128, 96, 35, 35], f16), 0), {})
+cnt: 2, ((T([128, 64, 35, 35], f16, stride=(313600, 1225, 35, 1)), T([128, 64, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([128, 192, 71, 71], f16), 0), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([128, 80, 73, 73], f16), 0), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([128, 64, 147, 147], f16), 0), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([128, 32, 147, 147], f16), 0), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([128, 32, 149, 149], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gluon_senet154_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gluon_senet154_training.txt
new file mode 100644
index 0000000000000..b766b8a41570c
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gluon_senet154_training.txt
@@ -0,0 +1,187 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([32, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([32, 1000], f16), T([32, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 5, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+cnt: 72, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 16, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 6, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 157, ((T([], i64), 1), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+cnt: 8, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 36, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([64, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([128, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 128, 56, 56], f16), T([256, 2, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([16, 256, 1, 1], f16), T([16], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 16, 1, 1], f16), T([256, 16, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([256, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([512, 4, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 9, ((T([32, 512, 28, 28], f16), T([512, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([32, 512, 1, 1], f16), T([32, 512, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([32, 32, 1, 1], f16), T([512, 32, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([512, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([32, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([512, 4, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([1024, 8, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 37, ((T([32, 1024, 14, 14], f16), T([1024, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 36, ((T([32, 1024, 1, 1], f16), T([64, 1024, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 36, ((T([32, 64, 1, 1], f16), T([1024, 64, 1, 1], f16), T([1024], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([1024, 512, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 35, ((T([32, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([1024, 8, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([2048, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([2048, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 2048, 1, 1], f16), T([128, 2048, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 128, 1, 1], f16), T([2048, 128, 1, 1], f16), T([2048], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([2048, 1024, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([1024, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([2048, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([32, 2048, 1, 1], f16), T([32, 128, 1, 1], f16), T([2048, 128, 1, 1], f16), [2048], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 128, 1, 1], f16), T([32, 2048, 1, 1], f16), T([128, 2048, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), T([2048, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([32, 1024, 7, 7], f16), T([2048, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([32, 2048, 7, 7], f16), T([1024, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 1024, 14, 14], f16), T([2048, 1024, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 1024, 14, 14], f16), T([2048, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 37, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), T([1024, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 36, ((T([32, 1024, 1, 1], f16), T([32, 64, 1, 1], f16), T([1024, 64, 1, 1], f16), [1024], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 36, ((T([32, 64, 1, 1], f16), T([32, 1024, 1, 1], f16), T([64, 1024, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 35, ((T([32, 1024, 14, 14], f16), T([32, 512, 14, 14], f16), T([1024, 8, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([32, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 512, 28, 28], f16), T([1024, 512, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 512, 28, 28], f16), T([1024, 8, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 9, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), T([512, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([32, 512, 1, 1], f16), T([32, 32, 1, 1], f16), T([512, 32, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 8, ((T([32, 32, 1, 1], f16), T([32, 512, 1, 1], f16), T([32, 512, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 7, ((T([32, 512, 28, 28], f16), T([32, 256, 28, 28], f16), T([512, 4, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([32, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 256, 56, 56], f16), T([512, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 256, 56, 56], f16), T([512, 4, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([32, 16, 1, 1], f16), T([256, 16, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 16, 1, 1], f16), T([32, 256, 1, 1], f16), T([16, 256, 1, 1], f16), [16], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 128, 56, 56], f16), T([256, 2, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 2, ((T([32, 128, 56, 56], f16), T([32, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 128, 56, 56], f16), T([256, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([32, 64, 112, 112], f16), T([128, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 3, 224, 224], f16), T([64, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 4, ((T([32, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+cnt: 36, ((T([32, 1024, 14, 14], f16, stride=(1024, 1, 0, 0)), 196), {})
+cnt: 8, ((T([32, 512, 28, 28], f16, stride=(512, 1, 0, 0)), 784), {})
+cnt: 3, ((T([32, 256, 56, 56], f16, stride=(256, 1, 0, 0)), 3136), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([32], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([32, 128, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([32, 128, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 3, ((T([32, 256, 56, 56], f16), [2, 3], True), {})
+cnt: 8, ((T([32, 512, 28, 28], f16), [2, 3], True), {})
+cnt: 36, ((T([32, 1024, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(1, 1000)), T([32, 2048], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 6, ((T([32, 256, 56, 56], f16), T([32, 256, 1, 1], f16)), {})
+cnt: 16, ((T([32, 512, 28, 28], f16), T([32, 512, 1, 1], f16)), {})
+cnt: 72, ((T([32, 1024, 14, 14], f16), T([32, 1024, 1, 1], f16)), {})
+cnt: 6, ((T([32, 2048, 7, 7], f16), T([32, 2048, 1, 1], f16)), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+cnt: 36, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 8, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 18, ((T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 74, ((T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 7, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([32, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 74, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([32, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 18, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([32, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([32, 128, 112, 112], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([32, 1000], f16), T([32], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([32, 1000], f16), T([32], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([32, 64, 112, 112], f16),), {})
+cnt: 1, ((T([32, 128, 112, 112], f16),), {})
+cnt: 3, ((T([32, 128, 56, 56], f16),), {})
+cnt: 7, ((T([32, 256, 56, 56], f16),), {})
+cnt: 3, ((T([32, 16, 1, 1], f16),), {})
+cnt: 17, ((T([32, 512, 28, 28], f16),), {})
+cnt: 8, ((T([32, 32, 1, 1], f16),), {})
+cnt: 7, ((T([32, 256, 28, 28], f16),), {})
+cnt: 73, ((T([32, 1024, 14, 14], f16),), {})
+cnt: 36, ((T([32, 64, 1, 1], f16),), {})
+cnt: 35, ((T([32, 512, 14, 14], f16),), {})
+cnt: 6, ((T([32, 2048, 7, 7], f16),), {})
+cnt: 3, ((T([32, 128, 1, 1], f16),), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 3, ((T([32, 256, 1, 1], f16),), {})
+cnt: 8, ((T([32, 512, 1, 1], f16),), {})
+cnt: 36, ((T([32, 1024, 1, 1], f16),), {})
+cnt: 3, ((T([32, 2048, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 3, ((T([32, 2048, 1, 1], f16), T([32, 2048, 1, 1], f16)), {})
+cnt: 36, ((T([32, 1024, 1, 1], f16), T([32, 1024, 1, 1], f16)), {})
+cnt: 8, ((T([32, 512, 1, 1], f16), T([32, 512, 1, 1], f16)), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([32, 256, 1, 1], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16), [0], True), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), [2, 3], True), {})
+cnt: 36, ((T([32, 1024, 14, 14], f16), [2, 3], True), {})
+cnt: 8, ((T([32, 512, 28, 28], f16), [2, 3], True), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 6, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), 0), {})
+cnt: 3, ((T([32, 128, 1, 1], f16), T([32, 128, 1, 1], f16), 0), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([32, 1024, 7, 7], f16), 0), {})
+cnt: 73, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), 0), {})
+cnt: 36, ((T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), 0), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([32, 512, 14, 14], f16), 0), {})
+cnt: 17, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), 0), {})
+cnt: 8, ((T([32, 32, 1, 1], f16), T([32, 32, 1, 1], f16), 0), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([32, 256, 28, 28], f16), 0), {})
+cnt: 7, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), 0), {})
+cnt: 3, ((T([32, 16, 1, 1], f16), T([32, 16, 1, 1], f16), 0), {})
+cnt: 3, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([32, 128, 112, 112], f16), 0), {})
+cnt: 2, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gluon_xception65_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gluon_xception65_training.txt
new file mode 100644
index 0000000000000..53a6cc2148962
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gluon_xception65_training.txt
@@ -0,0 +1,155 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([32, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([32, 1000], f16), T([32, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([32, 128, 75, 75], f16), T([32, 128, 75, 75], f16)), {})
+cnt: 2, ((T([32, 256, 38, 38], f16), T([32, 256, 38, 38], f16)), {})
+cnt: 34, ((T([32, 728, 19, 19], f16), T([32, 728, 19, 19], f16)), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16), T([32, 1024, 10, 10], f16)), {})
+cnt: 1, ((T([32, 64, 150, 150], f16), T([32, 64, 150, 150], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 132, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 299, 299], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 299, 299], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 32, 150, 150], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 150, 150], f16), T([128, 64, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 150, 150], f16), T([64, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([32, 64, 150, 150], f16), T([128, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 150, 150], f16), T([128, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 128), {})
+cnt: 1, ((T([32, 128, 150, 150], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 150, 150], f16), T([128, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 128), {})
+cnt: 1, ((T([32, 128, 75, 75], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 75, 75], f16), T([256, 128, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 75, 75], f16), T([128, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 128), {})
+cnt: 1, ((T([32, 128, 75, 75], f16), T([256, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 75, 75], f16), T([256, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 256), {})
+cnt: 1, ((T([32, 256, 75, 75], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 75, 75], f16), T([256, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 256), {})
+cnt: 1, ((T([32, 256, 38, 38], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 38, 38], f16), T([728, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 38, 38], f16), T([256, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 256), {})
+cnt: 1, ((T([32, 256, 38, 38], f16), T([728, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 728, 38, 38], f16), T([728, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 728), {})
+cnt: 1, ((T([32, 728, 38, 38], f16), T([728, 728, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 728, 38, 38], f16), T([728, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 728), {})
+cnt: 50, ((T([32, 728, 19, 19], f16), T([728, 728, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 50, ((T([32, 728, 19, 19], f16), T([728, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 728), {})
+cnt: 1, ((T([32, 728, 19, 19], f16), T([1024, 728, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 728, 19, 19], f16), T([1024, 728, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 19, 19], f16), T([1024, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1024), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16), T([1024, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16), T([1024, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1024), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16), T([1536, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 1536, 10, 10], f16), T([1536, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1536), {})
+cnt: 1, ((T([32, 1536, 10, 10], f16), T([1536, 1536, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1536, 10, 10], f16), T([2048, 1536, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 2048, 10, 10], f16), T([32, 1536, 10, 10], f16), T([2048, 1536, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 1536, 10, 10], f16), T([32, 1536, 10, 10], f16), T([1536, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1536, [True, True, False]), {})
+cnt: 1, ((T([32, 1536, 10, 10], f16), T([32, 1536, 10, 10], f16), T([1536, 1536, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1536, 10, 10], f16), T([32, 1024, 10, 10], f16), T([1536, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16), T([32, 1024, 10, 10], f16), T([1024, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1024, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16), T([32, 1024, 10, 10], f16), T([1024, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16), T([32, 1024, 19, 19], f16), T([1024, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1024, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 19, 19], f16), T([32, 728, 19, 19], f16), T([1024, 728, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 50, ((T([32, 728, 19, 19], f16), T([32, 728, 19, 19], f16), T([728, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 728, [True, True, False]), {})
+cnt: 50, ((T([32, 728, 19, 19], f16), T([32, 728, 19, 19], f16), T([728, 728, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16), T([32, 728, 19, 19], f16), T([1024, 728, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 728, 19, 19], f16), T([32, 728, 38, 38], f16), T([728, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 728, [True, True, False]), {})
+cnt: 1, ((T([32, 728, 38, 38], f16), T([32, 728, 38, 38], f16), T([728, 728, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 728, 38, 38], f16), T([32, 728, 38, 38], f16), T([728, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 728, [True, True, False]), {})
+cnt: 1, ((T([32, 728, 38, 38], f16), T([32, 256, 38, 38], f16), T([728, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 38, 38], f16), T([32, 256, 38, 38], f16), T([256, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 256, [True, True, False]), {})
+cnt: 1, ((T([32, 728, 19, 19], f16), T([32, 256, 38, 38], f16), T([728, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 38, 38], f16), T([32, 256, 38, 38], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 38, 38], f16), T([32, 256, 75, 75], f16), T([256, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 256, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 75, 75], f16), T([32, 256, 75, 75], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 75, 75], f16), T([32, 256, 75, 75], f16), T([256, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 256, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 75, 75], f16), T([32, 128, 75, 75], f16), T([256, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 75, 75], f16), T([32, 128, 75, 75], f16), T([128, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 128, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 38, 38], f16), T([32, 128, 75, 75], f16), T([256, 128, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 75, 75], f16), T([32, 128, 75, 75], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 75, 75], f16), T([32, 128, 150, 150], f16), T([128, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 128, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 150, 150], f16), T([32, 128, 150, 150], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 150, 150], f16), T([32, 128, 150, 150], f16), T([128, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 128, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 150, 150], f16), T([32, 64, 150, 150], f16), T([128, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 150, 150], f16), T([32, 64, 150, 150], f16), T([64, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 75, 75], f16), T([32, 64, 150, 150], f16), T([128, 64, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 150, 150], f16), T([32, 32, 150, 150], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 32, 150, 150], f16), T([32, 3, 299, 299], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 299, 299], f16), T([32, 3, 299, 299], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 2048, 10, 10], f16, stride=(2048, 1, 0, 0)), 100), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([32], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 2048, 10, 10], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(1, 1000)), T([32, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([32, 32, 150, 150], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 64, 150, 150], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 128, 75, 75], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 128, 150, 150], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 256, 38, 38], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 256, 75, 75], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 102, ((T([32, 728, 19, 19], f16), T([728], f16), T([728], f16), T([728], f16), T([728], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 728, 38, 38], f16), T([728], f16), T([728], f16), T([728], f16), T([728], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 1024, 10, 10], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 1024, 19, 19], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 1536, 10, 10], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 2048, 10, 10], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([32, 2048, 10, 10], f16), T([32, 2048, 10, 10], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 1536, 10, 10], f16), T([32, 1536, 10, 10], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f32), T([1536], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 1024, 10, 10], f16), T([32, 1024, 10, 10], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 1024, 19, 19], f16), T([32, 1024, 19, 19], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 102, ((T([32, 728, 19, 19], f16), T([32, 728, 19, 19], f16), T([728], f16), T([728], f16), T([728], f16), T([728], f32), T([728], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 728, 38, 38], f16), T([32, 728, 38, 38], f16), T([728], f16), T([728], f16), T([728], f16), T([728], f32), T([728], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 256, 38, 38], f16), T([32, 256, 38, 38], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 256, 75, 75], f16), T([32, 256, 75, 75], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 128, 75, 75], f16), T([32, 128, 75, 75], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 128, 150, 150], f16), T([32, 128, 150, 150], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 64, 150, 150], f16), T([32, 64, 150, 150], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 32, 150, 150], f16), T([32, 32, 150, 150], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([32, 1000], f16), T([32], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([32, 1000], f16), T([32], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 1, ((T([32, 256, 38, 38], f16),), {})
+cnt: 17, ((T([32, 728, 19, 19], f16),), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([32, 32, 150, 150], f16),), {})
+cnt: 1, ((T([32, 64, 150, 150], f16),), {})
+cnt: 2, ((T([32, 128, 150, 150], f16),), {})
+cnt: 1, ((T([32, 128, 75, 75], f16),), {})
+cnt: 2, ((T([32, 256, 75, 75], f16),), {})
+cnt: 2, ((T([32, 728, 38, 38], f16),), {})
+cnt: 33, ((T([32, 728, 19, 19], f16),), {})
+cnt: 1, ((T([32, 1024, 19, 19], f16),), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16),), {})
+cnt: 2, ((T([32, 1536, 10, 10], f16),), {})
+cnt: 1, ((T([32, 2048, 10, 10], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([32, 2048, 10, 10], f16), T([32, 2048, 10, 10], f16), 0), {})
+cnt: 2, ((T([32, 1536, 10, 10], f16), T([32, 1536, 10, 10], f16), 0), {})
+cnt: 1, ((T([32, 1024, 10, 10], f16), T([32, 1024, 10, 10], f16), 0), {})
+cnt: 1, ((T([32, 1024, 19, 19], f16), T([32, 1024, 19, 19], f16), 0), {})
+cnt: 50, ((T([32, 728, 19, 19], f16), T([32, 728, 19, 19], f16), 0), {})
+cnt: 2, ((T([32, 728, 38, 38], f16), T([32, 728, 38, 38], f16), 0), {})
+cnt: 1, ((T([32, 256, 38, 38], f16), T([32, 256, 38, 38], f16), 0), {})
+cnt: 2, ((T([32, 256, 75, 75], f16), T([32, 256, 75, 75], f16), 0), {})
+cnt: 1, ((T([32, 128, 75, 75], f16), T([32, 128, 75, 75], f16), 0), {})
+cnt: 2, ((T([32, 128, 150, 150], f16), T([32, 128, 150, 150], f16), 0), {})
+cnt: 1, ((T([32, 64, 150, 150], f16), T([32, 64, 150, 150], f16), 0), {})
+cnt: 1, ((T([32, 32, 150, 150], f16), T([32, 32, 150, 150], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gmixer_24_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gmixer_24_224_training.txt
new file mode 100644
index 0000000000000..3e4deb2860b67
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gmixer_24_224_training.txt
@@ -0,0 +1,83 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 24, ((T([64, 384, 384], f16), [64, 384, 384]), {})
+cnt: 24, ((T([64, 384, 196], f16), [24576, 196]), {})
+Operator: aten.add.Tensor
+cnt: 24, ((T([64, 384, 384], f16), T([384], f16)), {})
+cnt: 24, ((T([64, 196, 384], f16, stride=(75264, 1, 196)), T([64, 196, 384], f16, stride=(75264, 1, 196))), {})
+cnt: 24, ((T([64, 196, 384], f16, stride=(75264, 1, 196)), T([64, 196, 384], f16)), {})
+cnt: 24, ((T([64, 196, 384], f16), T([64, 196, 384], f16)), {})
+cnt: 24, ((T([64, 196, 384], f16), T([64, 196, 384], f16, stride=(75264, 1, 196))), {})
+Operator: aten.addmm.default
+cnt: 24, ((T([196], f16), T([24576, 192], f16), T([192, 196], f16, stride=(1, 192))), {})
+cnt: 24, ((T([1536], f16), T([12544, 384], f16), T([384, 1536], f16, stride=(1, 384))), {})
+cnt: 24, ((T([384], f16), T([12544, 768], f16), T([768, 384], f16, stride=(1, 768))), {})
+cnt: 1, ((T([1000], f16), T([64, 384], f16), T([384, 1000], f16, stride=(1, 384))), {})
+Operator: aten.bmm.default
+cnt: 24, ((T([64, 384, 196], f16, stride=(75264, 1, 384)), T([64, 196, 384], f16, stride=(0, 1, 196))), {})
+cnt: 24, ((T([64, 196, 384], f16), T([64, 384, 384], f16)), {})
+cnt: 24, ((T([64, 384, 384], f16), T([64, 384, 196], f16, stride=(0, 196, 1))), {})
+Operator: aten.cat.default
+cnt: 24, (([T([64, 196, 768], f16), T([64, 196, 768], f16)], 2), {})
+cnt: 24, (([T([64, 384, 192], f16), T([64, 384, 192], f16)], 2), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([384, 3, 16, 16], f16), T([384], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 384, 14, 14], f16, stride=(75264, 1, 5376, 384)), T([64, 3, 224, 224], f16), T([384, 3, 16, 16], f16), [384], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+cnt: 24, ((T([384, 196], f16), T([384, 196], f16, stride=(1, 384))), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 196, 384], f16, stride=(384, 0, 1)), 196), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 196, 384], f16), [1]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 384], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 384], f16)), {})
+cnt: 24, ((T([12544, 384], f16), T([384, 768], f16)), {})
+cnt: 24, ((T([384, 12544], f16, stride=(1, 384)), T([12544, 768], f16)), {})
+cnt: 24, ((T([12544, 1536], f16), T([1536, 384], f16)), {})
+cnt: 24, ((T([1536, 12544], f16, stride=(1, 1536)), T([12544, 384], f16)), {})
+cnt: 24, ((T([24576, 196], f16), T([196, 192], f16)), {})
+cnt: 24, ((T([196, 24576], f16, stride=(1, 196)), T([24576, 192], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 24, ((T([64, 384, 192], f16, stride=(147456, 384, 1)), T([64, 384, 192], f16)), {})
+cnt: 24, ((T([64, 196, 768], f16, stride=(301056, 1536, 1)), T([64, 196, 768], f16)), {})
+cnt: 24, ((T([64, 196, 768], f16), T([64, 196, 768], f16, stride=(301056, 1536, 1))), {})
+cnt: 24, ((T([64, 196, 768], f16), T([64, 196, 768], f16)), {})
+cnt: 24, ((T([64, 384, 192], f16), T([64, 384, 192], f16, stride=(147456, 384, 1))), {})
+cnt: 24, ((T([64, 384, 192], f16), T([64, 384, 192], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 49, ((T([64, 196, 384], f16, stride=(75264, 1, 196)), [384], T([384], f16), T([384], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([64, 196, 384], f16), T([64, 196, 384], f16, stride=(75264, 1, 196)), [384], T([64, 196, 1], f32), T([64, 196, 1], f32), T([384], f16), T([384], f16), [True, True, True]), {})
+cnt: 24, ((T([64, 196, 384], f16, stride=(75264, 1, 196)), T([64, 196, 384], f16, stride=(75264, 1, 196)), [384], T([64, 196, 1], f32), T([64, 196, 1], f32), T([384], f16), T([384], f16), [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 24, ((T([384, 196], f16, stride=(1, 384)), [384, 196], [196, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.silu.default
+cnt: 24, ((T([64, 384, 192], f16, stride=(147456, 384, 1)),), {})
+cnt: 24, ((T([64, 196, 768], f16, stride=(301056, 1536, 1)),), {})
+Operator: aten.silu_backward.default
+cnt: 24, ((T([64, 196, 768], f16), T([64, 196, 768], f16, stride=(301056, 1536, 1))), {})
+cnt: 24, ((T([64, 384, 192], f16), T([64, 384, 192], f16, stride=(147456, 384, 1))), {})
+Operator: aten.split.Tensor
+cnt: 24, ((T([64, 384, 384], f16), 192, -1), {})
+cnt: 24, ((T([64, 196, 1536], f16), 768, -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 24, ((T([12544, 384], f16), [0], True), {})
+cnt: 24, ((T([12544, 1536], f16), [0], True), {})
+cnt: 24, ((T([24576, 196], f16), [0], True), {})
+cnt: 24, ((T([64, 384, 384], f16), [0, 1], True), {})
+cnt: 24, ((T([64, 196, 384], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gmlp_s16_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gmlp_s16_224_training.txt
new file mode 100644
index 0000000000000..81057185fc5e2
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/gmlp_s16_224_training.txt
@@ -0,0 +1,70 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 30, ((T([64, 768, 196], f16), [64, 768, 196]), {})
+Operator: aten.add.Tensor
+cnt: 30, ((T([64, 768, 196], f16), T([196], f16)), {})
+cnt: 30, ((T([64, 196, 256], f16, stride=(50176, 1, 196)), T([64, 196, 256], f16)), {})
+cnt: 30, ((T([64, 196, 256], f16), T([64, 196, 256], f16)), {})
+Operator: aten.addmm.default
+cnt: 30, ((T([1536], f16), T([12544, 256], f16), T([256, 1536], f16, stride=(1, 256))), {})
+cnt: 30, ((T([256], f16), T([12544, 768], f16), T([768, 256], f16, stride=(1, 768))), {})
+cnt: 1, ((T([1000], f16), T([64, 256], f16), T([256, 1000], f16, stride=(1, 256))), {})
+Operator: aten.bmm.default
+cnt: 30, ((T([64, 768, 196], f16, stride=(150528, 1, 768)), T([64, 196, 196], f16, stride=(0, 1, 196))), {})
+cnt: 30, ((T([64, 196, 768], f16), T([64, 768, 196], f16, stride=(150528, 1, 768))), {})
+cnt: 30, ((T([64, 768, 196], f16, stride=(150528, 1, 768)), T([64, 196, 196], f16, stride=(0, 196, 1))), {})
+Operator: aten.cat.default
+cnt: 30, (([T([64, 196, 768], f16), T([64, 196, 768], f16, stride=(150528, 1, 196))], 2), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([256, 3, 16, 16], f16), T([256], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 256, 14, 14], f16, stride=(50176, 1, 3584, 256)), T([64, 3, 224, 224], f16), T([256, 3, 16, 16], f16), [256], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+cnt: 30, ((T([196, 196], f16), T([196, 196], f16, stride=(1, 196))), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 196, 256], f16, stride=(256, 0, 1)), 196), {})
+Operator: aten.gelu.default
+cnt: 30, ((T([64, 196, 1536], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 30, ((T([64, 196, 1536], f16), T([64, 196, 1536], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 196, 256], f16), [1]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 256], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 256], f16)), {})
+cnt: 30, ((T([12544, 256], f16), T([256, 768], f16)), {})
+cnt: 30, ((T([256, 12544], f16, stride=(1, 256)), T([12544, 768], f16)), {})
+cnt: 30, ((T([12544, 1536], f16), T([1536, 256], f16)), {})
+cnt: 30, ((T([1536, 12544], f16, stride=(1, 1536)), T([12544, 256], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 30, ((T([64, 196, 768], f16, stride=(301056, 1536, 1)), T([64, 196, 768], f16, stride=(150528, 1, 196))), {})
+cnt: 30, ((T([64, 196, 768], f16), T([64, 196, 768], f16, stride=(301056, 1536, 1))), {})
+cnt: 30, ((T([64, 196, 768], f16), T([64, 196, 768], f16, stride=(150528, 1, 196))), {})
+Operator: aten.native_layer_norm.default
+cnt: 31, ((T([64, 196, 256], f16, stride=(50176, 1, 196)), [256], T([256], f16), T([256], f16), 1e-06), {})
+cnt: 30, ((T([64, 196, 768], f16, stride=(301056, 1536, 1)), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 31, ((T([64, 196, 256], f16), T([64, 196, 256], f16, stride=(50176, 1, 196)), [256], T([64, 196, 1], f32), T([64, 196, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+cnt: 30, ((T([64, 196, 768], f16, stride=(150528, 1, 196)), T([64, 196, 768], f16, stride=(301056, 1536, 1)), [768], T([64, 196, 1], f32), T([64, 196, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 30, ((T([196, 196], f16, stride=(1, 196)), [196, 196], [196, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.split.Tensor
+cnt: 30, ((T([64, 196, 1536], f16), 768, -1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 30, ((T([12544, 256], f16), [0], True), {})
+cnt: 30, ((T([64, 768, 196], f16, stride=(150528, 1, 768)), [0, 1], True), {})
+cnt: 30, ((T([64, 196, 196], f16), [0], True), {})
+cnt: 30, ((T([12544, 1536], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/hardcorenas_a_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/hardcorenas_a_training.txt
new file mode 100644
index 0000000000000..18f12cb61ce13
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/hardcorenas_a_training.txt
@@ -0,0 +1,260 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 34, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16)), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16)), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16)), {})
+cnt: 2, ((T([128, 192, 7, 7], f16), T([128, 192, 7, 7], f16)), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16)), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 240, 28, 28], f16),), {})
+cnt: 1, ((T([128, 240, 14, 14], f16),), {})
+cnt: 4, ((T([128, 480, 14, 14], f16),), {})
+cnt: 3, ((T([128, 672, 14, 14], f16),), {})
+cnt: 1, ((T([128, 672, 7, 7], f16),), {})
+cnt: 2, ((T([128, 1152, 7, 7], f16),), {})
+cnt: 1, ((T([128, 960, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 1, 1], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([48, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([48, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 48), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([24, 48, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([72, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 72), {})
+cnt: 1, ((T([128, 72, 1, 1], f16), T([24, 72, 1, 1], f16), T([24], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([72, 24, 1, 1], f16), T([72], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([24, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 72), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([40, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([64, 240, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 64, 1, 1], f16), T([240, 64, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([40, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([480, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 480), {})
+cnt: 2, ((T([128, 480, 1, 1], f16), T([120, 480, 1, 1], f16), T([120], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 120, 1, 1], f16), T([480, 120, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([80, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([112, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([672, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 2, ((T([128, 672, 1, 1], f16), T([168, 672, 1, 1], f16), T([168], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 168, 1, 1], f16), T([672, 168, 1, 1], f16), T([672], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([112, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([192, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1152), {})
+cnt: 1, ((T([128, 1152, 1, 1], f16), T([288, 1152, 1, 1], f16), T([288], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 288, 1, 1], f16), T([1152, 288, 1, 1], f16), T([1152], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 7, 7], f16), T([960, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 960, 1, 1], f16), T([1280, 960, 1, 1], f16), T([1280], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 1, 1], f16), T([128, 960, 1, 1], f16), T([1280, 960, 1, 1], f16), [1280], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([128, 192, 7, 7], f16), T([960, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 7, 7], f16), T([128, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1152, 1, 1], f16), T([128, 288, 1, 1], f16), T([1152, 288, 1, 1], f16), [1152], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 288, 1, 1], f16), T([128, 1152, 1, 1], f16), T([288, 1152, 1, 1], f16), [288], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([128, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 7, 7], f16), T([128, 672, 7, 7], f16), T([192, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 672, 1, 1], f16), T([128, 168, 1, 1], f16), T([672, 168, 1, 1], f16), [672], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 168, 1, 1], f16), T([128, 672, 1, 1], f16), T([168, 672, 1, 1], f16), [168], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([128, 112, 14, 14], f16), T([672, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([128, 672, 14, 14], f16), T([112, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([128, 480, 14, 14], f16), T([112, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 480, 1, 1], f16), T([128, 120, 1, 1], f16), T([480, 120, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 120, 1, 1], f16), T([128, 480, 1, 1], f16), T([120, 480, 1, 1], f16), [120], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([128, 480, 14, 14], f16), T([80, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 64, 1, 1], f16), T([240, 64, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 64, 1, 1], f16), T([128, 240, 1, 1], f16), T([64, 240, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 2, ((T([128, 240, 28, 28], f16), T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([128, 240, 28, 28], f16), T([40, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([128, 72, 28, 28], f16), T([40, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), T([128, 24, 56, 56], f16), T([72, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 72, 56, 56], f16), T([24, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 72, 1, 1], f16), T([128, 24, 1, 1], f16), T([72, 24, 1, 1], f16), [72], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([128, 72, 1, 1], f16), T([24, 72, 1, 1], f16), [24], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 48, 56, 56], f16), T([24, 48, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 112, 112], f16), T([48, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 48, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([128, 16, 112, 112], f16), T([48, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 960, 7, 7], f16, stride=(960, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16, stride=(1152, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 672, 7, 7], f16, stride=(672, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 672, 14, 14], f16, stride=(672, 1, 0, 0)), 196), {})
+cnt: 2, ((T([128, 480, 14, 14], f16, stride=(480, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 240, 14, 14], f16, stride=(240, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 240, 28, 28], f16, stride=(240, 1, 0, 0)), 784), {})
+cnt: 1, ((T([128, 72, 56, 56], f16, stride=(72, 1, 0, 0)), 3136), {})
+Operator: aten.hardsigmoid.default
+cnt: 1, ((T([128, 72, 1, 1], f16),), {})
+cnt: 2, ((T([128, 240, 1, 1], f16),), {})
+cnt: 2, ((T([128, 480, 1, 1], f16),), {})
+cnt: 2, ((T([128, 672, 1, 1], f16),), {})
+cnt: 1, ((T([128, 1152, 1, 1], f16),), {})
+Operator: aten.hardsigmoid_backward.default
+cnt: 1, ((T([128, 1152, 1, 1], f16), T([128, 1152, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 1, 1], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 2, ((T([128, 480, 1, 1], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 240, 1, 1], f16)), {})
+cnt: 1, ((T([128, 72, 1, 1], f16), T([128, 72, 1, 1], f16)), {})
+Operator: aten.hardswish_.default
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 240, 28, 28], f16),), {})
+cnt: 1, ((T([128, 240, 14, 14], f16),), {})
+cnt: 4, ((T([128, 480, 14, 14], f16),), {})
+cnt: 3, ((T([128, 672, 14, 14], f16),), {})
+cnt: 1, ((T([128, 672, 7, 7], f16),), {})
+cnt: 2, ((T([128, 1152, 7, 7], f16),), {})
+cnt: 1, ((T([128, 960, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 1, 1], f16),), {})
+Operator: aten.hardswish_backward.default
+cnt: 1, ((T([128, 1280, 1, 1], f16), T([128, 1280, 1, 1], f16)), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16)), {})
+cnt: 2, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 3, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 4, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16)), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 72, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 72, 56, 56], f16), T([128, 72, 1, 1], f16)), {})
+cnt: 2, ((T([128, 240, 28, 28], f16), T([128, 240, 1, 1], f16)), {})
+cnt: 2, ((T([128, 240, 14, 14], f16), T([128, 240, 1, 1], f16)), {})
+cnt: 4, ((T([128, 480, 14, 14], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 7, 7], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 2, ((T([128, 1152, 7, 7], f16), T([128, 1152, 1, 1], f16)), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16)), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 192, 7, 7], f16), T([128, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f32), T([1152], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([128, 48, 112, 112], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 48, 112, 112], f16),), {})
+cnt: 1, ((T([128, 48, 56, 56], f16),), {})
+cnt: 3, ((T([128, 72, 56, 56], f16),), {})
+cnt: 1, ((T([128, 24, 1, 1], f16),), {})
+cnt: 1, ((T([128, 72, 28, 28], f16),), {})
+cnt: 2, ((T([128, 240, 28, 28], f16),), {})
+cnt: 2, ((T([128, 64, 1, 1], f16),), {})
+cnt: 2, ((T([128, 120, 1, 1], f16),), {})
+cnt: 2, ((T([128, 168, 1, 1], f16),), {})
+cnt: 1, ((T([128, 288, 1, 1], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 288, 1, 1], f16), T([128, 288, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 168, 1, 1], f16), T([128, 168, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 120, 1, 1], f16), T([128, 120, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 64, 1, 1], f16), T([128, 64, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), 0), {})
+cnt: 3, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([128, 24, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([128, 48, 112, 112], f16), 0), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/hrnet_w18_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/hrnet_w18_training.txt
new file mode 100644
index 0000000000000..cf63431eecc20
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/hrnet_w18_training.txt
@@ -0,0 +1,247 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 69, ((T([128, 18, 56, 56], f16), T([128, 18, 56, 56], f16)), {})
+cnt: 70, ((T([128, 36, 28, 28], f16), T([128, 36, 28, 28], f16)), {})
+cnt: 64, ((T([128, 72, 14, 14], f16), T([128, 72, 14, 14], f16)), {})
+cnt: 31, ((T([128, 144, 7, 7], f16), T([128, 144, 7, 7], f16)), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([128, 256, 28, 28], f16)), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([128, 512, 14, 14], f16)), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16)), {})
+cnt: 4, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 325, ((T([], i64), 1), {})
+cnt: 4, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+cnt: 32, ((T([128, 18, 56, 56], f16), T([128, 18, 56, 56], f16)), {})
+cnt: 32, ((T([128, 36, 28, 28], f16), T([128, 36, 28, 28], f16)), {})
+cnt: 28, ((T([128, 72, 14, 14], f16), T([128, 72, 14, 14], f16)), {})
+cnt: 12, ((T([128, 144, 7, 7], f16), T([128, 144, 7, 7], f16)), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16)), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([128, 256, 28, 28], f16)), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([128, 512, 14, 14], f16)), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([64, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 64, 56, 56], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 64, 56, 56], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 256, 56, 56], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([18, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([36, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 64, ((T([128, 18, 56, 56], f16), T([18, 18, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 64, ((T([128, 36, 28, 28], f16), T([36, 36, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([128, 36, 28, 28], f16), T([18, 36, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([128, 18, 56, 56], f16), T([36, 18, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([128, 36, 28, 28], f16), T([72, 36, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 56, ((T([128, 72, 14, 14], f16), T([72, 72, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 72, 14, 14], f16), T([18, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 72, 14, 14], f16), T([36, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 10, ((T([128, 18, 56, 56], f16), T([18, 18, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 18, 28, 28], f16), T([72, 18, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 72, 14, 14], f16), T([144, 72, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 24, ((T([128, 144, 7, 7], f16), T([144, 144, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 144, 7, 7], f16), T([18, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 144, 7, 7], f16), T([36, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 144, 7, 7], f16), T([72, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 18, 28, 28], f16), T([18, 18, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 18, 14, 14], f16), T([144, 18, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 36, 28, 28], f16), T([36, 36, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 36, 14, 14], f16), T([144, 36, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 18, 56, 56], f16), T([32, 18, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([32, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([128, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 18, 56, 56], f16), T([128, 18, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 36, 28, 28], f16), T([64, 36, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 36, 28, 28], f16), T([256, 36, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([256, 128, 3, 3], f16), T([256], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 14, 14], f16), T([128, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 14, 14], f16), T([512, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([512, 256, 3, 3], f16), T([512], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 7, 7], f16), T([256, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 7, 7], f16), T([1024, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([1024, 512, 3, 3], f16), T([1024], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), T([2048], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 2048, 7, 7], f16), T([128, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), [2048], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 512, 14, 14], f16), T([1024, 512, 3, 3], f16), [1024], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 144, 7, 7], f16), T([1024, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 256, 7, 7], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([128, 144, 7, 7], f16), T([256, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([128, 256, 28, 28], f16), T([512, 256, 3, 3], f16), [512], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([128, 72, 14, 14], f16), T([512, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([128, 128, 14, 14], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128, 72, 14, 14], f16), T([128, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([128, 128, 56, 56], f16), T([256, 128, 3, 3], f16), [256], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([128, 36, 28, 28], f16), T([256, 36, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([128, 64, 28, 28], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([128, 36, 28, 28], f16), T([64, 36, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 18, 56, 56], f16), T([128, 18, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 32, 56, 56], f16), T([128, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([32, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([128, 18, 56, 56], f16), T([32, 18, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 144, 7, 7], f16), T([128, 72, 14, 14], f16), T([144, 72, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 144, 7, 7], f16), T([128, 36, 14, 14], f16), T([144, 36, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 36, 14, 14], f16), T([128, 36, 28, 28], f16), T([36, 36, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 144, 7, 7], f16), T([128, 18, 14, 14], f16), T([144, 18, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 18, 14, 14], f16), T([128, 18, 28, 28], f16), T([18, 18, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 10, ((T([128, 18, 28, 28], f16), T([128, 18, 56, 56], f16), T([18, 18, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 72, 7, 7], f16), T([128, 144, 7, 7], f16), T([72, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([128, 72, 14, 14], f16), T([128, 36, 28, 28], f16), T([72, 36, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([128, 72, 14, 14], f16), T([128, 18, 28, 28], f16), T([72, 18, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 36, 7, 7], f16), T([128, 144, 7, 7], f16), T([36, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([128, 36, 14, 14], f16), T([128, 72, 14, 14], f16), T([36, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([128, 36, 28, 28], f16), T([128, 18, 56, 56], f16), T([36, 18, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 18, 7, 7], f16), T([128, 144, 7, 7], f16), T([18, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([128, 18, 14, 14], f16), T([128, 72, 14, 14], f16), T([18, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([128, 18, 28, 28], f16), T([128, 36, 28, 28], f16), T([18, 36, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 24, ((T([128, 144, 7, 7], f16), T([128, 144, 7, 7], f16), T([144, 144, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 56, ((T([128, 72, 14, 14], f16), T([128, 72, 14, 14], f16), T([72, 72, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 64, ((T([128, 36, 28, 28], f16), T([128, 36, 28, 28], f16), T([36, 36, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 64, ((T([128, 18, 56, 56], f16), T([128, 18, 56, 56], f16), T([18, 18, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 36, 28, 28], f16), T([128, 256, 56, 56], f16), T([36, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 18, 56, 56], f16), T([128, 256, 56, 56], f16), T([18, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 256, 56, 56], f16), T([128, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 56, 56], f16), T([128, 256, 56, 56], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 112, 112], f16), T([64, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 3, 224, 224], f16), T([64, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 65, ((T([128, 18, 56, 56], f16), T([18], f16), T([18], f16), T([18], f16), T([18], f16), True, 0.1, 1e-05), {})
+cnt: 73, ((T([128, 36, 28, 28], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f16), True, 0.1, 1e-05), {})
+cnt: 18, ((T([128, 18, 28, 28], f16), T([18], f16), T([18], f16), T([18], f16), T([18], f16), True, 0.1, 1e-05), {})
+cnt: 71, ((T([128, 72, 14, 14], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 10, ((T([128, 18, 14, 14], f16), T([18], f16), T([18], f16), T([18], f16), T([18], f16), True, 0.1, 1e-05), {})
+cnt: 10, ((T([128, 36, 14, 14], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f16), True, 0.1, 1e-05), {})
+cnt: 34, ((T([128, 144, 7, 7], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 18, 7, 7], f16), T([18], f16), T([18], f16), T([18], f16), T([18], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 36, 7, 7], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 72, 7, 7], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 64, 28, 28], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 256, 7, 7], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 512, 14, 14], f16), T([128, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 28, 28], f16), T([128, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 34, ((T([128, 144, 7, 7], f16), T([128, 144, 7, 7], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 10, ((T([128, 36, 14, 14], f16), T([128, 36, 14, 14], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f32), T([36], f32), True, 1e-05, [True, True, True]), {})
+cnt: 10, ((T([128, 18, 14, 14], f16), T([128, 18, 14, 14], f16), T([18], f16), T([18], f16), T([18], f16), T([18], f32), T([18], f32), True, 1e-05, [True, True, True]), {})
+cnt: 18, ((T([128, 18, 28, 28], f16), T([128, 18, 28, 28], f16), T([18], f16), T([18], f16), T([18], f16), T([18], f32), T([18], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 72, 7, 7], f16), T([128, 72, 7, 7], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 71, ((T([128, 72, 14, 14], f16), T([128, 72, 14, 14], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 36, 7, 7], f16), T([128, 36, 7, 7], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f32), T([36], f32), True, 1e-05, [True, True, True]), {})
+cnt: 73, ((T([128, 36, 28, 28], f16), T([128, 36, 28, 28], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f32), T([36], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 18, 7, 7], f16), T([128, 18, 7, 7], f16), T([18], f16), T([18], f16), T([18], f16), T([18], f32), T([18], f32), True, 1e-05, [True, True, True]), {})
+cnt: 65, ((T([128, 18, 56, 56], f16), T([128, 18, 56, 56], f16), T([18], f16), T([18], f16), T([18], f16), T([18], f32), T([18], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 8, ((T([128, 18, 56, 56], f16),), {})
+cnt: 8, ((T([128, 36, 28, 28], f16),), {})
+cnt: 10, ((T([128, 18, 28, 28], f16),), {})
+cnt: 7, ((T([128, 72, 14, 14], f16),), {})
+cnt: 3, ((T([128, 18, 14, 14], f16),), {})
+cnt: 3, ((T([128, 36, 14, 14], f16),), {})
+cnt: 3, ((T([128, 144, 7, 7], f16),), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 9, ((T([128, 64, 56, 56], f16),), {})
+cnt: 4, ((T([128, 256, 56, 56], f16),), {})
+cnt: 65, ((T([128, 18, 56, 56], f16),), {})
+cnt: 65, ((T([128, 36, 28, 28], f16),), {})
+cnt: 57, ((T([128, 72, 14, 14], f16),), {})
+cnt: 25, ((T([128, 144, 7, 7], f16),), {})
+cnt: 2, ((T([128, 32, 56, 56], f16),), {})
+cnt: 1, ((T([128, 128, 56, 56], f16),), {})
+cnt: 2, ((T([128, 64, 28, 28], f16),), {})
+cnt: 2, ((T([128, 256, 28, 28], f16),), {})
+cnt: 2, ((T([128, 128, 14, 14], f16),), {})
+cnt: 2, ((T([128, 512, 14, 14], f16),), {})
+cnt: 2, ((T([128, 256, 7, 7], f16),), {})
+cnt: 2, ((T([128, 1024, 7, 7], f16),), {})
+cnt: 1, ((T([128, 2048, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16), 0), {})
+cnt: 2, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), 0), {})
+cnt: 2, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), 0), {})
+cnt: 2, ((T([128, 512, 14, 14], f16), T([128, 512, 14, 14], f16), 0), {})
+cnt: 2, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), 0), {})
+cnt: 2, ((T([128, 256, 28, 28], f16), T([128, 256, 28, 28], f16), 0), {})
+cnt: 2, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), 0), {})
+cnt: 2, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), 0), {})
+cnt: 28, ((T([128, 144, 7, 7], f16), T([128, 144, 7, 7], f16), 0), {})
+cnt: 3, ((T([128, 36, 14, 14], f16), T([128, 36, 14, 14], f16), 0), {})
+cnt: 3, ((T([128, 18, 14, 14], f16), T([128, 18, 14, 14], f16), 0), {})
+cnt: 10, ((T([128, 18, 28, 28], f16), T([128, 18, 28, 28], f16), 0), {})
+cnt: 64, ((T([128, 72, 14, 14], f16), T([128, 72, 14, 14], f16), 0), {})
+cnt: 73, ((T([128, 36, 28, 28], f16), T([128, 36, 28, 28], f16), 0), {})
+cnt: 73, ((T([128, 18, 56, 56], f16), T([128, 18, 56, 56], f16), 0), {})
+cnt: 4, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16), 0), {})
+cnt: 9, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), 0), {})
+Operator: aten.upsample_nearest2d.vec
+cnt: 8, ((T([128, 18, 28, 28], f16), None, [2.0, 2.0]), {})
+cnt: 7, ((T([128, 18, 14, 14], f16), None, [4.0, 4.0]), {})
+cnt: 7, ((T([128, 36, 14, 14], f16), None, [2.0, 2.0]), {})
+cnt: 3, ((T([128, 18, 7, 7], f16), None, [8.0, 8.0]), {})
+cnt: 3, ((T([128, 36, 7, 7], f16), None, [4.0, 4.0]), {})
+cnt: 3, ((T([128, 72, 7, 7], f16), None, [2.0, 2.0]), {})
+Operator: aten.upsample_nearest2d_backward.vec
+cnt: 3, ((T([128, 72, 14, 14], f16), None, [128, 72, 7, 7], [2.0, 2.0]), {})
+cnt: 3, ((T([128, 36, 28, 28], f16), None, [128, 36, 7, 7], [4.0, 4.0]), {})
+cnt: 7, ((T([128, 36, 28, 28], f16), None, [128, 36, 14, 14], [2.0, 2.0]), {})
+cnt: 3, ((T([128, 18, 56, 56], f16), None, [128, 18, 7, 7], [8.0, 8.0]), {})
+cnt: 7, ((T([128, 18, 56, 56], f16), None, [128, 18, 14, 14], [4.0, 4.0]), {})
+cnt: 8, ((T([128, 18, 56, 56], f16), None, [128, 18, 28, 28], [2.0, 2.0]), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/inception_v3_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/inception_v3_training.txt
new file mode 100644
index 0000000000000..c11cd6890c765
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/inception_v3_training.txt
@@ -0,0 +1,239 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16)), {})
+cnt: 3, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16)), {})
+cnt: 3, ((T([128, 1280, 8, 8], f16), T([128, 1280, 8, 8], f16)), {})
+cnt: 14, ((T([128, 768, 17, 17], f16), T([128, 768, 17, 17], f16)), {})
+cnt: 5, ((T([128, 288, 35, 35], f16), T([128, 288, 35, 35], f16)), {})
+cnt: 3, ((T([128, 256, 35, 35], f16), T([128, 256, 35, 35], f16)), {})
+cnt: 3, ((T([128, 192, 35, 35], f16), T([128, 192, 35, 35], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 94, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 192, 35, 35], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 256, 35, 35], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 4, ((T([128, 768, 17, 17], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), [3, 3], [1, 1], [1, 1]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([128, 2048, 8, 8], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([128, 1280, 8, 8], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 4, ((T([128, 768, 17, 17], f16), T([128, 768, 17, 17], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), T([128, 288, 35, 35], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 256, 35, 35], f16), T([128, 256, 35, 35], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([128, 192, 35, 35], f16), [3, 3], [1, 1], [1, 1], False, True, None), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), T([128, 96, 35, 35], f16), T([128, 32, 35, 35], f16)], 1), {})
+cnt: 2, (([T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), T([128, 96, 35, 35], f16), T([128, 64, 35, 35], f16)], 1), {})
+cnt: 1, (([T([128, 384, 17, 17], f16), T([128, 96, 17, 17], f16), T([128, 288, 17, 17], f16)], 1), {})
+cnt: 4, (([T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16)], 1), {})
+cnt: 1, (([T([128, 320, 8, 8], f16), T([128, 192, 8, 8], f16), T([128, 768, 8, 8], f16)], 1), {})
+cnt: 4, (([T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16)], 1), {})
+cnt: 2, (([T([128, 320, 8, 8], f16), T([128, 768, 8, 8], f16), T([128, 768, 8, 8], f16), T([128, 192, 8, 8], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 299, 299], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 299, 299], f16), T([32, 3, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([32, 32, 3, 3], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 73, 73], f16), T([80, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([192, 80, 3, 3], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 192, 35, 35], f16), T([64, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([48, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([64, 48, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 64, 35, 35], f16), T([96, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([32, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 256, 35, 35], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 35, 35], f16), T([48, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 288, 35, 35], f16), T([64, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), T([48, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), T([384, 288, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([128, 768, 17, 17], f16), T([192, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 768, 17, 17], f16), T([128, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 17, 17], f16), T([192, 128, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 17, 17], f16), T([192, 128, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 768, 17, 17], f16), T([160, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([160, 160, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 160, 17, 17], f16), T([192, 160, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([160, 160, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 160, 17, 17], f16), T([192, 160, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([192, 192, 1, 7], f16), None, [1, 1], [0, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([192, 192, 7, 1], f16), None, [1, 1], [3, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([320, 192, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([192, 192, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([320, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([384, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([384, 384, 1, 3], f16), None, [1, 1], [0, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([384, 384, 3, 1], f16), None, [1, 1], [1, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([448, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([384, 448, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 8, 8], f16), T([192, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([320, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([384, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([448, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 2048, 8, 8], f16), T([192, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 192, 8, 8], f16), T([128, 2048, 8, 8], f16), T([192, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), T([384, 384, 3, 1], f16), [0], [1, 1], [1, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), T([384, 384, 1, 3], f16), [0], [1, 1], [0, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 384, 8, 8], f16), T([128, 448, 8, 8], f16), T([384, 448, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 448, 8, 8], f16), T([128, 2048, 8, 8], f16), T([448, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 8, 8], f16), T([128, 2048, 8, 8], f16), T([384, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 8, 8], f16), T([128, 2048, 8, 8], f16), T([320, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 8, 8], f16), T([128, 1280, 8, 8], f16), T([192, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 448, 8, 8], f16), T([128, 1280, 8, 8], f16), T([448, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 8, 8], f16), T([128, 1280, 8, 8], f16), T([384, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 8, 8], f16), T([128, 1280, 8, 8], f16), T([320, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 8, 8], f16), T([128, 192, 17, 17], f16), T([192, 192, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([192, 192, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([192, 192, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 12, ((T([128, 192, 17, 17], f16), T([128, 768, 17, 17], f16), T([192, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 8, 8], f16), T([128, 192, 17, 17], f16), T([320, 192, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 192, 17, 17], f16), T([128, 160, 17, 17], f16), T([192, 160, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), T([160, 160, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), T([160, 160, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 160, 17, 17], f16), T([128, 768, 17, 17], f16), T([160, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 192, 17, 17], f16), T([128, 160, 17, 17], f16), T([192, 160, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([128, 128, 17, 17], f16), T([192, 128, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), T([128, 128, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), T([128, 128, 1, 7], f16), [0], [1, 1], [0, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 17, 17], f16), T([128, 768, 17, 17], f16), T([128, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 17, 17], f16), T([128, 128, 17, 17], f16), T([192, 128, 7, 1], f16), [0], [1, 1], [3, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 17, 17], f16), T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 96, 35, 35], f16), T([128, 64, 35, 35], f16), T([96, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 64, 35, 35], f16), T([128, 288, 35, 35], f16), T([64, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 17, 17], f16), T([128, 288, 35, 35], f16), T([384, 288, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 96, 35, 35], f16), T([128, 96, 35, 35], f16), T([96, 96, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 35, 35], f16), T([128, 48, 35, 35], f16), T([64, 48, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 35, 35], f16), T([128, 288, 35, 35], f16), T([48, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 35, 35], f16), T([128, 256, 35, 35], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 35, 35], f16), T([128, 256, 35, 35], f16), T([48, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 35, 35], f16), T([128, 192, 35, 35], f16), T([32, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 35, 35], f16), T([128, 192, 35, 35], f16), T([64, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 35, 35], f16), T([128, 192, 35, 35], f16), T([48, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([128, 80, 73, 73], f16), T([192, 80, 3, 3], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([128, 64, 73, 73], f16), T([80, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([128, 32, 147, 147], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([128, 32, 149, 149], f16), T([32, 32, 3, 3], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([128, 3, 299, 299], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 299, 299], f16), T([128, 3, 299, 299], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2048, 8, 8], f16, stride=(2048, 1, 0, 0)), 64), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 147, 147], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 288, 35, 35], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 768, 17, 17], f16), [3, 3], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 768, 8, 8], f16, stride=(81920, 64, 8, 1)), T([128, 768, 17, 17], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 768, 8, 8], i64)), {})
+cnt: 1, ((T([128, 288, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 288, 35, 35], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 288, 17, 17], i64)), {})
+cnt: 1, ((T([128, 192, 35, 35], f16), T([128, 192, 71, 71], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 192, 35, 35], i64)), {})
+cnt: 1, ((T([128, 64, 73, 73], f16), T([128, 64, 147, 147], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 64, 73, 73], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 2048, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 32, 149, 149], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 12, ((T([128, 64, 35, 35], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 0.001), {})
+cnt: 7, ((T([128, 96, 35, 35], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 32, 35, 35], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 384, 17, 17], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 96, 17, 17], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 26, ((T([128, 192, 17, 17], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 6, ((T([128, 128, 17, 17], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 0.001), {})
+cnt: 12, ((T([128, 160, 17, 17], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 320, 8, 8], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 192, 8, 8], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 12, ((T([128, 384, 8, 8], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), True, 0.1, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([128, 192, 8, 8], f16), T([128, 192, 8, 8], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 12, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([128, 448, 8, 8], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 320, 8, 8], f16), T([128, 320, 8, 8], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 0.001, [True, True, True]), {})
+cnt: 26, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 12, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 0.001, [True, True, True]), {})
+cnt: 6, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 17, 17], f16), T([128, 96, 17, 17], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+cnt: 7, ((T([128, 96, 35, 35], f16), T([128, 96, 35, 35], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+cnt: 12, ((T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 384, 17, 17], f16), T([128, 384, 17, 17], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([128, 48, 35, 35], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 35, 35], f16), T([128, 32, 35, 35], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([128, 192, 71, 71], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([128, 80, 73, 73], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([128, 64, 147, 147], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([128, 32, 147, 147], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([128, 32, 149, 149], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 32, 149, 149], f16),), {})
+cnt: 1, ((T([128, 32, 147, 147], f16),), {})
+cnt: 1, ((T([128, 64, 147, 147], f16),), {})
+cnt: 1, ((T([128, 80, 73, 73], f16),), {})
+cnt: 1, ((T([128, 192, 71, 71], f16),), {})
+cnt: 12, ((T([128, 64, 35, 35], f16),), {})
+cnt: 3, ((T([128, 48, 35, 35], f16),), {})
+cnt: 7, ((T([128, 96, 35, 35], f16),), {})
+cnt: 1, ((T([128, 32, 35, 35], f16),), {})
+cnt: 1, ((T([128, 384, 17, 17], f16),), {})
+cnt: 1, ((T([128, 96, 17, 17], f16),), {})
+cnt: 26, ((T([128, 192, 17, 17], f16),), {})
+cnt: 6, ((T([128, 128, 17, 17], f16),), {})
+cnt: 12, ((T([128, 160, 17, 17], f16),), {})
+cnt: 3, ((T([128, 320, 8, 8], f16),), {})
+cnt: 3, ((T([128, 192, 8, 8], f16),), {})
+cnt: 12, ((T([128, 384, 8, 8], f16),), {})
+cnt: 2, ((T([128, 448, 8, 8], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([128, 192, 8, 8], f16, stride=(131072, 64, 8, 1)), T([128, 192, 8, 8], f16), 0), {})
+cnt: 8, ((T([128, 384, 8, 8], f16, stride=(131072, 64, 8, 1)), T([128, 384, 8, 8], f16), 0), {})
+cnt: 4, ((T([128, 384, 8, 8], f16), T([128, 384, 8, 8], f16), 0), {})
+cnt: 2, ((T([128, 448, 8, 8], f16), T([128, 448, 8, 8], f16), 0), {})
+cnt: 2, ((T([128, 320, 8, 8], f16, stride=(131072, 64, 8, 1)), T([128, 320, 8, 8], f16), 0), {})
+cnt: 1, ((T([128, 192, 8, 8], f16, stride=(81920, 64, 8, 1)), T([128, 192, 8, 8], f16), 0), {})
+cnt: 10, ((T([128, 192, 17, 17], f16), T([128, 192, 17, 17], f16), 0), {})
+cnt: 1, ((T([128, 320, 8, 8], f16, stride=(81920, 64, 8, 1)), T([128, 320, 8, 8], f16), 0), {})
+cnt: 16, ((T([128, 192, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 192, 17, 17], f16), 0), {})
+cnt: 12, ((T([128, 160, 17, 17], f16), T([128, 160, 17, 17], f16), 0), {})
+cnt: 6, ((T([128, 128, 17, 17], f16), T([128, 128, 17, 17], f16), 0), {})
+cnt: 1, ((T([128, 96, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 96, 17, 17], f16), 0), {})
+cnt: 4, ((T([128, 96, 35, 35], f16), T([128, 96, 35, 35], f16), 0), {})
+cnt: 4, ((T([128, 64, 35, 35], f16), T([128, 64, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 384, 17, 17], f16, stride=(221952, 289, 17, 1)), T([128, 384, 17, 17], f16), 0), {})
+cnt: 6, ((T([128, 64, 35, 35], f16, stride=(352800, 1225, 35, 1)), T([128, 64, 35, 35], f16), 0), {})
+cnt: 2, ((T([128, 96, 35, 35], f16, stride=(352800, 1225, 35, 1)), T([128, 96, 35, 35], f16), 0), {})
+cnt: 3, ((T([128, 48, 35, 35], f16), T([128, 48, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 32, 35, 35], f16, stride=(313600, 1225, 35, 1)), T([128, 32, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 96, 35, 35], f16, stride=(313600, 1225, 35, 1)), T([128, 96, 35, 35], f16), 0), {})
+cnt: 2, ((T([128, 64, 35, 35], f16, stride=(313600, 1225, 35, 1)), T([128, 64, 35, 35], f16), 0), {})
+cnt: 1, ((T([128, 192, 71, 71], f16), T([128, 192, 71, 71], f16), 0), {})
+cnt: 1, ((T([128, 80, 73, 73], f16), T([128, 80, 73, 73], f16), 0), {})
+cnt: 1, ((T([128, 64, 147, 147], f16), T([128, 64, 147, 147], f16), 0), {})
+cnt: 1, ((T([128, 32, 147, 147], f16), T([128, 32, 147, 147], f16), 0), {})
+cnt: 1, ((T([128, 32, 149, 149], f16), T([128, 32, 149, 149], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/jx_nest_base_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/jx_nest_base_training.txt
new file mode 100644
index 0000000000000..ddb7593f59490
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/jx_nest_base_training.txt
@@ -0,0 +1,269 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 2, ((T([64, 4, 16, 196, 196], f16), -1, False), {})
+cnt: 2, ((T([64, 8, 4, 196, 196], f16), -1, False), {})
+cnt: 20, ((T([64, 16, 1, 196, 196], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 20, ((T([64, 16, 1, 196, 196], f16), T([64, 16, 1, 196, 196], f16), -1, f16), {})
+cnt: 2, ((T([64, 8, 4, 196, 196], f16), T([64, 8, 4, 196, 196], f16), -1, f16), {})
+cnt: 2, ((T([64, 4, 16, 196, 196], f16), T([64, 4, 16, 196, 196], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 2, ((T([64, 4, 4, 14, 14, 128], f16), [64, 16, 196, 128]), {})
+cnt: 2, ((T([200704, 384], f16), [64, 16, 196, 384]), {})
+cnt: 6, ((T([64, 4, 16, 196, 32], f16), [4096, 196, 32]), {})
+cnt: 2, ((T([64, 4, 16, 32, 196], f16), [4096, 32, 196]), {})
+cnt: 2, ((T([4096, 196, 196], f16), [64, 4, 16, 196, 196]), {})
+cnt: 2, ((T([4096, 196, 32], f16), [64, 4, 16, 196, 32]), {})
+cnt: 2, ((T([64, 16, 196, 32, 4], f16), [64, 16, 196, 128]), {})
+cnt: 4, ((T([200704, 128], f16), [64, 16, 196, 128]), {})
+cnt: 2, ((T([200704, 512], f16), [64, 16, 196, 512]), {})
+cnt: 2, ((T([64, 4, 14, 4, 14, 128], f16), [64, 56, 56, 128]), {})
+cnt: 2, ((T([64, 2, 2, 14, 14, 256], f16), [64, 4, 196, 256]), {})
+cnt: 2, ((T([50176, 768], f16), [64, 4, 196, 768]), {})
+cnt: 6, ((T([64, 8, 4, 196, 32], f16), [2048, 196, 32]), {})
+cnt: 2, ((T([64, 8, 4, 32, 196], f16), [2048, 32, 196]), {})
+cnt: 2, ((T([2048, 196, 196], f16), [64, 8, 4, 196, 196]), {})
+cnt: 2, ((T([2048, 196, 32], f16), [64, 8, 4, 196, 32]), {})
+cnt: 2, ((T([64, 4, 196, 32, 8], f16), [64, 4, 196, 256]), {})
+cnt: 4, ((T([50176, 256], f16), [64, 4, 196, 256]), {})
+cnt: 2, ((T([50176, 1024], f16), [64, 4, 196, 1024]), {})
+cnt: 2, ((T([64, 2, 14, 2, 14, 256], f16), [64, 28, 28, 256]), {})
+cnt: 20, ((T([12544, 1536], f16), [64, 1, 196, 1536]), {})
+cnt: 60, ((T([64, 16, 1, 196, 32], f16), [1024, 196, 32]), {})
+cnt: 20, ((T([64, 16, 1, 32, 196], f16), [1024, 32, 196]), {})
+cnt: 20, ((T([1024, 196, 196], f16), [64, 16, 1, 196, 196]), {})
+cnt: 20, ((T([1024, 196, 32], f16), [64, 16, 1, 196, 32]), {})
+cnt: 20, ((T([64, 1, 196, 32, 16], f16), [64, 1, 196, 512]), {})
+cnt: 40, ((T([12544, 512], f16), [64, 1, 196, 512]), {})
+cnt: 20, ((T([12544, 2048], f16), [64, 1, 196, 2048]), {})
+cnt: 40, ((T([64, 1, 196, 512], f16), [12544, 512]), {})
+cnt: 20, ((T([64, 1, 196, 3, 16, 32], f16), [64, 1, 196, 1536]), {})
+cnt: 2, ((T([64, 4, 196, 3, 8, 32], f16), [64, 4, 196, 768]), {})
+cnt: 2, ((T([64, 16, 196, 3, 4, 32], f16), [64, 16, 196, 384]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([64, 16, 196, 128], f16), T([1, 16, 196, 128], f16)), {})
+cnt: 2, ((T([64, 16, 196, 384], f16), T([384], f16)), {})
+cnt: 4, ((T([64, 16, 196, 128], f16), T([128], f16)), {})
+cnt: 8, ((T([64, 16, 196, 128], f16), T([64, 16, 196, 128], f16)), {})
+cnt: 2, ((T([64, 16, 196, 512], f16), T([512], f16)), {})
+cnt: 1, ((T([64, 4, 196, 256], f16), T([1, 4, 196, 256], f16)), {})
+cnt: 2, ((T([64, 4, 196, 768], f16), T([768], f16)), {})
+cnt: 4, ((T([64, 4, 196, 256], f16), T([256], f16)), {})
+cnt: 8, ((T([64, 4, 196, 256], f16), T([64, 4, 196, 256], f16)), {})
+cnt: 2, ((T([64, 4, 196, 1024], f16), T([1024], f16)), {})
+cnt: 1, ((T([64, 1, 196, 512], f16), T([1, 1, 196, 512], f16)), {})
+cnt: 20, ((T([64, 1, 196, 1536], f16), T([1536], f16)), {})
+cnt: 40, ((T([64, 1, 196, 512], f16), T([512], f16)), {})
+cnt: 40, ((T([64, 1, 196, 512], f16), T([64, 1, 196, 512], f16)), {})
+cnt: 20, ((T([64, 1, 196, 2048], f16), T([2048], f16)), {})
+cnt: 40, ((T([64, 1, 196, 512], f16, stride=(100352, 196, 1, 196)), T([64, 1, 196, 512], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([64, 512], f16), T([512, 1000], f16, stride=(1, 512))), {})
+Operator: aten.as_strided_.default
+cnt: 1, ((T([64, 512, 1, 1], f16), [64, 512, 1, 1], [512, 1, 512, 512]), {})
+Operator: aten.bernoulli_.float
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.9782608691602945), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.9565217383205891), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.9347826093435287), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.9130434766411781), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8913043439388275), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8695652186870575), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8478260785341263), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8260869532823563), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8043478280305862), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.782608687877655), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.760869562625885), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.739130437374115), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.717391312122345), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.695652186870575), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.6739130318164825), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.6521739065647125), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.6304347813129425), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.6086956560611725), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.5869565308094025), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.5652174055576324), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.54347825050354), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.52173912525177), {})
+cnt: 2, ((T([64, 1, 1, 1], f16),), {})
+Operator: aten.bmm.default
+cnt: 2, ((T([4096, 196, 32], f16), T([4096, 32, 196], f16)), {})
+cnt: 2, ((T([4096, 196, 196], f16), T([4096, 196, 32], f16)), {})
+cnt: 2, ((T([2048, 196, 32], f16), T([2048, 32, 196], f16)), {})
+cnt: 2, ((T([2048, 196, 196], f16), T([2048, 196, 32], f16)), {})
+cnt: 20, ((T([1024, 196, 32], f16), T([1024, 32, 196], f16)), {})
+cnt: 20, ((T([1024, 196, 196], f16), T([1024, 196, 32], f16)), {})
+cnt: 20, ((T([1024, 196, 196], f16, stride=(38416, 1, 196)), T([1024, 196, 32], f16)), {})
+cnt: 20, ((T([1024, 196, 32], f16), T([1024, 32, 196], f16, stride=(6272, 1, 32))), {})
+cnt: 20, ((T([1024, 32, 196], f16, stride=(6272, 1, 32)), T([1024, 196, 196], f16)), {})
+cnt: 20, ((T([1024, 196, 196], f16), T([1024, 196, 32], f16, stride=(6272, 1, 196))), {})
+cnt: 2, ((T([2048, 196, 196], f16, stride=(38416, 1, 196)), T([2048, 196, 32], f16)), {})
+cnt: 2, ((T([2048, 196, 32], f16), T([2048, 32, 196], f16, stride=(6272, 1, 32))), {})
+cnt: 2, ((T([2048, 32, 196], f16, stride=(6272, 1, 32)), T([2048, 196, 196], f16)), {})
+cnt: 2, ((T([2048, 196, 196], f16), T([2048, 196, 32], f16, stride=(6272, 1, 196))), {})
+cnt: 2, ((T([4096, 196, 196], f16, stride=(38416, 1, 196)), T([4096, 196, 32], f16)), {})
+cnt: 2, ((T([4096, 196, 32], f16), T([4096, 32, 196], f16, stride=(6272, 1, 32))), {})
+cnt: 2, ((T([4096, 32, 196], f16, stride=(6272, 1, 32)), T([4096, 196, 196], f16)), {})
+cnt: 2, ((T([4096, 196, 196], f16), T([4096, 196, 32], f16, stride=(6272, 1, 196))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([64, 256, 56, 56], f16, stride=(802816, 1, 14336, 256)), [0, 1, 0, 1], -inf), {})
+cnt: 1, ((T([64, 512, 28, 28], f16, stride=(401408, 1, 14336, 512)), [0, 1, 0, 1], -inf), {})
+cnt: 1, ((T([64, 512, 29, 29], f16, stride=(430592, 1, 14848, 512)), [0, -1, 0, -1]), {})
+cnt: 1, ((T([64, 256, 57, 57], f16, stride=(831744, 1, 14592, 256)), [0, -1, 0, -1]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([128, 3, 4, 4], f16), T([128], f16), [4, 4], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([256, 128, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([512, 256, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 512, 28, 28], f16, stride=(401408, 1, 14336, 512)), T([64, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([512, 256, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 56, 56], f16, stride=(802816, 1, 14336, 256)), T([64, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([256, 128, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([64, 3, 224, 224], f16), T([128, 3, 4, 4], f16), [128], [4, 4], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+cnt: 1, ((T([64, 512], f16), T([64, 512], f16)), {})
+cnt: 1, ((T([512, 256, 3, 3], f16), T([512, 256, 3, 3], f16, stride=(2304, 1, 768, 256))), {})
+cnt: 1, ((T([256, 128, 3, 3], f16), T([256, 128, 3, 3], f16, stride=(1152, 1, 384, 128))), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 512, 14, 14], f16, stride=(512, 1, 0, 0)), 196), {})
+Operator: aten.div_.Tensor
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.9782608691602945), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.9565217383205891), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.9347826093435287), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.9130434766411781), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8913043439388275), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8695652186870575), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8478260785341263), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8260869532823563), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.8043478280305862), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.782608687877655), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.760869562625885), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.739130437374115), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.717391312122345), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.695652186870575), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.6739130318164825), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.6521739065647125), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.6304347813129425), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.6086956560611725), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.5869565308094025), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.5652174055576324), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.54347825050354), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.52173912525177), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.5), {})
+Operator: aten.gelu.default
+cnt: 2, ((T([64, 16, 196, 512], f16),), {})
+cnt: 2, ((T([64, 4, 196, 1024], f16),), {})
+cnt: 20, ((T([64, 1, 196, 2048], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 20, ((T([64, 1, 196, 2048], f16), T([64, 1, 196, 2048], f16)), {})
+cnt: 2, ((T([64, 4, 196, 1024], f16), T([64, 4, 196, 1024], f16)), {})
+cnt: 2, ((T([64, 16, 196, 512], f16), T([64, 16, 196, 512], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([64, 256, 57, 57], f16, stride=(831744, 1, 14592, 256)), [3, 3], [2, 2]), {})
+cnt: 1, ((T([64, 512, 29, 29], f16, stride=(430592, 1, 14848, 512)), [3, 3], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 29, 29], f16, stride=(430592, 1, 14848, 512)), [3, 3], [2, 2], [0, 0], [1, 1], False, T([64, 512, 14, 14], i64, stride=(100352, 1, 7168, 512))), {})
+cnt: 1, ((T([64, 256, 28, 28], f16, stride=(200704, 1, 7168, 256)), T([64, 256, 57, 57], f16, stride=(831744, 1, 14592, 256)), [3, 3], [2, 2], [0, 0], [1, 1], False, T([64, 256, 28, 28], i64, stride=(200704, 1, 7168, 256))), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 512, 14, 14], f16, stride=(100352, 1, 7168, 512)), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 2, ((T([200704, 128], f16), T([128, 384], f16, stride=(1, 128))), {})
+cnt: 2, ((T([200704, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 2, ((T([200704, 128], f16), T([128, 512], f16, stride=(1, 128))), {})
+cnt: 2, ((T([200704, 512], f16), T([512, 128], f16, stride=(1, 512))), {})
+cnt: 2, ((T([50176, 256], f16), T([256, 768], f16, stride=(1, 256))), {})
+cnt: 2, ((T([50176, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 2, ((T([50176, 256], f16), T([256, 1024], f16, stride=(1, 256))), {})
+cnt: 2, ((T([50176, 1024], f16), T([1024, 256], f16, stride=(1, 1024))), {})
+cnt: 20, ((T([12544, 512], f16), T([512, 1536], f16, stride=(1, 512))), {})
+cnt: 20, ((T([12544, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 20, ((T([12544, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 20, ((T([12544, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+cnt: 1, ((T([64, 1000], f16), T([1000, 512], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 512], f16)), {})
+cnt: 20, ((T([512, 12544], f16, stride=(1, 512)), T([12544, 2048], f16)), {})
+cnt: 20, ((T([12544, 512], f16), T([512, 2048], f16)), {})
+cnt: 20, ((T([2048, 12544], f16, stride=(1, 2048)), T([12544, 512], f16)), {})
+cnt: 20, ((T([12544, 2048], f16), T([2048, 512], f16)), {})
+cnt: 20, ((T([512, 12544], f16, stride=(1, 512)), T([12544, 512], f16)), {})
+cnt: 20, ((T([12544, 512], f16), T([512, 512], f16)), {})
+cnt: 20, ((T([1536, 12544], f16, stride=(1, 1536)), T([12544, 512], f16)), {})
+cnt: 20, ((T([12544, 1536], f16), T([1536, 512], f16)), {})
+cnt: 2, ((T([256, 50176], f16, stride=(1, 256)), T([50176, 1024], f16)), {})
+cnt: 2, ((T([50176, 256], f16), T([256, 1024], f16)), {})
+cnt: 2, ((T([1024, 50176], f16, stride=(1, 1024)), T([50176, 256], f16)), {})
+cnt: 2, ((T([50176, 1024], f16), T([1024, 256], f16)), {})
+cnt: 2, ((T([256, 50176], f16, stride=(1, 256)), T([50176, 256], f16)), {})
+cnt: 2, ((T([50176, 256], f16), T([256, 256], f16)), {})
+cnt: 2, ((T([768, 50176], f16, stride=(1, 768)), T([50176, 256], f16)), {})
+cnt: 2, ((T([50176, 768], f16), T([768, 256], f16)), {})
+cnt: 2, ((T([128, 200704], f16, stride=(1, 128)), T([200704, 512], f16)), {})
+cnt: 2, ((T([200704, 128], f16), T([128, 512], f16)), {})
+cnt: 2, ((T([512, 200704], f16, stride=(1, 512)), T([200704, 128], f16)), {})
+cnt: 2, ((T([200704, 512], f16), T([512, 128], f16)), {})
+cnt: 2, ((T([128, 200704], f16, stride=(1, 128)), T([200704, 128], f16)), {})
+cnt: 2, ((T([200704, 128], f16), T([128, 128], f16)), {})
+cnt: 2, ((T([384, 200704], f16, stride=(1, 384)), T([200704, 128], f16)), {})
+cnt: 2, ((T([200704, 384], f16), T([384, 128], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([64, 4, 16, 196, 196], f16), 0.1767766952966369), {})
+cnt: 4, ((T([64, 16, 196, 128], f16), T([64, 1, 1, 1], f16)), {})
+cnt: 4, ((T([64, 8, 4, 196, 196], f16), 0.1767766952966369), {})
+cnt: 8, ((T([64, 4, 196, 256], f16), T([64, 1, 1, 1], f16)), {})
+cnt: 40, ((T([64, 16, 1, 196, 196], f16), 0.1767766952966369), {})
+cnt: 40, ((T([64, 1, 196, 512], f16), T([64, 1, 1, 1], f16)), {})
+cnt: 40, ((T([64, 1, 196, 512], f16, stride=(100352, 196, 1, 196)), T([64, 1, 1, 1], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 4, ((T([64, 16, 196, 128], f16), [128], T([128], f16), T([128], f16), 1e-06), {})
+cnt: 1, ((T([64, 56, 56, 256], f16), [256], T([256], f16), T([256], f16), 1e-06), {})
+cnt: 4, ((T([64, 4, 196, 256], f16), [256], T([256], f16), T([256], f16), 1e-06), {})
+cnt: 1, ((T([64, 28, 28, 512], f16), [512], T([512], f16), T([512], f16), 1e-06), {})
+cnt: 40, ((T([64, 1, 196, 512], f16), [512], T([512], f16), T([512], f16), 1e-06), {})
+cnt: 1, ((T([64, 14, 14, 512], f16), [512], T([512], f16), T([512], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 1, ((T([64, 14, 14, 512], f16, stride=(100352, 14, 1, 196)), T([64, 14, 14, 512], f16), [512], T([64, 14, 14, 1], f32), T([64, 14, 14, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 40, ((T([64, 1, 196, 512], f16), T([64, 1, 196, 512], f16), [512], T([64, 1, 196, 1], f32), T([64, 1, 196, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 1, ((T([64, 28, 28, 512], f16), T([64, 28, 28, 512], f16), [512], T([64, 28, 28, 1], f32), T([64, 28, 28, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 4, ((T([64, 4, 196, 256], f16), T([64, 4, 196, 256], f16), [256], T([64, 4, 196, 1], f32), T([64, 4, 196, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+cnt: 1, ((T([64, 56, 56, 256], f16), T([64, 56, 56, 256], f16), [256], T([64, 56, 56, 1], f32), T([64, 56, 56, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+cnt: 4, ((T([64, 16, 196, 128], f16), T([64, 16, 196, 128], f16), [128], T([64, 16, 196, 1], f32), T([64, 16, 196, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+Operator: aten.new_empty.default
+cnt: 2, ((T([64, 16, 196, 128], f16), [64, 1, 1, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 4, ((T([64, 4, 196, 256], f16), [64, 1, 1, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 40, ((T([64, 1, 196, 512], f16), [64, 1, 1, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_empty_strided.default
+cnt: 1, ((T([512, 256, 3, 3], f16, stride=(2304, 1, 768, 256)), [512, 256, 3, 3], [2304, 9, 3, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([256, 128, 3, 3], f16, stride=(1152, 1, 384, 128)), [256, 128, 3, 3], [1152, 9, 3, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([64, 512], f16), [32768]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.stack.default
+cnt: 20, (([T([64, 16, 1, 196, 32], f16), T([64, 16, 1, 196, 32], f16, stride=(100352, 6272, 6272, 1, 196)), T([64, 16, 1, 196, 32], f16)],), {})
+cnt: 2, (([T([64, 8, 4, 196, 32], f16), T([64, 8, 4, 196, 32], f16, stride=(200704, 25088, 6272, 1, 196)), T([64, 8, 4, 196, 32], f16)],), {})
+cnt: 2, (([T([64, 4, 16, 196, 32], f16), T([64, 4, 16, 196, 32], f16, stride=(401408, 100352, 6272, 1, 196)), T([64, 4, 16, 196, 32], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 40, ((T([64, 1, 196, 512], f16, stride=(100352, 196, 1, 196)), [0, 1, 2], True), {})
+cnt: 20, ((T([64, 1, 196, 2048], f16), [0, 1, 2], True), {})
+cnt: 20, ((T([64, 1, 196, 1536], f16), [0, 1, 2], True), {})
+cnt: 1, ((T([64, 1, 196, 512], f16, stride=(100352, 196, 1, 196)), [0], True), {})
+cnt: 4, ((T([64, 4, 196, 256], f16), [0, 1, 2], True), {})
+cnt: 2, ((T([64, 4, 196, 1024], f16), [0, 1, 2], True), {})
+cnt: 2, ((T([64, 4, 196, 768], f16), [0, 1, 2], True), {})
+cnt: 1, ((T([64, 4, 196, 256], f16), [0], True), {})
+cnt: 4, ((T([64, 16, 196, 128], f16), [0, 1, 2], True), {})
+cnt: 2, ((T([64, 16, 196, 512], f16), [0, 1, 2], True), {})
+cnt: 2, ((T([64, 16, 196, 384], f16), [0, 1, 2], True), {})
+cnt: 1, ((T([64, 16, 196, 128], f16), [0], True), {})
+Operator: aten.unbind.int
+cnt: 2, ((T([3, 64, 4, 16, 196, 32], f16, stride=(128, 1204224, 32, 75264, 384, 1)),), {})
+cnt: 2, ((T([3, 64, 8, 4, 196, 32], f16, stride=(256, 602112, 32, 150528, 768, 1)),), {})
+cnt: 20, ((T([3, 64, 16, 1, 196, 32], f16, stride=(512, 301056, 32, 301056, 1536, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/lcnet_050_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/lcnet_050_training.txt
new file mode 100644
index 0000000000000..48f28c23f3f4c
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/lcnet_050_training.txt
@@ -0,0 +1,158 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 27, ((T([], i64), 1), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16)), {})
+cnt: 1, ((T([128, 128, 7, 7], f16), T([128, 128, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+cnt: 2, ((T([128, 8, 112, 112], f16),), {})
+cnt: 1, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 16, 56, 56], f16),), {})
+cnt: 3, ((T([128, 32, 56, 56], f16),), {})
+cnt: 1, ((T([128, 32, 28, 28], f16),), {})
+cnt: 3, ((T([128, 64, 28, 28], f16),), {})
+cnt: 1, ((T([128, 64, 14, 14], f16),), {})
+cnt: 11, ((T([128, 128, 14, 14], f16),), {})
+cnt: 1, ((T([128, 128, 7, 7], f16),), {})
+cnt: 3, ((T([128, 256, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 1, 1], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([8, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 8, 112, 112], f16), T([8, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 1, ((T([128, 8, 112, 112], f16), T([16, 8, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 1, ((T([128, 16, 56, 56], f16), T([32, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([32, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([32, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([128, 32, 28, 28], f16), T([64, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([64, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([64, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([128, 64, 14, 14], f16), T([128, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 128, 14, 14], f16), T([128, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 128), {})
+cnt: 5, ((T([128, 128, 14, 14], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 128), {})
+cnt: 1, ((T([128, 128, 1, 1], f16), T([32, 128, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 32, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 7, 7], f16), T([256, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([256, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 256), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([64, 256, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([256, 64, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([1280, 256, 1, 1], f16), T([1280], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 1, 1], f16), T([128, 256, 1, 1], f16), T([1280, 256, 1, 1], f16), [1280], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 64, 1, 1], f16), T([256, 64, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([128, 256, 1, 1], f16), T([64, 256, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([256, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 256, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([128, 128, 7, 7], f16), T([256, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 1, 1], f16), T([128, 32, 1, 1], f16), T([128, 32, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 128, 1, 1], f16), T([32, 128, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 7, 7], f16), T([128, 128, 14, 14], f16), T([128, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 128, [True, True, False]), {})
+cnt: 5, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 128, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128, 64, 14, 14], f16), T([128, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 14, 14], f16), T([128, 64, 28, 28], f16), T([64, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([64, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([128, 32, 28, 28], f16), T([64, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 28, 28], f16), T([128, 32, 56, 56], f16), T([32, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([32, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([128, 16, 56, 56], f16), T([32, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 56, 56], f16), T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 16, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 8, 112, 112], f16), T([16, 8, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 8, 112, 112], f16), T([128, 8, 112, 112], f16), T([8, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([128, 8, 112, 112], f16), T([128, 3, 224, 224], f16), T([8, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 2, ((T([128, 256, 7, 7], f16, stride=(256, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 128, 7, 7], f16, stride=(128, 1, 0, 0)), 49), {})
+Operator: aten.hardsigmoid.default
+cnt: 1, ((T([128, 128, 1, 1], f16),), {})
+cnt: 1, ((T([128, 256, 1, 1], f16),), {})
+Operator: aten.hardsigmoid_backward.default
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16)), {})
+cnt: 1, ((T([128, 128, 1, 1], f16), T([128, 128, 1, 1], f16)), {})
+Operator: aten.hardswish_.default
+cnt: 2, ((T([128, 8, 112, 112], f16),), {})
+cnt: 1, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 16, 56, 56], f16),), {})
+cnt: 3, ((T([128, 32, 56, 56], f16),), {})
+cnt: 1, ((T([128, 32, 28, 28], f16),), {})
+cnt: 3, ((T([128, 64, 28, 28], f16),), {})
+cnt: 1, ((T([128, 64, 14, 14], f16),), {})
+cnt: 11, ((T([128, 128, 14, 14], f16),), {})
+cnt: 1, ((T([128, 128, 7, 7], f16),), {})
+cnt: 3, ((T([128, 256, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 1, 1], f16),), {})
+Operator: aten.hardswish_backward.default
+cnt: 1, ((T([128, 1280, 1, 1], f16), T([128, 1280, 1, 1], f16)), {})
+cnt: 3, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16)), {})
+cnt: 1, ((T([128, 128, 7, 7], f16), T([128, 128, 7, 7], f16)), {})
+cnt: 11, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16)), {})
+cnt: 1, ((T([128, 64, 14, 14], f16), T([128, 64, 14, 14], f16)), {})
+cnt: 3, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16)), {})
+cnt: 1, ((T([128, 32, 28, 28], f16), T([128, 32, 28, 28], f16)), {})
+cnt: 3, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16)), {})
+cnt: 1, ((T([128, 16, 56, 56], f16), T([128, 16, 56, 56], f16)), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+cnt: 2, ((T([128, 8, 112, 112], f16), T([128, 8, 112, 112], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 128, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 128, 7, 7], f16), T([128, 128, 1, 1], f16)), {})
+cnt: 2, ((T([128, 256, 7, 7], f16), T([128, 256, 1, 1], f16)), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16)), {})
+cnt: 1, ((T([128, 128, 7, 7], f16), T([128, 128, 7, 7], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 8, 112, 112], f16), T([8], f16), T([8], f16), T([8], f16), T([8], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 16, 56, 56], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 32, 28, 28], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 64, 28, 28], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 14, 14], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 11, ((T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 128, 7, 7], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 256, 7, 7], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 7, 7], f16), T([128, 128, 7, 7], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 11, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 14, 14], f16), T([128, 64, 14, 14], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 28, 28], f16), T([128, 32, 28, 28], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 56, 56], f16), T([128, 16, 56, 56], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 8, 112, 112], f16), T([128, 8, 112, 112], f16), T([8], f16), T([8], f16), T([8], f16), T([8], f32), T([8], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 32, 1, 1], f16),), {})
+cnt: 1, ((T([128, 64, 1, 1], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 1, ((T([128, 256, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 128, 7, 7], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 64, 1, 1], f16), T([128, 64, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 32, 1, 1], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/legacy_senet154_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/legacy_senet154_training.txt
new file mode 100644
index 0000000000000..c4895fad41ff9
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/legacy_senet154_training.txt
@@ -0,0 +1,183 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([32, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([32, 1000], f16), T([32, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 9, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+cnt: 24, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 108, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 8, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 157, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([64, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([128, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 128, 56, 56], f16), T([256, 2, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([256, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([16, 256, 1, 1], f16), T([16], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 16, 1, 1], f16), T([256, 16, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([512, 4, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 9, ((T([32, 512, 28, 28], f16), T([512, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([512, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([32, 512, 1, 1], f16), T([32, 512, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([32, 32, 1, 1], f16), T([512, 32, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([32, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([512, 4, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([1024, 8, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 37, ((T([32, 1024, 14, 14], f16), T([1024, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([1024, 512, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 36, ((T([32, 1024, 1, 1], f16), T([64, 1024, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 36, ((T([32, 64, 1, 1], f16), T([1024, 64, 1, 1], f16), T([1024], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 35, ((T([32, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([1024, 8, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([2048, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([2048, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([2048, 1024, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 2048, 1, 1], f16), T([128, 2048, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 128, 1, 1], f16), T([2048, 128, 1, 1], f16), T([2048], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([1024, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([2048, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([32, 2048, 1, 1], f16), T([32, 128, 1, 1], f16), T([2048, 128, 1, 1], f16), [2048], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 128, 1, 1], f16), T([32, 2048, 1, 1], f16), T([128, 2048, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), T([2048, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([32, 1024, 7, 7], f16), T([2048, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([32, 2048, 7, 7], f16), T([1024, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 1024, 14, 14], f16), T([2048, 1024, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 1024, 14, 14], f16), T([2048, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 37, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), T([1024, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 36, ((T([32, 1024, 1, 1], f16), T([32, 64, 1, 1], f16), T([1024, 64, 1, 1], f16), [1024], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 36, ((T([32, 64, 1, 1], f16), T([32, 1024, 1, 1], f16), T([64, 1024, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 35, ((T([32, 1024, 14, 14], f16), T([32, 512, 14, 14], f16), T([1024, 8, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([32, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 512, 28, 28], f16), T([1024, 512, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 512, 28, 28], f16), T([1024, 8, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 9, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), T([512, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([32, 512, 1, 1], f16), T([32, 32, 1, 1], f16), T([512, 32, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 8, ((T([32, 32, 1, 1], f16), T([32, 512, 1, 1], f16), T([32, 512, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 7, ((T([32, 512, 28, 28], f16), T([32, 256, 28, 28], f16), T([512, 4, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([32, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 256, 56, 56], f16), T([512, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 256, 56, 56], f16), T([512, 4, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([32, 16, 1, 1], f16), T([256, 16, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 16, 1, 1], f16), T([32, 256, 1, 1], f16), T([16, 256, 1, 1], f16), [16], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 128, 56, 56], f16), T([256, 2, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 2, ((T([32, 128, 56, 56], f16), T([32, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 128, 56, 56], f16), T([256, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([32, 64, 112, 112], f16), T([128, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 3, 224, 224], f16), T([64, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 4, ((T([32, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+cnt: 36, ((T([32, 1024, 14, 14], f16, stride=(1024, 1, 0, 0)), 196), {})
+cnt: 8, ((T([32, 512, 28, 28], f16, stride=(512, 1, 0, 0)), 784), {})
+cnt: 3, ((T([32, 256, 56, 56], f16, stride=(256, 1, 0, 0)), 3136), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([32], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([32, 128, 112, 112], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 112, 112], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([32, 128, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 3, ((T([32, 256, 56, 56], f16), [2, 3], True), {})
+cnt: 8, ((T([32, 512, 28, 28], f16), [2, 3], True), {})
+cnt: 36, ((T([32, 1024, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(1, 1000)), T([32, 2048], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 6, ((T([32, 256, 56, 56], f16), T([32, 256, 1, 1], f16)), {})
+cnt: 16, ((T([32, 512, 28, 28], f16), T([32, 512, 1, 1], f16)), {})
+cnt: 72, ((T([32, 1024, 14, 14], f16), T([32, 1024, 1, 1], f16)), {})
+cnt: 6, ((T([32, 2048, 7, 7], f16), T([32, 2048, 1, 1], f16)), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+cnt: 36, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 8, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 18, ((T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 74, ((T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 7, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([32, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 74, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([32, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 18, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([32, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([32, 128, 112, 112], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([32, 1000], f16), T([32], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([32, 1000], f16), T([32], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([32, 64, 112, 112], f16),), {})
+cnt: 1, ((T([32, 128, 112, 112], f16),), {})
+cnt: 3, ((T([32, 128, 56, 56], f16),), {})
+cnt: 7, ((T([32, 256, 56, 56], f16),), {})
+cnt: 3, ((T([32, 16, 1, 1], f16),), {})
+cnt: 17, ((T([32, 512, 28, 28], f16),), {})
+cnt: 8, ((T([32, 32, 1, 1], f16),), {})
+cnt: 7, ((T([32, 256, 28, 28], f16),), {})
+cnt: 73, ((T([32, 1024, 14, 14], f16),), {})
+cnt: 36, ((T([32, 64, 1, 1], f16),), {})
+cnt: 35, ((T([32, 512, 14, 14], f16),), {})
+cnt: 6, ((T([32, 2048, 7, 7], f16),), {})
+cnt: 3, ((T([32, 128, 1, 1], f16),), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 3, ((T([32, 256, 1, 1], f16),), {})
+cnt: 8, ((T([32, 512, 1, 1], f16),), {})
+cnt: 36, ((T([32, 1024, 1, 1], f16),), {})
+cnt: 3, ((T([32, 2048, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 3, ((T([32, 2048, 1, 1], f16), T([32, 2048, 1, 1], f16)), {})
+cnt: 36, ((T([32, 1024, 1, 1], f16), T([32, 1024, 1, 1], f16)), {})
+cnt: 8, ((T([32, 512, 1, 1], f16), T([32, 512, 1, 1], f16)), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([32, 256, 1, 1], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16), [0], True), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), [2, 3], True), {})
+cnt: 36, ((T([32, 1024, 14, 14], f16), [2, 3], True), {})
+cnt: 8, ((T([32, 512, 28, 28], f16), [2, 3], True), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 6, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), 0), {})
+cnt: 3, ((T([32, 128, 1, 1], f16), T([32, 128, 1, 1], f16), 0), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([32, 1024, 7, 7], f16), 0), {})
+cnt: 73, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), 0), {})
+cnt: 36, ((T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), 0), {})
+cnt: 35, ((T([32, 512, 14, 14], f16), T([32, 512, 14, 14], f16), 0), {})
+cnt: 17, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), 0), {})
+cnt: 8, ((T([32, 32, 1, 1], f16), T([32, 32, 1, 1], f16), 0), {})
+cnt: 7, ((T([32, 256, 28, 28], f16), T([32, 256, 28, 28], f16), 0), {})
+cnt: 7, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), 0), {})
+cnt: 3, ((T([32, 16, 1, 1], f16), T([32, 16, 1, 1], f16), 0), {})
+cnt: 3, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 128, 112, 112], f16), T([32, 128, 112, 112], f16), 0), {})
+cnt: 2, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/levit_128_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/levit_128_training.txt
new file mode 100644
index 0000000000000..e24ac0ec6f74f
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/levit_128_training.txt
@@ -0,0 +1,295 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 4, ((T([128, 4, 196, 196], f16), -1, False), {})
+cnt: 1, ((T([128, 8, 49, 196], f16), -1, False), {})
+cnt: 4, ((T([128, 8, 49, 49], f16), -1, False), {})
+cnt: 1, ((T([128, 16, 16, 49], f16), -1, False), {})
+cnt: 4, ((T([128, 12, 16, 16], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 4, ((T([128, 12, 16, 16], f16), T([128, 12, 16, 16], f16), -1, f16), {})
+cnt: 1, ((T([128, 16, 16, 49], f16), T([128, 16, 16, 49], f16), -1, f16), {})
+cnt: 4, ((T([128, 8, 49, 49], f16), T([128, 8, 49, 49], f16), -1, f16), {})
+cnt: 1, ((T([128, 8, 49, 196], f16), T([128, 8, 49, 196], f16), -1, f16), {})
+cnt: 4, ((T([128, 4, 196, 196], f16), T([128, 4, 196, 196], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 8, ((T([128, 196, 256], f16), [128, 196, 256]), {})
+cnt: 4, ((T([128, 4, 196, 16], f16), [512, 196, 16]), {})
+cnt: 4, ((T([128, 4, 16, 196], f16), [512, 16, 196]), {})
+cnt: 4, ((T([512, 196, 196], f16), [128, 4, 196, 196]), {})
+cnt: 8, ((T([128, 4, 196, 32], f16), [512, 196, 32]), {})
+cnt: 4, ((T([512, 196, 32], f16), [128, 4, 196, 32]), {})
+cnt: 4, ((T([128, 196, 4, 32], f16), [128, 196, 128]), {})
+cnt: 8, ((T([25088, 128], f16), [128, 196, 128]), {})
+cnt: 1, ((T([128, 196, 640], f16), [128, 196, 640]), {})
+cnt: 1, ((T([128, 7, 7, 128], f16), [128, 49, 128]), {})
+cnt: 1, ((T([6272, 128], f16), [128, 49, 128]), {})
+cnt: 5, ((T([128, 8, 49, 16], f16), [1024, 49, 16]), {})
+cnt: 1, ((T([128, 8, 16, 196], f16), [1024, 16, 196]), {})
+cnt: 1, ((T([1024, 49, 196], f16), [128, 8, 49, 196]), {})
+cnt: 1, ((T([128, 8, 196, 64], f16), [1024, 196, 64]), {})
+cnt: 1, ((T([1024, 49, 64], f16), [128, 8, 49, 64]), {})
+cnt: 1, ((T([128, 49, 8, 64], f16), [128, 49, 512]), {})
+cnt: 10, ((T([6272, 256], f16), [128, 49, 256]), {})
+cnt: 9, ((T([6272, 512], f16), [128, 49, 512]), {})
+cnt: 4, ((T([128, 8, 16, 49], f16), [1024, 16, 49]), {})
+cnt: 4, ((T([1024, 49, 49], f16), [128, 8, 49, 49]), {})
+cnt: 8, ((T([128, 8, 49, 32], f16), [1024, 49, 32]), {})
+cnt: 4, ((T([1024, 49, 32], f16), [128, 8, 49, 32]), {})
+cnt: 4, ((T([128, 49, 8, 32], f16), [128, 49, 256]), {})
+cnt: 1, ((T([6272, 1280], f16), [128, 49, 1280]), {})
+cnt: 1, ((T([128, 4, 4, 256], f16), [128, 16, 256]), {})
+cnt: 1, ((T([2048, 256], f16), [128, 16, 256]), {})
+cnt: 1, ((T([128, 16, 16, 16], f16), [2048, 16, 16]), {})
+cnt: 1, ((T([128, 16, 16, 49], f16), [2048, 16, 49]), {})
+cnt: 1, ((T([2048, 16, 49], f16), [128, 16, 16, 49]), {})
+cnt: 1, ((T([128, 16, 49, 64], f16), [2048, 49, 64]), {})
+cnt: 1, ((T([2048, 16, 64], f16), [128, 16, 16, 64]), {})
+cnt: 1, ((T([128, 16, 16, 64], f16), [128, 16, 1024]), {})
+cnt: 10, ((T([2048, 384], f16), [128, 16, 384]), {})
+cnt: 9, ((T([2048, 768], f16), [128, 16, 768]), {})
+cnt: 8, ((T([128, 12, 16, 16], f16), [1536, 16, 16]), {})
+cnt: 4, ((T([1536, 16, 16], f16), [128, 12, 16, 16]), {})
+cnt: 8, ((T([128, 12, 16, 32], f16), [1536, 16, 32]), {})
+cnt: 4, ((T([1536, 16, 32], f16), [128, 12, 16, 32]), {})
+cnt: 4, ((T([128, 16, 12, 32], f16), [128, 16, 384]), {})
+cnt: 1, ((T([128, 16, 16, 64], f16), [2048, 16, 64]), {})
+cnt: 1, ((T([128, 16, 16, 16], f16), [128, 16, 256]), {})
+cnt: 1, ((T([128, 8, 49, 64], f16), [1024, 49, 64]), {})
+cnt: 1, ((T([128, 49, 8, 16], f16), [128, 49, 128]), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([128, 4, 196, 196], f16), T([4, 196, 196], f16)), {})
+cnt: 8, ((T([128, 196, 128], f16, stride=(25088, 1, 196)), T([128, 196, 128], f16)), {})
+cnt: 1, ((T([128, 8, 49, 196], f16), T([8, 49, 196], f16)), {})
+cnt: 19, ((T([128, 49, 256], f16), T([128, 49, 256], f16)), {})
+cnt: 4, ((T([128, 8, 49, 49], f16), T([8, 49, 49], f16)), {})
+cnt: 1, ((T([128, 16, 16, 49], f16), T([16, 16, 49], f16)), {})
+cnt: 18, ((T([128, 16, 384], f16), T([128, 16, 384], f16)), {})
+cnt: 4, ((T([128, 12, 16, 16], f16), T([12, 16, 16], f16)), {})
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16)), {})
+cnt: 1, ((T([128, 384], f16), T([128, 384], f16)), {})
+cnt: 9, ((T([128, 196, 128], f16), T([128, 196, 128], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 64, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 2, ((T([1000], f16), T([128, 384], f16), T([384, 1000], f16, stride=(1, 384))), {})
+Operator: aten.bmm.default
+cnt: 8, ((T([128, 196, 128], f16, stride=(25088, 1, 196)), T([128, 128, 256], f16, stride=(0, 1, 128))), {})
+cnt: 4, ((T([512, 196, 16], f16), T([512, 16, 196], f16)), {})
+cnt: 4, ((T([512, 196, 196], f16), T([512, 196, 32], f16)), {})
+cnt: 1, ((T([128, 196, 128], f16, stride=(25088, 1, 196)), T([128, 128, 640], f16, stride=(0, 1, 128))), {})
+cnt: 1, ((T([1024, 49, 16], f16), T([1024, 16, 196], f16)), {})
+cnt: 1, ((T([1024, 49, 196], f16), T([1024, 196, 64], f16)), {})
+cnt: 4, ((T([1024, 49, 16], f16), T([1024, 16, 49], f16)), {})
+cnt: 4, ((T([1024, 49, 49], f16), T([1024, 49, 32], f16)), {})
+cnt: 1, ((T([2048, 16, 16], f16), T([2048, 16, 49], f16)), {})
+cnt: 1, ((T([2048, 16, 49], f16), T([2048, 49, 64], f16)), {})
+cnt: 4, ((T([1536, 16, 16], f16), T([1536, 16, 16], f16)), {})
+cnt: 4, ((T([1536, 16, 16], f16), T([1536, 16, 32], f16)), {})
+cnt: 4, ((T([1536, 16, 16], f16, stride=(256, 1, 16)), T([1536, 16, 32], f16)), {})
+cnt: 4, ((T([1536, 16, 32], f16), T([1536, 32, 16], f16, stride=(512, 1, 32))), {})
+cnt: 4, ((T([1536, 16, 16], f16, stride=(256, 1, 16)), T([1536, 16, 16], f16)), {})
+cnt: 4, ((T([1536, 16, 16], f16), T([1536, 16, 16], f16, stride=(256, 1, 16))), {})
+cnt: 1, ((T([2048, 49, 16], f16, stride=(784, 1, 49)), T([2048, 16, 64], f16)), {})
+cnt: 1, ((T([2048, 16, 64], f16), T([2048, 64, 49], f16, stride=(3136, 1, 64))), {})
+cnt: 1, ((T([2048, 16, 16], f16, stride=(256, 1, 16)), T([2048, 16, 49], f16)), {})
+cnt: 1, ((T([2048, 16, 49], f16), T([2048, 49, 16], f16, stride=(784, 1, 49))), {})
+cnt: 4, ((T([1024, 49, 49], f16, stride=(2401, 1, 49)), T([1024, 49, 32], f16)), {})
+cnt: 4, ((T([1024, 49, 32], f16), T([1024, 32, 49], f16, stride=(1568, 1, 32))), {})
+cnt: 4, ((T([1024, 16, 49], f16, stride=(784, 1, 16)), T([1024, 49, 49], f16)), {})
+cnt: 4, ((T([1024, 49, 49], f16), T([1024, 49, 16], f16, stride=(784, 1, 49))), {})
+cnt: 1, ((T([1024, 196, 49], f16, stride=(9604, 1, 196)), T([1024, 49, 64], f16)), {})
+cnt: 1, ((T([1024, 49, 64], f16), T([1024, 64, 196], f16, stride=(12544, 1, 64))), {})
+cnt: 1, ((T([1024, 16, 49], f16, stride=(784, 1, 16)), T([1024, 49, 196], f16)), {})
+cnt: 1, ((T([1024, 49, 196], f16), T([1024, 196, 16], f16, stride=(3136, 1, 196))), {})
+cnt: 1, ((T([128, 128, 196], f16), T([128, 196, 640], f16)), {})
+cnt: 1, ((T([128, 196, 640], f16), T([128, 640, 128], f16, stride=(0, 128, 1))), {})
+cnt: 8, ((T([128, 128, 196], f16), T([128, 196, 256], f16)), {})
+cnt: 8, ((T([128, 196, 256], f16), T([128, 256, 128], f16, stride=(0, 128, 1))), {})
+cnt: 4, ((T([512, 196, 196], f16, stride=(38416, 1, 196)), T([512, 196, 32], f16)), {})
+cnt: 4, ((T([512, 196, 32], f16), T([512, 32, 196], f16, stride=(6272, 1, 32))), {})
+cnt: 4, ((T([512, 16, 196], f16, stride=(3136, 1, 16)), T([512, 196, 196], f16)), {})
+cnt: 4, ((T([512, 196, 196], f16), T([512, 196, 16], f16, stride=(3136, 1, 196))), {})
+Operator: aten.cat.default
+cnt: 4, (([T([128, 16, 12, 16], f16, stride=(3072, 16, 256, 1)), T([128, 16, 12, 16], f16, stride=(3072, 1, 256, 16)), T([128, 16, 12, 32], f16, stride=(6144, 32, 512, 1))], 3), {})
+cnt: 1, (([T([128, 49, 16, 16], f16, stride=(12544, 1, 784, 49)), T([128, 49, 16, 64], f16, stride=(50176, 64, 3136, 1))], 3), {})
+cnt: 4, (([T([128, 49, 8, 16], f16, stride=(6272, 16, 784, 1)), T([128, 49, 8, 16], f16, stride=(6272, 1, 784, 49)), T([128, 49, 8, 32], f16, stride=(12544, 32, 1568, 1))], 3), {})
+cnt: 1, (([T([128, 196, 8, 16], f16, stride=(25088, 1, 3136, 196)), T([128, 196, 8, 64], f16, stride=(100352, 64, 12544, 1))], 3), {})
+cnt: 4, (([T([128, 196, 4, 16], f16, stride=(12544, 16, 3136, 1)), T([128, 196, 4, 16], f16, stride=(12544, 1, 3136, 196)), T([128, 196, 4, 32], f16, stride=(25088, 32, 6272, 1))], 3), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([32, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([64, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([128, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 128, 14, 14], f16, stride=(25088, 1, 1792, 128)), T([128, 64, 28, 28], f16), T([128, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([128, 32, 56, 56], f16), T([64, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([128, 16, 112, 112], f16), T([32, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+cnt: 1, ((T([640, 128], f16), T([640, 128], f16, stride=(1, 640))), {})
+cnt: 8, ((T([256, 128], f16), T([256, 128], f16, stride=(1, 256))), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 16, 384], f16, stride=(384, 0, 1)), 16), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([128, 1000], f16), 2), {})
+Operator: aten.hardswish.default
+cnt: 1, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 32, 56, 56], f16),), {})
+cnt: 1, ((T([128, 64, 28, 28], f16),), {})
+cnt: 4, ((T([128, 196, 128], f16),), {})
+cnt: 4, ((T([128, 196, 256], f16),), {})
+cnt: 6, ((T([128, 49, 512], f16),), {})
+cnt: 4, ((T([128, 49, 256], f16),), {})
+cnt: 1, ((T([128, 16, 1024], f16),), {})
+cnt: 5, ((T([128, 16, 768], f16),), {})
+cnt: 4, ((T([128, 16, 384], f16),), {})
+Operator: aten.hardswish_backward.default
+cnt: 5, ((T([128, 16, 768], f16), T([128, 16, 768], f16)), {})
+cnt: 4, ((T([128, 16, 384], f16), T([128, 16, 384], f16)), {})
+cnt: 1, ((T([128, 16, 1024], f16), T([128, 16, 1024], f16)), {})
+cnt: 6, ((T([128, 49, 512], f16), T([128, 49, 512], f16)), {})
+cnt: 4, ((T([128, 49, 256], f16), T([128, 49, 256], f16)), {})
+cnt: 4, ((T([128, 196, 256], f16), T([128, 196, 256], f16)), {})
+cnt: 4, ((T([128, 196, 128], f16), T([128, 196, 128], f16)), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16)), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16)), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+Operator: aten.index.Tensor
+cnt: 4, ((T([4, 196], f16), [None, T([196, 196], i64)]), {})
+cnt: 1, ((T([8, 196], f16), [None, T([49, 196], i64)]), {})
+cnt: 4, ((T([8, 49], f16), [None, T([49, 49], i64)]), {})
+cnt: 1, ((T([16, 49], f16), [None, T([16, 49], i64)]), {})
+cnt: 4, ((T([12, 16], f16), [None, T([16, 16], i64)]), {})
+Operator: aten.index_put.default
+cnt: 4, ((T([12, 16], f16), [None, T([16, 16], i64)], T([12, 16, 16], f16), True), {})
+cnt: 1, ((T([16, 49], f16), [None, T([16, 49], i64)], T([16, 16, 49], f16), True), {})
+cnt: 4, ((T([8, 49], f16), [None, T([49, 49], i64)], T([8, 49, 49], f16), True), {})
+cnt: 1, ((T([8, 196], f16), [None, T([49, 196], i64)], T([8, 49, 196], f16), True), {})
+cnt: 4, ((T([4, 196], f16), [None, T([196, 196], i64)], T([4, 196, 196], f16), True), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 16, 384], f16), [1]), {})
+Operator: aten.mm.default
+cnt: 4, ((T([25088, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 4, ((T([25088, 256], f16), T([256, 128], f16, stride=(1, 256))), {})
+cnt: 1, ((T([6272, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 6, ((T([6272, 512], f16), T([512, 256], f16, stride=(1, 512))), {})
+cnt: 9, ((T([6272, 256], f16), T([256, 512], f16, stride=(1, 256))), {})
+cnt: 4, ((T([6272, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 1, ((T([6272, 256], f16), T([256, 1280], f16, stride=(1, 256))), {})
+cnt: 1, ((T([2048, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 1, ((T([2048, 1024], f16), T([1024, 384], f16, stride=(1, 1024))), {})
+cnt: 9, ((T([2048, 384], f16), T([384, 768], f16, stride=(1, 384))), {})
+cnt: 5, ((T([2048, 768], f16), T([768, 384], f16, stride=(1, 768))), {})
+cnt: 4, ((T([2048, 384], f16), T([384, 384], f16, stride=(1, 384))), {})
+cnt: 2, ((T([128, 1000], f16), T([1000, 384], f16)), {})
+cnt: 2, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 384], f16)), {})
+cnt: 5, ((T([384, 2048], f16, stride=(1, 384)), T([2048, 768], f16)), {})
+cnt: 5, ((T([2048, 384], f16), T([384, 768], f16)), {})
+cnt: 9, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 384], f16)), {})
+cnt: 9, ((T([2048, 768], f16), T([768, 384], f16)), {})
+cnt: 4, ((T([384, 2048], f16, stride=(1, 384)), T([2048, 384], f16)), {})
+cnt: 4, ((T([2048, 384], f16), T([384, 384], f16)), {})
+cnt: 1, ((T([384, 2048], f16, stride=(1, 384)), T([2048, 1024], f16)), {})
+cnt: 1, ((T([2048, 384], f16), T([384, 1024], f16)), {})
+cnt: 1, ((T([256, 2048], f16, stride=(1, 256)), T([2048, 256], f16)), {})
+cnt: 1, ((T([2048, 256], f16), T([256, 256], f16)), {})
+cnt: 1, ((T([1280, 6272], f16, stride=(1, 1280)), T([6272, 256], f16)), {})
+cnt: 1, ((T([6272, 1280], f16), T([1280, 256], f16)), {})
+cnt: 6, ((T([256, 6272], f16, stride=(1, 256)), T([6272, 512], f16)), {})
+cnt: 6, ((T([6272, 256], f16), T([256, 512], f16)), {})
+cnt: 9, ((T([512, 6272], f16, stride=(1, 512)), T([6272, 256], f16)), {})
+cnt: 9, ((T([6272, 512], f16), T([512, 256], f16)), {})
+cnt: 4, ((T([256, 6272], f16, stride=(1, 256)), T([6272, 256], f16)), {})
+cnt: 4, ((T([6272, 256], f16), T([256, 256], f16)), {})
+cnt: 1, ((T([128, 6272], f16, stride=(1, 128)), T([6272, 128], f16)), {})
+cnt: 1, ((T([6272, 128], f16), T([128, 128], f16)), {})
+cnt: 4, ((T([128, 25088], f16, stride=(1, 128)), T([25088, 256], f16)), {})
+cnt: 4, ((T([25088, 128], f16), T([128, 256], f16)), {})
+cnt: 4, ((T([128, 25088], f16, stride=(1, 128)), T([25088, 128], f16)), {})
+cnt: 4, ((T([25088, 128], f16), T([128, 128], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 8, ((T([128, 4, 196, 196], f16), 0.25), {})
+cnt: 2, ((T([128, 8, 49, 196], f16), 0.25), {})
+cnt: 8, ((T([128, 8, 49, 49], f16), 0.25), {})
+cnt: 2, ((T([128, 16, 16, 49], f16), 0.25), {})
+cnt: 8, ((T([128, 12, 16, 16], f16), 0.25), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([25088, 256], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([25088, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([25088, 640], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([6272, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 10, ((T([6272, 256], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([6272, 512], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([6272, 1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([2048, 256], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 10, ((T([2048, 384], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([2048, 768], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 384], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 2, ((T([128, 384], f16), T([128, 384], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 10, ((T([2048, 384], f16), T([2048, 384], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([2048, 768], f16), T([2048, 768], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([2048, 256], f16), T([2048, 256], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([6272, 1280], f16), T([6272, 1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 1e-05, [True, True, True]), {})
+cnt: 10, ((T([6272, 256], f16), T([6272, 256], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([6272, 512], f16), T([6272, 512], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([6272, 128], f16), T([6272, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([25088, 640], f16), T([25088, 640], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f32), T([640], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([25088, 128], f16), T([25088, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([25088, 256], f16), T([25088, 256], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 14, 14], f16, stride=(25088, 1, 1792, 128)), T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 1, ((T([640, 128], f16, stride=(1, 640)), [640, 128], [128, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 8, ((T([256, 128], f16, stride=(1, 256)), [256, 128], [128, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.new_zeros.default
+cnt: 4, ((T([12, 16, 16], f16), [12, 16]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([16, 16, 49], f16), [16, 49]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 4, ((T([8, 49, 49], f16), [8, 49]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([8, 49, 196], f16), [8, 196]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 4, ((T([4, 196, 196], f16), [4, 196]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.slice_backward.default
+cnt: 4, ((T([12, 16], f16), [12, 16], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([16, 49], f16), [16, 49], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 4, 4, 256], f16), [128, 4, 7, 256], 2, 0, 9223372036854775807, 2), {})
+cnt: 1, ((T([128, 4, 7, 256], f16), [128, 7, 7, 256], 1, 0, 9223372036854775807, 2), {})
+cnt: 1, ((T([128, 7, 7, 256], f16), [128, 7, 7, 256], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([8, 49], f16), [8, 49], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([8, 196], f16), [8, 196], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 7, 7, 128], f16), [128, 7, 14, 128], 2, 0, 9223372036854775807, 2), {})
+cnt: 1, ((T([128, 7, 14, 128], f16), [128, 14, 14, 128], 1, 0, 9223372036854775807, 2), {})
+cnt: 1, ((T([128, 14, 14, 128], f16), [128, 14, 14, 128], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([4, 196], f16), [4, 196], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.split_with_sizes.default
+cnt: 4, ((T([128, 196, 4, 64], f16), [16, 16, 32], 3), {})
+cnt: 1, ((T([128, 196, 8, 80], f16), [16, 64], 3), {})
+cnt: 4, ((T([128, 49, 8, 64], f16), [16, 16, 32], 3), {})
+cnt: 1, ((T([128, 49, 16, 80], f16), [16, 64], 3), {})
+cnt: 4, ((T([128, 16, 12, 64], f16), [16, 16, 32], 3), {})
+Operator: aten.sum.SymInt
+cnt: 2, ((T([128, 1000], f16), [0], True), {})
+cnt: 4, ((T([128, 12, 16, 16], f16), [0], True), {})
+cnt: 1, ((T([128, 16, 16, 49], f16), [0], True), {})
+cnt: 4, ((T([128, 8, 49, 49], f16), [0], True), {})
+cnt: 1, ((T([128, 8, 49, 196], f16), [0], True), {})
+cnt: 1, ((T([128, 128, 640], f16), [0], True), {})
+cnt: 8, ((T([128, 128, 256], f16), [0], True), {})
+cnt: 4, ((T([128, 4, 196, 196], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mixer_b16_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mixer_b16_224_training.txt
new file mode 100644
index 0000000000000..483b2dad380ba
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mixer_b16_224_training.txt
@@ -0,0 +1,70 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([64, 768, 384], f16), [64, 768, 384]), {})
+cnt: 12, ((T([64, 768, 196], f16), [49152, 196]), {})
+Operator: aten.add.Tensor
+cnt: 12, ((T([64, 768, 384], f16), T([384], f16)), {})
+cnt: 12, ((T([64, 196, 768], f16, stride=(150528, 1, 196)), T([64, 196, 768], f16, stride=(150528, 1, 196))), {})
+cnt: 12, ((T([64, 196, 768], f16, stride=(150528, 1, 196)), T([64, 196, 768], f16)), {})
+cnt: 12, ((T([64, 196, 768], f16), T([64, 196, 768], f16)), {})
+cnt: 12, ((T([64, 196, 768], f16), T([64, 196, 768], f16, stride=(150528, 1, 196))), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([196], f16), T([49152, 384], f16), T([384, 196], f16, stride=(1, 384))), {})
+cnt: 12, ((T([3072], f16), T([12544, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([12544, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([1000], f16), T([64, 768], f16), T([768, 1000], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([64, 768, 196], f16, stride=(150528, 1, 768)), T([64, 196, 384], f16, stride=(0, 1, 196))), {})
+cnt: 12, ((T([64, 196, 768], f16), T([64, 768, 384], f16)), {})
+cnt: 12, ((T([64, 768, 384], f16), T([64, 384, 196], f16, stride=(0, 196, 1))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([768, 3, 16, 16], f16), T([768], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 768, 14, 14], f16, stride=(150528, 1, 10752, 768)), T([64, 3, 224, 224], f16), T([768, 3, 16, 16], f16), [768], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+cnt: 12, ((T([384, 196], f16), T([384, 196], f16, stride=(1, 384))), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 196, 768], f16, stride=(768, 0, 1)), 196), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([64, 768, 384], f16),), {})
+cnt: 12, ((T([64, 196, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([64, 196, 3072], f16), T([64, 196, 3072], f16)), {})
+cnt: 12, ((T([64, 768, 384], f16), T([64, 768, 384], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 196, 768], f16), [1]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 768], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 768], f16)), {})
+cnt: 12, ((T([12544, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 12544], f16, stride=(1, 768)), T([12544, 3072], f16)), {})
+cnt: 12, ((T([12544, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 12544], f16, stride=(1, 3072)), T([12544, 768], f16)), {})
+cnt: 12, ((T([49152, 196], f16), T([196, 384], f16)), {})
+cnt: 12, ((T([196, 49152], f16, stride=(1, 196)), T([49152, 384], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([64, 196, 768], f16, stride=(150528, 1, 196)), [768], T([768], f16), T([768], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 13, ((T([64, 196, 768], f16), T([64, 196, 768], f16, stride=(150528, 1, 196)), [768], T([64, 196, 1], f32), T([64, 196, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+cnt: 12, ((T([64, 196, 768], f16, stride=(150528, 1, 196)), T([64, 196, 768], f16, stride=(150528, 1, 196)), [768], T([64, 196, 1], f32), T([64, 196, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 12, ((T([384, 196], f16, stride=(1, 384)), [384, 196], [196, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 12, ((T([12544, 768], f16), [0], True), {})
+cnt: 12, ((T([12544, 3072], f16), [0], True), {})
+cnt: 12, ((T([49152, 196], f16), [0], True), {})
+cnt: 12, ((T([64, 768, 384], f16), [0, 1], True), {})
+cnt: 12, ((T([64, 196, 384], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mixnet_l_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mixnet_l_training.txt
new file mode 100644
index 0000000000000..74b315457b93c
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mixnet_l_training.txt
@@ -0,0 +1,378 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 58, ((T([], i64), 1), {})
+cnt: 2, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16)), {})
+cnt: 2, ((T([64, 40, 56, 56], f16), T([64, 40, 56, 56], f16)), {})
+cnt: 6, ((T([64, 56, 28, 28], f16), T([64, 56, 28, 28], f16)), {})
+cnt: 6, ((T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16)), {})
+cnt: 6, ((T([64, 160, 14, 14], f16), T([64, 160, 14, 14], f16)), {})
+cnt: 6, ((T([64, 264, 7, 7], f16), T([64, 264, 7, 7], f16)), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), T([64, 1584, 7, 7], f16)), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16)), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16)), {})
+cnt: 4, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16)), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([64, 336, 14, 14], f16)), {})
+cnt: 3, ((T([64, 336, 28, 28], f16), T([64, 336, 28, 28], f16)), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([64, 240, 28, 28], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([64, 1536], f16), T([1536, 1000], f16, stride=(1, 1536))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 96, 112, 112], f16), T([64, 96, 112, 112], f16)], 1), {})
+cnt: 1, (([T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16)], 1), {})
+cnt: 3, (([T([64, 20, 56, 56], f16), T([64, 20, 56, 56], f16)], 1), {})
+cnt: 2, (([T([64, 60, 56, 56], f16), T([64, 60, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 60, 28, 28], f16), T([64, 60, 28, 28], f16), T([64, 60, 28, 28], f16), T([64, 60, 28, 28], f16)], 1), {})
+cnt: 12, (([T([64, 168, 28, 28], f16), T([64, 168, 28, 28], f16)], 1), {})
+cnt: 6, (([T([64, 28, 28, 28], f16), T([64, 28, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 112, 14, 14], f16), T([64, 112, 14, 14], f16), T([64, 112, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 312, 14, 14], f16), T([64, 312, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 156, 14, 14], f16), T([64, 156, 14, 14], f16), T([64, 156, 14, 14], f16), T([64, 156, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 52, 14, 14], f16), T([64, 52, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 240, 14, 14], f16), T([64, 240, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 120, 14, 14], f16), T([64, 120, 14, 14], f16), T([64, 120, 14, 14], f16), T([64, 120, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 80, 14, 14], f16), T([64, 80, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 240, 7, 7], f16), T([64, 240, 7, 7], f16), T([64, 240, 7, 7], f16), T([64, 240, 7, 7], f16)], 1), {})
+cnt: 6, (([T([64, 396, 7, 7], f16), T([64, 396, 7, 7], f16), T([64, 396, 7, 7], f16), T([64, 396, 7, 7], f16)], 1), {})
+cnt: 3, (([T([64, 132, 7, 7], f16), T([64, 132, 7, 7], f16)], 1), {})
+cnt: 3, (([T([64, 792, 7, 7], f16), T([64, 792, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 240, 14, 14], f16), T([64, 240, 14, 14], f16), T([64, 240, 14, 14], f16), T([64, 240, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 112, 28, 28], f16), T([64, 112, 28, 28], f16), T([64, 112, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 60, 56, 56], f16), T([64, 60, 56, 56], f16), T([64, 60, 56, 56], f16), T([64, 60, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 96, 56, 56], f16), T([64, 96, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16)], 1), {})
+cnt: 1, (([T([64, 16, 112, 112], f16), T([64, 16, 112, 112], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+cnt: 1, ((T([64, 240, 56, 56], f16),), {})
+cnt: 1, ((T([64, 240, 28, 28], f16),), {})
+cnt: 1, ((T([64, 20, 1, 1], f16),), {})
+cnt: 7, ((T([64, 336, 28, 28], f16),), {})
+cnt: 3, ((T([64, 28, 1, 1], f16),), {})
+cnt: 1, ((T([64, 336, 14, 14], f16),), {})
+cnt: 1, ((T([64, 14, 1, 1], f16),), {})
+cnt: 8, ((T([64, 624, 14, 14], f16),), {})
+cnt: 3, ((T([64, 26, 1, 1], f16),), {})
+cnt: 1, ((T([64, 52, 1, 1], f16),), {})
+cnt: 6, ((T([64, 480, 14, 14], f16),), {})
+cnt: 4, ((T([64, 80, 1, 1], f16),), {})
+cnt: 1, ((T([64, 960, 14, 14], f16),), {})
+cnt: 1, ((T([64, 960, 7, 7], f16),), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16),), {})
+cnt: 3, ((T([64, 132, 1, 1], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([32, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 16, 112, 112], f16, stride=(401408, 12544, 112, 1)), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 112, 112], f16, stride=(2408448, 12544, 112, 1)), T([64, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([64, 64, 112, 112], f16, stride=(2408448, 12544, 112, 1)), T([64, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([64, 64, 112, 112], f16, stride=(2408448, 12544, 112, 1)), T([64, 1, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 64), {})
+cnt: 2, ((T([64, 96, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([20, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 20, 56, 56], f16, stride=(125440, 3136, 56, 1)), T([60, 20, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 120, 56, 56], f16), T([120, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 120), {})
+cnt: 2, ((T([64, 60, 56, 56], f16, stride=(376320, 3136, 56, 1)), T([20, 60, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 40, 56, 56], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), T([60, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 60), {})
+cnt: 1, ((T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), T([60, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 60), {})
+cnt: 1, ((T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), T([60, 1, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 60), {})
+cnt: 1, ((T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), T([60, 1, 9, 9], f16), None, [2, 2], [4, 4], [1, 1], False, [0, 0], 60), {})
+cnt: 1, ((T([64, 240, 1, 1], f16), T([20, 240, 1, 1], f16), T([20], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 20, 1, 1], f16), T([240, 20, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([56, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 28, 28, 28], f16, stride=(43904, 784, 28, 1)), T([168, 28, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([168, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 168), {})
+cnt: 3, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([168, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 168), {})
+cnt: 3, ((T([64, 336, 1, 1], f16), T([28, 336, 1, 1], f16), T([28], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 28, 1, 1], f16), T([336, 28, 1, 1], f16), T([336], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([28, 168, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 56, 28, 28], f16), T([336, 56, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 112, 28, 28], f16, stride=(263424, 784, 28, 1)), T([112, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 112), {})
+cnt: 1, ((T([64, 112, 28, 28], f16, stride=(263424, 784, 28, 1)), T([112, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 112), {})
+cnt: 1, ((T([64, 112, 28, 28], f16, stride=(263424, 784, 28, 1)), T([112, 1, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 112), {})
+cnt: 1, ((T([64, 336, 1, 1], f16), T([14, 336, 1, 1], f16), T([14], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 14, 1, 1], f16), T([336, 14, 1, 1], f16), T([336], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([104, 336, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 52, 14, 14], f16, stride=(20384, 196, 14, 1)), T([312, 52, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 156), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 156), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 156), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 9, 9], f16), None, [1, 1], [4, 4], [1, 1], False, [0, 0], 156), {})
+cnt: 3, ((T([64, 624, 1, 1], f16), T([26, 624, 1, 1], f16), T([26], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 26, 1, 1], f16), T([624, 26, 1, 1], f16), T([624], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 312, 14, 14], f16, stride=(122304, 196, 14, 1)), T([52, 312, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 104, 14, 14], f16), T([624, 104, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 624, 14, 14], f16), T([624, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 624), {})
+cnt: 1, ((T([64, 624, 1, 1], f16), T([52, 624, 1, 1], f16), T([52], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 52, 1, 1], f16), T([624, 52, 1, 1], f16), T([624], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 624, 14, 14], f16), T([160, 624, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 80, 14, 14], f16, stride=(31360, 196, 14, 1)), T([240, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 120), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 120), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 120), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 9, 9], f16), None, [1, 1], [4, 4], [1, 1], False, [0, 0], 120), {})
+cnt: 3, ((T([64, 480, 1, 1], f16), T([80, 480, 1, 1], f16), T([80], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 80, 1, 1], f16), T([480, 80, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 240, 14, 14], f16, stride=(94080, 196, 14, 1)), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 160, 14, 14], f16), T([960, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), T([240, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), T([240, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), T([240, 1, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), T([240, 1, 9, 9], f16), None, [2, 2], [4, 4], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([64, 960, 1, 1], f16), T([80, 960, 1, 1], f16), T([80], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 80, 1, 1], f16), T([960, 80, 1, 1], f16), T([960], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([264, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 264, 7, 7], f16), T([1584, 264, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 396), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 396), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 396), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 9, 9], f16), None, [1, 1], [4, 4], [1, 1], False, [0, 0], 396), {})
+cnt: 3, ((T([64, 1584, 1, 1], f16), T([132, 1584, 1, 1], f16), T([132], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 132, 1, 1], f16), T([1584, 132, 1, 1], f16), T([1584], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 792, 7, 7], f16, stride=(77616, 49, 7, 1)), T([132, 792, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 264, 7, 7], f16), T([1536, 264, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 1536, 7, 7], f16), T([64, 264, 7, 7], f16), T([1536, 264, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([64, 132, 7, 7], f16, stride=(12936, 49, 7, 1)), T([64, 792, 7, 7], f16, stride=(77616, 49, 7, 1)), T([132, 792, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 1584, 1, 1], f16), T([64, 132, 1, 1], f16), T([1584, 132, 1, 1], f16), [1584], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 132, 1, 1], f16), T([64, 1584, 1, 1], f16), T([132, 1584, 1, 1], f16), [132], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 9, 9], f16), [0], [1, 1], [4, 4], [1, 1], False, [0, 0], 396, [True, True, False]), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 396, [True, True, False]), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 396, [True, True, False]), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 396, [True, True, False]), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), T([64, 264, 7, 7], f16), T([1584, 264, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 264, 7, 7], f16), T([64, 960, 7, 7], f16), T([264, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 960, 1, 1], f16), T([64, 80, 1, 1], f16), T([960, 80, 1, 1], f16), [960], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 80, 1, 1], f16), T([64, 960, 1, 1], f16), T([80, 960, 1, 1], f16), [80], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 240, 7, 7], f16, stride=(47040, 49, 7, 1)), T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), T([240, 1, 9, 9], f16), [0], [2, 2], [4, 4], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 7, 7], f16, stride=(47040, 49, 7, 1)), T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), T([240, 1, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 7, 7], f16, stride=(47040, 49, 7, 1)), T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), T([240, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 7, 7], f16, stride=(47040, 49, 7, 1)), T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), T([240, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([64, 160, 14, 14], f16), T([960, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([64, 80, 14, 14], f16, stride=(31360, 196, 14, 1)), T([64, 240, 14, 14], f16, stride=(94080, 196, 14, 1)), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 480, 1, 1], f16), T([64, 80, 1, 1], f16), T([480, 80, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 80, 1, 1], f16), T([64, 480, 1, 1], f16), T([80, 480, 1, 1], f16), [80], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 9, 9], f16), [0], [1, 1], [4, 4], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 6, ((T([64, 240, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 80, 14, 14], f16, stride=(31360, 196, 14, 1)), T([240, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 160, 14, 14], f16), T([64, 624, 14, 14], f16), T([160, 624, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 624, 1, 1], f16), T([64, 52, 1, 1], f16), T([624, 52, 1, 1], f16), [624], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 52, 1, 1], f16), T([64, 624, 1, 1], f16), T([52, 624, 1, 1], f16), [52], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16), T([624, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 624, [True, True, False]), {})
+cnt: 1, ((T([64, 624, 14, 14], f16), T([64, 104, 14, 14], f16), T([624, 104, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([64, 52, 14, 14], f16, stride=(20384, 196, 14, 1)), T([64, 312, 14, 14], f16, stride=(122304, 196, 14, 1)), T([52, 312, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 624, 1, 1], f16), T([64, 26, 1, 1], f16), T([624, 26, 1, 1], f16), [624], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 26, 1, 1], f16), T([64, 624, 1, 1], f16), T([26, 624, 1, 1], f16), [26], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 9, 9], f16), [0], [1, 1], [4, 4], [1, 1], False, [0, 0], 156, [True, True, False]), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 156, [True, True, False]), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 156, [True, True, False]), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 156, [True, True, False]), {})
+cnt: 6, ((T([64, 312, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 52, 14, 14], f16, stride=(20384, 196, 14, 1)), T([312, 52, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 104, 14, 14], f16), T([64, 336, 14, 14], f16), T([104, 336, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 336, 1, 1], f16), T([64, 14, 1, 1], f16), T([336, 14, 1, 1], f16), [336], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 14, 1, 1], f16), T([64, 336, 1, 1], f16), T([14, 336, 1, 1], f16), [14], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 112, 14, 14], f16, stride=(65856, 196, 14, 1)), T([64, 112, 28, 28], f16, stride=(263424, 784, 28, 1)), T([112, 1, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 112, [True, True, False]), {})
+cnt: 1, ((T([64, 112, 14, 14], f16, stride=(65856, 196, 14, 1)), T([64, 112, 28, 28], f16, stride=(263424, 784, 28, 1)), T([112, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 112, [True, True, False]), {})
+cnt: 1, ((T([64, 112, 14, 14], f16, stride=(65856, 196, 14, 1)), T([64, 112, 28, 28], f16, stride=(263424, 784, 28, 1)), T([112, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 112, [True, True, False]), {})
+cnt: 1, ((T([64, 336, 28, 28], f16), T([64, 56, 28, 28], f16), T([336, 56, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([64, 28, 28, 28], f16, stride=(43904, 784, 28, 1)), T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([28, 168, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 336, 1, 1], f16), T([64, 28, 1, 1], f16), T([336, 28, 1, 1], f16), [336], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 28, 1, 1], f16), T([64, 336, 1, 1], f16), T([28, 336, 1, 1], f16), [28], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([168, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 168, [True, True, False]), {})
+cnt: 3, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([168, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 168, [True, True, False]), {})
+cnt: 6, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([64, 28, 28, 28], f16, stride=(43904, 784, 28, 1)), T([168, 28, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 56, 28, 28], f16), T([64, 240, 28, 28], f16), T([56, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 1, 1], f16), T([64, 20, 1, 1], f16), T([240, 20, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 20, 1, 1], f16), T([64, 240, 1, 1], f16), T([20, 240, 1, 1], f16), [20], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 60, 28, 28], f16, stride=(188160, 784, 28, 1)), T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), T([60, 1, 9, 9], f16), [0], [2, 2], [4, 4], [1, 1], False, [0, 0], 60, [True, True, False]), {})
+cnt: 1, ((T([64, 60, 28, 28], f16, stride=(188160, 784, 28, 1)), T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), T([60, 1, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 60, [True, True, False]), {})
+cnt: 1, ((T([64, 60, 28, 28], f16, stride=(188160, 784, 28, 1)), T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), T([60, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 60, [True, True, False]), {})
+cnt: 1, ((T([64, 60, 28, 28], f16, stride=(188160, 784, 28, 1)), T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), T([60, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 60, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), T([64, 40, 56, 56], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 20, 56, 56], f16, stride=(125440, 3136, 56, 1)), T([64, 60, 56, 56], f16, stride=(376320, 3136, 56, 1)), T([20, 60, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 120, 56, 56], f16), T([64, 120, 56, 56], f16), T([120, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 2, ((T([64, 60, 56, 56], f16, stride=(376320, 3136, 56, 1)), T([64, 20, 56, 56], f16, stride=(125440, 3136, 56, 1)), T([60, 20, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 20, 56, 56], f16, stride=(125440, 3136, 56, 1)), T([64, 96, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([20, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([64, 64, 112, 112], f16, stride=(2408448, 12544, 112, 1)), T([64, 1, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([64, 64, 112, 112], f16, stride=(2408448, 12544, 112, 1)), T([64, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([64, 64, 112, 112], f16, stride=(2408448, 12544, 112, 1)), T([64, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 2, ((T([64, 96, 112, 112], f16, stride=(2408448, 12544, 112, 1)), T([64, 16, 112, 112], f16, stride=(401408, 12544, 112, 1)), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), T([32, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 1536, 7, 7], f16, stride=(1536, 1, 0, 0)), 49), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16, stride=(1584, 1, 0, 0)), 49), {})
+cnt: 1, ((T([64, 960, 7, 7], f16, stride=(960, 1, 0, 0)), 49), {})
+cnt: 3, ((T([64, 480, 14, 14], f16, stride=(480, 1, 0, 0)), 196), {})
+cnt: 4, ((T([64, 624, 14, 14], f16, stride=(624, 1, 0, 0)), 196), {})
+cnt: 1, ((T([64, 336, 14, 14], f16, stride=(336, 1, 0, 0)), 196), {})
+cnt: 3, ((T([64, 336, 28, 28], f16, stride=(336, 1, 0, 0)), 784), {})
+cnt: 1, ((T([64, 240, 28, 28], f16, stride=(240, 1, 0, 0)), 784), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 240, 28, 28], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 336, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), [2, 3], True), {})
+cnt: 4, ((T([64, 624, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 1536, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 1536], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 1536], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([64, 240, 28, 28], f16), T([64, 240, 1, 1], f16)), {})
+cnt: 6, ((T([64, 336, 28, 28], f16), T([64, 336, 1, 1], f16)), {})
+cnt: 2, ((T([64, 336, 14, 14], f16), T([64, 336, 1, 1], f16)), {})
+cnt: 8, ((T([64, 624, 14, 14], f16), T([64, 624, 1, 1], f16)), {})
+cnt: 6, ((T([64, 480, 14, 14], f16), T([64, 480, 1, 1], f16)), {})
+cnt: 2, ((T([64, 960, 7, 7], f16), T([64, 960, 1, 1], f16)), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16), T([64, 1584, 1, 1], f16)), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), T([64, 1584, 7, 7], f16)), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16)), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16)), {})
+cnt: 4, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16)), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([64, 336, 14, 14], f16)), {})
+cnt: 3, ((T([64, 336, 28, 28], f16), T([64, 336, 28, 28], f16)), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([64, 240, 28, 28], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 3, ((T([64, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 192, 112, 112], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([64, 40, 56, 56], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([64, 120, 56, 56], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 56, 28, 28], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([64, 336, 28, 28], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 104, 14, 14], f16), T([104], f16), T([104], f16), T([104], f16), T([104], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([64, 624, 14, 14], f16), T([624], f16), T([624], f16), T([624], f16), T([624], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 160, 14, 14], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([64, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 264, 7, 7], f16), T([264], f16), T([264], f16), T([264], f16), T([264], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16), T([1584], f16), T([1584], f16), T([1584], f16), T([1584], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 1536, 7, 7], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([64, 1536, 7, 7], f16), T([64, 1536, 7, 7], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f32), T([1536], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 264, 7, 7], f16), T([64, 264, 7, 7], f16), T([264], f16), T([264], f16), T([264], f16), T([264], f32), T([264], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16), T([64, 1584, 7, 7], f16), T([1584], f16), T([1584], f16), T([1584], f16), T([1584], f32), T([1584], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([64, 960, 14, 14], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 160, 14, 14], f16), T([64, 160, 14, 14], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16), T([624], f16), T([624], f16), T([624], f16), T([624], f32), T([624], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16), T([104], f16), T([104], f16), T([104], f16), T([104], f32), T([104], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([64, 336, 14, 14], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([64, 336, 28, 28], f16), T([64, 336, 28, 28], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 56, 28, 28], f16), T([64, 56, 28, 28], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f32), T([56], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([64, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), T([64, 240, 56, 56], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([64, 40, 56, 56], f16), T([64, 40, 56, 56], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([64, 120, 56, 56], f16), T([64, 120, 56, 56], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([64, 192, 56, 56], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 192, 112, 112], f16), T([64, 192, 112, 112], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([64, 32, 112, 112], f16),), {})
+cnt: 1, ((T([64, 192, 112, 112], f16),), {})
+cnt: 1, ((T([64, 192, 56, 56], f16),), {})
+cnt: 2, ((T([64, 120, 56, 56], f16),), {})
+cnt: 1, ((T([64, 1536, 7, 7], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([64, 240, 1, 1], f16),), {})
+cnt: 4, ((T([64, 336, 1, 1], f16),), {})
+cnt: 4, ((T([64, 624, 1, 1], f16),), {})
+cnt: 3, ((T([64, 480, 1, 1], f16),), {})
+cnt: 1, ((T([64, 960, 1, 1], f16),), {})
+cnt: 3, ((T([64, 1584, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 3, ((T([64, 1584, 1, 1], f16), T([64, 1584, 1, 1], f16)), {})
+cnt: 1, ((T([64, 960, 1, 1], f16), T([64, 960, 1, 1], f16)), {})
+cnt: 3, ((T([64, 480, 1, 1], f16), T([64, 480, 1, 1], f16)), {})
+cnt: 4, ((T([64, 624, 1, 1], f16), T([64, 624, 1, 1], f16)), {})
+cnt: 4, ((T([64, 336, 1, 1], f16), T([64, 336, 1, 1], f16)), {})
+cnt: 1, ((T([64, 240, 1, 1], f16), T([64, 240, 1, 1], f16)), {})
+Operator: aten.silu_.default
+cnt: 1, ((T([64, 240, 56, 56], f16),), {})
+cnt: 1, ((T([64, 240, 28, 28], f16),), {})
+cnt: 1, ((T([64, 20, 1, 1], f16),), {})
+cnt: 7, ((T([64, 336, 28, 28], f16),), {})
+cnt: 3, ((T([64, 28, 1, 1], f16),), {})
+cnt: 1, ((T([64, 336, 14, 14], f16),), {})
+cnt: 1, ((T([64, 14, 1, 1], f16),), {})
+cnt: 8, ((T([64, 624, 14, 14], f16),), {})
+cnt: 3, ((T([64, 26, 1, 1], f16),), {})
+cnt: 1, ((T([64, 52, 1, 1], f16),), {})
+cnt: 6, ((T([64, 480, 14, 14], f16),), {})
+cnt: 4, ((T([64, 80, 1, 1], f16),), {})
+cnt: 1, ((T([64, 960, 14, 14], f16),), {})
+cnt: 1, ((T([64, 960, 7, 7], f16),), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16),), {})
+cnt: 3, ((T([64, 132, 1, 1], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 3, ((T([64, 132, 1, 1], f16), T([64, 132, 1, 1], f16)), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16), T([64, 1584, 7, 7], f16)), {})
+cnt: 4, ((T([64, 80, 1, 1], f16), T([64, 80, 1, 1], f16)), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16)), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([64, 960, 14, 14], f16)), {})
+cnt: 6, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16)), {})
+cnt: 1, ((T([64, 52, 1, 1], f16), T([64, 52, 1, 1], f16)), {})
+cnt: 8, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16)), {})
+cnt: 3, ((T([64, 26, 1, 1], f16), T([64, 26, 1, 1], f16)), {})
+cnt: 1, ((T([64, 14, 1, 1], f16), T([64, 14, 1, 1], f16)), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([64, 336, 14, 14], f16)), {})
+cnt: 7, ((T([64, 336, 28, 28], f16), T([64, 336, 28, 28], f16)), {})
+cnt: 3, ((T([64, 28, 1, 1], f16), T([64, 28, 1, 1], f16)), {})
+cnt: 1, ((T([64, 20, 1, 1], f16), T([64, 20, 1, 1], f16)), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([64, 240, 28, 28], f16)), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), T([64, 240, 56, 56], f16)), {})
+Operator: aten.split_with_sizes.default
+cnt: 1, ((T([64, 32, 112, 112], f16), [16, 16], 1), {})
+cnt: 1, ((T([64, 192, 112, 112], f16), [64, 64, 64], 1), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), [96, 96], 1), {})
+cnt: 1, ((T([64, 40, 56, 56], f16), [20, 20], 1), {})
+cnt: 1, ((T([64, 120, 56, 56], f16), [60, 60], 1), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), [60, 60, 60, 60], 1), {})
+cnt: 3, ((T([64, 56, 28, 28], f16), [28, 28], 1), {})
+cnt: 6, ((T([64, 336, 28, 28], f16), [168, 168], 1), {})
+cnt: 1, ((T([64, 336, 28, 28], f16), [112, 112, 112], 1), {})
+cnt: 3, ((T([64, 104, 14, 14], f16), [52, 52], 1), {})
+cnt: 3, ((T([64, 624, 14, 14], f16), [156, 156, 156, 156], 1), {})
+cnt: 3, ((T([64, 624, 14, 14], f16), [312, 312], 1), {})
+cnt: 3, ((T([64, 160, 14, 14], f16), [80, 80], 1), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), [120, 120, 120, 120], 1), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), [240, 240], 1), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), [240, 240, 240, 240], 1), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), [396, 396, 396, 396], 1), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), [792, 792], 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 4, ((T([64, 624, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 336, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([64, 1536, 7, 7], f16), T([64, 1536, 7, 7], f16), 0), {})
+cnt: 2, ((T([64, 120, 56, 56], f16), T([64, 120, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([64, 192, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 192, 112, 112], f16), T([64, 192, 112, 112], f16), 0), {})
+cnt: 2, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mnasnet_100_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mnasnet_100_training.txt
new file mode 100644
index 0000000000000..6524a78aafe0f
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mnasnet_100_training.txt
@@ -0,0 +1,170 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 52, ((T([], i64), 1), {})
+cnt: 4, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 4, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16)), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16)), {})
+cnt: 2, ((T([128, 96, 14, 14], f16), T([128, 96, 14, 14], f16)), {})
+cnt: 6, ((T([128, 192, 7, 7], f16), T([128, 192, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([48, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([48, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 48), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([24, 48, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([72, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), T([72, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 72), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), T([24, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 72), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([40, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([120, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 120), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([40, 120, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([480, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 480), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([80, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([480, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 480), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([96, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 96, 14, 14], f16), T([576, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([576, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 576), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([96, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([576, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 576), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([192, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1152), {})
+cnt: 3, ((T([128, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1152), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([128, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 4, ((T([128, 1152, 7, 7], f16), T([128, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 7, 7], f16), T([128, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 7, 7], f16), T([128, 576, 7, 7], f16), T([192, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([128, 576, 14, 14], f16), T([576, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 576, [True, True, False]), {})
+cnt: 2, ((T([128, 576, 14, 14], f16), T([128, 96, 14, 14], f16), T([576, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 14, 14], f16), T([128, 576, 14, 14], f16), T([96, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([128, 576, 14, 14], f16), T([576, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 576, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 14, 14], f16), T([128, 480, 14, 14], f16), T([96, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 3, ((T([128, 480, 14, 14], f16), T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([128, 480, 14, 14], f16), T([80, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([128, 120, 28, 28], f16), T([40, 120, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([128, 72, 28, 28], f16), T([40, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 3, ((T([128, 72, 56, 56], f16), T([128, 24, 56, 56], f16), T([72, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 72, 56, 56], f16), T([24, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), T([72, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 48, 56, 56], f16), T([24, 48, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 112, 112], f16), T([48, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 48, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([128, 16, 112, 112], f16), T([48, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1280, 7, 7], f16, stride=(1280, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 1280, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([128, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f32), T([1152], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 192, 7, 7], f16), T([128, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([128, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 576, 14, 14], f16), T([128, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 96, 14, 14], f16), T([128, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([128, 48, 112, 112], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 48, 112, 112], f16),), {})
+cnt: 1, ((T([128, 48, 56, 56], f16),), {})
+cnt: 5, ((T([128, 72, 56, 56], f16),), {})
+cnt: 1, ((T([128, 72, 28, 28], f16),), {})
+cnt: 4, ((T([128, 120, 28, 28], f16),), {})
+cnt: 1, ((T([128, 240, 28, 28], f16),), {})
+cnt: 1, ((T([128, 240, 14, 14], f16),), {})
+cnt: 6, ((T([128, 480, 14, 14], f16),), {})
+cnt: 3, ((T([128, 576, 14, 14], f16),), {})
+cnt: 1, ((T([128, 576, 7, 7], f16),), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16), 0), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([128, 576, 7, 7], f16), 0), {})
+cnt: 3, ((T([128, 576, 14, 14], f16), T([128, 576, 14, 14], f16), 0), {})
+cnt: 6, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), 0), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), 0), {})
+cnt: 5, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([128, 48, 112, 112], f16), 0), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mobilenetv2_100_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mobilenetv2_100_training.txt
new file mode 100644
index 0000000000000..4c6b5706f2741
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mobilenetv2_100_training.txt
@@ -0,0 +1,172 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 52, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 4, ((T([128, 32, 28, 28], f16), T([128, 32, 28, 28], f16)), {})
+cnt: 6, ((T([128, 64, 14, 14], f16), T([128, 64, 14, 14], f16)), {})
+cnt: 4, ((T([128, 96, 14, 14], f16), T([128, 96, 14, 14], f16)), {})
+cnt: 4, ((T([128, 160, 7, 7], f16), T([128, 160, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+cnt: 2, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 96, 112, 112], f16),), {})
+cnt: 1, ((T([128, 96, 56, 56], f16),), {})
+cnt: 3, ((T([128, 144, 56, 56], f16),), {})
+cnt: 1, ((T([128, 144, 28, 28], f16),), {})
+cnt: 5, ((T([128, 192, 28, 28], f16),), {})
+cnt: 1, ((T([128, 192, 14, 14], f16),), {})
+cnt: 8, ((T([128, 384, 14, 14], f16),), {})
+cnt: 5, ((T([128, 576, 14, 14], f16),), {})
+cnt: 1, ((T([128, 576, 7, 7], f16),), {})
+cnt: 6, ((T([128, 960, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([96, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([24, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([144, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([144, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([24, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([144, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([32, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 32, 28, 28], f16), T([192, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 192, 28, 28], f16), T([192, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 192), {})
+cnt: 2, ((T([128, 192, 28, 28], f16), T([32, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([192, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 192), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([64, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 64, 14, 14], f16), T([384, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 14, 14], f16), T([384, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 384), {})
+cnt: 3, ((T([128, 384, 14, 14], f16), T([64, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([96, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 96, 14, 14], f16), T([576, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 576, 14, 14], f16), T([576, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 576), {})
+cnt: 2, ((T([128, 576, 14, 14], f16), T([96, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([576, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 576), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([160, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 160, 7, 7], f16), T([960, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 960, 7, 7], f16), T([960, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 960), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), T([160, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([320, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([128, 960, 7, 7], f16), T([320, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), T([960, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 960, [True, True, False]), {})
+cnt: 3, ((T([128, 960, 7, 7], f16), T([128, 160, 7, 7], f16), T([960, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 160, 7, 7], f16), T([128, 960, 7, 7], f16), T([160, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 160, 7, 7], f16), T([128, 576, 7, 7], f16), T([160, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([128, 576, 14, 14], f16), T([576, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 576, [True, True, False]), {})
+cnt: 3, ((T([128, 576, 14, 14], f16), T([128, 96, 14, 14], f16), T([576, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 96, 14, 14], f16), T([128, 576, 14, 14], f16), T([96, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 576, 14, 14], f16), T([128, 576, 14, 14], f16), T([576, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 576, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 14, 14], f16), T([128, 384, 14, 14], f16), T([96, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 384, [True, True, False]), {})
+cnt: 4, ((T([128, 384, 14, 14], f16), T([128, 64, 14, 14], f16), T([384, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 14, 14], f16), T([128, 384, 14, 14], f16), T([64, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 14, 14], f16), T([128, 192, 14, 14], f16), T([64, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([128, 192, 28, 28], f16), T([192, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 192, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 28, 28], f16), T([128, 32, 28, 28], f16), T([192, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 32, 28, 28], f16), T([128, 192, 28, 28], f16), T([32, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), T([192, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 192, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 28, 28], f16), T([128, 144, 28, 28], f16), T([32, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 56, 56], f16), T([144, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 2, ((T([128, 144, 56, 56], f16), T([128, 24, 56, 56], f16), T([144, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 144, 56, 56], f16), T([24, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16), T([144, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 96, 56, 56], f16), T([24, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 112, 112], f16), T([96, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 16, 112, 112], f16), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1280, 7, 7], f16, stride=(1280, 1, 0, 0)), 49), {})
+Operator: aten.hardtanh_.default
+cnt: 2, ((T([128, 32, 112, 112], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), 0.0, 6.0), {})
+cnt: 3, ((T([128, 144, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), 0.0, 6.0), {})
+cnt: 5, ((T([128, 192, 28, 28], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), 0.0, 6.0), {})
+cnt: 8, ((T([128, 384, 14, 14], f16), 0.0, 6.0), {})
+cnt: 5, ((T([128, 576, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), 0.0, 6.0), {})
+cnt: 6, ((T([128, 960, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16), 0.0, 6.0), {})
+Operator: aten.hardtanh_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16), 0.0, 6.0), {})
+cnt: 6, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([128, 576, 7, 7], f16), 0.0, 6.0), {})
+cnt: 5, ((T([128, 576, 14, 14], f16), T([128, 576, 14, 14], f16), 0.0, 6.0), {})
+cnt: 8, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), 0.0, 6.0), {})
+cnt: 5, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), 0.0, 6.0), {})
+cnt: 3, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 96, 112, 112], f16), 0.0, 6.0), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), 0.0, 6.0), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 1280, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 32, 28, 28], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 64, 14, 14], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([128, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([128, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 160, 7, 7], f16), T([128, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([128, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 576, 14, 14], f16), T([128, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 96, 14, 14], f16), T([128, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 64, 14, 14], f16), T([128, 64, 14, 14], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 192, 14, 14], f16), T([128, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 32, 28, 28], f16), T([128, 32, 28, 28], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mobilenetv3_large_100_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mobilenetv3_large_100_training.txt
new file mode 100644
index 0000000000000..df2ab44bf9f78
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mobilenetv3_large_100_training.txt
@@ -0,0 +1,269 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 46, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 4, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16)), {})
+cnt: 6, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16)), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16)), {})
+cnt: 4, ((T([128, 160, 7, 7], f16), T([128, 160, 7, 7], f16)), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16)), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+cnt: 1, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 240, 28, 28], f16),), {})
+cnt: 1, ((T([128, 240, 14, 14], f16),), {})
+cnt: 2, ((T([128, 200, 14, 14], f16),), {})
+cnt: 4, ((T([128, 184, 14, 14], f16),), {})
+cnt: 2, ((T([128, 480, 14, 14], f16),), {})
+cnt: 3, ((T([128, 672, 14, 14], f16),), {})
+cnt: 1, ((T([128, 672, 7, 7], f16),), {})
+cnt: 5, ((T([128, 960, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 1, 1], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([64, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([24, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([72, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([72, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 72), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([24, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 72), {})
+cnt: 1, ((T([128, 72, 1, 1], f16), T([24, 72, 1, 1], f16), T([24], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([72, 24, 1, 1], f16), T([72], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([40, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([120, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 120), {})
+cnt: 2, ((T([128, 120, 1, 1], f16), T([32, 120, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 32, 1, 1], f16), T([120, 32, 1, 1], f16), T([120], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([40, 120, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([200, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([200, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 200), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([80, 200, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([184, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 184, 14, 14], f16), T([184, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 184), {})
+cnt: 2, ((T([128, 184, 14, 14], f16), T([80, 184, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([480, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 480), {})
+cnt: 1, ((T([128, 480, 1, 1], f16), T([120, 480, 1, 1], f16), T([120], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([480, 120, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([112, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([672, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([672, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 672), {})
+cnt: 2, ((T([128, 672, 1, 1], f16), T([168, 672, 1, 1], f16), T([168], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 168, 1, 1], f16), T([672, 168, 1, 1], f16), T([672], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([112, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([160, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 160, 7, 7], f16), T([960, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), T([960, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 960), {})
+cnt: 2, ((T([128, 960, 1, 1], f16), T([240, 960, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([960, 240, 1, 1], f16), T([960], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), T([160, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 960, 1, 1], f16), T([1280, 960, 1, 1], f16), T([1280], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 1, 1], f16), T([128, 960, 1, 1], f16), T([1280, 960, 1, 1], f16), [1280], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 960, 7, 7], f16), T([128, 160, 7, 7], f16), T([960, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 160, 7, 7], f16), T([128, 960, 7, 7], f16), T([160, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 960, 1, 1], f16), T([128, 240, 1, 1], f16), T([960, 240, 1, 1], f16), [960], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 960, 1, 1], f16), T([240, 960, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), T([960, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 960, [True, True, False]), {})
+cnt: 1, ((T([128, 160, 7, 7], f16), T([128, 672, 7, 7], f16), T([160, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 672, 1, 1], f16), T([128, 168, 1, 1], f16), T([672, 168, 1, 1], f16), [672], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 168, 1, 1], f16), T([128, 672, 1, 1], f16), T([168, 672, 1, 1], f16), [168], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([128, 112, 14, 14], f16), T([672, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([128, 672, 14, 14], f16), T([112, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16), T([672, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([128, 480, 14, 14], f16), T([112, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 480, 1, 1], f16), T([128, 120, 1, 1], f16), T([480, 120, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([128, 480, 1, 1], f16), T([120, 480, 1, 1], f16), [120], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([128, 184, 14, 14], f16), T([80, 184, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 184, 14, 14], f16), T([128, 184, 14, 14], f16), T([184, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 184, [True, True, False]), {})
+cnt: 2, ((T([128, 184, 14, 14], f16), T([128, 80, 14, 14], f16), T([184, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([128, 200, 14, 14], f16), T([80, 200, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([128, 200, 14, 14], f16), T([200, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 200, [True, True, False]), {})
+cnt: 1, ((T([128, 200, 14, 14], f16), T([128, 80, 14, 14], f16), T([200, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 28, 28], f16), T([240, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([128, 120, 28, 28], f16), T([40, 120, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 120, 1, 1], f16), T([128, 32, 1, 1], f16), T([120, 32, 1, 1], f16), [120], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 1, 1], f16), T([128, 120, 1, 1], f16), T([32, 120, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([128, 72, 28, 28], f16), T([40, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 72, 1, 1], f16), T([128, 24, 1, 1], f16), T([72, 24, 1, 1], f16), [72], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([128, 72, 1, 1], f16), T([24, 72, 1, 1], f16), [24], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 56, 56], f16), T([72, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), T([128, 24, 56, 56], f16), T([72, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 72, 56, 56], f16), T([24, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), T([72, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 64, 56, 56], f16), T([24, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 112, 112], f16), T([64, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 16, 112, 112], f16), T([64, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 16, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 3, ((T([128, 960, 7, 7], f16, stride=(960, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 672, 7, 7], f16, stride=(672, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 672, 14, 14], f16, stride=(672, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 480, 14, 14], f16, stride=(480, 1, 0, 0)), 196), {})
+cnt: 2, ((T([128, 120, 28, 28], f16, stride=(120, 1, 0, 0)), 784), {})
+cnt: 1, ((T([128, 72, 28, 28], f16, stride=(72, 1, 0, 0)), 784), {})
+Operator: aten.hardsigmoid.default
+cnt: 1, ((T([128, 72, 1, 1], f16),), {})
+cnt: 2, ((T([128, 120, 1, 1], f16),), {})
+cnt: 1, ((T([128, 480, 1, 1], f16),), {})
+cnt: 2, ((T([128, 672, 1, 1], f16),), {})
+cnt: 2, ((T([128, 960, 1, 1], f16),), {})
+Operator: aten.hardsigmoid_backward.default
+cnt: 2, ((T([128, 960, 1, 1], f16), T([128, 960, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 1, 1], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 1, ((T([128, 480, 1, 1], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 2, ((T([128, 120, 1, 1], f16), T([128, 120, 1, 1], f16)), {})
+cnt: 1, ((T([128, 72, 1, 1], f16), T([128, 72, 1, 1], f16)), {})
+Operator: aten.hardswish_.default
+cnt: 1, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 240, 28, 28], f16),), {})
+cnt: 1, ((T([128, 240, 14, 14], f16),), {})
+cnt: 2, ((T([128, 200, 14, 14], f16),), {})
+cnt: 4, ((T([128, 184, 14, 14], f16),), {})
+cnt: 2, ((T([128, 480, 14, 14], f16),), {})
+cnt: 3, ((T([128, 672, 14, 14], f16),), {})
+cnt: 1, ((T([128, 672, 7, 7], f16),), {})
+cnt: 5, ((T([128, 960, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 1, 1], f16),), {})
+Operator: aten.hardswish_backward.default
+cnt: 1, ((T([128, 1280, 1, 1], f16), T([128, 1280, 1, 1], f16)), {})
+cnt: 5, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 3, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 4, ((T([128, 184, 14, 14], f16), T([128, 184, 14, 14], f16)), {})
+cnt: 2, ((T([128, 200, 14, 14], f16), T([128, 200, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16)), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 72, 28, 28], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 72, 28, 28], f16), T([128, 72, 1, 1], f16)), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([128, 120, 1, 1], f16)), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 7, 7], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 4, ((T([128, 960, 7, 7], f16), T([128, 960, 1, 1], f16)), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16)), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 3, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 200, 14, 14], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 184, 14, 14], f16), T([184], f16), T([184], f16), T([184], f16), T([184], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 5, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 160, 7, 7], f16), T([128, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 184, 14, 14], f16), T([128, 184, 14, 14], f16), T([184], f16), T([184], f16), T([184], f16), T([184], f32), T([184], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 200, 14, 14], f16), T([128, 200, 14, 14], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f32), T([200], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 1, ((T([128, 64, 56, 56], f16),), {})
+cnt: 3, ((T([128, 72, 56, 56], f16),), {})
+cnt: 1, ((T([128, 72, 28, 28], f16),), {})
+cnt: 1, ((T([128, 24, 1, 1], f16),), {})
+cnt: 4, ((T([128, 120, 28, 28], f16),), {})
+cnt: 2, ((T([128, 32, 1, 1], f16),), {})
+cnt: 1, ((T([128, 120, 1, 1], f16),), {})
+cnt: 2, ((T([128, 168, 1, 1], f16),), {})
+cnt: 2, ((T([128, 240, 1, 1], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 2, ((T([128, 960, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 120, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 240, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 168, 1, 1], f16), T([128, 168, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 120, 1, 1], f16), T([128, 120, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 32, 1, 1], f16), T([128, 32, 1, 1], f16), 0), {})
+cnt: 4, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([128, 24, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), 0), {})
+cnt: 3, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), 0), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mobilevit_s_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mobilevit_s_training.txt
new file mode 100644
index 0000000000000..ce3dba3ad0a77
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/mobilevit_s_training.txt
@@ -0,0 +1,313 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 2, ((T([256, 4, 256, 256], f16), -1, False), {})
+cnt: 4, ((T([256, 4, 64, 64], f16), -1, False), {})
+cnt: 3, ((T([256, 4, 16, 16], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 3, ((T([256, 4, 16, 16], f16), T([256, 4, 16, 16], f16), -1, f16), {})
+cnt: 4, ((T([256, 4, 64, 64], f16), T([256, 4, 64, 64], f16), -1, f16), {})
+cnt: 2, ((T([256, 4, 256, 256], f16), T([256, 4, 256, 256], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 2, ((T([147456, 16, 2, 2], f16), [64, 144, 256, 4]), {})
+cnt: 2, ((T([64, 4, 256, 144], f16), [256, 256, 144]), {})
+cnt: 6, ((T([256, 4, 256, 36], f16), [1024, 256, 36]), {})
+cnt: 2, ((T([256, 4, 36, 256], f16), [1024, 36, 256]), {})
+cnt: 2, ((T([1024, 256, 256], f16), [256, 4, 256, 256]), {})
+cnt: 2, ((T([1024, 256, 36], f16), [256, 4, 256, 36]), {})
+cnt: 2, ((T([256, 256, 4, 36], f16), [256, 256, 144]), {})
+cnt: 2, ((T([64, 144, 256, 4], f16), [147456, 16, 2, 2]), {})
+cnt: 2, ((T([147456, 2, 16, 2], f16), [64, 144, 32, 32]), {})
+cnt: 2, ((T([98304, 8, 2, 2], f16), [64, 192, 64, 4]), {})
+cnt: 2, ((T([64, 4, 64, 192], f16), [256, 64, 192]), {})
+cnt: 12, ((T([256, 4, 64, 48], f16), [1024, 64, 48]), {})
+cnt: 4, ((T([256, 4, 48, 64], f16), [1024, 48, 64]), {})
+cnt: 4, ((T([1024, 64, 64], f16), [256, 4, 64, 64]), {})
+cnt: 4, ((T([1024, 64, 48], f16), [256, 4, 64, 48]), {})
+cnt: 4, ((T([256, 64, 4, 48], f16), [256, 64, 192]), {})
+cnt: 2, ((T([64, 192, 64, 4], f16), [98304, 8, 2, 2]), {})
+cnt: 2, ((T([98304, 2, 8, 2], f16), [64, 192, 16, 16]), {})
+cnt: 2, ((T([61440, 4, 2, 2], f16), [64, 240, 16, 4]), {})
+cnt: 2, ((T([64, 4, 16, 240], f16), [256, 16, 240]), {})
+cnt: 9, ((T([256, 4, 16, 60], f16), [1024, 16, 60]), {})
+cnt: 3, ((T([256, 4, 60, 16], f16), [1024, 60, 16]), {})
+cnt: 3, ((T([1024, 16, 16], f16), [256, 4, 16, 16]), {})
+cnt: 3, ((T([1024, 16, 60], f16), [256, 4, 16, 60]), {})
+cnt: 3, ((T([256, 16, 4, 60], f16), [256, 16, 240]), {})
+cnt: 2, ((T([64, 240, 16, 4], f16), [61440, 4, 2, 2]), {})
+cnt: 2, ((T([61440, 2, 4, 2], f16), [64, 240, 8, 8]), {})
+cnt: 3, ((T([256, 16, 3, 4, 60], f16), [256, 16, 720]), {})
+cnt: 4, ((T([256, 64, 3, 4, 48], f16), [256, 64, 576]), {})
+cnt: 2, ((T([256, 256, 3, 4, 36], f16), [256, 256, 432]), {})
+Operator: aten.add.Tensor
+cnt: 32, ((T([], i64), 1), {})
+cnt: 4, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16)), {})
+cnt: 8, ((T([256, 256, 144], f16), T([256, 256, 144], f16)), {})
+cnt: 16, ((T([256, 64, 192], f16), T([256, 64, 192], f16)), {})
+cnt: 12, ((T([256, 16, 240], f16), T([256, 16, 240], f16)), {})
+cnt: 1, ((T([64, 160, 8, 8], f16, stride=(20480, 64, 8, 1)), T([64, 160, 8, 8], f16)), {})
+cnt: 1, ((T([64, 128, 16, 16], f16, stride=(65536, 256, 16, 1)), T([64, 128, 16, 16], f16)), {})
+cnt: 1, ((T([64, 96, 32, 32], f16, stride=(196608, 1024, 32, 1)), T([64, 96, 32, 32], f16)), {})
+Operator: aten.addmm.default
+cnt: 2, ((T([432], f16), T([65536, 144], f16), T([144, 432], f16, stride=(1, 144))), {})
+cnt: 2, ((T([144], f16), T([65536, 144], f16), T([144, 144], f16, stride=(1, 144))), {})
+cnt: 2, ((T([288], f16), T([65536, 144], f16), T([144, 288], f16, stride=(1, 144))), {})
+cnt: 2, ((T([144], f16), T([65536, 288], f16), T([288, 144], f16, stride=(1, 288))), {})
+cnt: 4, ((T([576], f16), T([16384, 192], f16), T([192, 576], f16, stride=(1, 192))), {})
+cnt: 4, ((T([192], f16), T([16384, 192], f16), T([192, 192], f16, stride=(1, 192))), {})
+cnt: 4, ((T([384], f16), T([16384, 192], f16), T([192, 384], f16, stride=(1, 192))), {})
+cnt: 4, ((T([192], f16), T([16384, 384], f16), T([384, 192], f16, stride=(1, 384))), {})
+cnt: 3, ((T([720], f16), T([4096, 240], f16), T([240, 720], f16, stride=(1, 240))), {})
+cnt: 3, ((T([240], f16), T([4096, 240], f16), T([240, 240], f16, stride=(1, 240))), {})
+cnt: 3, ((T([480], f16), T([4096, 240], f16), T([240, 480], f16, stride=(1, 240))), {})
+cnt: 3, ((T([240], f16), T([4096, 480], f16), T([480, 240], f16, stride=(1, 480))), {})
+cnt: 1, ((T([1000], f16), T([64, 640], f16), T([640, 1000], f16, stride=(1, 640))), {})
+Operator: aten.bmm.default
+cnt: 2, ((T([1024, 256, 36], f16), T([1024, 36, 256], f16)), {})
+cnt: 2, ((T([1024, 256, 256], f16), T([1024, 256, 36], f16)), {})
+cnt: 4, ((T([1024, 64, 48], f16), T([1024, 48, 64], f16)), {})
+cnt: 4, ((T([1024, 64, 64], f16), T([1024, 64, 48], f16)), {})
+cnt: 3, ((T([1024, 16, 60], f16), T([1024, 60, 16], f16)), {})
+cnt: 3, ((T([1024, 16, 16], f16), T([1024, 16, 60], f16)), {})
+cnt: 3, ((T([1024, 16, 16], f16, stride=(256, 1, 16)), T([1024, 16, 60], f16)), {})
+cnt: 3, ((T([1024, 16, 60], f16), T([1024, 60, 16], f16, stride=(960, 1, 60))), {})
+cnt: 3, ((T([1024, 60, 16], f16, stride=(960, 1, 60)), T([1024, 16, 16], f16)), {})
+cnt: 3, ((T([1024, 16, 16], f16), T([1024, 16, 60], f16, stride=(960, 1, 16))), {})
+cnt: 4, ((T([1024, 64, 64], f16, stride=(4096, 1, 64)), T([1024, 64, 48], f16)), {})
+cnt: 4, ((T([1024, 64, 48], f16), T([1024, 48, 64], f16, stride=(3072, 1, 48))), {})
+cnt: 4, ((T([1024, 48, 64], f16, stride=(3072, 1, 48)), T([1024, 64, 64], f16)), {})
+cnt: 4, ((T([1024, 64, 64], f16), T([1024, 64, 48], f16, stride=(3072, 1, 64))), {})
+cnt: 2, ((T([1024, 256, 256], f16, stride=(65536, 1, 256)), T([1024, 256, 36], f16)), {})
+cnt: 2, ((T([1024, 256, 36], f16), T([1024, 36, 256], f16, stride=(9216, 1, 36))), {})
+cnt: 2, ((T([1024, 36, 256], f16, stride=(9216, 1, 36)), T([1024, 256, 256], f16)), {})
+cnt: 2, ((T([1024, 256, 256], f16), T([1024, 256, 36], f16, stride=(9216, 1, 256))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 96, 32, 32], f16), T([64, 96, 32, 32], f16)], 1), {})
+cnt: 1, (([T([64, 128, 16, 16], f16), T([64, 128, 16, 16], f16)], 1), {})
+cnt: 1, (([T([64, 160, 8, 8], f16), T([64, 160, 8, 8], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 256, 256], f16),), {})
+cnt: 1, ((T([64, 16, 128, 128], f16),), {})
+cnt: 2, ((T([64, 64, 128, 128], f16),), {})
+cnt: 1, ((T([64, 128, 128, 128], f16),), {})
+cnt: 1, ((T([64, 128, 64, 64], f16),), {})
+cnt: 5, ((T([64, 256, 64, 64], f16),), {})
+cnt: 1, ((T([64, 256, 32, 32], f16),), {})
+cnt: 3, ((T([64, 96, 32, 32], f16),), {})
+cnt: 1, ((T([64, 384, 32, 32], f16),), {})
+cnt: 1, ((T([64, 384, 16, 16], f16),), {})
+cnt: 3, ((T([64, 128, 16, 16], f16),), {})
+cnt: 1, ((T([64, 512, 16, 16], f16),), {})
+cnt: 1, ((T([64, 512, 8, 8], f16),), {})
+cnt: 3, ((T([64, 160, 8, 8], f16),), {})
+cnt: 1, ((T([64, 640, 8, 8], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 256, 256], f16), T([16, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 16, 128, 128], f16), T([64, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([32, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([128, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([128, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 128), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([64, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 64, 64, 64], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 256, 64, 64], f16), T([256, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 256), {})
+cnt: 2, ((T([64, 256, 64, 64], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 64, 64], f16), T([256, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 256), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([96, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 96, 32, 32], f16), T([96, 96, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 96, 32, 32], f16), T([144, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 144, 32, 32], f16), T([96, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 192, 32, 32], f16), T([96, 192, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 96, 32, 32], f16), T([384, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 384, 32, 32], f16), T([384, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 384), {})
+cnt: 1, ((T([64, 384, 16, 16], f16), T([128, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 16, 16], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 16, 16], f16), T([192, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 192, 16, 16], f16), T([128, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 16, 16], f16), T([128, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 16, 16], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([512, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 512), {})
+cnt: 1, ((T([64, 512, 8, 8], f16), T([160, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 160, 8, 8], f16), T([160, 160, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 160, 8, 8], f16), T([240, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 240, 8, 8], f16), T([160, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 320, 8, 8], f16), T([160, 320, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 160, 8, 8], f16), T([640, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 640, 8, 8], f16), T([64, 160, 8, 8], f16), T([640, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 160, 8, 8], f16), T([64, 320, 8, 8], f16), T([160, 320, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 160, 8, 8], f16), T([64, 240, 8, 8], f16), T([160, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 8, 8], f16), T([64, 160, 8, 8], f16), T([240, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 160, 8, 8], f16), T([64, 160, 8, 8], f16), T([160, 160, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 160, 8, 8], f16), T([64, 512, 8, 8], f16), T([160, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 8, 8], f16), T([64, 512, 16, 16], f16), T([512, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 512, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([64, 128, 16, 16], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 16, 16], f16), T([64, 256, 16, 16], f16), T([128, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 16, 16], f16), T([64, 192, 16, 16], f16), T([128, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 192, 16, 16], f16), T([64, 128, 16, 16], f16), T([192, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 16, 16], f16), T([64, 128, 16, 16], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 16, 16], f16), T([64, 384, 16, 16], f16), T([128, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 384, 16, 16], f16), T([64, 384, 32, 32], f16), T([384, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 384, [True, True, False]), {})
+cnt: 1, ((T([64, 384, 32, 32], f16), T([64, 96, 32, 32], f16), T([384, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 96, 32, 32], f16), T([64, 192, 32, 32], f16), T([96, 192, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 96, 32, 32], f16), T([64, 144, 32, 32], f16), T([96, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 144, 32, 32], f16), T([64, 96, 32, 32], f16), T([144, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 96, 32, 32], f16), T([64, 96, 32, 32], f16), T([96, 96, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 96, 32, 32], f16), T([64, 256, 32, 32], f16), T([96, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([64, 256, 64, 64], f16), T([256, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 256, [True, True, False]), {})
+cnt: 3, ((T([64, 256, 64, 64], f16), T([64, 64, 64, 64], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 64, 64, 64], f16), T([64, 256, 64, 64], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 256, 64, 64], f16), T([64, 256, 64, 64], f16), T([256, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 256, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 64, 64], f16), T([64, 128, 64, 64], f16), T([64, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([64, 128, 128, 128], f16), T([128, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 128, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([64, 32, 128, 128], f16), T([128, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 64, 128, 128], f16), T([32, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 64, 128, 128], f16), T([64, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 128, 128], f16), T([64, 16, 128, 128], f16), T([64, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 16, 128, 128], f16), T([64, 3, 256, 256], f16), T([16, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 256, 256], f16), T([64, 3, 256, 256], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 640, 8, 8], f16, stride=(640, 1, 0, 0)), 64), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 640, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 640], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 640], f16)), {})
+cnt: 3, ((T([4096, 240], f16), T([240, 480], f16)), {})
+cnt: 3, ((T([240, 4096], f16, stride=(1, 240)), T([4096, 480], f16)), {})
+cnt: 3, ((T([4096, 480], f16), T([480, 240], f16)), {})
+cnt: 3, ((T([480, 4096], f16, stride=(1, 480)), T([4096, 240], f16)), {})
+cnt: 3, ((T([4096, 240], f16), T([240, 240], f16)), {})
+cnt: 3, ((T([240, 4096], f16, stride=(1, 240)), T([4096, 240], f16)), {})
+cnt: 3, ((T([4096, 720], f16), T([720, 240], f16)), {})
+cnt: 3, ((T([720, 4096], f16, stride=(1, 720)), T([4096, 240], f16)), {})
+cnt: 4, ((T([16384, 192], f16), T([192, 384], f16)), {})
+cnt: 4, ((T([192, 16384], f16, stride=(1, 192)), T([16384, 384], f16)), {})
+cnt: 4, ((T([16384, 384], f16), T([384, 192], f16)), {})
+cnt: 4, ((T([384, 16384], f16, stride=(1, 384)), T([16384, 192], f16)), {})
+cnt: 4, ((T([16384, 192], f16), T([192, 192], f16)), {})
+cnt: 4, ((T([192, 16384], f16, stride=(1, 192)), T([16384, 192], f16)), {})
+cnt: 4, ((T([16384, 576], f16), T([576, 192], f16)), {})
+cnt: 4, ((T([576, 16384], f16, stride=(1, 576)), T([16384, 192], f16)), {})
+cnt: 2, ((T([65536, 144], f16), T([144, 288], f16)), {})
+cnt: 2, ((T([144, 65536], f16, stride=(1, 144)), T([65536, 288], f16)), {})
+cnt: 2, ((T([65536, 288], f16), T([288, 144], f16)), {})
+cnt: 2, ((T([288, 65536], f16, stride=(1, 288)), T([65536, 144], f16)), {})
+cnt: 2, ((T([65536, 144], f16), T([144, 144], f16)), {})
+cnt: 2, ((T([144, 65536], f16, stride=(1, 144)), T([65536, 144], f16)), {})
+cnt: 2, ((T([65536, 432], f16), T([432, 144], f16)), {})
+cnt: 2, ((T([432, 65536], f16, stride=(1, 432)), T([65536, 144], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([256, 4, 256, 256], f16), 0.16666666666666666), {})
+cnt: 8, ((T([256, 4, 64, 64], f16), 0.14433756729740643), {})
+cnt: 6, ((T([256, 4, 16, 16], f16), 0.12909944487358058), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([64, 16, 128, 128], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([64, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([64, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 96, 32, 32], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 384, 32, 32], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 384, 16, 16], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 128, 16, 16], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 160, 8, 8], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 640, 8, 8], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([64, 640, 8, 8], f16), T([64, 640, 8, 8], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f32), T([640], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 160, 8, 8], f16), T([64, 160, 8, 8], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 128, 16, 16], f16), T([64, 128, 16, 16], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 384, 16, 16], f16), T([64, 384, 16, 16], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 384, 32, 32], f16), T([64, 384, 32, 32], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 96, 32, 32], f16), T([64, 96, 32, 32], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([64, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([64, 256, 64, 64], f16), T([64, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([64, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([64, 128, 128, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([64, 64, 128, 128], f16), T([64, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 16, 128, 128], f16), T([64, 16, 128, 128], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.native_layer_norm.default
+cnt: 5, ((T([256, 256, 144], f16), [144], T([144], f16), T([144], f16), 1e-05), {})
+cnt: 9, ((T([256, 64, 192], f16), [192], T([192], f16), T([192], f16), 1e-05), {})
+cnt: 7, ((T([256, 16, 240], f16), [240], T([240], f16), T([240], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 7, ((T([256, 16, 240], f16), T([256, 16, 240], f16), [240], T([256, 16, 1], f32), T([256, 16, 1], f32), T([240], f16), T([240], f16), [True, True, True]), {})
+cnt: 9, ((T([256, 64, 192], f16), T([256, 64, 192], f16), [192], T([256, 64, 1], f32), T([256, 64, 1], f32), T([192], f16), T([192], f16), [True, True, True]), {})
+cnt: 5, ((T([256, 256, 144], f16), T([256, 256, 144], f16), [144], T([256, 256, 1], f32), T([256, 256, 1], f32), T([144], f16), T([144], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.silu.default
+cnt: 2, ((T([256, 256, 288], f16),), {})
+cnt: 4, ((T([256, 64, 384], f16),), {})
+cnt: 3, ((T([256, 16, 480], f16),), {})
+Operator: aten.silu_.default
+cnt: 1, ((T([64, 16, 128, 128], f16),), {})
+cnt: 2, ((T([64, 64, 128, 128], f16),), {})
+cnt: 1, ((T([64, 128, 128, 128], f16),), {})
+cnt: 1, ((T([64, 128, 64, 64], f16),), {})
+cnt: 5, ((T([64, 256, 64, 64], f16),), {})
+cnt: 1, ((T([64, 256, 32, 32], f16),), {})
+cnt: 3, ((T([64, 96, 32, 32], f16),), {})
+cnt: 1, ((T([64, 384, 32, 32], f16),), {})
+cnt: 1, ((T([64, 384, 16, 16], f16),), {})
+cnt: 3, ((T([64, 128, 16, 16], f16),), {})
+cnt: 1, ((T([64, 512, 16, 16], f16),), {})
+cnt: 1, ((T([64, 512, 8, 8], f16),), {})
+cnt: 3, ((T([64, 160, 8, 8], f16),), {})
+cnt: 1, ((T([64, 640, 8, 8], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 1, ((T([64, 640, 8, 8], f16), T([64, 640, 8, 8], f16)), {})
+cnt: 2, ((T([64, 160, 8, 8], f16), T([64, 160, 8, 8], f16)), {})
+cnt: 1, ((T([64, 160, 8, 8], f16, stride=(20480, 64, 8, 1)), T([64, 160, 8, 8], f16)), {})
+cnt: 3, ((T([256, 16, 480], f16), T([256, 16, 480], f16)), {})
+cnt: 1, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16)), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16)), {})
+cnt: 2, ((T([64, 128, 16, 16], f16), T([64, 128, 16, 16], f16)), {})
+cnt: 1, ((T([64, 128, 16, 16], f16, stride=(65536, 256, 16, 1)), T([64, 128, 16, 16], f16)), {})
+cnt: 4, ((T([256, 64, 384], f16), T([256, 64, 384], f16)), {})
+cnt: 1, ((T([64, 384, 16, 16], f16), T([64, 384, 16, 16], f16)), {})
+cnt: 1, ((T([64, 384, 32, 32], f16), T([64, 384, 32, 32], f16)), {})
+cnt: 2, ((T([64, 96, 32, 32], f16), T([64, 96, 32, 32], f16)), {})
+cnt: 1, ((T([64, 96, 32, 32], f16, stride=(196608, 1024, 32, 1)), T([64, 96, 32, 32], f16)), {})
+cnt: 2, ((T([256, 256, 288], f16), T([256, 256, 288], f16)), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([64, 256, 32, 32], f16)), {})
+cnt: 5, ((T([64, 256, 64, 64], f16), T([64, 256, 64, 64], f16)), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([64, 128, 64, 64], f16)), {})
+cnt: 1, ((T([64, 128, 128, 128], f16), T([64, 128, 128, 128], f16)), {})
+cnt: 2, ((T([64, 64, 128, 128], f16), T([64, 64, 128, 128], f16)), {})
+cnt: 1, ((T([64, 16, 128, 128], f16), T([64, 16, 128, 128], f16)), {})
+Operator: aten.stack.default
+cnt: 3, (([T([256, 4, 16, 60], f16), T([256, 4, 16, 60], f16, stride=(3840, 960, 1, 16)), T([256, 4, 16, 60], f16)],), {})
+cnt: 4, (([T([256, 4, 64, 48], f16), T([256, 4, 64, 48], f16, stride=(12288, 3072, 1, 64)), T([256, 4, 64, 48], f16)],), {})
+cnt: 2, (([T([256, 4, 256, 36], f16), T([256, 4, 256, 36], f16, stride=(36864, 9216, 1, 256)), T([256, 4, 256, 36], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 6, ((T([4096, 240], f16), [0], True), {})
+cnt: 3, ((T([4096, 480], f16), [0], True), {})
+cnt: 3, ((T([4096, 720], f16), [0], True), {})
+cnt: 8, ((T([16384, 192], f16), [0], True), {})
+cnt: 4, ((T([16384, 384], f16), [0], True), {})
+cnt: 4, ((T([16384, 576], f16), [0], True), {})
+cnt: 4, ((T([65536, 144], f16), [0], True), {})
+cnt: 2, ((T([65536, 288], f16), [0], True), {})
+cnt: 2, ((T([65536, 432], f16), [0], True), {})
+Operator: aten.unbind.int
+cnt: 2, ((T([3, 256, 4, 256, 36], f16, stride=(144, 110592, 36, 432, 1)),), {})
+cnt: 4, ((T([3, 256, 4, 64, 48], f16, stride=(192, 36864, 48, 576, 1)),), {})
+cnt: 3, ((T([3, 256, 4, 16, 60], f16, stride=(240, 11520, 60, 720, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/nasnetalarge_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/nasnetalarge_training.txt
new file mode 100644
index 0000000000000..908397ba8fd11
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/nasnetalarge_training.txt
@@ -0,0 +1,309 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([16, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([16, 1000], f16), T([16, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([], i64), 1), {})
+cnt: 6, ((T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16)), {})
+cnt: 6, ((T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16)), {})
+cnt: 66, ((T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16)), {})
+cnt: 72, ((T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16)), {})
+cnt: 72, ((T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16)), {})
+cnt: 12, ((T([16, 672, 11, 11], f16, stride=(487872, 121, 11, 1)), T([16, 672, 11, 11], f16)), {})
+cnt: 6, ((T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16, stride=(487872, 121, 11, 1))), {})
+cnt: 4, ((T([16, 4032, 11, 11], f16), T([16, 4032, 11, 11], f16)), {})
+cnt: 1, ((T([16, 2688, 11, 11], f16), T([16, 2688, 11, 11], f16)), {})
+cnt: 7, ((T([16, 2016, 21, 21], f16), T([16, 2016, 21, 21], f16)), {})
+cnt: 1, ((T([16, 672, 11, 11], f16, stride=(325248, 121, 11, 1)), T([16, 672, 11, 11], f16, stride=(325248, 121, 11, 1))), {})
+cnt: 5, ((T([16, 672, 21, 21], f16), T([16, 672, 21, 21], f16)), {})
+cnt: 12, ((T([16, 336, 21, 21], f16, stride=(889056, 441, 21, 1)), T([16, 336, 21, 21], f16)), {})
+cnt: 6, ((T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16, stride=(889056, 441, 21, 1))), {})
+cnt: 1, ((T([16, 1344, 21, 21], f16), T([16, 1344, 21, 21], f16)), {})
+cnt: 7, ((T([16, 1008, 42, 42], f16), T([16, 1008, 42, 42], f16)), {})
+cnt: 1, ((T([16, 336, 21, 21], f16, stride=(592704, 441, 21, 1)), T([16, 336, 21, 21], f16, stride=(592704, 441, 21, 1))), {})
+cnt: 6, ((T([16, 336, 42, 42], f16), T([16, 336, 42, 42], f16)), {})
+cnt: 12, ((T([16, 168, 42, 42], f16, stride=(1778112, 1764, 42, 1)), T([16, 168, 42, 42], f16)), {})
+cnt: 6, ((T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16, stride=(1778112, 1764, 42, 1))), {})
+cnt: 2, ((T([16, 168, 83, 83], f16), T([16, 168, 83, 83], f16)), {})
+cnt: 1, ((T([16, 84, 42, 42], f16, stride=(592704, 1764, 42, 1)), T([16, 84, 42, 42], f16, stride=(592704, 1764, 42, 1))), {})
+cnt: 5, ((T([16, 84, 83, 83], f16), T([16, 84, 83, 83], f16)), {})
+cnt: 5, ((T([16, 96, 165, 165], f16), T([16, 96, 165, 165], f16)), {})
+cnt: 1, ((T([16, 42, 83, 83], f16, stride=(1157352, 6889, 83, 1)), T([16, 42, 83, 83], f16, stride=(1157352, 6889, 83, 1))), {})
+cnt: 3, ((T([16, 42, 165, 165], f16), T([16, 42, 165, 165], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 263, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([16, 4032], f16), T([4032, 1000], f16, stride=(1, 4032))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([16, 42, 167, 167], f16), [3, 3], [2, 2], [0, 0], False, False), {})
+cnt: 1, ((T([16, 42, 83, 83], f16), [3, 3], [1, 1], [1, 1], False, False), {})
+cnt: 2, ((T([16, 96, 165, 165], f16), [1, 1], [2, 2], [0, 0], False, False), {})
+cnt: 1, ((T([16, 84, 85, 85], f16), [3, 3], [2, 2], [0, 0], False, False), {})
+cnt: 1, ((T([16, 84, 42, 42], f16), [3, 3], [1, 1], [1, 1], False, False), {})
+cnt: 2, ((T([16, 168, 83, 83], f16), [1, 1], [2, 2], [0, 0], False, False), {})
+cnt: 18, ((T([16, 168, 42, 42], f16), [3, 3], [1, 1], [1, 1], False, False), {})
+cnt: 1, ((T([16, 336, 43, 43], f16), [3, 3], [2, 2], [0, 0], False, False), {})
+cnt: 19, ((T([16, 336, 21, 21], f16), [3, 3], [1, 1], [1, 1], False, False), {})
+cnt: 2, ((T([16, 1008, 42, 42], f16), [1, 1], [2, 2], [0, 0], False, False), {})
+cnt: 1, ((T([16, 672, 23, 23], f16), [3, 3], [2, 2], [0, 0], False, False), {})
+cnt: 19, ((T([16, 672, 11, 11], f16), [3, 3], [1, 1], [1, 1], False, False), {})
+cnt: 2, ((T([16, 2016, 21, 21], f16), [1, 1], [2, 2], [0, 0], False, False), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 18, ((T([16, 672, 11, 11], f16, stride=(487872, 121, 11, 1)), T([16, 672, 11, 11], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 2, ((T([16, 2016, 11, 11], f16), T([16, 2016, 21, 21], f16), [1, 1], [2, 2], [0, 0], False, False, None), {})
+cnt: 1, ((T([16, 672, 11, 11], f16, stride=(325248, 121, 11, 1)), T([16, 672, 11, 11], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 1, ((T([16, 672, 11, 11], f16, stride=(325248, 121, 11, 1)), T([16, 672, 23, 23], f16), [3, 3], [2, 2], [0, 0], False, False, None), {})
+cnt: 18, ((T([16, 336, 21, 21], f16, stride=(889056, 441, 21, 1)), T([16, 336, 21, 21], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 2, ((T([16, 1008, 21, 21], f16), T([16, 1008, 42, 42], f16), [1, 1], [2, 2], [0, 0], False, False, None), {})
+cnt: 1, ((T([16, 336, 21, 21], f16, stride=(592704, 441, 21, 1)), T([16, 336, 21, 21], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 1, ((T([16, 336, 21, 21], f16, stride=(592704, 441, 21, 1)), T([16, 336, 43, 43], f16), [3, 3], [2, 2], [0, 0], False, False, None), {})
+cnt: 18, ((T([16, 168, 42, 42], f16, stride=(1778112, 1764, 42, 1)), T([16, 168, 42, 42], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 2, ((T([16, 168, 42, 42], f16), T([16, 168, 83, 83], f16), [1, 1], [2, 2], [0, 0], False, False, None), {})
+cnt: 1, ((T([16, 84, 42, 42], f16, stride=(592704, 1764, 42, 1)), T([16, 84, 42, 42], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 1, ((T([16, 84, 42, 42], f16, stride=(592704, 1764, 42, 1)), T([16, 84, 85, 85], f16), [3, 3], [2, 2], [0, 0], False, False, None), {})
+cnt: 2, ((T([16, 96, 83, 83], f16), T([16, 96, 165, 165], f16), [1, 1], [2, 2], [0, 0], False, False, None), {})
+cnt: 1, ((T([16, 42, 83, 83], f16, stride=(1157352, 6889, 83, 1)), T([16, 42, 83, 83], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 1, ((T([16, 42, 83, 83], f16, stride=(1157352, 6889, 83, 1)), T([16, 42, 167, 167], f16), [3, 3], [2, 2], [0, 0], False, False, None), {})
+Operator: aten.cat.default
+cnt: 1, (([T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16)], 1), {})
+cnt: 1, (([T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16)], 1), {})
+cnt: 1, (([T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16)], 1), {})
+cnt: 1, (([T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16)], 1), {})
+cnt: 6, (([T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16)], 1), {})
+cnt: 1, (([T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16)], 1), {})
+cnt: 1, (([T([16, 168, 21, 21], f16), T([16, 168, 21, 21], f16)], 1), {})
+cnt: 6, (([T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16)], 1), {})
+cnt: 1, (([T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16)], 1), {})
+cnt: 1, (([T([16, 336, 11, 11], f16), T([16, 336, 11, 11], f16)], 1), {})
+cnt: 6, (([T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([16, 3, 331, 331], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([16, 42, 165, 165], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 2, ((T([16, 96, 165, 165], f16), [3, 3, 3, 3], 0.0), {})
+cnt: 2, ((T([16, 42, 165, 165], f16), [1, 1, 1, 1], -inf), {})
+cnt: 1, ((T([16, 42, 165, 165], f16), [1, 1, 1, 1], 0.0), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), [-1, 1, -1, 1], 0.0), {})
+cnt: 2, ((T([16, 84, 83, 83], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 2, ((T([16, 84, 83, 83], f16), [3, 3, 3, 3], 0.0), {})
+cnt: 2, ((T([16, 84, 83, 83], f16), [1, 1, 1, 1], -inf), {})
+cnt: 1, ((T([16, 84, 83, 83], f16), [1, 1, 1, 1], 0.0), {})
+cnt: 1, ((T([16, 168, 83, 83], f16), [-1, 1, -1, 1], 0.0), {})
+cnt: 2, ((T([16, 336, 42, 42], f16), [1, 2, 1, 2], 0.0), {})
+cnt: 2, ((T([16, 336, 42, 42], f16), [2, 3, 2, 3], 0.0), {})
+cnt: 2, ((T([16, 336, 42, 42], f16), [0, 1, 0, 1], -inf), {})
+cnt: 1, ((T([16, 336, 42, 42], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([16, 1008, 42, 42], f16), [-1, 1, -1, 1], 0.0), {})
+cnt: 2, ((T([16, 672, 21, 21], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 2, ((T([16, 672, 21, 21], f16), [3, 3, 3, 3], 0.0), {})
+cnt: 2, ((T([16, 672, 21, 21], f16), [1, 1, 1, 1], -inf), {})
+cnt: 1, ((T([16, 672, 21, 21], f16), [1, 1, 1, 1], 0.0), {})
+cnt: 1, ((T([16, 2016, 21, 21], f16), [-1, 1, -1, 1], 0.0), {})
+cnt: 1, ((T([16, 2016, 21, 21], f16), [1, -1, 1, -1]), {})
+cnt: 3, ((T([16, 672, 23, 23], f16), [-1, -1, -1, -1]), {})
+cnt: 2, ((T([16, 672, 25, 25], f16), [-2, -2, -2, -2]), {})
+cnt: 2, ((T([16, 672, 27, 27], f16), [-3, -3, -3, -3]), {})
+cnt: 1, ((T([16, 1008, 42, 42], f16), [1, -1, 1, -1]), {})
+cnt: 3, ((T([16, 336, 43, 43], f16), [0, -1, 0, -1]), {})
+cnt: 2, ((T([16, 336, 45, 45], f16), [-1, -2, -1, -2]), {})
+cnt: 2, ((T([16, 336, 47, 47], f16), [-2, -3, -2, -3]), {})
+cnt: 1, ((T([16, 168, 83, 83], f16), [1, -1, 1, -1]), {})
+cnt: 3, ((T([16, 84, 85, 85], f16), [-1, -1, -1, -1]), {})
+cnt: 2, ((T([16, 84, 87, 87], f16), [-2, -2, -2, -2]), {})
+cnt: 2, ((T([16, 84, 89, 89], f16), [-3, -3, -3, -3]), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), [1, -1, 1, -1]), {})
+cnt: 3, ((T([16, 42, 167, 167], f16), [-1, -1, -1, -1]), {})
+cnt: 1, ((T([16, 96, 169, 169], f16), [-2, -2, -2, -2]), {})
+cnt: 2, ((T([16, 96, 171, 171], f16), [-3, -3, -3, -3]), {})
+cnt: 1, ((T([16, 42, 169, 169], f16), [-2, -2, -2, -2]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([16, 3, 331, 331], f16), T([96, 3, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), T([42, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 42, 169, 169], f16), T([42, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 42), {})
+cnt: 7, ((T([16, 42, 83, 83], f16), T([42, 42, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 42, 83, 83], f16), T([42, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 42), {})
+cnt: 2, ((T([16, 96, 171, 171], f16), T([96, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 96), {})
+cnt: 5, ((T([16, 96, 83, 83], f16), T([42, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 42, 83, 83], f16), T([42, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 42), {})
+cnt: 1, ((T([16, 96, 169, 169], f16), T([96, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 96), {})
+cnt: 2, ((T([16, 42, 83, 83], f16), T([42, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 42), {})
+cnt: 1, ((T([16, 168, 83, 83], f16), T([84, 168, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 84, 87, 87], f16), T([84, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 84), {})
+cnt: 10, ((T([16, 84, 42, 42], f16), T([84, 84, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 84, 42, 42], f16), T([84, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 84), {})
+cnt: 2, ((T([16, 84, 89, 89], f16), T([84, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 84), {})
+cnt: 2, ((T([16, 84, 42, 42], f16), T([84, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 84), {})
+cnt: 2, ((T([16, 84, 42, 42], f16), T([84, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 84), {})
+cnt: 2, ((T([16, 168, 42, 42], f16), T([84, 168, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 336, 42, 42], f16), T([168, 336, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 24, ((T([16, 168, 42, 42], f16), T([168, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 168), {})
+cnt: 60, ((T([16, 168, 42, 42], f16), T([168, 168, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 36, ((T([16, 168, 42, 42], f16), T([168, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 168), {})
+cnt: 9, ((T([16, 1008, 42, 42], f16), T([168, 1008, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 1008, 42, 42], f16), T([336, 1008, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 336, 45, 45], f16), T([336, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 336), {})
+cnt: 70, ((T([16, 336, 21, 21], f16), T([336, 336, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 26, ((T([16, 336, 21, 21], f16), T([336, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 336), {})
+cnt: 2, ((T([16, 336, 47, 47], f16), T([336, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 336), {})
+cnt: 2, ((T([16, 336, 21, 21], f16), T([336, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 336), {})
+cnt: 38, ((T([16, 336, 21, 21], f16), T([336, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 336), {})
+cnt: 2, ((T([16, 1008, 21, 21], f16), T([168, 1008, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 1344, 21, 21], f16), T([336, 1344, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([16, 2016, 21, 21], f16), T([336, 2016, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 2016, 21, 21], f16), T([672, 2016, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 672, 25, 25], f16), T([672, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 672), {})
+cnt: 70, ((T([16, 672, 11, 11], f16), T([672, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 26, ((T([16, 672, 11, 11], f16), T([672, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 2, ((T([16, 672, 27, 27], f16), T([672, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 672), {})
+cnt: 2, ((T([16, 672, 11, 11], f16), T([672, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 672), {})
+cnt: 38, ((T([16, 672, 11, 11], f16), T([672, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 672), {})
+cnt: 2, ((T([16, 2016, 11, 11], f16), T([336, 2016, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 2688, 11, 11], f16), T([672, 2688, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([16, 4032, 11, 11], f16), T([672, 4032, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 70, ((T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([672, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 38, ((T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([672, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 26, ((T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([672, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 9, ((T([16, 672, 11, 11], f16), T([16, 4032, 11, 11], f16), T([672, 4032, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 672, 11, 11], f16), T([16, 2688, 11, 11], f16), T([672, 2688, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 336, 11, 11], f16, stride=(81312, 121, 11, 1)), T([16, 2016, 11, 11], f16), T([336, 2016, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 672, 11, 11], f16), T([16, 672, 25, 25], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 2, ((T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([672, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 2, ((T([16, 672, 11, 11], f16), T([16, 672, 27, 27], f16), T([672, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 2, ((T([16, 672, 21, 21], f16), T([16, 2016, 21, 21], f16), T([672, 2016, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 70, ((T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([336, 336, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 38, ((T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([336, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 336, [True, True, False]), {})
+cnt: 26, ((T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([336, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 336, [True, True, False]), {})
+cnt: 9, ((T([16, 336, 21, 21], f16), T([16, 2016, 21, 21], f16), T([336, 2016, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 336, 21, 21], f16), T([16, 1344, 21, 21], f16), T([336, 1344, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 168, 21, 21], f16, stride=(148176, 441, 21, 1)), T([16, 1008, 21, 21], f16), T([168, 1008, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 336, 21, 21], f16), T([16, 336, 45, 45], f16), T([336, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 336, [True, True, False]), {})
+cnt: 2, ((T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([336, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 336, [True, True, False]), {})
+cnt: 2, ((T([16, 336, 21, 21], f16), T([16, 336, 47, 47], f16), T([336, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 336, [True, True, False]), {})
+cnt: 2, ((T([16, 336, 42, 42], f16), T([16, 1008, 42, 42], f16), T([336, 1008, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 60, ((T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16), T([168, 168, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 36, ((T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16), T([168, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 168, [True, True, False]), {})
+cnt: 24, ((T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16), T([168, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 168, [True, True, False]), {})
+cnt: 9, ((T([16, 168, 42, 42], f16), T([16, 1008, 42, 42], f16), T([168, 1008, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 168, 42, 42], f16), T([16, 336, 42, 42], f16), T([168, 336, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 84, 42, 42], f16, stride=(296352, 1764, 42, 1)), T([16, 168, 42, 42], f16), T([84, 168, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 10, ((T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16), T([84, 84, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16), T([84, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 84, [True, True, False]), {})
+cnt: 2, ((T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16), T([84, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 84, [True, True, False]), {})
+cnt: 2, ((T([16, 84, 42, 42], f16), T([16, 84, 87, 87], f16), T([84, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 84, [True, True, False]), {})
+cnt: 2, ((T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16), T([84, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 84, [True, True, False]), {})
+cnt: 2, ((T([16, 84, 42, 42], f16), T([16, 84, 89, 89], f16), T([84, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 84, [True, True, False]), {})
+cnt: 2, ((T([16, 42, 83, 83], f16, stride=(578676, 6889, 83, 1)), T([16, 96, 83, 83], f16), T([42, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 84, 83, 83], f16), T([16, 168, 83, 83], f16), T([84, 168, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16), T([42, 42, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16), T([42, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 42, [True, True, False]), {})
+cnt: 2, ((T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16), T([42, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 42, [True, True, False]), {})
+cnt: 3, ((T([16, 42, 83, 83], f16), T([16, 96, 83, 83], f16), T([42, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 96, 83, 83], f16), T([16, 96, 169, 169], f16), T([96, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 2, ((T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16), T([42, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 42, [True, True, False]), {})
+cnt: 2, ((T([16, 96, 83, 83], f16), T([16, 96, 171, 171], f16), T([96, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([16, 42, 83, 83], f16), T([16, 42, 169, 169], f16), T([42, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 42, [True, True, False]), {})
+cnt: 1, ((T([16, 42, 165, 165], f16), T([16, 96, 165, 165], f16), T([42, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), T([16, 3, 331, 331], f16), T([96, 3, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([16, 3, 331, 331], f16), T([16, 3, 331, 331], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([16, 4032, 11, 11], f16, stride=(4032, 1, 0, 0)), 121), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([16], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 2, ((T([16, 42, 167, 167], f16), [3, 3], [2, 2]), {})
+cnt: 2, ((T([16, 84, 85, 85], f16), [3, 3], [2, 2]), {})
+cnt: 2, ((T([16, 336, 43, 43], f16), [3, 3], [2, 2]), {})
+cnt: 2, ((T([16, 672, 23, 23], f16), [3, 3], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([16, 672, 11, 11], f16, stride=(325248, 121, 11, 1)), T([16, 672, 23, 23], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 672, 11, 11], i64)), {})
+cnt: 1, ((T([16, 672, 11, 11], f16), T([16, 672, 23, 23], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 672, 11, 11], i64)), {})
+cnt: 1, ((T([16, 336, 21, 21], f16, stride=(592704, 441, 21, 1)), T([16, 336, 43, 43], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 336, 21, 21], i64)), {})
+cnt: 1, ((T([16, 336, 21, 21], f16), T([16, 336, 43, 43], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 336, 21, 21], i64)), {})
+cnt: 1, ((T([16, 84, 42, 42], f16, stride=(592704, 1764, 42, 1)), T([16, 84, 85, 85], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 84, 42, 42], i64)), {})
+cnt: 1, ((T([16, 84, 42, 42], f16), T([16, 84, 85, 85], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 84, 42, 42], i64)), {})
+cnt: 1, ((T([16, 42, 83, 83], f16, stride=(1157352, 6889, 83, 1)), T([16, 42, 167, 167], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 42, 83, 83], i64)), {})
+cnt: 1, ((T([16, 42, 83, 83], f16), T([16, 42, 167, 167], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 42, 83, 83], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([16, 4032, 11, 11], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([16, 1000], f16), T([1000, 4032], f16)), {})
+cnt: 1, ((T([1000, 16], f16, stride=(1, 1000)), T([16, 4032], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([16, 96, 165, 165], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([16, 42, 165, 165], f16), T([42], f16), T([42], f16), T([42], f16), T([42], f16), True, 0.1, 0.001), {})
+cnt: 10, ((T([16, 42, 83, 83], f16), T([42], f16), T([42], f16), T([42], f16), T([42], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([16, 84, 83, 83], f16), T([84], f16), T([84], f16), T([84], f16), T([84], f16), True, 0.1, 0.001), {})
+cnt: 10, ((T([16, 84, 42, 42], f16), T([84], f16), T([84], f16), T([84], f16), T([84], f16), True, 0.1, 0.001), {})
+cnt: 72, ((T([16, 168, 42, 42], f16), T([168], f16), T([168], f16), T([168], f16), T([168], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([16, 336, 42, 42], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f16), True, 0.1, 0.001), {})
+cnt: 82, ((T([16, 336, 21, 21], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([16, 672, 21, 21], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 0.001), {})
+cnt: 82, ((T([16, 672, 11, 11], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 30, ((T([16, 672, 11, 11], f16, stride=(487872, 121, 11, 1)), T([16, 672, 11, 11], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 0.001, [True, True, True]), {})
+cnt: 50, ((T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 672, 11, 11], f16, stride=(325248, 121, 11, 1)), T([16, 672, 11, 11], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 672, 21, 21], f16), T([16, 672, 21, 21], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 0.001, [True, True, True]), {})
+cnt: 30, ((T([16, 336, 21, 21], f16, stride=(889056, 441, 21, 1)), T([16, 336, 21, 21], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 0.001, [True, True, True]), {})
+cnt: 50, ((T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 336, 21, 21], f16, stride=(592704, 441, 21, 1)), T([16, 336, 21, 21], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 336, 42, 42], f16), T([16, 336, 42, 42], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 0.001, [True, True, True]), {})
+cnt: 30, ((T([16, 168, 42, 42], f16, stride=(1778112, 1764, 42, 1)), T([16, 168, 42, 42], f16), T([168], f16), T([168], f16), T([168], f16), T([168], f32), T([168], f32), True, 0.001, [True, True, True]), {})
+cnt: 42, ((T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16), T([168], f16), T([168], f16), T([168], f16), T([168], f32), T([168], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 84, 42, 42], f16, stride=(592704, 1764, 42, 1)), T([16, 84, 42, 42], f16), T([84], f16), T([84], f16), T([84], f16), T([84], f32), T([84], f32), True, 0.001, [True, True, True]), {})
+cnt: 8, ((T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16), T([84], f16), T([84], f16), T([84], f16), T([84], f32), T([84], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 84, 83, 83], f16), T([16, 84, 83, 83], f16), T([84], f16), T([84], f16), T([84], f16), T([84], f32), T([84], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 42, 83, 83], f16, stride=(1157352, 6889, 83, 1)), T([16, 42, 83, 83], f16), T([42], f16), T([42], f16), T([42], f16), T([42], f32), T([42], f32), True, 0.001, [True, True, True]), {})
+cnt: 8, ((T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16), T([42], f16), T([42], f16), T([42], f16), T([42], f32), T([42], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([16, 42, 165, 165], f16), T([16, 42, 165, 165], f16), T([42], f16), T([42], f16), T([42], f16), T([42], f32), T([42], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), T([16, 96, 165, 165], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([16, 1000], f16), T([16], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([16, 1000], f16), T([16], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 5, ((T([16, 96, 165, 165], f16),), {})
+cnt: 1, ((T([16, 42, 165, 165], f16),), {})
+cnt: 1, ((T([16, 42, 83, 83], f16),), {})
+cnt: 2, ((T([16, 168, 83, 83], f16),), {})
+cnt: 4, ((T([16, 84, 83, 83], f16),), {})
+cnt: 1, ((T([16, 84, 42, 42], f16),), {})
+cnt: 6, ((T([16, 336, 42, 42], f16),), {})
+cnt: 30, ((T([16, 168, 42, 42], f16),), {})
+cnt: 12, ((T([16, 1008, 42, 42], f16),), {})
+cnt: 31, ((T([16, 336, 21, 21], f16),), {})
+cnt: 2, ((T([16, 1344, 21, 21], f16),), {})
+cnt: 12, ((T([16, 2016, 21, 21], f16),), {})
+cnt: 4, ((T([16, 672, 21, 21], f16),), {})
+cnt: 31, ((T([16, 672, 11, 11], f16),), {})
+cnt: 2, ((T([16, 2688, 11, 11], f16),), {})
+cnt: 9, ((T([16, 4032, 11, 11], f16),), {})
+Operator: aten.relu_.default
+cnt: 5, ((T([16, 42, 83, 83], f16),), {})
+cnt: 5, ((T([16, 84, 42, 42], f16),), {})
+cnt: 30, ((T([16, 168, 42, 42], f16),), {})
+cnt: 35, ((T([16, 336, 21, 21], f16),), {})
+cnt: 35, ((T([16, 672, 11, 11], f16),), {})
+cnt: 1, ((T([16, 4032, 11, 11], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([16, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 10, ((T([16, 4032, 11, 11], f16), T([16, 4032, 11, 11], f16), 0), {})
+cnt: 66, ((T([16, 672, 11, 11], f16), T([16, 672, 11, 11], f16), 0), {})
+cnt: 2, ((T([16, 2688, 11, 11], f16), T([16, 2688, 11, 11], f16), 0), {})
+cnt: 12, ((T([16, 2016, 21, 21], f16), T([16, 2016, 21, 21], f16), 0), {})
+cnt: 4, ((T([16, 672, 21, 21], f16), T([16, 672, 21, 21], f16), 0), {})
+cnt: 66, ((T([16, 336, 21, 21], f16), T([16, 336, 21, 21], f16), 0), {})
+cnt: 2, ((T([16, 1344, 21, 21], f16), T([16, 1344, 21, 21], f16), 0), {})
+cnt: 12, ((T([16, 1008, 42, 42], f16), T([16, 1008, 42, 42], f16), 0), {})
+cnt: 6, ((T([16, 336, 42, 42], f16), T([16, 336, 42, 42], f16), 0), {})
+cnt: 60, ((T([16, 168, 42, 42], f16), T([16, 168, 42, 42], f16), 0), {})
+cnt: 2, ((T([16, 168, 83, 83], f16), T([16, 168, 83, 83], f16), 0), {})
+cnt: 6, ((T([16, 84, 42, 42], f16), T([16, 84, 42, 42], f16), 0), {})
+cnt: 4, ((T([16, 84, 83, 83], f16), T([16, 84, 83, 83], f16), 0), {})
+cnt: 5, ((T([16, 96, 165, 165], f16), T([16, 96, 165, 165], f16), 0), {})
+cnt: 6, ((T([16, 42, 83, 83], f16), T([16, 42, 83, 83], f16), 0), {})
+cnt: 1, ((T([16, 42, 165, 165], f16), T([16, 42, 165, 165], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/nfnet_l0_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/nfnet_l0_training.txt
new file mode 100644
index 0000000000000..ae315ada2dfb9
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/nfnet_l0_training.txt
@@ -0,0 +1,267 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 3, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+cnt: 6, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16)), {})
+cnt: 18, ((T([128, 1536, 14, 14], f16), T([128, 1536, 14, 14], f16)), {})
+cnt: 8, ((T([128, 1536, 7, 7], f16), T([128, 1536, 7, 7], f16)), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2304], f16), T([2304, 1000], f16, stride=(1, 2304))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 256, 56, 56], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([128, 1536, 14, 14], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 1536, 7, 7], f16), T([128, 1536, 14, 14], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([128, 512, 28, 28], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([128, 256, 56, 56], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+cnt: 1, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 2, ((T([128, 64, 56, 56], f16),), {})
+cnt: 1, ((T([128, 128, 56, 56], f16),), {})
+cnt: 3, ((T([128, 128, 28, 28], f16),), {})
+cnt: 1, ((T([128, 384, 28, 28], f16),), {})
+cnt: 12, ((T([128, 384, 14, 14], f16),), {})
+cnt: 5, ((T([128, 384, 7, 7], f16),), {})
+cnt: 1, ((T([128, 2304, 7, 7], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), T([16], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([32, 16, 3, 3], f16), T([32], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([64, 32, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 3, 3], f16), T([128], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([256, 128, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([64, 128, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 64, 56, 56], f16), T([64, 64, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([256, 64, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([64, 256, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([256, 64, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 28, 28], f16), T([512, 256, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 256, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 64, 3, 3], f16), T([128], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 3, ((T([128, 128, 28, 28], f16), T([128, 64, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 2, ((T([128, 128, 28, 28], f16), T([512, 128, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 512, 1, 1], f16), T([128, 512, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 1, 1], f16), T([512, 128, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 512, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 14, 14], f16), T([1536, 512, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([384, 512, 1, 1], f16), T([384], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 384, 28, 28], f16), T([384, 64, 3, 3], f16), T([384], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 6), {})
+cnt: 11, ((T([128, 384, 14, 14], f16), T([384, 64, 3, 3], f16), T([384], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 6), {})
+cnt: 6, ((T([128, 384, 14, 14], f16), T([1536, 384, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([128, 1536, 1, 1], f16), T([384, 1536, 1, 1], f16), T([384], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([128, 384, 1, 1], f16), T([1536, 384, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([128, 1536, 14, 14], f16), T([384, 1536, 1, 1], f16), T([384], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1536, 7, 7], f16), T([1536, 1536, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([384, 64, 3, 3], f16), T([384], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 6), {})
+cnt: 5, ((T([128, 384, 7, 7], f16), T([384, 64, 3, 3], f16), T([384], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 6), {})
+cnt: 3, ((T([128, 384, 7, 7], f16), T([1536, 384, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 1536, 7, 7], f16), T([384, 1536, 1, 1], f16), T([384], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1536, 7, 7], f16), T([2304, 1536, 1, 1], f16), T([2304], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 2304, 7, 7], f16), T([128, 1536, 7, 7], f16), T([2304, 1536, 1, 1], f16), [2304], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 9, ((T([128, 1536, 1, 1], f16), T([128, 384, 1, 1], f16), T([1536, 384, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 9, ((T([128, 384, 1, 1], f16), T([128, 1536, 1, 1], f16), T([384, 1536, 1, 1], f16), [384], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 1536, 7, 7], f16), T([128, 384, 7, 7], f16), T([1536, 384, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 5, ((T([128, 384, 7, 7], f16), T([128, 384, 7, 7], f16), T([384, 64, 3, 3], f16), [384], [1, 1], [1, 1], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 2, ((T([128, 384, 7, 7], f16), T([128, 1536, 7, 7], f16), T([384, 1536, 1, 1], f16), [384], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 384, 7, 7], f16), T([128, 384, 14, 14], f16), T([384, 64, 3, 3], f16), [384], [2, 2], [1, 1], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 6, ((T([128, 384, 14, 14], f16), T([128, 1536, 14, 14], f16), T([384, 1536, 1, 1], f16), [384], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1536, 7, 7], f16), T([128, 1536, 7, 7], f16), T([1536, 1536, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([128, 1536, 14, 14], f16), T([128, 384, 14, 14], f16), T([1536, 384, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 11, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384, 64, 3, 3], f16), [384], [1, 1], [1, 1], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([128, 384, 28, 28], f16), T([384, 64, 3, 3], f16), [384], [2, 2], [1, 1], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 1, ((T([128, 384, 28, 28], f16), T([128, 512, 28, 28], f16), T([384, 512, 1, 1], f16), [384], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1536, 14, 14], f16), T([128, 512, 14, 14], f16), T([1536, 512, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 512, 1, 1], f16), T([128, 128, 1, 1], f16), T([512, 128, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 128, 1, 1], f16), T([128, 512, 1, 1], f16), T([128, 512, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 512, 28, 28], f16), T([128, 128, 28, 28], f16), T([512, 128, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16), T([128, 64, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 28, 28], f16), T([128, 512, 28, 28], f16), T([128, 512, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 28, 28], f16), T([128, 128, 56, 56], f16), T([128, 64, 3, 3], f16), [128], [2, 2], [1, 1], [1, 1], False, [0, 0], 2, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 256, 28, 28], f16), T([512, 256, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 64, 1, 1], f16), T([256, 64, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([128, 256, 1, 1], f16), T([64, 256, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64, 64, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 128, 56, 56], f16), T([64, 128, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 128, 56, 56], f16), T([256, 128, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 64, 112, 112], f16), T([128, 64, 3, 3], f16), [128], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 32, 112, 112], f16), T([64, 32, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 16, 112, 112], f16), T([32, 16, 3, 3], f16), [32], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 3, 224, 224], f16), T([16, 3, 3, 3], f16), [16], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2304, 7, 7], f16, stride=(2304, 1, 0, 0)), 49), {})
+cnt: 3, ((T([128, 1536, 7, 7], f16, stride=(1536, 1, 0, 0)), 49), {})
+cnt: 6, ((T([128, 1536, 14, 14], f16, stride=(1536, 1, 0, 0)), 196), {})
+cnt: 2, ((T([128, 512, 28, 28], f16, stride=(512, 1, 0, 0)), 784), {})
+cnt: 1, ((T([128, 256, 56, 56], f16, stride=(256, 1, 0, 0)), 3136), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 256, 56, 56], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 512, 28, 28], f16), [2, 3], True), {})
+cnt: 6, ((T([128, 1536, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([128, 1536, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 2304, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 2304], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2304], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([16, 1, 1, 1], f16), 0.34412564994580647), {})
+cnt: 2, ((T([32, 1, 1, 1], f16), 0.1490107774734497), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.10536653122135592), {})
+cnt: 10, ((T([128, 1, 1, 1], f16), 0.07450538873672485), {})
+cnt: 2, ((T([128, 128, 56, 56], f16), 1.0), {})
+cnt: 2, ((T([256, 1, 1, 1], f16), 0.1580497968320339), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.1580497968320339), {})
+cnt: 4, ((T([64, 1, 1, 1], f16), 0.07450538873672485), {})
+cnt: 2, ((T([256, 1, 1, 1], f16), 0.22351616621017456), {})
+cnt: 2, ((T([128, 256, 56, 56], f16), T([128, 256, 1, 1], f16)), {})
+cnt: 2, ((T([128, 256, 56, 56], f16), 2.0), {})
+cnt: 2, ((T([128, 256, 56, 56], f16), 0.2), {})
+cnt: 2, ((T([128, 256, 56, 56], f16), 0.9805806756909201), {})
+cnt: 2, ((T([512, 1, 1, 1], f16), 0.11175808310508728), {})
+cnt: 2, ((T([128, 1, 1, 1], f16), 0.11175808310508728), {})
+cnt: 4, ((T([512, 1, 1, 1], f16), 0.1580497968320339), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), T([128, 512, 1, 1], f16)), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), 2.0), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), 0.2), {})
+cnt: 2, ((T([128, 512, 28, 28], f16), 0.9805806756909201), {})
+cnt: 2, ((T([128, 1, 1, 1], f16), 0.07902489841601695), {})
+cnt: 2, ((T([128, 512, 28, 28], f16), 0.9622504486493761), {})
+cnt: 2, ((T([1536, 1, 1, 1], f16), 0.07902489841601695), {})
+cnt: 2, ((T([384, 1, 1, 1], f16), 0.07902489841601695), {})
+cnt: 36, ((T([384, 1, 1, 1], f16), 0.07450538873672485), {})
+cnt: 18, ((T([1536, 1, 1, 1], f16), 0.09125009274634042), {})
+cnt: 12, ((T([128, 1536, 14, 14], f16), T([128, 1536, 1, 1], f16)), {})
+cnt: 12, ((T([128, 1536, 14, 14], f16), 2.0), {})
+cnt: 12, ((T([128, 1536, 14, 14], f16), 0.2), {})
+cnt: 2, ((T([128, 1536, 14, 14], f16), 0.9805806756909201), {})
+cnt: 16, ((T([384, 1, 1, 1], f16), 0.04562504637317021), {})
+cnt: 2, ((T([128, 1536, 14, 14], f16), 0.9622504486493761), {})
+cnt: 2, ((T([128, 1536, 14, 14], f16), 0.9449111825230679), {})
+cnt: 2, ((T([128, 1536, 14, 14], f16), 0.9284766908852592), {})
+cnt: 2, ((T([128, 1536, 14, 14], f16), 0.9128709291752768), {})
+cnt: 2, ((T([128, 1536, 14, 14], f16), 0.8980265101338745), {})
+cnt: 2, ((T([1536, 1, 1, 1], f16), 0.04562504637317021), {})
+cnt: 6, ((T([128, 1536, 7, 7], f16), T([128, 1536, 1, 1], f16)), {})
+cnt: 6, ((T([128, 1536, 7, 7], f16), 2.0), {})
+cnt: 6, ((T([128, 1536, 7, 7], f16), 0.2), {})
+cnt: 2, ((T([128, 1536, 7, 7], f16), 0.9805806756909201), {})
+cnt: 2, ((T([128, 1536, 7, 7], f16), 0.9622504486493761), {})
+cnt: 2, ((T([2304, 1, 1, 1], f16), 0.04562504637317021), {})
+cnt: 3, ((T([128, 1536, 7, 7], f16), T([128, 1536, 7, 7], f16)), {})
+cnt: 6, ((T([128, 1536, 14, 14], f16), T([128, 1536, 14, 14], f16)), {})
+cnt: 2, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16)), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([1, 16, 27], f16), T([16], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 32, 144], f16), T([32], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 64, 288], f16), T([64], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 5, ((T([1, 128, 576], f16), T([128], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 256, 128], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 64, 128], f16), T([64], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 2, ((T([1, 64, 576], f16), T([64], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 256, 64], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 512, 256], f16), T([512], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 128, 256], f16), T([128], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 2, ((T([1, 512, 128], f16), T([512], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 128, 512], f16), T([128], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 1536, 512], f16), T([1536], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 384, 512], f16), T([384], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 18, ((T([1, 384, 576], f16), T([384], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 9, ((T([1, 1536, 384], f16), T([1536], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 8, ((T([1, 384, 1536], f16), T([384], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 1536, 1536], f16), T([1536], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 2304, 1536], f16), T([2304], f16), None, None, None, True, 0.0, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([1, 2304, 1536], f16), T([1, 2304, 1536], f16), T([2304], f16), None, None, T([2304], f32), T([2304], f32), True, 1e-05, [True, True, False]), {})
+cnt: 9, ((T([1, 1536, 384], f16), T([1, 1536, 384], f16), T([1536], f16), None, None, T([1536], f32), T([1536], f32), True, 1e-05, [True, True, False]), {})
+cnt: 18, ((T([1, 384, 576], f16), T([1, 384, 576], f16), T([384], f16), None, None, T([384], f32), T([384], f32), True, 1e-05, [True, True, False]), {})
+cnt: 8, ((T([1, 384, 1536], f16), T([1, 384, 1536], f16), T([384], f16), None, None, T([384], f32), T([384], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 1536, 1536], f16), T([1, 1536, 1536], f16), T([1536], f16), None, None, T([1536], f32), T([1536], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 384, 512], f16), T([1, 384, 512], f16), T([384], f16), None, None, T([384], f32), T([384], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 1536, 512], f16), T([1, 1536, 512], f16), T([1536], f16), None, None, T([1536], f32), T([1536], f32), True, 1e-05, [True, True, False]), {})
+cnt: 2, ((T([1, 512, 128], f16), T([1, 512, 128], f16), T([512], f16), None, None, T([512], f32), T([512], f32), True, 1e-05, [True, True, False]), {})
+cnt: 5, ((T([1, 128, 576], f16), T([1, 128, 576], f16), T([128], f16), None, None, T([128], f32), T([128], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 128, 512], f16), T([1, 128, 512], f16), T([128], f16), None, None, T([128], f32), T([128], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 128, 256], f16), T([1, 128, 256], f16), T([128], f16), None, None, T([128], f32), T([128], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 512, 256], f16), T([1, 512, 256], f16), T([512], f16), None, None, T([512], f32), T([512], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 256, 64], f16), T([1, 256, 64], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 2, ((T([1, 64, 576], f16), T([1, 64, 576], f16), T([64], f16), None, None, T([64], f32), T([64], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 64, 128], f16), T([1, 64, 128], f16), T([64], f16), None, None, T([64], f32), T([64], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 256, 128], f16), T([1, 256, 128], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 64, 288], f16), T([1, 64, 288], f16), T([64], f16), None, None, T([64], f32), T([64], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 32, 144], f16), T([1, 32, 144], f16), T([32], f16), None, None, T([32], f32), T([32], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 16, 27], f16), T([1, 16, 27], f16), T([16], f16), None, None, T([16], f32), T([16], f32), True, 1e-05, [True, True, False]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 64, 1, 1], f16),), {})
+cnt: 2, ((T([128, 128, 1, 1], f16),), {})
+cnt: 9, ((T([128, 384, 1, 1], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([128, 256, 1, 1], f16),), {})
+cnt: 2, ((T([128, 512, 1, 1], f16),), {})
+cnt: 9, ((T([128, 1536, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 9, ((T([128, 1536, 1, 1], f16), T([128, 1536, 1, 1], f16)), {})
+cnt: 2, ((T([128, 512, 1, 1], f16), T([128, 512, 1, 1], f16)), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16)), {})
+Operator: aten.silu.default
+cnt: 1, ((T([128, 128, 56, 56], f16),), {})
+cnt: 1, ((T([128, 64, 56, 56], f16),), {})
+cnt: 1, ((T([128, 256, 56, 56], f16),), {})
+cnt: 2, ((T([128, 128, 28, 28], f16),), {})
+cnt: 2, ((T([128, 512, 28, 28], f16),), {})
+cnt: 6, ((T([128, 384, 14, 14], f16),), {})
+cnt: 6, ((T([128, 1536, 14, 14], f16),), {})
+cnt: 3, ((T([128, 384, 7, 7], f16),), {})
+cnt: 2, ((T([128, 1536, 7, 7], f16),), {})
+Operator: aten.silu_.default
+cnt: 1, ((T([128, 16, 112, 112], f16),), {})
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 2, ((T([128, 64, 56, 56], f16),), {})
+cnt: 1, ((T([128, 128, 56, 56], f16),), {})
+cnt: 3, ((T([128, 128, 28, 28], f16),), {})
+cnt: 1, ((T([128, 384, 28, 28], f16),), {})
+cnt: 12, ((T([128, 384, 14, 14], f16),), {})
+cnt: 5, ((T([128, 384, 7, 7], f16),), {})
+cnt: 1, ((T([128, 2304, 7, 7], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 1, ((T([128, 2304, 7, 7], f16), T([128, 2304, 7, 7], f16)), {})
+cnt: 8, ((T([128, 384, 7, 7], f16), T([128, 384, 7, 7], f16)), {})
+cnt: 2, ((T([128, 1536, 7, 7], f16), T([128, 1536, 7, 7], f16)), {})
+cnt: 18, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16)), {})
+cnt: 6, ((T([128, 1536, 14, 14], f16), T([128, 1536, 14, 14], f16)), {})
+cnt: 1, ((T([128, 384, 28, 28], f16), T([128, 384, 28, 28], f16)), {})
+cnt: 2, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16)), {})
+cnt: 5, ((T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16)), {})
+cnt: 2, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16)), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+cnt: 3, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16)), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16)), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16)), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 3, ((T([128, 1536, 7, 7], f16), [2, 3], True), {})
+cnt: 6, ((T([128, 1536, 14, 14], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 512, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 9, ((T([128, 384, 1, 1], f16), T([128, 384, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 128, 1, 1], f16), T([128, 128, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([128, 64, 1, 1], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/pit_b_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/pit_b_224_training.txt
new file mode 100644
index 0000000000000..d26a9ef24d6f2
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/pit_b_224_training.txt
@@ -0,0 +1,185 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 3, ((T([64, 4, 962, 962], f16), -1, False), {})
+cnt: 6, ((T([64, 8, 257, 257], f16), -1, False), {})
+cnt: 4, ((T([64, 16, 65, 65], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 4, ((T([64, 16, 65, 65], f16), T([64, 16, 65, 65], f16), -1, f16), {})
+cnt: 6, ((T([64, 8, 257, 257], f16), T([64, 8, 257, 257], f16), -1, f16), {})
+cnt: 3, ((T([64, 4, 962, 962], f16), T([64, 4, 962, 962], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 9, ((T([64, 4, 962, 64], f16), [256, 962, 64]), {})
+cnt: 3, ((T([64, 4, 64, 962], f16), [256, 64, 962]), {})
+cnt: 3, ((T([256, 962, 962], f16), [64, 4, 962, 962]), {})
+cnt: 3, ((T([256, 962, 64], f16), [64, 4, 962, 64]), {})
+cnt: 3, ((T([64, 962, 4, 64], f16), [64, 962, 256]), {})
+cnt: 1, ((T([64, 512], f16), [64, 1, 512]), {})
+cnt: 18, ((T([64, 8, 257, 64], f16), [512, 257, 64]), {})
+cnt: 6, ((T([64, 8, 64, 257], f16), [512, 64, 257]), {})
+cnt: 6, ((T([512, 257, 257], f16), [64, 8, 257, 257]), {})
+cnt: 6, ((T([512, 257, 64], f16), [64, 8, 257, 64]), {})
+cnt: 6, ((T([64, 257, 8, 64], f16), [64, 257, 512]), {})
+cnt: 1, ((T([64, 1024], f16), [64, 1, 1024]), {})
+cnt: 12, ((T([64, 16, 65, 64], f16), [1024, 65, 64]), {})
+cnt: 4, ((T([64, 16, 64, 65], f16), [1024, 64, 65]), {})
+cnt: 4, ((T([1024, 65, 65], f16), [64, 16, 65, 65]), {})
+cnt: 4, ((T([1024, 65, 64], f16), [64, 16, 65, 64]), {})
+cnt: 4, ((T([64, 65, 16, 64], f16), [64, 65, 1024]), {})
+cnt: 4, ((T([64, 65, 3, 16, 64], f16), [64, 65, 3072]), {})
+cnt: 6, ((T([64, 257, 3, 8, 64], f16), [64, 257, 1536]), {})
+cnt: 3, ((T([64, 962, 3, 4, 64], f16), [64, 962, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([64, 256, 31, 31], f16), T([1, 256, 31, 31], f16)), {})
+cnt: 13, ((T([64, 962, 256], f16), T([64, 962, 256], f16)), {})
+cnt: 1, ((T([64, 1, 512], f16), T([512], f16)), {})
+cnt: 25, ((T([64, 257, 512], f16), T([64, 257, 512], f16)), {})
+cnt: 1, ((T([64, 1, 1024], f16), T([1024], f16)), {})
+cnt: 16, ((T([64, 65, 1024], f16), T([64, 65, 1024], f16)), {})
+Operator: aten.addmm.default
+cnt: 3, ((T([768], f16), T([61568, 256], f16), T([256, 768], f16, stride=(1, 256))), {})
+cnt: 3, ((T([256], f16), T([61568, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 3, ((T([1024], f16), T([61568, 256], f16), T([256, 1024], f16, stride=(1, 256))), {})
+cnt: 3, ((T([256], f16), T([61568, 1024], f16), T([1024, 256], f16, stride=(1, 1024))), {})
+cnt: 6, ((T([1536], f16), T([16448, 512], f16), T([512, 1536], f16, stride=(1, 512))), {})
+cnt: 6, ((T([512], f16), T([16448, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 6, ((T([2048], f16), T([16448, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 6, ((T([512], f16), T([16448, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+cnt: 4, ((T([3072], f16), T([4160, 1024], f16), T([1024, 3072], f16, stride=(1, 1024))), {})
+cnt: 4, ((T([1024], f16), T([4160, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 4, ((T([4096], f16), T([4160, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 4, ((T([1024], f16), T([4160, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+cnt: 1, ((T([1000], f16), T([64, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.bmm.default
+cnt: 3, ((T([256, 962, 64], f16), T([256, 64, 962], f16)), {})
+cnt: 3, ((T([256, 962, 962], f16), T([256, 962, 64], f16)), {})
+cnt: 6, ((T([512, 257, 64], f16), T([512, 64, 257], f16)), {})
+cnt: 6, ((T([512, 257, 257], f16), T([512, 257, 64], f16)), {})
+cnt: 4, ((T([1024, 65, 64], f16), T([1024, 64, 65], f16)), {})
+cnt: 4, ((T([1024, 65, 65], f16), T([1024, 65, 64], f16)), {})
+cnt: 4, ((T([1024, 65, 65], f16, stride=(4225, 1, 65)), T([1024, 65, 64], f16)), {})
+cnt: 4, ((T([1024, 65, 64], f16), T([1024, 64, 65], f16, stride=(4160, 1, 64))), {})
+cnt: 4, ((T([1024, 64, 65], f16, stride=(4160, 1, 64)), T([1024, 65, 65], f16)), {})
+cnt: 4, ((T([1024, 65, 65], f16), T([1024, 65, 64], f16, stride=(4160, 1, 65))), {})
+cnt: 6, ((T([512, 257, 257], f16, stride=(66049, 1, 257)), T([512, 257, 64], f16)), {})
+cnt: 6, ((T([512, 257, 64], f16), T([512, 64, 257], f16, stride=(16448, 1, 64))), {})
+cnt: 6, ((T([512, 64, 257], f16, stride=(16448, 1, 64)), T([512, 257, 257], f16)), {})
+cnt: 6, ((T([512, 257, 257], f16), T([512, 257, 64], f16, stride=(16448, 1, 257))), {})
+cnt: 3, ((T([256, 962, 962], f16, stride=(925444, 1, 962)), T([256, 962, 64], f16)), {})
+cnt: 3, ((T([256, 962, 64], f16), T([256, 64, 962], f16, stride=(61568, 1, 64))), {})
+cnt: 3, ((T([256, 64, 962], f16, stride=(61568, 1, 64)), T([256, 962, 962], f16)), {})
+cnt: 3, ((T([256, 962, 962], f16), T([256, 962, 64], f16, stride=(61568, 1, 962))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 1, 256], f16, stride=(0, 256, 1)), T([64, 961, 256], f16, stride=(246016, 1, 961))], 1), {})
+cnt: 1, (([T([64, 1, 512], f16), T([64, 256, 512], f16, stride=(131072, 1, 256))], 1), {})
+cnt: 1, (([T([64, 1, 1024], f16), T([64, 64, 1024], f16, stride=(65536, 1, 64))], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([256, 3, 14, 14], f16), T([256], f16), [7, 7], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 31, 31], f16, stride=(246272, 1, 7936, 256)), T([512, 1, 3, 3], f16), T([512], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 256), {})
+cnt: 1, ((T([64, 512, 16, 16], f16, stride=(131584, 1, 8192, 512)), T([1024, 1, 3, 3], f16), T([1024], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 512), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 1024, 8, 8], f16, stride=(66560, 1, 8192, 1024)), T([64, 512, 16, 16], f16, stride=(131584, 1, 8192, 512)), T([1024, 1, 3, 3], f16), [1024], [2, 2], [1, 1], [1, 1], False, [0, 0], 512, [True, True, True]), {})
+cnt: 1, ((T([64, 512, 16, 16], f16, stride=(131584, 1, 8192, 512)), T([64, 256, 31, 31], f16, stride=(246272, 1, 7936, 256)), T([512, 1, 3, 3], f16), [512], [2, 2], [1, 1], [1, 1], False, [0, 0], 256, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 31, 31], f16, stride=(246272, 1, 7936, 256)), T([64, 3, 224, 224], f16), T([256, 3, 14, 14], f16), [256], [7, 7], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.gelu.default
+cnt: 3, ((T([64, 962, 1024], f16),), {})
+cnt: 6, ((T([64, 257, 2048], f16),), {})
+cnt: 4, ((T([64, 65, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 4, ((T([64, 65, 4096], f16), T([64, 65, 4096], f16)), {})
+cnt: 6, ((T([64, 257, 2048], f16), T([64, 257, 2048], f16)), {})
+cnt: 3, ((T([64, 962, 1024], f16), T([64, 962, 1024], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 256], f16, stride=(246272, 1)), T([256, 512], f16, stride=(1, 256))), {})
+cnt: 1, ((T([64, 512], f16, stride=(131584, 1)), T([512, 1024], f16, stride=(1, 512))), {})
+cnt: 1, ((T([64, 1000], f16), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 1024], f16)), {})
+cnt: 4, ((T([4160, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 4, ((T([1024, 4160], f16, stride=(1, 1024)), T([4160, 4096], f16)), {})
+cnt: 4, ((T([4160, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 4, ((T([4096, 4160], f16, stride=(1, 4096)), T([4160, 1024], f16)), {})
+cnt: 4, ((T([4160, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 4, ((T([1024, 4160], f16, stride=(1, 1024)), T([4160, 1024], f16)), {})
+cnt: 4, ((T([4160, 3072], f16), T([3072, 1024], f16)), {})
+cnt: 4, ((T([3072, 4160], f16, stride=(1, 3072)), T([4160, 1024], f16)), {})
+cnt: 1, ((T([1024, 64], f16, stride=(1, 66560)), T([64, 512], f16, stride=(131584, 1))), {})
+cnt: 1, ((T([64, 1024], f16, stride=(66560, 1)), T([1024, 512], f16)), {})
+cnt: 6, ((T([16448, 512], f16), T([512, 2048], f16)), {})
+cnt: 6, ((T([512, 16448], f16, stride=(1, 512)), T([16448, 2048], f16)), {})
+cnt: 6, ((T([16448, 2048], f16), T([2048, 512], f16)), {})
+cnt: 6, ((T([2048, 16448], f16, stride=(1, 2048)), T([16448, 512], f16)), {})
+cnt: 6, ((T([16448, 512], f16), T([512, 512], f16)), {})
+cnt: 6, ((T([512, 16448], f16, stride=(1, 512)), T([16448, 512], f16)), {})
+cnt: 6, ((T([16448, 1536], f16), T([1536, 512], f16)), {})
+cnt: 6, ((T([1536, 16448], f16, stride=(1, 1536)), T([16448, 512], f16)), {})
+cnt: 1, ((T([512, 64], f16, stride=(1, 131584)), T([64, 256], f16, stride=(246272, 1))), {})
+cnt: 1, ((T([64, 512], f16, stride=(131584, 1)), T([512, 256], f16)), {})
+cnt: 3, ((T([61568, 256], f16), T([256, 1024], f16)), {})
+cnt: 3, ((T([256, 61568], f16, stride=(1, 256)), T([61568, 1024], f16)), {})
+cnt: 3, ((T([61568, 1024], f16), T([1024, 256], f16)), {})
+cnt: 3, ((T([1024, 61568], f16, stride=(1, 1024)), T([61568, 256], f16)), {})
+cnt: 3, ((T([61568, 256], f16), T([256, 256], f16)), {})
+cnt: 3, ((T([256, 61568], f16, stride=(1, 256)), T([61568, 256], f16)), {})
+cnt: 3, ((T([61568, 768], f16), T([768, 256], f16)), {})
+cnt: 3, ((T([768, 61568], f16, stride=(1, 768)), T([61568, 256], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 6, ((T([64, 4, 962, 962], f16), 0.125), {})
+cnt: 12, ((T([64, 8, 257, 257], f16), 0.125), {})
+cnt: 8, ((T([64, 16, 65, 65], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 6, ((T([64, 962, 256], f16), [256], T([256], f16), T([256], f16), 1e-06), {})
+cnt: 12, ((T([64, 257, 512], f16), [512], T([512], f16), T([512], f16), 1e-06), {})
+cnt: 8, ((T([64, 65, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-06), {})
+cnt: 1, ((T([64, 1, 1024], f16, stride=(66560, 1024, 1)), [1024], T([1024], f16), T([1024], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 1, ((T([64, 1, 1024], f16), T([64, 1, 1024], f16, stride=(66560, 1024, 1)), [1024], T([64, 1, 1], f32), T([64, 1, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+cnt: 8, ((T([64, 65, 1024], f16), T([64, 65, 1024], f16), [1024], T([64, 65, 1], f32), T([64, 65, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+cnt: 12, ((T([64, 257, 512], f16), T([64, 257, 512], f16), [512], T([64, 257, 1], f32), T([64, 257, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 6, ((T([64, 962, 256], f16), T([64, 962, 256], f16), [256], T([64, 962, 1], f32), T([64, 962, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([64, 1024], f16), [64, 1, 1024], 1, 0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([64, 1, 1024], f16), [64, 1, 1024], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 1, 1024], f16), [64, 65, 1024], 1, 0, 1, 1), {})
+cnt: 1, ((T([64, 65, 1024], f16), [64, 65, 1024], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 256, 512], f16), [64, 257, 512], 1, 1, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 257, 512], f16), [64, 257, 512], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 1, 512], f16), [64, 257, 512], 1, 0, 1, 1), {})
+cnt: 1, ((T([64, 961, 256], f16), [64, 962, 256], 1, 1, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 962, 256], f16), [64, 962, 256], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([64, 1, 256], f16), [64, 962, 256], 1, 0, 1, 1), {})
+Operator: aten.stack.default
+cnt: 4, (([T([64, 16, 65, 64], f16), T([64, 16, 65, 64], f16, stride=(66560, 4160, 1, 65)), T([64, 16, 65, 64], f16)],), {})
+cnt: 6, (([T([64, 8, 257, 64], f16), T([64, 8, 257, 64], f16, stride=(131584, 16448, 1, 257)), T([64, 8, 257, 64], f16)],), {})
+cnt: 3, (([T([64, 4, 962, 64], f16), T([64, 4, 962, 64], f16, stride=(246272, 61568, 1, 962)), T([64, 4, 962, 64], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 8, ((T([4160, 1024], f16), [0], True), {})
+cnt: 4, ((T([4160, 4096], f16), [0], True), {})
+cnt: 4, ((T([4160, 3072], f16), [0], True), {})
+cnt: 1, ((T([64, 1, 1024], f16, stride=(66560, 1024, 1)), [0, 1], True), {})
+cnt: 12, ((T([16448, 512], f16), [0], True), {})
+cnt: 6, ((T([16448, 2048], f16), [0], True), {})
+cnt: 6, ((T([16448, 1536], f16), [0], True), {})
+cnt: 1, ((T([64, 1, 512], f16, stride=(131584, 512, 1)), [0, 1], True), {})
+cnt: 6, ((T([61568, 256], f16), [0], True), {})
+cnt: 3, ((T([61568, 1024], f16), [0], True), {})
+cnt: 3, ((T([61568, 768], f16), [0], True), {})
+cnt: 1, ((T([64, 1, 256], f16, stride=(246272, 256, 1)), [0], True), {})
+cnt: 1, ((T([64, 256, 31, 31], f16, stride=(246272, 1, 7936, 256)), [0], True), {})
+Operator: aten.unbind.int
+cnt: 3, ((T([3, 64, 4, 962, 64], f16, stride=(256, 738816, 64, 768, 1)),), {})
+cnt: 6, ((T([3, 64, 8, 257, 64], f16, stride=(512, 394752, 64, 1536, 1)),), {})
+cnt: 4, ((T([3, 64, 16, 65, 64], f16, stride=(1024, 199680, 64, 3072, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/pnasnet5large_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/pnasnet5large_training.txt
new file mode 100644
index 0000000000000..c6d164aa51780
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/pnasnet5large_training.txt
@@ -0,0 +1,293 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([16, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([16, 1000], f16), T([16, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([], i64), 1), {})
+cnt: 5, ((T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16)), {})
+cnt: 5, ((T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16)), {})
+cnt: 44, ((T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16)), {})
+cnt: 38, ((T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16)), {})
+cnt: 38, ((T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16)), {})
+cnt: 7, ((T([16, 864, 11, 11], f16, stride=(522720, 121, 11, 1)), T([16, 864, 11, 11], f16)), {})
+cnt: 2, ((T([16, 4320, 11, 11], f16), T([16, 4320, 11, 11], f16)), {})
+cnt: 5, ((T([16, 2160, 21, 21], f16), T([16, 2160, 21, 21], f16)), {})
+cnt: 7, ((T([16, 864, 21, 21], f16), T([16, 864, 21, 21], f16)), {})
+cnt: 7, ((T([16, 432, 21, 21], f16, stride=(952560, 441, 21, 1)), T([16, 432, 21, 21], f16)), {})
+cnt: 5, ((T([16, 1080, 42, 42], f16), T([16, 1080, 42, 42], f16)), {})
+cnt: 7, ((T([16, 432, 42, 42], f16), T([16, 432, 42, 42], f16)), {})
+cnt: 8, ((T([16, 216, 42, 42], f16, stride=(1905120, 1764, 42, 1)), T([16, 216, 42, 42], f16)), {})
+cnt: 1, ((T([16, 540, 42, 42], f16), T([16, 540, 42, 42], f16)), {})
+cnt: 2, ((T([16, 270, 83, 83], f16), T([16, 270, 83, 83], f16)), {})
+cnt: 7, ((T([16, 108, 83, 83], f16), T([16, 108, 83, 83], f16)), {})
+cnt: 1, ((T([16, 108, 42, 42], f16, stride=(952560, 1764, 42, 1)), T([16, 108, 42, 42], f16)), {})
+cnt: 5, ((T([16, 96, 165, 165], f16), T([16, 96, 165, 165], f16)), {})
+cnt: 5, ((T([16, 54, 165, 165], f16), T([16, 54, 165, 165], f16)), {})
+cnt: 1, ((T([16, 54, 83, 83], f16, stride=(1860030, 6889, 83, 1)), T([16, 54, 83, 83], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 200, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([16, 4320], f16), T([4320, 1000], f16, stride=(1, 4320))), {})
+Operator: aten.avg_pool2d.default
+cnt: 2, ((T([16, 96, 165, 165], f16), [1, 1], [2, 2], [0, 0], False, False), {})
+cnt: 2, ((T([16, 270, 83, 83], f16), [1, 1], [2, 2], [0, 0], False, False), {})
+cnt: 2, ((T([16, 1080, 42, 42], f16), [1, 1], [2, 2], [0, 0], False, False), {})
+cnt: 2, ((T([16, 2160, 21, 21], f16), [1, 1], [2, 2], [0, 0], False, False), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 2, ((T([16, 2160, 11, 11], f16), T([16, 2160, 21, 21], f16), [1, 1], [2, 2], [0, 0], False, False, None), {})
+cnt: 2, ((T([16, 1080, 21, 21], f16), T([16, 1080, 42, 42], f16), [1, 1], [2, 2], [0, 0], False, False, None), {})
+cnt: 2, ((T([16, 270, 42, 42], f16), T([16, 270, 83, 83], f16), [1, 1], [2, 2], [0, 0], False, False, None), {})
+cnt: 2, ((T([16, 96, 83, 83], f16), T([16, 96, 165, 165], f16), [1, 1], [2, 2], [0, 0], False, False, None), {})
+Operator: aten.cat.default
+cnt: 1, (([T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16)], 1), {})
+cnt: 1, (([T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16)], 1), {})
+cnt: 1, (([T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16)], 1), {})
+cnt: 1, (([T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16)], 1), {})
+cnt: 4, (([T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16)], 1), {})
+cnt: 4, (([T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16)], 1), {})
+cnt: 1, (([T([16, 216, 21, 21], f16), T([16, 216, 21, 21], f16)], 1), {})
+cnt: 4, (([T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16)], 1), {})
+cnt: 1, (([T([16, 432, 11, 11], f16), T([16, 432, 11, 11], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([16, 3, 331, 331], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([16, 96, 165, 165], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), [1, 1, 1, 1], -inf), {})
+cnt: 1, ((T([16, 54, 165, 165], f16), [3, 3, 3, 3], 0.0), {})
+cnt: 2, ((T([16, 54, 165, 165], f16), [1, 1, 1, 1], -inf), {})
+cnt: 1, ((T([16, 54, 165, 165], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 1, ((T([16, 54, 165, 165], f16), [1, 1, 1, 1], 0.0), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), [1, 1, 1, 1], 0.0), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), [-1, 1, -1, 1], 0.0), {})
+cnt: 2, ((T([16, 108, 83, 83], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 3, ((T([16, 108, 83, 83], f16), [1, 1, 1, 1], -inf), {})
+cnt: 1, ((T([16, 108, 83, 83], f16), [3, 3, 3, 3], 0.0), {})
+cnt: 2, ((T([16, 108, 83, 83], f16), [1, 1, 1, 1], 0.0), {})
+cnt: 1, ((T([16, 270, 83, 83], f16), [-1, 1, -1, 1], 0.0), {})
+cnt: 2, ((T([16, 432, 42, 42], f16), [1, 2, 1, 2], 0.0), {})
+cnt: 3, ((T([16, 432, 42, 42], f16), [0, 1, 0, 1], -inf), {})
+cnt: 1, ((T([16, 432, 42, 42], f16), [2, 3, 2, 3], 0.0), {})
+cnt: 2, ((T([16, 432, 42, 42], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([16, 1080, 42, 42], f16), [-1, 1, -1, 1], 0.0), {})
+cnt: 2, ((T([16, 864, 21, 21], f16), [2, 2, 2, 2], 0.0), {})
+cnt: 3, ((T([16, 864, 21, 21], f16), [1, 1, 1, 1], -inf), {})
+cnt: 1, ((T([16, 864, 21, 21], f16), [3, 3, 3, 3], 0.0), {})
+cnt: 2, ((T([16, 864, 21, 21], f16), [1, 1, 1, 1], 0.0), {})
+cnt: 1, ((T([16, 2160, 21, 21], f16), [-1, 1, -1, 1], 0.0), {})
+cnt: 1, ((T([16, 2160, 21, 21], f16), [1, -1, 1, -1]), {})
+cnt: 5, ((T([16, 864, 23, 23], f16), [-1, -1, -1, -1]), {})
+cnt: 2, ((T([16, 864, 25, 25], f16), [-2, -2, -2, -2]), {})
+cnt: 1, ((T([16, 864, 27, 27], f16), [-3, -3, -3, -3]), {})
+cnt: 1, ((T([16, 1080, 42, 42], f16), [1, -1, 1, -1]), {})
+cnt: 5, ((T([16, 432, 43, 43], f16), [0, -1, 0, -1]), {})
+cnt: 2, ((T([16, 432, 45, 45], f16), [-1, -2, -1, -2]), {})
+cnt: 1, ((T([16, 432, 47, 47], f16), [-2, -3, -2, -3]), {})
+cnt: 1, ((T([16, 270, 83, 83], f16), [1, -1, 1, -1]), {})
+cnt: 5, ((T([16, 108, 85, 85], f16), [-1, -1, -1, -1]), {})
+cnt: 2, ((T([16, 108, 87, 87], f16), [-2, -2, -2, -2]), {})
+cnt: 1, ((T([16, 108, 89, 89], f16), [-3, -3, -3, -3]), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), [1, -1, 1, -1]), {})
+cnt: 2, ((T([16, 96, 167, 167], f16), [-1, -1, -1, -1]), {})
+cnt: 3, ((T([16, 54, 167, 167], f16), [-1, -1, -1, -1]), {})
+cnt: 1, ((T([16, 54, 169, 169], f16), [-2, -2, -2, -2]), {})
+cnt: 1, ((T([16, 54, 171, 171], f16), [-3, -3, -3, -3]), {})
+cnt: 1, ((T([16, 96, 169, 169], f16), [-2, -2, -2, -2]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([16, 3, 331, 331], f16), T([96, 3, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), T([54, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 96, 169, 169], f16), T([96, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 96), {})
+cnt: 5, ((T([16, 96, 83, 83], f16), T([54, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 54, 83, 83], f16), T([54, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 54), {})
+cnt: 10, ((T([16, 54, 83, 83], f16), T([54, 54, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 54, 171, 171], f16), T([54, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 54), {})
+cnt: 1, ((T([16, 54, 83, 83], f16), T([54, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 54), {})
+cnt: 1, ((T([16, 54, 169, 169], f16), T([54, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 54), {})
+cnt: 1, ((T([16, 54, 167, 167], f16), T([54, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 54), {})
+cnt: 4, ((T([16, 54, 83, 83], f16), T([54, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 54), {})
+cnt: 1, ((T([16, 96, 167, 167], f16), T([96, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([16, 54, 165, 165], f16), T([54, 54, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 270, 83, 83], f16), T([108, 270, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 108, 87, 87], f16), T([108, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 108), {})
+cnt: 12, ((T([16, 108, 42, 42], f16), T([108, 108, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 108, 42, 42], f16), T([108, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 108), {})
+cnt: 1, ((T([16, 108, 89, 89], f16), T([108, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 108), {})
+cnt: 1, ((T([16, 108, 42, 42], f16), T([108, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 108), {})
+cnt: 2, ((T([16, 108, 85, 85], f16), T([108, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 108), {})
+cnt: 4, ((T([16, 108, 42, 42], f16), T([108, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 108), {})
+cnt: 1, ((T([16, 108, 83, 83], f16), T([108, 108, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 270, 42, 42], f16), T([108, 270, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 540, 42, 42], f16), T([216, 540, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 16, ((T([16, 216, 42, 42], f16), T([216, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 216), {})
+cnt: 48, ((T([16, 216, 42, 42], f16), T([216, 216, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([16, 216, 42, 42], f16), T([216, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 216), {})
+cnt: 24, ((T([16, 216, 42, 42], f16), T([216, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 216), {})
+cnt: 5, ((T([16, 1080, 42, 42], f16), T([216, 1080, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 1080, 42, 42], f16), T([432, 1080, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 432, 45, 45], f16), T([432, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 432), {})
+cnt: 48, ((T([16, 432, 21, 21], f16), T([432, 432, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 14, ((T([16, 432, 21, 21], f16), T([432, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 432), {})
+cnt: 1, ((T([16, 432, 47, 47], f16), T([432, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 432), {})
+cnt: 7, ((T([16, 432, 21, 21], f16), T([432, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 432), {})
+cnt: 2, ((T([16, 432, 43, 43], f16), T([432, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 432), {})
+cnt: 22, ((T([16, 432, 21, 21], f16), T([432, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 432), {})
+cnt: 1, ((T([16, 432, 42, 42], f16), T([432, 432, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 1080, 21, 21], f16), T([216, 1080, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([16, 2160, 21, 21], f16), T([432, 2160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 2160, 21, 21], f16), T([864, 2160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 864, 25, 25], f16), T([864, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 864), {})
+cnt: 48, ((T([16, 864, 11, 11], f16), T([864, 864, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 14, ((T([16, 864, 11, 11], f16), T([864, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 864), {})
+cnt: 1, ((T([16, 864, 27, 27], f16), T([864, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 864), {})
+cnt: 7, ((T([16, 864, 11, 11], f16), T([864, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 864), {})
+cnt: 2, ((T([16, 864, 23, 23], f16), T([864, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 864), {})
+cnt: 22, ((T([16, 864, 11, 11], f16), T([864, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 864), {})
+cnt: 1, ((T([16, 864, 21, 21], f16), T([864, 864, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([16, 2160, 11, 11], f16), T([432, 2160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([16, 4320, 11, 11], f16), T([864, 4320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 48, ((T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16), T([864, 864, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 22, ((T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16), T([864, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 864, [True, True, False]), {})
+cnt: 14, ((T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16), T([864, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 864, [True, True, False]), {})
+cnt: 7, ((T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16), T([864, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 864, [True, True, False]), {})
+cnt: 5, ((T([16, 864, 11, 11], f16), T([16, 4320, 11, 11], f16), T([864, 4320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 432, 11, 11], f16, stride=(104544, 121, 11, 1)), T([16, 2160, 11, 11], f16), T([432, 2160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 864, 11, 11], f16), T([16, 864, 21, 21], f16), T([864, 864, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 864, 11, 11], f16), T([16, 864, 23, 23], f16), T([864, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 864, [True, True, False]), {})
+cnt: 2, ((T([16, 864, 11, 11], f16), T([16, 864, 25, 25], f16), T([864, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 864, [True, True, False]), {})
+cnt: 1, ((T([16, 864, 11, 11], f16), T([16, 864, 27, 27], f16), T([864, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 864, [True, True, False]), {})
+cnt: 2, ((T([16, 864, 21, 21], f16), T([16, 2160, 21, 21], f16), T([864, 2160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 48, ((T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16), T([432, 432, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 22, ((T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16), T([432, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 432, [True, True, False]), {})
+cnt: 14, ((T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16), T([432, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 432, [True, True, False]), {})
+cnt: 7, ((T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16), T([432, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 432, [True, True, False]), {})
+cnt: 5, ((T([16, 432, 21, 21], f16), T([16, 2160, 21, 21], f16), T([432, 2160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 216, 21, 21], f16, stride=(190512, 441, 21, 1)), T([16, 1080, 21, 21], f16), T([216, 1080, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 432, 21, 21], f16), T([16, 432, 42, 42], f16), T([432, 432, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 432, 21, 21], f16), T([16, 432, 43, 43], f16), T([432, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 432, [True, True, False]), {})
+cnt: 2, ((T([16, 432, 21, 21], f16), T([16, 432, 45, 45], f16), T([432, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 432, [True, True, False]), {})
+cnt: 1, ((T([16, 432, 21, 21], f16), T([16, 432, 47, 47], f16), T([432, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 432, [True, True, False]), {})
+cnt: 2, ((T([16, 432, 42, 42], f16), T([16, 1080, 42, 42], f16), T([432, 1080, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 48, ((T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16), T([216, 216, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 24, ((T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16), T([216, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 216, [True, True, False]), {})
+cnt: 16, ((T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16), T([216, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 216, [True, True, False]), {})
+cnt: 8, ((T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16), T([216, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 216, [True, True, False]), {})
+cnt: 5, ((T([16, 216, 42, 42], f16), T([16, 1080, 42, 42], f16), T([216, 1080, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 216, 42, 42], f16), T([16, 540, 42, 42], f16), T([216, 540, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 108, 42, 42], f16, stride=(381024, 1764, 42, 1)), T([16, 270, 42, 42], f16), T([108, 270, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 108, 42, 42], f16), T([16, 108, 83, 83], f16), T([108, 108, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 12, ((T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16), T([108, 108, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16), T([108, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 108, [True, True, False]), {})
+cnt: 2, ((T([16, 108, 42, 42], f16), T([16, 108, 85, 85], f16), T([108, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 108, [True, True, False]), {})
+cnt: 2, ((T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16), T([108, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 108, [True, True, False]), {})
+cnt: 2, ((T([16, 108, 42, 42], f16), T([16, 108, 87, 87], f16), T([108, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 108, [True, True, False]), {})
+cnt: 1, ((T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16), T([108, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 108, [True, True, False]), {})
+cnt: 1, ((T([16, 108, 42, 42], f16), T([16, 108, 89, 89], f16), T([108, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 108, [True, True, False]), {})
+cnt: 1, ((T([16, 108, 83, 83], f16), T([16, 270, 83, 83], f16), T([108, 270, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([16, 54, 83, 83], f16, stride=(744012, 6889, 83, 1)), T([16, 96, 83, 83], f16), T([54, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 54, 83, 83], f16), T([16, 54, 165, 165], f16), T([54, 54, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 10, ((T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16), T([54, 54, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16), T([54, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 54, [True, True, False]), {})
+cnt: 3, ((T([16, 54, 83, 83], f16), T([16, 96, 83, 83], f16), T([54, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 96, 83, 83], f16), T([16, 96, 167, 167], f16), T([96, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([16, 54, 83, 83], f16), T([16, 54, 167, 167], f16), T([54, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 54, [True, True, False]), {})
+cnt: 2, ((T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16), T([54, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 54, [True, True, False]), {})
+cnt: 1, ((T([16, 54, 83, 83], f16), T([16, 54, 169, 169], f16), T([54, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 54, [True, True, False]), {})
+cnt: 1, ((T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16), T([54, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 54, [True, True, False]), {})
+cnt: 1, ((T([16, 54, 83, 83], f16), T([16, 54, 171, 171], f16), T([54, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 54, [True, True, False]), {})
+cnt: 1, ((T([16, 96, 83, 83], f16), T([16, 96, 169, 169], f16), T([96, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([16, 54, 165, 165], f16), T([16, 96, 165, 165], f16), T([54, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), T([16, 3, 331, 331], f16), T([96, 3, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([16, 3, 331, 331], f16), T([16, 3, 331, 331], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([16, 4320, 11, 11], f16, stride=(4320, 1, 0, 0)), 121), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([16], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([16, 96, 167, 167], f16), [3, 3], [2, 2]), {})
+cnt: 2, ((T([16, 54, 167, 167], f16), [3, 3], [2, 2]), {})
+cnt: 3, ((T([16, 108, 85, 85], f16), [3, 3], [2, 2]), {})
+cnt: 12, ((T([16, 216, 42, 42], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 3, ((T([16, 432, 43, 43], f16), [3, 3], [2, 2]), {})
+cnt: 9, ((T([16, 432, 21, 21], f16), [3, 3], [1, 1], [1, 1]), {})
+cnt: 3, ((T([16, 864, 23, 23], f16), [3, 3], [2, 2]), {})
+cnt: 9, ((T([16, 864, 11, 11], f16), [3, 3], [1, 1], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 9, ((T([16, 864, 11, 11], f16, stride=(522720, 121, 11, 1)), T([16, 864, 11, 11], f16), [3, 3], [1, 1], [1, 1], [1, 1], False, T([16, 864, 11, 11], i64)), {})
+cnt: 3, ((T([16, 864, 11, 11], f16, stride=(522720, 121, 11, 1)), T([16, 864, 23, 23], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 864, 11, 11], i64)), {})
+cnt: 9, ((T([16, 432, 21, 21], f16, stride=(952560, 441, 21, 1)), T([16, 432, 21, 21], f16), [3, 3], [1, 1], [1, 1], [1, 1], False, T([16, 432, 21, 21], i64)), {})
+cnt: 3, ((T([16, 432, 21, 21], f16, stride=(952560, 441, 21, 1)), T([16, 432, 43, 43], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 432, 21, 21], i64)), {})
+cnt: 12, ((T([16, 216, 42, 42], f16, stride=(1905120, 1764, 42, 1)), T([16, 216, 42, 42], f16), [3, 3], [1, 1], [1, 1], [1, 1], False, T([16, 216, 42, 42], i64)), {})
+cnt: 3, ((T([16, 108, 42, 42], f16, stride=(952560, 1764, 42, 1)), T([16, 108, 85, 85], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 108, 42, 42], i64)), {})
+cnt: 2, ((T([16, 54, 83, 83], f16, stride=(1860030, 6889, 83, 1)), T([16, 54, 167, 167], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 54, 83, 83], i64)), {})
+cnt: 1, ((T([16, 96, 83, 83], f16), T([16, 96, 167, 167], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([16, 96, 83, 83], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([16, 4320, 11, 11], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([16, 1000], f16), T([1000, 4320], f16)), {})
+cnt: 1, ((T([1000, 16], f16, stride=(1, 1000)), T([16, 4320], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([16, 96, 165, 165], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([16, 54, 165, 165], f16), T([54], f16), T([54], f16), T([54], f16), T([54], f16), True, 0.1, 0.001), {})
+cnt: 14, ((T([16, 54, 83, 83], f16), T([54], f16), T([54], f16), T([54], f16), T([54], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([16, 108, 83, 83], f16), T([108], f16), T([108], f16), T([108], f16), T([108], f16), True, 0.1, 0.001), {})
+cnt: 13, ((T([16, 108, 42, 42], f16), T([108], f16), T([108], f16), T([108], f16), T([108], f16), True, 0.1, 0.001), {})
+cnt: 56, ((T([16, 216, 42, 42], f16), T([216], f16), T([216], f16), T([216], f16), T([216], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([16, 432, 42, 42], f16), T([432], f16), T([432], f16), T([432], f16), T([432], f16), True, 0.1, 0.001), {})
+cnt: 55, ((T([16, 432, 21, 21], f16), T([432], f16), T([432], f16), T([432], f16), T([432], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([16, 864, 21, 21], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f16), True, 0.1, 0.001), {})
+cnt: 55, ((T([16, 864, 11, 11], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f16), True, 0.1, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 17, ((T([16, 864, 11, 11], f16, stride=(522720, 121, 11, 1)), T([16, 864, 11, 11], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f32), T([864], f32), True, 0.001, [True, True, True]), {})
+cnt: 38, ((T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f32), T([864], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 864, 21, 21], f16), T([16, 864, 21, 21], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f32), T([864], f32), True, 0.001, [True, True, True]), {})
+cnt: 17, ((T([16, 432, 21, 21], f16, stride=(952560, 441, 21, 1)), T([16, 432, 21, 21], f16), T([432], f16), T([432], f16), T([432], f16), T([432], f32), T([432], f32), True, 0.001, [True, True, True]), {})
+cnt: 38, ((T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16), T([432], f16), T([432], f16), T([432], f16), T([432], f32), T([432], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 432, 42, 42], f16), T([16, 432, 42, 42], f16), T([432], f16), T([432], f16), T([432], f16), T([432], f32), T([432], f32), True, 0.001, [True, True, True]), {})
+cnt: 16, ((T([16, 216, 42, 42], f16, stride=(1905120, 1764, 42, 1)), T([16, 216, 42, 42], f16), T([216], f16), T([216], f16), T([216], f16), T([216], f32), T([216], f32), True, 0.001, [True, True, True]), {})
+cnt: 40, ((T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16), T([216], f16), T([216], f16), T([216], f16), T([216], f32), T([216], f32), True, 0.001, [True, True, True]), {})
+cnt: 5, ((T([16, 108, 42, 42], f16, stride=(952560, 1764, 42, 1)), T([16, 108, 42, 42], f16), T([108], f16), T([108], f16), T([108], f16), T([108], f32), T([108], f32), True, 0.001, [True, True, True]), {})
+cnt: 8, ((T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16), T([108], f16), T([108], f16), T([108], f16), T([108], f32), T([108], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([16, 108, 83, 83], f16), T([16, 108, 83, 83], f16), T([108], f16), T([108], f16), T([108], f16), T([108], f32), T([108], f32), True, 0.001, [True, True, True]), {})
+cnt: 6, ((T([16, 54, 83, 83], f16, stride=(1860030, 6889, 83, 1)), T([16, 54, 83, 83], f16), T([54], f16), T([54], f16), T([54], f16), T([54], f32), T([54], f32), True, 0.001, [True, True, True]), {})
+cnt: 8, ((T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16), T([54], f16), T([54], f16), T([54], f16), T([54], f32), T([54], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([16, 54, 165, 165], f16), T([16, 54, 165, 165], f16), T([54], f16), T([54], f16), T([54], f16), T([54], f32), T([54], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([16, 96, 165, 165], f16), T([16, 96, 165, 165], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([16, 1000], f16), T([16], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([16, 1000], f16), T([16], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 4, ((T([16, 96, 165, 165], f16),), {})
+cnt: 7, ((T([16, 54, 83, 83], f16),), {})
+cnt: 4, ((T([16, 54, 165, 165], f16),), {})
+cnt: 2, ((T([16, 270, 83, 83], f16),), {})
+cnt: 6, ((T([16, 108, 83, 83], f16),), {})
+cnt: 7, ((T([16, 108, 42, 42], f16),), {})
+cnt: 2, ((T([16, 540, 42, 42], f16),), {})
+cnt: 48, ((T([16, 216, 42, 42], f16),), {})
+cnt: 8, ((T([16, 1080, 42, 42], f16),), {})
+cnt: 6, ((T([16, 432, 42, 42], f16),), {})
+cnt: 43, ((T([16, 432, 21, 21], f16),), {})
+cnt: 8, ((T([16, 2160, 21, 21], f16),), {})
+cnt: 6, ((T([16, 864, 21, 21], f16),), {})
+cnt: 43, ((T([16, 864, 11, 11], f16),), {})
+cnt: 6, ((T([16, 4320, 11, 11], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([16, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 6, ((T([16, 4320, 11, 11], f16), T([16, 4320, 11, 11], f16), 0), {})
+cnt: 43, ((T([16, 864, 11, 11], f16), T([16, 864, 11, 11], f16), 0), {})
+cnt: 8, ((T([16, 2160, 21, 21], f16), T([16, 2160, 21, 21], f16), 0), {})
+cnt: 6, ((T([16, 864, 21, 21], f16), T([16, 864, 21, 21], f16), 0), {})
+cnt: 43, ((T([16, 432, 21, 21], f16), T([16, 432, 21, 21], f16), 0), {})
+cnt: 8, ((T([16, 1080, 42, 42], f16), T([16, 1080, 42, 42], f16), 0), {})
+cnt: 6, ((T([16, 432, 42, 42], f16), T([16, 432, 42, 42], f16), 0), {})
+cnt: 48, ((T([16, 216, 42, 42], f16), T([16, 216, 42, 42], f16), 0), {})
+cnt: 2, ((T([16, 540, 42, 42], f16), T([16, 540, 42, 42], f16), 0), {})
+cnt: 2, ((T([16, 270, 83, 83], f16), T([16, 270, 83, 83], f16), 0), {})
+cnt: 6, ((T([16, 108, 83, 83], f16), T([16, 108, 83, 83], f16), 0), {})
+cnt: 7, ((T([16, 108, 42, 42], f16), T([16, 108, 42, 42], f16), 0), {})
+cnt: 4, ((T([16, 96, 165, 165], f16), T([16, 96, 165, 165], f16), 0), {})
+cnt: 4, ((T([16, 54, 165, 165], f16), T([16, 54, 165, 165], f16), 0), {})
+cnt: 7, ((T([16, 54, 83, 83], f16), T([16, 54, 83, 83], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/poolformer_m36_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/poolformer_m36_training.txt
new file mode 100644
index 0000000000000..2cbc4a779e5b8
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/poolformer_m36_training.txt
@@ -0,0 +1,111 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 30, ((T([64, 96, 56, 56], f16), T([64, 96, 56, 56], f16)), {})
+cnt: 30, ((T([64, 192, 28, 28], f16), T([64, 192, 28, 28], f16)), {})
+cnt: 90, ((T([64, 384, 14, 14], f16), T([64, 384, 14, 14], f16)), {})
+cnt: 30, ((T([64, 768, 7, 7], f16), T([64, 768, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([64, 768], f16), T([768, 1000], f16, stride=(1, 768))), {})
+Operator: aten.avg_pool2d.default
+cnt: 6, ((T([64, 96, 56, 56], f16), [3, 3], [1, 1], [1, 1], False, False), {})
+cnt: 6, ((T([64, 192, 28, 28], f16), [3, 3], [1, 1], [1, 1], False, False), {})
+cnt: 18, ((T([64, 384, 14, 14], f16), [3, 3], [1, 1], [1, 1], False, False), {})
+cnt: 6, ((T([64, 768, 7, 7], f16), [3, 3], [1, 1], [1, 1], False, False), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 6, ((T([64, 768, 7, 7], f16), T([64, 768, 7, 7], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 18, ((T([64, 384, 14, 14], f16), T([64, 384, 14, 14], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 6, ((T([64, 192, 28, 28], f16), T([64, 192, 28, 28], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+cnt: 6, ((T([64, 96, 56, 56], f16), T([64, 96, 56, 56], f16), [3, 3], [1, 1], [1, 1], False, False, None), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([96, 3, 7, 7], f16), T([96], f16), [4, 4], [2, 2], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 96, 56, 56], f16), T([384, 96, 1, 1], f16), T([384], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 384, 56, 56], f16), T([96, 384, 1, 1], f16), T([96], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 96, 56, 56], f16), T([192, 96, 3, 3], f16), T([192], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 192, 28, 28], f16), T([768, 192, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 768, 28, 28], f16), T([192, 768, 1, 1], f16), T([192], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 192, 28, 28], f16), T([384, 192, 3, 3], f16), T([384], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 18, ((T([64, 384, 14, 14], f16), T([1536, 384, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 18, ((T([64, 1536, 14, 14], f16), T([384, 1536, 1, 1], f16), T([384], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 384, 14, 14], f16), T([768, 384, 3, 3], f16), T([768], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 768, 7, 7], f16), T([3072, 768, 1, 1], f16), T([3072], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 3072, 7, 7], f16), T([768, 3072, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 6, ((T([64, 768, 7, 7], f16), T([64, 3072, 7, 7], f16), T([768, 3072, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([64, 3072, 7, 7], f16), T([64, 768, 7, 7], f16), T([3072, 768, 1, 1], f16), [3072], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 768, 7, 7], f16), T([64, 384, 14, 14], f16), T([768, 384, 3, 3], f16), [768], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 18, ((T([64, 384, 14, 14], f16), T([64, 1536, 14, 14], f16), T([384, 1536, 1, 1], f16), [384], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 18, ((T([64, 1536, 14, 14], f16), T([64, 384, 14, 14], f16), T([1536, 384, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 384, 14, 14], f16), T([64, 192, 28, 28], f16), T([384, 192, 3, 3], f16), [384], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([64, 192, 28, 28], f16), T([64, 768, 28, 28], f16), T([192, 768, 1, 1], f16), [192], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([64, 768, 28, 28], f16), T([64, 192, 28, 28], f16), T([768, 192, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 192, 28, 28], f16), T([64, 96, 56, 56], f16), T([192, 96, 3, 3], f16), [192], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([64, 96, 56, 56], f16), T([64, 384, 56, 56], f16), T([96, 384, 1, 1], f16), [96], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([64, 384, 56, 56], f16), T([64, 96, 56, 56], f16), T([384, 96, 1, 1], f16), [384], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 96, 56, 56], f16), T([64, 3, 224, 224], f16), T([96, 3, 7, 7], f16), [96], [4, 4], [2, 2], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 768, 7, 7], f16, stride=(768, 1, 0, 0)), 49), {})
+Operator: aten.gelu.default
+cnt: 6, ((T([64, 384, 56, 56], f16),), {})
+cnt: 6, ((T([64, 768, 28, 28], f16),), {})
+cnt: 18, ((T([64, 1536, 14, 14], f16),), {})
+cnt: 6, ((T([64, 3072, 7, 7], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 6, ((T([64, 3072, 7, 7], f16), T([64, 3072, 7, 7], f16)), {})
+cnt: 18, ((T([64, 1536, 14, 14], f16), T([64, 1536, 14, 14], f16)), {})
+cnt: 6, ((T([64, 768, 28, 28], f16), T([64, 768, 28, 28], f16)), {})
+cnt: 6, ((T([64, 384, 56, 56], f16), T([64, 384, 56, 56], f16)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 768, 7, 7], f16), [-2, -1]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 768], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 12, ((T([96, 1, 1], f16), T([64, 96, 56, 56], f16)), {})
+cnt: 12, ((T([192, 1, 1], f16), T([64, 192, 28, 28], f16)), {})
+cnt: 36, ((T([384, 1, 1], f16), T([64, 384, 14, 14], f16)), {})
+cnt: 12, ((T([768, 1, 1], f16), T([64, 768, 7, 7], f16)), {})
+cnt: 12, ((T([64, 768, 7, 7], f16), T([768, 1, 1], f16)), {})
+cnt: 12, ((T([64, 768, 7, 7], f16), T([64, 768, 7, 7], f16)), {})
+cnt: 36, ((T([64, 384, 14, 14], f16), T([384, 1, 1], f16)), {})
+cnt: 36, ((T([64, 384, 14, 14], f16), T([64, 384, 14, 14], f16)), {})
+cnt: 12, ((T([64, 192, 28, 28], f16), T([192, 1, 1], f16)), {})
+cnt: 12, ((T([64, 192, 28, 28], f16), T([64, 192, 28, 28], f16)), {})
+cnt: 12, ((T([64, 96, 56, 56], f16), T([96, 1, 1], f16)), {})
+cnt: 12, ((T([64, 96, 56, 56], f16), T([64, 96, 56, 56], f16)), {})
+Operator: aten.native_group_norm.default
+cnt: 12, ((T([64, 96, 56, 56], f16), T([96], f16), T([96], f16), 64, 96, 3136, 1, 1e-05), {})
+cnt: 12, ((T([64, 192, 28, 28], f16), T([192], f16), T([192], f16), 64, 192, 784, 1, 1e-05), {})
+cnt: 36, ((T([64, 384, 14, 14], f16), T([384], f16), T([384], f16), 64, 384, 196, 1, 1e-05), {})
+cnt: 13, ((T([64, 768, 7, 7], f16), T([768], f16), T([768], f16), 64, 768, 49, 1, 1e-05), {})
+Operator: aten.native_group_norm_backward.default
+cnt: 13, ((T([64, 768, 7, 7], f16), T([64, 768, 7, 7], f16), T([64, 1], f16), T([64, 1], f16), T([768], f16), 64, 768, 49, 1, [True, True, True]), {})
+cnt: 36, ((T([64, 384, 14, 14], f16), T([64, 384, 14, 14], f16), T([64, 1], f16), T([64, 1], f16), T([384], f16), 64, 384, 196, 1, [True, True, True]), {})
+cnt: 12, ((T([64, 192, 28, 28], f16), T([64, 192, 28, 28], f16), T([64, 1], f16), T([64, 1], f16), T([192], f16), 64, 192, 784, 1, [True, True, True]), {})
+cnt: 12, ((T([64, 96, 56, 56], f16), T([64, 96, 56, 56], f16), T([64, 1], f16), T([64, 1], f16), T([96], f16), 64, 96, 3136, 1, [True, True, True]), {})
+Operator: aten.neg.default
+cnt: 6, ((T([64, 768, 7, 7], f16),), {})
+cnt: 18, ((T([64, 384, 14, 14], f16),), {})
+cnt: 6, ((T([64, 192, 28, 28], f16),), {})
+cnt: 6, ((T([64, 96, 56, 56], f16),), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.sub.Tensor
+cnt: 6, ((T([64, 96, 56, 56], f16), T([64, 96, 56, 56], f16)), {})
+cnt: 6, ((T([64, 192, 28, 28], f16), T([64, 192, 28, 28], f16)), {})
+cnt: 18, ((T([64, 384, 14, 14], f16), T([64, 384, 14, 14], f16)), {})
+cnt: 6, ((T([64, 768, 7, 7], f16), T([64, 768, 7, 7], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 12, ((T([64, 768, 7, 7], f16), [0, 2, 3], True), {})
+cnt: 36, ((T([64, 384, 14, 14], f16), [0, 2, 3], True), {})
+cnt: 12, ((T([64, 192, 28, 28], f16), [0, 2, 3], True), {})
+cnt: 12, ((T([64, 96, 56, 56], f16), [0, 2, 3], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/regnety_002_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/regnety_002_training.txt
new file mode 100644
index 0000000000000..99d7f8ac9b481
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/regnety_002_training.txt
@@ -0,0 +1,181 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 44, ((T([], i64), 1), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 3, ((T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16)), {})
+cnt: 12, ((T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16)), {})
+cnt: 20, ((T([128, 368, 7, 7], f16), T([128, 368, 7, 7], f16)), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 368], f16), T([368, 1000], f16, stride=(1, 368))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([24, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([24, 8, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 3), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([8, 24, 1, 1], f16), T([8], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([24, 8, 1, 1], f16), T([24], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([24, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([24, 32, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([56, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 56, 56, 56], f16), T([56, 8, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 7), {})
+cnt: 1, ((T([128, 56, 1, 1], f16), T([6, 56, 1, 1], f16), T([6], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 6, 1, 1], f16), T([56, 6, 1, 1], f16), T([56], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 56, 28, 28], f16), T([56, 56, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([56, 24, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 56, 28, 28], f16), T([152, 56, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 152, 28, 28], f16), T([152, 8, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 19), {})
+cnt: 1, ((T([128, 152, 1, 1], f16), T([14, 152, 1, 1], f16), T([14], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 14, 1, 1], f16), T([152, 14, 1, 1], f16), T([152], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 152, 14, 14], f16), T([152, 152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 56, 28, 28], f16), T([152, 56, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 152, 14, 14], f16), T([152, 8, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 19), {})
+cnt: 3, ((T([128, 152, 1, 1], f16), T([38, 152, 1, 1], f16), T([38], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 38, 1, 1], f16), T([152, 38, 1, 1], f16), T([152], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 152, 14, 14], f16), T([368, 152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 368, 14, 14], f16), T([368, 8, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 46), {})
+cnt: 1, ((T([128, 368, 1, 1], f16), T([38, 368, 1, 1], f16), T([38], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 38, 1, 1], f16), T([368, 38, 1, 1], f16), T([368], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 13, ((T([128, 368, 7, 7], f16), T([368, 368, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 152, 14, 14], f16), T([368, 152, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([128, 368, 7, 7], f16), T([368, 8, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 46), {})
+cnt: 6, ((T([128, 368, 1, 1], f16), T([92, 368, 1, 1], f16), T([92], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([128, 92, 1, 1], f16), T([368, 92, 1, 1], f16), T([368], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 13, ((T([128, 368, 7, 7], f16), T([128, 368, 7, 7], f16), T([368, 368, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([128, 368, 1, 1], f16), T([128, 92, 1, 1], f16), T([368, 92, 1, 1], f16), [368], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([128, 92, 1, 1], f16), T([128, 368, 1, 1], f16), T([92, 368, 1, 1], f16), [92], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([128, 368, 7, 7], f16), T([128, 368, 7, 7], f16), T([368, 8, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 46, [True, True, False]), {})
+cnt: 1, ((T([128, 368, 7, 7], f16), T([128, 152, 14, 14], f16), T([368, 152, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 368, 1, 1], f16), T([128, 38, 1, 1], f16), T([368, 38, 1, 1], f16), [368], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 38, 1, 1], f16), T([128, 368, 1, 1], f16), T([38, 368, 1, 1], f16), [38], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 368, 7, 7], f16), T([128, 368, 14, 14], f16), T([368, 8, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 46, [True, True, False]), {})
+cnt: 1, ((T([128, 368, 14, 14], f16), T([128, 152, 14, 14], f16), T([368, 152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16), T([152, 152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 152, 1, 1], f16), T([128, 38, 1, 1], f16), T([152, 38, 1, 1], f16), [152], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 38, 1, 1], f16), T([128, 152, 1, 1], f16), T([38, 152, 1, 1], f16), [38], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16), T([152, 8, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 19, [True, True, False]), {})
+cnt: 1, ((T([128, 152, 14, 14], f16), T([128, 56, 28, 28], f16), T([152, 56, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 152, 1, 1], f16), T([128, 14, 1, 1], f16), T([152, 14, 1, 1], f16), [152], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 14, 1, 1], f16), T([128, 152, 1, 1], f16), T([14, 152, 1, 1], f16), [14], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 152, 14, 14], f16), T([128, 152, 28, 28], f16), T([152, 8, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 19, [True, True, False]), {})
+cnt: 1, ((T([128, 152, 28, 28], f16), T([128, 56, 28, 28], f16), T([152, 56, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 56, 28, 28], f16), T([128, 24, 56, 56], f16), T([56, 24, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16), T([56, 56, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 56, 1, 1], f16), T([128, 6, 1, 1], f16), T([56, 6, 1, 1], f16), [56], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 6, 1, 1], f16), T([128, 56, 1, 1], f16), T([6, 56, 1, 1], f16), [6], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 56, 28, 28], f16), T([128, 56, 56, 56], f16), T([56, 8, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 7, [True, True, False]), {})
+cnt: 1, ((T([128, 56, 56, 56], f16), T([128, 24, 56, 56], f16), T([56, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 32, 112, 112], f16), T([24, 32, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([128, 8, 1, 1], f16), T([24, 8, 1, 1], f16), [24], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([128, 24, 1, 1], f16), T([8, 24, 1, 1], f16), [8], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 24, 112, 112], f16), T([24, 8, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 3, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([128, 32, 112, 112], f16), T([24, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 8, ((T([128, 368, 7, 7], f16, stride=(368, 1, 0, 0)), 49), {})
+cnt: 4, ((T([128, 152, 14, 14], f16, stride=(152, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 56, 28, 28], f16, stride=(56, 1, 0, 0)), 784), {})
+cnt: 1, ((T([128, 24, 56, 56], f16, stride=(24, 1, 0, 0)), 3136), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 24, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 56, 28, 28], f16), [2, 3], True), {})
+cnt: 4, ((T([128, 152, 14, 14], f16), [2, 3], True), {})
+cnt: 7, ((T([128, 368, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 368, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 368], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 368], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 1, 1], f16)), {})
+cnt: 2, ((T([128, 56, 28, 28], f16), T([128, 56, 1, 1], f16)), {})
+cnt: 8, ((T([128, 152, 14, 14], f16), T([128, 152, 1, 1], f16)), {})
+cnt: 14, ((T([128, 368, 7, 7], f16), T([128, 368, 1, 1], f16)), {})
+cnt: 7, ((T([128, 368, 7, 7], f16), T([128, 368, 7, 7], f16)), {})
+cnt: 4, ((T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16)), {})
+cnt: 1, ((T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16)), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 56, 56, 56], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 56, 28, 28], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 152, 28, 28], f16), T([152], f16), T([152], f16), T([152], f16), T([152], f16), True, 0.1, 1e-05), {})
+cnt: 12, ((T([128, 152, 14, 14], f16), T([152], f16), T([152], f16), T([152], f16), T([152], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 368, 14, 14], f16), T([368], f16), T([368], f16), T([368], f16), T([368], f16), True, 0.1, 1e-05), {})
+cnt: 21, ((T([128, 368, 7, 7], f16), T([368], f16), T([368], f16), T([368], f16), T([368], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 21, ((T([128, 368, 7, 7], f16), T([128, 368, 7, 7], f16), T([368], f16), T([368], f16), T([368], f16), T([368], f32), T([368], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 368, 14, 14], f16), T([128, 368, 14, 14], f16), T([368], f16), T([368], f16), T([368], f16), T([368], f32), T([368], f32), True, 1e-05, [True, True, True]), {})
+cnt: 12, ((T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16), T([152], f16), T([152], f16), T([152], f16), T([152], f32), T([152], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 152, 28, 28], f16), T([128, 152, 28, 28], f16), T([152], f16), T([152], f16), T([152], f16), T([152], f32), T([152], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f32), T([56], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 56, 56, 56], f16), T([128, 56, 56, 56], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f32), T([56], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([128, 24, 112, 112], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu.default
+cnt: 1, ((T([128, 24, 56, 56], f16),), {})
+cnt: 1, ((T([128, 56, 28, 28], f16),), {})
+cnt: 4, ((T([128, 152, 14, 14], f16),), {})
+cnt: 7, ((T([128, 368, 7, 7], f16),), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 24, 112, 112], f16),), {})
+cnt: 1, ((T([128, 24, 56, 56], f16),), {})
+cnt: 1, ((T([128, 8, 1, 1], f16),), {})
+cnt: 1, ((T([128, 56, 56, 56], f16),), {})
+cnt: 1, ((T([128, 56, 28, 28], f16),), {})
+cnt: 1, ((T([128, 6, 1, 1], f16),), {})
+cnt: 1, ((T([128, 152, 28, 28], f16),), {})
+cnt: 7, ((T([128, 152, 14, 14], f16),), {})
+cnt: 1, ((T([128, 14, 1, 1], f16),), {})
+cnt: 4, ((T([128, 38, 1, 1], f16),), {})
+cnt: 1, ((T([128, 368, 14, 14], f16),), {})
+cnt: 13, ((T([128, 368, 7, 7], f16),), {})
+cnt: 6, ((T([128, 92, 1, 1], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([128, 24, 1, 1], f16),), {})
+cnt: 1, ((T([128, 56, 1, 1], f16),), {})
+cnt: 4, ((T([128, 152, 1, 1], f16),), {})
+cnt: 7, ((T([128, 368, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 7, ((T([128, 368, 1, 1], f16), T([128, 368, 1, 1], f16)), {})
+cnt: 4, ((T([128, 152, 1, 1], f16), T([128, 152, 1, 1], f16)), {})
+cnt: 1, ((T([128, 56, 1, 1], f16), T([128, 56, 1, 1], f16)), {})
+cnt: 1, ((T([128, 24, 1, 1], f16), T([128, 24, 1, 1], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 7, ((T([128, 368, 7, 7], f16), [2, 3], True), {})
+cnt: 4, ((T([128, 152, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 56, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 20, ((T([128, 368, 7, 7], f16), T([128, 368, 7, 7], f16), 0), {})
+cnt: 6, ((T([128, 92, 1, 1], f16), T([128, 92, 1, 1], f16), 0), {})
+cnt: 4, ((T([128, 38, 1, 1], f16), T([128, 38, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 368, 14, 14], f16), T([128, 368, 14, 14], f16), 0), {})
+cnt: 11, ((T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 14, 1, 1], f16), T([128, 14, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 152, 28, 28], f16), T([128, 152, 28, 28], f16), 0), {})
+cnt: 2, ((T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 6, 1, 1], f16), T([128, 6, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 56, 56, 56], f16), T([128, 56, 56, 56], f16), 0), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([128, 8, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([128, 24, 112, 112], f16), 0), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/repvgg_a2_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/repvgg_a2_training.txt
new file mode 100644
index 0000000000000..ff6a44e15f6a2
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/repvgg_a2_training.txt
@@ -0,0 +1,90 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 61, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16)), {})
+cnt: 6, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16)), {})
+cnt: 14, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16)), {})
+cnt: 54, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16)), {})
+cnt: 1, ((T([128, 1408, 7, 7], f16), T([128, 1408, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1408], f16), T([1408, 1000], f16, stride=(1, 1408))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([64, 3, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 3, 224, 224], f16), T([64, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([96, 64, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([96, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([96, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([96, 96, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([192, 96, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([192, 96, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 192, 28, 28], f16), T([192, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 192, 28, 28], f16), T([192, 192, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([384, 192, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([384, 192, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 13, ((T([128, 384, 14, 14], f16), T([384, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 13, ((T([128, 384, 14, 14], f16), T([384, 384, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([1408, 384, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([1408, 384, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1408, 7, 7], f16), T([128, 384, 14, 14], f16), T([1408, 384, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1408, 7, 7], f16), T([128, 384, 14, 14], f16), T([1408, 384, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 13, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384, 384, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 13, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([128, 192, 28, 28], f16), T([384, 192, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([128, 192, 28, 28], f16), T([384, 192, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), T([192, 192, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), T([192, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([128, 96, 56, 56], f16), T([192, 96, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([128, 96, 56, 56], f16), T([192, 96, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), T([96, 96, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), T([96, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 64, 112, 112], f16), T([96, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 64, 112, 112], f16), T([96, 64, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 3, 224, 224], f16), T([64, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 3, 224, 224], f16), T([64, 3, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1408, 7, 7], f16, stride=(1408, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 1408, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1408], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1408], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 11, ((T([128, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 41, ((T([128, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 1408, 7, 7], f16), T([1408], f16), T([1408], f16), T([1408], f16), T([1408], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 2, ((T([128, 1408, 7, 7], f16), T([128, 1408, 7, 7], f16), T([1408], f16), T([1408], f16), T([1408], f16), T([1408], f32), T([1408], f32), True, 1e-05, [True, True, True]), {})
+cnt: 41, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 11, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 2, ((T([128, 96, 56, 56], f16),), {})
+cnt: 4, ((T([128, 192, 28, 28], f16),), {})
+cnt: 14, ((T([128, 384, 14, 14], f16),), {})
+cnt: 1, ((T([128, 1408, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 1408, 7, 7], f16), T([128, 1408, 7, 7], f16), 0), {})
+cnt: 14, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), 0), {})
+cnt: 4, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), 0), {})
+cnt: 2, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/res2net101_26w_4s_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/res2net101_26w_4s_training.txt
new file mode 100644
index 0000000000000..c669ec35671a4
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/res2net101_26w_4s_training.txt
@@ -0,0 +1,209 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16, stride=(326144, 3136, 56, 1))), {})
+cnt: 6, ((T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16, stride=(163072, 784, 28, 1))), {})
+cnt: 44, ((T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16, stride=(81536, 196, 14, 1))), {})
+cnt: 4, ((T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16, stride=(40768, 49, 7, 1))), {})
+cnt: 4, ((T([64, 208, 7, 7], f16, stride=(40768, 49, 7, 1)), T([64, 208, 7, 7], f16)), {})
+cnt: 2, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16)), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16)), {})
+cnt: 44, ((T([64, 104, 14, 14], f16, stride=(81536, 196, 14, 1)), T([64, 104, 14, 14], f16)), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16)), {})
+cnt: 6, ((T([64, 52, 28, 28], f16, stride=(163072, 784, 28, 1)), T([64, 52, 28, 28], f16)), {})
+cnt: 3, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16)), {})
+cnt: 4, ((T([64, 26, 56, 56], f16, stride=(326144, 3136, 56, 1)), T([64, 26, 56, 56], f16)), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 170, ((T([], i64), 1), {})
+cnt: 3, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16)), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16)), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16)), {})
+cnt: 3, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([64, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([64, 26, 56, 56], f16, stride=(326144, 3136, 56, 1)), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([64, 52, 56, 56], f16, stride=(652288, 3136, 56, 1)), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([64, 104, 28, 28], f16, stride=(326144, 784, 28, 1)), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([64, 208, 14, 14], f16, stride=(163072, 196, 14, 1)), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([64, 208, 7, 7], f16, stride=(40768, 49, 7, 1)), T([64, 208, 14, 14], f16, stride=(163072, 196, 14, 1)), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([64, 104, 14, 14], f16, stride=(81536, 196, 14, 1)), T([64, 104, 28, 28], f16, stride=(326144, 784, 28, 1)), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([64, 52, 28, 28], f16, stride=(163072, 784, 28, 1)), T([64, 52, 56, 56], f16, stride=(652288, 3136, 56, 1)), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([64, 26, 56, 56], f16, stride=(326144, 3136, 56, 1)), T([64, 26, 56, 56], f16, stride=(326144, 3136, 56, 1)), [3, 3], [1, 1], [1, 1], False, True, None), {})
+Operator: aten.cat.default
+cnt: 2, (([T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16)], 1), {})
+cnt: 4, (([T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16, stride=(326144, 3136, 56, 1))], 1), {})
+cnt: 1, (([T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16)], 1), {})
+cnt: 6, (([T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16, stride=(163072, 784, 28, 1))], 1), {})
+cnt: 1, (([T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16)], 1), {})
+cnt: 44, (([T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16, stride=(81536, 196, 14, 1))], 1), {})
+cnt: 1, (([T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16)], 1), {})
+cnt: 4, (([T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16, stride=(40768, 49, 7, 1))], 1), {})
+cnt: 1, (([T([64, 208, 14, 14], f16), T([64, 208, 14, 14], f16), T([64, 208, 14, 14], f16), T([64, 208, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 104, 28, 28], f16), T([64, 104, 28, 28], f16), T([64, 104, 28, 28], f16), T([64, 104, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 52, 56, 56], f16), T([64, 52, 56, 56], f16), T([64, 52, 56, 56], f16), T([64, 52, 56, 56], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([104, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([64, 26, 56, 56], f16, stride=(326144, 3136, 56, 1)), T([26, 26, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 104, 56, 56], f16), T([256, 104, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 56, 56], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 256, 56, 56], f16), T([104, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([64, 26, 56, 56], f16), T([26, 26, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([208, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 52, 56, 56], f16, stride=(652288, 3136, 56, 1)), T([52, 52, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([64, 208, 28, 28], f16), T([512, 208, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 512, 28, 28], f16), T([208, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 52, 28, 28], f16, stride=(163072, 784, 28, 1)), T([52, 52, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 52, 28, 28], f16), T([52, 52, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([416, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 104, 28, 28], f16, stride=(326144, 784, 28, 1)), T([104, 104, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 23, ((T([64, 416, 14, 14], f16), T([1024, 416, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 22, ((T([64, 1024, 14, 14], f16), T([416, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 22, ((T([64, 104, 14, 14], f16, stride=(81536, 196, 14, 1)), T([104, 104, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 44, ((T([64, 104, 14, 14], f16), T([104, 104, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), T([832, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 208, 14, 14], f16, stride=(163072, 196, 14, 1)), T([208, 208, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 832, 7, 7], f16), T([2048, 832, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 2048, 7, 7], f16), T([832, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 208, 7, 7], f16, stride=(40768, 49, 7, 1)), T([208, 208, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([64, 208, 7, 7], f16), T([208, 208, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([64, 2048, 7, 7], f16), T([64, 832, 7, 7], f16), T([2048, 832, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16), T([208, 208, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16, stride=(40768, 49, 7, 1)), T([208, 208, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 832, 7, 7], f16), T([64, 2048, 7, 7], f16), T([832, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 2048, 7, 7], f16), T([64, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 208, 7, 7], f16), T([64, 208, 14, 14], f16, stride=(163072, 196, 14, 1)), T([208, 208, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 832, 14, 14], f16), T([64, 1024, 14, 14], f16), T([832, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), T([64, 416, 14, 14], f16), T([1024, 416, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 44, ((T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16), T([104, 104, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 22, ((T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16, stride=(81536, 196, 14, 1)), T([104, 104, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 22, ((T([64, 416, 14, 14], f16), T([64, 1024, 14, 14], f16), T([416, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 1024, 14, 14], f16), T([64, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 104, 14, 14], f16), T([64, 104, 28, 28], f16, stride=(326144, 784, 28, 1)), T([104, 104, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 416, 28, 28], f16), T([64, 512, 28, 28], f16), T([416, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), T([64, 208, 28, 28], f16), T([512, 208, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16), T([52, 52, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16, stride=(163072, 784, 28, 1)), T([52, 52, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 208, 28, 28], f16), T([64, 512, 28, 28], f16), T([208, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([64, 256, 56, 56], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 52, 28, 28], f16), T([64, 52, 56, 56], f16, stride=(652288, 3136, 56, 1)), T([52, 52, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 208, 56, 56], f16), T([64, 256, 56, 56], f16), T([208, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 256, 56, 56], f16), T([64, 104, 56, 56], f16), T([256, 104, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16), T([26, 26, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16, stride=(326144, 3136, 56, 1)), T([26, 26, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 104, 56, 56], f16), T([64, 256, 56, 56], f16), T([104, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([64, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 104, 56, 56], f16), T([64, 64, 56, 56], f16), T([104, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([64, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([64, 64, 56, 56], f16), T([64, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([64, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 104, 56, 56], f16), T([104], f16), T([104], f16), T([104], f16), T([104], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([64, 26, 56, 56], f16), T([26], f16), T([26], f16), T([26], f16), T([26], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 208, 56, 56], f16), T([208], f16), T([208], f16), T([208], f16), T([208], f16), True, 0.1, 1e-05), {})
+cnt: 12, ((T([64, 52, 28, 28], f16), T([52], f16), T([52], f16), T([52], f16), T([52], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([64, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 208, 28, 28], f16), T([208], f16), T([208], f16), T([208], f16), T([208], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 416, 28, 28], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f16), True, 0.1, 1e-05), {})
+cnt: 69, ((T([64, 104, 14, 14], f16), T([104], f16), T([104], f16), T([104], f16), T([104], f16), True, 0.1, 1e-05), {})
+cnt: 24, ((T([64, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 22, ((T([64, 416, 14, 14], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 832, 14, 14], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([64, 208, 7, 7], f16), T([208], f16), T([208], f16), T([208], f16), T([208], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([64, 832, 7, 7], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16), T([208], f16), T([208], f16), T([208], f16), T([208], f32), T([208], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([64, 832, 7, 7], f16), T([64, 832, 7, 7], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f32), T([832], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 832, 14, 14], f16), T([64, 832, 14, 14], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f32), T([832], f32), True, 1e-05, [True, True, True]), {})
+cnt: 24, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 69, ((T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16), T([104], f16), T([104], f16), T([104], f16), T([104], f32), T([104], f32), True, 1e-05, [True, True, True]), {})
+cnt: 22, ((T([64, 416, 14, 14], f16), T([64, 416, 14, 14], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f32), T([416], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 416, 28, 28], f16), T([64, 416, 28, 28], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f32), T([416], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 12, ((T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16), T([52], f16), T([52], f16), T([52], f16), T([52], f32), T([52], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 208, 28, 28], f16), T([64, 208, 28, 28], f16), T([208], f16), T([208], f16), T([208], f16), T([208], f32), T([208], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 208, 56, 56], f16), T([64, 208, 56, 56], f16), T([208], f16), T([208], f16), T([208], f16), T([208], f32), T([208], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16), T([26], f16), T([26], f16), T([26], f16), T([26], f32), T([26], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 104, 56, 56], f16), T([64, 104, 56, 56], f16), T([104], f16), T([104], f16), T([104], f16), T([104], f32), T([104], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([64, 64, 112, 112], f16),), {})
+cnt: 3, ((T([64, 104, 56, 56], f16),), {})
+cnt: 9, ((T([64, 26, 56, 56], f16),), {})
+cnt: 3, ((T([64, 256, 56, 56], f16),), {})
+cnt: 1, ((T([64, 208, 56, 56], f16),), {})
+cnt: 12, ((T([64, 52, 28, 28], f16),), {})
+cnt: 4, ((T([64, 512, 28, 28], f16),), {})
+cnt: 3, ((T([64, 208, 28, 28], f16),), {})
+cnt: 1, ((T([64, 416, 28, 28], f16),), {})
+cnt: 69, ((T([64, 104, 14, 14], f16),), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16),), {})
+cnt: 22, ((T([64, 416, 14, 14], f16),), {})
+cnt: 1, ((T([64, 832, 14, 14], f16),), {})
+cnt: 9, ((T([64, 208, 7, 7], f16),), {})
+cnt: 3, ((T([64, 2048, 7, 7], f16),), {})
+cnt: 2, ((T([64, 832, 7, 7], f16),), {})
+Operator: aten.split.Tensor
+cnt: 3, ((T([64, 104, 56, 56], f16), 26, 1), {})
+cnt: 1, ((T([64, 208, 56, 56], f16), 52, 1), {})
+cnt: 3, ((T([64, 208, 28, 28], f16), 52, 1), {})
+cnt: 1, ((T([64, 416, 28, 28], f16), 104, 1), {})
+cnt: 22, ((T([64, 416, 14, 14], f16), 104, 1), {})
+cnt: 1, ((T([64, 832, 14, 14], f16), 208, 1), {})
+cnt: 2, ((T([64, 832, 7, 7], f16), 208, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([64, 2048, 7, 7], f16), T([64, 2048, 7, 7], f16), 0), {})
+cnt: 5, ((T([64, 208, 7, 7], f16, stride=(40768, 49, 7, 1)), T([64, 208, 7, 7], f16), 0), {})
+cnt: 4, ((T([64, 208, 7, 7], f16), T([64, 208, 7, 7], f16), 0), {})
+cnt: 2, ((T([64, 832, 7, 7], f16), T([64, 832, 7, 7], f16), 0), {})
+cnt: 1, ((T([64, 832, 14, 14], f16), T([64, 832, 14, 14], f16), 0), {})
+cnt: 23, ((T([64, 1024, 14, 14], f16), T([64, 1024, 14, 14], f16), 0), {})
+cnt: 25, ((T([64, 104, 14, 14], f16, stride=(81536, 196, 14, 1)), T([64, 104, 14, 14], f16), 0), {})
+cnt: 44, ((T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16), 0), {})
+cnt: 22, ((T([64, 416, 14, 14], f16), T([64, 416, 14, 14], f16), 0), {})
+cnt: 1, ((T([64, 416, 28, 28], f16), T([64, 416, 28, 28], f16), 0), {})
+cnt: 4, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16), 0), {})
+cnt: 6, ((T([64, 52, 28, 28], f16, stride=(163072, 784, 28, 1)), T([64, 52, 28, 28], f16), 0), {})
+cnt: 6, ((T([64, 52, 28, 28], f16), T([64, 52, 28, 28], f16), 0), {})
+cnt: 3, ((T([64, 208, 28, 28], f16), T([64, 208, 28, 28], f16), 0), {})
+cnt: 1, ((T([64, 208, 56, 56], f16), T([64, 208, 56, 56], f16), 0), {})
+cnt: 3, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16), 0), {})
+cnt: 5, ((T([64, 26, 56, 56], f16, stride=(326144, 3136, 56, 1)), T([64, 26, 56, 56], f16), 0), {})
+cnt: 4, ((T([64, 26, 56, 56], f16), T([64, 26, 56, 56], f16), 0), {})
+cnt: 3, ((T([64, 104, 56, 56], f16), T([64, 104, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/res2net50_14w_8s_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/res2net50_14w_8s_training.txt
new file mode 100644
index 0000000000000..88b8cd46438ec
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/res2net50_14w_8s_training.txt
@@ -0,0 +1,209 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 12, ((T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16, stride=(351232, 3136, 56, 1))), {})
+cnt: 18, ((T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16, stride=(175616, 784, 28, 1))), {})
+cnt: 30, ((T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16, stride=(87808, 196, 14, 1))), {})
+cnt: 12, ((T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16, stride=(43904, 49, 7, 1))), {})
+cnt: 12, ((T([128, 112, 7, 7], f16, stride=(43904, 49, 7, 1)), T([128, 112, 7, 7], f16)), {})
+cnt: 2, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16)), {})
+cnt: 6, ((T([128, 1024, 14, 14], f16), T([128, 1024, 14, 14], f16)), {})
+cnt: 30, ((T([128, 56, 14, 14], f16, stride=(87808, 196, 14, 1)), T([128, 56, 14, 14], f16)), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16)), {})
+cnt: 18, ((T([128, 28, 28, 28], f16, stride=(175616, 784, 28, 1)), T([128, 28, 28, 28], f16)), {})
+cnt: 3, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+cnt: 12, ((T([128, 14, 56, 56], f16, stride=(351232, 3136, 56, 1)), T([128, 14, 56, 56], f16)), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 149, ((T([], i64), 1), {})
+cnt: 3, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16)), {})
+cnt: 6, ((T([128, 1024, 14, 14], f16), T([128, 1024, 14, 14], f16)), {})
+cnt: 3, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 14, 56, 56], f16, stride=(351232, 3136, 56, 1)), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 28, 56, 56], f16, stride=(702464, 3136, 56, 1)), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([128, 56, 28, 28], f16, stride=(351232, 784, 28, 1)), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16, stride=(175616, 196, 14, 1)), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 112, 7, 7], f16, stride=(43904, 49, 7, 1)), T([128, 112, 14, 14], f16, stride=(175616, 196, 14, 1)), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 56, 14, 14], f16, stride=(87808, 196, 14, 1)), T([128, 56, 28, 28], f16, stride=(351232, 784, 28, 1)), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 28, 28, 28], f16, stride=(175616, 784, 28, 1)), T([128, 28, 56, 56], f16, stride=(702464, 3136, 56, 1)), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 14, 56, 56], f16, stride=(351232, 3136, 56, 1)), T([128, 14, 56, 56], f16, stride=(351232, 3136, 56, 1)), [3, 3], [1, 1], [1, 1], False, True, None), {})
+Operator: aten.cat.default
+cnt: 2, (([T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16)], 1), {})
+cnt: 4, (([T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16, stride=(351232, 3136, 56, 1))], 1), {})
+cnt: 1, (([T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16)], 1), {})
+cnt: 6, (([T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16, stride=(175616, 784, 28, 1))], 1), {})
+cnt: 1, (([T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16)], 1), {})
+cnt: 10, (([T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16, stride=(87808, 196, 14, 1))], 1), {})
+cnt: 1, (([T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16)], 1), {})
+cnt: 4, (([T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16, stride=(43904, 49, 7, 1))], 1), {})
+cnt: 1, (([T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16)], 1), {})
+cnt: 1, (([T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16), T([128, 56, 28, 28], f16)], 1), {})
+cnt: 1, (([T([128, 28, 56, 56], f16), T([128, 28, 56, 56], f16), T([128, 28, 56, 56], f16), T([128, 28, 56, 56], f16), T([128, 28, 56, 56], f16), T([128, 28, 56, 56], f16), T([128, 28, 56, 56], f16), T([128, 28, 56, 56], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([112, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([128, 14, 56, 56], f16, stride=(351232, 3136, 56, 1)), T([14, 14, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 112, 56, 56], f16), T([256, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 256, 56, 56], f16), T([112, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([128, 14, 56, 56], f16), T([14, 14, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([224, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 28, 56, 56], f16, stride=(702464, 3136, 56, 1)), T([28, 28, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 224, 28, 28], f16), T([512, 224, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 512, 28, 28], f16), T([224, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 28, 28, 28], f16, stride=(175616, 784, 28, 1)), T([28, 28, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 18, ((T([128, 28, 28, 28], f16), T([28, 28, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([448, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 56, 28, 28], f16, stride=(351232, 784, 28, 1)), T([56, 56, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([128, 448, 14, 14], f16), T([1024, 448, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 1024, 14, 14], f16), T([448, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 56, 14, 14], f16, stride=(87808, 196, 14, 1)), T([56, 56, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 30, ((T([128, 56, 14, 14], f16), T([56, 56, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 14, 14], f16), T([896, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 112, 14, 14], f16, stride=(175616, 196, 14, 1)), T([112, 112, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 896, 7, 7], f16), T([2048, 896, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 2048, 7, 7], f16), T([896, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 112, 7, 7], f16, stride=(43904, 49, 7, 1)), T([112, 112, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([128, 112, 7, 7], f16), T([112, 112, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([128, 2048, 7, 7], f16), T([128, 896, 7, 7], f16), T([2048, 896, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 12, ((T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([112, 112, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16, stride=(43904, 49, 7, 1)), T([112, 112, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 896, 7, 7], f16), T([128, 2048, 7, 7], f16), T([896, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 2048, 7, 7], f16), T([128, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([128, 112, 7, 7], f16), T([128, 112, 14, 14], f16, stride=(175616, 196, 14, 1)), T([112, 112, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 896, 14, 14], f16), T([128, 1024, 14, 14], f16), T([896, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([128, 1024, 14, 14], f16), T([128, 448, 14, 14], f16), T([1024, 448, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 30, ((T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([56, 56, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16, stride=(87808, 196, 14, 1)), T([56, 56, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 448, 14, 14], f16), T([128, 1024, 14, 14], f16), T([448, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1024, 14, 14], f16), T([128, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([128, 56, 14, 14], f16), T([128, 56, 28, 28], f16, stride=(351232, 784, 28, 1)), T([56, 56, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 448, 28, 28], f16), T([128, 512, 28, 28], f16), T([448, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), T([128, 224, 28, 28], f16), T([512, 224, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 18, ((T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([28, 28, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16, stride=(175616, 784, 28, 1)), T([28, 28, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 224, 28, 28], f16), T([128, 512, 28, 28], f16), T([224, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 256, 56, 56], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([128, 28, 28, 28], f16), T([128, 28, 56, 56], f16, stride=(702464, 3136, 56, 1)), T([28, 28, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 224, 56, 56], f16), T([128, 256, 56, 56], f16), T([224, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 256, 56, 56], f16), T([128, 112, 56, 56], f16), T([256, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 12, ((T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([14, 14, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 9, ((T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16, stride=(351232, 3136, 56, 1)), T([14, 14, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 112, 56, 56], f16), T([128, 256, 56, 56], f16), T([112, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 56, 56], f16), T([128, 64, 56, 56], f16), T([112, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([128, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 112, 56, 56], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), True, 0.1, 1e-05), {})
+cnt: 21, ((T([128, 14, 56, 56], f16), T([14], f16), T([14], f16), T([14], f16), T([14], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 224, 56, 56], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), True, 0.1, 1e-05), {})
+cnt: 28, ((T([128, 28, 28, 28], f16), T([28], f16), T([28], f16), T([28], f16), T([28], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 224, 28, 28], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 448, 28, 28], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), True, 0.1, 1e-05), {})
+cnt: 42, ((T([128, 56, 14, 14], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([128, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 448, 14, 14], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 896, 14, 14], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f16), True, 0.1, 1e-05), {})
+cnt: 21, ((T([128, 112, 7, 7], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 896, 7, 7], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 21, ((T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 896, 7, 7], f16), T([128, 896, 7, 7], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f32), T([896], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 896, 14, 14], f16), T([128, 896, 14, 14], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f32), T([896], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([128, 1024, 14, 14], f16), T([128, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 42, ((T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f32), T([56], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 448, 14, 14], f16), T([128, 448, 14, 14], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 448, 28, 28], f16), T([128, 448, 28, 28], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 28, ((T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), T([28], f16), T([28], f16), T([28], f16), T([28], f32), T([28], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 224, 28, 28], f16), T([128, 224, 28, 28], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 224, 56, 56], f16), T([128, 224, 56, 56], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 21, ((T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), T([14], f16), T([14], f16), T([14], f16), T([14], f32), T([14], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 112, 56, 56], f16), T([128, 112, 56, 56], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 3, ((T([128, 112, 56, 56], f16),), {})
+cnt: 21, ((T([128, 14, 56, 56], f16),), {})
+cnt: 3, ((T([128, 256, 56, 56], f16),), {})
+cnt: 1, ((T([128, 224, 56, 56], f16),), {})
+cnt: 28, ((T([128, 28, 28, 28], f16),), {})
+cnt: 4, ((T([128, 512, 28, 28], f16),), {})
+cnt: 3, ((T([128, 224, 28, 28], f16),), {})
+cnt: 1, ((T([128, 448, 28, 28], f16),), {})
+cnt: 42, ((T([128, 56, 14, 14], f16),), {})
+cnt: 6, ((T([128, 1024, 14, 14], f16),), {})
+cnt: 5, ((T([128, 448, 14, 14], f16),), {})
+cnt: 1, ((T([128, 896, 14, 14], f16),), {})
+cnt: 21, ((T([128, 112, 7, 7], f16),), {})
+cnt: 3, ((T([128, 2048, 7, 7], f16),), {})
+cnt: 2, ((T([128, 896, 7, 7], f16),), {})
+Operator: aten.split.Tensor
+cnt: 3, ((T([128, 112, 56, 56], f16), 14, 1), {})
+cnt: 1, ((T([128, 224, 56, 56], f16), 28, 1), {})
+cnt: 3, ((T([128, 224, 28, 28], f16), 28, 1), {})
+cnt: 1, ((T([128, 448, 28, 28], f16), 56, 1), {})
+cnt: 5, ((T([128, 448, 14, 14], f16), 56, 1), {})
+cnt: 1, ((T([128, 896, 14, 14], f16), 112, 1), {})
+cnt: 2, ((T([128, 896, 7, 7], f16), 112, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16), 0), {})
+cnt: 9, ((T([128, 112, 7, 7], f16, stride=(43904, 49, 7, 1)), T([128, 112, 7, 7], f16), 0), {})
+cnt: 12, ((T([128, 112, 7, 7], f16), T([128, 112, 7, 7], f16), 0), {})
+cnt: 2, ((T([128, 896, 7, 7], f16), T([128, 896, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 896, 14, 14], f16), T([128, 896, 14, 14], f16), 0), {})
+cnt: 6, ((T([128, 1024, 14, 14], f16), T([128, 1024, 14, 14], f16), 0), {})
+cnt: 12, ((T([128, 56, 14, 14], f16, stride=(87808, 196, 14, 1)), T([128, 56, 14, 14], f16), 0), {})
+cnt: 30, ((T([128, 56, 14, 14], f16), T([128, 56, 14, 14], f16), 0), {})
+cnt: 5, ((T([128, 448, 14, 14], f16), T([128, 448, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 448, 28, 28], f16), T([128, 448, 28, 28], f16), 0), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16), 0), {})
+cnt: 10, ((T([128, 28, 28, 28], f16, stride=(175616, 784, 28, 1)), T([128, 28, 28, 28], f16), 0), {})
+cnt: 18, ((T([128, 28, 28, 28], f16), T([128, 28, 28, 28], f16), 0), {})
+cnt: 3, ((T([128, 224, 28, 28], f16), T([128, 224, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 224, 56, 56], f16), T([128, 224, 56, 56], f16), 0), {})
+cnt: 3, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16), 0), {})
+cnt: 9, ((T([128, 14, 56, 56], f16, stride=(351232, 3136, 56, 1)), T([128, 14, 56, 56], f16), 0), {})
+cnt: 12, ((T([128, 14, 56, 56], f16), T([128, 14, 56, 56], f16), 0), {})
+cnt: 3, ((T([128, 112, 56, 56], f16), T([128, 112, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/res2next50_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/res2next50_training.txt
new file mode 100644
index 0000000000000..d498c8050f7d8
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/res2next50_training.txt
@@ -0,0 +1,197 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1))), {})
+cnt: 6, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16, stride=(200704, 784, 28, 1))), {})
+cnt: 10, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16, stride=(100352, 196, 14, 1))), {})
+cnt: 4, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16, stride=(50176, 49, 7, 1))), {})
+cnt: 4, ((T([128, 256, 7, 7], f16, stride=(50176, 49, 7, 1)), T([128, 256, 7, 7], f16)), {})
+cnt: 2, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16)), {})
+cnt: 6, ((T([128, 1024, 14, 14], f16), T([128, 1024, 14, 14], f16)), {})
+cnt: 10, ((T([128, 128, 14, 14], f16, stride=(100352, 196, 14, 1)), T([128, 128, 14, 14], f16)), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16)), {})
+cnt: 6, ((T([128, 64, 28, 28], f16, stride=(200704, 784, 28, 1)), T([128, 64, 28, 28], f16)), {})
+cnt: 3, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+cnt: 4, ((T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1)), T([128, 32, 56, 56], f16)), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 85, ((T([], i64), 1), {})
+cnt: 3, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16)), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16)), {})
+cnt: 6, ((T([128, 1024, 14, 14], f16), T([128, 1024, 14, 14], f16)), {})
+cnt: 3, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1)), [3, 3], [1, 1], [1, 1]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16, stride=(802816, 3136, 56, 1)), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([128, 128, 28, 28], f16, stride=(401408, 784, 28, 1)), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([128, 256, 14, 14], f16, stride=(200704, 196, 14, 1)), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 256, 7, 7], f16, stride=(50176, 49, 7, 1)), T([128, 256, 14, 14], f16, stride=(200704, 196, 14, 1)), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 128, 14, 14], f16, stride=(100352, 196, 14, 1)), T([128, 128, 28, 28], f16, stride=(401408, 784, 28, 1)), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 64, 28, 28], f16, stride=(200704, 784, 28, 1)), T([128, 64, 56, 56], f16, stride=(802816, 3136, 56, 1)), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1)), T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1)), [3, 3], [1, 1], [1, 1], False, True, None), {})
+Operator: aten.cat.default
+cnt: 2, (([T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16)], 1), {})
+cnt: 4, (([T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1))], 1), {})
+cnt: 1, (([T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16)], 1), {})
+cnt: 6, (([T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16, stride=(200704, 784, 28, 1))], 1), {})
+cnt: 1, (([T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16)], 1), {})
+cnt: 10, (([T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16, stride=(100352, 196, 14, 1))], 1), {})
+cnt: 1, (([T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16)], 1), {})
+cnt: 4, (([T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16, stride=(50176, 49, 7, 1))], 1), {})
+cnt: 1, (([T([128, 256, 14, 14], f16), T([128, 256, 14, 14], f16), T([128, 256, 14, 14], f16), T([128, 256, 14, 14], f16)], 1), {})
+cnt: 1, (([T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16)], 1), {})
+cnt: 1, (([T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1)), T([32, 4, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 3, ((T([128, 128, 56, 56], f16), T([256, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 32, 56, 56], f16), T([32, 4, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 64, 56, 56], f16, stride=(802816, 3136, 56, 1)), T([64, 8, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 4, ((T([128, 256, 28, 28], f16), T([512, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 64, 28, 28], f16, stride=(200704, 784, 28, 1)), T([64, 8, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 6, ((T([128, 64, 28, 28], f16), T([64, 8, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([512, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 128, 28, 28], f16, stride=(401408, 784, 28, 1)), T([128, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 6, ((T([128, 512, 14, 14], f16), T([1024, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 128, 14, 14], f16, stride=(100352, 196, 14, 1)), T([128, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 10, ((T([128, 128, 14, 14], f16), T([128, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 1, ((T([128, 1024, 14, 14], f16), T([1024, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 256, 14, 14], f16, stride=(200704, 196, 14, 1)), T([256, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 3, ((T([128, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 2048, 7, 7], f16), T([1024, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 256, 7, 7], f16, stride=(50176, 49, 7, 1)), T([256, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 4, ((T([128, 256, 7, 7], f16), T([256, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([128, 2048, 7, 7], f16), T([128, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([256, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 2, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16, stride=(50176, 49, 7, 1)), T([256, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 2, ((T([128, 1024, 7, 7], f16), T([128, 2048, 7, 7], f16), T([1024, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 2048, 7, 7], f16), T([128, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 256, 7, 7], f16), T([128, 256, 14, 14], f16, stride=(200704, 196, 14, 1)), T([256, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([128, 1024, 14, 14], f16), T([128, 1024, 14, 14], f16), T([1024, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([128, 1024, 14, 14], f16), T([128, 512, 14, 14], f16), T([1024, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 10, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 5, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16, stride=(100352, 196, 14, 1)), T([128, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 5, ((T([128, 512, 14, 14], f16), T([128, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1024, 14, 14], f16), T([128, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 128, 14, 14], f16), T([128, 128, 28, 28], f16, stride=(401408, 784, 28, 1)), T([128, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16), T([512, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 512, 28, 28], f16), T([128, 256, 28, 28], f16), T([512, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([64, 8, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16, stride=(200704, 784, 28, 1)), T([64, 8, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 3, ((T([128, 256, 28, 28], f16), T([128, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), T([128, 256, 56, 56], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 64, 28, 28], f16), T([128, 64, 56, 56], f16, stride=(802816, 3136, 56, 1)), T([64, 8, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 256, 56, 56], f16), T([128, 128, 56, 56], f16), T([256, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([32, 4, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 5, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1)), T([32, 4, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 2, ((T([128, 128, 56, 56], f16), T([128, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), T([128, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 64, 56, 56], f16), T([128, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([128, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 12, ((T([128, 64, 28, 28], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([128, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 18, ((T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([128, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([128, 256, 7, 7], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([128, 1024, 14, 14], f16), T([128, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 18, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 512, 14, 14], f16), T([128, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 12, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 28, 28], f16), T([128, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 3, ((T([128, 128, 56, 56], f16),), {})
+cnt: 9, ((T([128, 32, 56, 56], f16),), {})
+cnt: 4, ((T([128, 256, 56, 56], f16),), {})
+cnt: 12, ((T([128, 64, 28, 28], f16),), {})
+cnt: 5, ((T([128, 512, 28, 28], f16),), {})
+cnt: 3, ((T([128, 256, 28, 28], f16),), {})
+cnt: 18, ((T([128, 128, 14, 14], f16),), {})
+cnt: 7, ((T([128, 1024, 14, 14], f16),), {})
+cnt: 5, ((T([128, 512, 14, 14], f16),), {})
+cnt: 9, ((T([128, 256, 7, 7], f16),), {})
+cnt: 3, ((T([128, 2048, 7, 7], f16),), {})
+cnt: 2, ((T([128, 1024, 7, 7], f16),), {})
+Operator: aten.split.Tensor
+cnt: 3, ((T([128, 128, 56, 56], f16), 32, 1), {})
+cnt: 1, ((T([128, 256, 56, 56], f16), 64, 1), {})
+cnt: 3, ((T([128, 256, 28, 28], f16), 64, 1), {})
+cnt: 1, ((T([128, 512, 28, 28], f16), 128, 1), {})
+cnt: 5, ((T([128, 512, 14, 14], f16), 128, 1), {})
+cnt: 1, ((T([128, 1024, 14, 14], f16), 256, 1), {})
+cnt: 2, ((T([128, 1024, 7, 7], f16), 256, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([128, 2048, 7, 7], f16), T([128, 2048, 7, 7], f16), 0), {})
+cnt: 5, ((T([128, 256, 7, 7], f16, stride=(50176, 49, 7, 1)), T([128, 256, 7, 7], f16), 0), {})
+cnt: 4, ((T([128, 256, 7, 7], f16), T([128, 256, 7, 7], f16), 0), {})
+cnt: 2, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), 0), {})
+cnt: 7, ((T([128, 1024, 14, 14], f16), T([128, 1024, 14, 14], f16), 0), {})
+cnt: 8, ((T([128, 128, 14, 14], f16, stride=(100352, 196, 14, 1)), T([128, 128, 14, 14], f16), 0), {})
+cnt: 10, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), 0), {})
+cnt: 5, ((T([128, 512, 14, 14], f16), T([128, 512, 14, 14], f16), 0), {})
+cnt: 5, ((T([128, 512, 28, 28], f16), T([128, 512, 28, 28], f16), 0), {})
+cnt: 6, ((T([128, 64, 28, 28], f16, stride=(200704, 784, 28, 1)), T([128, 64, 28, 28], f16), 0), {})
+cnt: 6, ((T([128, 64, 28, 28], f16), T([128, 64, 28, 28], f16), 0), {})
+cnt: 3, ((T([128, 256, 28, 28], f16), T([128, 256, 28, 28], f16), 0), {})
+cnt: 4, ((T([128, 256, 56, 56], f16), T([128, 256, 56, 56], f16), 0), {})
+cnt: 5, ((T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1)), T([128, 32, 56, 56], f16), 0), {})
+cnt: 4, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), 0), {})
+cnt: 3, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/resmlp_12_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/resmlp_12_224_training.txt
new file mode 100644
index 0000000000000..3c47d598f97f6
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/resmlp_12_224_training.txt
@@ -0,0 +1,75 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([128, 196, 1536], f16), [128, 196, 1536]), {})
+cnt: 12, ((T([128, 384, 196], f16), [49152, 196]), {})
+Operator: aten.add.Tensor
+cnt: 12, ((T([128, 196, 384], f16, stride=(75264, 1, 196)), T([128, 196, 384], f16, stride=(75264, 1, 196))), {})
+cnt: 12, ((T([128, 196, 1536], f16), T([1536], f16)), {})
+cnt: 12, ((T([128, 196, 384], f16, stride=(75264, 1, 196)), T([128, 196, 384], f16)), {})
+cnt: 12, ((T([128, 196, 384], f16), T([128, 196, 384], f16)), {})
+cnt: 12, ((T([128, 196, 384], f16), T([128, 196, 384], f16, stride=(75264, 1, 196))), {})
+Operator: aten.addcmul.default
+cnt: 25, ((T([1, 1, 384], f16), T([1, 1, 384], f16), T([128, 196, 384], f16, stride=(75264, 1, 196))), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([196], f16), T([49152, 196], f16), T([196, 196], f16, stride=(1, 196))), {})
+cnt: 12, ((T([384], f16), T([25088, 1536], f16), T([1536, 384], f16, stride=(1, 1536))), {})
+cnt: 1, ((T([1000], f16), T([128, 384], f16), T([384, 1000], f16, stride=(1, 384))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([128, 196, 384], f16, stride=(75264, 1, 196)), T([128, 384, 1536], f16, stride=(0, 1, 384))), {})
+cnt: 12, ((T([128, 384, 196], f16), T([128, 196, 1536], f16)), {})
+cnt: 12, ((T([128, 196, 1536], f16), T([128, 1536, 384], f16, stride=(0, 384, 1))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([384, 3, 16, 16], f16), T([384], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 384, 14, 14], f16, stride=(75264, 1, 5376, 384)), T([128, 3, 224, 224], f16), T([384, 3, 16, 16], f16), [384], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+cnt: 12, ((T([1536, 384], f16), T([1536, 384], f16, stride=(1, 1536))), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 196, 384], f16, stride=(384, 0, 1)), 196), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([128, 196, 1536], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([128, 196, 1536], f16), T([128, 196, 1536], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 196, 384], f16, stride=(75264, 1, 196)), [1]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 384], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 384], f16)), {})
+cnt: 12, ((T([25088, 384], f16), T([384, 1536], f16)), {})
+cnt: 12, ((T([384, 25088], f16, stride=(1, 384)), T([25088, 1536], f16)), {})
+cnt: 12, ((T([49152, 196], f16), T([196, 196], f16)), {})
+cnt: 12, ((T([196, 49152], f16, stride=(1, 196)), T([49152, 196], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 25, ((T([128, 196, 384], f16, stride=(75264, 1, 196)), 1), {})
+cnt: 25, ((T([1, 1, 384], f16), 1), {})
+Operator: aten.mul.Tensor
+cnt: 12, ((T([384], f16), T([128, 196, 384], f16, stride=(75264, 1, 196))), {})
+cnt: 12, ((T([384], f16), T([128, 196, 384], f16)), {})
+cnt: 25, ((T([128, 196, 384], f16), T([128, 196, 384], f16, stride=(75264, 1, 196))), {})
+cnt: 13, ((T([128, 196, 384], f16), T([1, 1, 384], f16)), {})
+cnt: 24, ((T([128, 196, 384], f16), T([384], f16)), {})
+cnt: 12, ((T([128, 196, 384], f16), T([128, 196, 384], f16)), {})
+cnt: 12, ((T([128, 196, 384], f16, stride=(75264, 1, 196)), T([128, 196, 384], f16, stride=(75264, 1, 196))), {})
+cnt: 12, ((T([128, 196, 384], f16, stride=(75264, 1, 196)), T([1, 1, 384], f16)), {})
+Operator: aten.new_empty_strided.default
+cnt: 12, ((T([1536, 384], f16, stride=(1, 1536)), [1536, 384], [384, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 50, ((T([128, 196, 384], f16), [0, 1], True), {})
+cnt: 12, ((T([25088, 384], f16), [0], True), {})
+cnt: 12, ((T([128, 196, 1536], f16), [0, 1], True), {})
+cnt: 12, ((T([128, 384, 1536], f16), [0], True), {})
+cnt: 12, ((T([49152, 196], f16), [0], True), {})
+cnt: 24, ((T([128, 196, 384], f16, stride=(75264, 1, 196)), [0, 1], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/resnest101e_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/resnest101e_training.txt
new file mode 100644
index 0000000000000..03e1db4dc9c66
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/resnest101e_training.txt
@@ -0,0 +1,269 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([32, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([32, 1000], f16), T([32, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 3, ((T([32, 2, 1, 64], f16), 1, False), {})
+cnt: 4, ((T([32, 2, 1, 128], f16), 1, False), {})
+cnt: 23, ((T([32, 2, 1, 256], f16), 1, False), {})
+cnt: 3, ((T([32, 2, 1, 512], f16), 1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 3, ((T([32, 2, 1, 512], f16), T([32, 2, 1, 512], f16), 1, f16), {})
+cnt: 23, ((T([32, 2, 1, 256], f16), T([32, 2, 1, 256], f16), 1, f16), {})
+cnt: 4, ((T([32, 2, 1, 128], f16), T([32, 2, 1, 128], f16), 1, f16), {})
+cnt: 3, ((T([32, 2, 1, 64], f16), T([32, 2, 1, 64], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([32, 2, 512, 8, 8], f16), T([32, 2, 512, 8, 8], f16, stride=(32768, 0, 64, 8, 1))), {})
+cnt: 2, ((T([32, 2048, 8, 8], f16), T([32, 2048, 8, 8], f16)), {})
+cnt: 1, ((T([32, 2, 512, 16, 16], f16), T([32, 2, 512, 16, 16], f16, stride=(131072, 0, 256, 16, 1))), {})
+cnt: 23, ((T([32, 1024, 16, 16], f16), T([32, 1024, 16, 16], f16)), {})
+cnt: 22, ((T([32, 2, 256, 16, 16], f16), T([32, 2, 256, 16, 16], f16, stride=(65536, 0, 256, 16, 1))), {})
+cnt: 1, ((T([32, 2, 256, 32, 32], f16), T([32, 2, 256, 32, 32], f16, stride=(262144, 0, 1024, 32, 1))), {})
+cnt: 4, ((T([32, 512, 32, 32], f16), T([32, 512, 32, 32], f16)), {})
+cnt: 3, ((T([32, 2, 128, 32, 32], f16), T([32, 2, 128, 32, 32], f16, stride=(131072, 0, 1024, 32, 1))), {})
+cnt: 1, ((T([32, 2, 128, 64, 64], f16), T([32, 2, 128, 64, 64], f16, stride=(524288, 0, 4096, 64, 1))), {})
+cnt: 3, ((T([32, 256, 64, 64], f16), T([32, 256, 64, 64], f16)), {})
+cnt: 3, ((T([32, 2, 64, 64, 64], f16), T([32, 2, 64, 64, 64], f16, stride=(262144, 0, 4096, 64, 1))), {})
+cnt: 1, ((T([32, 128, 64, 64], f16), T([32, 128, 64, 64], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 139, ((T([], i64), 1), {})
+cnt: 3, ((T([32, 256, 64, 64], f16), T([32, 256, 64, 64], f16)), {})
+cnt: 4, ((T([32, 512, 32, 32], f16), T([32, 512, 32, 32], f16)), {})
+cnt: 23, ((T([32, 1024, 16, 16], f16), T([32, 1024, 16, 16], f16)), {})
+cnt: 3, ((T([32, 2048, 8, 8], f16), T([32, 2048, 8, 8], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([32, 128, 64, 64], f16), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([32, 256, 64, 64], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([32, 256, 32, 32], f16), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([32, 512, 32, 32], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([32, 512, 16, 16], f16), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([32, 1024, 16, 16], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([32, 1024, 8, 8], f16), T([32, 1024, 16, 16], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([32, 512, 8, 8], f16), T([32, 512, 16, 16], f16), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([32, 512, 16, 16], f16), T([32, 512, 32, 32], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([32, 256, 16, 16], f16), T([32, 256, 32, 32], f16), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([32, 256, 32, 32], f16), T([32, 256, 64, 64], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([32, 128, 32, 32], f16), T([32, 128, 64, 64], f16), [3, 3], [2, 2], [1, 1], False, True, None), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 256, 256], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 256, 256], f16), T([64, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 128, 128], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 128, 128], f16), T([128, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 64, 64], f16), T([64, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 64, 64, 64], f16), T([128, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 3, ((T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 32, 1, 1], f16), T([128, 32, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 64, 64, 64], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 64, 64], f16), T([256, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 256, 64, 64], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 64, 64], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 64, 64], f16), T([256, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 4, ((T([32, 128, 1, 1], f16), T([64, 128, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 64, 1, 1], f16), T([256, 64, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 128, 32, 32], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 32, 32], f16), T([512, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 512, 32, 32], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 128, 32, 32], f16), T([256, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 1, ((T([32, 512, 32, 32], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 32, 32], f16), T([512, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 23, ((T([32, 256, 1, 1], f16), T([128, 256, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 23, ((T([32, 128, 1, 1], f16), T([512, 128, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 23, ((T([32, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 16, 16], f16), T([1024, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 22, ((T([32, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 22, ((T([32, 256, 16, 16], f16), T([512, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 1, ((T([32, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 16, 16], f16), T([1024, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 3, ((T([32, 512, 1, 1], f16), T([256, 512, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([1024, 256, 1, 1], f16), T([1024], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 512, 8, 8], f16), T([2048, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 8, 8], f16), T([2048, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 2048, 8, 8], f16), T([512, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 512, 8, 8], f16), T([1024, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([32, 2048, 8, 8], f16), T([32, 512, 8, 8], f16), T([2048, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 1024, 1, 1], f16), T([32, 256, 1, 1], f16), T([1024, 256, 1, 1], f16), [1024], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([32, 512, 1, 1], f16), T([256, 512, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 1024, 8, 8], f16), T([32, 512, 8, 8], f16), T([1024, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 2, ((T([32, 512, 8, 8], f16), T([32, 2048, 8, 8], f16), T([512, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 8, 8], f16), T([32, 1024, 8, 8], f16), T([2048, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 16, 16], f16), T([32, 512, 16, 16], f16), T([1024, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 16, 16], f16), T([32, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 23, ((T([32, 1024, 16, 16], f16), T([32, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 23, ((T([32, 512, 1, 1], f16), T([32, 128, 1, 1], f16), T([512, 128, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 23, ((T([32, 128, 1, 1], f16), T([32, 256, 1, 1], f16), T([128, 256, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 22, ((T([32, 512, 16, 16], f16), T([32, 256, 16, 16], f16), T([512, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 22, ((T([32, 256, 16, 16], f16), T([32, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 16, 16], f16), T([32, 512, 16, 16], f16), T([1024, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 32, 32], f16), T([32, 256, 32, 32], f16), T([512, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 32, 32], f16), T([32, 512, 32, 32], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 512, 32, 32], f16), T([32, 128, 32, 32], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 256, 1, 1], f16), T([32, 64, 1, 1], f16), T([256, 64, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([32, 64, 1, 1], f16), T([32, 128, 1, 1], f16), T([64, 128, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 256, 32, 32], f16), T([32, 128, 32, 32], f16), T([256, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 3, ((T([32, 128, 32, 32], f16), T([32, 512, 32, 32], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 32, 32], f16), T([32, 256, 32, 32], f16), T([512, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 64, 64], f16), T([32, 128, 64, 64], f16), T([256, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 64, 64], f16), T([32, 256, 64, 64], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 256, 64, 64], f16), T([32, 64, 64, 64], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 128, 1, 1], f16), T([32, 32, 1, 1], f16), T([128, 32, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 32, 1, 1], f16), T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 128, 64, 64], f16), T([32, 64, 64, 64], f16), T([128, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 2, ((T([32, 64, 64, 64], f16), T([32, 256, 64, 64], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 64, 64], f16), T([32, 128, 64, 64], f16), T([256, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 64, 64], f16), T([32, 128, 64, 64], f16), T([64, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 128, 128], f16), T([32, 64, 128, 128], f16), T([128, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 128, 128], f16), T([32, 64, 128, 128], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 128, 128], f16), T([32, 3, 256, 256], f16), T([64, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 256, 256], f16), T([32, 3, 256, 256], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 2048, 8, 8], f16, stride=(2048, 1, 0, 0)), 64), {})
+cnt: 2, ((T([32, 512, 8, 8], f16, stride=(512, 1, 0, 0)), 64), {})
+cnt: 1, ((T([32, 512, 16, 16], f16, stride=(512, 1, 0, 0)), 256), {})
+cnt: 22, ((T([32, 256, 16, 16], f16, stride=(256, 1, 0, 0)), 256), {})
+cnt: 1, ((T([32, 256, 32, 32], f16, stride=(256, 1, 0, 0)), 1024), {})
+cnt: 3, ((T([32, 128, 32, 32], f16, stride=(128, 1, 0, 0)), 1024), {})
+cnt: 1, ((T([32, 128, 64, 64], f16, stride=(128, 1, 0, 0)), 4096), {})
+cnt: 3, ((T([32, 64, 64, 64], f16, stride=(64, 1, 0, 0)), 4096), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([32], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([32, 128, 128, 128], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([32, 128, 64, 64], f16), T([32, 128, 128, 128], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([32, 128, 64, 64], i64)), {})
+Operator: aten.mean.dim
+cnt: 3, ((T([32, 64, 64, 64], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 128, 64, 64], f16), [2, 3], True), {})
+cnt: 3, ((T([32, 128, 32, 32], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 256, 32, 32], f16), [2, 3], True), {})
+cnt: 22, ((T([32, 256, 16, 16], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 512, 16, 16], f16), [2, 3], True), {})
+cnt: 2, ((T([32, 512, 8, 8], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 2048, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(1, 1000)), T([32, 2048], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 3, ((T([32, 2, 64, 64, 64], f16), T([32, 2, 64, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 128, 64, 64], f16), T([32, 2, 128, 1, 1], f16)), {})
+cnt: 3, ((T([32, 2, 128, 32, 32], f16), T([32, 2, 128, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 256, 32, 32], f16), T([32, 2, 256, 1, 1], f16)), {})
+cnt: 22, ((T([32, 2, 256, 16, 16], f16), T([32, 2, 256, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 512, 16, 16], f16), T([32, 2, 512, 1, 1], f16)), {})
+cnt: 2, ((T([32, 2, 512, 8, 8], f16), T([32, 2, 512, 1, 1], f16)), {})
+cnt: 2, ((T([32, 2, 512, 8, 8], f16, stride=(32768, 0, 64, 8, 1)), T([32, 2, 512, 8, 8], f16)), {})
+cnt: 2, ((T([32, 2, 512, 8, 8], f16, stride=(32768, 0, 64, 8, 1)), T([32, 2, 512, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 512, 16, 16], f16, stride=(131072, 0, 256, 16, 1)), T([32, 2, 512, 16, 16], f16)), {})
+cnt: 1, ((T([32, 2, 512, 16, 16], f16, stride=(131072, 0, 256, 16, 1)), T([32, 2, 512, 1, 1], f16)), {})
+cnt: 22, ((T([32, 2, 256, 16, 16], f16, stride=(65536, 0, 256, 16, 1)), T([32, 2, 256, 16, 16], f16)), {})
+cnt: 22, ((T([32, 2, 256, 16, 16], f16, stride=(65536, 0, 256, 16, 1)), T([32, 2, 256, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 256, 32, 32], f16, stride=(262144, 0, 1024, 32, 1)), T([32, 2, 256, 32, 32], f16)), {})
+cnt: 1, ((T([32, 2, 256, 32, 32], f16, stride=(262144, 0, 1024, 32, 1)), T([32, 2, 256, 1, 1], f16)), {})
+cnt: 3, ((T([32, 2, 128, 32, 32], f16, stride=(131072, 0, 1024, 32, 1)), T([32, 2, 128, 32, 32], f16)), {})
+cnt: 3, ((T([32, 2, 128, 32, 32], f16, stride=(131072, 0, 1024, 32, 1)), T([32, 2, 128, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 128, 64, 64], f16, stride=(524288, 0, 4096, 64, 1)), T([32, 2, 128, 64, 64], f16)), {})
+cnt: 1, ((T([32, 2, 128, 64, 64], f16, stride=(524288, 0, 4096, 64, 1)), T([32, 2, 128, 1, 1], f16)), {})
+cnt: 3, ((T([32, 2, 64, 64, 64], f16, stride=(262144, 0, 4096, 64, 1)), T([32, 2, 64, 64, 64], f16)), {})
+cnt: 3, ((T([32, 2, 64, 64, 64], f16, stride=(262144, 0, 4096, 64, 1)), T([32, 2, 64, 1, 1], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([32, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 128, 128, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 32, 1, 1], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([32, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 64, 1, 1], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([32, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 23, ((T([32, 128, 1, 1], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 25, ((T([32, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 22, ((T([32, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 23, ((T([32, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 2048, 8, 8], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 1024, 8, 8], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([32, 2048, 8, 8], f16), T([32, 2048, 8, 8], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([32, 256, 1, 1], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 1024, 8, 8], f16), T([32, 1024, 8, 8], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 512, 8, 8], f16), T([32, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 25, ((T([32, 1024, 16, 16], f16), T([32, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 23, ((T([32, 512, 16, 16], f16), T([32, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 23, ((T([32, 128, 1, 1], f16), T([32, 128, 1, 1], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 22, ((T([32, 256, 16, 16], f16), T([32, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([32, 512, 32, 32], f16), T([32, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 256, 32, 32], f16), T([32, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 128, 32, 32], f16), T([32, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([32, 256, 64, 64], f16), T([32, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 128, 64, 64], f16), T([32, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 32, 1, 1], f16), T([32, 32, 1, 1], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 64, 64, 64], f16), T([32, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 128, 128], f16), T([32, 128, 128, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 64, 128, 128], f16), T([32, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([32, 1000], f16), T([32], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([32, 1000], f16), T([32], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([32, 64, 128, 128], f16),), {})
+cnt: 1, ((T([32, 128, 128, 128], f16),), {})
+cnt: 3, ((T([32, 64, 64, 64], f16),), {})
+cnt: 4, ((T([32, 128, 64, 64], f16),), {})
+cnt: 3, ((T([32, 32, 1, 1], f16),), {})
+cnt: 4, ((T([32, 256, 64, 64], f16),), {})
+cnt: 4, ((T([32, 64, 1, 1], f16),), {})
+cnt: 5, ((T([32, 512, 32, 32], f16),), {})
+cnt: 3, ((T([32, 128, 32, 32], f16),), {})
+cnt: 4, ((T([32, 256, 32, 32], f16),), {})
+cnt: 23, ((T([32, 128, 1, 1], f16),), {})
+cnt: 24, ((T([32, 1024, 16, 16], f16),), {})
+cnt: 22, ((T([32, 256, 16, 16], f16),), {})
+cnt: 23, ((T([32, 512, 16, 16], f16),), {})
+cnt: 3, ((T([32, 256, 1, 1], f16),), {})
+cnt: 3, ((T([32, 2048, 8, 8], f16),), {})
+cnt: 2, ((T([32, 512, 8, 8], f16),), {})
+cnt: 2, ((T([32, 1024, 8, 8], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16), [0], True), {})
+cnt: 2, ((T([32, 2, 512, 8, 8], f16), [3, 4], True), {})
+cnt: 1, ((T([32, 2, 512, 16, 16], f16), [3, 4], True), {})
+cnt: 22, ((T([32, 2, 256, 16, 16], f16), [3, 4], True), {})
+cnt: 1, ((T([32, 2, 256, 32, 32], f16), [3, 4], True), {})
+cnt: 3, ((T([32, 2, 128, 32, 32], f16), [3, 4], True), {})
+cnt: 1, ((T([32, 2, 128, 64, 64], f16), [3, 4], True), {})
+cnt: 3, ((T([32, 2, 64, 64, 64], f16), [3, 4], True), {})
+Operator: aten.sum.dim_IntList
+cnt: 6, ((T([32, 2, 64, 64, 64], f16), [1]), {})
+cnt: 2, ((T([32, 2, 128, 64, 64], f16), [1]), {})
+cnt: 6, ((T([32, 2, 128, 32, 32], f16), [1]), {})
+cnt: 2, ((T([32, 2, 256, 32, 32], f16), [1]), {})
+cnt: 44, ((T([32, 2, 256, 16, 16], f16), [1]), {})
+cnt: 2, ((T([32, 2, 512, 16, 16], f16), [1]), {})
+cnt: 4, ((T([32, 2, 512, 8, 8], f16), [1]), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([32, 2048, 8, 8], f16), T([32, 2048, 8, 8], f16), 0), {})
+cnt: 3, ((T([32, 256, 1, 1], f16), T([32, 256, 1, 1], f16), 0), {})
+cnt: 2, ((T([32, 1024, 8, 8], f16), T([32, 1024, 8, 8], f16), 0), {})
+cnt: 2, ((T([32, 512, 8, 8], f16), T([32, 512, 8, 8], f16), 0), {})
+cnt: 24, ((T([32, 1024, 16, 16], f16), T([32, 1024, 16, 16], f16), 0), {})
+cnt: 23, ((T([32, 512, 16, 16], f16), T([32, 512, 16, 16], f16), 0), {})
+cnt: 23, ((T([32, 128, 1, 1], f16), T([32, 128, 1, 1], f16), 0), {})
+cnt: 22, ((T([32, 256, 16, 16], f16), T([32, 256, 16, 16], f16), 0), {})
+cnt: 5, ((T([32, 512, 32, 32], f16), T([32, 512, 32, 32], f16), 0), {})
+cnt: 4, ((T([32, 256, 32, 32], f16), T([32, 256, 32, 32], f16), 0), {})
+cnt: 4, ((T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), 0), {})
+cnt: 3, ((T([32, 128, 32, 32], f16), T([32, 128, 32, 32], f16), 0), {})
+cnt: 4, ((T([32, 256, 64, 64], f16), T([32, 256, 64, 64], f16), 0), {})
+cnt: 4, ((T([32, 128, 64, 64], f16), T([32, 128, 64, 64], f16), 0), {})
+cnt: 3, ((T([32, 32, 1, 1], f16), T([32, 32, 1, 1], f16), 0), {})
+cnt: 3, ((T([32, 64, 64, 64], f16), T([32, 64, 64, 64], f16), 0), {})
+cnt: 1, ((T([32, 128, 128, 128], f16), T([32, 128, 128, 128], f16), 0), {})
+cnt: 2, ((T([32, 64, 128, 128], f16), T([32, 64, 128, 128], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/resnet18_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/resnet18_training.txt
new file mode 100644
index 0000000000000..ef201d6c179c5
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/resnet18_training.txt
@@ -0,0 +1,88 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128, 512, 7, 7], f16), T([128, 512, 7, 7], f16)), {})
+cnt: 2, ((T([128, 256, 14, 14], f16), T([128, 256, 14, 14], f16)), {})
+cnt: 2, ((T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16)), {})
+cnt: 3, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 20, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16)), {})
+cnt: 2, ((T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16)), {})
+cnt: 2, ((T([128, 256, 14, 14], f16), T([128, 256, 14, 14], f16)), {})
+cnt: 2, ((T([128, 512, 7, 7], f16), T([128, 512, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 512], f16), T([512, 1000], f16, stride=(1, 512))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 64, 56, 56], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 128, 28, 28], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 28, 28], f16), T([256, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 256, 14, 14], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 28, 28], f16), T([256, 128, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 14, 14], f16), T([512, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 512, 7, 7], f16), T([512, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 14, 14], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([128, 512, 7, 7], f16), T([128, 512, 7, 7], f16), T([512, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 7, 7], f16), T([128, 256, 14, 14], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 512, 7, 7], f16), T([128, 256, 14, 14], f16), T([512, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 256, 14, 14], f16), T([128, 256, 14, 14], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 14, 14], f16), T([128, 128, 28, 28], f16), T([256, 128, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 256, 14, 14], f16), T([128, 128, 28, 28], f16), T([256, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 28, 28], f16), T([128, 64, 56, 56], f16), T([128, 64, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 28, 28], f16), T([128, 64, 56, 56], f16), T([128, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 512, 7, 7], f16, stride=(512, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([128, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 512, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 512], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 512], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 5, ((T([128, 512, 7, 7], f16), T([128, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 256, 14, 14], f16), T([128, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 64, 112, 112], f16),), {})
+cnt: 4, ((T([128, 64, 56, 56], f16),), {})
+cnt: 4, ((T([128, 128, 28, 28], f16),), {})
+cnt: 4, ((T([128, 256, 14, 14], f16),), {})
+cnt: 4, ((T([128, 512, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 4, ((T([128, 512, 7, 7], f16), T([128, 512, 7, 7], f16), 0), {})
+cnt: 4, ((T([128, 256, 14, 14], f16), T([128, 256, 14, 14], f16), 0), {})
+cnt: 4, ((T([128, 128, 28, 28], f16), T([128, 128, 28, 28], f16), 0), {})
+cnt: 4, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 64, 112, 112], f16), T([128, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/rexnet_100_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/rexnet_100_training.txt
new file mode 100644
index 0000000000000..739188b28f291
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/rexnet_100_training.txt
@@ -0,0 +1,573 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 49, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 27, 56, 56], f16, stride=(119168, 3136, 56, 1)), T([128, 27, 56, 56], f16)), {})
+cnt: 2, ((T([128, 50, 28, 28], f16, stride=(47824, 784, 28, 1)), T([128, 50, 28, 28], f16)), {})
+cnt: 2, ((T([128, 72, 14, 14], f16, stride=(16464, 196, 14, 1)), T([128, 72, 14, 14], f16)), {})
+cnt: 2, ((T([128, 84, 14, 14], f16, stride=(18620, 196, 14, 1)), T([128, 84, 14, 14], f16)), {})
+cnt: 2, ((T([128, 95, 14, 14], f16, stride=(20776, 196, 14, 1)), T([128, 95, 14, 14], f16)), {})
+cnt: 2, ((T([128, 106, 14, 14], f16, stride=(22932, 196, 14, 1)), T([128, 106, 14, 14], f16)), {})
+cnt: 2, ((T([128, 117, 14, 14], f16, stride=(25088, 196, 14, 1)), T([128, 117, 14, 14], f16)), {})
+cnt: 2, ((T([128, 140, 7, 7], f16, stride=(7399, 49, 7, 1)), T([128, 140, 7, 7], f16)), {})
+cnt: 2, ((T([128, 151, 7, 7], f16, stride=(7938, 49, 7, 1)), T([128, 151, 7, 7], f16)), {})
+cnt: 2, ((T([128, 162, 7, 7], f16, stride=(8526, 49, 7, 1)), T([128, 162, 7, 7], f16)), {})
+cnt: 2, ((T([128, 174, 7, 7], f16, stride=(9065, 49, 7, 1)), T([128, 174, 7, 7], f16)), {})
+cnt: 1, ((T([128, 185, 7, 7], f16), T([128, 185, 7, 7], f16)), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), T([128, 1044, 7, 7], f16)), {})
+cnt: 1, ((T([128, 174, 7, 7], f16), T([128, 174, 7, 7], f16)), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), T([128, 972, 7, 7], f16)), {})
+cnt: 1, ((T([128, 162, 7, 7], f16), T([128, 162, 7, 7], f16)), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), T([128, 906, 7, 7], f16)), {})
+cnt: 1, ((T([128, 151, 7, 7], f16), T([128, 151, 7, 7], f16)), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), T([128, 840, 7, 7], f16)), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([128, 768, 7, 7], f16)), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16)), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), T([128, 702, 14, 14], f16)), {})
+cnt: 1, ((T([128, 117, 14, 14], f16), T([128, 117, 14, 14], f16)), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), T([128, 636, 14, 14], f16)), {})
+cnt: 1, ((T([128, 106, 14, 14], f16), T([128, 106, 14, 14], f16)), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), T([128, 570, 14, 14], f16)), {})
+cnt: 1, ((T([128, 95, 14, 14], f16), T([128, 95, 14, 14], f16)), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), T([128, 504, 14, 14], f16)), {})
+cnt: 1, ((T([128, 84, 14, 14], f16), T([128, 84, 14, 14], f16)), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), T([128, 432, 14, 14], f16)), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), T([128, 366, 14, 14], f16)), {})
+cnt: 1, ((T([128, 61, 28, 28], f16), T([128, 61, 28, 28], f16)), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), T([128, 300, 28, 28], f16)), {})
+cnt: 1, ((T([128, 228, 28, 28], f16), T([128, 228, 28, 28], f16)), {})
+cnt: 1, ((T([128, 38, 56, 56], f16), T([128, 38, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 13, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 27, 56, 56], f16), T([128, 11, 56, 56], f16, stride=(119168, 3136, 56, 1))], 1), {})
+cnt: 1, (([T([128, 50, 28, 28], f16), T([128, 11, 28, 28], f16, stride=(47824, 784, 28, 1))], 1), {})
+cnt: 1, (([T([128, 72, 14, 14], f16), T([128, 12, 14, 14], f16, stride=(16464, 196, 14, 1))], 1), {})
+cnt: 1, (([T([128, 84, 14, 14], f16), T([128, 11, 14, 14], f16, stride=(18620, 196, 14, 1))], 1), {})
+cnt: 1, (([T([128, 95, 14, 14], f16), T([128, 11, 14, 14], f16, stride=(20776, 196, 14, 1))], 1), {})
+cnt: 1, (([T([128, 106, 14, 14], f16), T([128, 11, 14, 14], f16, stride=(22932, 196, 14, 1))], 1), {})
+cnt: 1, (([T([128, 117, 14, 14], f16), T([128, 11, 14, 14], f16, stride=(25088, 196, 14, 1))], 1), {})
+cnt: 1, (([T([128, 140, 7, 7], f16), T([128, 11, 7, 7], f16, stride=(7399, 49, 7, 1))], 1), {})
+cnt: 1, (([T([128, 151, 7, 7], f16), T([128, 11, 7, 7], f16, stride=(7938, 49, 7, 1))], 1), {})
+cnt: 1, (([T([128, 162, 7, 7], f16), T([128, 12, 7, 7], f16, stride=(8526, 49, 7, 1))], 1), {})
+cnt: 1, (([T([128, 174, 7, 7], f16), T([128, 11, 7, 7], f16, stride=(9065, 49, 7, 1))], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 96, 112, 112], f16),), {})
+cnt: 1, ((T([128, 162, 56, 56], f16),), {})
+cnt: 1, ((T([128, 228, 56, 56], f16),), {})
+cnt: 1, ((T([128, 300, 28, 28], f16),), {})
+cnt: 1, ((T([128, 366, 28, 28], f16),), {})
+cnt: 1, ((T([128, 432, 14, 14], f16),), {})
+cnt: 1, ((T([128, 504, 14, 14], f16),), {})
+cnt: 1, ((T([128, 570, 14, 14], f16),), {})
+cnt: 1, ((T([128, 636, 14, 14], f16),), {})
+cnt: 1, ((T([128, 702, 14, 14], f16),), {})
+cnt: 1, ((T([128, 768, 14, 14], f16),), {})
+cnt: 1, ((T([128, 840, 7, 7], f16),), {})
+cnt: 1, ((T([128, 906, 7, 7], f16),), {})
+cnt: 1, ((T([128, 972, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([96, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([27, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 27, 56, 56], f16), T([162, 27, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 162, 56, 56], f16), T([162, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 162), {})
+cnt: 1, ((T([128, 162, 56, 56], f16), T([38, 162, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 38, 56, 56], f16), T([228, 38, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 228, 56, 56], f16), T([228, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 228), {})
+cnt: 1, ((T([128, 228, 1, 1], f16), T([19, 228, 1, 1], f16), T([19], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 19, 1, 1], f16), T([228, 19, 1, 1], f16), T([228], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 228, 28, 28], f16), T([50, 228, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 50, 28, 28], f16), T([300, 50, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), T([300, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 300), {})
+cnt: 1, ((T([128, 300, 1, 1], f16), T([25, 300, 1, 1], f16), T([25], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 25, 1, 1], f16), T([300, 25, 1, 1], f16), T([300], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), T([61, 300, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 61, 28, 28], f16), T([366, 61, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 366, 28, 28], f16), T([366, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 366), {})
+cnt: 1, ((T([128, 366, 1, 1], f16), T([30, 366, 1, 1], f16), T([30], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 30, 1, 1], f16), T([366, 30, 1, 1], f16), T([366], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), T([72, 366, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 72, 14, 14], f16), T([432, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), T([432, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 432), {})
+cnt: 1, ((T([128, 432, 1, 1], f16), T([36, 432, 1, 1], f16), T([36], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 36, 1, 1], f16), T([432, 36, 1, 1], f16), T([432], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), T([84, 432, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 84, 14, 14], f16), T([504, 84, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), T([504, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 504), {})
+cnt: 1, ((T([128, 504, 1, 1], f16), T([42, 504, 1, 1], f16), T([42], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 42, 1, 1], f16), T([504, 42, 1, 1], f16), T([504], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), T([95, 504, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 95, 14, 14], f16), T([570, 95, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), T([570, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 570), {})
+cnt: 1, ((T([128, 570, 1, 1], f16), T([47, 570, 1, 1], f16), T([47], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 47, 1, 1], f16), T([570, 47, 1, 1], f16), T([570], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), T([106, 570, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 106, 14, 14], f16), T([636, 106, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), T([636, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 636), {})
+cnt: 1, ((T([128, 636, 1, 1], f16), T([53, 636, 1, 1], f16), T([53], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 53, 1, 1], f16), T([636, 53, 1, 1], f16), T([636], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), T([117, 636, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 117, 14, 14], f16), T([702, 117, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), T([702, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 702), {})
+cnt: 1, ((T([128, 702, 1, 1], f16), T([58, 702, 1, 1], f16), T([58], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 58, 1, 1], f16), T([702, 58, 1, 1], f16), T([702], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), T([128, 702, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([768, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([768, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 768), {})
+cnt: 1, ((T([128, 768, 1, 1], f16), T([64, 768, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([768, 64, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([140, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 140, 7, 7], f16), T([840, 140, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), T([840, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 840), {})
+cnt: 1, ((T([128, 840, 1, 1], f16), T([70, 840, 1, 1], f16), T([70], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 70, 1, 1], f16), T([840, 70, 1, 1], f16), T([840], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), T([151, 840, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 151, 7, 7], f16), T([906, 151, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), T([906, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 906), {})
+cnt: 1, ((T([128, 906, 1, 1], f16), T([75, 906, 1, 1], f16), T([75], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 75, 1, 1], f16), T([906, 75, 1, 1], f16), T([906], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), T([162, 906, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 162, 7, 7], f16), T([972, 162, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), T([972, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 972), {})
+cnt: 1, ((T([128, 972, 1, 1], f16), T([81, 972, 1, 1], f16), T([81], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 81, 1, 1], f16), T([972, 81, 1, 1], f16), T([972], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), T([174, 972, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 174, 7, 7], f16), T([1044, 174, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), T([1044, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1044), {})
+cnt: 1, ((T([128, 1044, 1, 1], f16), T([87, 1044, 1, 1], f16), T([87], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 87, 1, 1], f16), T([1044, 87, 1, 1], f16), T([1044], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), T([185, 1044, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 185, 7, 7], f16), T([1280, 185, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 185, 7, 7], f16), T([1280, 185, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 185, 7, 7], f16), T([128, 1044, 7, 7], f16), T([185, 1044, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1044, 1, 1], f16), T([128, 87, 1, 1], f16), T([1044, 87, 1, 1], f16), [1044], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 87, 1, 1], f16), T([128, 1044, 1, 1], f16), T([87, 1044, 1, 1], f16), [87], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), T([128, 1044, 7, 7], f16), T([1044, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1044, [True, True, False]), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), T([128, 174, 7, 7], f16), T([1044, 174, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 174, 7, 7], f16), T([128, 972, 7, 7], f16), T([174, 972, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 972, 1, 1], f16), T([128, 81, 1, 1], f16), T([972, 81, 1, 1], f16), [972], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 81, 1, 1], f16), T([128, 972, 1, 1], f16), T([81, 972, 1, 1], f16), [81], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), T([128, 972, 7, 7], f16), T([972, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 972, [True, True, False]), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), T([128, 162, 7, 7], f16), T([972, 162, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 162, 7, 7], f16), T([128, 906, 7, 7], f16), T([162, 906, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 906, 1, 1], f16), T([128, 75, 1, 1], f16), T([906, 75, 1, 1], f16), [906], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 75, 1, 1], f16), T([128, 906, 1, 1], f16), T([75, 906, 1, 1], f16), [75], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), T([128, 906, 7, 7], f16), T([906, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 906, [True, True, False]), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), T([128, 151, 7, 7], f16), T([906, 151, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 151, 7, 7], f16), T([128, 840, 7, 7], f16), T([151, 840, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 840, 1, 1], f16), T([128, 70, 1, 1], f16), T([840, 70, 1, 1], f16), [840], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 70, 1, 1], f16), T([128, 840, 1, 1], f16), T([70, 840, 1, 1], f16), [70], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), T([128, 840, 7, 7], f16), T([840, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 840, [True, True, False]), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), T([128, 140, 7, 7], f16), T([840, 140, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 140, 7, 7], f16), T([128, 768, 7, 7], f16), T([140, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 768, 1, 1], f16), T([128, 64, 1, 1], f16), T([768, 64, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([128, 768, 1, 1], f16), T([64, 768, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([128, 768, 14, 14], f16), T([768, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 768, [True, True, False]), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([128, 128, 14, 14], f16), T([768, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128, 702, 14, 14], f16), T([128, 702, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 702, 1, 1], f16), T([128, 58, 1, 1], f16), T([702, 58, 1, 1], f16), [702], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 58, 1, 1], f16), T([128, 702, 1, 1], f16), T([58, 702, 1, 1], f16), [58], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), T([128, 702, 14, 14], f16), T([702, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 702, [True, True, False]), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), T([128, 117, 14, 14], f16), T([702, 117, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 117, 14, 14], f16), T([128, 636, 14, 14], f16), T([117, 636, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 636, 1, 1], f16), T([128, 53, 1, 1], f16), T([636, 53, 1, 1], f16), [636], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 53, 1, 1], f16), T([128, 636, 1, 1], f16), T([53, 636, 1, 1], f16), [53], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), T([128, 636, 14, 14], f16), T([636, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 636, [True, True, False]), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), T([128, 106, 14, 14], f16), T([636, 106, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 106, 14, 14], f16), T([128, 570, 14, 14], f16), T([106, 570, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 570, 1, 1], f16), T([128, 47, 1, 1], f16), T([570, 47, 1, 1], f16), [570], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 47, 1, 1], f16), T([128, 570, 1, 1], f16), T([47, 570, 1, 1], f16), [47], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), T([128, 570, 14, 14], f16), T([570, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 570, [True, True, False]), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), T([128, 95, 14, 14], f16), T([570, 95, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 95, 14, 14], f16), T([128, 504, 14, 14], f16), T([95, 504, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 504, 1, 1], f16), T([128, 42, 1, 1], f16), T([504, 42, 1, 1], f16), [504], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 42, 1, 1], f16), T([128, 504, 1, 1], f16), T([42, 504, 1, 1], f16), [42], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), T([128, 504, 14, 14], f16), T([504, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 504, [True, True, False]), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), T([128, 84, 14, 14], f16), T([504, 84, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 84, 14, 14], f16), T([128, 432, 14, 14], f16), T([84, 432, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 432, 1, 1], f16), T([128, 36, 1, 1], f16), T([432, 36, 1, 1], f16), [432], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 36, 1, 1], f16), T([128, 432, 1, 1], f16), T([36, 432, 1, 1], f16), [36], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), T([128, 432, 14, 14], f16), T([432, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 432, [True, True, False]), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), T([128, 72, 14, 14], f16), T([432, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 72, 14, 14], f16), T([128, 366, 14, 14], f16), T([72, 366, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 366, 1, 1], f16), T([128, 30, 1, 1], f16), T([366, 30, 1, 1], f16), [366], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 30, 1, 1], f16), T([128, 366, 1, 1], f16), T([30, 366, 1, 1], f16), [30], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), T([128, 366, 28, 28], f16), T([366, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 366, [True, True, False]), {})
+cnt: 1, ((T([128, 366, 28, 28], f16), T([128, 61, 28, 28], f16), T([366, 61, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 61, 28, 28], f16), T([128, 300, 28, 28], f16), T([61, 300, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 300, 1, 1], f16), T([128, 25, 1, 1], f16), T([300, 25, 1, 1], f16), [300], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 25, 1, 1], f16), T([128, 300, 1, 1], f16), T([25, 300, 1, 1], f16), [25], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), T([128, 300, 28, 28], f16), T([300, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 300, [True, True, False]), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), T([128, 50, 28, 28], f16), T([300, 50, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 50, 28, 28], f16), T([128, 228, 28, 28], f16), T([50, 228, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 228, 1, 1], f16), T([128, 19, 1, 1], f16), T([228, 19, 1, 1], f16), [228], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 19, 1, 1], f16), T([128, 228, 1, 1], f16), T([19, 228, 1, 1], f16), [19], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 228, 28, 28], f16), T([128, 228, 56, 56], f16), T([228, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 228, [True, True, False]), {})
+cnt: 1, ((T([128, 228, 56, 56], f16), T([128, 38, 56, 56], f16), T([228, 38, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 38, 56, 56], f16), T([128, 162, 56, 56], f16), T([38, 162, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 162, 56, 56], f16), T([128, 162, 56, 56], f16), T([162, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 162, [True, True, False]), {})
+cnt: 1, ((T([128, 162, 56, 56], f16), T([128, 27, 56, 56], f16), T([162, 27, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 27, 56, 56], f16), T([128, 96, 56, 56], f16), T([27, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 112, 112], f16), T([96, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 16, 112, 112], f16), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1280, 7, 7], f16, stride=(1280, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16, stride=(1044, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 972, 7, 7], f16, stride=(972, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 906, 7, 7], f16, stride=(906, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 840, 7, 7], f16, stride=(840, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 768, 7, 7], f16, stride=(768, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 702, 14, 14], f16, stride=(702, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 636, 14, 14], f16, stride=(636, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 570, 14, 14], f16, stride=(570, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 504, 14, 14], f16, stride=(504, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 432, 14, 14], f16, stride=(432, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 366, 14, 14], f16, stride=(366, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 300, 28, 28], f16, stride=(300, 1, 0, 0)), 784), {})
+cnt: 1, ((T([128, 228, 28, 28], f16, stride=(228, 1, 0, 0)), 784), {})
+Operator: aten.hardtanh.default
+cnt: 1, ((T([128, 32, 112, 112], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 162, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 228, 28, 28], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), 0.0, 6.0), {})
+Operator: aten.hardtanh_backward.default
+cnt: 1, ((T([128, 1044, 7, 7], f16), T([128, 1044, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), T([128, 972, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), T([128, 906, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), T([128, 840, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([128, 768, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), T([128, 702, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), T([128, 636, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), T([128, 570, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), T([128, 504, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), T([128, 432, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), T([128, 366, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), T([128, 300, 28, 28], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 228, 28, 28], f16), T([128, 228, 28, 28], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 162, 56, 56], f16), T([128, 162, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), 0.0, 6.0), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 228, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 228, 28, 28], f16), T([128, 228, 1, 1], f16)), {})
+cnt: 2, ((T([128, 300, 28, 28], f16), T([128, 300, 1, 1], f16)), {})
+cnt: 2, ((T([128, 366, 14, 14], f16), T([128, 366, 1, 1], f16)), {})
+cnt: 2, ((T([128, 432, 14, 14], f16), T([128, 432, 1, 1], f16)), {})
+cnt: 2, ((T([128, 504, 14, 14], f16), T([128, 504, 1, 1], f16)), {})
+cnt: 2, ((T([128, 570, 14, 14], f16), T([128, 570, 1, 1], f16)), {})
+cnt: 2, ((T([128, 636, 14, 14], f16), T([128, 636, 1, 1], f16)), {})
+cnt: 2, ((T([128, 702, 14, 14], f16), T([128, 702, 1, 1], f16)), {})
+cnt: 2, ((T([128, 768, 7, 7], f16), T([128, 768, 1, 1], f16)), {})
+cnt: 2, ((T([128, 840, 7, 7], f16), T([128, 840, 1, 1], f16)), {})
+cnt: 2, ((T([128, 906, 7, 7], f16), T([128, 906, 1, 1], f16)), {})
+cnt: 2, ((T([128, 972, 7, 7], f16), T([128, 972, 1, 1], f16)), {})
+cnt: 2, ((T([128, 1044, 7, 7], f16), T([128, 1044, 1, 1], f16)), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), T([128, 1044, 7, 7], f16)), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), T([128, 972, 7, 7], f16)), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), T([128, 906, 7, 7], f16)), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), T([128, 840, 7, 7], f16)), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([128, 768, 7, 7], f16)), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), T([128, 702, 14, 14], f16)), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), T([128, 636, 14, 14], f16)), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), T([128, 570, 14, 14], f16)), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), T([128, 504, 14, 14], f16)), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), T([128, 432, 14, 14], f16)), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), T([128, 366, 14, 14], f16)), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), T([128, 300, 28, 28], f16)), {})
+cnt: 1, ((T([128, 228, 28, 28], f16), T([128, 228, 28, 28], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 27, 56, 56], f16), T([27], f16), T([27], f16), T([27], f16), T([27], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 162, 56, 56], f16), T([162], f16), T([162], f16), T([162], f16), T([162], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 38, 56, 56], f16), T([38], f16), T([38], f16), T([38], f16), T([38], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 228, 56, 56], f16), T([228], f16), T([228], f16), T([228], f16), T([228], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 228, 28, 28], f16), T([228], f16), T([228], f16), T([228], f16), T([228], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 19, 1, 1], f16), T([19], f16), T([19], f16), T([19], f16), T([19], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 50, 28, 28], f16), T([50], f16), T([50], f16), T([50], f16), T([50], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 300, 28, 28], f16), T([300], f16), T([300], f16), T([300], f16), T([300], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 25, 1, 1], f16), T([25], f16), T([25], f16), T([25], f16), T([25], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 61, 28, 28], f16), T([61], f16), T([61], f16), T([61], f16), T([61], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 366, 28, 28], f16), T([366], f16), T([366], f16), T([366], f16), T([366], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), T([366], f16), T([366], f16), T([366], f16), T([366], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 30, 1, 1], f16), T([30], f16), T([30], f16), T([30], f16), T([30], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 72, 14, 14], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 432, 14, 14], f16), T([432], f16), T([432], f16), T([432], f16), T([432], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 36, 1, 1], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 84, 14, 14], f16), T([84], f16), T([84], f16), T([84], f16), T([84], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 504, 14, 14], f16), T([504], f16), T([504], f16), T([504], f16), T([504], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 42, 1, 1], f16), T([42], f16), T([42], f16), T([42], f16), T([42], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 95, 14, 14], f16), T([95], f16), T([95], f16), T([95], f16), T([95], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 570, 14, 14], f16), T([570], f16), T([570], f16), T([570], f16), T([570], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 47, 1, 1], f16), T([47], f16), T([47], f16), T([47], f16), T([47], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 106, 14, 14], f16), T([106], f16), T([106], f16), T([106], f16), T([106], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 636, 14, 14], f16), T([636], f16), T([636], f16), T([636], f16), T([636], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 53, 1, 1], f16), T([53], f16), T([53], f16), T([53], f16), T([53], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 117, 14, 14], f16), T([117], f16), T([117], f16), T([117], f16), T([117], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 702, 14, 14], f16), T([702], f16), T([702], f16), T([702], f16), T([702], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 58, 1, 1], f16), T([58], f16), T([58], f16), T([58], f16), T([58], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 140, 7, 7], f16), T([140], f16), T([140], f16), T([140], f16), T([140], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 840, 7, 7], f16), T([840], f16), T([840], f16), T([840], f16), T([840], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 70, 1, 1], f16), T([70], f16), T([70], f16), T([70], f16), T([70], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 151, 7, 7], f16), T([151], f16), T([151], f16), T([151], f16), T([151], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 906, 7, 7], f16), T([906], f16), T([906], f16), T([906], f16), T([906], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 75, 1, 1], f16), T([75], f16), T([75], f16), T([75], f16), T([75], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 162, 7, 7], f16), T([162], f16), T([162], f16), T([162], f16), T([162], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 972, 7, 7], f16), T([972], f16), T([972], f16), T([972], f16), T([972], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 81, 1, 1], f16), T([81], f16), T([81], f16), T([81], f16), T([81], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 174, 7, 7], f16), T([174], f16), T([174], f16), T([174], f16), T([174], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 1044, 7, 7], f16), T([1044], f16), T([1044], f16), T([1044], f16), T([1044], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 87, 1, 1], f16), T([87], f16), T([87], f16), T([87], f16), T([87], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 185, 7, 7], f16), T([185], f16), T([185], f16), T([185], f16), T([185], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 185, 7, 7], f16), T([128, 185, 7, 7], f16), T([185], f16), T([185], f16), T([185], f16), T([185], f32), T([185], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 87, 1, 1], f16), T([128, 87, 1, 1], f16), T([87], f16), T([87], f16), T([87], f16), T([87], f32), T([87], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 1044, 7, 7], f16), T([128, 1044, 7, 7], f16), T([1044], f16), T([1044], f16), T([1044], f16), T([1044], f32), T([1044], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 174, 7, 7], f16), T([128, 174, 7, 7], f16), T([174], f16), T([174], f16), T([174], f16), T([174], f32), T([174], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 81, 1, 1], f16), T([128, 81, 1, 1], f16), T([81], f16), T([81], f16), T([81], f16), T([81], f32), T([81], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 972, 7, 7], f16), T([128, 972, 7, 7], f16), T([972], f16), T([972], f16), T([972], f16), T([972], f32), T([972], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 162, 7, 7], f16), T([128, 162, 7, 7], f16), T([162], f16), T([162], f16), T([162], f16), T([162], f32), T([162], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 75, 1, 1], f16), T([128, 75, 1, 1], f16), T([75], f16), T([75], f16), T([75], f16), T([75], f32), T([75], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 906, 7, 7], f16), T([128, 906, 7, 7], f16), T([906], f16), T([906], f16), T([906], f16), T([906], f32), T([906], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 151, 7, 7], f16), T([128, 151, 7, 7], f16), T([151], f16), T([151], f16), T([151], f16), T([151], f32), T([151], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 70, 1, 1], f16), T([128, 70, 1, 1], f16), T([70], f16), T([70], f16), T([70], f16), T([70], f32), T([70], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 840, 7, 7], f16), T([128, 840, 7, 7], f16), T([840], f16), T([840], f16), T([840], f16), T([840], f32), T([840], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 140, 7, 7], f16), T([128, 140, 7, 7], f16), T([140], f16), T([140], f16), T([140], f16), T([140], f32), T([140], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([128, 64, 1, 1], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([128, 768, 7, 7], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([128, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 14, 14], f16), T([128, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 58, 1, 1], f16), T([128, 58, 1, 1], f16), T([58], f16), T([58], f16), T([58], f16), T([58], f32), T([58], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 702, 14, 14], f16), T([128, 702, 14, 14], f16), T([702], f16), T([702], f16), T([702], f16), T([702], f32), T([702], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 117, 14, 14], f16), T([128, 117, 14, 14], f16), T([117], f16), T([117], f16), T([117], f16), T([117], f32), T([117], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 53, 1, 1], f16), T([128, 53, 1, 1], f16), T([53], f16), T([53], f16), T([53], f16), T([53], f32), T([53], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 636, 14, 14], f16), T([128, 636, 14, 14], f16), T([636], f16), T([636], f16), T([636], f16), T([636], f32), T([636], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 106, 14, 14], f16), T([128, 106, 14, 14], f16), T([106], f16), T([106], f16), T([106], f16), T([106], f32), T([106], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 47, 1, 1], f16), T([128, 47, 1, 1], f16), T([47], f16), T([47], f16), T([47], f16), T([47], f32), T([47], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 570, 14, 14], f16), T([128, 570, 14, 14], f16), T([570], f16), T([570], f16), T([570], f16), T([570], f32), T([570], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 95, 14, 14], f16), T([128, 95, 14, 14], f16), T([95], f16), T([95], f16), T([95], f16), T([95], f32), T([95], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 42, 1, 1], f16), T([128, 42, 1, 1], f16), T([42], f16), T([42], f16), T([42], f16), T([42], f32), T([42], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 504, 14, 14], f16), T([128, 504, 14, 14], f16), T([504], f16), T([504], f16), T([504], f16), T([504], f32), T([504], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 84, 14, 14], f16), T([128, 84, 14, 14], f16), T([84], f16), T([84], f16), T([84], f16), T([84], f32), T([84], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 36, 1, 1], f16), T([128, 36, 1, 1], f16), T([36], f16), T([36], f16), T([36], f16), T([36], f32), T([36], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 432, 14, 14], f16), T([128, 432, 14, 14], f16), T([432], f16), T([432], f16), T([432], f16), T([432], f32), T([432], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 72, 14, 14], f16), T([128, 72, 14, 14], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 30, 1, 1], f16), T([128, 30, 1, 1], f16), T([30], f16), T([30], f16), T([30], f16), T([30], f32), T([30], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), T([128, 366, 14, 14], f16), T([366], f16), T([366], f16), T([366], f16), T([366], f32), T([366], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 366, 28, 28], f16), T([128, 366, 28, 28], f16), T([366], f16), T([366], f16), T([366], f16), T([366], f32), T([366], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 61, 28, 28], f16), T([128, 61, 28, 28], f16), T([61], f16), T([61], f16), T([61], f16), T([61], f32), T([61], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 25, 1, 1], f16), T([128, 25, 1, 1], f16), T([25], f16), T([25], f16), T([25], f16), T([25], f32), T([25], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 300, 28, 28], f16), T([128, 300, 28, 28], f16), T([300], f16), T([300], f16), T([300], f16), T([300], f32), T([300], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 50, 28, 28], f16), T([128, 50, 28, 28], f16), T([50], f16), T([50], f16), T([50], f16), T([50], f32), T([50], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 19, 1, 1], f16), T([128, 19, 1, 1], f16), T([19], f16), T([19], f16), T([19], f16), T([19], f32), T([19], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 228, 28, 28], f16), T([128, 228, 28, 28], f16), T([228], f16), T([228], f16), T([228], f16), T([228], f32), T([228], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 228, 56, 56], f16), T([128, 228, 56, 56], f16), T([228], f16), T([228], f16), T([228], f16), T([228], f32), T([228], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 38, 56, 56], f16), T([128, 38, 56, 56], f16), T([38], f16), T([38], f16), T([38], f16), T([38], f32), T([38], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 162, 56, 56], f16), T([128, 162, 56, 56], f16), T([162], f16), T([162], f16), T([162], f16), T([162], f32), T([162], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 27, 56, 56], f16), T([128, 27, 56, 56], f16), T([27], f16), T([27], f16), T([27], f16), T([27], f32), T([27], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 19, 1, 1], f16),), {})
+cnt: 1, ((T([128, 25, 1, 1], f16),), {})
+cnt: 1, ((T([128, 30, 1, 1], f16),), {})
+cnt: 1, ((T([128, 36, 1, 1], f16),), {})
+cnt: 1, ((T([128, 42, 1, 1], f16),), {})
+cnt: 1, ((T([128, 47, 1, 1], f16),), {})
+cnt: 1, ((T([128, 53, 1, 1], f16),), {})
+cnt: 1, ((T([128, 58, 1, 1], f16),), {})
+cnt: 1, ((T([128, 64, 1, 1], f16),), {})
+cnt: 1, ((T([128, 70, 1, 1], f16),), {})
+cnt: 1, ((T([128, 75, 1, 1], f16),), {})
+cnt: 1, ((T([128, 81, 1, 1], f16),), {})
+cnt: 1, ((T([128, 87, 1, 1], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([128, 228, 1, 1], f16),), {})
+cnt: 1, ((T([128, 300, 1, 1], f16),), {})
+cnt: 1, ((T([128, 366, 1, 1], f16),), {})
+cnt: 1, ((T([128, 432, 1, 1], f16),), {})
+cnt: 1, ((T([128, 504, 1, 1], f16),), {})
+cnt: 1, ((T([128, 570, 1, 1], f16),), {})
+cnt: 1, ((T([128, 636, 1, 1], f16),), {})
+cnt: 1, ((T([128, 702, 1, 1], f16),), {})
+cnt: 1, ((T([128, 768, 1, 1], f16),), {})
+cnt: 1, ((T([128, 840, 1, 1], f16),), {})
+cnt: 1, ((T([128, 906, 1, 1], f16),), {})
+cnt: 1, ((T([128, 972, 1, 1], f16),), {})
+cnt: 1, ((T([128, 1044, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([128, 1044, 1, 1], f16), T([128, 1044, 1, 1], f16)), {})
+cnt: 1, ((T([128, 972, 1, 1], f16), T([128, 972, 1, 1], f16)), {})
+cnt: 1, ((T([128, 906, 1, 1], f16), T([128, 906, 1, 1], f16)), {})
+cnt: 1, ((T([128, 840, 1, 1], f16), T([128, 840, 1, 1], f16)), {})
+cnt: 1, ((T([128, 768, 1, 1], f16), T([128, 768, 1, 1], f16)), {})
+cnt: 1, ((T([128, 702, 1, 1], f16), T([128, 702, 1, 1], f16)), {})
+cnt: 1, ((T([128, 636, 1, 1], f16), T([128, 636, 1, 1], f16)), {})
+cnt: 1, ((T([128, 570, 1, 1], f16), T([128, 570, 1, 1], f16)), {})
+cnt: 1, ((T([128, 504, 1, 1], f16), T([128, 504, 1, 1], f16)), {})
+cnt: 1, ((T([128, 432, 1, 1], f16), T([128, 432, 1, 1], f16)), {})
+cnt: 1, ((T([128, 366, 1, 1], f16), T([128, 366, 1, 1], f16)), {})
+cnt: 1, ((T([128, 300, 1, 1], f16), T([128, 300, 1, 1], f16)), {})
+cnt: 1, ((T([128, 228, 1, 1], f16), T([128, 228, 1, 1], f16)), {})
+Operator: aten.silu_.default
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 96, 112, 112], f16),), {})
+cnt: 1, ((T([128, 162, 56, 56], f16),), {})
+cnt: 1, ((T([128, 228, 56, 56], f16),), {})
+cnt: 1, ((T([128, 300, 28, 28], f16),), {})
+cnt: 1, ((T([128, 366, 28, 28], f16),), {})
+cnt: 1, ((T([128, 432, 14, 14], f16),), {})
+cnt: 1, ((T([128, 504, 14, 14], f16),), {})
+cnt: 1, ((T([128, 570, 14, 14], f16),), {})
+cnt: 1, ((T([128, 636, 14, 14], f16),), {})
+cnt: 1, ((T([128, 702, 14, 14], f16),), {})
+cnt: 1, ((T([128, 768, 14, 14], f16),), {})
+cnt: 1, ((T([128, 840, 7, 7], f16),), {})
+cnt: 1, ((T([128, 906, 7, 7], f16),), {})
+cnt: 1, ((T([128, 972, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16)), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), T([128, 1044, 7, 7], f16)), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), T([128, 972, 7, 7], f16)), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), T([128, 906, 7, 7], f16)), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), T([128, 840, 7, 7], f16)), {})
+cnt: 1, ((T([128, 768, 14, 14], f16), T([128, 768, 14, 14], f16)), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), T([128, 702, 14, 14], f16)), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), T([128, 636, 14, 14], f16)), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), T([128, 570, 14, 14], f16)), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), T([128, 504, 14, 14], f16)), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), T([128, 432, 14, 14], f16)), {})
+cnt: 1, ((T([128, 366, 28, 28], f16), T([128, 366, 28, 28], f16)), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), T([128, 300, 28, 28], f16)), {})
+cnt: 1, ((T([128, 228, 56, 56], f16), T([128, 228, 56, 56], f16)), {})
+cnt: 1, ((T([128, 162, 56, 56], f16), T([128, 162, 56, 56], f16)), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 96, 112, 112], f16)), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16)), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([128, 11, 7, 7], f16, stride=(9065, 49, 7, 1)), [128, 185, 7, 7], 1, 174, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 185, 7, 7], f16), [128, 185, 7, 7], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 174, 7, 7], f16, stride=(9065, 49, 7, 1)), [128, 185, 7, 7], 1, 0, 174, 1), {})
+cnt: 1, ((T([128, 12, 7, 7], f16, stride=(8526, 49, 7, 1)), [128, 174, 7, 7], 1, 162, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 174, 7, 7], f16), [128, 174, 7, 7], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 162, 7, 7], f16, stride=(8526, 49, 7, 1)), [128, 174, 7, 7], 1, 0, 162, 1), {})
+cnt: 1, ((T([128, 11, 7, 7], f16, stride=(7938, 49, 7, 1)), [128, 162, 7, 7], 1, 151, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 162, 7, 7], f16), [128, 162, 7, 7], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 151, 7, 7], f16, stride=(7938, 49, 7, 1)), [128, 162, 7, 7], 1, 0, 151, 1), {})
+cnt: 1, ((T([128, 11, 7, 7], f16, stride=(7399, 49, 7, 1)), [128, 151, 7, 7], 1, 140, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 151, 7, 7], f16), [128, 151, 7, 7], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 140, 7, 7], f16, stride=(7399, 49, 7, 1)), [128, 151, 7, 7], 1, 0, 140, 1), {})
+cnt: 1, ((T([128, 11, 14, 14], f16, stride=(25088, 196, 14, 1)), [128, 128, 14, 14], 1, 117, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 128, 14, 14], f16), [128, 128, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 117, 14, 14], f16, stride=(25088, 196, 14, 1)), [128, 128, 14, 14], 1, 0, 117, 1), {})
+cnt: 1, ((T([128, 11, 14, 14], f16, stride=(22932, 196, 14, 1)), [128, 117, 14, 14], 1, 106, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 117, 14, 14], f16), [128, 117, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 106, 14, 14], f16, stride=(22932, 196, 14, 1)), [128, 117, 14, 14], 1, 0, 106, 1), {})
+cnt: 1, ((T([128, 11, 14, 14], f16, stride=(20776, 196, 14, 1)), [128, 106, 14, 14], 1, 95, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 106, 14, 14], f16), [128, 106, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 95, 14, 14], f16, stride=(20776, 196, 14, 1)), [128, 106, 14, 14], 1, 0, 95, 1), {})
+cnt: 1, ((T([128, 11, 14, 14], f16, stride=(18620, 196, 14, 1)), [128, 95, 14, 14], 1, 84, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 95, 14, 14], f16), [128, 95, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 84, 14, 14], f16, stride=(18620, 196, 14, 1)), [128, 95, 14, 14], 1, 0, 84, 1), {})
+cnt: 1, ((T([128, 12, 14, 14], f16, stride=(16464, 196, 14, 1)), [128, 84, 14, 14], 1, 72, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 84, 14, 14], f16), [128, 84, 14, 14], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 72, 14, 14], f16, stride=(16464, 196, 14, 1)), [128, 84, 14, 14], 1, 0, 72, 1), {})
+cnt: 1, ((T([128, 11, 28, 28], f16, stride=(47824, 784, 28, 1)), [128, 61, 28, 28], 1, 50, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 61, 28, 28], f16), [128, 61, 28, 28], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 50, 28, 28], f16, stride=(47824, 784, 28, 1)), [128, 61, 28, 28], 1, 0, 50, 1), {})
+cnt: 1, ((T([128, 11, 56, 56], f16, stride=(119168, 3136, 56, 1)), [128, 38, 56, 56], 1, 27, 9223372036854775807, 1), {})
+cnt: 2, ((T([128, 38, 56, 56], f16), [128, 38, 56, 56], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([128, 27, 56, 56], f16, stride=(119168, 3136, 56, 1)), [128, 38, 56, 56], 1, 0, 27, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 1, ((T([128, 1044, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 972, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 906, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 840, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 702, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 636, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 570, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 504, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 432, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 366, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 300, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 228, 28, 28], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 87, 1, 1], f16), T([128, 87, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 81, 1, 1], f16), T([128, 81, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 75, 1, 1], f16), T([128, 75, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 70, 1, 1], f16), T([128, 70, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 64, 1, 1], f16), T([128, 64, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 58, 1, 1], f16), T([128, 58, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 53, 1, 1], f16), T([128, 53, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 47, 1, 1], f16), T([128, 47, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 42, 1, 1], f16), T([128, 42, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 36, 1, 1], f16), T([128, 36, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 30, 1, 1], f16), T([128, 30, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 25, 1, 1], f16), T([128, 25, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 19, 1, 1], f16), T([128, 19, 1, 1], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/sebotnet33ts_256_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/sebotnet33ts_256_training.txt
new file mode 100644
index 0000000000000..cdfa544bf9c0f
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/sebotnet33ts_256_training.txt
@@ -0,0 +1,334 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 1, ((T([256, 1024, 1024], f16), -1, False), {})
+cnt: 2, ((T([256, 256, 256], f16), -1, False), {})
+cnt: 1, ((T([256, 64, 64], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 1, ((T([256, 64, 64], f16), T([256, 64, 64], f16), -1, f16), {})
+cnt: 2, ((T([256, 256, 256], f16), T([256, 256, 256], f16), -1, f16), {})
+cnt: 1, ((T([256, 1024, 1024], f16), T([256, 1024, 1024], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 3, ((T([64, 128, 32, 32], f16), [256, 32, 1024]), {})
+cnt: 1, ((T([256, 1024, 1024], f16), [256, 1024, 1024]), {})
+cnt: 2, ((T([256, 32, 32, 32], f16), [262144, 32]), {})
+cnt: 2, ((T([262144, 63], f16), [256, 32, 32, 63]), {})
+cnt: 1, ((T([256, 32, 32, 32, 32], f16), [256, 1024, 1024]), {})
+cnt: 1, ((T([256, 1024, 32], f16), [256, 1024, 32]), {})
+cnt: 3, ((T([256, 32, 1024], f16), [64, 128, 32, 32]), {})
+cnt: 3, ((T([64, 256, 16, 16], f16), [256, 64, 256]), {})
+cnt: 2, ((T([256, 256, 256], f16), [256, 256, 256]), {})
+cnt: 2, ((T([256, 16, 16, 64], f16), [65536, 64]), {})
+cnt: 4, ((T([65536, 31], f16), [256, 16, 16, 31]), {})
+cnt: 2, ((T([256, 16, 16, 16, 16], f16), [256, 256, 256]), {})
+cnt: 1, ((T([256, 256, 64], f16), [256, 256, 64]), {})
+cnt: 3, ((T([256, 64, 256], f16), [64, 256, 16, 16]), {})
+cnt: 3, ((T([64, 512, 16, 16], f16), [256, 128, 256]), {})
+cnt: 2, ((T([256, 16, 16, 128], f16), [65536, 128]), {})
+cnt: 1, ((T([256, 256, 128], f16), [256, 256, 128]), {})
+cnt: 3, ((T([256, 128, 256], f16), [64, 512, 16, 16]), {})
+cnt: 3, ((T([64, 512, 8, 8], f16), [256, 128, 64]), {})
+cnt: 1, ((T([256, 64, 64], f16), [256, 64, 64]), {})
+cnt: 2, ((T([256, 8, 8, 128], f16), [16384, 128]), {})
+cnt: 2, ((T([16384, 15], f16), [256, 8, 8, 15]), {})
+cnt: 1, ((T([256, 8, 8, 8, 8], f16), [256, 64, 64]), {})
+cnt: 1, ((T([256, 64, 128], f16), [256, 64, 128]), {})
+cnt: 3, ((T([256, 128, 64], f16), [64, 512, 8, 8]), {})
+cnt: 1, ((T([256, 8, 8, 128], f16), [256, 64, 128]), {})
+cnt: 1, ((T([256, 16, 16, 128], f16), [256, 256, 128]), {})
+cnt: 1, ((T([256, 16, 16, 64], f16), [256, 256, 64]), {})
+cnt: 1, ((T([256, 32, 32, 32], f16), [256, 1024, 32]), {})
+Operator: aten.add.Tensor
+cnt: 38, ((T([], i64), 1), {})
+cnt: 4, ((T([64, 256, 64, 64], f16), T([64, 256, 64, 64], f16)), {})
+cnt: 6, ((T([64, 512, 32, 32], f16), T([64, 512, 32, 32], f16)), {})
+cnt: 1, ((T([256, 32, 32, 32, 32], f16, stride=(66528, 63, 2079, 1, 0)), T([256, 32, 32, 32, 32], f16, stride=(66528, 2079, 63, 0, 1))), {})
+cnt: 1, ((T([256, 1024, 1024], f16), T([256, 1024, 1024], f16)), {})
+cnt: 6, ((T([64, 1024, 16, 16], f16), T([64, 1024, 16, 16], f16)), {})
+cnt: 2, ((T([256, 16, 16, 16, 16], f16, stride=(8432, 31, 527, 1, 0)), T([256, 16, 16, 16, 16], f16, stride=(8432, 527, 31, 0, 1))), {})
+cnt: 2, ((T([256, 256, 256], f16), T([256, 256, 256], f16)), {})
+cnt: 3, ((T([64, 1536, 8, 8], f16), T([64, 1536, 8, 8], f16)), {})
+cnt: 1, ((T([256, 8, 8, 8, 8], f16, stride=(1080, 15, 135, 1, 0)), T([256, 8, 8, 8, 8], f16, stride=(1080, 135, 15, 0, 1))), {})
+cnt: 1, ((T([256, 64, 64], f16), T([256, 64, 64], f16)), {})
+cnt: 1, ((T([256, 8, 8, 128], f16, stride=(8192, 128, 1024, 1)), T([256, 8, 8, 128], f16)), {})
+cnt: 1, ((T([256, 64, 128], f16), T([256, 64, 128], f16)), {})
+cnt: 1, ((T([256, 16, 16, 128], f16, stride=(32768, 128, 2048, 1)), T([256, 16, 16, 128], f16)), {})
+cnt: 1, ((T([256, 256, 128], f16), T([256, 256, 128], f16)), {})
+cnt: 1, ((T([256, 16, 16, 64], f16, stride=(16384, 64, 1024, 1)), T([256, 16, 16, 64], f16)), {})
+cnt: 1, ((T([256, 256, 64], f16), T([256, 256, 64], f16)), {})
+cnt: 2, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16)), {})
+cnt: 1, ((T([256, 32, 32, 32], f16, stride=(32768, 32, 1024, 1)), T([256, 32, 32, 32], f16)), {})
+cnt: 1, ((T([256, 1024, 32], f16), T([256, 1024, 32], f16)), {})
+cnt: 2, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16)), {})
+cnt: 3, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([64, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([64, 512, 16, 16], f16), [2, 2], [2, 2]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([64, 512, 8, 8], f16), T([64, 512, 16, 16], f16), [2, 2], [2, 2], [0, 0], False, True, None), {})
+Operator: aten.bmm.default
+cnt: 2, ((T([256, 1024, 32], f16, stride=(32768, 1, 1024)), T([256, 32, 1024], f16)), {})
+cnt: 2, ((T([256, 1024, 1024], f16), T([256, 1024, 32], f16, stride=(32768, 1, 1024))), {})
+cnt: 2, ((T([256, 256, 64], f16, stride=(16384, 1, 256)), T([256, 64, 256], f16)), {})
+cnt: 2, ((T([256, 256, 256], f16), T([256, 256, 64], f16, stride=(16384, 1, 256))), {})
+cnt: 2, ((T([256, 256, 128], f16, stride=(32768, 1, 256)), T([256, 128, 256], f16)), {})
+cnt: 2, ((T([256, 256, 256], f16), T([256, 256, 128], f16, stride=(32768, 1, 256))), {})
+cnt: 2, ((T([256, 64, 128], f16, stride=(8192, 1, 64)), T([256, 128, 64], f16)), {})
+cnt: 2, ((T([256, 64, 64], f16), T([256, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 1, ((T([256, 64, 64], f16, stride=(4096, 1, 64)), T([256, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 1, ((T([256, 128, 64], f16), T([256, 64, 64], f16)), {})
+cnt: 1, ((T([256, 256, 256], f16, stride=(65536, 1, 256)), T([256, 256, 128], f16, stride=(32768, 1, 256))), {})
+cnt: 1, ((T([256, 128, 256], f16), T([256, 256, 256], f16)), {})
+cnt: 1, ((T([256, 256, 256], f16, stride=(65536, 1, 256)), T([256, 256, 64], f16, stride=(16384, 1, 256))), {})
+cnt: 1, ((T([256, 64, 256], f16), T([256, 256, 256], f16)), {})
+cnt: 1, ((T([256, 1024, 1024], f16, stride=(1048576, 1, 1024)), T([256, 1024, 32], f16, stride=(32768, 1, 1024))), {})
+cnt: 1, ((T([256, 32, 1024], f16), T([256, 1024, 1024], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16)], 1), {})
+cnt: 1, (([T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16)], 1), {})
+cnt: 1, (([T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16)], 1), {})
+cnt: 1, (([T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 256, 256], f16),), {})
+cnt: 1, ((T([64, 24, 128, 128], f16),), {})
+cnt: 1, ((T([64, 32, 128, 128], f16),), {})
+cnt: 5, ((T([64, 64, 64, 64], f16),), {})
+cnt: 2, ((T([64, 256, 64, 64], f16),), {})
+cnt: 1, ((T([64, 128, 64, 64], f16),), {})
+cnt: 5, ((T([64, 128, 32, 32], f16),), {})
+cnt: 3, ((T([64, 512, 32, 32], f16),), {})
+cnt: 1, ((T([64, 256, 32, 32], f16),), {})
+cnt: 5, ((T([64, 256, 16, 16], f16),), {})
+cnt: 3, ((T([64, 1024, 16, 16], f16),), {})
+cnt: 1, ((T([64, 512, 16, 16], f16),), {})
+cnt: 3, ((T([64, 512, 8, 8], f16),), {})
+cnt: 2, ((T([64, 1536, 8, 8], f16),), {})
+cnt: 1, ((T([64, 1280, 8, 8], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 2, ((T([8192, 32, 63], f16), [0, 1], 0.0), {})
+cnt: 2, ((T([8192, 2048], f16), [0, 31], 0.0), {})
+cnt: 4, ((T([4096, 16, 31], f16), [0, 1], 0.0), {})
+cnt: 4, ((T([4096, 512], f16), [0, 15], 0.0), {})
+cnt: 2, ((T([2048, 8, 15], f16), [0, 1], 0.0), {})
+cnt: 2, ((T([2048, 128], f16), [0, 7], 0.0), {})
+cnt: 2, ((T([2048, 135], f16), [0, -7]), {})
+cnt: 2, ((T([2048, 8, 16], f16), [0, -1]), {})
+cnt: 4, ((T([4096, 527], f16), [0, -15]), {})
+cnt: 4, ((T([4096, 16, 32], f16), [0, -1]), {})
+cnt: 2, ((T([8192, 2079], f16), [0, -31]), {})
+cnt: 2, ((T([8192, 32, 64], f16), [0, -1]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 256, 256], f16), T([24, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 24, 128, 128], f16), T([32, 24, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 64, 64], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 64, 64, 64], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 64, 1, 1], f16), T([8, 64, 1, 1], f16), T([8], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 8, 1, 1], f16), T([64, 8, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 64, 64, 64], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 64, 64], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 64, 64], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([128, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 128, 1, 1], f16), T([8, 128, 1, 1], f16), T([8], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 8, 1, 1], f16), T([128, 8, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 128, 32, 32], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 64, 64], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 512, 32, 32], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 32, 32], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 32, 32], f16), T([384, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 32, 32], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([256, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 256, 1, 1], f16), T([16, 256, 1, 1], f16), T([16], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 16, 1, 1], f16), T([256, 16, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 32, 32], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 16, 16], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 16, 16], f16), T([768, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([1536, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 512, 8, 8], f16), T([1536, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1024, 16, 16], f16), T([1536, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1536, 8, 8], f16), T([512, 1536, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 1536, 8, 8], f16), T([1280, 1536, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 1280, 8, 8], f16), T([64, 1536, 8, 8], f16), T([1280, 1536, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 1536, 8, 8], f16), T([64, 512, 8, 8], f16), T([1536, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 8, 8], f16), T([64, 1536, 8, 8], f16), T([512, 1536, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 1536, 8, 8], f16), T([64, 1024, 16, 16], f16), T([1536, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 1536, 16, 16], f16), T([64, 512, 16, 16], f16), T([1536, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([64, 1024, 16, 16], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 1024, 16, 16], f16), T([64, 256, 16, 16], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 768, 16, 16], f16), T([64, 256, 16, 16], f16), T([768, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 256, 16, 16], f16), T([64, 1024, 16, 16], f16), T([256, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 256, 1, 1], f16), T([64, 16, 1, 1], f16), T([256, 16, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([64, 16, 1, 1], f16), T([64, 256, 1, 1], f16), T([16, 256, 1, 1], f16), [16], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 1024, 16, 16], f16), T([64, 512, 32, 32], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 16, 16], f16), T([64, 256, 32, 32], f16), T([256, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([64, 512, 32, 32], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 512, 32, 32], f16), T([64, 128, 32, 32], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 384, 32, 32], f16), T([64, 128, 32, 32], f16), T([384, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 128, 32, 32], f16), T([64, 512, 32, 32], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 128, 1, 1], f16), T([64, 8, 1, 1], f16), T([128, 8, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([64, 8, 1, 1], f16), T([64, 128, 1, 1], f16), T([8, 128, 1, 1], f16), [8], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 512, 32, 32], f16), T([64, 256, 64, 64], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 32, 32], f16), T([64, 128, 64, 64], f16), T([128, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([64, 256, 64, 64], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 256, 64, 64], f16), T([64, 64, 64, 64], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 64, 1, 1], f16), T([64, 8, 1, 1], f16), T([64, 8, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([64, 8, 1, 1], f16), T([64, 64, 1, 1], f16), T([8, 64, 1, 1], f16), [8], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 64, 64], f16), T([64, 256, 64, 64], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 64, 64], f16), T([64, 32, 128, 128], f16), T([64, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 24, 128, 128], f16), T([32, 24, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 24, 128, 128], f16), T([64, 3, 256, 256], f16), T([24, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 256, 256], f16), T([64, 3, 256, 256], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 1280, 8, 8], f16, stride=(1280, 1, 0, 0)), 64), {})
+cnt: 2, ((T([64, 256, 16, 16], f16, stride=(256, 1, 0, 0)), 256), {})
+cnt: 2, ((T([64, 128, 32, 32], f16, stride=(128, 1, 0, 0)), 1024), {})
+cnt: 2, ((T([64, 64, 64, 64], f16, stride=(64, 1, 0, 0)), 4096), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 2, ((T([64, 64, 64, 64], f16), [2, 3], True), {})
+cnt: 2, ((T([64, 128, 32, 32], f16), [2, 3], True), {})
+cnt: 2, ((T([64, 256, 16, 16], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 1280, 8, 8], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 2, ((T([262144, 32], f16), T([32, 63], f16, stride=(1, 32))), {})
+cnt: 2, ((T([65536, 64], f16), T([64, 31], f16, stride=(1, 64))), {})
+cnt: 2, ((T([65536, 128], f16), T([128, 31], f16, stride=(1, 128))), {})
+cnt: 2, ((T([16384, 128], f16), T([128, 15], f16, stride=(1, 128))), {})
+cnt: 1, ((T([64, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 1280], f16)), {})
+cnt: 2, ((T([15, 16384], f16, stride=(1, 15)), T([16384, 128], f16)), {})
+cnt: 2, ((T([16384, 15], f16), T([15, 128], f16)), {})
+cnt: 2, ((T([31, 65536], f16, stride=(1, 31)), T([65536, 128], f16)), {})
+cnt: 2, ((T([65536, 31], f16), T([31, 128], f16)), {})
+cnt: 2, ((T([31, 65536], f16, stride=(1, 31)), T([65536, 64], f16)), {})
+cnt: 2, ((T([65536, 31], f16), T([31, 64], f16)), {})
+cnt: 2, ((T([63, 262144], f16, stride=(1, 63)), T([262144, 32], f16)), {})
+cnt: 2, ((T([262144, 63], f16), T([63, 32], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([64, 64, 64, 64], f16), T([64, 64, 1, 1], f16)), {})
+cnt: 4, ((T([64, 128, 32, 32], f16), T([64, 128, 1, 1], f16)), {})
+cnt: 2, ((T([256, 1024, 1024], f16), 0.1767766952966369), {})
+cnt: 4, ((T([64, 256, 16, 16], f16), T([64, 256, 1, 1], f16)), {})
+cnt: 2, ((T([256, 256, 256], f16), 0.125), {})
+cnt: 2, ((T([256, 256, 256], f16), 0.08838834764831845), {})
+cnt: 2, ((T([256, 64, 64], f16), 0.08838834764831845), {})
+cnt: 2, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16)), {})
+cnt: 2, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16)), {})
+cnt: 2, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([64, 24, 128, 128], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([64, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([64, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([64, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([64, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([64, 1536, 8, 8], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([64, 1280, 8, 8], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([64, 1280, 8, 8], f16), T([64, 1280, 8, 8], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 1536, 8, 8], f16), T([64, 1536, 8, 8], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f32), T([1536], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 1024, 16, 16], f16), T([64, 1024, 16, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([64, 256, 32, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([64, 512, 32, 32], f16), T([64, 512, 32, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([64, 128, 64, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([64, 256, 64, 64], f16), T([64, 256, 64, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 32, 128, 128], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([64, 24, 128, 128], f16), T([64, 24, 128, 128], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 4, ((T([64, 8, 1, 1], f16),), {})
+cnt: 2, ((T([64, 16, 1, 1], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 2, ((T([64, 64, 1, 1], f16),), {})
+cnt: 2, ((T([64, 128, 1, 1], f16),), {})
+cnt: 2, ((T([64, 256, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 2, ((T([64, 256, 1, 1], f16), T([64, 256, 1, 1], f16)), {})
+cnt: 2, ((T([64, 128, 1, 1], f16), T([64, 128, 1, 1], f16)), {})
+cnt: 2, ((T([64, 64, 1, 1], f16), T([64, 64, 1, 1], f16)), {})
+Operator: aten.silu_.default
+cnt: 1, ((T([64, 24, 128, 128], f16),), {})
+cnt: 1, ((T([64, 32, 128, 128], f16),), {})
+cnt: 5, ((T([64, 64, 64, 64], f16),), {})
+cnt: 2, ((T([64, 256, 64, 64], f16),), {})
+cnt: 1, ((T([64, 128, 64, 64], f16),), {})
+cnt: 5, ((T([64, 128, 32, 32], f16),), {})
+cnt: 3, ((T([64, 512, 32, 32], f16),), {})
+cnt: 1, ((T([64, 256, 32, 32], f16),), {})
+cnt: 5, ((T([64, 256, 16, 16], f16),), {})
+cnt: 3, ((T([64, 1024, 16, 16], f16),), {})
+cnt: 1, ((T([64, 512, 16, 16], f16),), {})
+cnt: 3, ((T([64, 512, 8, 8], f16),), {})
+cnt: 2, ((T([64, 1536, 8, 8], f16),), {})
+cnt: 1, ((T([64, 1280, 8, 8], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 1, ((T([64, 1280, 8, 8], f16), T([64, 1280, 8, 8], f16)), {})
+cnt: 2, ((T([64, 1536, 8, 8], f16), T([64, 1536, 8, 8], f16)), {})
+cnt: 3, ((T([64, 512, 8, 8], f16), T([64, 512, 8, 8], f16)), {})
+cnt: 1, ((T([64, 512, 16, 16], f16), T([64, 512, 16, 16], f16)), {})
+cnt: 3, ((T([64, 1024, 16, 16], f16), T([64, 1024, 16, 16], f16)), {})
+cnt: 5, ((T([64, 256, 16, 16], f16), T([64, 256, 16, 16], f16)), {})
+cnt: 1, ((T([64, 256, 32, 32], f16), T([64, 256, 32, 32], f16)), {})
+cnt: 3, ((T([64, 512, 32, 32], f16), T([64, 512, 32, 32], f16)), {})
+cnt: 5, ((T([64, 128, 32, 32], f16), T([64, 128, 32, 32], f16)), {})
+cnt: 1, ((T([64, 128, 64, 64], f16), T([64, 128, 64, 64], f16)), {})
+cnt: 2, ((T([64, 256, 64, 64], f16), T([64, 256, 64, 64], f16)), {})
+cnt: 5, ((T([64, 64, 64, 64], f16), T([64, 64, 64, 64], f16)), {})
+cnt: 1, ((T([64, 32, 128, 128], f16), T([64, 32, 128, 128], f16)), {})
+cnt: 1, ((T([64, 24, 128, 128], f16), T([64, 24, 128, 128], f16)), {})
+Operator: aten.slice_backward.default
+cnt: 2, ((T([2048, 8, 8], f16), [2048, 8, 15], 2, 7, 9223372036854775807, 1), {})
+cnt: 2, ((T([2048, 8, 15], f16), [2048, 9, 15], 1, 0, 8, 1), {})
+cnt: 2, ((T([2048, 9, 15], f16), [2048, 9, 15], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([4096, 16, 16], f16), [4096, 16, 31], 2, 15, 9223372036854775807, 1), {})
+cnt: 4, ((T([4096, 16, 31], f16), [4096, 17, 31], 1, 0, 16, 1), {})
+cnt: 4, ((T([4096, 17, 31], f16), [4096, 17, 31], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([8192, 32, 32], f16), [8192, 32, 63], 2, 31, 9223372036854775807, 1), {})
+cnt: 2, ((T([8192, 32, 63], f16), [8192, 33, 63], 1, 0, 32, 1), {})
+cnt: 2, ((T([8192, 33, 63], f16), [8192, 33, 63], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.split_with_sizes.default
+cnt: 1, ((T([64, 384, 32, 32], f16), [128, 128, 128], 1), {})
+cnt: 1, ((T([64, 768, 16, 16], f16), [256, 256, 256], 1), {})
+cnt: 1, ((T([64, 1536, 16, 16], f16), [512, 512, 512], 1), {})
+cnt: 1, ((T([64, 1536, 8, 8], f16), [512, 512, 512], 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 1, ((T([256, 8, 8, 8, 8], f16, stride=(4096, 64, 1, 512, 8)), [2], True), {})
+cnt: 1, ((T([256, 8, 8, 8, 8], f16, stride=(4096, 512, 8, 64, 1)), [2], True), {})
+cnt: 2, ((T([256, 16, 16, 16, 16], f16, stride=(65536, 256, 1, 4096, 16)), [2], True), {})
+cnt: 2, ((T([256, 16, 16, 16, 16], f16, stride=(65536, 4096, 16, 256, 1)), [2], True), {})
+cnt: 2, ((T([64, 256, 16, 16], f16), [2, 3], True), {})
+cnt: 1, ((T([256, 32, 32, 32, 32], f16, stride=(1048576, 1024, 1, 32768, 32)), [2], True), {})
+cnt: 1, ((T([256, 32, 32, 32, 32], f16, stride=(1048576, 32768, 32, 1024, 1)), [2], True), {})
+cnt: 2, ((T([64, 128, 32, 32], f16), [2, 3], True), {})
+cnt: 2, ((T([64, 64, 64, 64], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([64, 16, 1, 1], f16), T([64, 16, 1, 1], f16), 0), {})
+cnt: 4, ((T([64, 8, 1, 1], f16), T([64, 8, 1, 1], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/selecsls42b_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/selecsls42b_training.txt
new file mode 100644
index 0000000000000..bc42466c16d67
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/selecsls42b_training.txt
@@ -0,0 +1,167 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128, 152, 14, 14], f16, stride=(178752, 196, 14, 1)), T([128, 152, 14, 14], f16)), {})
+cnt: 2, ((T([128, 304, 14, 14], f16, stride=(178752, 196, 14, 1)), T([128, 304, 14, 14], f16)), {})
+cnt: 1, ((T([128, 152, 14, 14], f16, stride=(119168, 196, 14, 1)), T([128, 152, 14, 14], f16)), {})
+cnt: 1, ((T([128, 304, 14, 14], f16, stride=(119168, 196, 14, 1)), T([128, 304, 14, 14], f16)), {})
+cnt: 1, ((T([128, 72, 28, 28], f16, stride=(338688, 784, 28, 1)), T([128, 72, 28, 28], f16)), {})
+cnt: 2, ((T([128, 144, 28, 28], f16, stride=(338688, 784, 28, 1)), T([128, 144, 28, 28], f16)), {})
+cnt: 1, ((T([128, 72, 28, 28], f16, stride=(225792, 784, 28, 1)), T([128, 72, 28, 28], f16)), {})
+cnt: 1, ((T([128, 144, 28, 28], f16, stride=(225792, 784, 28, 1)), T([128, 144, 28, 28], f16)), {})
+cnt: 1, ((T([128, 32, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([128, 32, 56, 56], f16)), {})
+cnt: 2, ((T([128, 64, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([128, 64, 56, 56], f16)), {})
+cnt: 1, ((T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1)), T([128, 32, 56, 56], f16)), {})
+cnt: 1, ((T([128, 64, 56, 56], f16, stride=(401408, 3136, 56, 1)), T([128, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 41, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 64, 56, 56], f16), T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([128, 64, 56, 56], f16), T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([128, 64, 56, 56], f16)], 1), {})
+cnt: 1, (([T([128, 144, 28, 28], f16), T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16)], 1), {})
+cnt: 1, (([T([128, 144, 28, 28], f16), T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), T([128, 144, 28, 28], f16)], 1), {})
+cnt: 1, (([T([128, 304, 14, 14], f16), T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16)], 1), {})
+cnt: 1, (([T([128, 304, 14, 14], f16), T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16), T([128, 304, 14, 14], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([64, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 64, 56, 56], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 64, 56, 56], f16), T([32, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 32, 56, 56], f16), T([64, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([64, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 56, 56], f16), T([128, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([144, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 144, 28, 28], f16), T([144, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 144, 28, 28], f16), T([72, 144, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 72, 28, 28], f16), T([144, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 288, 28, 28], f16), T([144, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([144, 144, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 432, 28, 28], f16), T([288, 432, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 288, 28, 28], f16), T([304, 288, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 304, 14, 14], f16), T([304, 304, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 304, 14, 14], f16), T([152, 304, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 152, 14, 14], f16), T([304, 152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 608, 14, 14], f16), T([304, 608, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 304, 14, 14], f16), T([304, 304, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 912, 14, 14], f16), T([480, 912, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([960, 480, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([1024, 960, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([1280, 1024, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1280, 4, 4], f16), T([1024, 1280, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1024, 4, 4], f16), T([128, 1280, 4, 4], f16), T([1024, 1280, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1280, 4, 4], f16), T([128, 1024, 7, 7], f16), T([1280, 1024, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 960, 7, 7], f16), T([1024, 960, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([128, 480, 14, 14], f16), T([960, 480, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 912, 14, 14], f16), T([480, 912, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 152, 14, 14], f16), T([128, 304, 14, 14], f16), T([152, 304, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 304, 14, 14], f16), T([128, 152, 14, 14], f16), T([304, 152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 304, 14, 14], f16), T([128, 304, 14, 14], f16), T([304, 304, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 304, 14, 14], f16), T([128, 304, 14, 14], f16), T([304, 304, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 304, 14, 14], f16), T([128, 608, 14, 14], f16), T([304, 608, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 304, 14, 14], f16), T([128, 288, 28, 28], f16), T([304, 288, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 288, 28, 28], f16), T([128, 432, 28, 28], f16), T([288, 432, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 72, 28, 28], f16), T([128, 144, 28, 28], f16), T([72, 144, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 144, 28, 28], f16), T([128, 72, 28, 28], f16), T([144, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), T([144, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), T([144, 144, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 288, 28, 28], f16), T([144, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 128, 56, 56], f16), T([144, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 192, 56, 56], f16), T([128, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 32, 56, 56], f16), T([128, 64, 56, 56], f16), T([32, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 56, 56], f16), T([128, 32, 56, 56], f16), T([64, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 128, 56, 56], f16), T([64, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 64, 56, 56], f16), T([128, 32, 112, 112], f16), T([64, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1024, 4, 4], f16, stride=(1024, 1, 0, 0)), 16), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 1024, 4, 4], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1024], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 288, 28, 28], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([128, 304, 14, 14], f16), T([304], f16), T([304], f16), T([304], f16), T([304], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 152, 14, 14], f16), T([152], f16), T([152], f16), T([152], f16), T([152], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1280, 4, 4], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1024, 4, 4], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1024, 4, 4], f16), T([128, 1024, 4, 4], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 1280, 4, 4], f16), T([128, 1280, 4, 4], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16), T([152], f16), T([152], f16), T([152], f16), T([152], f32), T([152], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([128, 304, 14, 14], f16), T([128, 304, 14, 14], f16), T([304], f16), T([304], f16), T([304], f16), T([304], f32), T([304], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 288, 28, 28], f16), T([128, 288, 28, 28], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f32), T([288], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+cnt: 7, ((T([128, 64, 56, 56], f16),), {})
+cnt: 4, ((T([128, 32, 56, 56], f16),), {})
+cnt: 1, ((T([128, 128, 56, 56], f16),), {})
+cnt: 7, ((T([128, 144, 28, 28], f16),), {})
+cnt: 4, ((T([128, 72, 28, 28], f16),), {})
+cnt: 1, ((T([128, 288, 28, 28], f16),), {})
+cnt: 7, ((T([128, 304, 14, 14], f16),), {})
+cnt: 4, ((T([128, 152, 14, 14], f16),), {})
+cnt: 1, ((T([128, 480, 14, 14], f16),), {})
+cnt: 1, ((T([128, 960, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 4, 4], f16),), {})
+cnt: 1, ((T([128, 1024, 4, 4], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 1024, 4, 4], f16), T([128, 1024, 4, 4], f16), 0), {})
+cnt: 1, ((T([128, 1280, 4, 4], f16), T([128, 1280, 4, 4], f16), 0), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 960, 7, 7], f16), T([128, 960, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 152, 14, 14], f16, stride=(178752, 196, 14, 1)), T([128, 152, 14, 14], f16), 0), {})
+cnt: 7, ((T([128, 304, 14, 14], f16), T([128, 304, 14, 14], f16), 0), {})
+cnt: 2, ((T([128, 152, 14, 14], f16), T([128, 152, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 152, 14, 14], f16, stride=(119168, 196, 14, 1)), T([128, 152, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 288, 28, 28], f16), T([128, 288, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 72, 28, 28], f16, stride=(338688, 784, 28, 1)), T([128, 72, 28, 28], f16), 0), {})
+cnt: 7, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), 0), {})
+cnt: 2, ((T([128, 72, 28, 28], f16), T([128, 72, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 72, 28, 28], f16, stride=(225792, 784, 28, 1)), T([128, 72, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 128, 56, 56], f16), T([128, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 32, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([128, 32, 56, 56], f16), 0), {})
+cnt: 7, ((T([128, 64, 56, 56], f16), T([128, 64, 56, 56], f16), 0), {})
+cnt: 2, ((T([128, 32, 56, 56], f16), T([128, 32, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 32, 56, 56], f16, stride=(401408, 3136, 56, 1)), T([128, 32, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/spnasnet_100_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/spnasnet_100_training.txt
new file mode 100644
index 0000000000000..5ffc25e3d6e66
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/spnasnet_100_training.txt
@@ -0,0 +1,182 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 64, ((T([], i64), 1), {})
+cnt: 4, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 6, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16)), {})
+cnt: 6, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16)), {})
+cnt: 6, ((T([128, 96, 14, 14], f16), T([128, 96, 14, 14], f16)), {})
+cnt: 6, ((T([128, 192, 7, 7], f16), T([128, 192, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([48, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([48, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 48), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([24, 48, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([72, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), T([72, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 72), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), T([24, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([144, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([144, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([40, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 120, 28, 28], f16), T([120, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 120), {})
+cnt: 3, ((T([128, 120, 28, 28], f16), T([40, 120, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 4, ((T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 80, 14, 14], f16), T([240, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 240, 14, 14], f16), T([240, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([480, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 480), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([96, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 96, 14, 14], f16), T([288, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 288, 14, 14], f16), T([288, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 288), {})
+cnt: 3, ((T([128, 288, 14, 14], f16), T([96, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 14, 14], f16), T([576, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([576, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 576), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([192, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1152), {})
+cnt: 3, ((T([128, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1152), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([128, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 4, ((T([128, 1152, 7, 7], f16), T([128, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 7, 7], f16), T([128, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 7, 7], f16), T([128, 576, 7, 7], f16), T([192, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([128, 576, 14, 14], f16), T([576, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 576, [True, True, False]), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([128, 96, 14, 14], f16), T([576, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 96, 14, 14], f16), T([128, 288, 14, 14], f16), T([96, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 288, 14, 14], f16), T([128, 288, 14, 14], f16), T([288, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 288, [True, True, False]), {})
+cnt: 3, ((T([128, 288, 14, 14], f16), T([128, 96, 14, 14], f16), T([288, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 14, 14], f16), T([128, 480, 14, 14], f16), T([96, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), T([240, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 3, ((T([128, 240, 14, 14], f16), T([128, 80, 14, 14], f16), T([240, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 40, 28, 28], f16), T([128, 120, 28, 28], f16), T([40, 120, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 3, ((T([128, 120, 28, 28], f16), T([128, 40, 28, 28], f16), T([120, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([128, 144, 28, 28], f16), T([40, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 56, 56], f16), T([144, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 24, 56, 56], f16), T([144, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 72, 56, 56], f16), T([24, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), T([72, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 2, ((T([128, 72, 56, 56], f16), T([128, 24, 56, 56], f16), T([72, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 48, 56, 56], f16), T([24, 48, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 112, 112], f16), T([48, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 48, [True, True, False]), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([128, 16, 112, 112], f16), T([48, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1280, 7, 7], f16, stride=(1280, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 1280, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([128, 288, 14, 14], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([128, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f32), T([1152], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 192, 7, 7], f16), T([128, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([128, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([128, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 96, 14, 14], f16), T([128, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([128, 288, 14, 14], f16), T([128, 288, 14, 14], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f32), T([288], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([128, 48, 112, 112], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 48, 112, 112], f16),), {})
+cnt: 1, ((T([128, 48, 56, 56], f16),), {})
+cnt: 4, ((T([128, 72, 56, 56], f16),), {})
+cnt: 1, ((T([128, 144, 56, 56], f16),), {})
+cnt: 1, ((T([128, 144, 28, 28], f16),), {})
+cnt: 6, ((T([128, 120, 28, 28], f16),), {})
+cnt: 1, ((T([128, 240, 28, 28], f16),), {})
+cnt: 7, ((T([128, 240, 14, 14], f16),), {})
+cnt: 2, ((T([128, 480, 14, 14], f16),), {})
+cnt: 6, ((T([128, 288, 14, 14], f16),), {})
+cnt: 1, ((T([128, 576, 14, 14], f16),), {})
+cnt: 1, ((T([128, 576, 7, 7], f16),), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16),), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16), 0), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 576, 7, 7], f16), T([128, 576, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 576, 14, 14], f16), T([128, 576, 14, 14], f16), 0), {})
+cnt: 6, ((T([128, 288, 14, 14], f16), T([128, 288, 14, 14], f16), 0), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), 0), {})
+cnt: 7, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), 0), {})
+cnt: 6, ((T([128, 120, 28, 28], f16), T([128, 120, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16), 0), {})
+cnt: 4, ((T([128, 72, 56, 56], f16), T([128, 72, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 48, 56, 56], f16), T([128, 48, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 48, 112, 112], f16), T([128, 48, 112, 112], f16), 0), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/swin_base_patch4_window7_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/swin_base_patch4_window7_224_training.txt
new file mode 100644
index 0000000000000..6076086ba3a59
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/swin_base_patch4_window7_224_training.txt
@@ -0,0 +1,341 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 2, ((T([4096, 4, 49, 49], f16), -1, False), {})
+cnt: 2, ((T([1024, 8, 49, 49], f16), -1, False), {})
+cnt: 18, ((T([256, 16, 49, 49], f16), -1, False), {})
+cnt: 2, ((T([64, 32, 49, 49], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 2, ((T([64, 32, 49, 49], f16), T([64, 32, 49, 49], f16), -1, f16), {})
+cnt: 18, ((T([256, 16, 49, 49], f16), T([256, 16, 49, 49], f16), -1, f16), {})
+cnt: 2, ((T([1024, 8, 49, 49], f16), T([1024, 8, 49, 49], f16), -1, f16), {})
+cnt: 2, ((T([4096, 4, 49, 49], f16), T([4096, 4, 49, 49], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 6, ((T([4096, 4, 49, 32], f16), [16384, 49, 32]), {})
+cnt: 2, ((T([4096, 4, 32, 49], f16), [16384, 32, 49]), {})
+cnt: 2, ((T([16384, 49, 49], f16), [4096, 4, 49, 49]), {})
+cnt: 2, ((T([16384, 49, 32], f16), [4096, 4, 49, 32]), {})
+cnt: 2, ((T([4096, 49, 4, 32], f16), [4096, 49, 128]), {})
+cnt: 1, ((T([50176, 256], f16), [64, 784, 256]), {})
+cnt: 6, ((T([1024, 8, 49, 32], f16), [8192, 49, 32]), {})
+cnt: 2, ((T([1024, 8, 32, 49], f16), [8192, 32, 49]), {})
+cnt: 2, ((T([8192, 49, 49], f16), [1024, 8, 49, 49]), {})
+cnt: 2, ((T([8192, 49, 32], f16), [1024, 8, 49, 32]), {})
+cnt: 2, ((T([1024, 49, 8, 32], f16), [1024, 49, 256]), {})
+cnt: 1, ((T([12544, 512], f16), [64, 196, 512]), {})
+cnt: 54, ((T([256, 16, 49, 32], f16), [4096, 49, 32]), {})
+cnt: 18, ((T([256, 16, 32, 49], f16), [4096, 32, 49]), {})
+cnt: 18, ((T([4096, 49, 49], f16), [256, 16, 49, 49]), {})
+cnt: 18, ((T([4096, 49, 32], f16), [256, 16, 49, 32]), {})
+cnt: 18, ((T([256, 49, 16, 32], f16), [256, 49, 512]), {})
+cnt: 1, ((T([3136, 1024], f16), [64, 49, 1024]), {})
+cnt: 6, ((T([64, 32, 49, 32], f16), [2048, 49, 32]), {})
+cnt: 2, ((T([64, 32, 32, 49], f16), [2048, 32, 49]), {})
+cnt: 2, ((T([2048, 49, 49], f16), [64, 32, 49, 49]), {})
+cnt: 2, ((T([2048, 49, 32], f16), [64, 32, 49, 32]), {})
+cnt: 2, ((T([64, 49, 32, 32], f16), [64, 49, 1024]), {})
+cnt: 2, ((T([64, 49, 3, 32, 32], f16), [64, 49, 3072]), {})
+cnt: 18, ((T([64, 2, 2, 7, 7, 512], f16), [256, 7, 7, 512]), {})
+cnt: 18, ((T([256, 49, 3, 16, 32], f16), [256, 49, 1536]), {})
+cnt: 18, ((T([64, 2, 7, 2, 7, 512], f16), [64, 14, 14, 512]), {})
+cnt: 2, ((T([64, 4, 4, 7, 7, 256], f16), [1024, 7, 7, 256]), {})
+cnt: 2, ((T([1024, 49, 3, 8, 32], f16), [1024, 49, 768]), {})
+cnt: 2, ((T([64, 4, 7, 4, 7, 256], f16), [64, 28, 28, 256]), {})
+cnt: 2, ((T([64, 8, 8, 7, 7, 128], f16), [4096, 7, 7, 128]), {})
+cnt: 2, ((T([4096, 49, 3, 4, 32], f16), [4096, 49, 384]), {})
+cnt: 2, ((T([64, 8, 7, 8, 7, 128], f16), [64, 56, 56, 128]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([4096, 4, 49, 49], f16), T([1, 4, 49, 49], f16)), {})
+cnt: 8, ((T([64, 3136, 128], f16), T([64, 3136, 128], f16)), {})
+cnt: 1, ((T([64, 64, 4, 49, 49], f16), T([1, 64, 1, 49, 49], f16)), {})
+cnt: 2, ((T([1024, 8, 49, 49], f16), T([1, 8, 49, 49], f16)), {})
+cnt: 8, ((T([64, 784, 256], f16), T([64, 784, 256], f16)), {})
+cnt: 1, ((T([64, 16, 8, 49, 49], f16), T([1, 16, 1, 49, 49], f16)), {})
+cnt: 18, ((T([256, 16, 49, 49], f16), T([1, 16, 49, 49], f16)), {})
+cnt: 72, ((T([64, 196, 512], f16), T([64, 196, 512], f16)), {})
+cnt: 9, ((T([64, 4, 16, 49, 49], f16), T([1, 4, 1, 49, 49], f16)), {})
+cnt: 2, ((T([64, 32, 49, 49], f16), T([1, 32, 49, 49], f16)), {})
+cnt: 8, ((T([64, 49, 1024], f16), T([64, 49, 1024], f16)), {})
+cnt: 3, ((T([64, 14, 14, 512], f16), T([64, 14, 14, 512], f16)), {})
+cnt: 3, ((T([64, 28, 28, 256], f16), T([64, 28, 28, 256], f16)), {})
+cnt: 3, ((T([64, 56, 56, 128], f16), T([64, 56, 56, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 2, ((T([384], f16), T([200704, 128], f16), T([128, 384], f16, stride=(1, 128))), {})
+cnt: 2, ((T([128], f16), T([200704, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 2, ((T([512], f16), T([200704, 128], f16), T([128, 512], f16, stride=(1, 128))), {})
+cnt: 2, ((T([128], f16), T([200704, 512], f16), T([512, 128], f16, stride=(1, 512))), {})
+cnt: 2, ((T([768], f16), T([50176, 256], f16), T([256, 768], f16, stride=(1, 256))), {})
+cnt: 2, ((T([256], f16), T([50176, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+cnt: 2, ((T([1024], f16), T([50176, 256], f16), T([256, 1024], f16, stride=(1, 256))), {})
+cnt: 2, ((T([256], f16), T([50176, 1024], f16), T([1024, 256], f16, stride=(1, 1024))), {})
+cnt: 18, ((T([1536], f16), T([12544, 512], f16), T([512, 1536], f16, stride=(1, 512))), {})
+cnt: 18, ((T([512], f16), T([12544, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 18, ((T([2048], f16), T([12544, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 18, ((T([512], f16), T([12544, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+cnt: 2, ((T([3072], f16), T([3136, 1024], f16), T([1024, 3072], f16, stride=(1, 1024))), {})
+cnt: 2, ((T([1024], f16), T([3136, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 2, ((T([4096], f16), T([3136, 1024], f16), T([1024, 4096], f16, stride=(1, 1024))), {})
+cnt: 2, ((T([1024], f16), T([3136, 4096], f16), T([4096, 1024], f16, stride=(1, 4096))), {})
+cnt: 1, ((T([1000], f16), T([64, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.bernoulli_.float
+cnt: 2, ((T([64, 1, 1], f16), 0.9956521736457944), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9913043472915888), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9869565209373832), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9826086945831776), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9782608672976494), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9739130418747663), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9695652164518833), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9652173891663551), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.960869561880827), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9565217345952988), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9521739110350609), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9478260837495327), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9434782564640045), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9391304329037666), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9347826093435287), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9304347857832909), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9260869547724724), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9217391312122345), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.917391300201416), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9130434766411781), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9086956530809402), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9043478220701218), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.8999999985098839), {})
+Operator: aten.bmm.default
+cnt: 2, ((T([16384, 49, 32], f16), T([16384, 32, 49], f16)), {})
+cnt: 2, ((T([16384, 49, 49], f16), T([16384, 49, 32], f16)), {})
+cnt: 2, ((T([8192, 49, 32], f16), T([8192, 32, 49], f16)), {})
+cnt: 2, ((T([8192, 49, 49], f16), T([8192, 49, 32], f16)), {})
+cnt: 18, ((T([4096, 49, 32], f16), T([4096, 32, 49], f16)), {})
+cnt: 18, ((T([4096, 49, 49], f16), T([4096, 49, 32], f16)), {})
+cnt: 2, ((T([2048, 49, 32], f16), T([2048, 32, 49], f16)), {})
+cnt: 2, ((T([2048, 49, 49], f16), T([2048, 49, 32], f16)), {})
+cnt: 2, ((T([2048, 49, 49], f16, stride=(2401, 1, 49)), T([2048, 49, 32], f16)), {})
+cnt: 2, ((T([2048, 49, 32], f16), T([2048, 32, 49], f16, stride=(1568, 1, 32))), {})
+cnt: 2, ((T([2048, 32, 49], f16, stride=(1568, 1, 32)), T([2048, 49, 49], f16)), {})
+cnt: 2, ((T([2048, 49, 49], f16), T([2048, 49, 32], f16, stride=(1568, 1, 49))), {})
+cnt: 18, ((T([4096, 49, 49], f16, stride=(2401, 1, 49)), T([4096, 49, 32], f16)), {})
+cnt: 18, ((T([4096, 49, 32], f16), T([4096, 32, 49], f16, stride=(1568, 1, 32))), {})
+cnt: 18, ((T([4096, 32, 49], f16, stride=(1568, 1, 32)), T([4096, 49, 49], f16)), {})
+cnt: 18, ((T([4096, 49, 49], f16), T([4096, 49, 32], f16, stride=(1568, 1, 49))), {})
+cnt: 2, ((T([8192, 49, 49], f16, stride=(2401, 1, 49)), T([8192, 49, 32], f16)), {})
+cnt: 2, ((T([8192, 49, 32], f16), T([8192, 32, 49], f16, stride=(1568, 1, 32))), {})
+cnt: 2, ((T([8192, 32, 49], f16, stride=(1568, 1, 32)), T([8192, 49, 49], f16)), {})
+cnt: 2, ((T([8192, 49, 49], f16), T([8192, 49, 32], f16, stride=(1568, 1, 49))), {})
+cnt: 2, ((T([16384, 49, 49], f16, stride=(2401, 1, 49)), T([16384, 49, 32], f16)), {})
+cnt: 2, ((T([16384, 49, 32], f16), T([16384, 32, 49], f16, stride=(1568, 1, 32))), {})
+cnt: 2, ((T([16384, 32, 49], f16, stride=(1568, 1, 32)), T([16384, 49, 49], f16)), {})
+cnt: 2, ((T([16384, 49, 49], f16), T([16384, 49, 32], f16, stride=(1568, 1, 49))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 28, 28, 128], f16, stride=(401408, 14336, 256, 1)), T([64, 28, 28, 128], f16, stride=(401408, 14336, 256, 1)), T([64, 28, 28, 128], f16, stride=(401408, 14336, 256, 1)), T([64, 28, 28, 128], f16, stride=(401408, 14336, 256, 1))], -1), {})
+cnt: 1, (([T([64, 14, 14, 256], f16, stride=(200704, 14336, 512, 1)), T([64, 14, 14, 256], f16, stride=(200704, 14336, 512, 1)), T([64, 14, 14, 256], f16, stride=(200704, 14336, 512, 1)), T([64, 14, 14, 256], f16, stride=(200704, 14336, 512, 1))], -1), {})
+cnt: 1, (([T([64, 7, 7, 512], f16, stride=(100352, 14336, 1024, 1)), T([64, 7, 7, 512], f16, stride=(100352, 14336, 1024, 1)), T([64, 7, 7, 512], f16, stride=(100352, 14336, 1024, 1)), T([64, 7, 7, 512], f16, stride=(100352, 14336, 1024, 1))], -1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([128, 3, 4, 4], f16), T([128], f16), [4, 4], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 128, 56, 56], f16, stride=(401408, 1, 7168, 128)), T([64, 3, 224, 224], f16), T([128, 3, 4, 4], f16), [128], [4, 4], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 49, 1024], f16, stride=(1024, 0, 1)), 49), {})
+Operator: aten.div_.Tensor
+cnt: 2, ((T([64, 1, 1], f16), 0.9956521736457944), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9913043472915888), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9869565209373832), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9826086945831776), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9782608672976494), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9739130418747663), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9695652164518833), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9652173891663551), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.960869561880827), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9565217345952988), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9521739110350609), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9478260837495327), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9434782564640045), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9391304329037666), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9347826093435287), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9304347857832909), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9260869547724724), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9217391312122345), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.917391300201416), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9130434766411781), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9086956530809402), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.9043478220701218), {})
+cnt: 2, ((T([64, 1, 1], f16), 0.8999999985098839), {})
+Operator: aten.gelu.default
+cnt: 2, ((T([64, 3136, 512], f16),), {})
+cnt: 2, ((T([64, 784, 1024], f16),), {})
+cnt: 18, ((T([64, 196, 2048], f16),), {})
+cnt: 2, ((T([64, 49, 4096], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 2, ((T([64, 49, 4096], f16), T([64, 49, 4096], f16)), {})
+cnt: 18, ((T([64, 196, 2048], f16), T([64, 196, 2048], f16)), {})
+cnt: 2, ((T([64, 784, 1024], f16), T([64, 784, 1024], f16)), {})
+cnt: 2, ((T([64, 3136, 512], f16), T([64, 3136, 512], f16)), {})
+Operator: aten.index.Tensor
+cnt: 2, ((T([169, 4], f16), [T([2401], i64)]), {})
+cnt: 2, ((T([169, 8], f16), [T([2401], i64)]), {})
+cnt: 18, ((T([169, 16], f16), [T([2401], i64)]), {})
+cnt: 2, ((T([169, 32], f16), [T([2401], i64)]), {})
+Operator: aten.index_put.default
+cnt: 2, ((T([169, 32], f16), [T([2401], i64)], T([2401, 32], f16, stride=(1, 2401)), True), {})
+cnt: 18, ((T([169, 16], f16), [T([2401], i64)], T([2401, 16], f16, stride=(1, 2401)), True), {})
+cnt: 2, ((T([169, 8], f16), [T([2401], i64)], T([2401, 8], f16, stride=(1, 2401)), True), {})
+cnt: 2, ((T([169, 4], f16), [T([2401], i64)], T([2401, 4], f16, stride=(1, 2401)), True), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 49, 1024], f16), [1]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([50176, 512], f16), T([512, 256], f16, stride=(1, 512))), {})
+cnt: 1, ((T([12544, 1024], f16), T([1024, 512], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([3136, 2048], f16), T([2048, 1024], f16, stride=(1, 2048))), {})
+cnt: 1, ((T([64, 1000], f16), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 1024], f16)), {})
+cnt: 2, ((T([3136, 1024], f16), T([1024, 4096], f16)), {})
+cnt: 2, ((T([1024, 3136], f16, stride=(1, 1024)), T([3136, 4096], f16)), {})
+cnt: 2, ((T([3136, 4096], f16), T([4096, 1024], f16)), {})
+cnt: 2, ((T([4096, 3136], f16, stride=(1, 4096)), T([3136, 1024], f16)), {})
+cnt: 2, ((T([3136, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 2, ((T([1024, 3136], f16, stride=(1, 1024)), T([3136, 1024], f16)), {})
+cnt: 2, ((T([3136, 3072], f16), T([3072, 1024], f16)), {})
+cnt: 2, ((T([3072, 3136], f16, stride=(1, 3072)), T([3136, 1024], f16)), {})
+cnt: 1, ((T([1024, 3136], f16, stride=(1, 1024)), T([3136, 2048], f16)), {})
+cnt: 1, ((T([3136, 1024], f16), T([1024, 2048], f16)), {})
+cnt: 18, ((T([12544, 512], f16), T([512, 2048], f16)), {})
+cnt: 18, ((T([512, 12544], f16, stride=(1, 512)), T([12544, 2048], f16)), {})
+cnt: 18, ((T([12544, 2048], f16), T([2048, 512], f16)), {})
+cnt: 18, ((T([2048, 12544], f16, stride=(1, 2048)), T([12544, 512], f16)), {})
+cnt: 18, ((T([12544, 512], f16), T([512, 512], f16)), {})
+cnt: 18, ((T([512, 12544], f16, stride=(1, 512)), T([12544, 512], f16)), {})
+cnt: 18, ((T([12544, 1536], f16), T([1536, 512], f16)), {})
+cnt: 18, ((T([1536, 12544], f16, stride=(1, 1536)), T([12544, 512], f16)), {})
+cnt: 1, ((T([512, 12544], f16, stride=(1, 512)), T([12544, 1024], f16)), {})
+cnt: 1, ((T([12544, 512], f16), T([512, 1024], f16)), {})
+cnt: 2, ((T([50176, 256], f16), T([256, 1024], f16)), {})
+cnt: 2, ((T([256, 50176], f16, stride=(1, 256)), T([50176, 1024], f16)), {})
+cnt: 2, ((T([50176, 1024], f16), T([1024, 256], f16)), {})
+cnt: 2, ((T([1024, 50176], f16, stride=(1, 1024)), T([50176, 256], f16)), {})
+cnt: 2, ((T([50176, 256], f16), T([256, 256], f16)), {})
+cnt: 2, ((T([256, 50176], f16, stride=(1, 256)), T([50176, 256], f16)), {})
+cnt: 2, ((T([50176, 768], f16), T([768, 256], f16)), {})
+cnt: 2, ((T([768, 50176], f16, stride=(1, 768)), T([50176, 256], f16)), {})
+cnt: 1, ((T([256, 50176], f16, stride=(1, 256)), T([50176, 512], f16)), {})
+cnt: 1, ((T([50176, 256], f16), T([256, 512], f16)), {})
+cnt: 2, ((T([200704, 128], f16), T([128, 512], f16)), {})
+cnt: 2, ((T([128, 200704], f16, stride=(1, 128)), T([200704, 512], f16)), {})
+cnt: 2, ((T([200704, 512], f16), T([512, 128], f16)), {})
+cnt: 2, ((T([512, 200704], f16, stride=(1, 512)), T([200704, 128], f16)), {})
+cnt: 2, ((T([200704, 128], f16), T([128, 128], f16)), {})
+cnt: 2, ((T([128, 200704], f16, stride=(1, 128)), T([200704, 128], f16)), {})
+cnt: 2, ((T([200704, 384], f16), T([384, 128], f16)), {})
+cnt: 2, ((T([384, 200704], f16, stride=(1, 384)), T([200704, 128], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([4096, 4, 49, 32], f16, stride=(18816, 32, 384, 1)), 0.1767766952966369), {})
+cnt: 4, ((T([64, 3136, 128], f16), T([64, 1, 1], f16)), {})
+cnt: 2, ((T([1024, 8, 49, 32], f16, stride=(37632, 32, 768, 1)), 0.1767766952966369), {})
+cnt: 8, ((T([64, 784, 256], f16), T([64, 1, 1], f16)), {})
+cnt: 18, ((T([256, 16, 49, 32], f16, stride=(75264, 32, 1536, 1)), 0.1767766952966369), {})
+cnt: 72, ((T([64, 196, 512], f16), T([64, 1, 1], f16)), {})
+cnt: 2, ((T([64, 32, 49, 32], f16, stride=(150528, 32, 3072, 1)), 0.1767766952966369), {})
+cnt: 8, ((T([64, 49, 1024], f16), T([64, 1, 1], f16)), {})
+cnt: 2, ((T([64, 32, 49, 32], f16), 0.1767766952966369), {})
+cnt: 18, ((T([256, 16, 49, 32], f16), 0.1767766952966369), {})
+cnt: 2, ((T([1024, 8, 49, 32], f16), 0.1767766952966369), {})
+cnt: 2, ((T([4096, 4, 49, 32], f16), 0.1767766952966369), {})
+Operator: aten.native_layer_norm.default
+cnt: 1, ((T([64, 3136, 128], f16, stride=(401408, 1, 3136)), [128], T([128], f16), T([128], f16), 1e-05), {})
+cnt: 4, ((T([64, 3136, 128], f16), [128], T([128], f16), T([128], f16), 1e-05), {})
+cnt: 1, ((T([64, 784, 512], f16), [512], T([512], f16), T([512], f16), 1e-05), {})
+cnt: 4, ((T([64, 784, 256], f16), [256], T([256], f16), T([256], f16), 1e-05), {})
+cnt: 1, ((T([64, 196, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+cnt: 36, ((T([64, 196, 512], f16), [512], T([512], f16), T([512], f16), 1e-05), {})
+cnt: 1, ((T([64, 49, 2048], f16), [2048], T([2048], f16), T([2048], f16), 1e-05), {})
+cnt: 5, ((T([64, 49, 1024], f16), [1024], T([1024], f16), T([1024], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 5, ((T([64, 49, 1024], f16), T([64, 49, 1024], f16), [1024], T([64, 49, 1], f32), T([64, 49, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+cnt: 1, ((T([64, 49, 2048], f16), T([64, 49, 2048], f16), [2048], T([64, 49, 1], f32), T([64, 49, 1], f32), T([2048], f16), T([2048], f16), [True, True, True]), {})
+cnt: 36, ((T([64, 196, 512], f16), T([64, 196, 512], f16), [512], T([64, 196, 1], f32), T([64, 196, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 1, ((T([64, 196, 1024], f16), T([64, 196, 1024], f16), [1024], T([64, 196, 1], f32), T([64, 196, 1], f32), T([1024], f16), T([1024], f16), [True, True, True]), {})
+cnt: 4, ((T([64, 784, 256], f16), T([64, 784, 256], f16), [256], T([64, 784, 1], f32), T([64, 784, 1], f32), T([256], f16), T([256], f16), [True, True, True]), {})
+cnt: 1, ((T([64, 784, 512], f16), T([64, 784, 512], f16), [512], T([64, 784, 1], f32), T([64, 784, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 4, ((T([64, 3136, 128], f16), T([64, 3136, 128], f16), [128], T([64, 3136, 1], f32), T([64, 3136, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 1, ((T([64, 3136, 128], f16), T([64, 3136, 128], f16, stride=(401408, 1, 3136)), [128], T([64, 3136, 1], f32), T([64, 3136, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+Operator: aten.new_empty.default
+cnt: 2, ((T([64, 3136, 128], f16), [64, 1, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 4, ((T([64, 784, 256], f16), [64, 1, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 36, ((T([64, 196, 512], f16), [64, 1, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 4, ((T([64, 49, 1024], f16), [64, 1, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_zeros.default
+cnt: 2, ((T([2401, 32], f16, stride=(1, 2401)), [169, 32]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 18, ((T([2401, 16], f16, stride=(1, 2401)), [169, 16]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 2, ((T([2401, 8], f16, stride=(1, 2401)), [169, 8]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 2, ((T([2401, 4], f16, stride=(1, 2401)), [169, 4]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.roll.default
+cnt: 1, ((T([64, 56, 56, 128], f16), [-3, -3], [1, 2]), {})
+cnt: 1, ((T([64, 56, 56, 128], f16), [3, 3], [1, 2]), {})
+cnt: 1, ((T([64, 28, 28, 256], f16), [-3, -3], [1, 2]), {})
+cnt: 1, ((T([64, 28, 28, 256], f16), [3, 3], [1, 2]), {})
+cnt: 9, ((T([64, 14, 14, 512], f16), [-3, -3], [1, 2]), {})
+cnt: 9, ((T([64, 14, 14, 512], f16), [3, 3], [1, 2]), {})
+cnt: 9, ((T([64, 14, 14, 512], f16), [-3, -3], [2, 1]), {})
+cnt: 9, ((T([64, 14, 14, 512], f16), [3, 3], [2, 1]), {})
+cnt: 1, ((T([64, 28, 28, 256], f16), [-3, -3], [2, 1]), {})
+cnt: 1, ((T([64, 28, 28, 256], f16), [3, 3], [2, 1]), {})
+cnt: 1, ((T([64, 56, 56, 128], f16), [-3, -3], [2, 1]), {})
+cnt: 1, ((T([64, 56, 56, 128], f16), [3, 3], [2, 1]), {})
+Operator: aten.slice_backward.default
+cnt: 4, ((T([64, 7, 7, 512], f16, stride=(100352, 14336, 2048, 1)), [64, 7, 7, 512], 3, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 7, 7, 512], f16), [64, 7, 14, 512], 2, 1, 9223372036854775807, 2), {})
+cnt: 2, ((T([64, 7, 14, 512], f16), [64, 14, 14, 512], 1, 1, 9223372036854775807, 2), {})
+cnt: 4, ((T([64, 14, 14, 512], f16), [64, 14, 14, 512], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 7, 14, 512], f16), [64, 14, 14, 512], 1, 0, 9223372036854775807, 2), {})
+cnt: 2, ((T([64, 7, 7, 512], f16), [64, 7, 14, 512], 2, 0, 9223372036854775807, 2), {})
+cnt: 4, ((T([64, 14, 14, 256], f16, stride=(200704, 14336, 1024, 1)), [64, 14, 14, 256], 3, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 14, 14, 256], f16), [64, 14, 28, 256], 2, 1, 9223372036854775807, 2), {})
+cnt: 2, ((T([64, 14, 28, 256], f16), [64, 28, 28, 256], 1, 1, 9223372036854775807, 2), {})
+cnt: 4, ((T([64, 28, 28, 256], f16), [64, 28, 28, 256], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 14, 28, 256], f16), [64, 28, 28, 256], 1, 0, 9223372036854775807, 2), {})
+cnt: 2, ((T([64, 14, 14, 256], f16), [64, 14, 28, 256], 2, 0, 9223372036854775807, 2), {})
+cnt: 4, ((T([64, 28, 28, 128], f16, stride=(401408, 14336, 512, 1)), [64, 28, 28, 128], 3, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 28, 28, 128], f16), [64, 28, 56, 128], 2, 1, 9223372036854775807, 2), {})
+cnt: 2, ((T([64, 28, 56, 128], f16), [64, 56, 56, 128], 1, 1, 9223372036854775807, 2), {})
+cnt: 4, ((T([64, 56, 56, 128], f16), [64, 56, 56, 128], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 28, 56, 128], f16), [64, 56, 56, 128], 1, 0, 9223372036854775807, 2), {})
+cnt: 2, ((T([64, 28, 28, 128], f16), [64, 28, 56, 128], 2, 0, 9223372036854775807, 2), {})
+Operator: aten.stack.default
+cnt: 2, (([T([64, 32, 49, 32], f16), T([64, 32, 49, 32], f16, stride=(50176, 1568, 1, 49)), T([64, 32, 49, 32], f16)],), {})
+cnt: 18, (([T([256, 16, 49, 32], f16), T([256, 16, 49, 32], f16, stride=(25088, 1568, 1, 49)), T([256, 16, 49, 32], f16)],), {})
+cnt: 2, (([T([1024, 8, 49, 32], f16), T([1024, 8, 49, 32], f16, stride=(12544, 1568, 1, 49)), T([1024, 8, 49, 32], f16)],), {})
+cnt: 2, (([T([4096, 4, 49, 32], f16), T([4096, 4, 49, 32], f16, stride=(6272, 1568, 1, 49)), T([4096, 4, 49, 32], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 4, ((T([3136, 1024], f16), [0], True), {})
+cnt: 2, ((T([3136, 4096], f16), [0], True), {})
+cnt: 2, ((T([64, 32, 49, 49], f16), [0], True), {})
+cnt: 2, ((T([3136, 3072], f16), [0], True), {})
+cnt: 36, ((T([12544, 512], f16), [0], True), {})
+cnt: 18, ((T([12544, 2048], f16), [0], True), {})
+cnt: 18, ((T([256, 16, 49, 49], f16), [0], True), {})
+cnt: 18, ((T([12544, 1536], f16), [0], True), {})
+cnt: 4, ((T([50176, 256], f16), [0], True), {})
+cnt: 2, ((T([50176, 1024], f16), [0], True), {})
+cnt: 2, ((T([1024, 8, 49, 49], f16), [0], True), {})
+cnt: 2, ((T([50176, 768], f16), [0], True), {})
+cnt: 4, ((T([200704, 128], f16), [0], True), {})
+cnt: 2, ((T([200704, 512], f16), [0], True), {})
+cnt: 2, ((T([4096, 4, 49, 49], f16), [0], True), {})
+cnt: 2, ((T([200704, 384], f16), [0], True), {})
+Operator: aten.unbind.int
+cnt: 2, ((T([3, 4096, 4, 49, 32], f16, stride=(128, 18816, 32, 384, 1)),), {})
+cnt: 2, ((T([3, 1024, 8, 49, 32], f16, stride=(256, 37632, 32, 768, 1)),), {})
+cnt: 18, ((T([3, 256, 16, 49, 32], f16, stride=(512, 75264, 32, 1536, 1)),), {})
+cnt: 2, ((T([3, 64, 32, 49, 32], f16, stride=(1024, 150528, 32, 3072, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/swsl_resnext101_32x16d_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/swsl_resnext101_32x16d_training.txt
new file mode 100644
index 0000000000000..58d92f4b561ca
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/swsl_resnext101_32x16d_training.txt
@@ -0,0 +1,143 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([32, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([32, 1000], f16), T([32, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+cnt: 23, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 4, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 104, ((T([], i64), 1), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+cnt: 4, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 23, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([512, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 512, 56, 56], f16), T([512, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 3, ((T([32, 512, 56, 56], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 256, 56, 56], f16), T([512, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 56, 56], f16), T([1024, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 4, ((T([32, 1024, 28, 28], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 1024, 28, 28], f16), T([1024, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([2048, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2048, 28, 28], f16), T([2048, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 23, ((T([32, 2048, 14, 14], f16), T([1024, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 22, ((T([32, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 22, ((T([32, 2048, 14, 14], f16), T([2048, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([4096, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 4096, 14, 14], f16), T([4096, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 3, ((T([32, 4096, 7, 7], f16), T([2048, 4096, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([4096, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 4096, 7, 7], f16), T([4096, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 4096, 7, 7], f16), T([2048, 4096, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 4096, 7, 7], f16), T([32, 4096, 7, 7], f16), T([4096, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 2, ((T([32, 4096, 7, 7], f16), T([32, 2048, 7, 7], f16), T([4096, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 4096, 7, 7], f16), T([32, 4096, 14, 14], f16), T([4096, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([32, 4096, 14, 14], f16), T([32, 1024, 14, 14], f16), T([4096, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 23, ((T([32, 1024, 14, 14], f16), T([32, 2048, 14, 14], f16), T([1024, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 22, ((T([32, 2048, 14, 14], f16), T([32, 2048, 14, 14], f16), T([2048, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 22, ((T([32, 2048, 14, 14], f16), T([32, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 14, 14], f16), T([32, 2048, 28, 28], f16), T([2048, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 28, 28], f16), T([32, 512, 28, 28], f16), T([2048, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 512, 28, 28], f16), T([32, 1024, 28, 28], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 1024, 28, 28], f16), T([32, 1024, 28, 28], f16), T([1024, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 3, ((T([32, 1024, 28, 28], f16), T([32, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 256, 56, 56], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 28, 28], f16), T([32, 1024, 56, 56], f16), T([1024, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 56, 56], f16), T([32, 256, 56, 56], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 512, 56, 56], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 512, 56, 56], f16), T([32, 512, 56, 56], f16), T([512, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 2, ((T([32, 512, 56, 56], f16), T([32, 256, 56, 56], f16), T([512, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 56, 56], f16), T([32, 64, 56, 56], f16), T([512, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([32], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([32, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([32, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(1, 1000)), T([32, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+cnt: 6, ((T([32, 512, 56, 56], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 1024, 56, 56], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([32, 1024, 28, 28], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 2048, 28, 28], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+cnt: 45, ((T([32, 2048, 14, 14], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+cnt: 24, ((T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 4096, 14, 14], f16), T([4096], f16), T([4096], f16), T([4096], f16), T([4096], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([32, 4096, 7, 7], f16), T([4096], f16), T([4096], f16), T([4096], f16), T([4096], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([32, 4096, 7, 7], f16), T([32, 4096, 7, 7], f16), T([4096], f16), T([4096], f16), T([4096], f16), T([4096], f32), T([4096], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 4096, 14, 14], f16), T([32, 4096, 14, 14], f16), T([4096], f16), T([4096], f16), T([4096], f16), T([4096], f32), T([4096], f32), True, 1e-05, [True, True, True]), {})
+cnt: 24, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 45, ((T([32, 2048, 14, 14], f16), T([32, 2048, 14, 14], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 2048, 28, 28], f16), T([32, 2048, 28, 28], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([32, 1024, 28, 28], f16), T([32, 1024, 28, 28], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 1024, 56, 56], f16), T([32, 1024, 56, 56], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), True, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([32, 512, 56, 56], f16), T([32, 512, 56, 56], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([32, 1000], f16), T([32], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([32, 1000], f16), T([32], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([32, 64, 112, 112], f16),), {})
+cnt: 6, ((T([32, 512, 56, 56], f16),), {})
+cnt: 3, ((T([32, 256, 56, 56], f16),), {})
+cnt: 1, ((T([32, 1024, 56, 56], f16),), {})
+cnt: 7, ((T([32, 1024, 28, 28], f16),), {})
+cnt: 4, ((T([32, 512, 28, 28], f16),), {})
+cnt: 1, ((T([32, 2048, 28, 28], f16),), {})
+cnt: 45, ((T([32, 2048, 14, 14], f16),), {})
+cnt: 23, ((T([32, 1024, 14, 14], f16),), {})
+cnt: 1, ((T([32, 4096, 14, 14], f16),), {})
+cnt: 5, ((T([32, 4096, 7, 7], f16),), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), 0), {})
+cnt: 5, ((T([32, 4096, 7, 7], f16), T([32, 4096, 7, 7], f16), 0), {})
+cnt: 1, ((T([32, 4096, 14, 14], f16), T([32, 4096, 14, 14], f16), 0), {})
+cnt: 23, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), 0), {})
+cnt: 45, ((T([32, 2048, 14, 14], f16), T([32, 2048, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 2048, 28, 28], f16), T([32, 2048, 28, 28], f16), 0), {})
+cnt: 4, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), 0), {})
+cnt: 7, ((T([32, 1024, 28, 28], f16), T([32, 1024, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 1024, 56, 56], f16), T([32, 1024, 56, 56], f16), 0), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), 0), {})
+cnt: 6, ((T([32, 512, 56, 56], f16), T([32, 512, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tf_efficientnet_b0_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tf_efficientnet_b0_training.txt
new file mode 100644
index 0000000000000..b606244e7f83f
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tf_efficientnet_b0_training.txt
@@ -0,0 +1,312 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 49, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16)), {})
+cnt: 4, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16)), {})
+cnt: 4, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16)), {})
+cnt: 6, ((T([128, 192, 7, 7], f16), T([128, 192, 7, 7], f16)), {})
+cnt: 4, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 3, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16)), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16)), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16)), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16)), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+cnt: 2, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 8, 1, 1], f16),), {})
+cnt: 1, ((T([128, 96, 112, 112], f16),), {})
+cnt: 1, ((T([128, 96, 56, 56], f16),), {})
+cnt: 1, ((T([128, 4, 1, 1], f16),), {})
+cnt: 3, ((T([128, 144, 56, 56], f16),), {})
+cnt: 2, ((T([128, 6, 1, 1], f16),), {})
+cnt: 1, ((T([128, 144, 28, 28], f16),), {})
+cnt: 3, ((T([128, 240, 28, 28], f16),), {})
+cnt: 2, ((T([128, 10, 1, 1], f16),), {})
+cnt: 1, ((T([128, 240, 14, 14], f16),), {})
+cnt: 6, ((T([128, 480, 14, 14], f16),), {})
+cnt: 3, ((T([128, 20, 1, 1], f16),), {})
+cnt: 5, ((T([128, 672, 14, 14], f16),), {})
+cnt: 3, ((T([128, 28, 1, 1], f16),), {})
+cnt: 1, ((T([128, 672, 7, 7], f16),), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16),), {})
+cnt: 4, ((T([128, 48, 1, 1], f16),), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([128, 3, 224, 224], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), [1, 2, 1, 2], 0.0), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 672, 14, 14], f16), [1, 2, 1, 2], 0.0), {})
+cnt: 1, ((T([128, 672, 17, 17], f16), [-1, -2, -1, -2]), {})
+cnt: 1, ((T([128, 240, 29, 29], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([128, 144, 59, 59], f16), [-1, -2, -1, -2]), {})
+cnt: 1, ((T([128, 96, 113, 113], f16), [0, -1, 0, -1]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 225, 225], f16), T([32, 3, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([8, 32, 1, 1], f16), T([8], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([32, 8, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 113, 113], f16), T([96, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([128, 96, 1, 1], f16), T([4, 96, 1, 1], f16), T([4], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 4, 1, 1], f16), T([96, 4, 1, 1], f16), T([96], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([24, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([144, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([144, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 144), {})
+cnt: 2, ((T([128, 144, 1, 1], f16), T([6, 144, 1, 1], f16), T([6], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 6, 1, 1], f16), T([144, 6, 1, 1], f16), T([144], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([24, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 59, 59], f16), T([144, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([40, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([10, 240, 1, 1], f16), T([10], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 10, 1, 1], f16), T([240, 10, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([40, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 29, 29], f16), T([240, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([480, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 480), {})
+cnt: 3, ((T([128, 480, 1, 1], f16), T([20, 480, 1, 1], f16), T([20], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 20, 1, 1], f16), T([480, 20, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([80, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([480, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 480), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([112, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 112, 14, 14], f16), T([672, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 3, ((T([128, 672, 1, 1], f16), T([28, 672, 1, 1], f16), T([28], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 28, 1, 1], f16), T([672, 28, 1, 1], f16), T([672], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([112, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 17, 17], f16), T([672, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 672), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([192, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1152), {})
+cnt: 4, ((T([128, 1152, 1, 1], f16), T([48, 1152, 1, 1], f16), T([48], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 48, 1, 1], f16), T([1152, 48, 1, 1], f16), T([1152], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1152), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([128, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 1152, 1, 1], f16), T([128, 48, 1, 1], f16), T([1152, 48, 1, 1], f16), [1152], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([128, 48, 1, 1], f16), T([128, 1152, 1, 1], f16), T([48, 1152, 1, 1], f16), [48], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 4, ((T([128, 1152, 7, 7], f16), T([128, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 192, 7, 7], f16), T([128, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 7, 7], f16), T([128, 672, 7, 7], f16), T([192, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 672, 1, 1], f16), T([128, 28, 1, 1], f16), T([672, 28, 1, 1], f16), [672], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 28, 1, 1], f16), T([128, 672, 1, 1], f16), T([28, 672, 1, 1], f16), [28], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 17, 17], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 3, ((T([128, 672, 14, 14], f16), T([128, 112, 14, 14], f16), T([672, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 112, 14, 14], f16), T([128, 672, 14, 14], f16), T([112, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 14, 14], f16), T([128, 480, 14, 14], f16), T([112, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 480, 1, 1], f16), T([128, 20, 1, 1], f16), T([480, 20, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 20, 1, 1], f16), T([128, 480, 1, 1], f16), T([20, 480, 1, 1], f16), [20], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 3, ((T([128, 480, 14, 14], f16), T([128, 80, 14, 14], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 80, 14, 14], f16), T([128, 480, 14, 14], f16), T([80, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 14, 14], f16), T([128, 240, 14, 14], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 10, 1, 1], f16), T([240, 10, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 10, 1, 1], f16), T([128, 240, 1, 1], f16), T([10, 240, 1, 1], f16), [10], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 29, 29], f16), T([240, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 2, ((T([128, 240, 28, 28], f16), T([128, 40, 28, 28], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([128, 240, 28, 28], f16), T([40, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), T([240, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 28, 28], f16), T([128, 144, 28, 28], f16), T([40, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 144, 1, 1], f16), T([128, 6, 1, 1], f16), T([144, 6, 1, 1], f16), [144], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 6, 1, 1], f16), T([128, 144, 1, 1], f16), T([6, 144, 1, 1], f16), [6], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 59, 59], f16), T([144, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 2, ((T([128, 144, 56, 56], f16), T([128, 24, 56, 56], f16), T([144, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 144, 56, 56], f16), T([24, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16), T([144, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 96, 56, 56], f16), T([24, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 1, 1], f16), T([128, 4, 1, 1], f16), T([96, 4, 1, 1], f16), [96], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 4, 1, 1], f16), T([128, 96, 1, 1], f16), T([4, 96, 1, 1], f16), [4], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 113, 113], f16), T([96, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 16, 112, 112], f16), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 32, 112, 112], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 8, 1, 1], f16), T([32, 8, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([128, 32, 1, 1], f16), T([8, 32, 1, 1], f16), [8], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 3, 225, 225], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1280, 7, 7], f16, stride=(1280, 1, 0, 0)), 49), {})
+cnt: 4, ((T([128, 1152, 7, 7], f16, stride=(1152, 1, 0, 0)), 49), {})
+cnt: 1, ((T([128, 672, 7, 7], f16, stride=(672, 1, 0, 0)), 49), {})
+cnt: 2, ((T([128, 672, 14, 14], f16, stride=(672, 1, 0, 0)), 196), {})
+cnt: 3, ((T([128, 480, 14, 14], f16, stride=(480, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 240, 14, 14], f16, stride=(240, 1, 0, 0)), 196), {})
+cnt: 1, ((T([128, 240, 28, 28], f16, stride=(240, 1, 0, 0)), 784), {})
+cnt: 1, ((T([128, 144, 28, 28], f16, stride=(144, 1, 0, 0)), 784), {})
+cnt: 1, ((T([128, 144, 56, 56], f16, stride=(144, 1, 0, 0)), 3136), {})
+cnt: 1, ((T([128, 96, 56, 56], f16, stride=(96, 1, 0, 0)), 3136), {})
+cnt: 1, ((T([128, 32, 112, 112], f16, stride=(32, 1, 0, 0)), 12544), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 32, 112, 112], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([128, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 4, ((T([128, 1152, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 1, 1], f16)), {})
+cnt: 2, ((T([128, 96, 56, 56], f16), T([128, 96, 1, 1], f16)), {})
+cnt: 2, ((T([128, 144, 56, 56], f16), T([128, 144, 1, 1], f16)), {})
+cnt: 2, ((T([128, 144, 28, 28], f16), T([128, 144, 1, 1], f16)), {})
+cnt: 2, ((T([128, 240, 28, 28], f16), T([128, 240, 1, 1], f16)), {})
+cnt: 2, ((T([128, 240, 14, 14], f16), T([128, 240, 1, 1], f16)), {})
+cnt: 6, ((T([128, 480, 14, 14], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 4, ((T([128, 672, 14, 14], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 7, 7], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([128, 1152, 1, 1], f16)), {})
+cnt: 4, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 3, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16)), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16)), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16)), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16)), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16)), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 0.001), {})
+cnt: 6, ((T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 0.001), {})
+cnt: 3, ((T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), True, 0.1, 0.001), {})
+cnt: 5, ((T([128, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 0.001), {})
+cnt: 4, ((T([128, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 320, 7, 7], f16), T([128, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 0.001, [True, True, True]), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f32), T([1152], f32), True, 0.001, [True, True, True]), {})
+cnt: 4, ((T([128, 192, 7, 7], f16), T([128, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 0.001, [True, True, True]), {})
+cnt: 5, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 112, 14, 14], f16), T([128, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), True, 0.001, [True, True, True]), {})
+cnt: 6, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 80, 14, 14], f16), T([128, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([128, 40, 28, 28], f16), T([128, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 112, 112], f16), T([128, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([128, 32, 1, 1], f16),), {})
+cnt: 1, ((T([128, 96, 1, 1], f16),), {})
+cnt: 2, ((T([128, 144, 1, 1], f16),), {})
+cnt: 2, ((T([128, 240, 1, 1], f16),), {})
+cnt: 3, ((T([128, 480, 1, 1], f16),), {})
+cnt: 3, ((T([128, 672, 1, 1], f16),), {})
+cnt: 4, ((T([128, 1152, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 4, ((T([128, 1152, 1, 1], f16), T([128, 1152, 1, 1], f16)), {})
+cnt: 3, ((T([128, 672, 1, 1], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 3, ((T([128, 480, 1, 1], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 240, 1, 1], f16)), {})
+cnt: 2, ((T([128, 144, 1, 1], f16), T([128, 144, 1, 1], f16)), {})
+cnt: 1, ((T([128, 96, 1, 1], f16), T([128, 96, 1, 1], f16)), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 32, 1, 1], f16)), {})
+Operator: aten.silu_.default
+cnt: 2, ((T([128, 32, 112, 112], f16),), {})
+cnt: 1, ((T([128, 8, 1, 1], f16),), {})
+cnt: 1, ((T([128, 96, 112, 112], f16),), {})
+cnt: 1, ((T([128, 96, 56, 56], f16),), {})
+cnt: 1, ((T([128, 4, 1, 1], f16),), {})
+cnt: 3, ((T([128, 144, 56, 56], f16),), {})
+cnt: 2, ((T([128, 6, 1, 1], f16),), {})
+cnt: 1, ((T([128, 144, 28, 28], f16),), {})
+cnt: 3, ((T([128, 240, 28, 28], f16),), {})
+cnt: 2, ((T([128, 10, 1, 1], f16),), {})
+cnt: 1, ((T([128, 240, 14, 14], f16),), {})
+cnt: 6, ((T([128, 480, 14, 14], f16),), {})
+cnt: 3, ((T([128, 20, 1, 1], f16),), {})
+cnt: 5, ((T([128, 672, 14, 14], f16),), {})
+cnt: 3, ((T([128, 28, 1, 1], f16),), {})
+cnt: 1, ((T([128, 672, 7, 7], f16),), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16),), {})
+cnt: 4, ((T([128, 48, 1, 1], f16),), {})
+cnt: 1, ((T([128, 1280, 7, 7], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 1, ((T([128, 1280, 7, 7], f16), T([128, 1280, 7, 7], f16)), {})
+cnt: 4, ((T([128, 48, 1, 1], f16), T([128, 48, 1, 1], f16)), {})
+cnt: 8, ((T([128, 1152, 7, 7], f16), T([128, 1152, 7, 7], f16)), {})
+cnt: 3, ((T([128, 28, 1, 1], f16), T([128, 28, 1, 1], f16)), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), T([128, 672, 7, 7], f16)), {})
+cnt: 5, ((T([128, 672, 14, 14], f16), T([128, 672, 14, 14], f16)), {})
+cnt: 3, ((T([128, 20, 1, 1], f16), T([128, 20, 1, 1], f16)), {})
+cnt: 6, ((T([128, 480, 14, 14], f16), T([128, 480, 14, 14], f16)), {})
+cnt: 2, ((T([128, 10, 1, 1], f16), T([128, 10, 1, 1], f16)), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), T([128, 240, 14, 14], f16)), {})
+cnt: 3, ((T([128, 240, 28, 28], f16), T([128, 240, 28, 28], f16)), {})
+cnt: 2, ((T([128, 6, 1, 1], f16), T([128, 6, 1, 1], f16)), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), T([128, 144, 28, 28], f16)), {})
+cnt: 3, ((T([128, 144, 56, 56], f16), T([128, 144, 56, 56], f16)), {})
+cnt: 1, ((T([128, 4, 1, 1], f16), T([128, 4, 1, 1], f16)), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), T([128, 96, 56, 56], f16)), {})
+cnt: 1, ((T([128, 96, 112, 112], f16), T([128, 96, 112, 112], f16)), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([128, 8, 1, 1], f16)), {})
+cnt: 2, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 4, ((T([128, 1152, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([128, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 144, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 144, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 96, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), [2, 3], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tf_mixnet_l_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tf_mixnet_l_training.txt
new file mode 100644
index 0000000000000..5612bc45879f8
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tf_mixnet_l_training.txt
@@ -0,0 +1,408 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 58, ((T([], i64), 1), {})
+cnt: 2, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16)), {})
+cnt: 2, ((T([64, 40, 56, 56], f16), T([64, 40, 56, 56], f16)), {})
+cnt: 6, ((T([64, 56, 28, 28], f16), T([64, 56, 28, 28], f16)), {})
+cnt: 6, ((T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16)), {})
+cnt: 6, ((T([64, 160, 14, 14], f16), T([64, 160, 14, 14], f16)), {})
+cnt: 6, ((T([64, 264, 7, 7], f16), T([64, 264, 7, 7], f16)), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), T([64, 1584, 7, 7], f16)), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16)), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16)), {})
+cnt: 4, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16)), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([64, 336, 14, 14], f16)), {})
+cnt: 3, ((T([64, 336, 28, 28], f16), T([64, 336, 28, 28], f16)), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([64, 240, 28, 28], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([64, 1536], f16), T([1536, 1000], f16, stride=(1, 1536))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 96, 112, 112], f16), T([64, 96, 112, 112], f16)], 1), {})
+cnt: 1, (([T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16), T([64, 64, 56, 56], f16)], 1), {})
+cnt: 3, (([T([64, 20, 56, 56], f16), T([64, 20, 56, 56], f16)], 1), {})
+cnt: 2, (([T([64, 60, 56, 56], f16), T([64, 60, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 60, 28, 28], f16), T([64, 60, 28, 28], f16), T([64, 60, 28, 28], f16), T([64, 60, 28, 28], f16)], 1), {})
+cnt: 12, (([T([64, 168, 28, 28], f16), T([64, 168, 28, 28], f16)], 1), {})
+cnt: 6, (([T([64, 28, 28, 28], f16), T([64, 28, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 112, 14, 14], f16), T([64, 112, 14, 14], f16), T([64, 112, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 312, 14, 14], f16), T([64, 312, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 156, 14, 14], f16), T([64, 156, 14, 14], f16), T([64, 156, 14, 14], f16), T([64, 156, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 52, 14, 14], f16), T([64, 52, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 240, 14, 14], f16), T([64, 240, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 120, 14, 14], f16), T([64, 120, 14, 14], f16), T([64, 120, 14, 14], f16), T([64, 120, 14, 14], f16)], 1), {})
+cnt: 6, (([T([64, 80, 14, 14], f16), T([64, 80, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 240, 7, 7], f16), T([64, 240, 7, 7], f16), T([64, 240, 7, 7], f16), T([64, 240, 7, 7], f16)], 1), {})
+cnt: 6, (([T([64, 396, 7, 7], f16), T([64, 396, 7, 7], f16), T([64, 396, 7, 7], f16), T([64, 396, 7, 7], f16)], 1), {})
+cnt: 3, (([T([64, 132, 7, 7], f16), T([64, 132, 7, 7], f16)], 1), {})
+cnt: 3, (([T([64, 792, 7, 7], f16), T([64, 792, 7, 7], f16)], 1), {})
+cnt: 1, (([T([64, 240, 14, 14], f16), T([64, 240, 14, 14], f16), T([64, 240, 14, 14], f16), T([64, 240, 14, 14], f16)], 1), {})
+cnt: 1, (([T([64, 112, 28, 28], f16), T([64, 112, 28, 28], f16), T([64, 112, 28, 28], f16)], 1), {})
+cnt: 1, (([T([64, 60, 56, 56], f16), T([64, 60, 56, 56], f16), T([64, 60, 56, 56], f16), T([64, 60, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 96, 56, 56], f16), T([64, 96, 56, 56], f16)], 1), {})
+cnt: 1, (([T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16)], 1), {})
+cnt: 1, (([T([64, 16, 112, 112], f16), T([64, 16, 112, 112], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+cnt: 1, ((T([64, 240, 56, 56], f16),), {})
+cnt: 1, ((T([64, 240, 28, 28], f16),), {})
+cnt: 1, ((T([64, 20, 1, 1], f16),), {})
+cnt: 7, ((T([64, 336, 28, 28], f16),), {})
+cnt: 3, ((T([64, 28, 1, 1], f16),), {})
+cnt: 1, ((T([64, 336, 14, 14], f16),), {})
+cnt: 1, ((T([64, 14, 1, 1], f16),), {})
+cnt: 8, ((T([64, 624, 14, 14], f16),), {})
+cnt: 3, ((T([64, 26, 1, 1], f16),), {})
+cnt: 1, ((T([64, 52, 1, 1], f16),), {})
+cnt: 6, ((T([64, 480, 14, 14], f16),), {})
+cnt: 4, ((T([64, 80, 1, 1], f16),), {})
+cnt: 1, ((T([64, 960, 14, 14], f16),), {})
+cnt: 1, ((T([64, 960, 7, 7], f16),), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16),), {})
+cnt: 3, ((T([64, 132, 1, 1], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([64, 3, 224, 224], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([64, 64, 112, 112], f16, stride=(2408448, 12544, 112, 1)), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([64, 64, 112, 112], f16, stride=(2408448, 12544, 112, 1)), [1, 2, 1, 2], 0.0), {})
+cnt: 1, ((T([64, 64, 112, 112], f16, stride=(2408448, 12544, 112, 1)), [2, 3, 2, 3], 0.0), {})
+cnt: 1, ((T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), [1, 2, 1, 2], 0.0), {})
+cnt: 1, ((T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), [2, 3, 2, 3], 0.0), {})
+cnt: 1, ((T([64, 60, 56, 56], f16, stride=(752640, 3136, 56, 1)), [3, 4, 3, 4], 0.0), {})
+cnt: 1, ((T([64, 112, 28, 28], f16, stride=(263424, 784, 28, 1)), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([64, 112, 28, 28], f16, stride=(263424, 784, 28, 1)), [1, 2, 1, 2], 0.0), {})
+cnt: 1, ((T([64, 112, 28, 28], f16, stride=(263424, 784, 28, 1)), [2, 3, 2, 3], 0.0), {})
+cnt: 1, ((T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), [1, 2, 1, 2], 0.0), {})
+cnt: 1, ((T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), [2, 3, 2, 3], 0.0), {})
+cnt: 1, ((T([64, 240, 14, 14], f16, stride=(188160, 196, 14, 1)), [3, 4, 3, 4], 0.0), {})
+cnt: 1, ((T([64, 240, 21, 21], f16), [-3, -4, -3, -4]), {})
+cnt: 1, ((T([64, 240, 19, 19], f16), [-2, -3, -2, -3]), {})
+cnt: 1, ((T([64, 240, 17, 17], f16), [-1, -2, -1, -2]), {})
+cnt: 1, ((T([64, 240, 15, 15], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([64, 112, 33, 33], f16), [-2, -3, -2, -3]), {})
+cnt: 1, ((T([64, 112, 31, 31], f16), [-1, -2, -1, -2]), {})
+cnt: 1, ((T([64, 112, 29, 29], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([64, 60, 63, 63], f16), [-3, -4, -3, -4]), {})
+cnt: 1, ((T([64, 60, 61, 61], f16), [-2, -3, -2, -3]), {})
+cnt: 1, ((T([64, 60, 59, 59], f16), [-1, -2, -1, -2]), {})
+cnt: 1, ((T([64, 60, 57, 57], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([64, 64, 117, 117], f16), [-2, -3, -2, -3]), {})
+cnt: 1, ((T([64, 64, 115, 115], f16), [-1, -2, -1, -2]), {})
+cnt: 1, ((T([64, 64, 113, 113], f16), [0, -1, 0, -1]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 225, 225], f16), T([32, 3, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([32, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 16, 112, 112], f16, stride=(401408, 12544, 112, 1)), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 113, 113], f16), T([64, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([64, 64, 115, 115], f16), T([64, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([64, 64, 117, 117], f16), T([64, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 64), {})
+cnt: 2, ((T([64, 96, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([20, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 20, 56, 56], f16, stride=(125440, 3136, 56, 1)), T([60, 20, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 120, 56, 56], f16), T([120, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 120), {})
+cnt: 2, ((T([64, 60, 56, 56], f16, stride=(376320, 3136, 56, 1)), T([20, 60, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 40, 56, 56], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 60, 57, 57], f16), T([60, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 60), {})
+cnt: 1, ((T([64, 60, 59, 59], f16), T([60, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 60), {})
+cnt: 1, ((T([64, 60, 61, 61], f16), T([60, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 60), {})
+cnt: 1, ((T([64, 60, 63, 63], f16), T([60, 1, 9, 9], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 60), {})
+cnt: 1, ((T([64, 240, 1, 1], f16), T([20, 240, 1, 1], f16), T([20], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 20, 1, 1], f16), T([240, 20, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([56, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 28, 28, 28], f16, stride=(43904, 784, 28, 1)), T([168, 28, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([168, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 168), {})
+cnt: 3, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([168, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 168), {})
+cnt: 3, ((T([64, 336, 1, 1], f16), T([28, 336, 1, 1], f16), T([28], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 28, 1, 1], f16), T([336, 28, 1, 1], f16), T([336], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([28, 168, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 56, 28, 28], f16), T([336, 56, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 112, 29, 29], f16), T([112, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 112), {})
+cnt: 1, ((T([64, 112, 31, 31], f16), T([112, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 112), {})
+cnt: 1, ((T([64, 112, 33, 33], f16), T([112, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 112), {})
+cnt: 1, ((T([64, 336, 1, 1], f16), T([14, 336, 1, 1], f16), T([14], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 14, 1, 1], f16), T([336, 14, 1, 1], f16), T([336], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([104, 336, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 52, 14, 14], f16, stride=(20384, 196, 14, 1)), T([312, 52, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 156), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 156), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 156), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 9, 9], f16), None, [1, 1], [4, 4], [1, 1], False, [0, 0], 156), {})
+cnt: 3, ((T([64, 624, 1, 1], f16), T([26, 624, 1, 1], f16), T([26], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 26, 1, 1], f16), T([624, 26, 1, 1], f16), T([624], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 312, 14, 14], f16, stride=(122304, 196, 14, 1)), T([52, 312, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 104, 14, 14], f16), T([624, 104, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 624, 14, 14], f16), T([624, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 624), {})
+cnt: 1, ((T([64, 624, 1, 1], f16), T([52, 624, 1, 1], f16), T([52], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 52, 1, 1], f16), T([624, 52, 1, 1], f16), T([624], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 624, 14, 14], f16), T([160, 624, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 80, 14, 14], f16, stride=(31360, 196, 14, 1)), T([240, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 120), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 120), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 120), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 9, 9], f16), None, [1, 1], [4, 4], [1, 1], False, [0, 0], 120), {})
+cnt: 3, ((T([64, 480, 1, 1], f16), T([80, 480, 1, 1], f16), T([80], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 80, 1, 1], f16), T([480, 80, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 240, 14, 14], f16, stride=(94080, 196, 14, 1)), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 160, 14, 14], f16), T([960, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 240, 15, 15], f16), T([240, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([64, 240, 17, 17], f16), T([240, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([64, 240, 19, 19], f16), T([240, 1, 7, 7], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([64, 240, 21, 21], f16), T([240, 1, 9, 9], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([64, 960, 1, 1], f16), T([80, 960, 1, 1], f16), T([80], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 80, 1, 1], f16), T([960, 80, 1, 1], f16), T([960], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([264, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 264, 7, 7], f16), T([1584, 264, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 396), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 396), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 396), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 9, 9], f16), None, [1, 1], [4, 4], [1, 1], False, [0, 0], 396), {})
+cnt: 3, ((T([64, 1584, 1, 1], f16), T([132, 1584, 1, 1], f16), T([132], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 132, 1, 1], f16), T([1584, 132, 1, 1], f16), T([1584], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([64, 792, 7, 7], f16, stride=(77616, 49, 7, 1)), T([132, 792, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 264, 7, 7], f16), T([1536, 264, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 1536, 7, 7], f16), T([64, 264, 7, 7], f16), T([1536, 264, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([64, 132, 7, 7], f16, stride=(12936, 49, 7, 1)), T([64, 792, 7, 7], f16, stride=(77616, 49, 7, 1)), T([132, 792, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 1584, 1, 1], f16), T([64, 132, 1, 1], f16), T([1584, 132, 1, 1], f16), [1584], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 132, 1, 1], f16), T([64, 1584, 1, 1], f16), T([132, 1584, 1, 1], f16), [132], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 9, 9], f16), [0], [1, 1], [4, 4], [1, 1], False, [0, 0], 396, [True, True, False]), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 396, [True, True, False]), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 396, [True, True, False]), {})
+cnt: 3, ((T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([64, 396, 7, 7], f16, stride=(77616, 49, 7, 1)), T([396, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 396, [True, True, False]), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), T([64, 264, 7, 7], f16), T([1584, 264, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 264, 7, 7], f16), T([64, 960, 7, 7], f16), T([264, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 960, 1, 1], f16), T([64, 80, 1, 1], f16), T([960, 80, 1, 1], f16), [960], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 80, 1, 1], f16), T([64, 960, 1, 1], f16), T([80, 960, 1, 1], f16), [80], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 240, 7, 7], f16, stride=(47040, 49, 7, 1)), T([64, 240, 21, 21], f16), T([240, 1, 9, 9], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 7, 7], f16, stride=(47040, 49, 7, 1)), T([64, 240, 19, 19], f16), T([240, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 7, 7], f16, stride=(47040, 49, 7, 1)), T([64, 240, 17, 17], f16), T([240, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 7, 7], f16, stride=(47040, 49, 7, 1)), T([64, 240, 15, 15], f16), T([240, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([64, 160, 14, 14], f16), T([960, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([64, 80, 14, 14], f16, stride=(31360, 196, 14, 1)), T([64, 240, 14, 14], f16, stride=(94080, 196, 14, 1)), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 480, 1, 1], f16), T([64, 80, 1, 1], f16), T([480, 80, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 80, 1, 1], f16), T([64, 480, 1, 1], f16), T([80, 480, 1, 1], f16), [80], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 9, 9], f16), [0], [1, 1], [4, 4], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 3, ((T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 120, 14, 14], f16, stride=(94080, 196, 14, 1)), T([120, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 6, ((T([64, 240, 14, 14], f16, stride=(94080, 196, 14, 1)), T([64, 80, 14, 14], f16, stride=(31360, 196, 14, 1)), T([240, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 160, 14, 14], f16), T([64, 624, 14, 14], f16), T([160, 624, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 624, 1, 1], f16), T([64, 52, 1, 1], f16), T([624, 52, 1, 1], f16), [624], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 52, 1, 1], f16), T([64, 624, 1, 1], f16), T([52, 624, 1, 1], f16), [52], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16), T([624, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 624, [True, True, False]), {})
+cnt: 1, ((T([64, 624, 14, 14], f16), T([64, 104, 14, 14], f16), T([624, 104, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([64, 52, 14, 14], f16, stride=(20384, 196, 14, 1)), T([64, 312, 14, 14], f16, stride=(122304, 196, 14, 1)), T([52, 312, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 624, 1, 1], f16), T([64, 26, 1, 1], f16), T([624, 26, 1, 1], f16), [624], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 26, 1, 1], f16), T([64, 624, 1, 1], f16), T([26, 624, 1, 1], f16), [26], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 9, 9], f16), [0], [1, 1], [4, 4], [1, 1], False, [0, 0], 156, [True, True, False]), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 156, [True, True, False]), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 156, [True, True, False]), {})
+cnt: 3, ((T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 156, 14, 14], f16, stride=(122304, 196, 14, 1)), T([156, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 156, [True, True, False]), {})
+cnt: 6, ((T([64, 312, 14, 14], f16, stride=(122304, 196, 14, 1)), T([64, 52, 14, 14], f16, stride=(20384, 196, 14, 1)), T([312, 52, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 104, 14, 14], f16), T([64, 336, 14, 14], f16), T([104, 336, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 336, 1, 1], f16), T([64, 14, 1, 1], f16), T([336, 14, 1, 1], f16), [336], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 14, 1, 1], f16), T([64, 336, 1, 1], f16), T([14, 336, 1, 1], f16), [14], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 112, 14, 14], f16, stride=(65856, 196, 14, 1)), T([64, 112, 33, 33], f16), T([112, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 112, [True, True, False]), {})
+cnt: 1, ((T([64, 112, 14, 14], f16, stride=(65856, 196, 14, 1)), T([64, 112, 31, 31], f16), T([112, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 112, [True, True, False]), {})
+cnt: 1, ((T([64, 112, 14, 14], f16, stride=(65856, 196, 14, 1)), T([64, 112, 29, 29], f16), T([112, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 112, [True, True, False]), {})
+cnt: 1, ((T([64, 336, 28, 28], f16), T([64, 56, 28, 28], f16), T([336, 56, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([64, 28, 28, 28], f16, stride=(43904, 784, 28, 1)), T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([28, 168, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([64, 336, 1, 1], f16), T([64, 28, 1, 1], f16), T([336, 28, 1, 1], f16), [336], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 28, 1, 1], f16), T([64, 336, 1, 1], f16), T([28, 336, 1, 1], f16), [28], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([168, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 168, [True, True, False]), {})
+cnt: 3, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([168, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 168, [True, True, False]), {})
+cnt: 6, ((T([64, 168, 28, 28], f16, stride=(263424, 784, 28, 1)), T([64, 28, 28, 28], f16, stride=(43904, 784, 28, 1)), T([168, 28, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 56, 28, 28], f16), T([64, 240, 28, 28], f16), T([56, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 1, 1], f16), T([64, 20, 1, 1], f16), T([240, 20, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 20, 1, 1], f16), T([64, 240, 1, 1], f16), T([20, 240, 1, 1], f16), [20], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 60, 28, 28], f16, stride=(188160, 784, 28, 1)), T([64, 60, 63, 63], f16), T([60, 1, 9, 9], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 60, [True, True, False]), {})
+cnt: 1, ((T([64, 60, 28, 28], f16, stride=(188160, 784, 28, 1)), T([64, 60, 61, 61], f16), T([60, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 60, [True, True, False]), {})
+cnt: 1, ((T([64, 60, 28, 28], f16, stride=(188160, 784, 28, 1)), T([64, 60, 59, 59], f16), T([60, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 60, [True, True, False]), {})
+cnt: 1, ((T([64, 60, 28, 28], f16, stride=(188160, 784, 28, 1)), T([64, 60, 57, 57], f16), T([60, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 60, [True, True, False]), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), T([64, 40, 56, 56], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 20, 56, 56], f16, stride=(125440, 3136, 56, 1)), T([64, 60, 56, 56], f16, stride=(376320, 3136, 56, 1)), T([20, 60, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 120, 56, 56], f16), T([64, 120, 56, 56], f16), T([120, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 2, ((T([64, 60, 56, 56], f16, stride=(376320, 3136, 56, 1)), T([64, 20, 56, 56], f16, stride=(125440, 3136, 56, 1)), T([60, 20, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([64, 20, 56, 56], f16, stride=(125440, 3136, 56, 1)), T([64, 96, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([20, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([64, 64, 117, 117], f16), T([64, 1, 7, 7], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([64, 64, 115, 115], f16), T([64, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 56, 56], f16, stride=(602112, 3136, 56, 1)), T([64, 64, 113, 113], f16), T([64, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 2, ((T([64, 96, 112, 112], f16, stride=(2408448, 12544, 112, 1)), T([64, 16, 112, 112], f16, stride=(401408, 12544, 112, 1)), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), T([32, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([64, 32, 112, 112], f16), T([64, 3, 225, 225], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([64, 1536, 7, 7], f16, stride=(1536, 1, 0, 0)), 49), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16, stride=(1584, 1, 0, 0)), 49), {})
+cnt: 1, ((T([64, 960, 7, 7], f16, stride=(960, 1, 0, 0)), 49), {})
+cnt: 3, ((T([64, 480, 14, 14], f16, stride=(480, 1, 0, 0)), 196), {})
+cnt: 4, ((T([64, 624, 14, 14], f16, stride=(624, 1, 0, 0)), 196), {})
+cnt: 1, ((T([64, 336, 14, 14], f16, stride=(336, 1, 0, 0)), 196), {})
+cnt: 3, ((T([64, 336, 28, 28], f16, stride=(336, 1, 0, 0)), 784), {})
+cnt: 1, ((T([64, 240, 28, 28], f16, stride=(240, 1, 0, 0)), 784), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([64, 240, 28, 28], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 336, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), [2, 3], True), {})
+cnt: 4, ((T([64, 624, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 1536, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 1536], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 1536], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([64, 240, 28, 28], f16), T([64, 240, 1, 1], f16)), {})
+cnt: 6, ((T([64, 336, 28, 28], f16), T([64, 336, 1, 1], f16)), {})
+cnt: 2, ((T([64, 336, 14, 14], f16), T([64, 336, 1, 1], f16)), {})
+cnt: 8, ((T([64, 624, 14, 14], f16), T([64, 624, 1, 1], f16)), {})
+cnt: 6, ((T([64, 480, 14, 14], f16), T([64, 480, 1, 1], f16)), {})
+cnt: 2, ((T([64, 960, 7, 7], f16), T([64, 960, 1, 1], f16)), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16), T([64, 1584, 1, 1], f16)), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), T([64, 1584, 7, 7], f16)), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16)), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16)), {})
+cnt: 4, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16)), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([64, 336, 14, 14], f16)), {})
+cnt: 3, ((T([64, 336, 28, 28], f16), T([64, 336, 28, 28], f16)), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([64, 240, 28, 28], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 3, ((T([64, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([64, 192, 112, 112], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([64, 40, 56, 56], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 0.001), {})
+cnt: 2, ((T([64, 120, 56, 56], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 0.001), {})
+cnt: 4, ((T([64, 56, 28, 28], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f16), True, 0.1, 0.001), {})
+cnt: 7, ((T([64, 336, 28, 28], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f16), True, 0.1, 0.001), {})
+cnt: 4, ((T([64, 104, 14, 14], f16), T([104], f16), T([104], f16), T([104], f16), T([104], f16), True, 0.1, 0.001), {})
+cnt: 8, ((T([64, 624, 14, 14], f16), T([624], f16), T([624], f16), T([624], f16), T([624], f16), True, 0.1, 0.001), {})
+cnt: 4, ((T([64, 160, 14, 14], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), True, 0.1, 0.001), {})
+cnt: 6, ((T([64, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), True, 0.1, 0.001), {})
+cnt: 4, ((T([64, 264, 7, 7], f16), T([264], f16), T([264], f16), T([264], f16), T([264], f16), True, 0.1, 0.001), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16), T([1584], f16), T([1584], f16), T([1584], f16), T([1584], f16), True, 0.1, 0.001), {})
+cnt: 1, ((T([64, 1536, 7, 7], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f16), True, 0.1, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([64, 1536, 7, 7], f16), T([64, 1536, 7, 7], f16), T([1536], f16), T([1536], f16), T([1536], f16), T([1536], f32), T([1536], f32), True, 0.001, [True, True, True]), {})
+cnt: 4, ((T([64, 264, 7, 7], f16), T([64, 264, 7, 7], f16), T([264], f16), T([264], f16), T([264], f16), T([264], f32), T([264], f32), True, 0.001, [True, True, True]), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16), T([64, 1584, 7, 7], f16), T([1584], f16), T([1584], f16), T([1584], f16), T([1584], f32), T([1584], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([64, 960, 14, 14], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), True, 0.001, [True, True, True]), {})
+cnt: 4, ((T([64, 160, 14, 14], f16), T([64, 160, 14, 14], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), True, 0.001, [True, True, True]), {})
+cnt: 6, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 0.001, [True, True, True]), {})
+cnt: 8, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16), T([624], f16), T([624], f16), T([624], f16), T([624], f32), T([624], f32), True, 0.001, [True, True, True]), {})
+cnt: 4, ((T([64, 104, 14, 14], f16), T([64, 104, 14, 14], f16), T([104], f16), T([104], f16), T([104], f16), T([104], f32), T([104], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([64, 336, 14, 14], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 0.001, [True, True, True]), {})
+cnt: 7, ((T([64, 336, 28, 28], f16), T([64, 336, 28, 28], f16), T([336], f16), T([336], f16), T([336], f16), T([336], f32), T([336], f32), True, 0.001, [True, True, True]), {})
+cnt: 4, ((T([64, 56, 28, 28], f16), T([64, 56, 28, 28], f16), T([56], f16), T([56], f16), T([56], f16), T([56], f32), T([56], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([64, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), T([64, 240, 56, 56], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([64, 40, 56, 56], f16), T([64, 40, 56, 56], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 0.001, [True, True, True]), {})
+cnt: 2, ((T([64, 120, 56, 56], f16), T([64, 120, 56, 56], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([64, 192, 56, 56], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 1, ((T([64, 192, 112, 112], f16), T([64, 192, 112, 112], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 0.001, [True, True, True]), {})
+cnt: 3, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 0.001, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([64, 32, 112, 112], f16),), {})
+cnt: 1, ((T([64, 192, 112, 112], f16),), {})
+cnt: 1, ((T([64, 192, 56, 56], f16),), {})
+cnt: 2, ((T([64, 120, 56, 56], f16),), {})
+cnt: 1, ((T([64, 1536, 7, 7], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([64, 240, 1, 1], f16),), {})
+cnt: 4, ((T([64, 336, 1, 1], f16),), {})
+cnt: 4, ((T([64, 624, 1, 1], f16),), {})
+cnt: 3, ((T([64, 480, 1, 1], f16),), {})
+cnt: 1, ((T([64, 960, 1, 1], f16),), {})
+cnt: 3, ((T([64, 1584, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 3, ((T([64, 1584, 1, 1], f16), T([64, 1584, 1, 1], f16)), {})
+cnt: 1, ((T([64, 960, 1, 1], f16), T([64, 960, 1, 1], f16)), {})
+cnt: 3, ((T([64, 480, 1, 1], f16), T([64, 480, 1, 1], f16)), {})
+cnt: 4, ((T([64, 624, 1, 1], f16), T([64, 624, 1, 1], f16)), {})
+cnt: 4, ((T([64, 336, 1, 1], f16), T([64, 336, 1, 1], f16)), {})
+cnt: 1, ((T([64, 240, 1, 1], f16), T([64, 240, 1, 1], f16)), {})
+Operator: aten.silu_.default
+cnt: 1, ((T([64, 240, 56, 56], f16),), {})
+cnt: 1, ((T([64, 240, 28, 28], f16),), {})
+cnt: 1, ((T([64, 20, 1, 1], f16),), {})
+cnt: 7, ((T([64, 336, 28, 28], f16),), {})
+cnt: 3, ((T([64, 28, 1, 1], f16),), {})
+cnt: 1, ((T([64, 336, 14, 14], f16),), {})
+cnt: 1, ((T([64, 14, 1, 1], f16),), {})
+cnt: 8, ((T([64, 624, 14, 14], f16),), {})
+cnt: 3, ((T([64, 26, 1, 1], f16),), {})
+cnt: 1, ((T([64, 52, 1, 1], f16),), {})
+cnt: 6, ((T([64, 480, 14, 14], f16),), {})
+cnt: 4, ((T([64, 80, 1, 1], f16),), {})
+cnt: 1, ((T([64, 960, 14, 14], f16),), {})
+cnt: 1, ((T([64, 960, 7, 7], f16),), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16),), {})
+cnt: 3, ((T([64, 132, 1, 1], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 3, ((T([64, 132, 1, 1], f16), T([64, 132, 1, 1], f16)), {})
+cnt: 6, ((T([64, 1584, 7, 7], f16), T([64, 1584, 7, 7], f16)), {})
+cnt: 4, ((T([64, 80, 1, 1], f16), T([64, 80, 1, 1], f16)), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), T([64, 960, 7, 7], f16)), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), T([64, 960, 14, 14], f16)), {})
+cnt: 6, ((T([64, 480, 14, 14], f16), T([64, 480, 14, 14], f16)), {})
+cnt: 1, ((T([64, 52, 1, 1], f16), T([64, 52, 1, 1], f16)), {})
+cnt: 8, ((T([64, 624, 14, 14], f16), T([64, 624, 14, 14], f16)), {})
+cnt: 3, ((T([64, 26, 1, 1], f16), T([64, 26, 1, 1], f16)), {})
+cnt: 1, ((T([64, 14, 1, 1], f16), T([64, 14, 1, 1], f16)), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), T([64, 336, 14, 14], f16)), {})
+cnt: 7, ((T([64, 336, 28, 28], f16), T([64, 336, 28, 28], f16)), {})
+cnt: 3, ((T([64, 28, 1, 1], f16), T([64, 28, 1, 1], f16)), {})
+cnt: 1, ((T([64, 20, 1, 1], f16), T([64, 20, 1, 1], f16)), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), T([64, 240, 28, 28], f16)), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), T([64, 240, 56, 56], f16)), {})
+Operator: aten.split_with_sizes.default
+cnt: 1, ((T([64, 32, 112, 112], f16), [16, 16], 1), {})
+cnt: 1, ((T([64, 192, 112, 112], f16), [64, 64, 64], 1), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), [96, 96], 1), {})
+cnt: 1, ((T([64, 40, 56, 56], f16), [20, 20], 1), {})
+cnt: 1, ((T([64, 120, 56, 56], f16), [60, 60], 1), {})
+cnt: 1, ((T([64, 240, 56, 56], f16), [60, 60, 60, 60], 1), {})
+cnt: 3, ((T([64, 56, 28, 28], f16), [28, 28], 1), {})
+cnt: 6, ((T([64, 336, 28, 28], f16), [168, 168], 1), {})
+cnt: 1, ((T([64, 336, 28, 28], f16), [112, 112, 112], 1), {})
+cnt: 3, ((T([64, 104, 14, 14], f16), [52, 52], 1), {})
+cnt: 3, ((T([64, 624, 14, 14], f16), [156, 156, 156, 156], 1), {})
+cnt: 3, ((T([64, 624, 14, 14], f16), [312, 312], 1), {})
+cnt: 3, ((T([64, 160, 14, 14], f16), [80, 80], 1), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), [120, 120, 120, 120], 1), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), [240, 240], 1), {})
+cnt: 1, ((T([64, 960, 14, 14], f16), [240, 240, 240, 240], 1), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), [396, 396, 396, 396], 1), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), [792, 792], 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 3, ((T([64, 1584, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 960, 7, 7], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 4, ((T([64, 624, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 336, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([64, 336, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([64, 240, 28, 28], f16), [2, 3], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([64, 1536, 7, 7], f16), T([64, 1536, 7, 7], f16), 0), {})
+cnt: 2, ((T([64, 120, 56, 56], f16), T([64, 120, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 192, 56, 56], f16), T([64, 192, 56, 56], f16), 0), {})
+cnt: 1, ((T([64, 192, 112, 112], f16), T([64, 192, 112, 112], f16), 0), {})
+cnt: 2, ((T([64, 32, 112, 112], f16), T([64, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tinynet_a_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tinynet_a_training.txt
new file mode 100644
index 0000000000000..c3f1255f43ee6
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tinynet_a_training.txt
@@ -0,0 +1,302 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 58, ((T([], i64), 1), {})
+cnt: 2, ((T([128, 24, 48, 48], f16), T([128, 24, 48, 48], f16)), {})
+cnt: 2, ((T([128, 40, 24, 24], f16), T([128, 40, 24, 24], f16)), {})
+cnt: 6, ((T([128, 80, 12, 12], f16), T([128, 80, 12, 12], f16)), {})
+cnt: 6, ((T([128, 112, 12, 12], f16), T([128, 112, 12, 12], f16)), {})
+cnt: 8, ((T([128, 192, 6, 6], f16), T([128, 192, 6, 6], f16)), {})
+cnt: 5, ((T([128, 1152, 6, 6], f16), T([128, 1152, 6, 6], f16)), {})
+cnt: 1, ((T([128, 672, 6, 6], f16), T([128, 672, 6, 6], f16)), {})
+cnt: 3, ((T([128, 672, 12, 12], f16), T([128, 672, 12, 12], f16)), {})
+cnt: 4, ((T([128, 480, 12, 12], f16), T([128, 480, 12, 12], f16)), {})
+cnt: 1, ((T([128, 240, 12, 12], f16), T([128, 240, 12, 12], f16)), {})
+cnt: 1, ((T([128, 240, 24, 24], f16), T([128, 240, 24, 24], f16)), {})
+cnt: 1, ((T([128, 144, 24, 24], f16), T([128, 144, 24, 24], f16)), {})
+cnt: 1, ((T([128, 144, 48, 48], f16), T([128, 144, 48, 48], f16)), {})
+cnt: 1, ((T([128, 96, 48, 48], f16), T([128, 96, 48, 48], f16)), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([128, 32, 96, 96], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 192, 192], f16),), {})
+cnt: 2, ((T([128, 32, 96, 96], f16),), {})
+cnt: 1, ((T([128, 8, 1, 1], f16),), {})
+cnt: 1, ((T([128, 96, 96, 96], f16),), {})
+cnt: 1, ((T([128, 96, 48, 48], f16),), {})
+cnt: 1, ((T([128, 4, 1, 1], f16),), {})
+cnt: 3, ((T([128, 144, 48, 48], f16),), {})
+cnt: 2, ((T([128, 6, 1, 1], f16),), {})
+cnt: 1, ((T([128, 144, 24, 24], f16),), {})
+cnt: 3, ((T([128, 240, 24, 24], f16),), {})
+cnt: 2, ((T([128, 10, 1, 1], f16),), {})
+cnt: 1, ((T([128, 240, 12, 12], f16),), {})
+cnt: 8, ((T([128, 480, 12, 12], f16),), {})
+cnt: 4, ((T([128, 20, 1, 1], f16),), {})
+cnt: 7, ((T([128, 672, 12, 12], f16),), {})
+cnt: 4, ((T([128, 28, 1, 1], f16),), {})
+cnt: 1, ((T([128, 672, 6, 6], f16),), {})
+cnt: 10, ((T([128, 1152, 6, 6], f16),), {})
+cnt: 5, ((T([128, 48, 1, 1], f16),), {})
+cnt: 1, ((T([128, 1280, 6, 6], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 192, 192], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([8, 32, 1, 1], f16), T([8], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([32, 8, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 96, 96], f16), T([96, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([128, 96, 1, 1], f16), T([4, 96, 1, 1], f16), T([4], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 4, 1, 1], f16), T([96, 4, 1, 1], f16), T([96], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 96, 48, 48], f16), T([24, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 24, 48, 48], f16), T([144, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 48, 48], f16), T([144, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 144), {})
+cnt: 2, ((T([128, 144, 1, 1], f16), T([6, 144, 1, 1], f16), T([6], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 6, 1, 1], f16), T([144, 6, 1, 1], f16), T([144], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 48, 48], f16), T([24, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 144, 48, 48], f16), T([144, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([128, 144, 24, 24], f16), T([40, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 40, 24, 24], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 24, 24], f16), T([240, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([10, 240, 1, 1], f16), T([10], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 10, 1, 1], f16), T([240, 10, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 24, 24], f16), T([40, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 240, 24, 24], f16), T([240, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([128, 240, 12, 12], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 80, 12, 12], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 480, 12, 12], f16), T([480, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 480), {})
+cnt: 4, ((T([128, 480, 1, 1], f16), T([20, 480, 1, 1], f16), T([20], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 20, 1, 1], f16), T([480, 20, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 480, 12, 12], f16), T([80, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 480, 12, 12], f16), T([480, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 480), {})
+cnt: 1, ((T([128, 480, 12, 12], f16), T([112, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 112, 12, 12], f16), T([672, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 672, 12, 12], f16), T([672, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 4, ((T([128, 672, 1, 1], f16), T([28, 672, 1, 1], f16), T([28], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 28, 1, 1], f16), T([672, 28, 1, 1], f16), T([672], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 672, 12, 12], f16), T([112, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 672, 12, 12], f16), T([672, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 1, ((T([128, 672, 6, 6], f16), T([192, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 192, 6, 6], f16), T([1152, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 1152, 6, 6], f16), T([1152, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1152), {})
+cnt: 5, ((T([128, 1152, 1, 1], f16), T([48, 1152, 1, 1], f16), T([48], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([128, 48, 1, 1], f16), T([1152, 48, 1, 1], f16), T([1152], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 1152, 6, 6], f16), T([192, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1152, 6, 6], f16), T([1152, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1152), {})
+cnt: 1, ((T([128, 1152, 6, 6], f16), T([320, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 320, 6, 6], f16), T([1280, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1280, 6, 6], f16), T([128, 320, 6, 6], f16), T([1280, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 320, 6, 6], f16), T([128, 1152, 6, 6], f16), T([320, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 1152, 1, 1], f16), T([128, 48, 1, 1], f16), T([1152, 48, 1, 1], f16), [1152], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 5, ((T([128, 48, 1, 1], f16), T([128, 1152, 1, 1], f16), T([48, 1152, 1, 1], f16), [48], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1152, 6, 6], f16), T([128, 1152, 6, 6], f16), T([1152, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 5, ((T([128, 1152, 6, 6], f16), T([128, 192, 6, 6], f16), T([1152, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 192, 6, 6], f16), T([128, 1152, 6, 6], f16), T([192, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 1152, 6, 6], f16), T([128, 1152, 6, 6], f16), T([1152, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 6, 6], f16), T([128, 672, 6, 6], f16), T([192, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 672, 1, 1], f16), T([128, 28, 1, 1], f16), T([672, 28, 1, 1], f16), [672], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([128, 28, 1, 1], f16), T([128, 672, 1, 1], f16), T([28, 672, 1, 1], f16), [28], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 6, 6], f16), T([128, 672, 12, 12], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 4, ((T([128, 672, 12, 12], f16), T([128, 112, 12, 12], f16), T([672, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 112, 12, 12], f16), T([128, 672, 12, 12], f16), T([112, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 672, 12, 12], f16), T([128, 672, 12, 12], f16), T([672, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 1, ((T([128, 112, 12, 12], f16), T([128, 480, 12, 12], f16), T([112, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 480, 1, 1], f16), T([128, 20, 1, 1], f16), T([480, 20, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([128, 20, 1, 1], f16), T([128, 480, 1, 1], f16), T([20, 480, 1, 1], f16), [20], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 480, 12, 12], f16), T([128, 480, 12, 12], f16), T([480, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 4, ((T([128, 480, 12, 12], f16), T([128, 80, 12, 12], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 80, 12, 12], f16), T([128, 480, 12, 12], f16), T([80, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 480, 12, 12], f16), T([128, 480, 12, 12], f16), T([480, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 1, ((T([128, 80, 12, 12], f16), T([128, 240, 12, 12], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 10, 1, 1], f16), T([240, 10, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 10, 1, 1], f16), T([128, 240, 1, 1], f16), T([10, 240, 1, 1], f16), [10], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 12, 12], f16), T([128, 240, 24, 24], f16), T([240, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 2, ((T([128, 240, 24, 24], f16), T([128, 40, 24, 24], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 24, 24], f16), T([128, 240, 24, 24], f16), T([40, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 240, 24, 24], f16), T([128, 240, 24, 24], f16), T([240, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([128, 40, 24, 24], f16), T([128, 144, 24, 24], f16), T([40, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 144, 1, 1], f16), T([128, 6, 1, 1], f16), T([144, 6, 1, 1], f16), [144], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 6, 1, 1], f16), T([128, 144, 1, 1], f16), T([6, 144, 1, 1], f16), [6], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 144, 24, 24], f16), T([128, 144, 48, 48], f16), T([144, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 2, ((T([128, 144, 48, 48], f16), T([128, 24, 48, 48], f16), T([144, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 48, 48], f16), T([128, 144, 48, 48], f16), T([24, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 144, 48, 48], f16), T([128, 144, 48, 48], f16), T([144, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 48, 48], f16), T([128, 96, 48, 48], f16), T([24, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 1, 1], f16), T([128, 4, 1, 1], f16), T([96, 4, 1, 1], f16), [96], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 4, 1, 1], f16), T([128, 96, 1, 1], f16), T([4, 96, 1, 1], f16), [4], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 48, 48], f16), T([128, 96, 96, 96], f16), T([96, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([128, 96, 96, 96], f16), T([128, 16, 96, 96], f16), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([128, 32, 96, 96], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 8, 1, 1], f16), T([32, 8, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([128, 32, 1, 1], f16), T([8, 32, 1, 1], f16), [8], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([128, 32, 96, 96], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([128, 3, 192, 192], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 192, 192], f16), T([128, 3, 192, 192], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1280, 6, 6], f16, stride=(1280, 1, 0, 0)), 36), {})
+cnt: 5, ((T([128, 1152, 6, 6], f16, stride=(1152, 1, 0, 0)), 36), {})
+cnt: 1, ((T([128, 672, 6, 6], f16, stride=(672, 1, 0, 0)), 36), {})
+cnt: 3, ((T([128, 672, 12, 12], f16, stride=(672, 1, 0, 0)), 144), {})
+cnt: 4, ((T([128, 480, 12, 12], f16, stride=(480, 1, 0, 0)), 144), {})
+cnt: 1, ((T([128, 240, 12, 12], f16, stride=(240, 1, 0, 0)), 144), {})
+cnt: 1, ((T([128, 240, 24, 24], f16, stride=(240, 1, 0, 0)), 576), {})
+cnt: 1, ((T([128, 144, 24, 24], f16, stride=(144, 1, 0, 0)), 576), {})
+cnt: 1, ((T([128, 144, 48, 48], f16, stride=(144, 1, 0, 0)), 2304), {})
+cnt: 1, ((T([128, 96, 48, 48], f16, stride=(96, 1, 0, 0)), 2304), {})
+cnt: 1, ((T([128, 32, 96, 96], f16, stride=(32, 1, 0, 0)), 9216), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 32, 96, 96], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 96, 48, 48], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 144, 48, 48], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 144, 24, 24], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 24, 24], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 12, 12], f16), [2, 3], True), {})
+cnt: 4, ((T([128, 480, 12, 12], f16), [2, 3], True), {})
+cnt: 3, ((T([128, 672, 12, 12], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 6, 6], f16), [2, 3], True), {})
+cnt: 5, ((T([128, 1152, 6, 6], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 1280, 6, 6], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 1280], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([128, 32, 96, 96], f16), T([128, 32, 1, 1], f16)), {})
+cnt: 2, ((T([128, 96, 48, 48], f16), T([128, 96, 1, 1], f16)), {})
+cnt: 2, ((T([128, 144, 48, 48], f16), T([128, 144, 1, 1], f16)), {})
+cnt: 2, ((T([128, 144, 24, 24], f16), T([128, 144, 1, 1], f16)), {})
+cnt: 2, ((T([128, 240, 24, 24], f16), T([128, 240, 1, 1], f16)), {})
+cnt: 2, ((T([128, 240, 12, 12], f16), T([128, 240, 1, 1], f16)), {})
+cnt: 8, ((T([128, 480, 12, 12], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 6, ((T([128, 672, 12, 12], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 2, ((T([128, 672, 6, 6], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 10, ((T([128, 1152, 6, 6], f16), T([128, 1152, 1, 1], f16)), {})
+cnt: 5, ((T([128, 1152, 6, 6], f16), T([128, 1152, 6, 6], f16)), {})
+cnt: 1, ((T([128, 672, 6, 6], f16), T([128, 672, 6, 6], f16)), {})
+cnt: 3, ((T([128, 672, 12, 12], f16), T([128, 672, 12, 12], f16)), {})
+cnt: 4, ((T([128, 480, 12, 12], f16), T([128, 480, 12, 12], f16)), {})
+cnt: 1, ((T([128, 240, 12, 12], f16), T([128, 240, 12, 12], f16)), {})
+cnt: 1, ((T([128, 240, 24, 24], f16), T([128, 240, 24, 24], f16)), {})
+cnt: 1, ((T([128, 144, 24, 24], f16), T([128, 144, 24, 24], f16)), {})
+cnt: 1, ((T([128, 144, 48, 48], f16), T([128, 144, 48, 48], f16)), {})
+cnt: 1, ((T([128, 96, 48, 48], f16), T([128, 96, 48, 48], f16)), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([128, 32, 96, 96], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([128, 32, 96, 96], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 96, 96, 96], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 96, 48, 48], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 24, 48, 48], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 144, 48, 48], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 144, 24, 24], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), True, 0.1, 1e-05), {})
+cnt: 2, ((T([128, 40, 24, 24], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), True, 0.1, 1e-05), {})
+cnt: 3, ((T([128, 240, 24, 24], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 240, 12, 12], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 80, 12, 12], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([128, 480, 12, 12], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), True, 0.1, 1e-05), {})
+cnt: 4, ((T([128, 112, 12, 12], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), True, 0.1, 1e-05), {})
+cnt: 7, ((T([128, 672, 12, 12], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 672, 6, 6], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), True, 0.1, 1e-05), {})
+cnt: 5, ((T([128, 192, 6, 6], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 10, ((T([128, 1152, 6, 6], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 320, 6, 6], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), True, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1280, 6, 6], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1280, 6, 6], f16), T([128, 1280, 6, 6], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 320, 6, 6], f16), T([128, 320, 6, 6], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), True, 1e-05, [True, True, True]), {})
+cnt: 10, ((T([128, 1152, 6, 6], f16), T([128, 1152, 6, 6], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f32), T([1152], f32), True, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([128, 192, 6, 6], f16), T([128, 192, 6, 6], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 672, 6, 6], f16), T([128, 672, 6, 6], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([128, 672, 12, 12], f16), T([128, 672, 12, 12], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 112, 12, 12], f16), T([128, 112, 12, 12], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([128, 480, 12, 12], f16), T([128, 480, 12, 12], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), True, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([128, 80, 12, 12], f16), T([128, 80, 12, 12], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 240, 12, 12], f16), T([128, 240, 12, 12], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 240, 24, 24], f16), T([128, 240, 24, 24], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 40, 24, 24], f16), T([128, 40, 24, 24], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 144, 24, 24], f16), T([128, 144, 24, 24], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([128, 144, 48, 48], f16), T([128, 144, 48, 48], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 24, 48, 48], f16), T([128, 24, 48, 48], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 48, 48], f16), T([128, 96, 48, 48], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 96, 96, 96], f16), T([128, 96, 96, 96], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([128, 16, 96, 96], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), True, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([128, 32, 96, 96], f16), T([128, 32, 96, 96], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([128, 32, 1, 1], f16),), {})
+cnt: 1, ((T([128, 96, 1, 1], f16),), {})
+cnt: 2, ((T([128, 144, 1, 1], f16),), {})
+cnt: 2, ((T([128, 240, 1, 1], f16),), {})
+cnt: 4, ((T([128, 480, 1, 1], f16),), {})
+cnt: 4, ((T([128, 672, 1, 1], f16),), {})
+cnt: 5, ((T([128, 1152, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 5, ((T([128, 1152, 1, 1], f16), T([128, 1152, 1, 1], f16)), {})
+cnt: 4, ((T([128, 672, 1, 1], f16), T([128, 672, 1, 1], f16)), {})
+cnt: 4, ((T([128, 480, 1, 1], f16), T([128, 480, 1, 1], f16)), {})
+cnt: 2, ((T([128, 240, 1, 1], f16), T([128, 240, 1, 1], f16)), {})
+cnt: 2, ((T([128, 144, 1, 1], f16), T([128, 144, 1, 1], f16)), {})
+cnt: 1, ((T([128, 96, 1, 1], f16), T([128, 96, 1, 1], f16)), {})
+cnt: 1, ((T([128, 32, 1, 1], f16), T([128, 32, 1, 1], f16)), {})
+Operator: aten.silu_.default
+cnt: 2, ((T([128, 32, 96, 96], f16),), {})
+cnt: 1, ((T([128, 8, 1, 1], f16),), {})
+cnt: 1, ((T([128, 96, 96, 96], f16),), {})
+cnt: 1, ((T([128, 96, 48, 48], f16),), {})
+cnt: 1, ((T([128, 4, 1, 1], f16),), {})
+cnt: 3, ((T([128, 144, 48, 48], f16),), {})
+cnt: 2, ((T([128, 6, 1, 1], f16),), {})
+cnt: 1, ((T([128, 144, 24, 24], f16),), {})
+cnt: 3, ((T([128, 240, 24, 24], f16),), {})
+cnt: 2, ((T([128, 10, 1, 1], f16),), {})
+cnt: 1, ((T([128, 240, 12, 12], f16),), {})
+cnt: 8, ((T([128, 480, 12, 12], f16),), {})
+cnt: 4, ((T([128, 20, 1, 1], f16),), {})
+cnt: 7, ((T([128, 672, 12, 12], f16),), {})
+cnt: 4, ((T([128, 28, 1, 1], f16),), {})
+cnt: 1, ((T([128, 672, 6, 6], f16),), {})
+cnt: 10, ((T([128, 1152, 6, 6], f16),), {})
+cnt: 5, ((T([128, 48, 1, 1], f16),), {})
+cnt: 1, ((T([128, 1280, 6, 6], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 1, ((T([128, 1280, 6, 6], f16), T([128, 1280, 6, 6], f16)), {})
+cnt: 5, ((T([128, 48, 1, 1], f16), T([128, 48, 1, 1], f16)), {})
+cnt: 10, ((T([128, 1152, 6, 6], f16), T([128, 1152, 6, 6], f16)), {})
+cnt: 4, ((T([128, 28, 1, 1], f16), T([128, 28, 1, 1], f16)), {})
+cnt: 1, ((T([128, 672, 6, 6], f16), T([128, 672, 6, 6], f16)), {})
+cnt: 7, ((T([128, 672, 12, 12], f16), T([128, 672, 12, 12], f16)), {})
+cnt: 4, ((T([128, 20, 1, 1], f16), T([128, 20, 1, 1], f16)), {})
+cnt: 8, ((T([128, 480, 12, 12], f16), T([128, 480, 12, 12], f16)), {})
+cnt: 2, ((T([128, 10, 1, 1], f16), T([128, 10, 1, 1], f16)), {})
+cnt: 1, ((T([128, 240, 12, 12], f16), T([128, 240, 12, 12], f16)), {})
+cnt: 3, ((T([128, 240, 24, 24], f16), T([128, 240, 24, 24], f16)), {})
+cnt: 2, ((T([128, 6, 1, 1], f16), T([128, 6, 1, 1], f16)), {})
+cnt: 1, ((T([128, 144, 24, 24], f16), T([128, 144, 24, 24], f16)), {})
+cnt: 3, ((T([128, 144, 48, 48], f16), T([128, 144, 48, 48], f16)), {})
+cnt: 1, ((T([128, 4, 1, 1], f16), T([128, 4, 1, 1], f16)), {})
+cnt: 1, ((T([128, 96, 48, 48], f16), T([128, 96, 48, 48], f16)), {})
+cnt: 1, ((T([128, 96, 96, 96], f16), T([128, 96, 96, 96], f16)), {})
+cnt: 1, ((T([128, 8, 1, 1], f16), T([128, 8, 1, 1], f16)), {})
+cnt: 2, ((T([128, 32, 96, 96], f16), T([128, 32, 96, 96], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 5, ((T([128, 1152, 6, 6], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 672, 6, 6], f16), [2, 3], True), {})
+cnt: 3, ((T([128, 672, 12, 12], f16), [2, 3], True), {})
+cnt: 4, ((T([128, 480, 12, 12], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 12, 12], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 240, 24, 24], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 144, 24, 24], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 144, 48, 48], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 96, 48, 48], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), [2, 3], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tnt_s_patch16_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tnt_s_patch16_224_training.txt
new file mode 100644
index 0000000000000..d7622dd4d8ce7
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/tnt_s_patch16_224_training.txt
@@ -0,0 +1,146 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([12544, 4, 16, 16], f16), -1, False), {})
+cnt: 12, ((T([64, 6, 197, 197], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([64, 6, 197, 197], f16), T([64, 6, 197, 197], f16), -1, f16), {})
+cnt: 12, ((T([12544, 4, 16, 16], f16), T([12544, 4, 16, 16], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 1, ((T([64, 196, 384], f16), [12544, 24, 4, 4]), {})
+cnt: 1, ((T([12544, 16, 24], f16), [64, 196, 384]), {})
+cnt: 12, ((T([200704, 48], f16), [12544, 16, 48]), {})
+cnt: 12, ((T([200704, 24], f16), [12544, 16, 24]), {})
+cnt: 36, ((T([12544, 4, 16, 6], f16), [50176, 16, 6]), {})
+cnt: 12, ((T([12544, 4, 6, 16], f16), [50176, 6, 16]), {})
+cnt: 12, ((T([50176, 16, 16], f16), [12544, 4, 16, 16]), {})
+cnt: 12, ((T([50176, 16, 6], f16), [12544, 4, 16, 6]), {})
+cnt: 24, ((T([12544, 16, 4, 6], f16), [12544, 16, 24]), {})
+cnt: 12, ((T([12608, 768], f16), [64, 197, 768]), {})
+cnt: 12, ((T([12608, 384], f16), [64, 197, 384]), {})
+cnt: 36, ((T([64, 6, 197, 64], f16), [384, 197, 64]), {})
+cnt: 12, ((T([64, 6, 64, 197], f16), [384, 64, 197]), {})
+cnt: 12, ((T([384, 197, 197], f16), [64, 6, 197, 197]), {})
+cnt: 12, ((T([384, 197, 64], f16), [64, 6, 197, 64]), {})
+cnt: 24, ((T([64, 197, 6, 64], f16), [64, 197, 384]), {})
+cnt: 12, ((T([64, 197, 2, 6, 64], f16), [64, 197, 768]), {})
+cnt: 12, ((T([64, 196, 384], f16), [12544, 384]), {})
+cnt: 12, ((T([12544, 16, 2, 4, 6], f16), [12544, 16, 48]), {})
+cnt: 1, ((T([12544, 24, 4, 4], f16), [64, 196, 384]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([12544, 24, 4, 4], f16), T([1, 24, 4, 4], f16)), {})
+cnt: 1, ((T([64, 197, 384], f16), T([1, 197, 384], f16)), {})
+cnt: 24, ((T([12544, 16, 24], f16, stride=(384, 1, 16)), T([12544, 16, 24], f16)), {})
+cnt: 12, ((T([64, 196, 384], f16, stride=(75648, 384, 1)), T([64, 196, 384], f16)), {})
+cnt: 72, ((T([64, 197, 384], f16), T([64, 197, 384], f16)), {})
+cnt: 48, ((T([12544, 16, 24], f16), T([12544, 16, 24], f16)), {})
+Operator: aten.addmm.default
+cnt: 13, ((T([384], f16), T([12544, 384], f16), T([384, 384], f16, stride=(1, 384))), {})
+cnt: 12, ((T([24], f16), T([200704, 24], f16), T([24, 24], f16, stride=(1, 24))), {})
+cnt: 12, ((T([96], f16), T([200704, 24], f16), T([24, 96], f16, stride=(1, 24))), {})
+cnt: 12, ((T([24], f16), T([200704, 96], f16), T([96, 24], f16, stride=(1, 96))), {})
+cnt: 12, ((T([384], f16), T([12608, 384], f16), T([384, 384], f16, stride=(1, 384))), {})
+cnt: 12, ((T([1536], f16), T([12608, 384], f16), T([384, 1536], f16, stride=(1, 384))), {})
+cnt: 12, ((T([384], f16), T([12608, 1536], f16), T([1536, 384], f16, stride=(1, 1536))), {})
+cnt: 1, ((T([1000], f16), T([64, 384], f16, stride=(75648, 1)), T([384, 1000], f16, stride=(1, 384))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([50176, 16, 6], f16), T([50176, 6, 16], f16)), {})
+cnt: 12, ((T([50176, 16, 16], f16), T([50176, 16, 6], f16)), {})
+cnt: 12, ((T([384, 197, 64], f16), T([384, 64, 197], f16)), {})
+cnt: 12, ((T([384, 197, 197], f16), T([384, 197, 64], f16)), {})
+cnt: 12, ((T([384, 197, 197], f16, stride=(38809, 1, 197)), T([384, 197, 64], f16)), {})
+cnt: 12, ((T([384, 197, 64], f16), T([384, 64, 197], f16, stride=(12608, 1, 64))), {})
+cnt: 12, ((T([384, 64, 197], f16, stride=(12608, 1, 64)), T([384, 197, 197], f16)), {})
+cnt: 12, ((T([384, 197, 197], f16), T([384, 197, 64], f16, stride=(12608, 1, 197))), {})
+cnt: 12, ((T([50176, 16, 16], f16, stride=(256, 1, 16)), T([50176, 16, 6], f16)), {})
+cnt: 12, ((T([50176, 16, 6], f16), T([50176, 6, 16], f16, stride=(96, 1, 6))), {})
+cnt: 12, ((T([50176, 6, 16], f16, stride=(96, 1, 6)), T([50176, 16, 16], f16)), {})
+cnt: 12, ((T([50176, 16, 16], f16), T([50176, 16, 6], f16, stride=(96, 1, 16))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 1, 384], f16, stride=(0, 384, 1)), T([64, 196, 384], f16)], 1), {})
+cnt: 12, (([T([64, 1, 384], f16, stride=(75648, 384, 1)), T([64, 196, 384], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([24, 3, 7, 7], f16), T([24], f16), [4, 4], [3, 3], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 24, 56, 56], f16), T([64, 3, 224, 224], f16), T([24, 3, 7, 7], f16), [24], [4, 4], [3, 3], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([12544, 16, 96], f16),), {})
+cnt: 12, ((T([64, 197, 1536], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([64, 197, 1536], f16), T([64, 197, 1536], f16)), {})
+cnt: 12, ((T([12544, 16, 96], f16), T([12544, 16, 96], f16)), {})
+Operator: aten.im2col.default
+cnt: 1, ((T([64, 24, 56, 56], f16), [4, 4], [1, 1], [0, 0], [4, 4]), {})
+Operator: aten.im2col_backward.default
+cnt: 1, ((T([64, 384, 196], f16, stride=(75264, 1, 384)), [56, 56], [4, 4], [1, 1], [0, 0], [4, 4]), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mm.default
+cnt: 12, ((T([200704, 24], f16), T([24, 48], f16, stride=(1, 24))), {})
+cnt: 12, ((T([200704, 24], f16), T([24, 24], f16, stride=(1, 24))), {})
+cnt: 12, ((T([12608, 384], f16), T([384, 768], f16, stride=(1, 384))), {})
+cnt: 12, ((T([12608, 384], f16), T([384, 384], f16, stride=(1, 384))), {})
+cnt: 1, ((T([64, 1000], f16), T([1000, 384], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 384], f16, stride=(75648, 1))), {})
+cnt: 12, ((T([12608, 384], f16), T([384, 1536], f16)), {})
+cnt: 12, ((T([384, 12608], f16, stride=(1, 384)), T([12608, 1536], f16)), {})
+cnt: 12, ((T([12608, 1536], f16), T([1536, 384], f16)), {})
+cnt: 12, ((T([1536, 12608], f16, stride=(1, 1536)), T([12608, 384], f16)), {})
+cnt: 24, ((T([12608, 384], f16), T([384, 384], f16)), {})
+cnt: 24, ((T([384, 12608], f16, stride=(1, 384)), T([12608, 384], f16)), {})
+cnt: 12, ((T([768, 12608], f16, stride=(1, 768)), T([12608, 384], f16)), {})
+cnt: 12, ((T([12608, 768], f16), T([768, 384], f16)), {})
+cnt: 13, ((T([12544, 384], f16), T([384, 384], f16)), {})
+cnt: 13, ((T([384, 12544], f16, stride=(1, 384)), T([12544, 384], f16)), {})
+cnt: 12, ((T([200704, 24], f16), T([24, 96], f16)), {})
+cnt: 12, ((T([24, 200704], f16, stride=(1, 24)), T([200704, 96], f16)), {})
+cnt: 12, ((T([200704, 96], f16), T([96, 24], f16)), {})
+cnt: 12, ((T([96, 200704], f16, stride=(1, 96)), T([200704, 24], f16)), {})
+cnt: 24, ((T([200704, 24], f16), T([24, 24], f16)), {})
+cnt: 24, ((T([24, 200704], f16, stride=(1, 24)), T([200704, 24], f16)), {})
+cnt: 12, ((T([48, 200704], f16, stride=(1, 48)), T([200704, 24], f16)), {})
+cnt: 12, ((T([200704, 48], f16), T([48, 24], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 24, ((T([12544, 4, 16, 16], f16), 0.408248290463863), {})
+cnt: 24, ((T([64, 6, 197, 197], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 2, ((T([64, 196, 384], f16), [384], T([384], f16), T([384], f16), 1e-05), {})
+cnt: 36, ((T([12544, 16, 24], f16, stride=(384, 1, 16)), [24], T([24], f16), T([24], f16), 1e-05), {})
+cnt: 25, ((T([64, 197, 384], f16), [384], T([384], f16), T([384], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([64, 197, 384], f16), T([64, 197, 384], f16), [384], T([64, 197, 1], f32), T([64, 197, 1], f32), T([384], f16), T([384], f16), [True, True, True]), {})
+cnt: 36, ((T([12544, 16, 24], f16), T([12544, 16, 24], f16, stride=(384, 1, 16)), [24], T([12544, 16, 1], f32), T([12544, 16, 1], f32), T([24], f16), T([24], f16), [True, True, True]), {})
+cnt: 1, ((T([64, 196, 384], f16, stride=(75648, 384, 1)), T([64, 196, 384], f16), [384], T([64, 196, 1], f32), T([64, 196, 1], f32), T([384], f16), T([384], f16), [True, True, True]), {})
+cnt: 1, ((T([64, 196, 384], f16), T([64, 196, 384], f16), [384], T([64, 196, 1], f32), T([64, 196, 1], f32), T([384], f16), T([384], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([64, 384], f16), [64, 197, 384], 1, 0), {})
+Operator: aten.slice_backward.default
+cnt: 25, ((T([64, 197, 384], f16), [64, 197, 384], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([64, 196, 384], f16, stride=(75648, 384, 1)), [64, 197, 384], 1, 1, 9223372036854775807, 1), {})
+cnt: 12, ((T([64, 1, 384], f16, stride=(75648, 384, 1)), [64, 197, 384], 1, 0, 1, 1), {})
+Operator: aten.stack.default
+cnt: 12, (([T([64, 6, 197, 64], f16), T([64, 6, 197, 64], f16, stride=(75648, 12608, 1, 197))],), {})
+cnt: 12, (([T([12544, 4, 16, 6], f16), T([12544, 4, 16, 6], f16, stride=(384, 96, 1, 16))],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 24, ((T([12608, 384], f16), [0], True), {})
+cnt: 12, ((T([12608, 1536], f16), [0], True), {})
+cnt: 13, ((T([12544, 384], f16), [0], True), {})
+cnt: 24, ((T([200704, 24], f16), [0], True), {})
+cnt: 12, ((T([200704, 96], f16), [0], True), {})
+cnt: 1, ((T([64, 197, 384], f16), [0], True), {})
+cnt: 1, ((T([64, 1, 384], f16, stride=(75648, 384, 1)), [0], True), {})
+cnt: 1, ((T([12544, 24, 4, 4], f16, stride=(384, 1, 96, 24)), [0], True), {})
+Operator: aten.unbind.int
+cnt: 12, ((T([2, 12544, 4, 16, 6], f16, stride=(24, 768, 6, 48, 1)),), {})
+cnt: 12, ((T([2, 64, 6, 197, 64], f16, stride=(384, 151296, 64, 768, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/twins_pcpvt_base_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/twins_pcpvt_base_training.txt
new file mode 100644
index 0000000000000..f3a99cba2b649
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/twins_pcpvt_base_training.txt
@@ -0,0 +1,245 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([32, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([32, 1000], f16), T([32, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 3, ((T([32, 1, 3136, 49], f16), -1, False), {})
+cnt: 4, ((T([32, 2, 784, 49], f16), -1, False), {})
+cnt: 18, ((T([32, 5, 196, 49], f16), -1, False), {})
+cnt: 3, ((T([32, 8, 49, 49], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 3, ((T([32, 8, 49, 49], f16), T([32, 8, 49, 49], f16), -1, f16), {})
+cnt: 18, ((T([32, 5, 196, 49], f16), T([32, 5, 196, 49], f16), -1, f16), {})
+cnt: 4, ((T([32, 2, 784, 49], f16), T([32, 2, 784, 49], f16), -1, f16), {})
+cnt: 3, ((T([32, 1, 3136, 49], f16), T([32, 1, 3136, 49], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 3, ((T([32, 3136, 49], f16), [32, 1, 3136, 49]), {})
+cnt: 3, ((T([32, 3136, 64], f16), [32, 1, 3136, 64]), {})
+cnt: 8, ((T([32, 2, 784, 64], f16), [64, 784, 64]), {})
+cnt: 4, ((T([32, 2, 64, 49], f16), [64, 64, 49]), {})
+cnt: 4, ((T([64, 784, 49], f16), [32, 2, 784, 49]), {})
+cnt: 4, ((T([32, 2, 49, 64], f16), [64, 49, 64]), {})
+cnt: 4, ((T([64, 784, 64], f16), [32, 2, 784, 64]), {})
+cnt: 8, ((T([32, 784, 2, 64], f16), [32, 784, 128]), {})
+cnt: 36, ((T([32, 5, 196, 64], f16), [160, 196, 64]), {})
+cnt: 18, ((T([32, 5, 64, 49], f16), [160, 64, 49]), {})
+cnt: 18, ((T([160, 196, 49], f16), [32, 5, 196, 49]), {})
+cnt: 18, ((T([32, 5, 49, 64], f16), [160, 49, 64]), {})
+cnt: 18, ((T([160, 196, 64], f16), [32, 5, 196, 64]), {})
+cnt: 36, ((T([32, 196, 5, 64], f16), [32, 196, 320]), {})
+cnt: 9, ((T([32, 8, 49, 64], f16), [256, 49, 64]), {})
+cnt: 3, ((T([32, 8, 64, 49], f16), [256, 64, 49]), {})
+cnt: 3, ((T([256, 49, 49], f16), [32, 8, 49, 49]), {})
+cnt: 3, ((T([256, 49, 64], f16), [32, 8, 49, 64]), {})
+cnt: 6, ((T([32, 49, 8, 64], f16), [32, 49, 512]), {})
+cnt: 3, ((T([32, 49, 2, 8, 64], f16), [32, 49, 1024]), {})
+cnt: 36, ((T([32, 196, 320], f16), [6272, 320]), {})
+cnt: 18, ((T([32, 49, 2, 5, 64], f16), [32, 49, 640]), {})
+cnt: 8, ((T([32, 784, 128], f16), [25088, 128]), {})
+cnt: 4, ((T([32, 49, 2, 2, 64], f16), [32, 49, 256]), {})
+cnt: 6, ((T([32, 3136, 64], f16), [100352, 64]), {})
+cnt: 3, ((T([32, 49, 2, 1, 64], f16), [32, 49, 128]), {})
+Operator: aten.add.Tensor
+cnt: 9, ((T([32, 3136, 64], f16), T([32, 3136, 64], f16)), {})
+cnt: 12, ((T([32, 784, 128], f16), T([32, 784, 128], f16)), {})
+cnt: 54, ((T([32, 196, 320], f16), T([32, 196, 320], f16)), {})
+cnt: 15, ((T([32, 49, 512], f16), T([32, 49, 512], f16)), {})
+cnt: 3, ((T([2, 32, 8, 49, 64], f16), T([2, 32, 8, 49, 64], f16)), {})
+cnt: 1, ((T([32, 512, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([32, 512, 7, 7], f16, stride=(25088, 1, 3584, 512))), {})
+cnt: 36, ((T([32, 196, 320], f16, stride=(62720, 1, 196)), T([32, 196, 320], f16)), {})
+cnt: 18, ((T([2, 32, 5, 49, 64], f16), T([2, 32, 5, 49, 64], f16)), {})
+cnt: 1, ((T([32, 320, 14, 14], f16), T([32, 320, 14, 14], f16, stride=(62720, 1, 4480, 320))), {})
+cnt: 8, ((T([32, 784, 128], f16, stride=(100352, 1, 784)), T([32, 784, 128], f16)), {})
+cnt: 4, ((T([2, 32, 2, 49, 64], f16), T([2, 32, 2, 49, 64], f16)), {})
+cnt: 1, ((T([32, 128, 28, 28], f16), T([32, 128, 28, 28], f16, stride=(100352, 1, 3584, 128))), {})
+cnt: 6, ((T([32, 3136, 64], f16, stride=(200704, 1, 3136)), T([32, 3136, 64], f16)), {})
+cnt: 3, ((T([2, 32, 1, 49, 64], f16), T([2, 32, 1, 49, 64], f16)), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16, stride=(200704, 1, 3584, 64))), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([32, 64, 56, 56], f16, stride=(200704, 1, 3584, 64)), T([32, 64, 56, 56], f16, stride=(200704, 1, 3584, 64))), {})
+cnt: 1, ((T([32, 128, 28, 28], f16, stride=(100352, 1, 3584, 128)), T([32, 128, 28, 28], f16, stride=(100352, 1, 3584, 128))), {})
+cnt: 1, ((T([32, 320, 14, 14], f16, stride=(62720, 1, 4480, 320)), T([32, 320, 14, 14], f16, stride=(62720, 1, 4480, 320))), {})
+cnt: 1, ((T([32, 512, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([32, 512, 7, 7], f16, stride=(25088, 1, 3584, 512))), {})
+Operator: aten.addmm.default
+cnt: 6, ((T([64], f16), T([100352, 64], f16), T([64, 64], f16, stride=(1, 64))), {})
+cnt: 3, ((T([128], f16), T([1568, 64], f16), T([64, 128], f16, stride=(1, 64))), {})
+cnt: 3, ((T([512], f16), T([100352, 64], f16), T([64, 512], f16, stride=(1, 64))), {})
+cnt: 3, ((T([64], f16), T([100352, 512], f16), T([512, 64], f16, stride=(1, 512))), {})
+cnt: 8, ((T([128], f16), T([25088, 128], f16), T([128, 128], f16, stride=(1, 128))), {})
+cnt: 4, ((T([256], f16), T([1568, 128], f16), T([128, 256], f16, stride=(1, 128))), {})
+cnt: 4, ((T([1024], f16), T([25088, 128], f16), T([128, 1024], f16, stride=(1, 128))), {})
+cnt: 4, ((T([128], f16), T([25088, 1024], f16), T([1024, 128], f16, stride=(1, 1024))), {})
+cnt: 36, ((T([320], f16), T([6272, 320], f16), T([320, 320], f16, stride=(1, 320))), {})
+cnt: 18, ((T([640], f16), T([1568, 320], f16), T([320, 640], f16, stride=(1, 320))), {})
+cnt: 18, ((T([1280], f16), T([6272, 320], f16), T([320, 1280], f16, stride=(1, 320))), {})
+cnt: 18, ((T([320], f16), T([6272, 1280], f16), T([1280, 320], f16, stride=(1, 1280))), {})
+cnt: 6, ((T([512], f16), T([1568, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 3, ((T([1024], f16), T([1568, 512], f16), T([512, 1024], f16, stride=(1, 512))), {})
+cnt: 3, ((T([2048], f16), T([1568, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 3, ((T([512], f16), T([1568, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+cnt: 1, ((T([1000], f16), T([32, 512], f16), T([512, 1000], f16, stride=(1, 512))), {})
+Operator: aten.bmm.default
+cnt: 6, ((T([32, 3136, 64], f16), T([32, 64, 49], f16, stride=(6272, 1, 128))), {})
+cnt: 6, ((T([32, 3136, 49], f16), T([32, 49, 64], f16, stride=(6272, 128, 1))), {})
+cnt: 4, ((T([64, 784, 64], f16), T([64, 64, 49], f16)), {})
+cnt: 4, ((T([64, 784, 49], f16), T([64, 49, 64], f16)), {})
+cnt: 18, ((T([160, 196, 64], f16), T([160, 64, 49], f16)), {})
+cnt: 18, ((T([160, 196, 49], f16), T([160, 49, 64], f16)), {})
+cnt: 3, ((T([256, 49, 64], f16), T([256, 64, 49], f16)), {})
+cnt: 3, ((T([256, 49, 49], f16), T([256, 49, 64], f16)), {})
+cnt: 3, ((T([256, 49, 49], f16, stride=(2401, 1, 49)), T([256, 49, 64], f16)), {})
+cnt: 3, ((T([256, 49, 64], f16), T([256, 64, 49], f16, stride=(3136, 1, 64))), {})
+cnt: 3, ((T([256, 64, 49], f16, stride=(3136, 1, 64)), T([256, 49, 49], f16)), {})
+cnt: 3, ((T([256, 49, 49], f16), T([256, 49, 64], f16, stride=(3136, 1, 49))), {})
+cnt: 18, ((T([160, 49, 196], f16, stride=(9604, 1, 49)), T([160, 196, 64], f16)), {})
+cnt: 18, ((T([160, 196, 64], f16), T([160, 64, 49], f16, stride=(3136, 1, 64))), {})
+cnt: 18, ((T([160, 64, 196], f16, stride=(12544, 1, 64)), T([160, 196, 49], f16)), {})
+cnt: 18, ((T([160, 196, 49], f16), T([160, 49, 64], f16, stride=(3136, 1, 49))), {})
+cnt: 4, ((T([64, 49, 784], f16, stride=(38416, 1, 49)), T([64, 784, 64], f16)), {})
+cnt: 4, ((T([64, 784, 64], f16), T([64, 64, 49], f16, stride=(3136, 1, 64))), {})
+cnt: 4, ((T([64, 64, 784], f16, stride=(50176, 1, 64)), T([64, 784, 49], f16)), {})
+cnt: 4, ((T([64, 784, 49], f16), T([64, 49, 64], f16, stride=(3136, 1, 49))), {})
+cnt: 3, ((T([32, 49, 3136], f16, stride=(153664, 1, 49)), T([32, 3136, 64], f16)), {})
+cnt: 3, ((T([32, 64, 3136], f16, stride=(200704, 1, 64)), T([32, 3136, 49], f16)), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([64, 3, 4, 4], f16), T([64], f16), [4, 4], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 64, 56, 56], f16, stride=(200704, 1, 3584, 64)), T([64, 64, 8, 8], f16), T([64], f16), [8, 8], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 56, 56], f16, stride=(200704, 1, 3584, 64)), T([64, 1, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([128, 64, 2, 2], f16), T([128], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 128, 28, 28], f16, stride=(100352, 1, 3584, 128)), T([128, 128, 4, 4], f16), T([128], f16), [4, 4], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 28, 28], f16, stride=(100352, 1, 3584, 128)), T([128, 1, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 128), {})
+cnt: 1, ((T([32, 128, 28, 28], f16), T([320, 128, 2, 2], f16), T([320], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 18, ((T([32, 320, 14, 14], f16, stride=(62720, 1, 4480, 320)), T([320, 320, 2, 2], f16), T([320], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 320, 14, 14], f16, stride=(62720, 1, 4480, 320)), T([320, 1, 3, 3], f16), T([320], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 320), {})
+cnt: 1, ((T([32, 320, 14, 14], f16), T([512, 320, 2, 2], f16), T([512], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([512, 1, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 512), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 512, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([32, 512, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([512, 1, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 512, [True, True, True]), {})
+cnt: 1, ((T([32, 512, 7, 7], f16, stride=(25088, 1, 3584, 512)), T([32, 320, 14, 14], f16), T([512, 320, 2, 2], f16), [512], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 18, ((T([32, 320, 7, 7], f16, stride=(15680, 1, 2240, 320)), T([32, 320, 14, 14], f16, stride=(62720, 1, 4480, 320)), T([320, 320, 2, 2], f16), [320], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 320, 14, 14], f16), T([32, 320, 14, 14], f16, stride=(62720, 1, 4480, 320)), T([320, 1, 3, 3], f16), [320], [1, 1], [1, 1], [1, 1], False, [0, 0], 320, [True, True, True]), {})
+cnt: 1, ((T([32, 320, 14, 14], f16), T([32, 128, 28, 28], f16), T([320, 128, 2, 2], f16), [320], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([32, 128, 7, 7], f16, stride=(6272, 1, 896, 128)), T([32, 128, 28, 28], f16, stride=(100352, 1, 3584, 128)), T([128, 128, 4, 4], f16), [128], [4, 4], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 28, 28], f16), T([32, 128, 28, 28], f16, stride=(100352, 1, 3584, 128)), T([128, 1, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 128, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 28, 28], f16), T([32, 64, 56, 56], f16), T([128, 64, 2, 2], f16), [128], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 64, 7, 7], f16, stride=(3136, 1, 448, 64)), T([32, 64, 56, 56], f16, stride=(200704, 1, 3584, 64)), T([64, 64, 8, 8], f16), [64], [8, 8], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16, stride=(200704, 1, 3584, 64)), T([64, 1, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 64, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 3, 224, 224], f16), T([64, 3, 4, 4], f16), [64], [4, 4], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+cnt: 18, ((T([320, 320, 2, 2], f16), T([320, 320, 2, 2], f16, stride=(1280, 1, 640, 320))), {})
+cnt: 4, ((T([128, 128, 4, 4], f16), T([128, 128, 4, 4], f16, stride=(2048, 1, 512, 128))), {})
+cnt: 3, ((T([64, 64, 8, 8], f16), T([64, 64, 8, 8], f16, stride=(4096, 1, 512, 64))), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 49, 512], f16, stride=(512, 0, 1)), 49), {})
+Operator: aten.gelu.default
+cnt: 3, ((T([32, 3136, 512], f16),), {})
+cnt: 4, ((T([32, 784, 1024], f16),), {})
+cnt: 18, ((T([32, 196, 1280], f16),), {})
+cnt: 3, ((T([32, 49, 2048], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 3, ((T([32, 49, 2048], f16), T([32, 49, 2048], f16)), {})
+cnt: 18, ((T([32, 196, 1280], f16), T([32, 196, 1280], f16)), {})
+cnt: 4, ((T([32, 784, 1024], f16), T([32, 784, 1024], f16)), {})
+cnt: 3, ((T([32, 3136, 512], f16), T([32, 3136, 512], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([32], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 49, 512], f16), [1]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16), T([1000, 512], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(1, 1000)), T([32, 512], f16)), {})
+cnt: 3, ((T([1568, 512], f16), T([512, 2048], f16)), {})
+cnt: 3, ((T([512, 1568], f16, stride=(1, 512)), T([1568, 2048], f16)), {})
+cnt: 3, ((T([1568, 2048], f16), T([2048, 512], f16)), {})
+cnt: 3, ((T([2048, 1568], f16, stride=(1, 2048)), T([1568, 512], f16)), {})
+cnt: 6, ((T([1568, 512], f16), T([512, 512], f16)), {})
+cnt: 6, ((T([512, 1568], f16, stride=(1, 512)), T([1568, 512], f16)), {})
+cnt: 3, ((T([1568, 1024], f16), T([1024, 512], f16)), {})
+cnt: 3, ((T([1024, 1568], f16, stride=(1, 1024)), T([1568, 512], f16)), {})
+cnt: 18, ((T([6272, 320], f16), T([320, 1280], f16)), {})
+cnt: 18, ((T([320, 6272], f16, stride=(1, 320)), T([6272, 1280], f16)), {})
+cnt: 18, ((T([6272, 1280], f16), T([1280, 320], f16)), {})
+cnt: 18, ((T([1280, 6272], f16, stride=(1, 1280)), T([6272, 320], f16)), {})
+cnt: 36, ((T([6272, 320], f16), T([320, 320], f16)), {})
+cnt: 36, ((T([320, 6272], f16, stride=(1, 320)), T([6272, 320], f16)), {})
+cnt: 18, ((T([1568, 640], f16), T([640, 320], f16)), {})
+cnt: 18, ((T([640, 1568], f16, stride=(1, 640)), T([1568, 320], f16)), {})
+cnt: 4, ((T([25088, 128], f16), T([128, 1024], f16)), {})
+cnt: 4, ((T([128, 25088], f16, stride=(1, 128)), T([25088, 1024], f16)), {})
+cnt: 4, ((T([25088, 1024], f16), T([1024, 128], f16)), {})
+cnt: 4, ((T([1024, 25088], f16, stride=(1, 1024)), T([25088, 128], f16)), {})
+cnt: 8, ((T([25088, 128], f16), T([128, 128], f16)), {})
+cnt: 8, ((T([128, 25088], f16, stride=(1, 128)), T([25088, 128], f16)), {})
+cnt: 4, ((T([1568, 256], f16), T([256, 128], f16)), {})
+cnt: 4, ((T([256, 1568], f16, stride=(1, 256)), T([1568, 128], f16)), {})
+cnt: 3, ((T([100352, 64], f16), T([64, 512], f16)), {})
+cnt: 3, ((T([64, 100352], f16, stride=(1, 64)), T([100352, 512], f16)), {})
+cnt: 3, ((T([100352, 512], f16), T([512, 64], f16)), {})
+cnt: 3, ((T([512, 100352], f16, stride=(1, 512)), T([100352, 64], f16)), {})
+cnt: 6, ((T([100352, 64], f16), T([64, 64], f16)), {})
+cnt: 6, ((T([64, 100352], f16, stride=(1, 64)), T([100352, 64], f16)), {})
+cnt: 3, ((T([1568, 128], f16), T([128, 64], f16)), {})
+cnt: 3, ((T([128, 1568], f16, stride=(1, 128)), T([1568, 64], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 6, ((T([32, 1, 3136, 49], f16), 0.125), {})
+cnt: 8, ((T([32, 2, 784, 49], f16), 0.125), {})
+cnt: 36, ((T([32, 5, 196, 49], f16), 0.125), {})
+cnt: 6, ((T([32, 8, 49, 49], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 1, ((T([32, 3136, 64], f16, stride=(200704, 1, 3136)), [64], T([64], f16), T([64], f16), 1e-05), {})
+cnt: 6, ((T([32, 3136, 64], f16), [64], T([64], f16), T([64], f16), 1e-06), {})
+cnt: 3, ((T([32, 49, 64], f16), [64], T([64], f16), T([64], f16), 1e-05), {})
+cnt: 1, ((T([32, 784, 128], f16, stride=(100352, 1, 784)), [128], T([128], f16), T([128], f16), 1e-05), {})
+cnt: 8, ((T([32, 784, 128], f16), [128], T([128], f16), T([128], f16), 1e-06), {})
+cnt: 4, ((T([32, 49, 128], f16), [128], T([128], f16), T([128], f16), 1e-05), {})
+cnt: 1, ((T([32, 196, 320], f16, stride=(62720, 1, 196)), [320], T([320], f16), T([320], f16), 1e-05), {})
+cnt: 36, ((T([32, 196, 320], f16), [320], T([320], f16), T([320], f16), 1e-06), {})
+cnt: 18, ((T([32, 49, 320], f16), [320], T([320], f16), T([320], f16), 1e-05), {})
+cnt: 1, ((T([32, 49, 512], f16, stride=(25088, 1, 49)), [512], T([512], f16), T([512], f16), 1e-05), {})
+cnt: 7, ((T([32, 49, 512], f16), [512], T([512], f16), T([512], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 7, ((T([32, 49, 512], f16), T([32, 49, 512], f16), [512], T([32, 49, 1], f32), T([32, 49, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 1, ((T([32, 49, 512], f16), T([32, 49, 512], f16, stride=(25088, 1, 49)), [512], T([32, 49, 1], f32), T([32, 49, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 36, ((T([32, 196, 320], f16), T([32, 196, 320], f16), [320], T([32, 196, 1], f32), T([32, 196, 1], f32), T([320], f16), T([320], f16), [True, True, True]), {})
+cnt: 18, ((T([32, 49, 320], f16), T([32, 49, 320], f16), [320], T([32, 49, 1], f32), T([32, 49, 1], f32), T([320], f16), T([320], f16), [True, True, True]), {})
+cnt: 1, ((T([32, 196, 320], f16, stride=(62720, 1, 196)), T([32, 196, 320], f16, stride=(62720, 1, 196)), [320], T([32, 196, 1], f32), T([32, 196, 1], f32), T([320], f16), T([320], f16), [True, True, True]), {})
+cnt: 8, ((T([32, 784, 128], f16), T([32, 784, 128], f16), [128], T([32, 784, 1], f32), T([32, 784, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 4, ((T([32, 49, 128], f16), T([32, 49, 128], f16), [128], T([32, 49, 1], f32), T([32, 49, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 1, ((T([32, 784, 128], f16, stride=(100352, 1, 784)), T([32, 784, 128], f16, stride=(100352, 1, 784)), [128], T([32, 784, 1], f32), T([32, 784, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 6, ((T([32, 3136, 64], f16), T([32, 3136, 64], f16), [64], T([32, 3136, 1], f32), T([32, 3136, 1], f32), T([64], f16), T([64], f16), [True, True, True]), {})
+cnt: 3, ((T([32, 49, 64], f16), T([32, 49, 64], f16), [64], T([32, 49, 1], f32), T([32, 49, 1], f32), T([64], f16), T([64], f16), [True, True, True]), {})
+cnt: 1, ((T([32, 3136, 64], f16, stride=(200704, 1, 3136)), T([32, 3136, 64], f16, stride=(200704, 1, 3136)), [64], T([32, 3136, 1], f32), T([32, 3136, 1], f32), T([64], f16), T([64], f16), [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 18, ((T([320, 320, 2, 2], f16, stride=(1280, 1, 640, 320)), [320, 320, 2, 2], [1280, 4, 2, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 4, ((T([128, 128, 4, 4], f16, stride=(2048, 1, 512, 128)), [128, 128, 4, 4], [2048, 16, 4, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 3, ((T([64, 64, 8, 8], f16, stride=(4096, 1, 512, 64)), [64, 64, 8, 8], [4096, 64, 8, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([32, 1000], f16), T([32], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([32, 1000], f16), T([32], i64), None, 1, -100), {})
+Operator: aten.select_backward.default
+cnt: 3, ((T([32, 8, 49, 64], f16), [2, 32, 8, 49, 64], 0, 1), {})
+cnt: 3, ((T([32, 8, 49, 64], f16, stride=(25088, 3136, 1, 49)), [2, 32, 8, 49, 64], 0, 0), {})
+cnt: 18, ((T([32, 5, 49, 64], f16), [2, 32, 5, 49, 64], 0, 1), {})
+cnt: 18, ((T([32, 5, 49, 64], f16, stride=(15680, 3136, 1, 49)), [2, 32, 5, 49, 64], 0, 0), {})
+cnt: 4, ((T([32, 2, 49, 64], f16), [2, 32, 2, 49, 64], 0, 1), {})
+cnt: 4, ((T([32, 2, 49, 64], f16, stride=(6272, 3136, 1, 49)), [2, 32, 2, 49, 64], 0, 0), {})
+cnt: 3, ((T([32, 1, 49, 64], f16), [2, 32, 1, 49, 64], 0, 1), {})
+cnt: 3, ((T([32, 1, 49, 64], f16, stride=(3136, 3136, 1, 49)), [2, 32, 1, 49, 64], 0, 0), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16), [0], True), {})
+cnt: 9, ((T([1568, 512], f16), [0], True), {})
+cnt: 3, ((T([1568, 2048], f16), [0], True), {})
+cnt: 3, ((T([1568, 1024], f16), [0], True), {})
+cnt: 54, ((T([6272, 320], f16), [0], True), {})
+cnt: 18, ((T([6272, 1280], f16), [0], True), {})
+cnt: 18, ((T([1568, 640], f16), [0], True), {})
+cnt: 12, ((T([25088, 128], f16), [0], True), {})
+cnt: 4, ((T([25088, 1024], f16), [0], True), {})
+cnt: 4, ((T([1568, 256], f16), [0], True), {})
+cnt: 9, ((T([100352, 64], f16), [0], True), {})
+cnt: 3, ((T([100352, 512], f16), [0], True), {})
+cnt: 3, ((T([1568, 128], f16), [0], True), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/visformer_small_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/visformer_small_training.txt
new file mode 100644
index 0000000000000..76ef9f17620e7
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/visformer_small_training.txt
@@ -0,0 +1,132 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([128, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([128, 1000], f16), T([128, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 4, ((T([128, 6, 196, 196], f16), -1, False), {})
+cnt: 4, ((T([128, 6, 49, 49], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 4, ((T([128, 6, 49, 49], f16), T([128, 6, 49, 49], f16), -1, f16), {})
+cnt: 4, ((T([128, 6, 196, 196], f16), T([128, 6, 196, 196], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 8, ((T([128, 6, 196, 64], f16), [768, 196, 64]), {})
+cnt: 4, ((T([128, 6, 64, 196], f16), [768, 64, 196]), {})
+cnt: 4, ((T([768, 196, 196], f16), [128, 6, 196, 196]), {})
+cnt: 4, ((T([768, 196, 64], f16), [128, 6, 196, 64]), {})
+cnt: 4, ((T([128, 6, 64, 196], f16), [128, 384, 14, 14]), {})
+cnt: 8, ((T([128, 6, 49, 128], f16), [768, 49, 128]), {})
+cnt: 4, ((T([128, 6, 128, 49], f16), [768, 128, 49]), {})
+cnt: 4, ((T([768, 49, 49], f16), [128, 6, 49, 49]), {})
+cnt: 4, ((T([768, 49, 128], f16), [128, 6, 49, 128]), {})
+cnt: 4, ((T([128, 6, 128, 49], f16), [128, 768, 7, 7]), {})
+cnt: 4, ((T([128, 3, 6, 128, 49], f16), [128, 2304, 7, 7]), {})
+cnt: 4, ((T([128, 3, 6, 64, 196], f16), [128, 1152, 14, 14]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128, 192, 28, 28], f16), T([1, 192, 28, 28], f16)), {})
+cnt: 14, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16)), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([1, 384, 14, 14], f16)), {})
+cnt: 16, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16)), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([1, 768, 7, 7], f16)), {})
+cnt: 16, ((T([128, 768, 7, 7], f16), T([128, 768, 7, 7], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 28, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 768], f16), T([768, 1000], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 4, ((T([768, 196, 64], f16), T([768, 64, 196], f16)), {})
+cnt: 4, ((T([768, 196, 196], f16), T([768, 196, 64], f16)), {})
+cnt: 4, ((T([768, 49, 128], f16), T([768, 128, 49], f16)), {})
+cnt: 4, ((T([768, 49, 49], f16), T([768, 49, 128], f16)), {})
+cnt: 4, ((T([768, 49, 49], f16, stride=(2401, 1, 49)), T([768, 49, 128], f16, stride=(6272, 1, 49))), {})
+cnt: 4, ((T([768, 49, 128], f16, stride=(6272, 1, 49)), T([768, 128, 49], f16, stride=(6272, 1, 128))), {})
+cnt: 4, ((T([768, 128, 49], f16, stride=(6272, 1, 128)), T([768, 49, 49], f16)), {})
+cnt: 4, ((T([768, 49, 49], f16), T([768, 49, 128], f16, stride=(6272, 1, 49))), {})
+cnt: 4, ((T([768, 196, 196], f16, stride=(38416, 1, 196)), T([768, 196, 64], f16, stride=(12544, 1, 196))), {})
+cnt: 4, ((T([768, 196, 64], f16, stride=(12544, 1, 196)), T([768, 64, 196], f16, stride=(12544, 1, 64))), {})
+cnt: 4, ((T([768, 64, 196], f16, stride=(12544, 1, 64)), T([768, 196, 196], f16)), {})
+cnt: 4, ((T([768, 196, 196], f16), T([768, 196, 64], f16, stride=(12544, 1, 196))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([32, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([192, 32, 4, 4], f16), T([192], f16), [4, 4], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 192, 28, 28], f16), T([384, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 384, 28, 28], f16), T([384, 48, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 7, ((T([128, 384, 28, 28], f16), T([192, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([384, 192, 2, 2], f16), T([384], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 14, 14], f16), T([1152, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 14, 14], f16), T([384, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 384, 14, 14], f16), T([1536, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 1536, 14, 14], f16), T([384, 1536, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([768, 384, 2, 2], f16), T([768], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 768, 7, 7], f16), T([2304, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 768, 7, 7], f16), T([768, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 768, 7, 7], f16), T([3072, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([128, 3072, 7, 7], f16), T([768, 3072, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 4, ((T([128, 768, 7, 7], f16), T([128, 3072, 7, 7], f16), T([768, 3072, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 3072, 7, 7], f16), T([128, 768, 7, 7], f16), T([3072, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 768, 7, 7], f16), T([128, 768, 7, 7], f16), T([768, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 2304, 7, 7], f16), T([128, 768, 7, 7], f16), T([2304, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), T([128, 384, 14, 14], f16), T([768, 384, 2, 2], f16), [768], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([128, 384, 14, 14], f16), T([128, 1536, 14, 14], f16), T([384, 1536, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 1536, 14, 14], f16), T([128, 384, 14, 14], f16), T([1536, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 1152, 14, 14], f16), T([128, 384, 14, 14], f16), T([1152, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), T([128, 192, 28, 28], f16), T([384, 192, 2, 2], f16), [384], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 7, ((T([128, 192, 28, 28], f16), T([128, 384, 28, 28], f16), T([192, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([128, 384, 28, 28], f16), T([128, 384, 28, 28], f16), T([384, 48, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 7, ((T([128, 384, 28, 28], f16), T([128, 192, 28, 28], f16), T([384, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), T([128, 32, 112, 112], f16), T([192, 32, 4, 4], f16), [192], [4, 4], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 3, 224, 224], f16), T([32, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 768, 7, 7], f16, stride=(768, 1, 0, 0)), 49), {})
+Operator: aten.gelu.default
+cnt: 14, ((T([128, 384, 28, 28], f16),), {})
+cnt: 4, ((T([128, 1536, 14, 14], f16),), {})
+cnt: 4, ((T([128, 3072, 7, 7], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 4, ((T([128, 3072, 7, 7], f16), T([128, 3072, 7, 7], f16)), {})
+cnt: 4, ((T([128, 1536, 14, 14], f16), T([128, 1536, 14, 14], f16)), {})
+cnt: 14, ((T([128, 384, 28, 28], f16), T([128, 384, 28, 28], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([128], i64),), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 768, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16), T([1000, 768], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(1, 1000)), T([128, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 8, ((T([128, 6, 196, 196], f16), 0.125), {})
+cnt: 8, ((T([128, 6, 49, 49], f16), 0.08838834764831845), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), True, 0.1, 1e-05), {})
+cnt: 8, ((T([128, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), True, 0.1, 1e-05), {})
+cnt: 9, ((T([128, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), True, 0.1, 1e-05), {})
+cnt: 10, ((T([128, 768, 7, 7], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 10, ((T([128, 768, 7, 7], f16), T([128, 768, 7, 7], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), True, 1e-05, [True, True, True]), {})
+cnt: 9, ((T([128, 384, 14, 14], f16), T([128, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), True, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([128, 192, 28, 28], f16), T([128, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), True, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([128, 1000], f16), T([128], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([128, 1000], f16), T([128], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 32, 112, 112], f16),), {})
+Operator: aten.stack.default
+cnt: 4, (([T([128, 6, 49, 128], f16), T([128, 6, 49, 128], f16, stride=(37632, 6272, 1, 49)), T([128, 6, 49, 128], f16)],), {})
+cnt: 4, (([T([128, 6, 196, 64], f16), T([128, 6, 196, 64], f16, stride=(75264, 12544, 1, 196)), T([128, 6, 196, 64], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16), [0], True), {})
+cnt: 1, ((T([128, 768, 7, 7], f16), [0], True), {})
+cnt: 1, ((T([128, 384, 14, 14], f16), [0], True), {})
+cnt: 1, ((T([128, 192, 28, 28], f16), [0], True), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 32, 112, 112], f16), T([128, 32, 112, 112], f16), 0), {})
+Operator: aten.unbind.int
+cnt: 4, ((T([3, 128, 6, 196, 64], f16, stride=(75264, 225792, 12544, 1, 196)),), {})
+cnt: 4, ((T([3, 128, 6, 49, 128], f16, stride=(37632, 112896, 6272, 1, 49)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/vit_base_patch16_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/vit_base_patch16_224_training.txt
new file mode 100644
index 0000000000000..8d2c7bd9a7409
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/vit_base_patch16_224_training.txt
@@ -0,0 +1,83 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([64, 12, 197, 197], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([64, 12, 197, 197], f16), T([64, 12, 197, 197], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([64, 12, 197, 64], f16), [768, 197, 64]), {})
+cnt: 12, ((T([64, 12, 64, 197], f16), [768, 64, 197]), {})
+cnt: 12, ((T([768, 197, 197], f16), [64, 12, 197, 197]), {})
+cnt: 12, ((T([768, 197, 64], f16), [64, 12, 197, 64]), {})
+cnt: 12, ((T([64, 197, 12, 64], f16), [64, 197, 768]), {})
+cnt: 12, ((T([64, 197, 3, 12, 64], f16), [64, 197, 2304]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([64, 197, 768], f16), T([1, 197, 768], f16)), {})
+cnt: 48, ((T([64, 197, 768], f16), T([64, 197, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([2304], f16), T([12608, 768], f16), T([768, 2304], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([12608, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([12608, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([12608, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([1000], f16), T([64, 768], f16, stride=(151296, 1)), T([768, 1000], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([768, 197, 64], f16), T([768, 64, 197], f16)), {})
+cnt: 12, ((T([768, 197, 197], f16), T([768, 197, 64], f16)), {})
+cnt: 12, ((T([768, 197, 197], f16, stride=(38809, 1, 197)), T([768, 197, 64], f16)), {})
+cnt: 12, ((T([768, 197, 64], f16), T([768, 64, 197], f16, stride=(12608, 1, 64))), {})
+cnt: 12, ((T([768, 64, 197], f16, stride=(12608, 1, 64)), T([768, 197, 197], f16)), {})
+cnt: 12, ((T([768, 197, 197], f16), T([768, 197, 64], f16, stride=(12608, 1, 197))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 1, 768], f16, stride=(0, 768, 1)), T([64, 196, 768], f16, stride=(150528, 1, 196))], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([768, 3, 16, 16], f16), T([768], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 768, 14, 14], f16, stride=(151296, 1, 10752, 768)), T([64, 3, 224, 224], f16), T([768, 3, 16, 16], f16), [768], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([64, 197, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([64, 197, 3072], f16), T([64, 197, 3072], f16)), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16), T([1000, 768], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 768], f16, stride=(151296, 1))), {})
+cnt: 12, ((T([12608, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 12608], f16, stride=(1, 768)), T([12608, 3072], f16)), {})
+cnt: 12, ((T([12608, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 12608], f16, stride=(1, 3072)), T([12608, 768], f16)), {})
+cnt: 12, ((T([12608, 768], f16), T([768, 768], f16)), {})
+cnt: 12, ((T([768, 12608], f16, stride=(1, 768)), T([12608, 768], f16)), {})
+cnt: 12, ((T([12608, 2304], f16), T([2304, 768], f16)), {})
+cnt: 12, ((T([2304, 12608], f16, stride=(1, 2304)), T([12608, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 24, ((T([64, 12, 197, 197], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([64, 197, 768], f16), [768], T([768], f16), T([768], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([64, 197, 768], f16), T([64, 197, 768], f16), [768], T([64, 197, 1], f32), T([64, 197, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([64, 768], f16), [64, 197, 768], 1, 0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([64, 197, 768], f16), [64, 197, 768], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.stack.default
+cnt: 12, (([T([64, 12, 197, 64], f16), T([64, 12, 197, 64], f16, stride=(151296, 12608, 1, 197)), T([64, 12, 197, 64], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 24, ((T([12608, 768], f16), [0], True), {})
+cnt: 12, ((T([12608, 3072], f16), [0], True), {})
+cnt: 12, ((T([12608, 2304], f16), [0], True), {})
+cnt: 1, ((T([64, 197, 768], f16), [0], True), {})
+cnt: 1, ((T([64, 1, 768], f16, stride=(151296, 768, 1)), [0], True), {})
+Operator: aten.unbind.int
+cnt: 12, ((T([3, 64, 12, 197, 64], f16, stride=(768, 453888, 64, 2304, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/volo_d1_224_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/volo_d1_224_training.txt
new file mode 100644
index 0000000000000..2f173f535c37b
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/timm_train/volo_d1_224_training.txt
@@ -0,0 +1,216 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([64, 1000], f16), 1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16), 1, f16), {})
+Operator: aten._softmax.default
+cnt: 4, ((T([64, 6, 196, 9, 9], f16, stride=(95256, 81, 486, 9, 1)), -1, False), {})
+cnt: 14, ((T([64, 12, 196, 196], f16), -1, False), {})
+cnt: 2, ((T([64, 12, 1, 197], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 2, ((T([64, 12, 1, 197], f16), T([64, 12, 1, 197], f16), -1, f16), {})
+cnt: 14, ((T([64, 12, 196, 196], f16), T([64, 12, 196, 196], f16), -1, f16), {})
+cnt: 4, ((T([64, 6, 196, 9, 9], f16), T([64, 6, 196, 9, 9], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 12, ((T([50176, 192], f16), [64, 28, 28, 192]), {})
+cnt: 4, ((T([12544, 486], f16), [64, 14, 14, 486]), {})
+cnt: 8, ((T([64, 6, 196, 9, 32], f16), [75264, 9, 32]), {})
+cnt: 4, ((T([75264, 9, 32], f16), [64, 6, 196, 9, 32]), {})
+cnt: 8, ((T([64, 6, 32, 9, 196], f16), [64, 1728, 196]), {})
+cnt: 16, ((T([64, 28, 28, 192], f16), [50176, 192]), {})
+cnt: 4, ((T([50176, 576], f16), [64, 28, 28, 576]), {})
+cnt: 28, ((T([12544, 1152], f16), [64, 14, 14, 1152]), {})
+cnt: 42, ((T([64, 12, 196, 32], f16), [768, 196, 32]), {})
+cnt: 14, ((T([64, 12, 32, 196], f16), [768, 32, 196]), {})
+cnt: 14, ((T([768, 196, 196], f16), [64, 12, 196, 196]), {})
+cnt: 14, ((T([768, 196, 32], f16), [64, 12, 196, 32]), {})
+cnt: 14, ((T([64, 196, 12, 32], f16), [64, 14, 14, 384]), {})
+cnt: 28, ((T([12544, 384], f16), [64, 14, 14, 384]), {})
+cnt: 2, ((T([12608, 768], f16), [64, 197, 768]), {})
+cnt: 2, ((T([64, 384], f16), [64, 1, 384]), {})
+cnt: 2, ((T([64, 12, 32, 197], f16), [768, 32, 197]), {})
+cnt: 2, ((T([768, 1, 197], f16), [64, 12, 1, 197]), {})
+cnt: 2, ((T([64, 12, 197, 32], f16), [768, 197, 32]), {})
+cnt: 2, ((T([768, 1, 32], f16), [64, 12, 1, 32]), {})
+cnt: 1, ((T([64, 196, 384], f16), [12544, 384]), {})
+cnt: 1, ((T([12544, 1000], f16), [64, 196, 1000]), {})
+cnt: 2, ((T([64, 197, 2, 12, 32], f16), [64, 197, 768]), {})
+cnt: 1, ((T([64, 14, 14, 384], f16), [12544, 384]), {})
+cnt: 14, ((T([64, 196, 3, 12, 32], f16), [64, 14, 14, 1152]), {})
+cnt: 4, ((T([64, 196, 6, 9, 9], f16), [64, 14, 14, 486]), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([64, 14, 14, 486], f16), T([486], f16)), {})
+cnt: 8, ((T([64, 28, 28, 192], f16), T([192], f16)), {})
+cnt: 16, ((T([64, 28, 28, 192], f16, stride=(150528, 28, 1, 784)), T([64, 28, 28, 192], f16)), {})
+cnt: 4, ((T([64, 28, 28, 576], f16), T([576], f16)), {})
+cnt: 1, ((T([64, 14, 14, 384], f16, stride=(75264, 14, 1, 196)), T([1, 14, 14, 384], f16)), {})
+cnt: 28, ((T([64, 14, 14, 384], f16), T([384], f16)), {})
+cnt: 28, ((T([64, 14, 14, 384], f16, stride=(75264, 14, 1, 196)), T([64, 14, 14, 384], f16)), {})
+cnt: 14, ((T([64, 14, 14, 1152], f16), T([1152], f16)), {})
+cnt: 4, ((T([64, 1, 384], f16, stride=(75648, 384, 1)), T([64, 1, 384], f16)), {})
+cnt: 2, ((T([64, 1, 384], f16), T([64, 1, 384], f16)), {})
+cnt: 1, ((T([64, 196, 1000], f16), T([1000], f16)), {})
+cnt: 1, ((T([64, 1000], f16), T([64, 1000], f16)), {})
+cnt: 7, ((T([64, 197, 384], f16), T([64, 197, 384], f16)), {})
+cnt: 1, ((T([64, 14, 14, 384], f16, stride=(75648, 5376, 384, 1)), T([64, 14, 14, 384], f16)), {})
+cnt: 27, ((T([64, 14, 14, 384], f16), T([64, 14, 14, 384], f16)), {})
+cnt: 4, ((T([64, 28, 28, 192], f16), T([64, 28, 28, 192], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 3, ((T([], i64), 1), {})
+Operator: aten.addmm.default
+cnt: 2, ((T([384], f16), T([64, 384], f16), T([384, 384], f16, stride=(1, 384))), {})
+cnt: 2, ((T([1152], f16), T([64, 384], f16), T([384, 1152], f16, stride=(1, 384))), {})
+cnt: 2, ((T([384], f16), T([64, 1152], f16), T([1152, 384], f16, stride=(1, 1152))), {})
+cnt: 1, ((T([1000], f16), T([64, 384], f16, stride=(75648, 1)), T([384, 1000], f16, stride=(1, 384))), {})
+Operator: aten.avg_pool2d.default
+cnt: 4, ((T([64, 192, 28, 28], f16, stride=(150528, 1, 5376, 192)), [2, 2], [2, 2], [0, 0], True), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 4, ((T([64, 192, 14, 14], f16, stride=(37632, 1, 2688, 192)), T([64, 192, 28, 28], f16, stride=(150528, 1, 5376, 192)), [2, 2], [2, 2], [0, 0], True, True, None), {})
+Operator: aten.bmm.default
+cnt: 4, ((T([75264, 9, 9], f16), T([75264, 9, 32], f16)), {})
+cnt: 14, ((T([768, 196, 32], f16), T([768, 32, 196], f16)), {})
+cnt: 14, ((T([768, 196, 196], f16), T([768, 196, 32], f16)), {})
+cnt: 2, ((T([768, 1, 32], f16), T([768, 32, 197], f16)), {})
+cnt: 2, ((T([768, 1, 197], f16), T([768, 197, 32], f16)), {})
+cnt: 2, ((T([768, 197, 1], f16), T([768, 1, 32], f16)), {})
+cnt: 2, ((T([768, 1, 32], f16), T([768, 32, 197], f16, stride=(6304, 1, 32))), {})
+cnt: 2, ((T([768, 32, 1], f16), T([768, 1, 197], f16)), {})
+cnt: 2, ((T([768, 1, 197], f16), T([768, 197, 32], f16, stride=(6304, 1, 197))), {})
+cnt: 14, ((T([768, 196, 196], f16, stride=(38416, 1, 196)), T([768, 196, 32], f16)), {})
+cnt: 14, ((T([768, 196, 32], f16), T([768, 32, 196], f16, stride=(6272, 1, 32))), {})
+cnt: 14, ((T([768, 32, 196], f16, stride=(6272, 1, 32)), T([768, 196, 196], f16)), {})
+cnt: 14, ((T([768, 196, 196], f16), T([768, 196, 32], f16, stride=(6272, 1, 196))), {})
+cnt: 4, ((T([75264, 9, 9], f16, stride=(81, 1, 9)), T([75264, 9, 32], f16)), {})
+cnt: 4, ((T([75264, 9, 32], f16), T([75264, 32, 9], f16, stride=(288, 1, 32))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([64, 1, 384], f16, stride=(0, 384, 1)), T([64, 196, 384], f16, stride=(75264, 1, 196))], 1), {})
+cnt: 2, (([T([64, 1, 384], f16), T([64, 196, 384], f16, stride=(75648, 384, 1))], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.col2im.default
+cnt: 4, ((T([64, 1728, 196], f16), [28, 28], [3, 3], [1, 1], [1, 1], [2, 2]), {})
+Operator: aten.col2im_backward.default
+cnt: 4, ((T([64, 192, 28, 28], f16, stride=(150528, 1, 5376, 192)), [3, 3], [1, 1], [1, 1], [2, 2]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 64, 112, 112], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([192, 64, 4, 4], f16), T([192], f16), [4, 4], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 192, 28, 28], f16), T([384, 192, 2, 2], f16), T([384], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([64, 384, 14, 14], f16, stride=(75264, 1, 5376, 384)), T([64, 192, 28, 28], f16), T([384, 192, 2, 2], f16), [384], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 192, 28, 28], f16), T([64, 64, 112, 112], f16), T([192, 64, 4, 4], f16), [192], [4, 4], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.gelu.default
+cnt: 4, ((T([64, 28, 28, 576], f16),), {})
+cnt: 14, ((T([64, 14, 14, 1152], f16),), {})
+cnt: 2, ((T([64, 1, 1152], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 2, ((T([64, 1, 1152], f16), T([64, 1, 1152], f16)), {})
+cnt: 14, ((T([64, 14, 14, 1152], f16), T([64, 14, 14, 1152], f16)), {})
+cnt: 4, ((T([64, 28, 28, 576], f16), T([64, 28, 28, 576], f16)), {})
+Operator: aten.im2col.default
+cnt: 4, ((T([64, 192, 28, 28], f16, stride=(150528, 1, 5376, 192)), [3, 3], [1, 1], [1, 1], [2, 2]), {})
+Operator: aten.im2col_backward.default
+cnt: 4, ((T([64, 1728, 196], f16), [28, 28], [3, 3], [1, 1], [1, 1], [2, 2]), {})
+Operator: aten.lift_fresh_copy.default
+cnt: 1, ((T([64], i64),), {})
+Operator: aten.max.dim
+cnt: 1, ((T([64, 196, 1000], f16), 1), {})
+Operator: aten.mm.default
+cnt: 8, ((T([50176, 192], f16), T([192, 192], f16, stride=(1, 192))), {})
+cnt: 4, ((T([12544, 192], f16), T([192, 486], f16, stride=(1, 192))), {})
+cnt: 4, ((T([50176, 192], f16), T([192, 576], f16, stride=(1, 192))), {})
+cnt: 4, ((T([50176, 576], f16), T([576, 192], f16, stride=(1, 576))), {})
+cnt: 28, ((T([12544, 384], f16), T([384, 1152], f16, stride=(1, 384))), {})
+cnt: 14, ((T([12544, 384], f16), T([384, 384], f16, stride=(1, 384))), {})
+cnt: 14, ((T([12544, 1152], f16), T([1152, 384], f16, stride=(1, 1152))), {})
+cnt: 2, ((T([12608, 384], f16), T([384, 768], f16, stride=(1, 384))), {})
+cnt: 2, ((T([64, 384], f16, stride=(75648, 1)), T([384, 384], f16, stride=(1, 384))), {})
+cnt: 1, ((T([12544, 384], f16), T([384, 1000], f16, stride=(1, 384))), {})
+cnt: 1, ((T([1000, 12544], f16, stride=(1, 1000)), T([12544, 384], f16)), {})
+cnt: 1, ((T([12544, 1000], f16), T([1000, 384], f16)), {})
+cnt: 1, ((T([64, 1000], f16), T([1000, 384], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(1, 1000)), T([64, 384], f16, stride=(75648, 1))), {})
+cnt: 2, ((T([64, 384], f16, stride=(75648, 1)), T([384, 1152], f16)), {})
+cnt: 2, ((T([384, 64], f16, stride=(1, 75648)), T([64, 1152], f16)), {})
+cnt: 2, ((T([64, 1152], f16), T([1152, 384], f16)), {})
+cnt: 2, ((T([1152, 64], f16, stride=(1, 1152)), T([64, 384], f16)), {})
+cnt: 4, ((T([64, 384], f16), T([384, 384], f16)), {})
+cnt: 2, ((T([384, 64], f16, stride=(1, 384)), T([64, 384], f16)), {})
+cnt: 2, ((T([384, 64], f16, stride=(1, 384)), T([64, 384], f16, stride=(75648, 1))), {})
+cnt: 2, ((T([768, 12608], f16, stride=(1, 768)), T([12608, 384], f16)), {})
+cnt: 2, ((T([12608, 768], f16), T([768, 384], f16)), {})
+cnt: 14, ((T([384, 12544], f16, stride=(1, 384)), T([12544, 1152], f16)), {})
+cnt: 14, ((T([12544, 384], f16), T([384, 1152], f16)), {})
+cnt: 28, ((T([1152, 12544], f16, stride=(1, 1152)), T([12544, 384], f16)), {})
+cnt: 28, ((T([12544, 1152], f16), T([1152, 384], f16)), {})
+cnt: 14, ((T([384, 12544], f16, stride=(1, 384)), T([12544, 384], f16)), {})
+cnt: 14, ((T([12544, 384], f16), T([384, 384], f16)), {})
+cnt: 4, ((T([192, 50176], f16, stride=(1, 192)), T([50176, 576], f16)), {})
+cnt: 4, ((T([50176, 192], f16), T([192, 576], f16)), {})
+cnt: 4, ((T([576, 50176], f16, stride=(1, 576)), T([50176, 192], f16)), {})
+cnt: 4, ((T([50176, 576], f16), T([576, 192], f16)), {})
+cnt: 8, ((T([192, 50176], f16, stride=(1, 192)), T([50176, 192], f16)), {})
+cnt: 8, ((T([50176, 192], f16), T([192, 192], f16)), {})
+cnt: 4, ((T([486, 12544], f16, stride=(1, 486)), T([12544, 192], f16)), {})
+cnt: 4, ((T([12544, 486], f16), T([486, 192], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([64, 6, 196, 9, 9], f16, stride=(95256, 81, 486, 9, 1)), 0.1767766952966369), {})
+cnt: 28, ((T([64, 12, 196, 196], f16), 0.1767766952966369), {})
+cnt: 4, ((T([64, 12, 1, 32], f16), 0.1767766952966369), {})
+cnt: 2, ((T([64, 1000], f16), 0.5), {})
+cnt: 4, ((T([64, 6, 196, 9, 9], f16), 0.1767766952966369), {})
+Operator: aten.native_batch_norm.default
+cnt: 3, ((T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), True, 1e-05, [True, True, True]), {})
+Operator: aten.native_layer_norm.default
+cnt: 8, ((T([64, 28, 28, 192], f16, stride=(150528, 28, 1, 784)), [192], T([192], f16), T([192], f16), 1e-05), {})
+cnt: 28, ((T([64, 14, 14, 384], f16, stride=(75264, 14, 1, 196)), [384], T([384], f16), T([384], f16), 1e-05), {})
+cnt: 3, ((T([64, 197, 384], f16), [384], T([384], f16), T([384], f16), 1e-05), {})
+cnt: 2, ((T([64, 1, 384], f16), [384], T([384], f16), T([384], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 3, ((T([64, 197, 384], f16), T([64, 197, 384], f16), [384], T([64, 197, 1], f32), T([64, 197, 1], f32), T([384], f16), T([384], f16), [True, True, True]), {})
+cnt: 2, ((T([64, 1, 384], f16), T([64, 1, 384], f16), [384], T([64, 1, 1], f32), T([64, 1, 1], f32), T([384], f16), T([384], f16), [True, True, True]), {})
+cnt: 28, ((T([64, 14, 14, 384], f16), T([64, 14, 14, 384], f16, stride=(75264, 14, 1, 196)), [384], T([64, 14, 14, 1], f32), T([64, 14, 14, 1], f32), T([384], f16), T([384], f16), [True, True, True]), {})
+cnt: 8, ((T([64, 28, 28, 192], f16), T([64, 28, 28, 192], f16, stride=(150528, 28, 1, 784)), [192], T([64, 28, 28, 1], f32), T([64, 28, 28, 1], f32), T([192], f16), T([192], f16), [True, True, True]), {})
+Operator: aten.nll_loss_backward.default
+cnt: 1, ((T([], f16), T([64, 1000], f16), T([64], i64), None, 1, -100, T([], f16)), {})
+Operator: aten.nll_loss_forward.default
+cnt: 1, ((T([64, 1000], f16), T([64], i64), None, 1, -100), {})
+Operator: aten.relu_.default
+cnt: 3, ((T([64, 64, 112, 112], f16),), {})
+Operator: aten.scatter.src
+cnt: 1, ((T([64, 196, 1000], f16), 1, T([64, 1, 1000], i64), T([64, 1, 1000], f16)), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([64, 384], f16), [64, 197, 384], 1, 0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([64, 196, 384], f16), [64, 197, 384], 1, 1, 9223372036854775807, 1), {})
+cnt: 8, ((T([64, 197, 384], f16), [64, 197, 384], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 196, 384], f16, stride=(75648, 384, 1)), [64, 197, 384], 1, 1, 9223372036854775807, 1), {})
+cnt: 2, ((T([64, 1, 384], f16), [64, 1, 384], 2, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([64, 1, 384], f16), [64, 197, 384], 1, 0, 1, 1), {})
+Operator: aten.stack.default
+cnt: 2, (([T([64, 12, 197, 32], f16, stride=(75648, 6304, 1, 197)), T([64, 12, 197, 32], f16)],), {})
+cnt: 14, (([T([64, 12, 196, 32], f16), T([64, 12, 196, 32], f16, stride=(75264, 6272, 1, 196)), T([64, 12, 196, 32], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 196, 1000], f16), [0, 1], True), {})
+cnt: 1, ((T([64, 1000], f16), [0], True), {})
+cnt: 2, ((T([64, 384], f16, stride=(75648, 1)), [0], True), {})
+cnt: 2, ((T([64, 1152], f16), [0], True), {})
+cnt: 2, ((T([64, 384], f16), [0], True), {})
+cnt: 1, ((T([64, 1, 384], f16, stride=(75648, 384, 1)), [0], True), {})
+cnt: 1, ((T([64, 14, 14, 384], f16, stride=(75648, 5376, 384, 1)), [0, 1, 2], True), {})
+cnt: 14, ((T([64, 14, 14, 1152], f16), [0, 1, 2], True), {})
+cnt: 27, ((T([64, 14, 14, 384], f16), [0, 1, 2], True), {})
+cnt: 1, ((T([64, 14, 14, 384], f16), [0], True), {})
+cnt: 8, ((T([64, 28, 28, 192], f16, stride=(150528, 28, 1, 784)), [0, 1, 2], True), {})
+cnt: 4, ((T([64, 28, 28, 576], f16), [0, 1, 2], True), {})
+cnt: 4, ((T([64, 14, 14, 486], f16), [0, 1, 2], True), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([64, 64, 112, 112], f16), T([64, 64, 112, 112], f16), 0), {})
+Operator: aten.unbind.int
+cnt: 14, ((T([3, 64, 12, 196, 32], f16, stride=(384, 225792, 32, 1152, 1)),), {})
+cnt: 2, ((T([2, 64, 12, 197, 32], f16, stride=(384, 151296, 32, 768, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/BERT_pytorch_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/BERT_pytorch_training.txt
new file mode 100644
index 0000000000000..6c1b78ab6bfea
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/BERT_pytorch_training.txt
@@ -0,0 +1,94 @@
+Operator: aten._softmax.default
+cnt: 12, ((T([16, 12, 128, 128], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([16, 12, 128, 128], f16), T([16, 12, 128, 128], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([16, 12, 128, 64], f16), [192, 128, 64]), {})
+cnt: 12, ((T([16, 12, 64, 128], f16), [192, 64, 128]), {})
+cnt: 12, ((T([192, 128, 128], f16), [16, 12, 128, 128]), {})
+cnt: 12, ((T([192, 128, 64], f16), [16, 12, 128, 64]), {})
+cnt: 24, ((T([16, 128, 12, 64], f16), [16, 128, 768]), {})
+cnt: 12, ((T([16, 128, 768], f16), [2048, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([16, 128, 768], f16), T([1, 128, 768], f16)), {})
+cnt: 120, ((T([16, 128, 768], f16), T([16, 128, 768], f16)), {})
+cnt: 24, ((T([16, 128, 1], f16), 1e-06), {})
+cnt: 24, ((T([16, 128, 768], f16), T([768], f16)), {})
+cnt: 1, ((T([16, 128, 768], f16, stride=(0, 0, 0)), T([16, 128, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([192, 128, 64], f16), T([192, 64, 128], f16)), {})
+cnt: 12, ((T([192, 128, 128], f16), T([192, 128, 64], f16)), {})
+cnt: 12, ((T([192, 128, 128], f16, stride=(16384, 1, 128)), T([192, 128, 64], f16)), {})
+cnt: 12, ((T([192, 128, 64], f16), T([192, 64, 128], f16, stride=(8192, 1, 64))), {})
+cnt: 12, ((T([192, 64, 128], f16, stride=(8192, 1, 64)), T([192, 128, 128], f16)), {})
+cnt: 12, ((T([192, 128, 128], f16), T([192, 128, 64], f16, stride=(8192, 1, 128))), {})
+Operator: aten.clone.default
+cnt: 2, ((T([16, 128], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([16, 128], i64), T([16, 128], i64)), {})
+Operator: aten.div.Scalar
+cnt: 24, ((T([16, 128, 768], f16, stride=(128, 1, 0)), 768), {})
+Operator: aten.div.Tensor
+cnt: 96, ((T([16, 128, 768], f16), T([16, 128, 1], f16)), {})
+cnt: 24, ((T([16, 12, 128, 128], f16), 8.0), {})
+cnt: 2, ((T([], f16), 1572864), {})
+cnt: 24, ((T([16, 128, 1], f16), T([16, 128, 1], f16)), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([20005, 768], f16), T([16, 128], i64), 0), {})
+cnt: 1, ((T([3, 768], f16), T([16, 128], i64), 0), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([16, 128, 768], f16), T([16, 128], i64), 3, 0, False), {})
+cnt: 1, ((T([16, 128, 768], f16), T([16, 128], i64), 20005, 0, False), {})
+Operator: aten.eq.Scalar
+cnt: 12, ((T([16, 1, 128, 128], b8), 0), {})
+cnt: 24, ((T([16, 128, 1], f16), 0), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([16, 128, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([16, 128, 3072], f16), T([16, 128, 3072], f16)), {})
+Operator: aten.gt.Scalar
+cnt: 1, ((T([16, 128], i64), 0), {})
+Operator: aten.masked_fill.Scalar
+cnt: 12, ((T([16, 12, 128, 128], f16), T([16, 1, 128, 128], b8), -65504.0), {})
+cnt: 12, ((T([16, 12, 128, 128], f16), T([16, 1, 128, 128], b8), 0), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 24, ((T([16, 128, 1], f16), T([16, 128, 1], b8), 0), {})
+Operator: aten.mean.dim
+cnt: 48, ((T([16, 128, 768], f16), [-1], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 768], f16, stride=(0, 0)), T([768, 3072], f16)), {})
+cnt: 1, ((T([768, 2048], f16, stride=(0, 0)), T([2048, 3072], f16)), {})
+cnt: 12, ((T([2048, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+cnt: 48, ((T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 48, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+cnt: 11, ((T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 11, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 24, ((T([16, 128, 1], f16), 2), {})
+cnt: 24, ((T([16, 128, 1], f16), 0.002607561929595828), {})
+Operator: aten.mul.Tensor
+cnt: 24, ((T([768], f16), T([16, 128, 768], f16)), {})
+cnt: 48, ((T([16, 128, 768], f16), T([16, 128, 768], f16)), {})
+cnt: 24, ((T([16, 128, 768], f16), T([768], f16)), {})
+cnt: 24, ((T([16, 128, 1], f16), T([16, 128, 768], f16)), {})
+Operator: aten.neg.default
+cnt: 48, ((T([16, 128, 768], f16),), {})
+Operator: aten.repeat.default
+cnt: 1, ((T([16, 1, 128], b8), [1, 128, 1]), {})
+Operator: aten.std.correction
+cnt: 24, ((T([16, 128, 768], f16), [-1]), {'correction': 1, 'keepdim': True})
+Operator: aten.sub.Tensor
+cnt: 48, ((T([16, 128, 768], f16), T([16, 128, 1], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2048, 768], f16, stride=(0, 0)), [0], True), {})
+cnt: 12, ((T([2048, 3072], f16), [0], True), {})
+cnt: 48, ((T([16, 128, 768], f16), [0, 1], True), {})
+cnt: 48, ((T([16, 128, 768], f16), [2], True), {})
+cnt: 59, ((T([2048, 768], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([16, 128, 768], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/Background_Matting_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/Background_Matting_training.txt
new file mode 100644
index 0000000000000..fbc1f47d5c8fd
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/Background_Matting_training.txt
@@ -0,0 +1,119 @@
+Operator: aten.add.Tensor
+cnt: 27, ((T([3, 256, 128, 128], f16), T([3, 256, 128, 128], f16)), {})
+cnt: 1, ((T([], f16), 0), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+cnt: 1, ((T([3, 256, 128, 128], f16, stride=(7340032, 16384, 128, 1)), T([3, 256, 128, 128], f16, stride=(8388608, 16384, 128, 1))), {})
+cnt: 2, ((T([3, 256, 128, 128], f16), T([3, 256, 128, 128], f16, stride=(8388608, 16384, 128, 1))), {})
+cnt: 1, ((T([3, 256, 128, 128], f16, stride=(8388608, 16384, 128, 1)), T([3, 256, 128, 128], f16, stride=(8388608, 16384, 128, 1))), {})
+cnt: 1, ((T([3, 128, 256, 256], f16, stride=(16777216, 65536, 256, 1)), T([3, 128, 256, 256], f16)), {})
+Operator: aten.cat.default
+cnt: 2, (([T([3, 256, 128, 128], f16), T([3, 256, 128, 128], f16)], 1), {})
+cnt: 1, (([T([3, 256, 128, 128], f16), T([3, 256, 128, 128], f16, stride=(4194304, 1, 32768, 256))], 1), {})
+cnt: 1, (([T([3, 64, 128, 128], f16), T([3, 64, 128, 128], f16), T([3, 64, 128, 128], f16)], 1), {})
+cnt: 1, (([T([3, 256, 128, 128], f16), T([3, 192, 128, 128], f16)], 1), {})
+cnt: 1, (([T([3, 128, 256, 256], f16), T([3, 128, 256, 256], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 2, ((T([3, 3, 512, 512], f16),), {})
+cnt: 1, ((T([3, 1, 512, 512], f16),), {})
+cnt: 1, ((T([3, 4, 512, 512], f16),), {})
+Operator: aten.convolution.default
+cnt: 2, ((T([3, 3, 518, 518], f16), T([64, 3, 7, 7], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([3, 64, 512, 512], f16), T([128, 64, 3, 3], f16), T([128], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([3, 128, 256, 256], f16), T([256, 128, 3, 3], f16), T([256], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([3, 1, 518, 518], f16), T([64, 1, 7, 7], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([3, 64, 512, 512], f16, stride=(16777216, 1, 32768, 64)), T([128, 64, 3, 3], f16), T([128], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([3, 128, 256, 256], f16, stride=(8388608, 1, 32768, 128)), T([256, 128, 3, 3], f16), T([256], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([3, 4, 518, 518], f16), T([64, 4, 7, 7], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([3, 512, 128, 128], f16), T([64, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([3, 448, 128, 128], f16), T([256, 448, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 26, ((T([3, 256, 130, 130], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([3, 256, 256, 256], f16), T([128, 256, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([3, 128, 512, 512], f16), T([64, 128, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([3, 64, 518, 518], f16), T([1, 64, 7, 7], f16), T([1], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([3, 256, 512, 512], f16), T([64, 256, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([3, 64, 518, 518], f16), T([3, 64, 7, 7], f16), T([3], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([3, 3, 512, 512], f16, stride=(0, 0, 0, 0)), T([3, 64, 518, 518], f16), T([3, 64, 7, 7], f16), [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([3, 64, 512, 512], f16), T([3, 256, 512, 512], f16), T([64, 256, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([3, 128, 256, 256], f16), T([3, 256, 256, 256], f16), T([128, 256, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 26, ((T([3, 256, 128, 128], f16), T([3, 256, 130, 130], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([3, 1, 512, 512], f16), T([3, 64, 518, 518], f16), T([1, 64, 7, 7], f16), [1], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([3, 64, 512, 512], f16), T([3, 128, 512, 512], f16), T([64, 128, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([3, 256, 128, 128], f16), T([3, 448, 128, 128], f16), T([256, 448, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([3, 64, 128, 128], f16), T([3, 512, 128, 128], f16), T([64, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([3, 256, 128, 128], f16, stride=(4194304, 1, 32768, 256)), T([3, 128, 256, 256], f16, stride=(8388608, 1, 32768, 128)), T([256, 128, 3, 3], f16), [256], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([3, 128, 256, 256], f16, stride=(8388608, 1, 32768, 128)), T([3, 64, 512, 512], f16, stride=(16777216, 1, 32768, 64)), T([128, 64, 3, 3], f16), [128], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([3, 64, 512, 512], f16, stride=(16777216, 1, 32768, 64)), T([3, 1, 518, 518], f16), T([64, 1, 7, 7], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+cnt: 2, ((T([3, 256, 128, 128], f16), T([3, 128, 256, 256], f16), T([256, 128, 3, 3], f16), [256], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([3, 128, 256, 256], f16), T([3, 64, 512, 512], f16), T([128, 64, 3, 3], f16), [128], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([3, 64, 512, 512], f16), T([3, 3, 518, 518], f16), T([64, 3, 7, 7], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([3, 3, 512, 512], f16), T([3, 3, 512, 512], f16)), {})
+cnt: 1, ((T([3, 1, 512, 512], f16), T([3, 1, 512, 512], f16)), {})
+cnt: 1, ((T([3, 4, 512, 512], f16), T([3, 4, 512, 512], f16)), {})
+cnt: 1, ((T([256, 128, 3, 3], f16), T([256, 128, 3, 3], f16, stride=(1152, 1, 384, 128))), {})
+cnt: 1, ((T([128, 64, 3, 3], f16), T([128, 64, 3, 3], f16, stride=(576, 1, 192, 64))), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 786432), {})
+cnt: 2, ((T([], f16), 2359296), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.native_batch_norm.default
+cnt: 5, ((T([3, 64, 512, 512], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([3, 128, 256, 256], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 30, ((T([3, 256, 128, 128], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([3, 64, 512, 512], f16, stride=(16777216, 1, 32768, 64)), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([3, 128, 256, 256], f16, stride=(8388608, 1, 32768, 128)), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([3, 256, 128, 128], f16, stride=(4194304, 1, 32768, 256)), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([3, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([3, 64, 512, 512], f16), T([3, 64, 512, 512], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([3, 128, 256, 256], f16), T([3, 128, 256, 256], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 29, ((T([3, 256, 128, 128], f16), T([3, 256, 128, 128], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([3, 64, 128, 128], f16), T([3, 64, 128, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([3, 256, 128, 128], f16, stride=(4194304, 1, 32768, 256)), T([3, 256, 128, 128], f16, stride=(4194304, 1, 32768, 256)), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([3, 128, 256, 256], f16, stride=(8388608, 1, 32768, 128)), T([3, 128, 256, 256], f16, stride=(8388608, 1, 32768, 128)), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([3, 64, 512, 512], f16, stride=(16777216, 1, 32768, 64)), T([3, 64, 512, 512], f16, stride=(16777216, 1, 32768, 64)), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 1, ((T([256, 128, 3, 3], f16, stride=(1152, 1, 384, 128)), [256, 128, 3, 3], [1152, 9, 3, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([128, 64, 3, 3], f16, stride=(576, 1, 192, 64)), [128, 64, 3, 3], [576, 9, 3, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.reflection_pad2d.default
+cnt: 2, ((T([3, 3, 512, 512], f16), [3, 3, 3, 3]), {})
+cnt: 1, ((T([3, 1, 512, 512], f16), [3, 3, 3, 3]), {})
+cnt: 1, ((T([3, 4, 512, 512], f16), [3, 3, 3, 3]), {})
+cnt: 26, ((T([3, 256, 128, 128], f16), [1, 1, 1, 1]), {})
+cnt: 2, ((T([3, 64, 512, 512], f16), [3, 3, 3, 3]), {})
+Operator: aten.reflection_pad2d_backward.default
+cnt: 2, ((T([3, 64, 518, 518], f16), T([3, 64, 512, 512], f16), [3, 3, 3, 3]), {})
+cnt: 26, ((T([3, 256, 130, 130], f16), T([3, 256, 128, 128], f16), [1, 1, 1, 1]), {})
+Operator: aten.relu_.default
+cnt: 5, ((T([3, 64, 512, 512], f16),), {})
+cnt: 5, ((T([3, 128, 256, 256], f16),), {})
+cnt: 17, ((T([3, 256, 128, 128], f16),), {})
+cnt: 1, ((T([3, 64, 512, 512], f16, stride=(16777216, 1, 32768, 64)),), {})
+cnt: 1, ((T([3, 128, 256, 256], f16, stride=(8388608, 1, 32768, 128)),), {})
+cnt: 1, ((T([3, 256, 128, 128], f16, stride=(4194304, 1, 32768, 256)),), {})
+cnt: 3, ((T([3, 64, 128, 128], f16),), {})
+Operator: aten.sum.default
+cnt: 1, ((T([3, 1, 512, 512], f16),), {})
+cnt: 1, ((T([3, 3, 512, 512], f16),), {})
+Operator: aten.tanh.default
+cnt: 1, ((T([3, 1, 512, 512], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 1, ((T([3, 1, 512, 512], f16, stride=(0, 0, 0, 0)), T([3, 1, 512, 512], f16)), {})
+Operator: aten.threshold_backward.default
+cnt: 4, ((T([3, 64, 512, 512], f16), T([3, 64, 512, 512], f16), 0), {})
+cnt: 1, ((T([3, 128, 256, 256], f16, stride=(16777216, 65536, 256, 1)), T([3, 128, 256, 256], f16), 0), {})
+cnt: 16, ((T([3, 256, 128, 128], f16), T([3, 256, 128, 128], f16), 0), {})
+cnt: 3, ((T([3, 128, 256, 256], f16), T([3, 128, 256, 256], f16), 0), {})
+cnt: 3, ((T([3, 64, 128, 128], f16, stride=(7340032, 16384, 128, 1)), T([3, 64, 128, 128], f16), 0), {})
+cnt: 1, ((T([3, 256, 128, 128], f16, stride=(8388608, 16384, 128, 1)), T([3, 256, 128, 128], f16, stride=(4194304, 1, 32768, 256)), 0), {})
+cnt: 1, ((T([3, 128, 256, 256], f16, stride=(8388608, 1, 32768, 128)), T([3, 128, 256, 256], f16, stride=(8388608, 1, 32768, 128)), 0), {})
+cnt: 1, ((T([3, 64, 512, 512], f16, stride=(16777216, 1, 32768, 64)), T([3, 64, 512, 512], f16, stride=(16777216, 1, 32768, 64)), 0), {})
+Operator: aten.upsample_bilinear2d.vec
+cnt: 2, ((T([3, 256, 128, 128], f16), None, True, [2.0, 2.0]), {})
+cnt: 1, ((T([3, 128, 256, 256], f16), None, True, [2.0, 2.0]), {})
+cnt: 1, ((T([3, 256, 256, 256], f16), None, True, [2.0, 2.0]), {})
+Operator: aten.upsample_bilinear2d_backward.vec
+cnt: 1, ((T([3, 256, 512, 512], f16), None, [3, 256, 256, 256], True, [2.0, 2.0]), {})
+cnt: 2, ((T([3, 256, 256, 256], f16), None, [3, 256, 128, 128], True, [2.0, 2.0]), {})
+cnt: 1, ((T([3, 128, 512, 512], f16), None, [3, 128, 256, 256], True, [2.0, 2.0]), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/LearningToPaint_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/LearningToPaint_training.txt
new file mode 100644
index 0000000000000..272e9fb338582
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/LearningToPaint_training.txt
@@ -0,0 +1,86 @@
+Operator: aten.add.Tensor
+cnt: 1, ((T([96, 512, 4, 4], f16), T([96, 512, 4, 4], f16)), {})
+cnt: 2, ((T([96, 256, 8, 8], f16), T([96, 256, 8, 8], f16)), {})
+cnt: 2, ((T([96, 128, 16, 16], f16), T([96, 128, 16, 16], f16)), {})
+cnt: 2, ((T([96, 64, 32, 32], f16), T([96, 64, 32, 32], f16)), {})
+cnt: 1, ((T([96, 64, 64, 64], f16), T([96, 64, 64, 64], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 2, ((T([96, 64, 32, 32], f16), T([96, 64, 32, 32], f16)), {})
+cnt: 2, ((T([96, 128, 16, 16], f16), T([96, 128, 16, 16], f16)), {})
+cnt: 2, ((T([96, 256, 8, 8], f16), T([96, 256, 8, 8], f16)), {})
+cnt: 2, ((T([96, 512, 4, 4], f16), T([96, 512, 4, 4], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([65], f16), T([96, 512], f16), T([512, 65], f16, stride=(1, 512))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([96, 512, 4, 4], f16), [4, 4]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([96, 512, 1, 1], f16), T([96, 512, 4, 4], f16), [4, 4], [], [0, 0], False, True, None), {})
+Operator: aten.clone.default
+cnt: 1, ((T([96, 9, 128, 128], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([96, 9, 128, 128], f16), T([64, 9, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 64, 64, 64], f16), T([64, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([96, 64, 32, 32], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 64, 64, 64], f16), T([64, 64, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 64, 32, 32], f16), T([128, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([96, 128, 16, 16], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 64, 32, 32], f16), T([128, 64, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 128, 16, 16], f16), T([256, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([96, 256, 8, 8], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 128, 16, 16], f16), T([256, 128, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 256, 8, 8], f16), T([512, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([96, 512, 4, 4], f16), T([512, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 256, 8, 8], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([96, 512, 4, 4], f16), T([96, 512, 4, 4], f16), T([512, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 512, 4, 4], f16), T([96, 256, 8, 8], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 512, 4, 4], f16), T([96, 256, 8, 8], f16), T([512, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([96, 256, 8, 8], f16), T([96, 256, 8, 8], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 256, 8, 8], f16), T([96, 128, 16, 16], f16), T([256, 128, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 256, 8, 8], f16), T([96, 128, 16, 16], f16), T([256, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([96, 128, 16, 16], f16), T([96, 128, 16, 16], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 128, 16, 16], f16), T([96, 64, 32, 32], f16), T([128, 64, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 128, 16, 16], f16), T([96, 64, 32, 32], f16), T([128, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([96, 64, 32, 32], f16), T([96, 64, 32, 32], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 64, 32, 32], f16), T([96, 64, 64, 64], f16), T([64, 64, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 64, 32, 32], f16), T([96, 64, 64, 64], f16), T([64, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 64, 64, 64], f16), T([96, 9, 128, 128], f16), T([64, 9, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([96, 9, 128, 128], f16), T([96, 9, 128, 128], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 6240), {})
+Operator: aten.mm.default
+cnt: 1, ((T([96, 65], f16), T([65, 512], f16)), {})
+cnt: 1, ((T([65, 96], f16, stride=(1, 65)), T([96, 512], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([96, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([96, 64, 32, 32], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([96, 128, 16, 16], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([96, 256, 8, 8], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([96, 512, 4, 4], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 5, ((T([96, 512, 4, 4], f16), T([96, 512, 4, 4], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([96, 256, 8, 8], f16), T([96, 256, 8, 8], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([96, 128, 16, 16], f16), T([96, 128, 16, 16], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([96, 64, 32, 32], f16), T([96, 64, 32, 32], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([96, 64, 64, 64], f16), T([96, 64, 64, 64], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu.default
+cnt: 1, ((T([96, 64, 64, 64], f16),), {})
+cnt: 4, ((T([96, 64, 32, 32], f16),), {})
+cnt: 4, ((T([96, 128, 16, 16], f16),), {})
+cnt: 4, ((T([96, 256, 8, 8], f16),), {})
+cnt: 4, ((T([96, 512, 4, 4], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([96, 65], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([96, 65], f16, stride=(0, 0)), T([96, 65], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([96, 65], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([96, 65], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 4, ((T([96, 512, 4, 4], f16), T([96, 512, 4, 4], f16), 0), {})
+cnt: 4, ((T([96, 256, 8, 8], f16), T([96, 256, 8, 8], f16), 0), {})
+cnt: 4, ((T([96, 128, 16, 16], f16), T([96, 128, 16, 16], f16), 0), {})
+cnt: 4, ((T([96, 64, 32, 32], f16), T([96, 64, 32, 32], f16), 0), {})
+cnt: 1, ((T([96, 64, 64, 64], f16), T([96, 64, 64, 64], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/Super_SloMo_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/Super_SloMo_training.txt
new file mode 100644
index 0000000000000..ff432c07b7abf
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/Super_SloMo_training.txt
@@ -0,0 +1,255 @@
+Operator: aten._to_copy.default
+cnt: 12, ((T([6, 352, 352], i64, stride=(0, 352, 1)),), {'dtype': f16})
+Operator: aten.abs.default
+cnt: 5, ((T([6, 3, 352, 352], f16),), {})
+cnt: 2, ((T([6, 2, 352, 351], f16),), {})
+cnt: 2, ((T([6, 2, 351, 352], f16),), {})
+Operator: aten.add.Tensor
+cnt: 22, ((T([6, 2, 352, 352], f16), T([6, 2, 352, 352], f16)), {})
+cnt: 8, ((T([6, 352, 352], f16), T([6, 352, 352], f16, stride=(247808, 352, 1))), {})
+cnt: 2, ((T([6, 2, 352, 352], f16, stride=(619520, 123904, 352, 1)), T([6, 2, 352, 352], f16)), {})
+cnt: 2, ((T([6, 3, 352, 352], f16), T([6, 3, 352, 352], f16)), {})
+cnt: 4, ((T([6, 1, 352, 352], f16), T([6, 1, 352, 352], f16)), {})
+cnt: 10, ((T([], f16), T([], f16)), {})
+cnt: 4, ((T([6, 352, 352], f16), T([6, 352, 352], f16, stride=(495616, 352, 1))), {})
+cnt: 1, ((T([], f16), 0), {})
+cnt: 1, ((T([6, 3, 352, 352], f16, stride=(0, 0, 0, 0)), T([6, 3, 352, 352], f16)), {})
+cnt: 2, ((T([6, 5, 352, 352], f16), T([6, 5, 352, 352], f16)), {})
+cnt: 2, ((T([6, 512, 22, 22], f16, stride=(495616, 484, 22, 1)), T([6, 512, 22, 22], f16)), {})
+cnt: 2, ((T([6, 256, 44, 44], f16, stride=(991232, 1936, 44, 1)), T([6, 256, 44, 44], f16)), {})
+cnt: 2, ((T([6, 128, 88, 88], f16, stride=(1982464, 7744, 88, 1)), T([6, 128, 88, 88], f16)), {})
+cnt: 2, ((T([6, 64, 176, 176], f16, stride=(3964928, 30976, 176, 1)), T([6, 64, 176, 176], f16)), {})
+cnt: 2, ((T([6, 32, 352, 352], f16, stride=(7929856, 123904, 352, 1)), T([6, 32, 352, 352], f16)), {})
+cnt: 4, ((T([6, 2, 352, 352], f16), T([6, 2, 352, 352], f16, stride=(2478080, 123904, 352, 1))), {})
+cnt: 2, ((T([6, 3, 352, 352], f16), T([6, 3, 352, 352], f16, stride=(2478080, 123904, 352, 1))), {})
+cnt: 1, ((T([6, 4, 352, 352], f16), T([6, 4, 352, 352], f16)), {})
+Operator: aten.avg_pool2d.default
+cnt: 2, ((T([6, 32, 352, 352], f16), [2, 2]), {})
+cnt: 2, ((T([6, 64, 176, 176], f16), [2, 2]), {})
+cnt: 2, ((T([6, 128, 88, 88], f16), [2, 2]), {})
+cnt: 2, ((T([6, 256, 44, 44], f16), [2, 2]), {})
+cnt: 2, ((T([6, 512, 22, 22], f16), [2, 2]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 2, ((T([6, 512, 11, 11], f16), T([6, 512, 22, 22], f16), [2, 2], [], [0, 0], False, True, None), {})
+cnt: 2, ((T([6, 256, 22, 22], f16), T([6, 256, 44, 44], f16), [2, 2], [], [0, 0], False, True, None), {})
+cnt: 2, ((T([6, 128, 44, 44], f16), T([6, 128, 88, 88], f16), [2, 2], [], [0, 0], False, True, None), {})
+cnt: 2, ((T([6, 64, 88, 88], f16), T([6, 64, 176, 176], f16), [2, 2], [], [0, 0], False, True, None), {})
+cnt: 2, ((T([6, 32, 176, 176], f16), T([6, 32, 352, 352], f16), [2, 2], [], [0, 0], False, True, None), {})
+Operator: aten.cat.default
+cnt: 1, (([T([6, 3, 352, 352], f16), T([6, 3, 352, 352], f16)], 1), {})
+cnt: 2, (([T([6, 512, 22, 22], f16), T([6, 512, 22, 22], f16)], 1), {})
+cnt: 2, (([T([6, 256, 44, 44], f16), T([6, 256, 44, 44], f16)], 1), {})
+cnt: 2, (([T([6, 128, 88, 88], f16), T([6, 128, 88, 88], f16)], 1), {})
+cnt: 2, (([T([6, 64, 176, 176], f16), T([6, 64, 176, 176], f16)], 1), {})
+cnt: 2, (([T([6, 32, 352, 352], f16), T([6, 32, 352, 352], f16)], 1), {})
+cnt: 1, (([T([6, 3, 352, 352], f16), T([6, 3, 352, 352], f16), T([6, 2, 352, 352], f16, stride=(495616, 123904, 352, 1)), T([6, 2, 352, 352], f16, stride=(495616, 123904, 352, 1)), T([6, 2, 352, 352], f16), T([6, 2, 352, 352], f16), T([6, 3, 352, 352], f16), T([6, 3, 352, 352], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([6], i64),), {})
+cnt: 3, ((T([6, 3, 352, 352], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([6, 6, 352, 352], f16), T([32, 6, 7, 7], f16), T([32], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 32, 352, 352], f16), T([32, 32, 7, 7], f16), T([32], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 32, 176, 176], f16), T([64, 32, 5, 5], f16), T([64], f16), [1, 1], [2, 2], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 64, 176, 176], f16), T([64, 64, 5, 5], f16), T([64], f16), [1, 1], [2, 2], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 64, 88, 88], f16), T([128, 64, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 128, 88, 88], f16), T([128, 128, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 128, 44, 44], f16), T([256, 128, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 256, 44, 44], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 256, 22, 22], f16), T([512, 256, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([6, 512, 22, 22], f16), T([512, 512, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([6, 512, 11, 11], f16), T([512, 512, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 1024, 22, 22], f16), T([512, 1024, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([6, 512, 44, 44], f16), T([256, 512, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([6, 256, 88, 88], f16), T([128, 256, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([6, 128, 176, 176], f16), T([64, 128, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([6, 64, 352, 352], f16), T([32, 64, 3, 3], f16), T([32], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([6, 32, 352, 352], f16), T([4, 32, 3, 3], f16), T([4], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([6, 20, 352, 352], f16), T([32, 20, 7, 7], f16), T([32], f16), [1, 1], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([6, 32, 352, 352], f16), T([5, 32, 3, 3], f16), T([5], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 3, 352, 352], f16), T([64, 3, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 64, 352, 352], f16), T([64, 64, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 64, 176, 176], f16), T([128, 64, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 128, 176, 176], f16), T([128, 128, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 128, 88, 88], f16), T([256, 128, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([6, 256, 88, 88], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([6, 256, 44, 44], f16), T([512, 256, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([6, 512, 44, 44], f16), T([512, 512, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 2, ((T([6, 512, 44, 44], f16), T([6, 512, 44, 44], f16), T([512, 512, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, False, False]), {})
+cnt: 1, ((T([6, 512, 44, 44], f16), T([6, 256, 44, 44], f16), T([512, 256, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, False, False]), {})
+cnt: 2, ((T([6, 256, 88, 88], f16), T([6, 256, 88, 88], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, False, False]), {})
+cnt: 1, ((T([6, 256, 88, 88], f16), T([6, 128, 88, 88], f16), T([256, 128, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, False, False]), {})
+cnt: 1, ((T([6, 128, 176, 176], f16), T([6, 128, 176, 176], f16), T([128, 128, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, False, False]), {})
+cnt: 1, ((T([6, 128, 176, 176], f16), T([6, 64, 176, 176], f16), T([128, 64, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, False, False]), {})
+cnt: 1, ((T([6, 64, 352, 352], f16), T([6, 64, 352, 352], f16), T([64, 64, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, False, False]), {})
+cnt: 1, ((T([6, 64, 352, 352], f16), T([6, 3, 352, 352], f16), T([64, 3, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, False, False]), {})
+cnt: 1, ((T([6, 5, 352, 352], f16), T([6, 32, 352, 352], f16), T([5, 32, 3, 3], f16), [5], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([6, 32, 352, 352], f16), T([6, 64, 352, 352], f16), T([32, 64, 3, 3], f16), [32], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([6, 64, 176, 176], f16), T([6, 128, 176, 176], f16), T([64, 128, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([6, 128, 88, 88], f16), T([6, 256, 88, 88], f16), T([128, 256, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([6, 256, 44, 44], f16), T([6, 512, 44, 44], f16), T([256, 512, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([6, 512, 22, 22], f16), T([6, 1024, 22, 22], f16), T([512, 1024, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([6, 512, 22, 22], f16), T([6, 512, 22, 22], f16), T([512, 512, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([6, 512, 11, 11], f16), T([6, 512, 11, 11], f16), T([512, 512, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([6, 512, 22, 22], f16), T([6, 256, 22, 22], f16), T([512, 256, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([6, 256, 44, 44], f16), T([6, 256, 44, 44], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([6, 256, 44, 44], f16), T([6, 128, 44, 44], f16), T([256, 128, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([6, 128, 88, 88], f16), T([6, 128, 88, 88], f16), T([128, 128, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([6, 128, 88, 88], f16), T([6, 64, 88, 88], f16), T([128, 64, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([6, 64, 176, 176], f16), T([6, 64, 176, 176], f16), T([64, 64, 5, 5], f16), [64], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([6, 64, 176, 176], f16), T([6, 32, 176, 176], f16), T([64, 32, 5, 5], f16), [64], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([6, 32, 352, 352], f16), T([6, 32, 352, 352], f16), T([32, 32, 7, 7], f16), [32], [1, 1], [3, 3], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([6, 32, 352, 352], f16), T([6, 20, 352, 352], f16), T([32, 20, 7, 7], f16), [32], [1, 1], [3, 3], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([6, 4, 352, 352], f16), T([6, 32, 352, 352], f16), T([4, 32, 3, 3], f16), [4], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([6, 32, 352, 352], f16), T([6, 6, 352, 352], f16), T([32, 6, 7, 7], f16), [32], [1, 1], [3, 3], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([6], i64), T([6], i64)), {})
+cnt: 3, ((T([6, 3, 352, 352], f16), T([6, 3, 352, 352], f16)), {})
+Operator: aten.div.Scalar
+cnt: 2, ((T([6, 2, 351, 352], f16, stride=(0, 0, 0, 0)), 1482624), {})
+cnt: 2, ((T([6, 2, 352, 351], f16, stride=(0, 0, 0, 0)), 1482624), {})
+cnt: 5, ((T([6, 3, 352, 352], f16, stride=(0, 0, 0, 0)), 2230272), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([6, 352, 352], f16), 352), {})
+cnt: 4, ((T([6, 3, 352, 352], f16), T([6, 1, 352, 352], f16)), {})
+cnt: 2, ((T([], f16), 2230272), {})
+cnt: 2, ((T([], f16), 1), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.grid_sampler_2d.default
+cnt: 6, ((T([6, 3, 352, 352], f16), T([6, 352, 352, 2], f16), 0, 0, False), {})
+Operator: aten.grid_sampler_2d_backward.default
+cnt: 6, ((T([6, 3, 352, 352], f16), T([6, 3, 352, 352], f16), T([6, 352, 352, 2], f16), 0, 0, False, [False, True]), {})
+Operator: aten.index.Tensor
+cnt: 8, ((T([7], f16), [T([6], i64)]), {})
+Operator: aten.leaky_relu.default
+cnt: 8, ((T([6, 32, 352, 352], f16), 0.1), {})
+cnt: 8, ((T([6, 64, 176, 176], f16), 0.1), {})
+cnt: 8, ((T([6, 128, 88, 88], f16), 0.1), {})
+cnt: 8, ((T([6, 256, 44, 44], f16), 0.1), {})
+cnt: 8, ((T([6, 512, 22, 22], f16), 0.1), {})
+cnt: 4, ((T([6, 512, 11, 11], f16), 0.1), {})
+cnt: 1, ((T([6, 4, 352, 352], f16), 0.1), {})
+cnt: 1, ((T([6, 5, 352, 352], f16), 0.1), {})
+Operator: aten.leaky_relu_backward.default
+cnt: 1, ((T([6, 5, 352, 352], f16), T([6, 5, 352, 352], f16), 0.1, False), {})
+cnt: 6, ((T([6, 32, 352, 352], f16), T([6, 32, 352, 352], f16), 0.1, False), {})
+cnt: 2, ((T([6, 32, 352, 352], f16, stride=(7929856, 123904, 352, 1)), T([6, 32, 352, 352], f16), 0.1, False), {})
+cnt: 6, ((T([6, 64, 176, 176], f16), T([6, 64, 176, 176], f16), 0.1, False), {})
+cnt: 2, ((T([6, 64, 176, 176], f16, stride=(3964928, 30976, 176, 1)), T([6, 64, 176, 176], f16), 0.1, False), {})
+cnt: 6, ((T([6, 128, 88, 88], f16), T([6, 128, 88, 88], f16), 0.1, False), {})
+cnt: 2, ((T([6, 128, 88, 88], f16, stride=(1982464, 7744, 88, 1)), T([6, 128, 88, 88], f16), 0.1, False), {})
+cnt: 6, ((T([6, 256, 44, 44], f16), T([6, 256, 44, 44], f16), 0.1, False), {})
+cnt: 2, ((T([6, 256, 44, 44], f16, stride=(991232, 1936, 44, 1)), T([6, 256, 44, 44], f16), 0.1, False), {})
+cnt: 6, ((T([6, 512, 22, 22], f16), T([6, 512, 22, 22], f16), 0.1, False), {})
+cnt: 2, ((T([6, 512, 22, 22], f16, stride=(495616, 484, 22, 1)), T([6, 512, 22, 22], f16), 0.1, False), {})
+cnt: 4, ((T([6, 512, 11, 11], f16), T([6, 512, 11, 11], f16), 0.1, False), {})
+cnt: 1, ((T([6, 4, 352, 352], f16), T([6, 4, 352, 352], f16), 0.1, False), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 2, ((T([6, 64, 352, 352], f16), [2, 2], [2, 2]), {})
+cnt: 2, ((T([6, 128, 176, 176], f16), [2, 2], [2, 2]), {})
+cnt: 2, ((T([6, 256, 88, 88], f16), [2, 2], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([6, 256, 44, 44], f16), T([6, 256, 88, 88], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([6, 256, 44, 44], i64)), {})
+cnt: 1, ((T([6, 128, 88, 88], f16), T([6, 128, 176, 176], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([6, 128, 88, 88], i64)), {})
+cnt: 1, ((T([6, 64, 176, 176], f16), T([6, 64, 352, 352], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([6, 64, 176, 176], i64)), {})
+Operator: aten.mean.default
+cnt: 5, ((T([6, 3, 352, 352], f16),), {})
+cnt: 2, ((T([6, 2, 352, 351], f16),), {})
+cnt: 2, ((T([6, 2, 351, 352], f16),), {})
+Operator: aten.mse_loss.default
+cnt: 1, ((T([6, 512, 44, 44], f16), T([6, 512, 44, 44], f16)), {})
+Operator: aten.mse_loss_backward.default
+cnt: 1, ((T([], f16), T([6, 512, 44, 44], f16), T([6, 512, 44, 44], f16), 1), {})
+Operator: aten.mul.Tensor
+cnt: 3, ((T([6], f16), T([6], f16)), {})
+cnt: 4, ((T([6, 1, 1, 1], f16), T([6, 2, 352, 352], f16, stride=(495616, 123904, 352, 1))), {})
+cnt: 12, ((T([6, 352, 352], f16), 2), {})
+cnt: 4, ((T([6, 1, 1, 1], f16), T([6, 1, 352, 352], f16)), {})
+cnt: 2, ((T([6, 1, 352, 352], f16), T([6, 3, 352, 352], f16)), {})
+cnt: 2, ((T([], f16), 204), {})
+cnt: 2, ((T([], f16), 102), {})
+cnt: 2, ((T([], f16), 0.005), {})
+cnt: 2, ((T([6, 2, 351, 352], f16), T([6, 2, 351, 352], f16)), {})
+cnt: 2, ((T([6, 2, 352, 351], f16), T([6, 2, 352, 351], f16)), {})
+cnt: 8, ((T([6, 3, 352, 352], f16), T([6, 3, 352, 352], f16)), {})
+cnt: 12, ((T([6, 352, 352], f16, stride=(247808, 704, 2)), 2), {})
+cnt: 4, ((T([6, 1, 352, 352], f16), T([6, 1, 1, 1], f16)), {})
+cnt: 2, ((T([6, 3, 352, 352], f16), T([6, 1, 352, 352], f16)), {})
+cnt: 4, ((T([6, 2, 352, 352], f16), T([6, 1, 1, 1], f16)), {})
+Operator: aten.neg.default
+cnt: 1, ((T([6], f16),), {})
+cnt: 2, ((T([6, 2, 351, 352], f16),), {})
+cnt: 2, ((T([6, 2, 352, 351], f16),), {})
+cnt: 1, ((T([6, 3, 352, 352], f16),), {})
+cnt: 1, ((T([6, 1, 352, 352], f16),), {})
+Operator: aten.relu_.default
+cnt: 4, ((T([6, 64, 352, 352], f16),), {})
+cnt: 4, ((T([6, 128, 176, 176], f16),), {})
+cnt: 6, ((T([6, 256, 88, 88], f16),), {})
+cnt: 4, ((T([6, 512, 44, 44], f16),), {})
+Operator: aten.rsub.Scalar
+cnt: 4, ((T([6], f16), 1), {})
+cnt: 1, ((T([6, 1, 352, 352], f16), 1), {})
+Operator: aten.select_backward.default
+cnt: 6, ((T([6, 352, 352], f16), [6, 2, 352, 352], 1, 1), {})
+cnt: 6, ((T([6, 352, 352], f16), [6, 2, 352, 352], 1, 0), {})
+Operator: aten.sgn.default
+cnt: 2, ((T([6, 2, 351, 352], f16),), {})
+cnt: 2, ((T([6, 2, 352, 351], f16),), {})
+cnt: 5, ((T([6, 3, 352, 352], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([6, 1, 352, 352], f16, stride=(619520, 123904, 352, 1)),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([6, 1, 352, 352], f16), T([6, 1, 352, 352], f16)), {})
+Operator: aten.slice_backward.default
+cnt: 4, ((T([6, 2, 351, 352], f16), [6, 2, 351, 352], 3, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([6, 2, 351, 352], f16), [6, 2, 352, 352], 2, 1, 9223372036854775807, 1), {})
+cnt: 8, ((T([6, 2, 352, 352], f16), [6, 2, 352, 352], 1, 0, 9223372036854775807, 1), {})
+cnt: 20, ((T([6, 2, 352, 352], f16), [6, 2, 352, 352], 0, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([6, 2, 351, 352], f16), [6, 2, 352, 352], 2, 0, -1, 1), {})
+cnt: 2, ((T([6, 2, 352, 351], f16), [6, 2, 352, 352], 3, 1, 9223372036854775807, 1), {})
+cnt: 8, ((T([6, 2, 352, 352], f16), [6, 2, 352, 352], 2, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([6, 2, 352, 351], f16), [6, 2, 352, 352], 3, 0, -1, 1), {})
+cnt: 12, ((T([6, 352, 352], f16), [6, 352, 352], 2, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([6, 352, 352], f16), [6, 352, 352], 1, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([6, 1, 352, 352], f16), [6, 1, 352, 352], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([6, 1, 352, 352], f16), [6, 1, 352, 352], 2, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([6, 1, 352, 352], f16), [6, 5, 352, 352], 1, 4, 5, 1), {})
+cnt: 3, ((T([6, 5, 352, 352], f16), [6, 5, 352, 352], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([6, 2, 352, 352], f16), [6, 2, 352, 352], 3, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([6, 2, 352, 352], f16), [6, 5, 352, 352], 1, 2, 4, 1), {})
+cnt: 1, ((T([6, 2, 352, 352], f16), [6, 5, 352, 352], 1, 0, 2, 1), {})
+cnt: 1, ((T([6, 2, 352, 352], f16), [6, 4, 352, 352], 1, 2, 9223372036854775807, 1), {})
+cnt: 2, ((T([6, 4, 352, 352], f16), [6, 4, 352, 352], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([6, 2, 352, 352], f16), [6, 4, 352, 352], 1, 0, 2, 1), {})
+Operator: aten.stack.default
+cnt: 6, (([T([6, 352, 352], f16), T([6, 352, 352], f16)], 3), {})
+Operator: aten.sub.Tensor
+cnt: 12, ((T([6, 352, 352], f16), 0.5), {})
+cnt: 5, ((T([6, 3, 352, 352], f16), T([6, 3, 352, 352], f16)), {})
+cnt: 2, ((T([6, 2, 352, 351], f16, stride=(495616, 123904, 352, 1)), T([6, 2, 352, 351], f16, stride=(495616, 123904, 352, 1))), {})
+cnt: 2, ((T([6, 2, 351, 352], f16, stride=(495616, 123904, 352, 1)), T([6, 2, 351, 352], f16, stride=(495616, 123904, 352, 1))), {})
+Operator: aten.sum.SymInt
+cnt: 3, ((T([6, 3, 352, 352], f16), [1], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([6, 3, 352, 352], f16),), {})
+cnt: 1, ((T([], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([6, 512, 44, 44], f16), T([6, 512, 44, 44], f16), 0), {})
+cnt: 3, ((T([6, 256, 88, 88], f16), T([6, 256, 88, 88], f16), 0), {})
+cnt: 2, ((T([6, 128, 176, 176], f16), T([6, 128, 176, 176], f16), 0), {})
+cnt: 2, ((T([6, 64, 352, 352], f16), T([6, 64, 352, 352], f16), 0), {})
+Operator: aten.unbind.int
+cnt: 6, ((T([6, 352, 352, 2], f16), 3), {})
+Operator: aten.upsample_bilinear2d.vec
+cnt: 2, ((T([6, 512, 11, 11], f16), None, False, [2.0, 2.0]), {})
+cnt: 2, ((T([6, 512, 22, 22], f16), None, False, [2.0, 2.0]), {})
+cnt: 2, ((T([6, 256, 44, 44], f16), None, False, [2.0, 2.0]), {})
+cnt: 2, ((T([6, 128, 88, 88], f16), None, False, [2.0, 2.0]), {})
+cnt: 2, ((T([6, 64, 176, 176], f16), None, False, [2.0, 2.0]), {})
+Operator: aten.upsample_bilinear2d_backward.vec
+cnt: 2, ((T([6, 64, 352, 352], f16), None, [6, 64, 176, 176], False, [2.0, 2.0]), {})
+cnt: 2, ((T([6, 128, 176, 176], f16), None, [6, 128, 88, 88], False, [2.0, 2.0]), {})
+cnt: 2, ((T([6, 256, 88, 88], f16), None, [6, 256, 44, 44], False, [2.0, 2.0]), {})
+cnt: 2, ((T([6, 512, 44, 44], f16), None, [6, 512, 22, 22], False, [2.0, 2.0]), {})
+cnt: 2, ((T([6, 512, 22, 22], f16), None, [6, 512, 11, 11], False, [2.0, 2.0]), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/alexnet_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/alexnet_training.txt
new file mode 100644
index 0000000000000..a235e1b0535ee
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/alexnet_training.txt
@@ -0,0 +1,58 @@
+Operator: aten._adaptive_avg_pool2d.default
+cnt: 1, ((T([128, 256, 6, 6], f16), [6, 6]), {})
+Operator: aten._adaptive_avg_pool2d_backward.default
+cnt: 1, ((T([128, 256, 6, 6], f16), T([128, 256, 6, 6], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([4096], f16), T([128, 9216], f16), T([9216, 4096], f16, stride=(1, 9216))), {})
+cnt: 1, ((T([4096], f16), T([128, 4096], f16), T([4096, 4096], f16, stride=(1, 4096))), {})
+cnt: 1, ((T([1000], f16), T([128, 4096], f16), T([4096, 1000], f16, stride=(1, 4096))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([64, 3, 11, 11], f16), T([64], f16), [4, 4], [2, 2], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 27, 27], f16), T([192, 64, 5, 5], f16), T([192], f16), [1, 1], [2, 2], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 192, 13, 13], f16), T([384, 192, 3, 3], f16), T([384], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 384, 13, 13], f16), T([256, 384, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 13, 13], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 256, 13, 13], f16), T([128, 256, 13, 13], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 13, 13], f16), T([128, 384, 13, 13], f16), T([256, 384, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 384, 13, 13], f16), T([128, 192, 13, 13], f16), T([384, 192, 3, 3], f16), [384], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 192, 27, 27], f16), T([128, 64, 27, 27], f16), T([192, 64, 5, 5], f16), [192], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 55, 55], f16), T([128, 3, 224, 224], f16), T([64, 3, 11, 11], f16), [64], [4, 4], [2, 2], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 128000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 64, 55, 55], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 192, 27, 27], f16), [3, 3], [2, 2]), {})
+cnt: 1, ((T([128, 256, 13, 13], f16), [3, 3], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 256, 6, 6], f16), T([128, 256, 13, 13], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 256, 6, 6], i64)), {})
+cnt: 1, ((T([128, 192, 13, 13], f16), T([128, 192, 27, 27], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 192, 13, 13], i64)), {})
+cnt: 1, ((T([128, 64, 27, 27], f16), T([128, 64, 55, 55], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([128, 64, 27, 27], i64)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16, stride=(0, 0)), T([1000, 4096], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(0, 0)), T([128, 4096], f16)), {})
+cnt: 1, ((T([128, 4096], f16), T([4096, 4096], f16)), {})
+cnt: 1, ((T([4096, 128], f16, stride=(1, 4096)), T([128, 4096], f16)), {})
+cnt: 1, ((T([128, 4096], f16), T([4096, 9216], f16)), {})
+cnt: 1, ((T([4096, 128], f16, stride=(1, 4096)), T([128, 9216], f16)), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 64, 55, 55], f16),), {})
+cnt: 1, ((T([128, 192, 27, 27], f16),), {})
+cnt: 1, ((T([128, 384, 13, 13], f16),), {})
+cnt: 2, ((T([128, 256, 13, 13], f16),), {})
+cnt: 2, ((T([128, 4096], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16, stride=(0, 0)), [0], True), {})
+cnt: 2, ((T([128, 4096], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([128, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([128, 4096], f16), T([128, 4096], f16), 0), {})
+cnt: 2, ((T([128, 256, 13, 13], f16), T([128, 256, 13, 13], f16), 0), {})
+cnt: 1, ((T([128, 384, 13, 13], f16), T([128, 384, 13, 13], f16), 0), {})
+cnt: 1, ((T([128, 192, 27, 27], f16), T([128, 192, 27, 27], f16), 0), {})
+cnt: 1, ((T([128, 64, 55, 55], f16), T([128, 64, 55, 55], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/attention_is_all_you_need_pytorch_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/attention_is_all_you_need_pytorch_training.txt
new file mode 100644
index 0000000000000..16700c6bb7da4
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/attention_is_all_you_need_pytorch_training.txt
@@ -0,0 +1,148 @@
+Operator: aten._softmax.default
+cnt: 6, ((T([256, 8, 33, 33], f16), -1, False), {})
+cnt: 6, ((T([256, 8, 31, 31], f16), -1, False), {})
+cnt: 6, ((T([256, 8, 31, 33], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 6, ((T([256, 8, 31, 33], f16), T([256, 8, 31, 33], f16), -1, f16), {})
+cnt: 6, ((T([256, 8, 31, 31], f16), T([256, 8, 31, 31], f16), -1, f16), {})
+cnt: 6, ((T([256, 8, 33, 33], f16), T([256, 8, 33, 33], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([1, 31, 31], f32),), {'dtype': torch.bool})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([8448, 512], f16), [256, 33, 512]), {})
+cnt: 24, ((T([256, 8, 33, 64], f16), [2048, 33, 64]), {})
+cnt: 12, ((T([256, 8, 64, 33], f16), [2048, 64, 33]), {})
+cnt: 6, ((T([2048, 33, 33], f16), [256, 8, 33, 33]), {})
+cnt: 6, ((T([2048, 33, 64], f16), [256, 8, 33, 64]), {})
+cnt: 36, ((T([7936, 512], f16), [256, 31, 512]), {})
+cnt: 30, ((T([256, 8, 31, 64], f16), [2048, 31, 64]), {})
+cnt: 6, ((T([256, 8, 64, 31], f16), [2048, 64, 31]), {})
+cnt: 6, ((T([2048, 31, 31], f16), [256, 8, 31, 31]), {})
+cnt: 12, ((T([2048, 31, 64], f16), [256, 8, 31, 64]), {})
+cnt: 6, ((T([2048, 31, 33], f16), [256, 8, 31, 33]), {})
+cnt: 1, ((T([7936, 9521], f16), [256, 31, 9521]), {})
+cnt: 18, ((T([256, 33, 8, 64], f16), [256, 33, 512]), {})
+cnt: 12, ((T([256, 33, 512], f16), [8448, 512]), {})
+cnt: 18, ((T([256, 31, 8, 64], f16), [256, 31, 512]), {})
+cnt: 6, ((T([256, 31, 512], f16), [7936, 512]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([256, 33, 512], f16), T([1, 33, 512], f16)), {})
+cnt: 1, ((T([256, 31, 512], f16), T([1, 31, 512], f16)), {})
+cnt: 30, ((T([256, 31, 512], f16), T([256, 31, 512], f16)), {})
+cnt: 35, ((T([256, 33, 512], f16), T([256, 33, 512], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 12, ((T([256, 33, 512], f16), T([256, 33, 512], f16)), {})
+cnt: 18, ((T([256, 31, 512], f16), T([256, 31, 512], f16)), {})
+Operator: aten.addmm.default
+cnt: 6, ((T([2048], f16), T([8448, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 6, ((T([512], f16), T([8448, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+cnt: 6, ((T([2048], f16), T([7936, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 6, ((T([512], f16), T([7936, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+Operator: aten.bitwise_and.Tensor
+cnt: 1, ((T([256, 1, 31], b8, stride=(1, 7936, 256)), T([1, 31, 31], b8)), {})
+Operator: aten.bmm.default
+cnt: 6, ((T([2048, 33, 64], f16), T([2048, 64, 33], f16)), {})
+cnt: 6, ((T([2048, 33, 33], f16), T([2048, 33, 64], f16)), {})
+cnt: 6, ((T([2048, 31, 64], f16), T([2048, 64, 31], f16)), {})
+cnt: 6, ((T([2048, 31, 31], f16), T([2048, 31, 64], f16)), {})
+cnt: 6, ((T([2048, 31, 64], f16), T([2048, 64, 33], f16)), {})
+cnt: 6, ((T([2048, 31, 33], f16), T([2048, 33, 64], f16)), {})
+cnt: 6, ((T([2048, 33, 31], f16, stride=(1023, 1, 33)), T([2048, 31, 64], f16)), {})
+cnt: 6, ((T([2048, 31, 64], f16), T([2048, 64, 33], f16, stride=(2112, 1, 64))), {})
+cnt: 6, ((T([2048, 64, 31], f16, stride=(1984, 1, 64)), T([2048, 31, 33], f16)), {})
+cnt: 6, ((T([2048, 31, 33], f16), T([2048, 33, 64], f16, stride=(2112, 1, 33))), {})
+cnt: 6, ((T([2048, 31, 31], f16, stride=(961, 1, 31)), T([2048, 31, 64], f16)), {})
+cnt: 6, ((T([2048, 31, 64], f16), T([2048, 64, 31], f16, stride=(1984, 1, 64))), {})
+cnt: 6, ((T([2048, 64, 31], f16, stride=(1984, 1, 64)), T([2048, 31, 31], f16)), {})
+cnt: 6, ((T([2048, 31, 31], f16), T([2048, 31, 64], f16, stride=(1984, 1, 31))), {})
+cnt: 6, ((T([2048, 33, 33], f16, stride=(1089, 1, 33)), T([2048, 33, 64], f16)), {})
+cnt: 6, ((T([2048, 33, 64], f16), T([2048, 64, 33], f16, stride=(2112, 1, 64))), {})
+cnt: 6, ((T([2048, 64, 33], f16, stride=(2112, 1, 64)), T([2048, 33, 33], f16)), {})
+cnt: 6, ((T([2048, 33, 33], f16), T([2048, 33, 64], f16, stride=(2112, 1, 33))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([256, 33], i64, stride=(1, 256)),), {})
+cnt: 1, ((T([256, 31], i64, stride=(1, 256)),), {})
+cnt: 1, ((T([1, 33, 512], f16),), {})
+cnt: 1, ((T([1, 31, 512], f16),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([256, 33], i64, stride=(1, 256)), T([256, 33], i64, stride=(1, 256))), {})
+cnt: 1, ((T([256, 31], i64, stride=(1, 256)), T([256, 31], i64, stride=(1, 256))), {})
+cnt: 12, ((T([256, 31, 512], f16), T([256, 31, 512], f16)), {})
+cnt: 6, ((T([7936, 512], f16), T([7936, 512], f16)), {})
+cnt: 12, ((T([256, 33, 512], f16), T([256, 33, 512], f16)), {})
+cnt: 6, ((T([8448, 512], f16), T([8448, 512], f16)), {})
+Operator: aten.div.Tensor
+cnt: 6, ((T([256, 8, 33, 64], f16, stride=(16896, 64, 512, 1)), 8.0), {})
+cnt: 12, ((T([256, 8, 31, 64], f16, stride=(15872, 64, 512, 1)), 8.0), {})
+cnt: 2, ((T([], f16), 75558656), {})
+cnt: 12, ((T([256, 8, 31, 64], f16), 8.0), {})
+cnt: 6, ((T([256, 8, 33, 64], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([9521, 512], f16), T([256, 33], i64, stride=(1, 256)), 1), {})
+cnt: 1, ((T([9521, 512], f16), T([256, 31], i64, stride=(1, 256)), 1), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([256, 31, 512], f16), T([256, 31], i64, stride=(1, 256)), 9521, 1, False), {})
+cnt: 1, ((T([256, 33, 512], f16), T([256, 33], i64, stride=(1, 256)), 9521, 1, False), {})
+Operator: aten.eq.Scalar
+cnt: 12, ((T([256, 1, 1, 33], b8, stride=(1, 8448, 8448, 256)), 0), {})
+cnt: 6, ((T([256, 1, 31, 31], b8, stride=(1, 7936, 256, 7936)), 0), {})
+Operator: aten.masked_fill.Scalar
+cnt: 6, ((T([256, 8, 33, 33], f16), T([256, 1, 1, 33], b8, stride=(1, 8448, 8448, 256)), -65504.0), {})
+cnt: 6, ((T([256, 8, 31, 31], f16), T([256, 1, 31, 31], b8, stride=(1, 7936, 256, 7936)), -65504.0), {})
+cnt: 6, ((T([256, 8, 31, 33], f16), T([256, 1, 1, 33], b8, stride=(1, 8448, 8448, 256)), -65504.0), {})
+cnt: 6, ((T([256, 8, 31, 33], f16), T([256, 1, 1, 33], b8, stride=(1, 8448, 8448, 256)), 0), {})
+cnt: 6, ((T([256, 8, 31, 31], f16), T([256, 1, 31, 31], b8, stride=(1, 7936, 256, 7936)), 0), {})
+cnt: 6, ((T([256, 8, 33, 33], f16), T([256, 1, 1, 33], b8, stride=(1, 8448, 8448, 256)), 0), {})
+Operator: aten.mm.default
+cnt: 36, ((T([8448, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 36, ((T([7936, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 1, ((T([7936, 512], f16), T([512, 9521], f16, stride=(1, 512))), {})
+cnt: 1, ((T([9521, 7936], f16, stride=(1, 9521)), T([7936, 512], f16)), {})
+cnt: 1, ((T([7936, 9521], f16), T([9521, 512], f16)), {})
+cnt: 6, ((T([7936, 512], f16), T([512, 2048], f16)), {})
+cnt: 6, ((T([512, 7936], f16, stride=(1, 512)), T([7936, 2048], f16)), {})
+cnt: 6, ((T([7936, 2048], f16), T([2048, 512], f16)), {})
+cnt: 6, ((T([2048, 7936], f16, stride=(1, 2048)), T([7936, 512], f16)), {})
+cnt: 36, ((T([512, 7936], f16, stride=(1, 512)), T([7936, 512], f16)), {})
+cnt: 36, ((T([7936, 512], f16), T([512, 512], f16)), {})
+cnt: 36, ((T([512, 8448], f16, stride=(1, 512)), T([8448, 512], f16)), {})
+cnt: 36, ((T([8448, 512], f16), T([512, 512], f16)), {})
+cnt: 6, ((T([8448, 512], f16), T([512, 2048], f16)), {})
+cnt: 6, ((T([512, 8448], f16, stride=(1, 512)), T([8448, 2048], f16)), {})
+cnt: 6, ((T([8448, 2048], f16), T([2048, 512], f16)), {})
+cnt: 6, ((T([2048, 8448], f16, stride=(1, 2048)), T([8448, 512], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([256, 31, 9521], f16), 1.0), {})
+cnt: 1, ((T([256, 31, 9521], f16, stride=(0, 0, 0)), 1.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 13, ((T([256, 33, 512], f16), [512], T([512], f16), T([512], f16), 1e-06), {})
+cnt: 19, ((T([256, 31, 512], f16), [512], T([512], f16), T([512], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 19, ((T([256, 31, 512], f16), T([256, 31, 512], f16), [512], T([256, 31, 1], f32), T([256, 31, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 13, ((T([256, 33, 512], f16), T([256, 33, 512], f16), [512], T([256, 33, 1], f32), T([256, 33, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([256, 33], i64, stride=(1, 256)), 1), {})
+cnt: 1, ((T([256, 31], i64, stride=(1, 256)), 1), {})
+Operator: aten.new_empty_strided.default
+cnt: 6, ((T([7936, 512], f16), [7936, 512], [512, 1]), {})
+cnt: 6, ((T([8448, 512], f16), [8448, 512], [512, 1]), {})
+Operator: aten.new_zeros.default
+cnt: 6, ((T([256, 31, 512], f16), [4063232]), {})
+cnt: 6, ((T([256, 33, 512], f16), [4325376]), {})
+Operator: aten.relu.default
+cnt: 6, ((T([256, 33, 2048], f16),), {})
+cnt: 6, ((T([256, 31, 2048], f16),), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([1, 31, 31], f32), 1), {})
+Operator: aten.sum.SymInt
+cnt: 6, ((T([7936, 512], f16), [0], True), {})
+cnt: 6, ((T([7936, 2048], f16), [0], True), {})
+cnt: 6, ((T([8448, 512], f16), [0], True), {})
+cnt: 6, ((T([8448, 2048], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([7936, 9521], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 6, ((T([256, 31, 2048], f16), T([256, 31, 2048], f16), 0), {})
+cnt: 6, ((T([256, 33, 2048], f16), T([256, 33, 2048], f16), 0), {})
+Operator: aten.triu.default
+cnt: 1, ((T([1, 31, 31], f32), 1), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/dcgan_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/dcgan_training.txt
new file mode 100644
index 0000000000000..0adf5dcbf66d2
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/dcgan_training.txt
@@ -0,0 +1,42 @@
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 64, 64], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 64, 64], f16), T([64, 3, 4, 4], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 32, 32], f16), T([128, 64, 4, 4], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 16, 16], f16), T([256, 128, 4, 4], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 8, 8], f16), T([512, 256, 4, 4], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 4, 4], f16), T([1, 512, 4, 4], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 1, 1, 1], f16), T([32, 512, 4, 4], f16), T([1, 512, 4, 4], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 4, 4], f16), T([32, 256, 8, 8], f16), T([512, 256, 4, 4], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 8, 8], f16), T([32, 128, 16, 16], f16), T([256, 128, 4, 4], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 16, 16], f16), T([32, 64, 32, 32], f16), T([128, 64, 4, 4], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 32, 32], f16), T([32, 3, 64, 64], f16), T([64, 3, 4, 4], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 64, 64], f16), T([32, 3, 64, 64], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 32), {})
+Operator: aten.leaky_relu_.default
+cnt: 1, ((T([32, 64, 32, 32], f16), 0.2), {})
+cnt: 1, ((T([32, 128, 16, 16], f16), 0.2), {})
+cnt: 1, ((T([32, 256, 8, 8], f16), 0.2), {})
+cnt: 1, ((T([32, 512, 4, 4], f16), 0.2), {})
+Operator: aten.leaky_relu_backward.default
+cnt: 1, ((T([32, 512, 4, 4], f16), T([32, 512, 4, 4], f16), 0.2, True), {})
+cnt: 1, ((T([32, 256, 8, 8], f16), T([32, 256, 8, 8], f16), 0.2, True), {})
+cnt: 1, ((T([32, 128, 16, 16], f16), T([32, 128, 16, 16], f16), 0.2, True), {})
+cnt: 1, ((T([32, 64, 32, 32], f16), T([32, 64, 32, 32], f16), 0.2, True), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([32, 128, 16, 16], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 256, 8, 8], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 512, 4, 4], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([32, 512, 4, 4], f16), T([32, 512, 4, 4], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 256, 8, 8], f16), T([32, 256, 8, 8], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 16, 16], f16), T([32, 128, 16, 16], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([32, 1, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([32, 1, 1, 1], f16, stride=(0, 0, 0, 0)), T([32, 1, 1, 1], f16)), {})
+Operator: aten.sum.default
+cnt: 1, ((T([32, 1, 1, 1], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/densenet121_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/densenet121_training.txt
new file mode 100644
index 0000000000000..80f89b7834620
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/densenet121_training.txt
@@ -0,0 +1,609 @@
+Operator: aten.add.Tensor
+cnt: 1, ((T([4, 512, 7, 7], f16, stride=(50176, 49, 7, 1)), T([4, 512, 7, 7], f16, stride=(48608, 49, 7, 1))), {})
+cnt: 15, ((T([4, 32, 7, 7], f16, stride=(50176, 49, 7, 1)), T([4, 32, 7, 7], f16, stride=(48608, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(47040, 49, 7, 1))), {})
+cnt: 14, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(47040, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(45472, 49, 7, 1))), {})
+cnt: 13, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(45472, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(43904, 49, 7, 1))), {})
+cnt: 12, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(43904, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(42336, 49, 7, 1))), {})
+cnt: 11, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(42336, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(40768, 49, 7, 1))), {})
+cnt: 10, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(40768, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(39200, 49, 7, 1))), {})
+cnt: 9, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(39200, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(37632, 49, 7, 1))), {})
+cnt: 8, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(37632, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(36064, 49, 7, 1))), {})
+cnt: 7, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(36064, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(34496, 49, 7, 1))), {})
+cnt: 6, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(34496, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(32928, 49, 7, 1))), {})
+cnt: 5, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(32928, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(31360, 49, 7, 1))), {})
+cnt: 4, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(31360, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(29792, 49, 7, 1))), {})
+cnt: 3, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(29792, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(28224, 49, 7, 1))), {})
+cnt: 2, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(28224, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16, stride=(26656, 49, 7, 1))), {})
+cnt: 1, ((T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16, stride=(26656, 49, 7, 1))), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16)), {})
+cnt: 1, ((T([4, 256, 14, 14], f16, stride=(200704, 196, 14, 1)), T([4, 256, 14, 14], f16, stride=(194432, 196, 14, 1))), {})
+cnt: 23, ((T([4, 32, 14, 14], f16, stride=(200704, 196, 14, 1)), T([4, 32, 14, 14], f16, stride=(194432, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(188160, 196, 14, 1))), {})
+cnt: 22, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(188160, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(181888, 196, 14, 1))), {})
+cnt: 21, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(181888, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(175616, 196, 14, 1))), {})
+cnt: 20, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(175616, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(169344, 196, 14, 1))), {})
+cnt: 19, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(169344, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(163072, 196, 14, 1))), {})
+cnt: 18, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(163072, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(156800, 196, 14, 1))), {})
+cnt: 17, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(156800, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(150528, 196, 14, 1))), {})
+cnt: 16, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(150528, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(144256, 196, 14, 1))), {})
+cnt: 15, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(144256, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(137984, 196, 14, 1))), {})
+cnt: 14, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(137984, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(131712, 196, 14, 1))), {})
+cnt: 13, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(131712, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(125440, 196, 14, 1))), {})
+cnt: 12, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(125440, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(119168, 196, 14, 1))), {})
+cnt: 11, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(119168, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(112896, 196, 14, 1))), {})
+cnt: 10, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(112896, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(106624, 196, 14, 1))), {})
+cnt: 9, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(106624, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(100352, 196, 14, 1))), {})
+cnt: 8, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(100352, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(94080, 196, 14, 1))), {})
+cnt: 7, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(94080, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(87808, 196, 14, 1))), {})
+cnt: 6, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(87808, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(81536, 196, 14, 1))), {})
+cnt: 5, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(81536, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(75264, 196, 14, 1))), {})
+cnt: 4, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(75264, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(68992, 196, 14, 1))), {})
+cnt: 3, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(68992, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(62720, 196, 14, 1))), {})
+cnt: 2, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(62720, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16, stride=(56448, 196, 14, 1))), {})
+cnt: 1, ((T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16, stride=(56448, 196, 14, 1))), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16)), {})
+cnt: 1, ((T([4, 128, 28, 28], f16, stride=(401408, 784, 28, 1)), T([4, 128, 28, 28], f16, stride=(376320, 784, 28, 1))), {})
+cnt: 11, ((T([4, 32, 28, 28], f16, stride=(401408, 784, 28, 1)), T([4, 32, 28, 28], f16, stride=(376320, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(351232, 784, 28, 1))), {})
+cnt: 10, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(351232, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(326144, 784, 28, 1))), {})
+cnt: 9, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(326144, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(301056, 784, 28, 1))), {})
+cnt: 8, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(301056, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(275968, 784, 28, 1))), {})
+cnt: 7, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(275968, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(250880, 784, 28, 1))), {})
+cnt: 6, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(250880, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(225792, 784, 28, 1))), {})
+cnt: 5, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(225792, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(200704, 784, 28, 1))), {})
+cnt: 4, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(200704, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(175616, 784, 28, 1))), {})
+cnt: 3, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(175616, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(150528, 784, 28, 1))), {})
+cnt: 2, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(150528, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16, stride=(125440, 784, 28, 1))), {})
+cnt: 1, ((T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16, stride=(125440, 784, 28, 1))), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16)), {})
+cnt: 1, ((T([4, 64, 56, 56], f16, stride=(802816, 3136, 56, 1)), T([4, 64, 56, 56], f16, stride=(702464, 3136, 56, 1))), {})
+cnt: 5, ((T([4, 32, 56, 56], f16, stride=(802816, 3136, 56, 1)), T([4, 32, 56, 56], f16, stride=(702464, 3136, 56, 1))), {})
+cnt: 1, ((T([4, 64, 56, 56], f16), T([4, 64, 56, 56], f16, stride=(602112, 3136, 56, 1))), {})
+cnt: 4, ((T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16, stride=(602112, 3136, 56, 1))), {})
+cnt: 1, ((T([4, 64, 56, 56], f16), T([4, 64, 56, 56], f16, stride=(501760, 3136, 56, 1))), {})
+cnt: 3, ((T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16, stride=(501760, 3136, 56, 1))), {})
+cnt: 1, ((T([4, 64, 56, 56], f16), T([4, 64, 56, 56], f16, stride=(401408, 3136, 56, 1))), {})
+cnt: 2, ((T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16, stride=(401408, 3136, 56, 1))), {})
+cnt: 1, ((T([4, 64, 56, 56], f16), T([4, 64, 56, 56], f16, stride=(301056, 3136, 56, 1))), {})
+cnt: 1, ((T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16, stride=(301056, 3136, 56, 1))), {})
+cnt: 1, ((T([4, 64, 56, 56], f16), T([4, 64, 56, 56], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([4, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([4, 128, 56, 56], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([4, 256, 28, 28], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([4, 512, 14, 14], f16), [2, 2], [2, 2]), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 14, 14], f16), [2, 2], [2, 2], [0, 0], False, True, None), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 28, 28], f16), [2, 2], [2, 2], [0, 0], False, True, None), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 56, 56], f16), [2, 2], [2, 2], [0, 0], False, True, None), {})
+Operator: aten.cat.default
+cnt: 1, (([T([4, 64, 56, 56], f16)], 1), {})
+cnt: 1, (([T([4, 64, 56, 56], f16), T([4, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([4, 64, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([4, 64, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([4, 64, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([4, 64, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([4, 64, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16), T([4, 32, 56, 56], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 128, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16), T([4, 32, 28, 28], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 256, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16), T([4, 32, 14, 14], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+cnt: 1, (([T([4, 512, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16), T([4, 32, 7, 7], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([4, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([4, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 64, 56, 56], f16), T([128, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([4, 128, 56, 56], f16), T([32, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 96, 56, 56], f16), T([128, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 128, 56, 56], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 160, 56, 56], f16), T([128, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 192, 56, 56], f16), T([128, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 224, 56, 56], f16), T([128, 224, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([128, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([4, 128, 28, 28], f16), T([32, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 160, 28, 28], f16), T([128, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 192, 28, 28], f16), T([128, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 224, 28, 28], f16), T([128, 224, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 28, 28], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 288, 28, 28], f16), T([128, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 320, 28, 28], f16), T([128, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 352, 28, 28], f16), T([128, 352, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 384, 28, 28], f16), T([128, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 416, 28, 28], f16), T([128, 416, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 448, 28, 28], f16), T([128, 448, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 480, 28, 28], f16), T([128, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 24, ((T([4, 128, 14, 14], f16), T([32, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 288, 14, 14], f16), T([128, 288, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 320, 14, 14], f16), T([128, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 352, 14, 14], f16), T([128, 352, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 384, 14, 14], f16), T([128, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 416, 14, 14], f16), T([128, 416, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 448, 14, 14], f16), T([128, 448, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 480, 14, 14], f16), T([128, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 512, 14, 14], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 544, 14, 14], f16), T([128, 544, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 576, 14, 14], f16), T([128, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 608, 14, 14], f16), T([128, 608, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 640, 14, 14], f16), T([128, 640, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 672, 14, 14], f16), T([128, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 704, 14, 14], f16), T([128, 704, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 736, 14, 14], f16), T([128, 736, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 768, 14, 14], f16), T([128, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 800, 14, 14], f16), T([128, 800, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 832, 14, 14], f16), T([128, 832, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 864, 14, 14], f16), T([128, 864, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 896, 14, 14], f16), T([128, 896, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 928, 14, 14], f16), T([128, 928, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 960, 14, 14], f16), T([128, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 992, 14, 14], f16), T([128, 992, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 16, ((T([4, 128, 7, 7], f16), T([32, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 544, 7, 7], f16), T([128, 544, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 576, 7, 7], f16), T([128, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 608, 7, 7], f16), T([128, 608, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 640, 7, 7], f16), T([128, 640, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 672, 7, 7], f16), T([128, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 704, 7, 7], f16), T([128, 704, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 736, 7, 7], f16), T([128, 736, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 768, 7, 7], f16), T([128, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 800, 7, 7], f16), T([128, 800, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 832, 7, 7], f16), T([128, 832, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 864, 7, 7], f16), T([128, 864, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 896, 7, 7], f16), T([128, 896, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 928, 7, 7], f16), T([128, 928, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 960, 7, 7], f16), T([128, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 992, 7, 7], f16), T([128, 992, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([4, 32, 7, 7], f16, stride=(50176, 49, 7, 1)), T([4, 128, 7, 7], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 992, 7, 7], f16), T([128, 992, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 15, ((T([4, 32, 7, 7], f16), T([4, 128, 7, 7], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 960, 7, 7], f16), T([128, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 928, 7, 7], f16), T([128, 928, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 896, 7, 7], f16), T([128, 896, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 864, 7, 7], f16), T([128, 864, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 832, 7, 7], f16), T([128, 832, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 800, 7, 7], f16), T([128, 800, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 768, 7, 7], f16), T([128, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 736, 7, 7], f16), T([128, 736, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 704, 7, 7], f16), T([128, 704, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 672, 7, 7], f16), T([128, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 640, 7, 7], f16), T([128, 640, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 608, 7, 7], f16), T([128, 608, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 576, 7, 7], f16), T([128, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 544, 7, 7], f16), T([128, 544, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 7, 7], f16), T([4, 512, 7, 7], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 512, 14, 14], f16), T([4, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 32, 14, 14], f16, stride=(200704, 196, 14, 1)), T([4, 128, 14, 14], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 992, 14, 14], f16), T([128, 992, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 23, ((T([4, 32, 14, 14], f16), T([4, 128, 14, 14], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 960, 14, 14], f16), T([128, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 928, 14, 14], f16), T([128, 928, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 896, 14, 14], f16), T([128, 896, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 864, 14, 14], f16), T([128, 864, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 832, 14, 14], f16), T([128, 832, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 800, 14, 14], f16), T([128, 800, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 768, 14, 14], f16), T([128, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 736, 14, 14], f16), T([128, 736, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 704, 14, 14], f16), T([128, 704, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 672, 14, 14], f16), T([128, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 640, 14, 14], f16), T([128, 640, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 608, 14, 14], f16), T([128, 608, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 576, 14, 14], f16), T([128, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 544, 14, 14], f16), T([128, 544, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 512, 14, 14], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 480, 14, 14], f16), T([128, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 448, 14, 14], f16), T([128, 448, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 416, 14, 14], f16), T([128, 416, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 384, 14, 14], f16), T([128, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 352, 14, 14], f16), T([128, 352, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 320, 14, 14], f16), T([128, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 288, 14, 14], f16), T([128, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 14, 14], f16), T([4, 256, 14, 14], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 256, 28, 28], f16), T([4, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 32, 28, 28], f16, stride=(401408, 784, 28, 1)), T([4, 128, 28, 28], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 480, 28, 28], f16), T([128, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 11, ((T([4, 32, 28, 28], f16), T([4, 128, 28, 28], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 448, 28, 28], f16), T([128, 448, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 416, 28, 28], f16), T([128, 416, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 384, 28, 28], f16), T([128, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 352, 28, 28], f16), T([128, 352, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 320, 28, 28], f16), T([128, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 288, 28, 28], f16), T([128, 288, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 256, 28, 28], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 224, 28, 28], f16), T([128, 224, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 192, 28, 28], f16), T([128, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 160, 28, 28], f16), T([128, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 56, 56], f16), T([4, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 32, 56, 56], f16, stride=(802816, 3136, 56, 1)), T([4, 128, 56, 56], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 56, 56], f16), T([4, 224, 56, 56], f16), T([128, 224, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([4, 32, 56, 56], f16), T([4, 128, 56, 56], f16), T([32, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 56, 56], f16), T([4, 192, 56, 56], f16), T([128, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 56, 56], f16), T([4, 160, 56, 56], f16), T([128, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 56, 56], f16), T([4, 128, 56, 56], f16), T([128, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 56, 56], f16), T([4, 96, 56, 56], f16), T([128, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 56, 56], f16), T([4, 64, 56, 56], f16), T([128, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 64, 112, 112], f16), T([4, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([4, 3, 224, 224], f16), T([4, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([4, 1024, 7, 7], f16, stride=(1024, 1, 0, 0)), 49), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 4000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([4, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([4, 64, 56, 56], f16), T([4, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([4, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([4, 1024, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([4, 1000], f16, stride=(0, 0)), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 4], f16, stride=(0, 0)), T([4, 1024], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([4, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 7, ((T([4, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 160, 56, 56], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 192, 56, 56], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 224, 56, 56], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 13, ((T([4, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 160, 28, 28], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 224, 28, 28], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 288, 28, 28], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 320, 28, 28], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 352, 28, 28], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 384, 28, 28], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 416, 28, 28], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 448, 28, 28], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 480, 28, 28], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 24, ((T([4, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 288, 14, 14], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 320, 14, 14], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 352, 14, 14], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 416, 14, 14], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 448, 14, 14], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 544, 14, 14], f16), T([544], f16), T([544], f16), T([544], f16), T([544], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 608, 14, 14], f16), T([608], f16), T([608], f16), T([608], f16), T([608], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 640, 14, 14], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 704, 14, 14], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 736, 14, 14], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 800, 14, 14], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 832, 14, 14], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 864, 14, 14], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 896, 14, 14], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 928, 14, 14], f16), T([928], f16), T([928], f16), T([928], f16), T([928], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 960, 14, 14], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 992, 14, 14], f16), T([992], f16), T([992], f16), T([992], f16), T([992], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 16, ((T([4, 128, 7, 7], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 544, 7, 7], f16), T([544], f16), T([544], f16), T([544], f16), T([544], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 608, 7, 7], f16), T([608], f16), T([608], f16), T([608], f16), T([608], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 640, 7, 7], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 704, 7, 7], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 736, 7, 7], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 768, 7, 7], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 800, 7, 7], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 832, 7, 7], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 864, 7, 7], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 896, 7, 7], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 928, 7, 7], f16), T([928], f16), T([928], f16), T([928], f16), T([928], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 992, 7, 7], f16), T([992], f16), T([992], f16), T([992], f16), T([992], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([4, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([4, 1024, 7, 7], f16), T([4, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 1e-05, [True, True, True]), {})
+cnt: 16, ((T([4, 128, 7, 7], f16), T([4, 128, 7, 7], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 992, 7, 7], f16), T([4, 992, 7, 7], f16), T([992], f16), T([992], f16), T([992], f16), T([992], f32), T([992], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 960, 7, 7], f16), T([4, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 928, 7, 7], f16), T([4, 928, 7, 7], f16), T([928], f16), T([928], f16), T([928], f16), T([928], f32), T([928], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 896, 7, 7], f16), T([4, 896, 7, 7], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f32), T([896], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 864, 7, 7], f16), T([4, 864, 7, 7], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f32), T([864], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 832, 7, 7], f16), T([4, 832, 7, 7], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f32), T([832], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 800, 7, 7], f16), T([4, 800, 7, 7], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f32), T([800], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 768, 7, 7], f16), T([4, 768, 7, 7], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 736, 7, 7], f16), T([4, 736, 7, 7], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f32), T([736], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 704, 7, 7], f16), T([4, 704, 7, 7], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f32), T([704], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 672, 7, 7], f16), T([4, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 640, 7, 7], f16), T([4, 640, 7, 7], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f32), T([640], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 608, 7, 7], f16), T([4, 608, 7, 7], f16), T([608], f16), T([608], f16), T([608], f16), T([608], f32), T([608], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 576, 7, 7], f16), T([4, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 544, 7, 7], f16), T([4, 544, 7, 7], f16), T([544], f16), T([544], f16), T([544], f16), T([544], f32), T([544], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 1024, 14, 14], f16), T([4, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 1e-05, [True, True, True]), {})
+cnt: 24, ((T([4, 128, 14, 14], f16), T([4, 128, 14, 14], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 992, 14, 14], f16), T([4, 992, 14, 14], f16), T([992], f16), T([992], f16), T([992], f16), T([992], f32), T([992], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 960, 14, 14], f16), T([4, 960, 14, 14], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 928, 14, 14], f16), T([4, 928, 14, 14], f16), T([928], f16), T([928], f16), T([928], f16), T([928], f32), T([928], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 896, 14, 14], f16), T([4, 896, 14, 14], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f32), T([896], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 864, 14, 14], f16), T([4, 864, 14, 14], f16), T([864], f16), T([864], f16), T([864], f16), T([864], f32), T([864], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 832, 14, 14], f16), T([4, 832, 14, 14], f16), T([832], f16), T([832], f16), T([832], f16), T([832], f32), T([832], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 800, 14, 14], f16), T([4, 800, 14, 14], f16), T([800], f16), T([800], f16), T([800], f16), T([800], f32), T([800], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 768, 14, 14], f16), T([4, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 736, 14, 14], f16), T([4, 736, 14, 14], f16), T([736], f16), T([736], f16), T([736], f16), T([736], f32), T([736], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 704, 14, 14], f16), T([4, 704, 14, 14], f16), T([704], f16), T([704], f16), T([704], f16), T([704], f32), T([704], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 672, 14, 14], f16), T([4, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 640, 14, 14], f16), T([4, 640, 14, 14], f16), T([640], f16), T([640], f16), T([640], f16), T([640], f32), T([640], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 608, 14, 14], f16), T([4, 608, 14, 14], f16), T([608], f16), T([608], f16), T([608], f16), T([608], f32), T([608], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 576, 14, 14], f16), T([4, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 544, 14, 14], f16), T([4, 544, 14, 14], f16), T([544], f16), T([544], f16), T([544], f16), T([544], f32), T([544], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 512, 14, 14], f16), T([4, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 480, 14, 14], f16), T([4, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 448, 14, 14], f16), T([4, 448, 14, 14], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 416, 14, 14], f16), T([4, 416, 14, 14], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f32), T([416], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 384, 14, 14], f16), T([4, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 352, 14, 14], f16), T([4, 352, 14, 14], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f32), T([352], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 320, 14, 14], f16), T([4, 320, 14, 14], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 288, 14, 14], f16), T([4, 288, 14, 14], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f32), T([288], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 512, 28, 28], f16), T([4, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 13, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 480, 28, 28], f16), T([4, 480, 28, 28], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 448, 28, 28], f16), T([4, 448, 28, 28], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 416, 28, 28], f16), T([4, 416, 28, 28], f16), T([416], f16), T([416], f16), T([416], f16), T([416], f32), T([416], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 384, 28, 28], f16), T([4, 384, 28, 28], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 352, 28, 28], f16), T([4, 352, 28, 28], f16), T([352], f16), T([352], f16), T([352], f16), T([352], f32), T([352], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 320, 28, 28], f16), T([4, 320, 28, 28], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 288, 28, 28], f16), T([4, 288, 28, 28], f16), T([288], f16), T([288], f16), T([288], f16), T([288], f32), T([288], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 256, 28, 28], f16), T([4, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 224, 28, 28], f16), T([4, 224, 28, 28], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 192, 28, 28], f16), T([4, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 160, 28, 28], f16), T([4, 160, 28, 28], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 256, 56, 56], f16), T([4, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([4, 128, 56, 56], f16), T([4, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 224, 56, 56], f16), T([4, 224, 56, 56], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 192, 56, 56], f16), T([4, 192, 56, 56], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 160, 56, 56], f16), T([4, 160, 56, 56], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 96, 56, 56], f16), T([4, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 64, 56, 56], f16), T([4, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([4, 64, 112, 112], f16), T([4, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([4, 64, 112, 112], f16),), {})
+cnt: 1, ((T([4, 64, 56, 56], f16),), {})
+cnt: 7, ((T([4, 128, 56, 56], f16),), {})
+cnt: 1, ((T([4, 96, 56, 56], f16),), {})
+cnt: 1, ((T([4, 160, 56, 56], f16),), {})
+cnt: 1, ((T([4, 192, 56, 56], f16),), {})
+cnt: 1, ((T([4, 224, 56, 56], f16),), {})
+cnt: 1, ((T([4, 256, 56, 56], f16),), {})
+cnt: 13, ((T([4, 128, 28, 28], f16),), {})
+cnt: 1, ((T([4, 160, 28, 28], f16),), {})
+cnt: 1, ((T([4, 192, 28, 28], f16),), {})
+cnt: 1, ((T([4, 224, 28, 28], f16),), {})
+cnt: 1, ((T([4, 256, 28, 28], f16),), {})
+cnt: 1, ((T([4, 288, 28, 28], f16),), {})
+cnt: 1, ((T([4, 320, 28, 28], f16),), {})
+cnt: 1, ((T([4, 352, 28, 28], f16),), {})
+cnt: 1, ((T([4, 384, 28, 28], f16),), {})
+cnt: 1, ((T([4, 416, 28, 28], f16),), {})
+cnt: 1, ((T([4, 448, 28, 28], f16),), {})
+cnt: 1, ((T([4, 480, 28, 28], f16),), {})
+cnt: 1, ((T([4, 512, 28, 28], f16),), {})
+cnt: 1, ((T([4, 256, 14, 14], f16),), {})
+cnt: 24, ((T([4, 128, 14, 14], f16),), {})
+cnt: 1, ((T([4, 288, 14, 14], f16),), {})
+cnt: 1, ((T([4, 320, 14, 14], f16),), {})
+cnt: 1, ((T([4, 352, 14, 14], f16),), {})
+cnt: 1, ((T([4, 384, 14, 14], f16),), {})
+cnt: 1, ((T([4, 416, 14, 14], f16),), {})
+cnt: 1, ((T([4, 448, 14, 14], f16),), {})
+cnt: 1, ((T([4, 480, 14, 14], f16),), {})
+cnt: 1, ((T([4, 512, 14, 14], f16),), {})
+cnt: 1, ((T([4, 544, 14, 14], f16),), {})
+cnt: 1, ((T([4, 576, 14, 14], f16),), {})
+cnt: 1, ((T([4, 608, 14, 14], f16),), {})
+cnt: 1, ((T([4, 640, 14, 14], f16),), {})
+cnt: 1, ((T([4, 672, 14, 14], f16),), {})
+cnt: 1, ((T([4, 704, 14, 14], f16),), {})
+cnt: 1, ((T([4, 736, 14, 14], f16),), {})
+cnt: 1, ((T([4, 768, 14, 14], f16),), {})
+cnt: 1, ((T([4, 800, 14, 14], f16),), {})
+cnt: 1, ((T([4, 832, 14, 14], f16),), {})
+cnt: 1, ((T([4, 864, 14, 14], f16),), {})
+cnt: 1, ((T([4, 896, 14, 14], f16),), {})
+cnt: 1, ((T([4, 928, 14, 14], f16),), {})
+cnt: 1, ((T([4, 960, 14, 14], f16),), {})
+cnt: 1, ((T([4, 992, 14, 14], f16),), {})
+cnt: 1, ((T([4, 1024, 14, 14], f16),), {})
+cnt: 1, ((T([4, 512, 7, 7], f16),), {})
+cnt: 16, ((T([4, 128, 7, 7], f16),), {})
+cnt: 1, ((T([4, 544, 7, 7], f16),), {})
+cnt: 1, ((T([4, 576, 7, 7], f16),), {})
+cnt: 1, ((T([4, 608, 7, 7], f16),), {})
+cnt: 1, ((T([4, 640, 7, 7], f16),), {})
+cnt: 1, ((T([4, 672, 7, 7], f16),), {})
+cnt: 1, ((T([4, 704, 7, 7], f16),), {})
+cnt: 1, ((T([4, 736, 7, 7], f16),), {})
+cnt: 1, ((T([4, 768, 7, 7], f16),), {})
+cnt: 1, ((T([4, 800, 7, 7], f16),), {})
+cnt: 1, ((T([4, 832, 7, 7], f16),), {})
+cnt: 1, ((T([4, 864, 7, 7], f16),), {})
+cnt: 1, ((T([4, 896, 7, 7], f16),), {})
+cnt: 1, ((T([4, 928, 7, 7], f16),), {})
+cnt: 1, ((T([4, 960, 7, 7], f16),), {})
+cnt: 1, ((T([4, 992, 7, 7], f16),), {})
+cnt: 1, ((T([4, 1024, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([4, 1000], f16, stride=(0, 0)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([4, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([4, 1024, 7, 7], f16), T([4, 1024, 7, 7], f16), 0), {})
+cnt: 16, ((T([4, 128, 7, 7], f16), T([4, 128, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 992, 7, 7], f16), T([4, 992, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 960, 7, 7], f16), T([4, 960, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 928, 7, 7], f16), T([4, 928, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 896, 7, 7], f16), T([4, 896, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 864, 7, 7], f16), T([4, 864, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 832, 7, 7], f16), T([4, 832, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 800, 7, 7], f16), T([4, 800, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 768, 7, 7], f16), T([4, 768, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 736, 7, 7], f16), T([4, 736, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 704, 7, 7], f16), T([4, 704, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 672, 7, 7], f16), T([4, 672, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 640, 7, 7], f16), T([4, 640, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 608, 7, 7], f16), T([4, 608, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 576, 7, 7], f16), T([4, 576, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 544, 7, 7], f16), T([4, 544, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 512, 7, 7], f16), T([4, 512, 7, 7], f16), 0), {})
+cnt: 1, ((T([4, 1024, 14, 14], f16), T([4, 1024, 14, 14], f16), 0), {})
+cnt: 24, ((T([4, 128, 14, 14], f16), T([4, 128, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 992, 14, 14], f16), T([4, 992, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 960, 14, 14], f16), T([4, 960, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 928, 14, 14], f16), T([4, 928, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 896, 14, 14], f16), T([4, 896, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 864, 14, 14], f16), T([4, 864, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 832, 14, 14], f16), T([4, 832, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 800, 14, 14], f16), T([4, 800, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 768, 14, 14], f16), T([4, 768, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 736, 14, 14], f16), T([4, 736, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 704, 14, 14], f16), T([4, 704, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 672, 14, 14], f16), T([4, 672, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 640, 14, 14], f16), T([4, 640, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 608, 14, 14], f16), T([4, 608, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 576, 14, 14], f16), T([4, 576, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 544, 14, 14], f16), T([4, 544, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 512, 14, 14], f16), T([4, 512, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 480, 14, 14], f16), T([4, 480, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 448, 14, 14], f16), T([4, 448, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 416, 14, 14], f16), T([4, 416, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 384, 14, 14], f16), T([4, 384, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 352, 14, 14], f16), T([4, 352, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 320, 14, 14], f16), T([4, 320, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 288, 14, 14], f16), T([4, 288, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 256, 14, 14], f16), T([4, 256, 14, 14], f16), 0), {})
+cnt: 1, ((T([4, 512, 28, 28], f16), T([4, 512, 28, 28], f16), 0), {})
+cnt: 13, ((T([4, 128, 28, 28], f16), T([4, 128, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 480, 28, 28], f16), T([4, 480, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 448, 28, 28], f16), T([4, 448, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 416, 28, 28], f16), T([4, 416, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 384, 28, 28], f16), T([4, 384, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 352, 28, 28], f16), T([4, 352, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 320, 28, 28], f16), T([4, 320, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 288, 28, 28], f16), T([4, 288, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 256, 28, 28], f16), T([4, 256, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 224, 28, 28], f16), T([4, 224, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 192, 28, 28], f16), T([4, 192, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 160, 28, 28], f16), T([4, 160, 28, 28], f16), 0), {})
+cnt: 1, ((T([4, 256, 56, 56], f16), T([4, 256, 56, 56], f16), 0), {})
+cnt: 7, ((T([4, 128, 56, 56], f16), T([4, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([4, 224, 56, 56], f16), T([4, 224, 56, 56], f16), 0), {})
+cnt: 1, ((T([4, 192, 56, 56], f16), T([4, 192, 56, 56], f16), 0), {})
+cnt: 1, ((T([4, 160, 56, 56], f16), T([4, 160, 56, 56], f16), 0), {})
+cnt: 1, ((T([4, 96, 56, 56], f16), T([4, 96, 56, 56], f16), 0), {})
+cnt: 1, ((T([4, 64, 56, 56], f16), T([4, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([4, 64, 112, 112], f16), T([4, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/fambench_dlrm_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/fambench_dlrm_training.txt
new file mode 100644
index 0000000000000..89e383e39c3a7
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/fambench_dlrm_training.txt
@@ -0,0 +1,1063 @@
+Operator: aten._embedding_bag.default
+cnt: 2, ((T([965, 192], f16), T([54824], i64), T([1024], i64), False, 0, True, T([54824], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54798], i64), T([1024], i64), False, 0, True, T([54798], f16)), {})
+cnt: 5, ((T([965, 192], f16), T([54763], i64), T([1024], i64), False, 0, True, T([54763], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54783], i64), T([1024], i64), False, 0, True, T([54783], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54762], i64), T([1024], i64), False, 0, True, T([54762], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54862], i64), T([1024], i64), False, 0, True, T([54862], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54743], i64), T([1024], i64), False, 0, True, T([54743], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54750], i64), T([1024], i64), False, 0, True, T([54750], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54705], i64), T([1024], i64), False, 0, True, T([54705], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54735], i64), T([1024], i64), False, 0, True, T([54735], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54736], i64), T([1024], i64), False, 0, True, T([54736], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54775], i64), T([1024], i64), False, 0, True, T([54775], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54710], i64), T([1024], i64), False, 0, True, T([54710], f16)), {})
+cnt: 4, ((T([965, 192], f16), T([54753], i64), T([1024], i64), False, 0, True, T([54753], f16)), {})
+cnt: 4, ((T([965, 192], f16), T([54833], i64), T([1024], i64), False, 0, True, T([54833], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54767], i64), T([1024], i64), False, 0, True, T([54767], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54749], i64), T([1024], i64), False, 0, True, T([54749], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54795], i64), T([1024], i64), False, 0, True, T([54795], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54813], i64), T([1024], i64), False, 0, True, T([54813], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54730], i64), T([1024], i64), False, 0, True, T([54730], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54768], i64), T([1024], i64), False, 0, True, T([54768], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54826], i64), T([1024], i64), False, 0, True, T([54826], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54701], i64), T([1024], i64), False, 0, True, T([54701], f16)), {})
+cnt: 6, ((T([965, 192], f16), T([54761], i64), T([1024], i64), False, 0, True, T([54761], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54807], i64), T([1024], i64), False, 0, True, T([54807], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54744], i64), T([1024], i64), False, 0, True, T([54744], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54745], i64), T([1024], i64), False, 0, True, T([54745], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54723], i64), T([1024], i64), False, 0, True, T([54723], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54797], i64), T([1024], i64), False, 0, True, T([54797], f16)), {})
+cnt: 4, ((T([965, 192], f16), T([54786], i64), T([1024], i64), False, 0, True, T([54786], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54816], i64), T([1024], i64), False, 0, True, T([54816], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54725], i64), T([1024], i64), False, 0, True, T([54725], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54819], i64), T([1024], i64), False, 0, True, T([54819], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54855], i64), T([1024], i64), False, 0, True, T([54855], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54782], i64), T([1024], i64), False, 0, True, T([54782], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54712], i64), T([1024], i64), False, 0, True, T([54712], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54799], i64), T([1024], i64), False, 0, True, T([54799], f16)), {})
+cnt: 4, ((T([965, 192], f16), T([54801], i64), T([1024], i64), False, 0, True, T([54801], f16)), {})
+cnt: 5, ((T([965, 192], f16), T([54818], i64), T([1024], i64), False, 0, True, T([54818], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54779], i64), T([1024], i64), False, 0, True, T([54779], f16)), {})
+cnt: 4, ((T([965, 192], f16), T([54719], i64), T([1024], i64), False, 0, True, T([54719], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54778], i64), T([1024], i64), False, 0, True, T([54778], f16)), {})
+cnt: 6, ((T([965, 192], f16), T([54760], i64), T([1024], i64), False, 0, True, T([54760], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54802], i64), T([1024], i64), False, 0, True, T([54802], f16)), {})
+cnt: 5, ((T([965, 192], f16), T([54776], i64), T([1024], i64), False, 0, True, T([54776], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54828], i64), T([1024], i64), False, 0, True, T([54828], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54715], i64), T([1024], i64), False, 0, True, T([54715], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54843], i64), T([1024], i64), False, 0, True, T([54843], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54756], i64), T([1024], i64), False, 0, True, T([54756], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54766], i64), T([1024], i64), False, 0, True, T([54766], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54697], i64), T([1024], i64), False, 0, True, T([54697], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54792], i64), T([1024], i64), False, 0, True, T([54792], f16)), {})
+cnt: 5, ((T([965, 192], f16), T([54793], i64), T([1024], i64), False, 0, True, T([54793], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54727], i64), T([1024], i64), False, 0, True, T([54727], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54733], i64), T([1024], i64), False, 0, True, T([54733], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54692], i64), T([1024], i64), False, 0, True, T([54692], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54758], i64), T([1024], i64), False, 0, True, T([54758], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54820], i64), T([1024], i64), False, 0, True, T([54820], f16)), {})
+cnt: 4, ((T([965, 192], f16), T([54787], i64), T([1024], i64), False, 0, True, T([54787], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54815], i64), T([1024], i64), False, 0, True, T([54815], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54814], i64), T([1024], i64), False, 0, True, T([54814], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54759], i64), T([1024], i64), False, 0, True, T([54759], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54757], i64), T([1024], i64), False, 0, True, T([54757], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54821], i64), T([1024], i64), False, 0, True, T([54821], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54769], i64), T([1024], i64), False, 0, True, T([54769], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54842], i64), T([1024], i64), False, 0, True, T([54842], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54718], i64), T([1024], i64), False, 0, True, T([54718], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54771], i64), T([1024], i64), False, 0, True, T([54771], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54844], i64), T([1024], i64), False, 0, True, T([54844], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54838], i64), T([1024], i64), False, 0, True, T([54838], f16)), {})
+cnt: 5, ((T([965, 192], f16), T([54781], i64), T([1024], i64), False, 0, True, T([54781], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54804], i64), T([1024], i64), False, 0, True, T([54804], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54788], i64), T([1024], i64), False, 0, True, T([54788], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54774], i64), T([1024], i64), False, 0, True, T([54774], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54829], i64), T([1024], i64), False, 0, True, T([54829], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54738], i64), T([1024], i64), False, 0, True, T([54738], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54777], i64), T([1024], i64), False, 0, True, T([54777], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54811], i64), T([1024], i64), False, 0, True, T([54811], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54772], i64), T([1024], i64), False, 0, True, T([54772], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54800], i64), T([1024], i64), False, 0, True, T([54800], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54741], i64), T([1024], i64), False, 0, True, T([54741], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54794], i64), T([1024], i64), False, 0, True, T([54794], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54773], i64), T([1024], i64), False, 0, True, T([54773], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54803], i64), T([1024], i64), False, 0, True, T([54803], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54789], i64), T([1024], i64), False, 0, True, T([54789], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54707], i64), T([1024], i64), False, 0, True, T([54707], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54737], i64), T([1024], i64), False, 0, True, T([54737], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54722], i64), T([1024], i64), False, 0, True, T([54722], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54747], i64), T([1024], i64), False, 0, True, T([54747], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54770], i64), T([1024], i64), False, 0, True, T([54770], f16)), {})
+cnt: 4, ((T([965, 192], f16), T([54780], i64), T([1024], i64), False, 0, True, T([54780], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54731], i64), T([1024], i64), False, 0, True, T([54731], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54836], i64), T([1024], i64), False, 0, True, T([54836], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54839], i64), T([1024], i64), False, 0, True, T([54839], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54714], i64), T([1024], i64), False, 0, True, T([54714], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54785], i64), T([1024], i64), False, 0, True, T([54785], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54729], i64), T([1024], i64), False, 0, True, T([54729], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54812], i64), T([1024], i64), False, 0, True, T([54812], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54734], i64), T([1024], i64), False, 0, True, T([54734], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54791], i64), T([1024], i64), False, 0, True, T([54791], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54827], i64), T([1024], i64), False, 0, True, T([54827], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54717], i64), T([1024], i64), False, 0, True, T([54717], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54716], i64), T([1024], i64), False, 0, True, T([54716], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54830], i64), T([1024], i64), False, 0, True, T([54830], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54732], i64), T([1024], i64), False, 0, True, T([54732], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54835], i64), T([1024], i64), False, 0, True, T([54835], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54831], i64), T([1024], i64), False, 0, True, T([54831], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54748], i64), T([1024], i64), False, 0, True, T([54748], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54746], i64), T([1024], i64), False, 0, True, T([54746], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54711], i64), T([1024], i64), False, 0, True, T([54711], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54739], i64), T([1024], i64), False, 0, True, T([54739], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54713], i64), T([1024], i64), False, 0, True, T([54713], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54847], i64), T([1024], i64), False, 0, True, T([54847], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54809], i64), T([1024], i64), False, 0, True, T([54809], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54742], i64), T([1024], i64), False, 0, True, T([54742], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54704], i64), T([1024], i64), False, 0, True, T([54704], f16)), {})
+cnt: 3, ((T([965, 192], f16), T([54784], i64), T([1024], i64), False, 0, True, T([54784], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54796], i64), T([1024], i64), False, 0, True, T([54796], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54754], i64), T([1024], i64), False, 0, True, T([54754], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54751], i64), T([1024], i64), False, 0, True, T([54751], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54764], i64), T([1024], i64), False, 0, True, T([54764], f16)), {})
+cnt: 2, ((T([965, 192], f16), T([54687], i64), T([1024], i64), False, 0, True, T([54687], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54740], i64), T([1024], i64), False, 0, True, T([54740], f16)), {})
+cnt: 1, ((T([965, 192], f16), T([54765], i64), T([1024], i64), False, 0, True, T([54765], f16)), {})
+Operator: aten._embedding_bag_per_sample_weights_backward.default
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54765], i64), T([1024], i64), T([54765], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54704], i64), T([1024], i64), T([54704], i64), 0), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54786], i64), T([1024], i64), T([54786], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54804], i64), T([1024], i64), T([54804], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54757], i64), T([1024], i64), T([54757], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54746], i64), T([1024], i64), T([54746], i64), 0), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54781], i64), T([1024], i64), T([54781], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54687], i64), T([1024], i64), T([54687], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54738], i64), T([1024], i64), T([54738], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54784], i64), T([1024], i64), T([54784], i64), 0), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54787], i64), T([1024], i64), T([54787], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54768], i64), T([1024], i64), T([54768], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54697], i64), T([1024], i64), T([54697], i64), 0), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54833], i64), T([1024], i64), T([54833], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54809], i64), T([1024], i64), T([54809], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54713], i64), T([1024], i64), T([54713], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54814], i64), T([1024], i64), T([54814], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54802], i64), T([1024], i64), T([54802], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54789], i64), T([1024], i64), T([54789], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54743], i64), T([1024], i64), T([54743], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54731], i64), T([1024], i64), T([54731], i64), 0), {})
+cnt: 6, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54760], i64), T([1024], i64), T([54760], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54771], i64), T([1024], i64), T([54771], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54723], i64), T([1024], i64), T([54723], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54812], i64), T([1024], i64), T([54812], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54799], i64), T([1024], i64), T([54799], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54745], i64), T([1024], i64), T([54745], i64), 0), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54753], i64), T([1024], i64), T([54753], i64), 0), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54763], i64), T([1024], i64), T([54763], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54795], i64), T([1024], i64), T([54795], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54740], i64), T([1024], i64), T([54740], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54707], i64), T([1024], i64), T([54707], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54798], i64), T([1024], i64), T([54798], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54751], i64), T([1024], i64), T([54751], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54788], i64), T([1024], i64), T([54788], i64), 0), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54780], i64), T([1024], i64), T([54780], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54824], i64), T([1024], i64), T([54824], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54764], i64), T([1024], i64), T([54764], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54797], i64), T([1024], i64), T([54797], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54739], i64), T([1024], i64), T([54739], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54791], i64), T([1024], i64), T([54791], i64), 0), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54776], i64), T([1024], i64), T([54776], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54754], i64), T([1024], i64), T([54754], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54777], i64), T([1024], i64), T([54777], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54794], i64), T([1024], i64), T([54794], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54742], i64), T([1024], i64), T([54742], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54748], i64), T([1024], i64), T([54748], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54729], i64), T([1024], i64), T([54729], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54815], i64), T([1024], i64), T([54815], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54796], i64), T([1024], i64), T([54796], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54730], i64), T([1024], i64), T([54730], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54773], i64), T([1024], i64), T([54773], i64), 0), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54801], i64), T([1024], i64), T([54801], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54744], i64), T([1024], i64), T([54744], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54847], i64), T([1024], i64), T([54847], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54766], i64), T([1024], i64), T([54766], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54778], i64), T([1024], i64), T([54778], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54711], i64), T([1024], i64), T([54711], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54826], i64), T([1024], i64), T([54826], i64), 0), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54793], i64), T([1024], i64), T([54793], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54792], i64), T([1024], i64), T([54792], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54831], i64), T([1024], i64), T([54831], i64), 0), {})
+cnt: 6, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54761], i64), T([1024], i64), T([54761], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54835], i64), T([1024], i64), T([54835], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54732], i64), T([1024], i64), T([54732], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54830], i64), T([1024], i64), T([54830], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54775], i64), T([1024], i64), T([54775], i64), 0), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54719], i64), T([1024], i64), T([54719], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54722], i64), T([1024], i64), T([54722], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54716], i64), T([1024], i64), T([54716], i64), 0), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54818], i64), T([1024], i64), T([54818], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54783], i64), T([1024], i64), T([54783], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54717], i64), T([1024], i64), T([54717], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54827], i64), T([1024], i64), T([54827], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54734], i64), T([1024], i64), T([54734], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54779], i64), T([1024], i64), T([54779], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54785], i64), T([1024], i64), T([54785], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54714], i64), T([1024], i64), T([54714], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54772], i64), T([1024], i64), T([54772], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54839], i64), T([1024], i64), T([54839], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54836], i64), T([1024], i64), T([54836], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54774], i64), T([1024], i64), T([54774], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54803], i64), T([1024], i64), T([54803], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54770], i64), T([1024], i64), T([54770], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54747], i64), T([1024], i64), T([54747], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54737], i64), T([1024], i64), T([54737], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54741], i64), T([1024], i64), T([54741], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54800], i64), T([1024], i64), T([54800], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54811], i64), T([1024], i64), T([54811], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54758], i64), T([1024], i64), T([54758], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54829], i64), T([1024], i64), T([54829], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54838], i64), T([1024], i64), T([54838], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54759], i64), T([1024], i64), T([54759], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54733], i64), T([1024], i64), T([54733], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54844], i64), T([1024], i64), T([54844], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54718], i64), T([1024], i64), T([54718], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54842], i64), T([1024], i64), T([54842], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54769], i64), T([1024], i64), T([54769], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54821], i64), T([1024], i64), T([54821], i64), 0), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54782], i64), T([1024], i64), T([54782], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54710], i64), T([1024], i64), T([54710], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54820], i64), T([1024], i64), T([54820], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54692], i64), T([1024], i64), T([54692], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54727], i64), T([1024], i64), T([54727], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54767], i64), T([1024], i64), T([54767], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54819], i64), T([1024], i64), T([54819], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54756], i64), T([1024], i64), T([54756], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54843], i64), T([1024], i64), T([54843], i64), 0), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54735], i64), T([1024], i64), T([54735], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54715], i64), T([1024], i64), T([54715], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54828], i64), T([1024], i64), T([54828], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54712], i64), T([1024], i64), T([54712], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54855], i64), T([1024], i64), T([54855], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54725], i64), T([1024], i64), T([54725], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54816], i64), T([1024], i64), T([54816], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54807], i64), T([1024], i64), T([54807], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54701], i64), T([1024], i64), T([54701], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54813], i64), T([1024], i64), T([54813], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54749], i64), T([1024], i64), T([54749], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54736], i64), T([1024], i64), T([54736], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54705], i64), T([1024], i64), T([54705], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54750], i64), T([1024], i64), T([54750], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54862], i64), T([1024], i64), T([54862], i64), 0), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), T([965, 192], f16), T([54762], i64), T([1024], i64), T([54762], i64), 0), {})
+Operator: aten._sparse_coo_tensor_with_dims_and_tensors.default
+cnt: 2, ((1, 1, [965, 192], T([1, 54765], i64), T([54765, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54704], i64), T([54704, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 8, ((1, 1, [965, 192], T([1, 54786], i64), T([54786, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54804], i64), T([54804, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54757], i64), T([54757, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54746], i64), T([54746, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 10, ((1, 1, [965, 192], T([1, 54781], i64), T([54781, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54687], i64), T([54687, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54738], i64), T([54738, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54784], i64), T([54784, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 8, ((1, 1, [965, 192], T([1, 54787], i64), T([54787, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54768], i64), T([54768, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54697], i64), T([54697, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 8, ((1, 1, [965, 192], T([1, 54833], i64), T([54833, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54809], i64), T([54809, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54713], i64), T([54713, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54814], i64), T([54814, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54802], i64), T([54802, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54789], i64), T([54789, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54743], i64), T([54743, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54731], i64), T([54731, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 12, ((1, 1, [965, 192], T([1, 54760], i64), T([54760, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54771], i64), T([54771, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54723], i64), T([54723, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54812], i64), T([54812, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54799], i64), T([54799, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54745], i64), T([54745, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 8, ((1, 1, [965, 192], T([1, 54753], i64), T([54753, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 10, ((1, 1, [965, 192], T([1, 54763], i64), T([54763, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54795], i64), T([54795, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54740], i64), T([54740, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54707], i64), T([54707, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54798], i64), T([54798, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54751], i64), T([54751, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54788], i64), T([54788, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 8, ((1, 1, [965, 192], T([1, 54780], i64), T([54780, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54824], i64), T([54824, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54764], i64), T([54764, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54797], i64), T([54797, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54739], i64), T([54739, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54791], i64), T([54791, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 10, ((1, 1, [965, 192], T([1, 54776], i64), T([54776, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54754], i64), T([54754, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54777], i64), T([54777, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54794], i64), T([54794, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54742], i64), T([54742, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54748], i64), T([54748, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54729], i64), T([54729, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54815], i64), T([54815, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54796], i64), T([54796, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54730], i64), T([54730, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54773], i64), T([54773, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 8, ((1, 1, [965, 192], T([1, 54801], i64), T([54801, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54744], i64), T([54744, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54847], i64), T([54847, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54766], i64), T([54766, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54778], i64), T([54778, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54711], i64), T([54711, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54826], i64), T([54826, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 10, ((1, 1, [965, 192], T([1, 54793], i64), T([54793, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54792], i64), T([54792, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54831], i64), T([54831, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 12, ((1, 1, [965, 192], T([1, 54761], i64), T([54761, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54835], i64), T([54835, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54732], i64), T([54732, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54830], i64), T([54830, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54775], i64), T([54775, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 8, ((1, 1, [965, 192], T([1, 54719], i64), T([54719, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54722], i64), T([54722, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54716], i64), T([54716, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 10, ((1, 1, [965, 192], T([1, 54818], i64), T([54818, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54783], i64), T([54783, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54717], i64), T([54717, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54827], i64), T([54827, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54734], i64), T([54734, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54779], i64), T([54779, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54785], i64), T([54785, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54714], i64), T([54714, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54772], i64), T([54772, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54839], i64), T([54839, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54836], i64), T([54836, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54774], i64), T([54774, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54803], i64), T([54803, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54770], i64), T([54770, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54747], i64), T([54747, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54737], i64), T([54737, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54741], i64), T([54741, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54800], i64), T([54800, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54811], i64), T([54811, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54758], i64), T([54758, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54829], i64), T([54829, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54838], i64), T([54838, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54759], i64), T([54759, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54733], i64), T([54733, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54844], i64), T([54844, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54718], i64), T([54718, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54842], i64), T([54842, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54769], i64), T([54769, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54821], i64), T([54821, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 6, ((1, 1, [965, 192], T([1, 54782], i64), T([54782, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54710], i64), T([54710, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54820], i64), T([54820, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54692], i64), T([54692, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54727], i64), T([54727, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54767], i64), T([54767, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54819], i64), T([54819, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54756], i64), T([54756, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54843], i64), T([54843, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 4, ((1, 1, [965, 192], T([1, 54735], i64), T([54735, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54715], i64), T([54715, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54828], i64), T([54828, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54712], i64), T([54712, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54855], i64), T([54855, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54725], i64), T([54725, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54816], i64), T([54816, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54807], i64), T([54807, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54701], i64), T([54701, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54813], i64), T([54813, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54749], i64), T([54749, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54736], i64), T([54736, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54705], i64), T([54705, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54750], i64), T([54750, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54862], i64), T([54862, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+cnt: 2, ((1, 1, [965, 192], T([1, 54762], i64), T([54762, 192], f16)), {'dtype': f16, 'layout': torch.sparse_coo, 'device': 'cuda', 'pin_memory': None})
+Operator: aten.add.Tensor
+cnt: 1, ((T([1024, 249, 192], f16), T([1024, 249, 192], f16, stride=(47808, 1, 249))), {})
+cnt: 1, ((T([1024, 192], f16, stride=(31068, 1)), T([1024, 192], f16, stride=(47808, 1))), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1500], f16), T([1024, 2000], f16), T([2000, 1500], f16, stride=(1, 2000))), {})
+cnt: 2, ((T([1500], f16), T([1024, 1500], f16), T([1500, 1500], f16, stride=(1, 1500))), {})
+cnt: 1, ((T([192], f16), T([1024, 1500], f16), T([1500, 192], f16, stride=(1, 1500))), {})
+cnt: 1, ((T([4000], f16), T([1024, 31068], f16), T([31068, 4000], f16, stride=(1, 31068))), {})
+cnt: 8, ((T([4000], f16), T([1024, 4000], f16), T([4000, 4000], f16, stride=(1, 4000))), {})
+cnt: 1, ((T([1], f16), T([1024, 4000], f16), T([4000, 1], f16)), {})
+Operator: aten.bmm.default
+cnt: 1, ((T([1024, 249, 192], f16), T([1024, 192, 249], f16, stride=(47808, 1, 192))), {})
+cnt: 1, ((T([1024, 192, 249], f16, stride=(47808, 1, 192)), T([1024, 249, 249], f16)), {})
+cnt: 1, ((T([1024, 249, 249], f16), T([1024, 249, 192], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16), T([1024, 192], f16)], 1), {})
+cnt: 1, (([T([1024, 192], f16), T([1024, 30876], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([1024, 2000], f16),), {})
+cnt: 1, ((T([248, 1024], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([1024, 2000], f16), T([1024, 2000], f16)), {})
+cnt: 1, ((T([248, 1024], i64), T([248, 1024], i64)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 1024), {})
+Operator: aten.gather.default
+cnt: 2, ((T([965], f16), 0, T([54824], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54798], i64)), {})
+cnt: 5, ((T([965], f16), 0, T([54763], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54783], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54762], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54862], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54743], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54750], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54705], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54735], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54736], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54775], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54710], i64)), {})
+cnt: 4, ((T([965], f16), 0, T([54753], i64)), {})
+cnt: 4, ((T([965], f16), 0, T([54833], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54767], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54749], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54795], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54813], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54730], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54768], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54826], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54701], i64)), {})
+cnt: 6, ((T([965], f16), 0, T([54761], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54807], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54744], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54745], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54723], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54797], i64)), {})
+cnt: 4, ((T([965], f16), 0, T([54786], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54816], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54725], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54819], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54855], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54782], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54712], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54799], i64)), {})
+cnt: 4, ((T([965], f16), 0, T([54801], i64)), {})
+cnt: 5, ((T([965], f16), 0, T([54818], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54779], i64)), {})
+cnt: 4, ((T([965], f16), 0, T([54719], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54778], i64)), {})
+cnt: 6, ((T([965], f16), 0, T([54760], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54802], i64)), {})
+cnt: 5, ((T([965], f16), 0, T([54776], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54828], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54715], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54843], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54756], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54766], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54697], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54792], i64)), {})
+cnt: 5, ((T([965], f16), 0, T([54793], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54727], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54733], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54692], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54758], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54820], i64)), {})
+cnt: 4, ((T([965], f16), 0, T([54787], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54815], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54814], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54759], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54757], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54821], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54769], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54842], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54718], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54771], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54844], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54838], i64)), {})
+cnt: 5, ((T([965], f16), 0, T([54781], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54804], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54788], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54774], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54829], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54738], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54777], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54811], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54772], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54800], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54741], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54794], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54773], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54803], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54789], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54707], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54737], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54722], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54747], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54770], i64)), {})
+cnt: 4, ((T([965], f16), 0, T([54780], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54731], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54836], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54839], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54714], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54785], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54729], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54812], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54734], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54791], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54827], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54717], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54716], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54830], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54732], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54835], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54831], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54748], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54746], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54711], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54739], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54713], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54847], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54809], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54742], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54704], i64)), {})
+cnt: 3, ((T([965], f16), 0, T([54784], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54796], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54754], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54751], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54764], i64)), {})
+cnt: 2, ((T([965], f16), 0, T([54687], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54740], i64)), {})
+cnt: 1, ((T([965], f16), 0, T([54765], i64)), {})
+Operator: aten.index.Tensor
+cnt: 1, ((T([1024, 249, 249], f16), [None, T([30876], i64), T([30876], i64)]), {})
+Operator: aten.index_put.default
+cnt: 1, ((T([1024, 249, 249], f16), [None, T([30876], i64), T([30876], i64)], T([1024, 30876], f16, stride=(31068, 1)), True), {})
+Operator: aten.index_select.default
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54765], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54704], i64)), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54786], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54804], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54757], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54746], i64)), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54781], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54687], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54738], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54784], i64)), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54787], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54768], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54697], i64)), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54833], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54809], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54713], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54814], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54802], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54789], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54743], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54731], i64)), {})
+cnt: 6, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54760], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54771], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54723], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54812], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54799], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54745], i64)), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54753], i64)), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54763], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54795], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54740], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54707], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54798], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54751], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54788], i64)), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54780], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54824], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54764], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54797], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54739], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54791], i64)), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54776], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54754], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54777], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54794], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54742], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54748], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54729], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54815], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54796], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54730], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54773], i64)), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54801], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54744], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54847], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54766], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54778], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54711], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54826], i64)), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54793], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54792], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54831], i64)), {})
+cnt: 6, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54761], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54835], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54732], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54830], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54775], i64)), {})
+cnt: 4, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54719], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54722], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54716], i64)), {})
+cnt: 5, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54818], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54783], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54717], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54827], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54734], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54779], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54785], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54714], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54772], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54839], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54836], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54774], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54803], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54770], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54747], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54737], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54741], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54800], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54811], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54758], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54829], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54838], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54759], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54733], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54844], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54718], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54842], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54769], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54821], i64)), {})
+cnt: 3, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54782], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54710], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54820], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54692], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54727], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54767], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54819], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54756], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54843], i64)), {})
+cnt: 2, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54735], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54715], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54828], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54712], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54855], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54725], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54816], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54807], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54701], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54813], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54749], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54736], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54705], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54750], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54862], i64)), {})
+cnt: 1, ((T([1024, 192], f16, stride=(47808, 1)), 0, T([54762], i64)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([1024, 1], f16), T([1, 4000], f16)), {})
+cnt: 1, ((T([1, 1024], f16), T([1024, 4000], f16)), {})
+cnt: 8, ((T([1024, 4000], f16), T([4000, 4000], f16)), {})
+cnt: 8, ((T([4000, 1024], f16, stride=(1, 4000)), T([1024, 4000], f16)), {})
+cnt: 1, ((T([1024, 4000], f16), T([4000, 31068], f16)), {})
+cnt: 1, ((T([4000, 1024], f16, stride=(1, 4000)), T([1024, 31068], f16)), {})
+cnt: 1, ((T([1024, 192], f16), T([192, 1500], f16)), {})
+cnt: 1, ((T([192, 1024], f16, stride=(1, 192)), T([1024, 1500], f16)), {})
+cnt: 2, ((T([1024, 1500], f16), T([1500, 1500], f16)), {})
+cnt: 2, ((T([1500, 1024], f16, stride=(1, 1500)), T([1024, 1500], f16)), {})
+cnt: 1, ((T([1500, 1024], f16, stride=(1, 1500)), T([1024, 2000], f16)), {})
+Operator: aten.mul_.Tensor
+cnt: 1, ((T([54765, 192], f16), T([54765, 1], f16)), {})
+cnt: 2, ((T([54704, 192], f16), T([54704, 1], f16)), {})
+cnt: 4, ((T([54786, 192], f16), T([54786, 1], f16)), {})
+cnt: 2, ((T([54804, 192], f16), T([54804, 1], f16)), {})
+cnt: 3, ((T([54757, 192], f16), T([54757, 1], f16)), {})
+cnt: 2, ((T([54746, 192], f16), T([54746, 1], f16)), {})
+cnt: 5, ((T([54781, 192], f16), T([54781, 1], f16)), {})
+cnt: 2, ((T([54687, 192], f16), T([54687, 1], f16)), {})
+cnt: 2, ((T([54738, 192], f16), T([54738, 1], f16)), {})
+cnt: 3, ((T([54784, 192], f16), T([54784, 1], f16)), {})
+cnt: 4, ((T([54787, 192], f16), T([54787, 1], f16)), {})
+cnt: 3, ((T([54768, 192], f16), T([54768, 1], f16)), {})
+cnt: 2, ((T([54697, 192], f16), T([54697, 1], f16)), {})
+cnt: 4, ((T([54833, 192], f16), T([54833, 1], f16)), {})
+cnt: 2, ((T([54809, 192], f16), T([54809, 1], f16)), {})
+cnt: 2, ((T([54713, 192], f16), T([54713, 1], f16)), {})
+cnt: 2, ((T([54814, 192], f16), T([54814, 1], f16)), {})
+cnt: 2, ((T([54802, 192], f16), T([54802, 1], f16)), {})
+cnt: 2, ((T([54789, 192], f16), T([54789, 1], f16)), {})
+cnt: 2, ((T([54743, 192], f16), T([54743, 1], f16)), {})
+cnt: 2, ((T([54731, 192], f16), T([54731, 1], f16)), {})
+cnt: 6, ((T([54760, 192], f16), T([54760, 1], f16)), {})
+cnt: 3, ((T([54771, 192], f16), T([54771, 1], f16)), {})
+cnt: 2, ((T([54723, 192], f16), T([54723, 1], f16)), {})
+cnt: 2, ((T([54812, 192], f16), T([54812, 1], f16)), {})
+cnt: 3, ((T([54799, 192], f16), T([54799, 1], f16)), {})
+cnt: 3, ((T([54745, 192], f16), T([54745, 1], f16)), {})
+cnt: 4, ((T([54753, 192], f16), T([54753, 1], f16)), {})
+cnt: 5, ((T([54763, 192], f16), T([54763, 1], f16)), {})
+cnt: 2, ((T([54795, 192], f16), T([54795, 1], f16)), {})
+cnt: 1, ((T([54740, 192], f16), T([54740, 1], f16)), {})
+cnt: 2, ((T([54707, 192], f16), T([54707, 1], f16)), {})
+cnt: 2, ((T([54798, 192], f16), T([54798, 1], f16)), {})
+cnt: 2, ((T([54751, 192], f16), T([54751, 1], f16)), {})
+cnt: 2, ((T([54788, 192], f16), T([54788, 1], f16)), {})
+cnt: 4, ((T([54780, 192], f16), T([54780, 1], f16)), {})
+cnt: 2, ((T([54824, 192], f16), T([54824, 1], f16)), {})
+cnt: 1, ((T([54764, 192], f16), T([54764, 1], f16)), {})
+cnt: 2, ((T([54797, 192], f16), T([54797, 1], f16)), {})
+cnt: 3, ((T([54739, 192], f16), T([54739, 1], f16)), {})
+cnt: 2, ((T([54791, 192], f16), T([54791, 1], f16)), {})
+cnt: 5, ((T([54776, 192], f16), T([54776, 1], f16)), {})
+cnt: 1, ((T([54754, 192], f16), T([54754, 1], f16)), {})
+cnt: 2, ((T([54777, 192], f16), T([54777, 1], f16)), {})
+cnt: 2, ((T([54794, 192], f16), T([54794, 1], f16)), {})
+cnt: 2, ((T([54742, 192], f16), T([54742, 1], f16)), {})
+cnt: 3, ((T([54748, 192], f16), T([54748, 1], f16)), {})
+cnt: 2, ((T([54729, 192], f16), T([54729, 1], f16)), {})
+cnt: 2, ((T([54815, 192], f16), T([54815, 1], f16)), {})
+cnt: 1, ((T([54796, 192], f16), T([54796, 1], f16)), {})
+cnt: 3, ((T([54730, 192], f16), T([54730, 1], f16)), {})
+cnt: 2, ((T([54773, 192], f16), T([54773, 1], f16)), {})
+cnt: 4, ((T([54801, 192], f16), T([54801, 1], f16)), {})
+cnt: 2, ((T([54744, 192], f16), T([54744, 1], f16)), {})
+cnt: 1, ((T([54847, 192], f16), T([54847, 1], f16)), {})
+cnt: 3, ((T([54766, 192], f16), T([54766, 1], f16)), {})
+cnt: 3, ((T([54778, 192], f16), T([54778, 1], f16)), {})
+cnt: 1, ((T([54711, 192], f16), T([54711, 1], f16)), {})
+cnt: 2, ((T([54826, 192], f16), T([54826, 1], f16)), {})
+cnt: 5, ((T([54793, 192], f16), T([54793, 1], f16)), {})
+cnt: 3, ((T([54792, 192], f16), T([54792, 1], f16)), {})
+cnt: 1, ((T([54831, 192], f16), T([54831, 1], f16)), {})
+cnt: 6, ((T([54761, 192], f16), T([54761, 1], f16)), {})
+cnt: 1, ((T([54835, 192], f16), T([54835, 1], f16)), {})
+cnt: 1, ((T([54732, 192], f16), T([54732, 1], f16)), {})
+cnt: 1, ((T([54830, 192], f16), T([54830, 1], f16)), {})
+cnt: 3, ((T([54775, 192], f16), T([54775, 1], f16)), {})
+cnt: 4, ((T([54719, 192], f16), T([54719, 1], f16)), {})
+cnt: 2, ((T([54722, 192], f16), T([54722, 1], f16)), {})
+cnt: 1, ((T([54716, 192], f16), T([54716, 1], f16)), {})
+cnt: 5, ((T([54818, 192], f16), T([54818, 1], f16)), {})
+cnt: 2, ((T([54783, 192], f16), T([54783, 1], f16)), {})
+cnt: 1, ((T([54717, 192], f16), T([54717, 1], f16)), {})
+cnt: 1, ((T([54827, 192], f16), T([54827, 1], f16)), {})
+cnt: 1, ((T([54734, 192], f16), T([54734, 1], f16)), {})
+cnt: 3, ((T([54779, 192], f16), T([54779, 1], f16)), {})
+cnt: 1, ((T([54785, 192], f16), T([54785, 1], f16)), {})
+cnt: 1, ((T([54714, 192], f16), T([54714, 1], f16)), {})
+cnt: 2, ((T([54772, 192], f16), T([54772, 1], f16)), {})
+cnt: 1, ((T([54839, 192], f16), T([54839, 1], f16)), {})
+cnt: 1, ((T([54836, 192], f16), T([54836, 1], f16)), {})
+cnt: 2, ((T([54774, 192], f16), T([54774, 1], f16)), {})
+cnt: 2, ((T([54803, 192], f16), T([54803, 1], f16)), {})
+cnt: 1, ((T([54770, 192], f16), T([54770, 1], f16)), {})
+cnt: 1, ((T([54747, 192], f16), T([54747, 1], f16)), {})
+cnt: 1, ((T([54737, 192], f16), T([54737, 1], f16)), {})
+cnt: 1, ((T([54741, 192], f16), T([54741, 1], f16)), {})
+cnt: 1, ((T([54800, 192], f16), T([54800, 1], f16)), {})
+cnt: 1, ((T([54811, 192], f16), T([54811, 1], f16)), {})
+cnt: 2, ((T([54758, 192], f16), T([54758, 1], f16)), {})
+cnt: 1, ((T([54829, 192], f16), T([54829, 1], f16)), {})
+cnt: 1, ((T([54838, 192], f16), T([54838, 1], f16)), {})
+cnt: 2, ((T([54759, 192], f16), T([54759, 1], f16)), {})
+cnt: 2, ((T([54733, 192], f16), T([54733, 1], f16)), {})
+cnt: 1, ((T([54844, 192], f16), T([54844, 1], f16)), {})
+cnt: 1, ((T([54718, 192], f16), T([54718, 1], f16)), {})
+cnt: 1, ((T([54842, 192], f16), T([54842, 1], f16)), {})
+cnt: 1, ((T([54769, 192], f16), T([54769, 1], f16)), {})
+cnt: 1, ((T([54821, 192], f16), T([54821, 1], f16)), {})
+cnt: 3, ((T([54782, 192], f16), T([54782, 1], f16)), {})
+cnt: 2, ((T([54710, 192], f16), T([54710, 1], f16)), {})
+cnt: 1, ((T([54820, 192], f16), T([54820, 1], f16)), {})
+cnt: 1, ((T([54692, 192], f16), T([54692, 1], f16)), {})
+cnt: 1, ((T([54727, 192], f16), T([54727, 1], f16)), {})
+cnt: 2, ((T([54767, 192], f16), T([54767, 1], f16)), {})
+cnt: 2, ((T([54819, 192], f16), T([54819, 1], f16)), {})
+cnt: 1, ((T([54756, 192], f16), T([54756, 1], f16)), {})
+cnt: 1, ((T([54843, 192], f16), T([54843, 1], f16)), {})
+cnt: 2, ((T([54735, 192], f16), T([54735, 1], f16)), {})
+cnt: 1, ((T([54715, 192], f16), T([54715, 1], f16)), {})
+cnt: 1, ((T([54828, 192], f16), T([54828, 1], f16)), {})
+cnt: 1, ((T([54712, 192], f16), T([54712, 1], f16)), {})
+cnt: 1, ((T([54855, 192], f16), T([54855, 1], f16)), {})
+cnt: 1, ((T([54725, 192], f16), T([54725, 1], f16)), {})
+cnt: 1, ((T([54816, 192], f16), T([54816, 1], f16)), {})
+cnt: 1, ((T([54807, 192], f16), T([54807, 1], f16)), {})
+cnt: 1, ((T([54701, 192], f16), T([54701, 1], f16)), {})
+cnt: 1, ((T([54813, 192], f16), T([54813, 1], f16)), {})
+cnt: 1, ((T([54749, 192], f16), T([54749, 1], f16)), {})
+cnt: 1, ((T([54736, 192], f16), T([54736, 1], f16)), {})
+cnt: 1, ((T([54705, 192], f16), T([54705, 1], f16)), {})
+cnt: 1, ((T([54750, 192], f16), T([54750, 1], f16)), {})
+cnt: 1, ((T([54862, 192], f16), T([54862, 1], f16)), {})
+cnt: 1, ((T([54762, 192], f16), T([54762, 1], f16)), {})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([1024, 30876], f16, stride=(31068, 1)), [1024, 249, 249]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([54765], f16), [965]), {})
+cnt: 2, ((T([54704], f16), [965]), {})
+cnt: 4, ((T([54786], f16), [965]), {})
+cnt: 2, ((T([54804], f16), [965]), {})
+cnt: 3, ((T([54757], f16), [965]), {})
+cnt: 2, ((T([54746], f16), [965]), {})
+cnt: 5, ((T([54781], f16), [965]), {})
+cnt: 2, ((T([54687], f16), [965]), {})
+cnt: 2, ((T([54738], f16), [965]), {})
+cnt: 3, ((T([54784], f16), [965]), {})
+cnt: 4, ((T([54787], f16), [965]), {})
+cnt: 3, ((T([54768], f16), [965]), {})
+cnt: 2, ((T([54697], f16), [965]), {})
+cnt: 4, ((T([54833], f16), [965]), {})
+cnt: 2, ((T([54809], f16), [965]), {})
+cnt: 2, ((T([54713], f16), [965]), {})
+cnt: 2, ((T([54814], f16), [965]), {})
+cnt: 2, ((T([54802], f16), [965]), {})
+cnt: 2, ((T([54789], f16), [965]), {})
+cnt: 2, ((T([54743], f16), [965]), {})
+cnt: 2, ((T([54731], f16), [965]), {})
+cnt: 6, ((T([54760], f16), [965]), {})
+cnt: 3, ((T([54771], f16), [965]), {})
+cnt: 2, ((T([54723], f16), [965]), {})
+cnt: 2, ((T([54812], f16), [965]), {})
+cnt: 3, ((T([54799], f16), [965]), {})
+cnt: 3, ((T([54745], f16), [965]), {})
+cnt: 4, ((T([54753], f16), [965]), {})
+cnt: 5, ((T([54763], f16), [965]), {})
+cnt: 2, ((T([54795], f16), [965]), {})
+cnt: 1, ((T([54740], f16), [965]), {})
+cnt: 2, ((T([54707], f16), [965]), {})
+cnt: 2, ((T([54798], f16), [965]), {})
+cnt: 2, ((T([54751], f16), [965]), {})
+cnt: 2, ((T([54788], f16), [965]), {})
+cnt: 4, ((T([54780], f16), [965]), {})
+cnt: 2, ((T([54824], f16), [965]), {})
+cnt: 1, ((T([54764], f16), [965]), {})
+cnt: 2, ((T([54797], f16), [965]), {})
+cnt: 3, ((T([54739], f16), [965]), {})
+cnt: 2, ((T([54791], f16), [965]), {})
+cnt: 5, ((T([54776], f16), [965]), {})
+cnt: 1, ((T([54754], f16), [965]), {})
+cnt: 2, ((T([54777], f16), [965]), {})
+cnt: 2, ((T([54794], f16), [965]), {})
+cnt: 2, ((T([54742], f16), [965]), {})
+cnt: 3, ((T([54748], f16), [965]), {})
+cnt: 2, ((T([54729], f16), [965]), {})
+cnt: 2, ((T([54815], f16), [965]), {})
+cnt: 1, ((T([54796], f16), [965]), {})
+cnt: 3, ((T([54730], f16), [965]), {})
+cnt: 2, ((T([54773], f16), [965]), {})
+cnt: 4, ((T([54801], f16), [965]), {})
+cnt: 2, ((T([54744], f16), [965]), {})
+cnt: 1, ((T([54847], f16), [965]), {})
+cnt: 3, ((T([54766], f16), [965]), {})
+cnt: 3, ((T([54778], f16), [965]), {})
+cnt: 1, ((T([54711], f16), [965]), {})
+cnt: 2, ((T([54826], f16), [965]), {})
+cnt: 5, ((T([54793], f16), [965]), {})
+cnt: 3, ((T([54792], f16), [965]), {})
+cnt: 1, ((T([54831], f16), [965]), {})
+cnt: 6, ((T([54761], f16), [965]), {})
+cnt: 1, ((T([54835], f16), [965]), {})
+cnt: 1, ((T([54732], f16), [965]), {})
+cnt: 1, ((T([54830], f16), [965]), {})
+cnt: 3, ((T([54775], f16), [965]), {})
+cnt: 4, ((T([54719], f16), [965]), {})
+cnt: 2, ((T([54722], f16), [965]), {})
+cnt: 1, ((T([54716], f16), [965]), {})
+cnt: 5, ((T([54818], f16), [965]), {})
+cnt: 2, ((T([54783], f16), [965]), {})
+cnt: 1, ((T([54717], f16), [965]), {})
+cnt: 1, ((T([54827], f16), [965]), {})
+cnt: 1, ((T([54734], f16), [965]), {})
+cnt: 3, ((T([54779], f16), [965]), {})
+cnt: 1, ((T([54785], f16), [965]), {})
+cnt: 1, ((T([54714], f16), [965]), {})
+cnt: 2, ((T([54772], f16), [965]), {})
+cnt: 1, ((T([54839], f16), [965]), {})
+cnt: 1, ((T([54836], f16), [965]), {})
+cnt: 2, ((T([54774], f16), [965]), {})
+cnt: 2, ((T([54803], f16), [965]), {})
+cnt: 1, ((T([54770], f16), [965]), {})
+cnt: 1, ((T([54747], f16), [965]), {})
+cnt: 1, ((T([54737], f16), [965]), {})
+cnt: 1, ((T([54741], f16), [965]), {})
+cnt: 1, ((T([54800], f16), [965]), {})
+cnt: 1, ((T([54811], f16), [965]), {})
+cnt: 2, ((T([54758], f16), [965]), {})
+cnt: 1, ((T([54829], f16), [965]), {})
+cnt: 1, ((T([54838], f16), [965]), {})
+cnt: 2, ((T([54759], f16), [965]), {})
+cnt: 2, ((T([54733], f16), [965]), {})
+cnt: 1, ((T([54844], f16), [965]), {})
+cnt: 1, ((T([54718], f16), [965]), {})
+cnt: 1, ((T([54842], f16), [965]), {})
+cnt: 1, ((T([54769], f16), [965]), {})
+cnt: 1, ((T([54821], f16), [965]), {})
+cnt: 3, ((T([54782], f16), [965]), {})
+cnt: 2, ((T([54710], f16), [965]), {})
+cnt: 1, ((T([54820], f16), [965]), {})
+cnt: 1, ((T([54692], f16), [965]), {})
+cnt: 1, ((T([54727], f16), [965]), {})
+cnt: 2, ((T([54767], f16), [965]), {})
+cnt: 2, ((T([54819], f16), [965]), {})
+cnt: 1, ((T([54756], f16), [965]), {})
+cnt: 1, ((T([54843], f16), [965]), {})
+cnt: 2, ((T([54735], f16), [965]), {})
+cnt: 1, ((T([54715], f16), [965]), {})
+cnt: 1, ((T([54828], f16), [965]), {})
+cnt: 1, ((T([54712], f16), [965]), {})
+cnt: 1, ((T([54855], f16), [965]), {})
+cnt: 1, ((T([54725], f16), [965]), {})
+cnt: 1, ((T([54816], f16), [965]), {})
+cnt: 1, ((T([54807], f16), [965]), {})
+cnt: 1, ((T([54701], f16), [965]), {})
+cnt: 1, ((T([54813], f16), [965]), {})
+cnt: 1, ((T([54749], f16), [965]), {})
+cnt: 1, ((T([54736], f16), [965]), {})
+cnt: 1, ((T([54705], f16), [965]), {})
+cnt: 1, ((T([54750], f16), [965]), {})
+cnt: 1, ((T([54862], f16), [965]), {})
+cnt: 1, ((T([54762], f16), [965]), {})
+Operator: aten.relu.default
+cnt: 3, ((T([1024, 1500], f16),), {})
+cnt: 1, ((T([1024, 192], f16),), {})
+cnt: 9, ((T([1024, 4000], f16),), {})
+Operator: aten.scatter_add.default
+cnt: 1, ((T([965], f16), 0, T([54765], i64), T([54765], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54704], i64), T([54704], f16)), {})
+cnt: 4, ((T([965], f16), 0, T([54786], i64), T([54786], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54804], i64), T([54804], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54757], i64), T([54757], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54746], i64), T([54746], f16)), {})
+cnt: 5, ((T([965], f16), 0, T([54781], i64), T([54781], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54687], i64), T([54687], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54738], i64), T([54738], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54784], i64), T([54784], f16)), {})
+cnt: 4, ((T([965], f16), 0, T([54787], i64), T([54787], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54768], i64), T([54768], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54697], i64), T([54697], f16)), {})
+cnt: 4, ((T([965], f16), 0, T([54833], i64), T([54833], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54809], i64), T([54809], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54713], i64), T([54713], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54814], i64), T([54814], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54802], i64), T([54802], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54789], i64), T([54789], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54743], i64), T([54743], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54731], i64), T([54731], f16)), {})
+cnt: 6, ((T([965], f16), 0, T([54760], i64), T([54760], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54771], i64), T([54771], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54723], i64), T([54723], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54812], i64), T([54812], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54799], i64), T([54799], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54745], i64), T([54745], f16)), {})
+cnt: 4, ((T([965], f16), 0, T([54753], i64), T([54753], f16)), {})
+cnt: 5, ((T([965], f16), 0, T([54763], i64), T([54763], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54795], i64), T([54795], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54740], i64), T([54740], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54707], i64), T([54707], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54798], i64), T([54798], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54751], i64), T([54751], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54788], i64), T([54788], f16)), {})
+cnt: 4, ((T([965], f16), 0, T([54780], i64), T([54780], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54824], i64), T([54824], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54764], i64), T([54764], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54797], i64), T([54797], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54739], i64), T([54739], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54791], i64), T([54791], f16)), {})
+cnt: 5, ((T([965], f16), 0, T([54776], i64), T([54776], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54754], i64), T([54754], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54777], i64), T([54777], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54794], i64), T([54794], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54742], i64), T([54742], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54748], i64), T([54748], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54729], i64), T([54729], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54815], i64), T([54815], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54796], i64), T([54796], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54730], i64), T([54730], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54773], i64), T([54773], f16)), {})
+cnt: 4, ((T([965], f16), 0, T([54801], i64), T([54801], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54744], i64), T([54744], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54847], i64), T([54847], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54766], i64), T([54766], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54778], i64), T([54778], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54711], i64), T([54711], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54826], i64), T([54826], f16)), {})
+cnt: 5, ((T([965], f16), 0, T([54793], i64), T([54793], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54792], i64), T([54792], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54831], i64), T([54831], f16)), {})
+cnt: 6, ((T([965], f16), 0, T([54761], i64), T([54761], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54835], i64), T([54835], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54732], i64), T([54732], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54830], i64), T([54830], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54775], i64), T([54775], f16)), {})
+cnt: 4, ((T([965], f16), 0, T([54719], i64), T([54719], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54722], i64), T([54722], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54716], i64), T([54716], f16)), {})
+cnt: 5, ((T([965], f16), 0, T([54818], i64), T([54818], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54783], i64), T([54783], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54717], i64), T([54717], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54827], i64), T([54827], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54734], i64), T([54734], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54779], i64), T([54779], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54785], i64), T([54785], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54714], i64), T([54714], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54772], i64), T([54772], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54839], i64), T([54839], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54836], i64), T([54836], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54774], i64), T([54774], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54803], i64), T([54803], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54770], i64), T([54770], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54747], i64), T([54747], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54737], i64), T([54737], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54741], i64), T([54741], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54800], i64), T([54800], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54811], i64), T([54811], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54758], i64), T([54758], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54829], i64), T([54829], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54838], i64), T([54838], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54759], i64), T([54759], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54733], i64), T([54733], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54844], i64), T([54844], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54718], i64), T([54718], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54842], i64), T([54842], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54769], i64), T([54769], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54821], i64), T([54821], f16)), {})
+cnt: 3, ((T([965], f16), 0, T([54782], i64), T([54782], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54710], i64), T([54710], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54820], i64), T([54820], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54692], i64), T([54692], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54727], i64), T([54727], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54767], i64), T([54767], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54819], i64), T([54819], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54756], i64), T([54756], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54843], i64), T([54843], f16)), {})
+cnt: 2, ((T([965], f16), 0, T([54735], i64), T([54735], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54715], i64), T([54715], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54828], i64), T([54828], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54712], i64), T([54712], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54855], i64), T([54855], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54725], i64), T([54725], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54816], i64), T([54816], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54807], i64), T([54807], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54701], i64), T([54701], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54813], i64), T([54813], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54749], i64), T([54749], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54736], i64), T([54736], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54705], i64), T([54705], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54750], i64), T([54750], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54862], i64), T([54862], f16)), {})
+cnt: 1, ((T([965], f16), 0, T([54762], i64), T([54762], f16)), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([1024, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([1024, 1], f16, stride=(0, 0)), T([1024, 1], f16)), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([1024, 249, 249], f16), [1024, 249, 249], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([1024, 1], f16), [0], True), {})
+cnt: 9, ((T([1024, 4000], f16), [0], True), {})
+cnt: 1, ((T([1024, 192], f16), [0], True), {})
+cnt: 3, ((T([1024, 1500], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([1024, 1], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 9, ((T([1024, 4000], f16), T([1024, 4000], f16), 0), {})
+cnt: 1, ((T([1024, 192], f16), T([1024, 192], f16), 0), {})
+cnt: 3, ((T([1024, 1500], f16), T([1024, 1500], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/fastNLP_Bert_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/fastNLP_Bert_training.txt
new file mode 100644
index 0000000000000..14639db6d7128
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/fastNLP_Bert_training.txt
@@ -0,0 +1,157 @@
+Operator: aten._index_put_impl_.default
+cnt: 1, ((T([6, 474, 768], f16), [T([6, 474], i64, stride=(1, 0)), T([6, 474], i64, stride=(475, 1))], T([6, 474, 768], f16), True, True), {})
+Operator: aten._softmax.default
+cnt: 12, ((T([6, 12, 476, 476], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([6, 12, 476, 476], f16), T([6, 12, 476, 476], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([6, 474], i64),), {'dtype': i64, 'layout': torch.strided, 'device': "torch.device('cpu')"})
+cnt: 1, ((T([6], i64),), {'dtype': i64, 'device': 'cuda'})
+cnt: 1, ((T([6, 476], b8),), {'dtype': i64})
+cnt: 1, ((T([6, 1, 1, 476], i64),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([6, 12, 476, 64], f16), [72, 476, 64]), {})
+cnt: 12, ((T([6, 12, 64, 476], f16), [72, 64, 476]), {})
+cnt: 12, ((T([72, 476, 476], f16), [6, 12, 476, 476]), {})
+cnt: 12, ((T([72, 476, 64], f16), [6, 12, 476, 64]), {})
+cnt: 24, ((T([6, 476, 12, 64], f16), [6, 476, 768]), {})
+cnt: 12, ((T([6, 476, 768], f16), [2856, 768]), {})
+Operator: aten.add.Tensor
+cnt: 6, ((T([], i64), 1), {})
+cnt: 6, ((T([], i64), 2), {})
+cnt: 1, ((T([6], i64), 1), {})
+cnt: 74, ((T([6, 476, 768], f16), T([6, 476, 768], f16)), {})
+cnt: 12, ((T([6, 12, 476, 476], f16), T([6, 1, 1, 476], f16)), {})
+cnt: 12, ((T([6, 476, 3072], f16), 1.0), {})
+cnt: 1, ((T([], f16), 0), {})
+cnt: 1, ((T([], f16), T([], f16)), {})
+cnt: 1, ((T([6, 474, 2], f16), T([6, 474, 2], f16)), {})
+cnt: 12, ((T([6, 476, 3072], f16), T([6, 476, 3072], f16)), {})
+Operator: aten.addmm.default
+cnt: 48, ((T([768], f16), T([2856, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([2856, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([2856, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([768], f16), T([6, 768], f16, stride=(365568, 1)), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 1, ((T([2], f16), T([2844, 768], f16), T([768, 2], f16, stride=(1, 768))), {})
+Operator: aten.bitwise_xor.Tensor
+cnt: 1, ((T([6, 1], i64, stride=(476, 1)), T([6, 476], i64)), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([72, 476, 64], f16), T([72, 64, 476], f16)), {})
+cnt: 12, ((T([72, 476, 476], f16), T([72, 476, 64], f16)), {})
+cnt: 12, ((T([72, 476, 476], f16, stride=(226576, 1, 476)), T([72, 476, 64], f16)), {})
+cnt: 12, ((T([72, 476, 64], f16), T([72, 64, 476], f16, stride=(30464, 1, 64))), {})
+cnt: 12, ((T([72, 64, 476], f16, stride=(30464, 1, 64)), T([72, 476, 476], f16)), {})
+cnt: 12, ((T([72, 476, 476], f16), T([72, 476, 64], f16, stride=(30464, 1, 476))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([6, 474, 768], f16)], -1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([6, 474], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([6, 474], i64), T([6, 474], i64)), {})
+cnt: 6, ((T([474], i64), T([474], i64)), {})
+cnt: 1, ((T([6, 474], i64, stride=(475, 1)), T([6, 474], i64)), {})
+cnt: 1, ((T([6, 474, 768], f16), T([6, 474, 768], f16)), {})
+cnt: 1, ((T([1, 6, 474, 768], f16), T([1, 6, 474, 768], f16)), {})
+Operator: aten.cumsum.default
+cnt: 1, ((T([6, 476], i64), -1), {})
+cnt: 1, ((T([6, 474], i64), -1), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([6, 12, 476, 476], f16), 8.0), {})
+cnt: 24, ((T([6, 476, 3072], f16), 1.4142135623730951), {})
+cnt: 4, ((T([], f16), 2844), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([21128, 768], f16), T([6, 476], i64), 0), {})
+cnt: 1, ((T([512, 768], f16), T([6, 476], i64, stride=(0, 1))), {})
+cnt: 1, ((T([2, 768], f16), T([6, 476], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([6, 476, 768], f16), T([6, 476], i64), 2, -1, False), {})
+cnt: 1, ((T([6, 476, 768], f16), T([6, 476], i64, stride=(0, 1)), 512, -1, False), {})
+cnt: 1, ((T([6, 476, 768], f16), T([6, 476], i64), 21128, 0, False), {})
+Operator: aten.eq.Scalar
+cnt: 1, ((T([6, 474], b8), False), {})
+cnt: 1, ((T([6, 476], i64), 511), {})
+cnt: 1, ((T([6, 474, 1], b8), False), {})
+Operator: aten.erf.default
+cnt: 12, ((T([6, 476, 3072], f16),), {})
+Operator: aten.exp.default
+cnt: 12, ((T([6, 476, 3072], f16),), {})
+Operator: aten.fill_.Scalar
+cnt: 6, ((T([476], i64), 1), {})
+cnt: 1, ((T([6], i64, stride=(476,)), 2057), {})
+Operator: aten.flip.default
+cnt: 2, ((T([6, 476], i64), [-1]), {})
+Operator: aten.fmod.Scalar
+cnt: 1, ((T([6, 476], i64), 2), {})
+Operator: aten.ge.Scalar
+cnt: 1, ((T([6, 474], i64, stride=(475, 1)), 474), {})
+Operator: aten.index.Tensor
+cnt: 1, ((T([2869], i64), [T([6, 474], i64)]), {})
+cnt: 1, ((T([6, 474, 768], f16, stride=(365568, 768, 1)), [T([6, 474], i64, stride=(1, 0)), T([6, 474], i64, stride=(475, 1))]), {})
+Operator: aten.index_put_.default
+cnt: 1, ((T([6, 476], i64), [T([6], i64), T([6], i64)], T([], i64)), {})
+Operator: aten.masked_fill.Scalar
+cnt: 1, ((T([6, 474], i64), T([6, 474], b8), 0), {})
+cnt: 2, ((T([6, 474, 768], f16), T([6, 474, 1], b8), 0), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([6, 474], i64, stride=(475, 1)), T([6, 474], b8), 0), {})
+Operator: aten.max.default
+cnt: 2, ((T([6], i64),), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2844, 2], f16), T([2, 768], f16)), {})
+cnt: 1, ((T([2, 2844], f16, stride=(1, 2)), T([2844, 768], f16)), {})
+cnt: 12, ((T([2856, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 2856], f16, stride=(1, 768)), T([2856, 3072], f16)), {})
+cnt: 12, ((T([2856, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 2856], f16, stride=(1, 3072)), T([2856, 768], f16)), {})
+cnt: 48, ((T([2856, 768], f16), T([768, 768], f16)), {})
+cnt: 48, ((T([768, 2856], f16, stride=(1, 768)), T([2856, 768], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 12, ((T([6, 476, 3072], f16), 1.1283791670955126), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([6, 1, 1, 476], f16), -10000.0), {})
+cnt: 24, ((T([6, 476, 3072], f16), 0.5), {})
+cnt: 48, ((T([6, 476, 3072], f16), T([6, 476, 3072], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([6, 476, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([6, 476, 768], f16), T([6, 476, 768], f16), [768], T([6, 476, 1], f32), T([6, 476, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([6, 474], i64), 0), {})
+Operator: aten.neg.default
+cnt: 12, ((T([6, 476, 3072], f16),), {})
+Operator: aten.new_empty_strided.default
+cnt: 1, ((T([1, 6, 474, 768], f16), [1, 6, 474, 768], [2184192, 364032, 768, 1]), {})
+Operator: aten.new_full.default
+cnt: 1, ((T([6, 474], i64), [6, 476], 2457), {'dtype': i64, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([6, 476, 768], f16), [1, 6, 474, 768]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 1, ((T([6, 474], i64), [6, 475]), {'dtype': i64, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 1, ((T([6, 474, 768], f16), [6, 474, 768]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.pow.Tensor_Scalar
+cnt: 12, ((T([6, 476, 3072], f16), 2), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([6, 1, 1, 476], f16), 1.0), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([6, 474], f16, stride=(0, 0)), [6, 474, 2], 2, 1), {})
+cnt: 1, ((T([6, 474], f16, stride=(0, 0)), [6, 474, 2], 2, 0), {})
+Operator: aten.slice_backward.default
+cnt: 2, ((T([6, 474, 2], f16), [6, 474, 2], 1, 0, 9223372036854775807, 1), {})
+cnt: 2, ((T([6, 474, 2], f16), [6, 474, 2], 0, 0, 9223372036854775807, 1), {})
+cnt: 1, ((T([6, 474, 768], f16), [6, 476, 768], 1, 1, -1, 1), {})
+cnt: 1, ((T([6, 476, 768], f16), [6, 476, 768], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.stack.default
+cnt: 1, (([T([6, 474, 768], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2844, 2], f16), [0], True), {})
+cnt: 60, ((T([2856, 768], f16), [0], True), {})
+cnt: 12, ((T([2856, 3072], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 2, ((T([6, 474], f16, stride=(948, 2)),), {})
+Operator: aten.sum.dim_IntList
+cnt: 1, ((T([6, 474], b8), [-1]), {})
+cnt: 2, ((T([6, 474], i64), [-1]), {})
+Operator: aten.tanh.default
+cnt: 1, ((T([6, 768], f16),), {})
+Operator: aten.unbind.int
+cnt: 1, ((T([1, 6, 474, 768], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Albert_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Albert_training.txt
new file mode 100644
index 0000000000000..9dc41c8ff4684
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Albert_training.txt
@@ -0,0 +1,110 @@
+Operator: aten._softmax.default
+cnt: 12, ((T([8, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([8, 12, 512, 512], f16), T([8, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([8, 1, 1, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([8, 12, 512, 64], f16), [96, 512, 64]), {})
+cnt: 12, ((T([8, 12, 64, 512], f16), [96, 64, 512]), {})
+cnt: 12, ((T([96, 512, 512], f16), [8, 12, 512, 512]), {})
+cnt: 12, ((T([96, 512, 64], f16), [8, 12, 512, 64]), {})
+cnt: 36, ((T([8, 512, 12, 64], f16), [8, 512, 768]), {})
+cnt: 12, ((T([8, 512, 768], f16), [4096, 768]), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([8, 512, 128], f16), T([8, 512, 128], f16)), {})
+cnt: 12, ((T([8, 12, 512, 512], f16), T([8, 1, 1, 512], f16)), {})
+cnt: 72, ((T([8, 512, 768], f16), T([8, 512, 768], f16)), {})
+cnt: 36, ((T([8, 512, 3072], f16), T([8, 512, 3072], f16)), {})
+cnt: 12, ((T([8, 512, 3072], f16), 1.0), {})
+cnt: 1, ((T([8, 512, 128], f16), 1.0), {})
+cnt: 99, ((T([768], f16), T([768], f16)), {})
+cnt: 11, ((T([768, 3072], f16), T([768, 3072], f16)), {})
+cnt: 11, ((T([3072], f16), T([3072], f16)), {})
+cnt: 11, ((T([3072, 768], f16), T([3072, 768], f16)), {})
+cnt: 44, ((T([768, 768], f16), T([768, 768], f16)), {})
+cnt: 1, ((T([30000, 128], f16), T([30000, 128], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([8, 512, 128], f16), T([1, 512, 128], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([768], f16), T([4096, 128], f16), T([128, 768], f16, stride=(1, 128))), {})
+cnt: 48, ((T([768], f16), T([4096, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([4096, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([4096, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([128], f16), T([4096, 768], f16), T([768, 128], f16, stride=(1, 768))), {})
+cnt: 1, ((T([30000], f16), T([4096, 128], f16), T([128, 30000], f16, stride=(1, 128))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([96, 512, 64], f16), T([96, 64, 512], f16)), {})
+cnt: 12, ((T([96, 512, 512], f16), T([96, 512, 64], f16)), {})
+cnt: 12, ((T([96, 512, 512], f16, stride=(262144, 1, 512)), T([96, 512, 64], f16)), {})
+cnt: 12, ((T([96, 512, 64], f16), T([96, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([96, 64, 512], f16, stride=(32768, 1, 64)), T([96, 512, 512], f16)), {})
+cnt: 12, ((T([96, 512, 512], f16), T([96, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([8, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([8, 512], i64), T([8, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([8, 12, 512, 512], f16), 8.0), {})
+cnt: 2, ((T([], f16), 122880000), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30000, 128], f16), T([8, 512], i64), 0), {})
+cnt: 1, ((T([2, 128], f16), T([8, 512], i64, stride=(0, 1))), {})
+cnt: 1, ((T([512, 128], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 128], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([8, 512, 128], f16), T([8, 512], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([8, 512, 128], f16), T([8, 512], i64), 30000, 0, False), {})
+Operator: aten.mm.default
+cnt: 1, ((T([4096, 30000], f16, stride=(0, 0)), T([30000, 128], f16)), {})
+cnt: 1, ((T([30000, 4096], f16, stride=(0, 0)), T([4096, 128], f16)), {})
+cnt: 1, ((T([4096, 128], f16), T([128, 768], f16)), {})
+cnt: 1, ((T([128, 4096], f16, stride=(1, 128)), T([4096, 768], f16)), {})
+cnt: 12, ((T([4096, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 3072], f16)), {})
+cnt: 12, ((T([4096, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 4096], f16, stride=(1, 3072)), T([4096, 768], f16)), {})
+cnt: 48, ((T([4096, 768], f16), T([768, 768], f16)), {})
+cnt: 48, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 768], f16)), {})
+cnt: 1, ((T([4096, 768], f16), T([768, 128], f16)), {})
+cnt: 1, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 128], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 1, ((T([8, 512, 128], f16), 3.0), {})
+cnt: 12, ((T([8, 512, 3072], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([8, 1, 1, 512], f16), -65504.0), {})
+cnt: 24, ((T([8, 512, 3072], f16), 0.5), {})
+cnt: 24, ((T([8, 512, 3072], f16), 0.044715), {})
+cnt: 24, ((T([8, 512, 3072], f16), 0.7978845608028654), {})
+cnt: 48, ((T([8, 512, 3072], f16), T([8, 512, 3072], f16)), {})
+cnt: 2, ((T([8, 512, 128], f16), 0.5), {})
+cnt: 2, ((T([8, 512, 128], f16), 0.044715), {})
+cnt: 2, ((T([8, 512, 128], f16), 0.7978845608028654), {})
+cnt: 4, ((T([8, 512, 128], f16), T([8, 512, 128], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 2, ((T([8, 512, 128], f16), [128], T([128], f16), T([128], f16), 1e-12), {})
+cnt: 24, ((T([8, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 2, ((T([8, 512, 128], f16), T([8, 512, 128], f16), [128], T([8, 512, 1], f32), T([8, 512, 1], f32), T([128], f16), T([128], f16), [True, True, True]), {})
+cnt: 24, ((T([8, 512, 768], f16), T([8, 512, 768], f16), [768], T([8, 512, 1], f32), T([8, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 12, ((T([8, 512, 3072], f16), 3.0), {})
+cnt: 1, ((T([8, 512, 128], f16), 3.0), {})
+cnt: 1, ((T([8, 512, 128], f16), 2.0), {})
+cnt: 12, ((T([8, 512, 3072], f16), 2.0), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([8, 1, 1, 512], f16), 1.0), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([4096, 30000], f16, stride=(0, 0)), [0], True), {})
+cnt: 1, ((T([4096, 128], f16), [0], True), {})
+cnt: 61, ((T([4096, 768], f16), [0], True), {})
+cnt: 12, ((T([4096, 3072], f16), [0], True), {})
+cnt: 1, ((T([8, 512, 128], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([8, 512, 30000], f16),), {})
+Operator: aten.tanh.default
+cnt: 12, ((T([8, 512, 3072], f16),), {})
+cnt: 1, ((T([8, 512, 128], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 1, ((T([8, 512, 128], f16), T([8, 512, 128], f16)), {})
+cnt: 12, ((T([8, 512, 3072], f16), T([8, 512, 3072], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Bart_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Bart_training.txt
new file mode 100644
index 0000000000000..96ff5f455b082
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Bart_training.txt
@@ -0,0 +1,76 @@
+Operator: aten._softmax.default
+cnt: 18, ((T([48, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 18, ((T([48, 512, 512], f16), T([48, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([512, 512], f32),), {'dtype': f16})
+cnt: 1, ((T([4, 1, 512, 512], f16, stride=(0, 262144, 512, 1)),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 54, ((T([4, 512, 12, 64], f16), [4, 512, 768]), {})
+cnt: 1, ((T([2048, 50265], f16), [4, 512, 50265]), {})
+cnt: 18, ((T([4, 12, 512, 64], f16), [48, 512, 64]), {})
+cnt: 18, ((T([4, 512, 768], f16), [2048, 768]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([4, 512], i64, stride=(0, 1)), 2), {})
+cnt: 97, ((T([4, 512, 768], f16), T([4, 512, 768], f16)), {})
+cnt: 1, ((T([512], i64), 1), {})
+cnt: 6, ((T([4, 12, 512, 512], f16), T([4, 1, 512, 512], f16)), {})
+cnt: 1, ((T([4, 512, 50265], f16), T([1, 50265], f16)), {})
+cnt: 2, ((T([50265, 768], f16), T([50265, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 72, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+Operator: aten.any.default
+cnt: 12, ((T([4, 512, 768], b8),), {})
+Operator: aten.bmm.default
+cnt: 36, ((T([48, 512, 64], f16), T([48, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 36, ((T([48, 512, 512], f16), T([48, 512, 64], f16)), {})
+cnt: 18, ((T([48, 512, 512], f16, stride=(262144, 1, 512)), T([48, 512, 64], f16)), {})
+cnt: 18, ((T([48, 64, 512], f16, stride=(32768, 1, 64)), T([48, 512, 512], f16)), {})
+Operator: aten.clone.default
+cnt: 2, ((T([4, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 2, ((T([4, 512], i64), T([4, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 102942720), {})
+Operator: aten.embedding.default
+cnt: 2, ((T([50265, 768], f16), T([4, 512], i64), 1), {})
+cnt: 2, ((T([1026, 768], f16), T([4, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 2, ((T([4, 512, 768], f16), T([4, 512], i64), 1026, -1, False), {})
+cnt: 2, ((T([4, 512, 768], f16), T([4, 512], i64), 50265, 1, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([4, 512, 3072], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([4, 512, 3072], f16), T([4, 512, 3072], f16)), {})
+Operator: aten.isinf.default
+cnt: 6, ((T([4, 512, 768], f16),), {})
+Operator: aten.isnan.default
+cnt: 6, ((T([4, 512, 768], f16),), {})
+Operator: aten.lt.Tensor
+cnt: 1, ((T([512], i64), T([512, 1], i64)), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([512, 512], f32), T([512, 512], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 768], f16), T([768, 50265], f16, stride=(1, 768))), {})
+cnt: 1, ((T([50265, 2048], f16, stride=(0, 0)), T([2048, 768], f16)), {})
+cnt: 1, ((T([2048, 50265], f16, stride=(0, 0)), T([50265, 768], f16)), {})
+cnt: 12, ((T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+cnt: 12, ((T([2048, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+cnt: 72, ((T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 72, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([4, 512, 768], f16), 1.0), {})
+cnt: 36, ((T([4, 512, 768], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 32, ((T([4, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 32, ((T([4, 512, 768], f16), T([4, 512, 768], f16), [768], T([4, 512, 1], f32), T([4, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.sum.SymInt
+cnt: 84, ((T([2048, 768], f16), [0], True), {})
+cnt: 12, ((T([2048, 3072], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([4, 512, 50265], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Bert_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Bert_training.txt
new file mode 100644
index 0000000000000..59a786f127ce5
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Bert_training.txt
@@ -0,0 +1,76 @@
+Operator: aten._softmax.default
+cnt: 12, ((T([4, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([4, 12, 512, 512], f16), T([4, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([4, 1, 1, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([4, 12, 512, 64], f16), [48, 512, 64]), {})
+cnt: 12, ((T([4, 12, 64, 512], f16), [48, 64, 512]), {})
+cnt: 12, ((T([48, 512, 512], f16), [4, 12, 512, 512]), {})
+cnt: 12, ((T([48, 512, 64], f16), [4, 12, 512, 64]), {})
+cnt: 24, ((T([4, 512, 12, 64], f16), [4, 512, 768]), {})
+cnt: 12, ((T([4, 512, 768], f16), [2048, 768]), {})
+Operator: aten.add.Tensor
+cnt: 73, ((T([4, 512, 768], f16), T([4, 512, 768], f16)), {})
+cnt: 12, ((T([4, 12, 512, 512], f16), T([4, 1, 1, 512], f16)), {})
+cnt: 1, ((T([30522, 768], f16), T([30522, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([4, 512, 768], f16), T([1, 512, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 49, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([30522], f16), T([2048, 768], f16), T([768, 30522], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([48, 512, 64], f16), T([48, 64, 512], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16), T([48, 512, 64], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16, stride=(262144, 1, 512)), T([48, 512, 64], f16)), {})
+cnt: 12, ((T([48, 512, 64], f16), T([48, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([48, 64, 512], f16, stride=(32768, 1, 64)), T([48, 512, 512], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16), T([48, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([4, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([4, 512], i64), T([4, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([4, 12, 512, 512], f16), 8.0), {})
+cnt: 2, ((T([], f16), 62509056), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([4, 512], i64), 0), {})
+cnt: 1, ((T([2, 768], f16), T([4, 512], i64, stride=(0, 1))), {})
+cnt: 1, ((T([512, 768], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([4, 512, 768], f16), T([4, 512], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([4, 512, 768], f16), T([4, 512], i64), 30522, 0, False), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([4, 512, 3072], f16),), {})
+cnt: 1, ((T([4, 512, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([4, 512, 768], f16), T([4, 512, 768], f16)), {})
+cnt: 12, ((T([4, 512, 3072], f16), T([4, 512, 3072], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 30522], f16, stride=(0, 0)), T([30522, 768], f16)), {})
+cnt: 1, ((T([30522, 2048], f16, stride=(0, 0)), T([2048, 768], f16)), {})
+cnt: 49, ((T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 49, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+cnt: 12, ((T([2048, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([4, 1, 1, 512], f16), -65504.0), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([4, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([4, 512, 768], f16), T([4, 512, 768], f16), [768], T([4, 512, 1], f32), T([4, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([4, 1, 1, 512], f16), 1.0), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2048, 30522], f16, stride=(0, 0)), [0], True), {})
+cnt: 61, ((T([2048, 768], f16), [0], True), {})
+cnt: 12, ((T([2048, 3072], f16), [0], True), {})
+cnt: 1, ((T([4, 512, 768], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([4, 512, 30522], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_BigBird_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_BigBird_training.txt
new file mode 100644
index 0000000000000..924d9eb843b35
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_BigBird_training.txt
@@ -0,0 +1,235 @@
+Operator: aten._softmax.default
+cnt: 24, ((T([2, 12, 64, 1024], f16), -1, False), {})
+cnt: 24, ((T([2, 12, 64, 448], f16), -1, False), {})
+cnt: 12, ((T([2, 12, 12, 64, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 24, ((T([2, 12, 64, 1024], f16), T([2, 12, 64, 1024], f16), -1, f16), {})
+cnt: 24, ((T([2, 12, 64, 448], f16), T([2, 12, 64, 448], f16), -1, f16), {})
+cnt: 12, ((T([2, 12, 12, 64, 512], f16), T([2, 12, 12, 64, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 12, ((T([2, 1, 12, 64, 192], f32),), {'dtype': f16})
+cnt: 12, ((T([2, 1, 1024, 1], f32),), {'dtype': f16})
+cnt: 12, ((T([2, 1, 1, 1024], f32),), {'dtype': f16})
+cnt: 12, ((T([12, 14, 3], i32),), {'dtype': i64, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 24, ((T([2, 12, 16, 64, 64], f16), [384, 64, 64]), {})
+cnt: 96, ((T([2, 12, 64, 64], f16), [24, 64, 64]), {})
+cnt: 48, ((T([2, 12, 1024, 64], f16), [24, 1024, 64]), {})
+cnt: 24, ((T([2, 12, 12, 64, 64], f16), [288, 64, 64]), {})
+cnt: 24, ((T([2, 12, 12, 192, 64], f16), [288, 192, 64]), {})
+cnt: 24, ((T([2, 12, 12, 64, 64, 1], f16), [24, 768, 64]), {})
+cnt: 48, ((T([2, 12, 64, 64, 1, 1], f16), [24, 64, 64]), {})
+cnt: 24, ((T([2, 1024, 12, 64], f16), [2, 1024, 768]), {})
+cnt: 12, ((T([2, 1024, 768], f16), [2048, 768]), {})
+Operator: aten.add.Tensor
+cnt: 76, ((T([2, 1024, 768], f16), T([2, 1024, 768], f16)), {})
+cnt: 24, ((T([1008], i64), T([1008], i64)), {})
+cnt: 36, ((T([2, 1024, 3072], f16), T([2, 1024, 3072], f16)), {})
+cnt: 12, ((T([2, 1024, 3072], f16), 1.0), {})
+cnt: 1, ((T([2, 1024, 768], f16), 1.0), {})
+cnt: 360, ((T([2, 12, 16, 64, 64], f16), T([2, 12, 16, 64, 64], f16)), {})
+cnt: 36, ((T([2, 12, 12, 64, 512], f16), T([2, 12, 12, 64, 512], f16)), {})
+cnt: 48, ((T([2, 12, 14, 192, 64], f16), T([2, 12, 14, 192, 64], f16)), {})
+cnt: 36, ((T([2, 12, 12, 64, 64], f16), T([2, 12, 12, 64, 64], f16)), {})
+cnt: 24, ((T([2, 12, 1024, 64], f16), T([2, 12, 1024, 64], f16)), {})
+cnt: 12, ((T([2, 12, 1024, 64], f16, stride=(786432, 65536, 1, 1024)), T([2, 12, 1024, 64], f16, stride=(786432, 65536, 1, 1024))), {})
+cnt: 12, ((T([2, 12, 1024, 64], f16, stride=(786432, 65536, 1, 1024)), T([2, 12, 1024, 64], f16)), {})
+cnt: 1, ((T([50358, 768], f16), T([50358, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([2, 1024, 768], f16), T([1, 1024, 768], f16)), {})
+cnt: 24, ((T([2, 12, 64, 1024], f16), T([2, 1, 1, 1024], f16)), {})
+cnt: 24, ((T([2, 12, 64, 448], f16), T([2, 12, 64, 448], f32)), {})
+cnt: 12, ((T([2, 12, 12, 64, 192], f16), T([2, 1, 12, 64, 192], f16)), {})
+cnt: 24, ((T([2, 12, 12, 64, 64], f16), T([2, 1, 1, 1, 64], f16)), {})
+cnt: 12, ((T([2, 12, 12, 64, 192], f16), T([2, 12, 12, 64, 192], f32)), {})
+cnt: 36, ((T([2, 12, 12, 64, 64], f16), T([2, 12, 12, 64, 64], f16)), {})
+Operator: aten.addmm.default
+cnt: 49, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([768], f16), T([2, 768], f16, stride=(786432, 1)), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 1, ((T([50358], f16), T([2048, 768], f16), T([768, 50358], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 48, ((T([24, 64, 64], f16), T([24, 64, 1024], f16, stride=(65536, 1, 64))), {})
+cnt: 48, ((T([24, 64, 1024], f16), T([24, 1024, 64], f16)), {})
+cnt: 48, ((T([24, 64, 64], f16), T([24, 64, 448], f16, stride=(28672, 1, 64))), {})
+cnt: 48, ((T([24, 64, 448], f16), T([24, 448, 64], f16)), {})
+cnt: 48, ((T([288, 64, 64], f16), T([288, 64, 192], f16, stride=(12288, 1, 64))), {})
+cnt: 24, ((T([24, 768, 64], f16), T([24, 64, 64], f16)), {})
+cnt: 24, ((T([288, 64, 192], f16, stride=(32768, 512, 1)), T([288, 192, 64], f16)), {})
+cnt: 24, ((T([24, 768, 64], f16, stride=(393216, 512, 1)), T([24, 64, 64], f16)), {})
+cnt: 24, ((T([24, 1024, 64], f16, stride=(65536, 1, 1024)), T([24, 64, 64], f16)), {})
+cnt: 24, ((T([24, 64, 64], f16, stride=(4096, 1, 64)), T([24, 64, 1024], f16)), {})
+cnt: 24, ((T([24, 448, 64], f16, stride=(28672, 1, 448)), T([24, 64, 64], f16)), {})
+cnt: 24, ((T([24, 64, 64], f16, stride=(4096, 1, 64)), T([24, 64, 448], f16)), {})
+cnt: 24, ((T([24, 64, 768], f16, stride=(393216, 1, 512)), T([24, 768, 64], f16)), {})
+cnt: 48, ((T([24, 768, 64], f16), T([24, 64, 64], f16, stride=(4096, 1, 64))), {})
+cnt: 24, ((T([288, 192, 64], f16, stride=(32768, 1, 512)), T([288, 64, 64], f16)), {})
+cnt: 24, ((T([24, 64, 768], f16, stride=(49152, 1, 64)), T([24, 768, 64], f16)), {})
+cnt: 24, ((T([288, 64, 64], f16, stride=(4096, 1, 64)), T([288, 64, 192], f16)), {})
+cnt: 24, ((T([288, 64, 192], f16), T([288, 192, 64], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([2, 12, 64], f32, stride=(1024, 64, 1)), T([2, 12, 64], f32, stride=(1024, 64, 1)), T([2, 12, 64], f32, stride=(1024, 64, 1))], 2), {})
+cnt: 12, (([T([1, 12, 14, 3], i64), T([1, 12, 14, 3], i64)],), {})
+cnt: 48, (([T([2, 12, 64, 64], f16, stride=(786432, 64, 768, 1)), T([2, 12, 64, 64], f16, stride=(786432, 64, 768, 1)), T([2, 12, 64, 64], f16, stride=(786432, 64, 768, 1)), T([2, 12, 64, 64], f16, stride=(786432, 64, 768, 1)), T([2, 12, 192, 64], f16, stride=(2064384, 172032, 64, 1))], 2), {})
+cnt: 12, (([T([2, 1, 1, 192], f16, stride=(1024, 1024, 1024, 1)), T([2, 1, 1, 64], f16, stride=(1024, 1024, 1024, 1)), T([2, 1, 1, 192], f16)], 3), {})
+cnt: 24, (([T([2, 12, 64, 256], f32), T([2, 12, 64, 192], f32, stride=(2064384, 172032, 192, 1))], 3), {})
+cnt: 24, (([T([2, 12, 12, 64, 64], f16, stride=(786432, 64, 49152, 768, 1)), T([2, 12, 12, 64, 64], f16, stride=(786432, 64, 49152, 768, 1)), T([2, 12, 12, 64, 64], f16, stride=(786432, 64, 49152, 768, 1))], 3), {})
+cnt: 12, (([T([2, 12, 12, 64, 64], f16), T([2, 12, 12, 64, 192], f16), T([2, 12, 12, 64, 192], f16), T([2, 12, 12, 64, 64], f16)], -1), {})
+cnt: 12, (([T([2, 1, 1, 64], f16, stride=(1024, 1024, 1024, 1)), T([2, 1, 1, 192], f16, stride=(1024, 1024, 1024, 1)), T([2, 1, 1, 192], f16)], 3), {})
+cnt: 12, (([T([2, 12, 1, 64, 64], f16), T([2, 12, 1, 64, 64], f16), T([2, 12, 12, 64, 64], f16), T([2, 12, 1, 64, 64], f16), T([2, 12, 1, 64, 64], f16)], 2), {})
+Operator: aten.clone.default
+cnt: 1, ((T([2, 1024], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([2, 1024], i64), T([2, 1024], i64)), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16), T([2, 12, 12, 64, 64], f16, stride=(786432, 64, 49152, 768, 1))), {})
+cnt: 36, ((T([288, 64, 64], f16), T([288, 64, 64], f16)), {})
+cnt: 36, ((T([2, 12, 12, 64, 64], f16), T([2, 12, 12, 64, 64], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 103133184), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50358, 768], f16), T([2, 1024], i64), 0), {})
+cnt: 1, ((T([2, 768], f16), T([2, 1024], i64, stride=(0, 1))), {})
+cnt: 1, ((T([4096, 768], f16), T([1, 1024], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 1024, 768], f16), T([1, 1024], i64), 4096, -1, False), {})
+cnt: 1, ((T([2, 1024, 768], f16), T([2, 1024], i64, stride=(0, 1)), 2, -1, False), {})
+cnt: 1, ((T([2, 1024, 768], f16), T([2, 1024], i64), 50358, 0, False), {})
+Operator: aten.floor_divide.default
+cnt: 24, ((T([1008], i64), 42), {})
+Operator: aten.index.Tensor
+cnt: 24, ((T([16, 64], f32), [T([504], i64)]), {})
+Operator: aten.index_add.default
+cnt: 24, ((T([384, 64, 64], f16), 0, T([1008], i64), T([1008, 64, 64], f16)), {})
+Operator: aten.index_select.default
+cnt: 24, ((T([384, 64, 64], f16), 0, T([1008], i64)), {})
+Operator: aten.minimum.default
+cnt: 24, ((T([2, 1, 1, 448], f16), T([2, 12, 64, 448], f32)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 50358], f16, stride=(0, 0)), T([50358, 768], f16)), {})
+cnt: 1, ((T([50358, 2048], f16, stride=(0, 0)), T([2048, 768], f16)), {})
+cnt: 49, ((T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 49, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+cnt: 12, ((T([2048, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 1, ((T([2, 1024, 768], f16), 3.0), {})
+cnt: 12, ((T([2, 1024, 3072], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([2, 12, 64, 1], f32, stride=(1024, 64, 1, 1)), T([2, 12, 1, 192], f32)), {})
+cnt: 12, ((T([2, 1, 14, 64, 1], f32, stride=(1024, 1, 64, 1, 1)), T([2, 12, 14, 1, 192], f32)), {})
+cnt: 24, ((T([1008], i64), 16), {})
+cnt: 48, ((T([2, 12, 64, 1024], f16), 0.125), {})
+cnt: 24, ((T([2, 1, 1, 1024], f16), -10000.0), {})
+cnt: 48, ((T([2, 12, 64, 448], f16), 0.125), {})
+cnt: 24, ((T([2, 12, 64, 448], f32), -10000.0), {})
+cnt: 24, ((T([2, 12, 12, 64, 192], f16), 0.125), {})
+cnt: 24, ((T([2, 12, 12, 64, 64], f16), 0.125), {})
+cnt: 12, ((T([2, 1, 12, 64, 192], f16), -10000.0), {})
+cnt: 24, ((T([2, 1, 1, 1, 64], f16), -10000.0), {})
+cnt: 12, ((T([2, 12, 12, 64, 192], f32), -10000.0), {})
+cnt: 12, ((T([2, 12, 1024, 64], f16), T([2, 1, 1024, 1], f16)), {})
+cnt: 24, ((T([2, 1024, 3072], f16), 0.5), {})
+cnt: 24, ((T([2, 1024, 3072], f16), 0.044715), {})
+cnt: 24, ((T([2, 1024, 3072], f16), 0.7978845608028654), {})
+cnt: 48, ((T([2, 1024, 3072], f16), T([2, 1024, 3072], f16)), {})
+cnt: 2, ((T([2, 1024, 768], f16), 0.5), {})
+cnt: 2, ((T([2, 1024, 768], f16), 0.044715), {})
+cnt: 2, ((T([2, 1024, 768], f16), 0.7978845608028654), {})
+cnt: 4, ((T([2, 1024, 768], f16), T([2, 1024, 768], f16)), {})
+cnt: 12, ((T([2, 12, 1024, 64], f16, stride=(786432, 64, 768, 1)), T([2, 1, 1024, 1], f16)), {})
+cnt: 24, ((T([2, 12, 12, 64, 64], f16, stride=(4718592, 393216, 32768, 512, 1)), 0.125), {})
+cnt: 24, ((T([2, 12, 12, 64, 192], f16, stride=(4718592, 393216, 32768, 512, 1)), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([2, 1024, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([2, 1024, 768], f16), T([2, 1024, 768], f16), [768], T([2, 1024, 1], f32), T([2, 1024, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 36, ((T([288, 64, 64], f16), [288, 64, 64], [4096, 64, 1]), {})
+Operator: aten.new_ones.default
+cnt: 24, ((T([2, 1, 1, 1024], f16), [2, 1, 1, 192]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 24, ((T([2, 12, 14, 64, 192], f32), [2, 12, 64, 256]), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_zeros.default
+cnt: 12, ((T([2, 12, 12, 64, 64], f16, stride=(786432, 64, 49152, 768, 1)), [1179648]), {})
+cnt: 24, ((T([1008, 64, 64], f16), [384, 64, 64]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.pow.Tensor_Scalar
+cnt: 12, ((T([2, 1024, 3072], f16), 3.0), {})
+cnt: 1, ((T([2, 1024, 768], f16), 3.0), {})
+cnt: 1, ((T([2, 1024, 768], f16), 2.0), {})
+cnt: 12, ((T([2, 1024, 3072], f16), 2.0), {})
+Operator: aten.rsub.Scalar
+cnt: 24, ((T([2, 1, 1, 1024], f16), 1.0), {})
+cnt: 24, ((T([2, 12, 64, 448], f32), 1.0), {})
+cnt: 12, ((T([2, 1, 12, 64, 192], f16), 1.0), {})
+cnt: 24, ((T([2, 1, 1, 1, 64], f16, stride=(1024, 1024, 1024, 64, 1)), 1.0), {})
+cnt: 12, ((T([2, 12, 12, 64, 192], f32, stride=(2064384, 172032, 12288, 192, 1)), 1.0), {})
+Operator: aten.select_backward.default
+cnt: 24, ((T([2, 12, 64, 64], f16), [2, 12, 16, 64, 64], 2, -1), {})
+cnt: 12, ((T([2, 12, 64, 64], f16), [2, 12, 16, 64, 64], 2, -2), {})
+cnt: 12, ((T([2, 12, 192, 64], f16, stride=(344064, 28672, 64, 1)), [2, 12, 14, 192, 64], 2, -1), {})
+cnt: 24, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [2, 12, 16, 64, 64], 2, -1), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [2, 12, 16, 64, 64], 2, -2), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [2, 12, 16, 64, 64], 2, -3), {})
+cnt: 24, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [2, 12, 16, 64, 64], 2, 0), {})
+cnt: 12, ((T([2, 12, 192, 64], f16, stride=(344064, 28672, 1, 448)), [2, 12, 14, 192, 64], 2, -1), {})
+cnt: 24, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [2, 12, 16, 64, 64], 2, -1), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [2, 12, 16, 64, 64], 2, -2), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [2, 12, 16, 64, 64], 2, -3), {})
+cnt: 24, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [2, 12, 16, 64, 64], 2, 0), {})
+cnt: 24, ((T([2, 12, 64, 64], f16), [2, 12, 16, 64, 64], 2, 0), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(49152, 4096, 1, 64)), [2, 12, 16, 64, 64], 2, -1), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(49152, 4096, 1, 64)), [2, 12, 16, 64, 64], 2, 0), {})
+cnt: 12, ((T([2, 12, 64, 64], f16), [2, 12, 16, 64, 64], 2, 1), {})
+cnt: 12, ((T([2, 12, 192, 64], f16, stride=(344064, 28672, 64, 1)), [2, 12, 14, 192, 64], 2, 0), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [2, 12, 16, 64, 64], 2, 2), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 64, 1)), [2, 12, 16, 64, 64], 2, 1), {})
+cnt: 12, ((T([2, 12, 192, 64], f16, stride=(344064, 28672, 1, 448)), [2, 12, 14, 192, 64], 2, 0), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [2, 12, 16, 64, 64], 2, 2), {})
+cnt: 12, ((T([2, 12, 64, 64], f16, stride=(344064, 28672, 1, 448)), [2, 12, 16, 64, 64], 2, 1), {})
+Operator: aten.slice_backward.default
+cnt: 372, ((T([2, 12, 16, 64, 64], f16), [2, 12, 16, 64, 64], 1, 0, 9223372036854775807, 1), {})
+cnt: 372, ((T([2, 12, 16, 64, 64], f16), [2, 12, 16, 64, 64], 0, 0, 9223372036854775807, 1), {})
+cnt: 72, ((T([2, 12, 14, 192, 64], f16), [2, 12, 14, 192, 64], 1, 0, 9223372036854775807, 1), {})
+cnt: 72, ((T([2, 12, 14, 192, 64], f16), [2, 12, 14, 192, 64], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16), [2, 12, 12, 64, 512], 4, -64, 9223372036854775807, 1), {})
+cnt: 48, ((T([2, 12, 12, 64, 512], f16), [2, 12, 12, 64, 512], 3, 0, 9223372036854775807, 1), {})
+cnt: 48, ((T([2, 12, 12, 64, 512], f16), [2, 12, 12, 64, 512], 2, 0, 9223372036854775807, 1), {})
+cnt: 48, ((T([2, 12, 12, 64, 512], f16), [2, 12, 12, 64, 512], 1, 0, 9223372036854775807, 1), {})
+cnt: 48, ((T([2, 12, 12, 64, 512], f16), [2, 12, 12, 64, 512], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16), [2, 12, 12, 64, 512], 4, 0, 64, 1), {})
+cnt: 12, ((T([2, 12, 12, 192, 64], f16), [2, 12, 14, 192, 64], 2, 1, -1, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 192], f16), [2, 12, 12, 64, 512], 4, 256, -64, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 192], f16), [2, 12, 12, 64, 512], 4, 64, 256, 1), {})
+cnt: 12, ((T([2, 12, 12, 192, 64], f16, stride=(1769472, 147456, 12288, 1, 192)), [2, 12, 14, 192, 64], 2, 1, -1, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16), [2, 12, 16, 64, 64], 2, 2, -2, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 64, 1)), [2, 12, 16, 64, 64], 2, 3, -1, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 64, 1)), [2, 12, 16, 64, 64], 2, 2, -2, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 64, 1)), [2, 12, 16, 64, 64], 2, 1, -3, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 1, 192)), [2, 12, 16, 64, 64], 2, 3, -1, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 1, 192)), [2, 12, 16, 64, 64], 2, 2, -2, 1), {})
+cnt: 12, ((T([2, 12, 12, 64, 64], f16, stride=(1769472, 147456, 12288, 1, 192)), [2, 12, 16, 64, 64], 2, 1, -3, 1), {})
+Operator: aten.stack.default
+cnt: 12, (([T([504, 64], f32), T([504, 64], f32)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2048, 50358], f16, stride=(0, 0)), [0], True), {})
+cnt: 61, ((T([2048, 768], f16), [0], True), {})
+cnt: 12, ((T([2048, 3072], f16), [0], True), {})
+cnt: 1, ((T([2, 1024, 768], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([2, 1024, 50358], f16),), {})
+Operator: aten.tanh.default
+cnt: 12, ((T([2, 1024, 3072], f16),), {})
+cnt: 1, ((T([2, 768], f16),), {})
+cnt: 1, ((T([2, 1024, 768], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 1, ((T([2, 1024, 768], f16), T([2, 1024, 768], f16)), {})
+cnt: 12, ((T([2, 1024, 3072], f16), T([2, 1024, 3072], f16)), {})
+Operator: aten.unbind.int
+cnt: 12, ((T([2, 16, 64], f32),), {})
+cnt: 12, ((T([2, 12, 14, 3], i64),), {})
+Operator: aten.unsqueeze_.default
+cnt: 1, ((T([2, 12, 64, 192], f32), 1), {})
+cnt: 12, ((T([12, 14, 3], i64), 0), {})
+cnt: 48, ((T([2, 12, 64, 64], f16), 2), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_DistilBert_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_DistilBert_training.txt
new file mode 100644
index 0000000000000..225446dad9dd3
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_DistilBert_training.txt
@@ -0,0 +1,73 @@
+Operator: aten._softmax.default
+cnt: 6, ((T([8, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 6, ((T([8, 12, 512, 512], f16), T([8, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 18, ((T([8, 12, 512, 64], f16), [96, 512, 64]), {})
+cnt: 6, ((T([8, 12, 64, 512], f16), [96, 64, 512]), {})
+cnt: 6, ((T([96, 512, 512], f16), [8, 12, 512, 512]), {})
+cnt: 6, ((T([96, 512, 64], f16), [8, 12, 512, 64]), {})
+cnt: 12, ((T([8, 512, 12, 64], f16), [8, 512, 768]), {})
+cnt: 6, ((T([8, 512, 768], f16), [4096, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([8, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 36, ((T([8, 512, 768], f16), T([8, 512, 768], f16)), {})
+cnt: 1, ((T([30522, 768], f16), T([30522, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 25, ((T([768], f16), T([4096, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 6, ((T([3072], f16), T([4096, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 6, ((T([768], f16), T([4096, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([30522], f16), T([4096, 768], f16), T([768, 30522], f16, stride=(1, 768))), {})
+Operator: aten.bmm.default
+cnt: 6, ((T([96, 512, 64], f16), T([96, 64, 512], f16)), {})
+cnt: 6, ((T([96, 512, 512], f16), T([96, 512, 64], f16)), {})
+cnt: 6, ((T([96, 512, 512], f16, stride=(262144, 1, 512)), T([96, 512, 64], f16)), {})
+cnt: 6, ((T([96, 512, 64], f16), T([96, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 6, ((T([96, 64, 512], f16, stride=(32768, 1, 64)), T([96, 512, 512], f16)), {})
+cnt: 6, ((T([96, 512, 512], f16), T([96, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([8, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([8, 512], i64), T([8, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 6, ((T([8, 12, 512, 64], f16, stride=(393216, 64, 768, 1)), 8.0), {})
+cnt: 2, ((T([], f16), 125018112), {})
+cnt: 6, ((T([8, 12, 512, 64], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([30522, 768], f16), T([8, 512], i64), 0), {})
+cnt: 1, ((T([512, 768], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 512, -1, False), {})
+cnt: 1, ((T([8, 512, 768], f16), T([8, 512], i64), 30522, 0, False), {})
+Operator: aten.eq.Scalar
+cnt: 6, ((T([8, 512], f32), 0), {})
+Operator: aten.gelu.default
+cnt: 6, ((T([8, 512, 3072], f16),), {})
+cnt: 1, ((T([8, 512, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([8, 512, 768], f16), T([8, 512, 768], f16)), {})
+cnt: 6, ((T([8, 512, 3072], f16), T([8, 512, 3072], f16)), {})
+Operator: aten.masked_fill.Scalar
+cnt: 6, ((T([8, 12, 512, 512], f16), T([8, 12, 512, 512], b8, stride=(512, 0, 0, 1)), 0), {})
+Operator: aten.masked_fill.Tensor
+cnt: 6, ((T([8, 12, 512, 512], f16), T([8, 12, 512, 512], b8, stride=(512, 0, 0, 1)), T([], f32)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([4096, 30522], f16, stride=(0, 0)), T([30522, 768], f16)), {})
+cnt: 1, ((T([30522, 4096], f16, stride=(0, 0)), T([4096, 768], f16)), {})
+cnt: 25, ((T([4096, 768], f16), T([768, 768], f16)), {})
+cnt: 25, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 768], f16)), {})
+cnt: 6, ((T([4096, 768], f16), T([768, 3072], f16)), {})
+cnt: 6, ((T([768, 4096], f16, stride=(1, 768)), T([4096, 3072], f16)), {})
+cnt: 6, ((T([4096, 3072], f16), T([3072, 768], f16)), {})
+cnt: 6, ((T([3072, 4096], f16, stride=(1, 3072)), T([4096, 768], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 14, ((T([8, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-12), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 14, ((T([8, 512, 768], f16), T([8, 512, 768], f16), [768], T([8, 512, 1], f32), T([8, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([4096, 30522], f16, stride=(0, 0)), [0], True), {})
+cnt: 31, ((T([4096, 768], f16), [0], True), {})
+cnt: 6, ((T([4096, 3072], f16), [0], True), {})
+cnt: 1, ((T([8, 512, 768], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([8, 512, 30522], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_GPT2_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_GPT2_training.txt
new file mode 100644
index 0000000000000..7a2ca611a2ec2
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_GPT2_training.txt
@@ -0,0 +1,88 @@
+Operator: aten._softmax.default
+cnt: 12, ((T([4, 12, 512, 512], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([4, 12, 512, 512], f16), T([4, 12, 512, 512], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 12, ((T([1, 1, 512, 512], u8, stride=(1048576, 1048576, 1024, 1)),), {'dtype': torch.bool})
+cnt: 12, ((T([], f16),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([4, 12, 512, 64], f16), [48, 512, 64]), {})
+cnt: 12, ((T([4, 12, 64, 512], f16), [48, 64, 512]), {})
+cnt: 12, ((T([48, 512, 512], f16), [4, 12, 512, 512]), {})
+cnt: 12, ((T([48, 512, 64], f16), [4, 12, 512, 64]), {})
+cnt: 1, ((T([2048, 50257], f16), [4, 512, 50257]), {})
+cnt: 24, ((T([4, 512, 12, 64], f16), [4, 512, 768]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([4, 512, 768], f16), T([1, 512, 768], f16)), {})
+cnt: 48, ((T([4, 512, 768], f16), T([4, 512, 768], f16)), {})
+cnt: 36, ((T([4, 512, 3072], f16), T([4, 512, 3072], f16)), {})
+cnt: 12, ((T([4, 512, 3072], f16), 1.0), {})
+cnt: 1, ((T([50257, 768], f16), T([50257, 768], f16)), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([2304], f16), T([2048, 768], f16), T([768, 2304], f16)), {})
+cnt: 12, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 12, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16)), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([48, 512, 64], f16), T([48, 64, 512], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16), T([48, 512, 64], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16, stride=(262144, 1, 512)), T([48, 512, 64], f16)), {})
+cnt: 12, ((T([48, 512, 64], f16), T([48, 64, 512], f16, stride=(32768, 1, 64))), {})
+cnt: 12, ((T([48, 64, 512], f16, stride=(32768, 1, 64)), T([48, 512, 512], f16)), {})
+cnt: 12, ((T([48, 512, 512], f16), T([48, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.cat.default
+cnt: 12, (([T([4, 512, 768], f16), T([4, 512, 768], f16, stride=(393216, 1, 512)), T([4, 512, 768], f16)], 2), {})
+Operator: aten.clone.default
+cnt: 1, ((T([4, 512], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([4, 512], i64), T([4, 512], i64)), {})
+Operator: aten.div.Tensor
+cnt: 24, ((T([4, 12, 512, 512], f16), T([], f16)), {})
+cnt: 2, ((T([], f16), 102926336), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50257, 768], f16), T([4, 512], i64)), {})
+cnt: 1, ((T([1024, 768], f16), T([1, 512], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([1, 512, 768], f16), T([1, 512], i64), 1024, -1, False), {})
+cnt: 1, ((T([4, 512, 768], f16), T([4, 512], i64), 50257, -1, False), {})
+Operator: aten.mm.default
+cnt: 1, ((T([2048, 768], f16), T([768, 50257], f16, stride=(1, 768))), {})
+cnt: 1, ((T([50257, 2048], f16, stride=(0, 0)), T([2048, 768], f16)), {})
+cnt: 1, ((T([2048, 50257], f16, stride=(0, 0)), T([50257, 768], f16)), {})
+cnt: 12, ((T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+cnt: 12, ((T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 2304], f16), T([2304, 768], f16, stride=(1, 2304))), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 2304], f16)), {})
+Operator: aten.mul.Scalar
+cnt: 12, ((T([4, 512, 3072], f16), 3.0), {})
+Operator: aten.mul.Tensor
+cnt: 24, ((T([4, 512, 3072], f16), 0.5), {})
+cnt: 24, ((T([4, 512, 3072], f16), 0.044715), {})
+cnt: 24, ((T([4, 512, 3072], f16), 0.7978845608028654), {})
+cnt: 48, ((T([4, 512, 3072], f16), T([4, 512, 3072], f16)), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([4, 512, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([4, 512, 768], f16), T([4, 512, 768], f16), [768], T([4, 512, 1], f32), T([4, 512, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.pow.Tensor_Scalar
+cnt: 12, ((T([4, 512, 3072], f16), 3.0), {})
+cnt: 12, ((T([4, 512, 3072], f16), 2.0), {})
+Operator: aten.split.Tensor
+cnt: 12, ((T([4, 512, 2304], f16), 768, 2), {})
+Operator: aten.sum.SymInt
+cnt: 24, ((T([2048, 768], f16), [0], True), {})
+cnt: 12, ((T([2048, 3072], f16), [0], True), {})
+cnt: 12, ((T([2048, 2304], f16), [0], True), {})
+cnt: 1, ((T([4, 512, 768], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([4, 512, 50257], f16),), {})
+Operator: aten.tanh.default
+cnt: 12, ((T([4, 512, 3072], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 12, ((T([4, 512, 3072], f16), T([4, 512, 3072], f16)), {})
+Operator: aten.where.self
+cnt: 24, ((T([1, 1, 512, 512], b8), T([4, 12, 512, 512], f16), T([], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Longformer_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Longformer_training.txt
new file mode 100644
index 0000000000000..23725d8af4314
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/hf_Longformer_training.txt
@@ -0,0 +1,189 @@
+Operator: aten._softmax.default
+cnt: 12, ((T([2, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)), -1, True), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([2, 1024, 12, 513], f32), T([2, 1024, 12, 513], f32), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([2, 1, 1, 1024], f32),), {'dtype': f16})
+cnt: 1, ((T([2, 1024], b8),), {'dtype': i32})
+cnt: 1, ((T([2, 1024], i64),), {'dtype': i32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([2, 1024], i32),), {'dtype': i64})
+cnt: 12, ((T([2, 1024, 1, 1], b8),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 12, ((T([2, 1024, 12, 513], f32),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 12, ((T([2, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([1024, 2, 768], f16), [2048, 768]), {})
+cnt: 36, ((T([2048, 768], f16), [1024, 2, 768]), {})
+cnt: 12, ((T([24, 3, 512, 64, 1], f16), [72, 512, 64]), {})
+cnt: 12, ((T([24, 3, 64, 512, 1], f16), [72, 64, 512]), {})
+cnt: 12, ((T([2, 12, 1024, 513], f16), [24, 4, 256, 513]), {})
+cnt: 12, ((T([24, 4, 768, 64, 1], f16), [96, 768, 64]), {})
+cnt: 24, ((T([1024, 2, 12, 64], f16), [1024, 2, 768]), {})
+cnt: 12, ((T([2, 1024, 768], f16), [2048, 768]), {})
+cnt: 12, ((T([2048, 768], f16), [2, 1024, 768]), {})
+cnt: 12, ((T([2, 12, 1024, 64], f16), [24, 4, 256, 64]), {})
+cnt: 12, ((T([24, 4, 768, 64], i64), [4718592]), {})
+cnt: 12, ((T([24, 3, 512, 64], f16), [2359296]), {})
+cnt: 24, ((T([24, 3, 512, 64], i64), [2359296]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([2, 1024], i64), 1), {})
+cnt: 38, ((T([2, 1024, 768], f16), T([2, 1024, 768], f16)), {})
+cnt: 36, ((T([1024, 2, 768], f16), T([768], f16)), {})
+cnt: 12, ((T([2, 1024, 768], f16), T([768], f16)), {})
+cnt: 1, ((T([], f16), 0), {})
+cnt: 36, ((T([24, 3, 512, 513], f16), T([24, 3, 512, 513], f16)), {})
+cnt: 24, ((T([1024, 2, 768], f16), T([1024, 2, 768], f16)), {})
+cnt: 12, ((T([2, 1024, 768], f16), T([2, 1024, 768], f16, stride=(768, 1536, 1))), {})
+cnt: 1, ((T([50265, 768], f16), T([50265, 768], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 12, ((T([2, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)), T([2, 1024, 1, 513], f16)), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([3072], f16), T([2048, 768], f16), T([768, 3072], f16, stride=(1, 768))), {})
+cnt: 12, ((T([768], f16), T([2048, 3072], f16), T([3072, 768], f16, stride=(1, 3072))), {})
+cnt: 1, ((T([768], f16), T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 1, ((T([50265], f16), T([2048, 768], f16), T([768, 50265], f16, stride=(1, 768))), {})
+Operator: aten.any.default
+cnt: 1, ((T([2048], b8),), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([72, 512, 64], f16), T([72, 64, 512], f16)), {})
+cnt: 12, ((T([96, 256, 768], f16, stride=(197120, 769, 1)), T([96, 768, 64], f16)), {})
+cnt: 12, ((T([96, 768, 256], f16, stride=(197120, 1, 769)), T([96, 256, 64], f16)), {})
+cnt: 12, ((T([96, 256, 64], f16), T([96, 64, 768], f16, stride=(49152, 1, 64))), {})
+cnt: 12, ((T([72, 64, 512], f16, stride=(32768, 1, 64)), T([72, 512, 512], f16)), {})
+cnt: 12, ((T([72, 512, 512], f16), T([72, 512, 64], f16, stride=(32768, 1, 512))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([2, 1024], i64),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 12, ((T([24, 3, 512, 512], f16), [0, 0, 0, 1], 0.0), {})
+cnt: 12, ((T([2, 3, 512, 512], f16), [0, 0, 0, 1], 0.0), {})
+cnt: 12, ((T([24, 1024, 64], f16, stride=(64, 1536, 1)), [0, 0, 256, 256], -1.0), {})
+cnt: 12, ((T([24, 4, 256, 513], f16), [0, 257], 0.0), {})
+cnt: 12, ((T([24, 4, 256, 770], f16), [0, -257]), {})
+cnt: 12, ((T([24, 1536, 64], f16), [0, 0, -256, -256]), {})
+cnt: 12, ((T([24, 3, 513, 512], f16), [0, 0, 0, -1]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([2, 1024], i64), T([2, 1024], i64)), {})
+cnt: 12, ((T([24, 3, 256, 257], f16, stride=(525312, 131328, 513, 1)), T([24, 3, 256, 257], f16, stride=(787968, 262656, 513, 1))), {})
+cnt: 12, ((T([24, 256, 257], f16, stride=(525312, 513, 1)), T([24, 256, 257], f16, stride=(787968, 513, 1))), {})
+cnt: 12, ((T([24, 3, 256, 256], f16, stride=(525312, 131328, 513, 1)), T([24, 3, 256, 256], f16, stride=(787968, 262656, 513, 1))), {})
+cnt: 12, ((T([24, 255, 255], f16, stride=(525312, 513, 1)), T([24, 255, 255], f16, stride=(787968, 513, 1))), {})
+cnt: 12, ((T([2, 3, 256, 257], f16, stride=(525312, 131328, 513, 1)), T([2, 3, 256, 257], f16, stride=(787968, 262656, 513, 1))), {})
+cnt: 12, ((T([2, 256, 257], f16, stride=(525312, 513, 1)), T([2, 256, 257], f16, stride=(787968, 513, 1))), {})
+cnt: 12, ((T([2, 3, 256, 256], f16, stride=(525312, 131328, 513, 1)), T([2, 3, 256, 256], f16, stride=(787968, 262656, 513, 1))), {})
+cnt: 12, ((T([2, 255, 255], f16, stride=(525312, 513, 1)), T([2, 255, 255], f16, stride=(787968, 513, 1))), {})
+cnt: 24, ((T([2, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)), T([2, 1024, 12, 513], f16)), {})
+cnt: 84, ((T([24, 4, 256, 513], f16), T([24, 4, 256, 513], f16)), {})
+cnt: 24, ((T([2, 256, 12, 257], f16, stride=(6303744, 513, 525312, 1)), T([2, 256, 12, 257], f16)), {})
+cnt: 12, ((T([24, 255, 255], f16, stride=(525312, 513, 1)), T([24, 255, 255], f16)), {})
+cnt: 12, ((T([24, 3, 256, 256], f16, stride=(525312, 131328, 513, 1)), T([24, 3, 256, 256], f16)), {})
+cnt: 12, ((T([24, 256, 257], f16, stride=(525312, 513, 1)), T([24, 256, 257], f16)), {})
+Operator: aten.cumsum.default
+cnt: 1, ((T([2, 1024], i32), 1), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 102942720), {})
+cnt: 2, ((T([], f16), 1), {})
+cnt: 12, ((T([1024, 2, 768], f16), 8.0), {})
+Operator: aten.div_.Tensor
+cnt: 12, ((T([1024, 2, 768], f16), 8.0), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([50265, 768], f16), T([2, 1024], i64), 1), {})
+cnt: 1, ((T([4098, 768], f16), T([2, 1024], i64), 1), {})
+cnt: 1, ((T([1, 768], f16), T([2, 1024], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([2, 1024, 768], f16), T([2, 1024], i64), 1, -1, False), {})
+cnt: 1, ((T([2, 1024, 768], f16), T([2, 1024], i64), 4098, 1, False), {})
+cnt: 1, ((T([2, 1024, 768], f16), T([2, 1024], i64), 50265, 1, False), {})
+Operator: aten.eq.Scalar
+cnt: 24, ((T([2, 256, 12, 257], f16, stride=(0, 257, 0, 1)), 1), {})
+cnt: 24, ((T([2, 256, 1, 257], f16, stride=(0, 257, 257, 1)), 1), {})
+Operator: aten.flip.default
+cnt: 24, ((T([256, 257], f16), [0]), {})
+cnt: 24, ((T([1, 256, 1, 257], f16), [1, 3]), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([2, 1024, 3072], f16),), {})
+cnt: 1, ((T([2, 1024, 768], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([2, 1024, 768], f16), T([2, 1024, 768], f16)), {})
+cnt: 12, ((T([2, 1024, 3072], f16), T([2, 1024, 3072], f16)), {})
+Operator: aten.gt.Scalar
+cnt: 1, ((T([2, 1024], f16), 0), {})
+Operator: aten.index_add_.default
+cnt: 12, ((T([2359296], f16), 0, T([4718592], i64), T([4718592], f16)), {})
+cnt: 24, ((T([1572864], f16), 0, T([2359296], i64), T([2359296], f16)), {})
+Operator: aten.lt.Scalar
+cnt: 1, ((T([2, 1024], f16), 0), {})
+Operator: aten.masked_fill.Scalar
+cnt: 12, ((T([2, 1024, 1, 1], f16), T([2, 1024, 1, 1], b8), -65504.0), {})
+cnt: 12, ((T([2, 1024, 12, 513], f32), T([2, 1024, 1, 1], b8), 0.0), {})
+cnt: 12, ((T([2, 1024, 12, 513], f32, stride=(6303744, 513, 525312, 1)), T([2, 1024, 1, 1], b8), 0), {})
+cnt: 24, ((T([2, 256, 12, 257], f16), T([2, 256, 12, 257], b8), 0), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 24, ((T([2, 256, 12, 257], f16, stride=(6303744, 513, 525312, 1)), T([2, 256, 12, 257], b8), -inf), {})
+cnt: 24, ((T([2, 256, 1, 257], f16, stride=(525312, 513, 525312, 1)), T([2, 256, 1, 257], b8), -inf), {})
+Operator: aten.mm.default
+cnt: 48, ((T([2048, 768], f16), T([768, 768], f16, stride=(1, 768))), {})
+cnt: 1, ((T([2048, 50265], f16, stride=(0, 0)), T([50265, 768], f16)), {})
+cnt: 1, ((T([50265, 2048], f16, stride=(0, 0)), T([2048, 768], f16)), {})
+cnt: 49, ((T([2048, 768], f16), T([768, 768], f16)), {})
+cnt: 49, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 768], f16)), {})
+cnt: 12, ((T([2048, 768], f16), T([768, 3072], f16)), {})
+cnt: 12, ((T([768, 2048], f16, stride=(1, 768)), T([2048, 3072], f16)), {})
+cnt: 12, ((T([2048, 3072], f16), T([3072, 768], f16)), {})
+cnt: 12, ((T([3072, 2048], f16, stride=(1, 3072)), T([2048, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([2, 1, 1, 1024], f16), -65504.0), {})
+cnt: 1, ((T([2, 1024], i32), T([2, 1024], i32)), {})
+cnt: 12, ((T([2, 3, 512, 1], f16, stride=(1024, 256, 1, 1)), T([2, 3, 1, 512], f16, stride=(1024, 256, 1, 1))), {})
+Operator: aten.native_layer_norm.default
+cnt: 26, ((T([2, 1024, 768], f16), [768], T([768], f16), T([768], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 26, ((T([2, 1024, 768], f16), T([2, 1024, 768], f16), [768], T([2, 1024, 1], f32), T([2, 1024, 1], f32), T([768], f16), T([768], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 1, ((T([2, 1024], i64), 1), {})
+cnt: 12, ((T([2, 1024], f16), 0), {})
+Operator: aten.new_empty.default
+cnt: 12, ((T([24, 3, 512, 513], f16), [24, 4, 256, 513]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 12, ((T([2, 3, 512, 513], f16), [2, 4, 256, 513]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_empty_strided.default
+cnt: 84, ((T([24, 4, 256, 513], f16), [24, 4, 256, 513], [525312, 131328, 513, 1]), {})
+Operator: aten.new_ones.default
+cnt: 12, ((T([2, 1024, 12, 513], f16, stride=(6303744, 513, 525312, 1)), [256, 257]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 12, ((T([2, 1024, 1, 1], f16), [2, 1024, 1, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 12, ((T([2, 1024, 1, 513], f16), [256, 257]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_zeros.default
+cnt: 12, ((T([24, 4, 768, 64], f16), [2359296]), {})
+cnt: 12, ((T([2, 1024, 12, 513], f16), [12607488]), {})
+cnt: 12, ((T([24, 3, 512, 64], f16, stride=(98304, 32768, 1, 512)), [1572864]), {})
+cnt: 12, ((T([24, 3, 512, 64], f16), [1572864]), {})
+Operator: aten.rsub.Scalar
+cnt: 1, ((T([2, 1, 1, 1024], f16), 1.0), {})
+Operator: aten.select_backward.default
+cnt: 12, ((T([24, 512, 513], f16), [24, 3, 512, 513], 1, 0), {})
+cnt: 12, ((T([24, 512, 513], f16), [24, 3, 512, 513], 1, -1), {})
+Operator: aten.slice_backward.default
+cnt: 12, ((T([24, 4, 256, 768], f16), [24, 4, 256, 769], 3, 0, -1, 1), {})
+cnt: 12, ((T([24, 4, 256, 769], f16), [24, 4, 256, 769], 2, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 4, 256, 769], f16), [24, 4, 256, 769], 1, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 4, 256, 769], f16), [24, 4, 256, 769], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 4, 196864], f16), [24, 4, 197120], 2, 0, -256, 1), {})
+cnt: 12, ((T([24, 4, 197120], f16), [24, 4, 197120], 1, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 4, 197120], f16), [24, 4, 197120], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 255, 255], f16), [24, 255, 513], 2, -255, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 255, 513], f16), [24, 512, 513], 1, 0, 255, 1), {})
+cnt: 48, ((T([24, 3, 512, 513], f16), [24, 3, 512, 513], 0, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 3, 256, 256], f16), [24, 3, 256, 513], 3, 257, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 3, 256, 513], f16), [24, 3, 512, 513], 2, -257, -1, 1), {})
+cnt: 24, ((T([24, 3, 512, 513], f16), [24, 3, 512, 513], 1, 0, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 256, 257], f16), [24, 256, 513], 2, 0, 257, 1), {})
+cnt: 12, ((T([24, 256, 513], f16), [24, 512, 513], 1, 256, 9223372036854775807, 1), {})
+cnt: 12, ((T([24, 3, 256, 257], f16), [24, 3, 256, 513], 3, 0, 257, 1), {})
+cnt: 12, ((T([24, 3, 256, 513], f16), [24, 3, 512, 513], 2, 0, 256, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([2048, 50265], f16, stride=(0, 0)), [0], True), {})
+cnt: 13, ((T([2048, 768], f16), [0], True), {})
+cnt: 12, ((T([2048, 3072], f16), [0], True), {})
+cnt: 12, ((T([2, 1024, 768], f16), [0, 1], True), {})
+cnt: 36, ((T([1024, 2, 768], f16), [0, 1], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([2, 1024, 50265], f16),), {})
+Operator: aten.tril.default
+cnt: 24, ((T([256, 257], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/maml_omniglot_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/maml_omniglot_training.txt
new file mode 100644
index 0000000000000..3121d116ddddd
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/maml_omniglot_training.txt
@@ -0,0 +1,49 @@
+Operator: aten.addmm.default
+cnt: 1, ((T([5], f16), T([5, 64], f16), T([64, 5], f16, stride=(1, 64))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([5, 1, 28, 28], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([5, 1, 28, 28], f16), T([64, 1, 3, 3], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([5, 64, 13, 13], f16, stride=(10816, 1, 832, 64)), T([64, 64, 3, 3], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([5, 64, 5, 5], f16, stride=(1600, 1, 320, 64)), T([64, 64, 3, 3], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([5, 64, 3, 3], f16, stride=(576, 1, 192, 64)), T([5, 64, 5, 5], f16, stride=(1600, 1, 320, 64)), T([64, 64, 3, 3], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([5, 64, 11, 11], f16, stride=(7744, 1, 704, 64)), T([5, 64, 13, 13], f16, stride=(10816, 1, 832, 64)), T([64, 64, 3, 3], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([5, 64, 26, 26], f16, stride=(43264, 1, 1664, 64)), T([5, 1, 28, 28], f16), T([64, 1, 3, 3], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([5, 1, 28, 28], f16), T([5, 1, 28, 28], f16)), {})
+cnt: 2, ((T([64, 64, 3, 3], f16), T([64, 64, 3, 3], f16, stride=(576, 1, 192, 64))), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 25), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([5, 64, 26, 26], f16, stride=(43264, 1, 1664, 64)), [2, 2], [2, 2]), {})
+cnt: 1, ((T([5, 64, 11, 11], f16, stride=(7744, 1, 704, 64)), [2, 2], [2, 2]), {})
+cnt: 1, ((T([5, 64, 3, 3], f16, stride=(576, 1, 192, 64)), [2, 2], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([5, 64, 1, 1], f16), T([5, 64, 3, 3], f16, stride=(576, 1, 192, 64)), [2, 2], [2, 2], [0, 0], [1, 1], False, T([5, 64, 1, 1], i64)), {})
+cnt: 1, ((T([5, 64, 5, 5], f16, stride=(1600, 1, 320, 64)), T([5, 64, 11, 11], f16, stride=(7744, 1, 704, 64)), [2, 2], [2, 2], [0, 0], [1, 1], False, T([5, 64, 5, 5], i64, stride=(1600, 1, 320, 64))), {})
+cnt: 1, ((T([5, 64, 13, 13], f16, stride=(10816, 1, 832, 64)), T([5, 64, 26, 26], f16, stride=(43264, 1, 1664, 64)), [2, 2], [2, 2], [0, 0], [1, 1], False, T([5, 64, 13, 13], i64, stride=(10816, 1, 832, 64))), {})
+Operator: aten.mm.default
+cnt: 2, ((T([5, 5], f16, stride=(0, 0)), T([5, 64], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([5, 64, 26, 26], f16, stride=(43264, 1, 1664, 64)), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 1.0, 1e-05), {})
+cnt: 1, ((T([5, 64, 11, 11], f16, stride=(7744, 1, 704, 64)), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 1.0, 1e-05), {})
+cnt: 1, ((T([5, 64, 3, 3], f16, stride=(576, 1, 192, 64)), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 1.0, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([5, 64, 3, 3], f16, stride=(576, 1, 192, 64)), T([5, 64, 3, 3], f16, stride=(576, 1, 192, 64)), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([5, 64, 11, 11], f16, stride=(7744, 1, 704, 64)), T([5, 64, 11, 11], f16, stride=(7744, 1, 704, 64)), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([5, 64, 26, 26], f16, stride=(43264, 1, 1664, 64)), T([5, 64, 26, 26], f16, stride=(43264, 1, 1664, 64)), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 2, ((T([64, 64, 3, 3], f16, stride=(576, 1, 192, 64)), [64, 64, 3, 3], [576, 9, 3, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.relu_.default
+cnt: 1, ((T([5, 64, 26, 26], f16, stride=(43264, 1, 1664, 64)),), {})
+cnt: 1, ((T([5, 64, 11, 11], f16, stride=(7744, 1, 704, 64)),), {})
+cnt: 1, ((T([5, 64, 3, 3], f16, stride=(576, 1, 192, 64)),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([5, 5], f16, stride=(0, 0)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([5, 5], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([5, 64, 3, 3], f16, stride=(576, 1, 192, 64)), T([5, 64, 3, 3], f16, stride=(576, 1, 192, 64)), 0), {})
+cnt: 1, ((T([5, 64, 11, 11], f16, stride=(7744, 1, 704, 64)), T([5, 64, 11, 11], f16, stride=(7744, 1, 704, 64)), 0), {})
+cnt: 1, ((T([5, 64, 26, 26], f16, stride=(43264, 1, 1664, 64)), T([5, 64, 26, 26], f16, stride=(43264, 1, 1664, 64)), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/mnasnet1_0_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/mnasnet1_0_training.txt
new file mode 100644
index 0000000000000..4f81a114632cc
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/mnasnet1_0_training.txt
@@ -0,0 +1,163 @@
+Operator: aten.add.Tensor
+cnt: 4, ((T([32, 24, 56, 56], f16), T([32, 24, 56, 56], f16)), {})
+cnt: 4, ((T([32, 40, 28, 28], f16), T([32, 40, 28, 28], f16)), {})
+cnt: 4, ((T([32, 80, 14, 14], f16), T([32, 80, 14, 14], f16)), {})
+cnt: 2, ((T([32, 96, 14, 14], f16), T([32, 96, 14, 14], f16)), {})
+cnt: 6, ((T([32, 192, 7, 7], f16), T([32, 192, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([48, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 48, 112, 112], f16), T([48, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 48), {})
+cnt: 1, ((T([32, 48, 56, 56], f16), T([24, 48, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 24, 56, 56], f16), T([72, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 72, 56, 56], f16), T([72, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 72), {})
+cnt: 2, ((T([32, 72, 56, 56], f16), T([24, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 72, 56, 56], f16), T([72, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 72), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([40, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([120, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([120, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 120), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([40, 120, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 40, 28, 28], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([240, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 80, 14, 14], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 480, 14, 14], f16), T([480, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 480), {})
+cnt: 2, ((T([32, 480, 14, 14], f16), T([80, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([480, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 480), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([96, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 96, 14, 14], f16), T([576, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 576, 14, 14], f16), T([576, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 576), {})
+cnt: 1, ((T([32, 576, 14, 14], f16), T([96, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 576, 14, 14], f16), T([576, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 576), {})
+cnt: 1, ((T([32, 576, 7, 7], f16), T([192, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1152), {})
+cnt: 3, ((T([32, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1152), {})
+cnt: 1, ((T([32, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 1280, 7, 7], f16), T([32, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 320, 7, 7], f16), T([32, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 4, ((T([32, 1152, 7, 7], f16), T([32, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 192, 7, 7], f16), T([32, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 1, ((T([32, 192, 7, 7], f16), T([32, 576, 7, 7], f16), T([192, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 576, 7, 7], f16), T([32, 576, 14, 14], f16), T([576, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 576, [True, True, False]), {})
+cnt: 2, ((T([32, 576, 14, 14], f16), T([32, 96, 14, 14], f16), T([576, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 96, 14, 14], f16), T([32, 576, 14, 14], f16), T([96, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 576, 14, 14], f16), T([32, 576, 14, 14], f16), T([576, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 576, [True, True, False]), {})
+cnt: 1, ((T([32, 96, 14, 14], f16), T([32, 480, 14, 14], f16), T([96, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16), T([480, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 3, ((T([32, 480, 14, 14], f16), T([32, 80, 14, 14], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 80, 14, 14], f16), T([32, 480, 14, 14], f16), T([80, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16), T([480, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 1, ((T([32, 80, 14, 14], f16), T([32, 240, 14, 14], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 28, 28], f16), T([240, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([32, 40, 28, 28], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([32, 120, 28, 28], f16), T([40, 120, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([32, 120, 28, 28], f16), T([120, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([32, 40, 28, 28], f16), T([120, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 40, 28, 28], f16), T([32, 72, 28, 28], f16), T([40, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([32, 72, 56, 56], f16), T([72, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 3, ((T([32, 72, 56, 56], f16), T([32, 24, 56, 56], f16), T([72, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 24, 56, 56], f16), T([32, 72, 56, 56], f16), T([24, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 72, 56, 56], f16), T([32, 72, 56, 56], f16), T([72, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 1, ((T([32, 24, 56, 56], f16), T([32, 48, 56, 56], f16), T([24, 48, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 48, 56, 56], f16), T([32, 48, 112, 112], f16), T([48, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 48, [True, True, False]), {})
+cnt: 1, ((T([32, 48, 112, 112], f16), T([32, 16, 112, 112], f16), T([48, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 32, 112, 112], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 1280, 7, 7], f16, stride=(1280, 1, 0, 0)), 49), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 32000), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 1280, 7, 7], f16), [2, 3]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(0, 0)), T([32, 1280], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 1, ((T([32, 48, 112, 112], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 1, ((T([32, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 3, ((T([32, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 5, ((T([32, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 3, ((T([32, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 4, ((T([32, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 3, ((T([32, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 6, ((T([32, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 2, ((T([32, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 3, ((T([32, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 1, ((T([32, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 4, ((T([32, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 1, ((T([32, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), False, 0.00029999999999996696, 1e-05), {})
+cnt: 1, ((T([32, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), False, 0.00029999999999996696, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([32, 1280, 7, 7], f16), T([32, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 320, 7, 7], f16), T([32, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), False, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f32), T([1152], f32), False, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 192, 7, 7], f16), T([32, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 576, 7, 7], f16), T([32, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 576, 14, 14], f16), T([32, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 96, 14, 14], f16), T([32, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), False, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 80, 14, 14], f16), T([32, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([32, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 40, 28, 28], f16), T([32, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), False, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 120, 28, 28], f16), T([32, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([32, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([32, 72, 56, 56], f16), T([32, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 24, 56, 56], f16), T([32, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 48, 56, 56], f16), T([32, 48, 56, 56], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 48, 112, 112], f16), T([32, 48, 112, 112], f16), T([48], f16), T([48], f16), T([48], f16), T([48], f32), T([48], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([32, 32, 112, 112], f16),), {})
+cnt: 1, ((T([32, 48, 112, 112], f16),), {})
+cnt: 1, ((T([32, 48, 56, 56], f16),), {})
+cnt: 5, ((T([32, 72, 56, 56], f16),), {})
+cnt: 1, ((T([32, 72, 28, 28], f16),), {})
+cnt: 4, ((T([32, 120, 28, 28], f16),), {})
+cnt: 1, ((T([32, 240, 28, 28], f16),), {})
+cnt: 1, ((T([32, 240, 14, 14], f16),), {})
+cnt: 6, ((T([32, 480, 14, 14], f16),), {})
+cnt: 3, ((T([32, 576, 14, 14], f16),), {})
+cnt: 1, ((T([32, 576, 7, 7], f16),), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16),), {})
+cnt: 1, ((T([32, 1280, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([32, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([32, 1280, 7, 7], f16), T([32, 1280, 7, 7], f16), 0), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16), 0), {})
+cnt: 1, ((T([32, 576, 7, 7], f16), T([32, 576, 7, 7], f16), 0), {})
+cnt: 3, ((T([32, 576, 14, 14], f16), T([32, 576, 14, 14], f16), 0), {})
+cnt: 6, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([32, 240, 28, 28], f16), 0), {})
+cnt: 4, ((T([32, 120, 28, 28], f16), T([32, 120, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([32, 72, 28, 28], f16), 0), {})
+cnt: 5, ((T([32, 72, 56, 56], f16), T([32, 72, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 48, 56, 56], f16), T([32, 48, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 48, 112, 112], f16), T([32, 48, 112, 112], f16), 0), {})
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/mobilenet_v2_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/mobilenet_v2_training.txt
new file mode 100644
index 0000000000000..185ce981ae35d
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/mobilenet_v2_training.txt
@@ -0,0 +1,165 @@
+Operator: aten.add.Tensor
+cnt: 2, ((T([96, 24, 56, 56], f16), T([96, 24, 56, 56], f16)), {})
+cnt: 4, ((T([96, 32, 28, 28], f16), T([96, 32, 28, 28], f16)), {})
+cnt: 6, ((T([96, 64, 14, 14], f16), T([96, 64, 14, 14], f16)), {})
+cnt: 4, ((T([96, 96, 14, 14], f16), T([96, 96, 14, 14], f16)), {})
+cnt: 4, ((T([96, 160, 7, 7], f16), T([96, 160, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([96, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([96, 3, 224, 224], f16),), {})
+cnt: 2, ((T([96, 32, 112, 112], f16),), {})
+cnt: 1, ((T([96, 96, 112, 112], f16),), {})
+cnt: 1, ((T([96, 96, 56, 56], f16),), {})
+cnt: 3, ((T([96, 144, 56, 56], f16),), {})
+cnt: 1, ((T([96, 144, 28, 28], f16),), {})
+cnt: 5, ((T([96, 192, 28, 28], f16),), {})
+cnt: 1, ((T([96, 192, 14, 14], f16),), {})
+cnt: 8, ((T([96, 384, 14, 14], f16),), {})
+cnt: 5, ((T([96, 576, 14, 14], f16),), {})
+cnt: 1, ((T([96, 576, 7, 7], f16),), {})
+cnt: 6, ((T([96, 960, 7, 7], f16),), {})
+cnt: 1, ((T([96, 1280, 7, 7], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([96, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([96, 32, 112, 112], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 16, 112, 112], f16), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 96, 112, 112], f16), T([96, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([96, 96, 56, 56], f16), T([24, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([96, 24, 56, 56], f16), T([144, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 144, 56, 56], f16), T([144, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([96, 144, 56, 56], f16), T([24, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 144, 56, 56], f16), T([144, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([96, 144, 28, 28], f16), T([32, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([96, 32, 28, 28], f16), T([192, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([96, 192, 28, 28], f16), T([192, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 192), {})
+cnt: 2, ((T([96, 192, 28, 28], f16), T([32, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 192, 28, 28], f16), T([192, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 192), {})
+cnt: 1, ((T([96, 192, 14, 14], f16), T([64, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([96, 64, 14, 14], f16), T([384, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([96, 384, 14, 14], f16), T([384, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 384), {})
+cnt: 3, ((T([96, 384, 14, 14], f16), T([64, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 384, 14, 14], f16), T([96, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([96, 96, 14, 14], f16), T([576, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([96, 576, 14, 14], f16), T([576, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 576), {})
+cnt: 2, ((T([96, 576, 14, 14], f16), T([96, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 576, 14, 14], f16), T([576, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 576), {})
+cnt: 1, ((T([96, 576, 7, 7], f16), T([160, 576, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([96, 160, 7, 7], f16), T([960, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([96, 960, 7, 7], f16), T([960, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 960), {})
+cnt: 2, ((T([96, 960, 7, 7], f16), T([160, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 960, 7, 7], f16), T([320, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([96, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([96, 1280, 7, 7], f16), T([96, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 320, 7, 7], f16), T([96, 960, 7, 7], f16), T([320, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([96, 960, 7, 7], f16), T([96, 960, 7, 7], f16), T([960, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 960, [True, True, False]), {})
+cnt: 3, ((T([96, 960, 7, 7], f16), T([96, 160, 7, 7], f16), T([960, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([96, 160, 7, 7], f16), T([96, 960, 7, 7], f16), T([160, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 160, 7, 7], f16), T([96, 576, 7, 7], f16), T([160, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 576, 7, 7], f16), T([96, 576, 14, 14], f16), T([576, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 576, [True, True, False]), {})
+cnt: 3, ((T([96, 576, 14, 14], f16), T([96, 96, 14, 14], f16), T([576, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([96, 96, 14, 14], f16), T([96, 576, 14, 14], f16), T([96, 576, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([96, 576, 14, 14], f16), T([96, 576, 14, 14], f16), T([576, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 576, [True, True, False]), {})
+cnt: 1, ((T([96, 96, 14, 14], f16), T([96, 384, 14, 14], f16), T([96, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([96, 384, 14, 14], f16), T([96, 384, 14, 14], f16), T([384, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 384, [True, True, False]), {})
+cnt: 4, ((T([96, 384, 14, 14], f16), T([96, 64, 14, 14], f16), T([384, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([96, 64, 14, 14], f16), T([96, 384, 14, 14], f16), T([64, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 64, 14, 14], f16), T([96, 192, 14, 14], f16), T([64, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 192, 14, 14], f16), T([96, 192, 28, 28], f16), T([192, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 192, [True, True, False]), {})
+cnt: 3, ((T([96, 192, 28, 28], f16), T([96, 32, 28, 28], f16), T([192, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([96, 32, 28, 28], f16), T([96, 192, 28, 28], f16), T([32, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([96, 192, 28, 28], f16), T([96, 192, 28, 28], f16), T([192, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 192, [True, True, False]), {})
+cnt: 1, ((T([96, 32, 28, 28], f16), T([96, 144, 28, 28], f16), T([32, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 144, 28, 28], f16), T([96, 144, 56, 56], f16), T([144, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 2, ((T([96, 144, 56, 56], f16), T([96, 24, 56, 56], f16), T([144, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 24, 56, 56], f16), T([96, 144, 56, 56], f16), T([24, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 144, 56, 56], f16), T([96, 144, 56, 56], f16), T([144, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 1, ((T([96, 24, 56, 56], f16), T([96, 96, 56, 56], f16), T([24, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 96, 56, 56], f16), T([96, 96, 112, 112], f16), T([96, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([96, 96, 112, 112], f16), T([96, 16, 112, 112], f16), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 16, 112, 112], f16), T([96, 32, 112, 112], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([96, 32, 112, 112], f16), T([96, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([96, 32, 112, 112], f16), T([96, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([96, 3, 224, 224], f16), T([96, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([96, 1280, 7, 7], f16, stride=(1280, 1, 0, 0)), 49), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 96000), {})
+Operator: aten.hardtanh_.default
+cnt: 2, ((T([96, 32, 112, 112], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 96, 112, 112], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 96, 56, 56], f16), 0.0, 6.0), {})
+cnt: 3, ((T([96, 144, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 144, 28, 28], f16), 0.0, 6.0), {})
+cnt: 5, ((T([96, 192, 28, 28], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 192, 14, 14], f16), 0.0, 6.0), {})
+cnt: 8, ((T([96, 384, 14, 14], f16), 0.0, 6.0), {})
+cnt: 5, ((T([96, 576, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 576, 7, 7], f16), 0.0, 6.0), {})
+cnt: 6, ((T([96, 960, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 1280, 7, 7], f16), 0.0, 6.0), {})
+Operator: aten.hardtanh_backward.default
+cnt: 1, ((T([96, 1280, 7, 7], f16), T([96, 1280, 7, 7], f16), 0.0, 6.0), {})
+cnt: 6, ((T([96, 960, 7, 7], f16), T([96, 960, 7, 7], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 576, 7, 7], f16), T([96, 576, 7, 7], f16), 0.0, 6.0), {})
+cnt: 5, ((T([96, 576, 14, 14], f16), T([96, 576, 14, 14], f16), 0.0, 6.0), {})
+cnt: 8, ((T([96, 384, 14, 14], f16), T([96, 384, 14, 14], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 192, 14, 14], f16), T([96, 192, 14, 14], f16), 0.0, 6.0), {})
+cnt: 5, ((T([96, 192, 28, 28], f16), T([96, 192, 28, 28], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 144, 28, 28], f16), T([96, 144, 28, 28], f16), 0.0, 6.0), {})
+cnt: 3, ((T([96, 144, 56, 56], f16), T([96, 144, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 96, 56, 56], f16), T([96, 96, 56, 56], f16), 0.0, 6.0), {})
+cnt: 1, ((T([96, 96, 112, 112], f16), T([96, 96, 112, 112], f16), 0.0, 6.0), {})
+cnt: 2, ((T([96, 32, 112, 112], f16), T([96, 32, 112, 112], f16), 0.0, 6.0), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([96, 1280, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([96, 1000], f16, stride=(0, 0)), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 96], f16, stride=(0, 0)), T([96, 1280], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([96, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([96, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([96, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([96, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), False, 0.1, 1e-05), {})
+cnt: 2, ((T([96, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([96, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([96, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([96, 32, 28, 28], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([96, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([96, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), False, 0.1, 1e-05), {})
+cnt: 4, ((T([96, 64, 14, 14], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 8, ((T([96, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([96, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([96, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([96, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([96, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), False, 0.1, 1e-05), {})
+cnt: 6, ((T([96, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([96, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([96, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([96, 1280, 7, 7], f16), T([96, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([96, 320, 7, 7], f16), T([96, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), False, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([96, 960, 7, 7], f16), T([96, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([96, 160, 7, 7], f16), T([96, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([96, 576, 7, 7], f16), T([96, 576, 7, 7], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([96, 576, 14, 14], f16), T([96, 576, 14, 14], f16), T([576], f16), T([576], f16), T([576], f16), T([576], f32), T([576], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([96, 96, 14, 14], f16), T([96, 96, 14, 14], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), False, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([96, 384, 14, 14], f16), T([96, 384, 14, 14], f16), T([384], f16), T([384], f16), T([384], f16), T([384], f32), T([384], f32), False, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([96, 64, 14, 14], f16), T([96, 64, 14, 14], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([96, 192, 14, 14], f16), T([96, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([96, 192, 28, 28], f16), T([96, 192, 28, 28], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([96, 32, 28, 28], f16), T([96, 32, 28, 28], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([96, 144, 28, 28], f16), T([96, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([96, 144, 56, 56], f16), T([96, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([96, 24, 56, 56], f16), T([96, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([96, 96, 56, 56], f16), T([96, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([96, 96, 112, 112], f16), T([96, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([96, 16, 112, 112], f16), T([96, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([96, 32, 112, 112], f16), T([96, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([96, 1000], f16, stride=(0, 0)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([96, 1000], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/mobilenet_v3_large_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/mobilenet_v3_large_training.txt
new file mode 100644
index 0000000000000..07ba40cf12a53
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/mobilenet_v3_large_training.txt
@@ -0,0 +1,277 @@
+Operator: aten.add.Tensor
+cnt: 2, ((T([32, 960, 7, 7], f16), T([32, 960, 7, 7], f16)), {})
+cnt: 2, ((T([32, 160, 7, 7], f16), T([32, 160, 7, 7], f16)), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 7, 7], f16)), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16)), {})
+cnt: 1, ((T([32, 112, 14, 14], f16), T([32, 112, 14, 14], f16)), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16)), {})
+cnt: 3, ((T([32, 80, 14, 14], f16), T([32, 80, 14, 14], f16)), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([32, 120, 28, 28], f16)), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([32, 40, 28, 28], f16)), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([32, 72, 28, 28], f16)), {})
+cnt: 1, ((T([32, 24, 56, 56], f16), T([32, 24, 56, 56], f16)), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 16, 112, 112], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 16, 112, 112], f16)), {})
+cnt: 1, ((T([32, 24, 56, 56], f16), T([32, 24, 56, 56], f16)), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([32, 40, 28, 28], f16)), {})
+cnt: 3, ((T([32, 80, 14, 14], f16), T([32, 80, 14, 14], f16)), {})
+cnt: 1, ((T([32, 112, 14, 14], f16), T([32, 112, 14, 14], f16)), {})
+cnt: 2, ((T([32, 160, 7, 7], f16), T([32, 160, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1280], f16), T([32, 960], f16), T([960, 1280], f16, stride=(1, 960))), {})
+cnt: 1, ((T([1000], f16), T([32, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+cnt: 1, ((T([32, 16, 112, 112], f16),), {})
+cnt: 1, ((T([32, 240, 28, 28], f16),), {})
+cnt: 1, ((T([32, 240, 14, 14], f16),), {})
+cnt: 2, ((T([32, 200, 14, 14], f16),), {})
+cnt: 4, ((T([32, 184, 14, 14], f16),), {})
+cnt: 2, ((T([32, 480, 14, 14], f16),), {})
+cnt: 3, ((T([32, 672, 14, 14], f16),), {})
+cnt: 1, ((T([32, 672, 7, 7], f16),), {})
+cnt: 5, ((T([32, 960, 7, 7], f16),), {})
+cnt: 1, ((T([32, 1280], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([16, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([16, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([16, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([64, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([64, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 64), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([24, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 24, 56, 56], f16), T([72, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 72, 56, 56], f16), T([72, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 72), {})
+cnt: 1, ((T([32, 72, 56, 56], f16), T([24, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 72, 56, 56], f16), T([72, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 72), {})
+cnt: 1, ((T([32, 72, 1, 1], f16), T([24, 72, 1, 1], f16), T([24], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 24, 1, 1], f16), T([72, 24, 1, 1], f16), T([72], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([40, 72, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([120, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([120, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 120), {})
+cnt: 2, ((T([32, 120, 1, 1], f16), T([32, 120, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 32, 1, 1], f16), T([120, 32, 1, 1], f16), T([120], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([40, 120, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 40, 28, 28], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([240, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 80, 14, 14], f16), T([200, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 200, 14, 14], f16), T([200, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 200), {})
+cnt: 1, ((T([32, 200, 14, 14], f16), T([80, 200, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 80, 14, 14], f16), T([184, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 184, 14, 14], f16), T([184, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 184), {})
+cnt: 2, ((T([32, 184, 14, 14], f16), T([80, 184, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 80, 14, 14], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([480, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 480), {})
+cnt: 1, ((T([32, 480, 1, 1], f16), T([120, 480, 1, 1], f16), T([120], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 120, 1, 1], f16), T([480, 120, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([112, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 112, 14, 14], f16), T([672, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), T([672, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 672), {})
+cnt: 2, ((T([32, 672, 1, 1], f16), T([168, 672, 1, 1], f16), T([168], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 168, 1, 1], f16), T([672, 168, 1, 1], f16), T([672], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), T([112, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([160, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 160, 7, 7], f16), T([960, 160, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 960, 7, 7], f16), T([960, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 960), {})
+cnt: 2, ((T([32, 960, 1, 1], f16), T([240, 960, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 240, 1, 1], f16), T([960, 240, 1, 1], f16), T([960], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 960, 7, 7], f16), T([160, 960, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([32, 960, 7, 7], f16), T([32, 160, 7, 7], f16), T([960, 160, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 160, 7, 7], f16), T([32, 960, 7, 7], f16), T([160, 960, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 960, 1, 1], f16), T([32, 240, 1, 1], f16), T([960, 240, 1, 1], f16), [960], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 240, 1, 1], f16), T([32, 960, 1, 1], f16), T([240, 960, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 960, 7, 7], f16), T([32, 960, 7, 7], f16), T([960, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 960, [True, True, False]), {})
+cnt: 1, ((T([32, 160, 7, 7], f16), T([32, 672, 7, 7], f16), T([160, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 672, 1, 1], f16), T([32, 168, 1, 1], f16), T([672, 168, 1, 1], f16), [672], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 168, 1, 1], f16), T([32, 672, 1, 1], f16), T([168, 672, 1, 1], f16), [168], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 2, ((T([32, 672, 14, 14], f16), T([32, 112, 14, 14], f16), T([672, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 112, 14, 14], f16), T([32, 672, 14, 14], f16), T([112, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16), T([672, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 1, ((T([32, 112, 14, 14], f16), T([32, 480, 14, 14], f16), T([112, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 480, 1, 1], f16), T([32, 120, 1, 1], f16), T([480, 120, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 120, 1, 1], f16), T([32, 480, 1, 1], f16), T([120, 480, 1, 1], f16), [120], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16), T([480, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([32, 80, 14, 14], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 80, 14, 14], f16), T([32, 184, 14, 14], f16), T([80, 184, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 184, 14, 14], f16), T([32, 184, 14, 14], f16), T([184, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 184, [True, True, False]), {})
+cnt: 2, ((T([32, 184, 14, 14], f16), T([32, 80, 14, 14], f16), T([184, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 80, 14, 14], f16), T([32, 200, 14, 14], f16), T([80, 200, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 200, 14, 14], f16), T([32, 200, 14, 14], f16), T([200, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 200, [True, True, False]), {})
+cnt: 1, ((T([32, 200, 14, 14], f16), T([32, 80, 14, 14], f16), T([200, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 80, 14, 14], f16), T([32, 240, 14, 14], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 28, 28], f16), T([240, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([32, 40, 28, 28], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([32, 120, 28, 28], f16), T([40, 120, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 120, 1, 1], f16), T([32, 32, 1, 1], f16), T([120, 32, 1, 1], f16), [120], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 32, 1, 1], f16), T([32, 120, 1, 1], f16), T([32, 120, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([32, 120, 28, 28], f16), T([120, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 120, [True, True, False]), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([32, 40, 28, 28], f16), T([120, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 40, 28, 28], f16), T([32, 72, 28, 28], f16), T([40, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 72, 1, 1], f16), T([32, 24, 1, 1], f16), T([72, 24, 1, 1], f16), [72], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 24, 1, 1], f16), T([32, 72, 1, 1], f16), T([24, 72, 1, 1], f16), [24], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([32, 72, 56, 56], f16), T([72, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 2, ((T([32, 72, 56, 56], f16), T([32, 24, 56, 56], f16), T([72, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 24, 56, 56], f16), T([32, 72, 56, 56], f16), T([24, 72, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 72, 56, 56], f16), T([32, 72, 56, 56], f16), T([72, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 72, [True, True, False]), {})
+cnt: 1, ((T([32, 24, 56, 56], f16), T([32, 64, 56, 56], f16), T([24, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 112, 112], f16), T([64, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 64, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 16, 112, 112], f16), T([64, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 16, 112, 112], f16), T([16, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 16, 112, 112], f16), T([16, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 16, [True, True, False]), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 3, 224, 224], f16), T([16, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 3, ((T([32, 960, 7, 7], f16, stride=(960, 1, 0, 0)), 49), {})
+cnt: 1, ((T([32, 672, 7, 7], f16, stride=(672, 1, 0, 0)), 49), {})
+cnt: 1, ((T([32, 672, 14, 14], f16, stride=(672, 1, 0, 0)), 196), {})
+cnt: 1, ((T([32, 480, 14, 14], f16, stride=(480, 1, 0, 0)), 196), {})
+cnt: 2, ((T([32, 120, 28, 28], f16, stride=(120, 1, 0, 0)), 784), {})
+cnt: 1, ((T([32, 72, 28, 28], f16, stride=(72, 1, 0, 0)), 784), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 32000), {})
+Operator: aten.hardsigmoid.default
+cnt: 1, ((T([32, 72, 1, 1], f16),), {})
+cnt: 2, ((T([32, 120, 1, 1], f16),), {})
+cnt: 1, ((T([32, 480, 1, 1], f16),), {})
+cnt: 2, ((T([32, 672, 1, 1], f16),), {})
+cnt: 2, ((T([32, 960, 1, 1], f16),), {})
+Operator: aten.hardsigmoid_backward.default
+cnt: 2, ((T([32, 960, 1, 1], f16), T([32, 960, 1, 1], f16)), {})
+cnt: 2, ((T([32, 672, 1, 1], f16), T([32, 672, 1, 1], f16)), {})
+cnt: 1, ((T([32, 480, 1, 1], f16), T([32, 480, 1, 1], f16)), {})
+cnt: 2, ((T([32, 120, 1, 1], f16), T([32, 120, 1, 1], f16)), {})
+cnt: 1, ((T([32, 72, 1, 1], f16), T([32, 72, 1, 1], f16)), {})
+Operator: aten.hardswish_.default
+cnt: 1, ((T([32, 16, 112, 112], f16),), {})
+cnt: 1, ((T([32, 240, 28, 28], f16),), {})
+cnt: 1, ((T([32, 240, 14, 14], f16),), {})
+cnt: 2, ((T([32, 200, 14, 14], f16),), {})
+cnt: 4, ((T([32, 184, 14, 14], f16),), {})
+cnt: 2, ((T([32, 480, 14, 14], f16),), {})
+cnt: 3, ((T([32, 672, 14, 14], f16),), {})
+cnt: 1, ((T([32, 672, 7, 7], f16),), {})
+cnt: 5, ((T([32, 960, 7, 7], f16),), {})
+cnt: 1, ((T([32, 1280], f16),), {})
+Operator: aten.hardswish_backward.default
+cnt: 1, ((T([32, 1280], f16), T([32, 1280], f16)), {})
+cnt: 5, ((T([32, 960, 7, 7], f16), T([32, 960, 7, 7], f16)), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 7, 7], f16)), {})
+cnt: 3, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16)), {})
+cnt: 2, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16)), {})
+cnt: 4, ((T([32, 184, 14, 14], f16), T([32, 184, 14, 14], f16)), {})
+cnt: 2, ((T([32, 200, 14, 14], f16), T([32, 200, 14, 14], f16)), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 14, 14], f16)), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([32, 240, 28, 28], f16)), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 16, 112, 112], f16)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 72, 28, 28], f16), [-1, -2], True), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), [-1, -2], True), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), [-1, -2], True), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), [-1, -2], True), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), [-1, -2], True), {})
+cnt: 3, ((T([32, 960, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(0, 0)), T([32, 1280], f16)), {})
+cnt: 1, ((T([32, 1280], f16), T([1280, 960], f16)), {})
+cnt: 1, ((T([1280, 32], f16, stride=(1, 1280)), T([32, 960], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([32, 72, 1, 1], f16), T([32, 72, 28, 28], f16)), {})
+cnt: 2, ((T([32, 120, 1, 1], f16), T([32, 120, 28, 28], f16)), {})
+cnt: 1, ((T([32, 480, 1, 1], f16), T([32, 480, 14, 14], f16)), {})
+cnt: 1, ((T([32, 672, 1, 1], f16), T([32, 672, 14, 14], f16)), {})
+cnt: 1, ((T([32, 672, 1, 1], f16), T([32, 672, 7, 7], f16)), {})
+cnt: 2, ((T([32, 960, 1, 1], f16), T([32, 960, 7, 7], f16)), {})
+cnt: 2, ((T([32, 960, 7, 7], f16), T([32, 960, 1, 1], f16)), {})
+cnt: 2, ((T([32, 960, 7, 7], f16), T([32, 960, 7, 7], f16)), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 1, 1], f16)), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 7, 7], f16)), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), T([32, 672, 1, 1], f16)), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16)), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([32, 480, 1, 1], f16)), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16)), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([32, 120, 1, 1], f16)), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), T([32, 120, 28, 28], f16)), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([32, 72, 1, 1], f16)), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([32, 72, 28, 28], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 3, ((T([32, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), False, 0.01, 0.001), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.01, 0.001), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.01, 0.001), {})
+cnt: 2, ((T([32, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), False, 0.01, 0.001), {})
+cnt: 3, ((T([32, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), False, 0.01, 0.001), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f16), False, 0.01, 0.001), {})
+cnt: 3, ((T([32, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), False, 0.01, 0.001), {})
+cnt: 4, ((T([32, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f16), False, 0.01, 0.001), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), False, 0.01, 0.001), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), False, 0.01, 0.001), {})
+cnt: 4, ((T([32, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), False, 0.01, 0.001), {})
+cnt: 2, ((T([32, 200, 14, 14], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f16), False, 0.01, 0.001), {})
+cnt: 4, ((T([32, 184, 14, 14], f16), T([184], f16), T([184], f16), T([184], f16), T([184], f16), False, 0.01, 0.001), {})
+cnt: 2, ((T([32, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), False, 0.01, 0.001), {})
+cnt: 2, ((T([32, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), False, 0.01, 0.001), {})
+cnt: 3, ((T([32, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), False, 0.01, 0.001), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), False, 0.01, 0.001), {})
+cnt: 3, ((T([32, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), False, 0.01, 0.001), {})
+cnt: 5, ((T([32, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f16), False, 0.01, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 5, ((T([32, 960, 7, 7], f16), T([32, 960, 7, 7], f16), T([960], f16), T([960], f16), T([960], f16), T([960], f32), T([960], f32), False, 0.001, [True, True, True]), {})
+cnt: 3, ((T([32, 160, 7, 7], f16), T([32, 160, 7, 7], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), False, 0.001, [True, True, True]), {})
+cnt: 3, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), False, 0.001, [True, True, True]), {})
+cnt: 2, ((T([32, 112, 14, 14], f16), T([32, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), False, 0.001, [True, True, True]), {})
+cnt: 2, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), False, 0.001, [True, True, True]), {})
+cnt: 4, ((T([32, 80, 14, 14], f16), T([32, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), False, 0.001, [True, True, True]), {})
+cnt: 4, ((T([32, 184, 14, 14], f16), T([32, 184, 14, 14], f16), T([184], f16), T([184], f16), T([184], f16), T([184], f32), T([184], f32), False, 0.001, [True, True, True]), {})
+cnt: 2, ((T([32, 200, 14, 14], f16), T([32, 200, 14, 14], f16), T([200], f16), T([200], f16), T([200], f16), T([200], f32), T([200], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([32, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), False, 0.001, [True, True, True]), {})
+cnt: 3, ((T([32, 40, 28, 28], f16), T([32, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), False, 0.001, [True, True, True]), {})
+cnt: 4, ((T([32, 120, 28, 28], f16), T([32, 120, 28, 28], f16), T([120], f16), T([120], f16), T([120], f16), T([120], f32), T([120], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([32, 72, 28, 28], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), False, 0.001, [True, True, True]), {})
+cnt: 3, ((T([32, 72, 56, 56], f16), T([32, 72, 56, 56], f16), T([72], f16), T([72], f16), T([72], f16), T([72], f32), T([72], f32), False, 0.001, [True, True, True]), {})
+cnt: 2, ((T([32, 24, 56, 56], f16), T([32, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 0.001, [True, True, True]), {})
+cnt: 3, ((T([32, 16, 112, 112], f16), T([32, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), False, 0.001, [True, True, True]), {})
+Operator: aten.relu.default
+cnt: 1, ((T([32, 24, 1, 1], f16),), {})
+cnt: 2, ((T([32, 32, 1, 1], f16),), {})
+cnt: 1, ((T([32, 120, 1, 1], f16),), {})
+cnt: 2, ((T([32, 168, 1, 1], f16),), {})
+cnt: 2, ((T([32, 240, 1, 1], f16),), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([32, 16, 112, 112], f16),), {})
+cnt: 1, ((T([32, 64, 112, 112], f16),), {})
+cnt: 1, ((T([32, 64, 56, 56], f16),), {})
+cnt: 3, ((T([32, 72, 56, 56], f16),), {})
+cnt: 1, ((T([32, 72, 28, 28], f16),), {})
+cnt: 4, ((T([32, 120, 28, 28], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), [0], True), {})
+cnt: 1, ((T([32, 1280], f16), [0], True), {})
+cnt: 2, ((T([32, 960, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 2, ((T([32, 120, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), [2, 3], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([32, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([32, 240, 1, 1], f16), T([32, 240, 1, 1], f16), 0), {})
+cnt: 2, ((T([32, 168, 1, 1], f16), T([32, 168, 1, 1], f16), 0), {})
+cnt: 1, ((T([32, 120, 1, 1], f16), T([32, 120, 1, 1], f16), 0), {})
+cnt: 2, ((T([32, 32, 1, 1], f16), T([32, 32, 1, 1], f16), 0), {})
+cnt: 4, ((T([32, 120, 28, 28], f16), T([32, 120, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 24, 1, 1], f16), T([32, 24, 1, 1], f16), 0), {})
+cnt: 1, ((T([32, 72, 28, 28], f16), T([32, 72, 28, 28], f16), 0), {})
+cnt: 3, ((T([32, 72, 56, 56], f16), T([32, 72, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), 0), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 16, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/nvidia_deeprecommender_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/nvidia_deeprecommender_training.txt
new file mode 100644
index 0000000000000..438f2289338e9
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/nvidia_deeprecommender_training.txt
@@ -0,0 +1,36 @@
+Operator: aten.addmm.default
+cnt: 1, ((T([512], f16), T([256, 197951], f16), T([197951, 512], f16, stride=(1, 197951))), {})
+cnt: 2, ((T([512], f16), T([256, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 1, ((T([1024], f16), T([256, 512], f16), T([512, 1024], f16, stride=(1, 512))), {})
+cnt: 1, ((T([512], f16), T([256, 1024], f16), T([1024, 512], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([197951], f16), T([256, 512], f16), T([512, 197951], f16, stride=(1, 512))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([256, 197951], f16),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([256, 197951], f16), T([256, 197951], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 50675456), {})
+Operator: aten.elu.default
+cnt: 4, ((T([256, 512], f16), 1.6732632423543772, 1.0507009873554805), {})
+cnt: 1, ((T([256, 1024], f16), 1.6732632423543772, 1.0507009873554805), {})
+cnt: 1, ((T([256, 197951], f16), 1.6732632423543772, 1.0507009873554805), {})
+Operator: aten.elu_backward.default
+cnt: 1, ((T([256, 197951], f16, stride=(0, 0)), 1.6732632423543772, 1.0507009873554805, 1, False, T([256, 197951], f16)), {})
+cnt: 4, ((T([256, 512], f16), 1.6732632423543772, 1.0507009873554805, 1, False, T([256, 512], f16)), {})
+cnt: 1, ((T([256, 1024], f16), 1.6732632423543772, 1.0507009873554805, 1, False, T([256, 1024], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([256, 197951], f16), T([197951, 512], f16)), {})
+cnt: 1, ((T([197951, 256], f16, stride=(1, 197951)), T([256, 512], f16)), {})
+cnt: 2, ((T([256, 512], f16), T([512, 512], f16)), {})
+cnt: 2, ((T([512, 256], f16, stride=(1, 512)), T([256, 512], f16)), {})
+cnt: 1, ((T([256, 512], f16), T([512, 1024], f16)), {})
+cnt: 1, ((T([512, 256], f16, stride=(1, 512)), T([256, 1024], f16)), {})
+cnt: 1, ((T([256, 1024], f16), T([1024, 512], f16)), {})
+cnt: 1, ((T([1024, 256], f16, stride=(1, 1024)), T([256, 512], f16)), {})
+cnt: 1, ((T([512, 256], f16, stride=(1, 512)), T([256, 197951], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([256, 197951], f16), [0], True), {})
+cnt: 4, ((T([256, 512], f16), [0], True), {})
+cnt: 1, ((T([256, 1024], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([256, 197951], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_CycleGAN_and_pix2pix_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_CycleGAN_and_pix2pix_training.txt
new file mode 100644
index 0000000000000..81c5a051ffe89
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_CycleGAN_and_pix2pix_training.txt
@@ -0,0 +1,67 @@
+Operator: aten.add.Tensor
+cnt: 18, ((T([1, 256, 64, 64], f16), T([1, 256, 64, 64], f16)), {})
+Operator: aten.clone.default
+cnt: 1, ((T([1, 3, 256, 256], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([1, 3, 262, 262], f16), T([64, 3, 7, 7], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 64, 256, 256], f16), T([128, 64, 3, 3], f16), T([128], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 128, 128, 128], f16), T([256, 128, 3, 3], f16), T([256], f16), [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 18, ((T([1, 256, 66, 66], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 256, 64, 64], f16), T([256, 128, 3, 3], f16), T([128], f16), [2, 2], [1, 1], [1, 1], True, [1, 1], 1), {})
+cnt: 1, ((T([1, 128, 128, 128], f16), T([128, 64, 3, 3], f16), T([64], f16), [2, 2], [1, 1], [1, 1], True, [1, 1], 1), {})
+cnt: 1, ((T([1, 64, 262, 262], f16), T([3, 64, 7, 7], f16), T([3], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([1, 3, 256, 256], f16), T([1, 64, 262, 262], f16), T([3, 64, 7, 7], f16), [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 64, 256, 256], f16), T([1, 128, 128, 128], f16), T([128, 64, 3, 3], f16), [64], [2, 2], [1, 1], [1, 1], True, [1, 1], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 128, 128, 128], f16), T([1, 256, 64, 64], f16), T([256, 128, 3, 3], f16), [128], [2, 2], [1, 1], [1, 1], True, [1, 1], 1, [True, True, True]), {})
+cnt: 18, ((T([1, 256, 64, 64], f16), T([1, 256, 66, 66], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 256, 64, 64], f16), T([1, 128, 128, 128], f16), T([256, 128, 3, 3], f16), [256], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 128, 128, 128], f16), T([1, 64, 256, 256], f16), T([128, 64, 3, 3], f16), [128], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 64, 256, 256], f16), T([1, 3, 262, 262], f16), T([64, 3, 7, 7], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([1, 3, 256, 256], f16), T([1, 3, 256, 256], f16)), {})
+cnt: 2, ((T([64, 256, 256], f16), T([64, 256, 256], f16)), {})
+cnt: 4, ((T([1, 64, 256, 256], f16), T([1, 64, 256, 256], f16)), {})
+cnt: 2, ((T([128, 128, 128], f16), T([128, 128, 128], f16)), {})
+cnt: 4, ((T([1, 128, 128, 128], f16), T([1, 128, 128, 128], f16)), {})
+cnt: 10, ((T([256, 64, 64], f16), T([256, 64, 64], f16)), {})
+cnt: 20, ((T([1, 256, 64, 64], f16), T([1, 256, 64, 64], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 196608), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([1, 64, 256, 256], f16), None, None, None, None, True, 0.1, 1e-05), {})
+cnt: 2, ((T([1, 128, 128, 128], f16), None, None, None, None, True, 0.1, 1e-05), {})
+cnt: 19, ((T([1, 256, 64, 64], f16), None, None, None, None, True, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 2, ((T([1, 64, 256, 256], f16), T([1, 64, 256, 256], f16), None, None, None, T([64], f32), T([64], f32), True, 1e-05, [True, False, False]), {})
+cnt: 2, ((T([1, 128, 128, 128], f16), T([1, 128, 128, 128], f16), None, None, None, T([128], f32), T([128], f32), True, 1e-05, [True, False, False]), {})
+cnt: 19, ((T([1, 256, 64, 64], f16), T([1, 256, 64, 64], f16), None, None, None, T([256], f32), T([256], f32), True, 1e-05, [True, False, False]), {})
+Operator: aten.new_empty_strided.default
+cnt: 2, ((T([1, 64, 256, 256], f16), [1, 64, 256, 256], [4194304, 65536, 256, 1]), {})
+cnt: 2, ((T([1, 128, 128, 128], f16), [1, 128, 128, 128], [2097152, 16384, 128, 1]), {})
+cnt: 10, ((T([1, 256, 64, 64], f16), [1, 256, 64, 64], [1048576, 4096, 64, 1]), {})
+Operator: aten.new_zeros.default
+cnt: 2, ((T([64, 256, 256], f16), [4194304]), {})
+cnt: 2, ((T([128, 128, 128], f16), [2097152]), {})
+cnt: 10, ((T([256, 64, 64], f16), [1048576]), {})
+Operator: aten.reflection_pad2d.default
+cnt: 1, ((T([1, 3, 256, 256], f16), [3, 3, 3, 3]), {})
+cnt: 18, ((T([1, 256, 64, 64], f16), [1, 1, 1, 1]), {})
+cnt: 1, ((T([1, 64, 256, 256], f16), [3, 3, 3, 3]), {})
+Operator: aten.reflection_pad2d_backward.default
+cnt: 1, ((T([1, 64, 262, 262], f16), T([1, 64, 256, 256], f16), [3, 3, 3, 3]), {})
+cnt: 18, ((T([1, 256, 66, 66], f16), T([1, 256, 64, 64], f16), [1, 1, 1, 1]), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([1, 64, 256, 256], f16),), {})
+cnt: 2, ((T([1, 128, 128, 128], f16),), {})
+cnt: 10, ((T([1, 256, 64, 64], f16),), {})
+Operator: aten.sum.default
+cnt: 1, ((T([1, 3, 256, 256], f16),), {})
+Operator: aten.tanh.default
+cnt: 1, ((T([1, 3, 256, 256], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 1, ((T([1, 3, 256, 256], f16, stride=(0, 0, 0, 0)), T([1, 3, 256, 256], f16)), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([1, 64, 256, 256], f16), T([1, 64, 256, 256], f16), 0), {})
+cnt: 2, ((T([1, 128, 128, 128], f16), T([1, 128, 128, 128], f16), 0), {})
+cnt: 10, ((T([1, 256, 64, 64], f16), T([1, 256, 64, 64], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_stargan_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_stargan_training.txt
new file mode 100644
index 0000000000000..a2969693ef9b6
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_stargan_training.txt
@@ -0,0 +1,80 @@
+Operator: aten.add.Tensor
+cnt: 12, ((T([16, 256, 32, 32], f16), T([16, 256, 32, 32], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([16, 3, 128, 128], f16), T([16, 5, 128, 128], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([16, 3, 128, 128], f16),), {})
+cnt: 1, ((T([16, 5], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([16, 8, 128, 128], f16), T([64, 8, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 64, 128, 128], f16), T([128, 64, 4, 4], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 128, 64, 64], f16), T([256, 128, 4, 4], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([16, 256, 32, 32], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 256, 32, 32], f16), T([256, 128, 4, 4], f16), None, [2, 2], [1, 1], [1, 1], True, [0, 0], 1), {})
+cnt: 1, ((T([16, 128, 64, 64], f16), T([128, 64, 4, 4], f16), None, [2, 2], [1, 1], [1, 1], True, [0, 0], 1), {})
+cnt: 1, ((T([16, 64, 128, 128], f16), T([3, 64, 7, 7], f16), None, [1, 1], [3, 3], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([16, 3, 128, 128], f16), T([16, 64, 128, 128], f16), T([3, 64, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 64, 128, 128], f16), T([16, 128, 64, 64], f16), T([128, 64, 4, 4], f16), [0], [2, 2], [1, 1], [1, 1], True, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 128, 64, 64], f16), T([16, 256, 32, 32], f16), T([256, 128, 4, 4], f16), [0], [2, 2], [1, 1], [1, 1], True, [0, 0], 1, [True, True, False]), {})
+cnt: 12, ((T([16, 256, 32, 32], f16), T([16, 256, 32, 32], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 256, 32, 32], f16), T([16, 128, 64, 64], f16), T([256, 128, 4, 4], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 128, 64, 64], f16), T([16, 64, 128, 128], f16), T([128, 64, 4, 4], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 64, 128, 128], f16), T([16, 8, 128, 128], f16), T([64, 8, 7, 7], f16), [0], [1, 1], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([16, 3, 128, 128], f16), T([16, 3, 128, 128], f16)), {})
+cnt: 1, ((T([16, 5], f16), T([16, 5], f16)), {})
+cnt: 4, ((T([64], f16), T([64], f16)), {})
+cnt: 4, ((T([128], f16), T([128], f16)), {})
+cnt: 26, ((T([256], f16), T([256], f16)), {})
+cnt: 4, ((T([16, 64, 128, 128], f16), T([16, 64, 128, 128], f16)), {})
+cnt: 2, ((T([1, 1024, 128, 128], f16), T([1, 1024, 128, 128], f16)), {})
+cnt: 4, ((T([16, 128, 64, 64], f16), T([16, 128, 64, 64], f16)), {})
+cnt: 2, ((T([1, 2048, 64, 64], f16), T([1, 2048, 64, 64], f16)), {})
+cnt: 14, ((T([16, 256, 32, 32], f16), T([16, 256, 32, 32], f16)), {})
+cnt: 7, ((T([1, 4096, 32, 32], f16), T([1, 4096, 32, 32], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 786432), {})
+Operator: aten.mean.dim
+cnt: 4, ((T([16, 64], f16), [0]), {})
+cnt: 4, ((T([16, 128], f16), [0]), {})
+cnt: 26, ((T([16, 256], f16), [0]), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([1, 1024, 128, 128], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.1, 1e-05), {})
+cnt: 2, ((T([1, 2048, 64, 64], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), False, 0.1, 1e-05), {})
+cnt: 13, ((T([1, 4096, 32, 32], f16), T([4096], f16), T([4096], f16), T([4096], f16), T([4096], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 2, ((T([1, 1024, 128, 128], f16), T([1, 1024, 128, 128], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([1, 2048, 64, 64], f16), T([1, 2048, 64, 64], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), False, 1e-05, [True, True, True]), {})
+cnt: 13, ((T([1, 4096, 32, 32], f16), T([1, 4096, 32, 32], f16), T([4096], f16), T([4096], f16), T([4096], f16), T([4096], f32), T([4096], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 2, ((T([1, 1024, 128, 128], f16), [1, 1024, 128, 128], [16777216, 16384, 128, 1]), {})
+cnt: 2, ((T([1, 2048, 64, 64], f16), [1, 2048, 64, 64], [8388608, 4096, 64, 1]), {})
+cnt: 7, ((T([1, 4096, 32, 32], f16), [1, 4096, 32, 32], [4194304, 1024, 32, 1]), {})
+Operator: aten.new_zeros.default
+cnt: 2, ((T([16, 64, 128, 128], f16), [16777216]), {})
+cnt: 2, ((T([16, 128, 64, 64], f16), [8388608]), {})
+cnt: 7, ((T([16, 256, 32, 32], f16), [4194304]), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([16, 64, 128, 128], f16),), {})
+cnt: 2, ((T([16, 128, 64, 64], f16),), {})
+cnt: 7, ((T([16, 256, 32, 32], f16),), {})
+Operator: aten.repeat.default
+cnt: 1, ((T([16, 5, 1, 1], f16), [1, 1, 128, 128]), {})
+cnt: 8, ((T([64], f16), [16]), {})
+cnt: 8, ((T([128], f16), [16]), {})
+cnt: 52, ((T([256], f16), [16]), {})
+Operator: aten.sum.default
+cnt: 1, ((T([16, 3, 128, 128], f16),), {})
+Operator: aten.sum.dim_IntList
+cnt: 4, ((T([16, 64], f16), [0]), {})
+cnt: 4, ((T([16, 128], f16), [0]), {})
+cnt: 26, ((T([16, 256], f16), [0]), {})
+Operator: aten.tanh.default
+cnt: 1, ((T([16, 3, 128, 128], f16),), {})
+Operator: aten.tanh_backward.default
+cnt: 1, ((T([16, 3, 128, 128], f16, stride=(0, 0, 0, 0)), T([16, 3, 128, 128], f16)), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([16, 64, 128, 128], f16), T([16, 64, 128, 128], f16), 0), {})
+cnt: 2, ((T([16, 128, 64, 64], f16), T([16, 128, 64, 64], f16), 0), {})
+cnt: 7, ((T([16, 256, 32, 32], f16), T([16, 256, 32, 32], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_struct_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_struct_training.txt
new file mode 100644
index 0000000000000..3512fcd8ff066
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_struct_training.txt
@@ -0,0 +1,63 @@
+Operator: aten._log_softmax.default
+cnt: 1, ((T([30, 4771], f16, stride=(1, 30)), -1, False), {})
+cnt: 1, ((T([30, 3600], f16), -1, False), {})
+cnt: 1, ((T([30], f16), -1, False), {})
+Operator: aten._log_softmax_backward_data.default
+cnt: 1, ((T([30], f16), T([30], f16), -1, f16), {})
+cnt: 1, ((T([30, 3600], f16), T([30, 3600], f16), -1, f16), {})
+cnt: 1, ((T([30, 4771], f16), T([30, 4771], f16), -1, f16), {})
+Operator: aten.add.Tensor
+cnt: 4, ((T([30, 256], f16), T([30, 256], f16)), {})
+cnt: 1, ((T([], f16), 0), {})
+cnt: 2, ((T([], f16), T([], f16)), {})
+cnt: 4, ((T([30, 256], f16, stride=(1, 30)), T([30, 256], f16)), {})
+Operator: aten.addmm.default
+cnt: 10, ((T([256], f16), T([30, 256], f16), T([256, 256], f16, stride=(1, 256))), {})
+Operator: aten.bmm.default
+cnt: 1, ((T([1, 4771, 256], f16), T([1, 256, 30], f16, stride=(256, 1, 256))), {})
+cnt: 1, ((T([1, 30, 256], f16), T([1, 256, 3600], f16, stride=(256, 1, 256))), {})
+cnt: 1, ((T([1, 1, 256], f16), T([1, 256, 30], f16, stride=(256, 1, 256))), {})
+cnt: 1, ((T([1, 256, 1], f16), T([1, 1, 30], f16)), {})
+cnt: 1, ((T([1, 1, 30], f16), T([1, 30, 256], f16)), {})
+cnt: 1, ((T([1, 256, 30], f16, stride=(7680, 1, 256)), T([1, 30, 3600], f16)), {})
+cnt: 1, ((T([1, 30, 3600], f16), T([1, 3600, 256], f16)), {})
+cnt: 1, ((T([1, 256, 4771], f16, stride=(1221376, 1, 256)), T([1, 4771, 30], f16, stride=(4771, 1, 4771))), {})
+cnt: 1, ((T([1, 4771, 30], f16, stride=(4771, 1, 4771)), T([1, 30, 256], f16)), {})
+Operator: aten.clone.default
+cnt: 1, ((T([40, 29], i64, stride=(1, 40)),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([40, 29], i64, stride=(1, 40)), T([40, 29], i64, stride=(1, 40))), {})
+cnt: 1, ((T([60, 60, 256], f16), T([60, 60, 256], f16, stride=(60, 1, 3600))), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 34800), {})
+cnt: 2, ((T([], f16), 4320000), {})
+cnt: 2, ((T([], f16), 1200), {})
+cnt: 2, ((T([], f16), 3), {})
+Operator: aten.gather.default
+cnt: 1, ((T([40, 29, 30, 4771], f16, stride=(0, 0, 4771, 1)), 3, T([40, 29, 30, 1], i64, stride=(1, 40, 0, 1))), {})
+Operator: aten.mm.default
+cnt: 8, ((T([30, 256], f16), T([256, 256], f16)), {})
+cnt: 8, ((T([256, 30], f16, stride=(1, 256)), T([30, 256], f16)), {})
+cnt: 2, ((T([30, 256], f16, stride=(1, 30)), T([256, 256], f16)), {})
+cnt: 2, ((T([256, 30], f16), T([30, 256], f16)), {})
+Operator: aten.new_empty_strided.default
+cnt: 1, ((T([60, 60, 256], f16, stride=(60, 1, 3600)), [60, 60, 256], [15360, 256, 1]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([40, 29, 30, 1], f16, stride=(0, 0, 0, 1)), [40, 29, 30, 4771]), {})
+Operator: aten.relu.default
+cnt: 8, ((T([30, 256], f16),), {})
+Operator: aten.scatter_add.default
+cnt: 1, ((T([40, 29, 30, 4771], f16), 3, T([40, 29, 30, 1], i64, stride=(1, 40, 0, 1)), T([40, 29, 30, 1], f16, stride=(0, 0, 0, 1))), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([40, 30], f16, stride=(0, 0)), [0], True), {})
+cnt: 8, ((T([30, 256], f16), [0], True), {})
+cnt: 2, ((T([30, 256], f16, stride=(1, 30)), [0], True), {})
+cnt: 1, ((T([40, 30, 60, 60], f16, stride=(0, 0, 0, 0)), [0], True), {})
+cnt: 1, ((T([40, 29, 30, 4771], f16), [0, 1], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([40, 29, 30], f16),), {})
+cnt: 1, ((T([40, 30, 60, 60], f16, stride=(0, 3600, 60, 1)),), {})
+cnt: 1, ((T([40, 30], f16, stride=(0, 1)),), {})
+Operator: aten.threshold_backward.default
+cnt: 4, ((T([30, 256], f16, stride=(1, 30)), T([30, 256], f16), 0), {})
+cnt: 4, ((T([30, 256], f16), T([30, 256], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_unet_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_unet_training.txt
new file mode 100644
index 0000000000000..e2e12ab9be692
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/pytorch_unet_training.txt
@@ -0,0 +1,119 @@
+Operator: aten.add.Tensor
+cnt: 1, ((T([1, 512, 80, 119], f16), T([1, 512, 80, 119], f16)), {})
+cnt: 1, ((T([1, 256, 160, 239], f16), T([1, 256, 160, 239], f16)), {})
+cnt: 1, ((T([1, 128, 320, 479], f16), T([1, 128, 320, 479], f16)), {})
+cnt: 1, ((T([1, 64, 640, 959], f16), T([1, 64, 640, 959], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([1, 512, 80, 119], f16), T([1, 512, 80, 119], f16)], 1), {})
+cnt: 1, (([T([1, 256, 160, 239], f16), T([1, 256, 160, 239], f16)], 1), {})
+cnt: 1, (([T([1, 128, 320, 479], f16), T([1, 128, 320, 479], f16)], 1), {})
+cnt: 1, (([T([1, 64, 640, 959], f16), T([1, 64, 640, 959], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([1, 3, 640, 959], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([1, 512, 80, 118], f16), [0, 1, 0, 0], 0.0), {})
+cnt: 1, ((T([1, 256, 160, 238], f16), [0, 1, 0, 0], 0.0), {})
+cnt: 1, ((T([1, 128, 320, 478], f16), [0, 1, 0, 0], 0.0), {})
+cnt: 1, ((T([1, 64, 640, 958], f16), [0, 1, 0, 0], 0.0), {})
+cnt: 1, ((T([1, 64, 640, 959], f16), [0, -1, 0, 0]), {})
+cnt: 1, ((T([1, 128, 320, 479], f16), [0, -1, 0, 0]), {})
+cnt: 1, ((T([1, 256, 160, 239], f16), [0, -1, 0, 0]), {})
+cnt: 1, ((T([1, 512, 80, 119], f16), [0, -1, 0, 0]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([1, 3, 640, 959], f16), T([64, 3, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([1, 64, 640, 959], f16), T([64, 64, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 64, 320, 479], f16), T([128, 64, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 128, 320, 479], f16), T([128, 128, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 128, 160, 239], f16), T([256, 128, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 256, 160, 239], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 256, 80, 119], f16), T([512, 256, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 512, 80, 119], f16), T([512, 512, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([1, 512, 40, 59], f16), T([512, 512, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 1024, 80, 119], f16), T([512, 1024, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 512, 80, 119], f16), T([256, 512, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 512, 160, 239], f16), T([256, 512, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 256, 160, 239], f16), T([128, 256, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 256, 320, 479], f16), T([128, 256, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 128, 320, 479], f16), T([64, 128, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 128, 640, 959], f16), T([64, 128, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 64, 640, 959], f16), T([2, 64, 1, 1], f16), T([2], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([1, 2, 640, 959], f16, stride=(0, 0, 0, 0)), T([1, 64, 640, 959], f16), T([2, 64, 1, 1], f16), [2], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([1, 64, 640, 959], f16), T([1, 64, 640, 959], f16), T([64, 64, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 64, 640, 959], f16), T([1, 128, 640, 959], f16), T([64, 128, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 64, 320, 479], f16), T([1, 128, 320, 479], f16), T([64, 128, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 128, 320, 479], f16), T([1, 256, 320, 479], f16), T([128, 256, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 128, 160, 239], f16), T([1, 256, 160, 239], f16), T([128, 256, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 256, 160, 239], f16), T([1, 512, 160, 239], f16), T([256, 512, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 256, 80, 119], f16), T([1, 512, 80, 119], f16), T([256, 512, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 512, 80, 119], f16), T([1, 1024, 80, 119], f16), T([512, 1024, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([1, 512, 40, 59], f16), T([1, 512, 40, 59], f16), T([512, 512, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 512, 80, 119], f16), T([1, 512, 80, 119], f16), T([512, 512, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 512, 80, 119], f16), T([1, 256, 80, 119], f16), T([512, 256, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 256, 160, 239], f16), T([1, 256, 160, 239], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 256, 160, 239], f16), T([1, 128, 160, 239], f16), T([256, 128, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 128, 320, 479], f16), T([1, 128, 320, 479], f16), T([128, 128, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 128, 320, 479], f16), T([1, 64, 320, 479], f16), T([128, 64, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 64, 640, 959], f16), T([1, 3, 640, 959], f16), T([64, 3, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([1, 3, 640, 959], f16), T([1, 3, 640, 959], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 1227520), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([1, 64, 640, 959], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([1, 128, 320, 479], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([1, 256, 160, 239], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([1, 512, 80, 119], f16), [2, 2], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([1, 512, 40, 59], f16), T([1, 512, 80, 119], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([1, 512, 40, 59], i64)), {})
+cnt: 1, ((T([1, 256, 80, 119], f16), T([1, 256, 160, 239], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([1, 256, 80, 119], i64)), {})
+cnt: 1, ((T([1, 128, 160, 239], f16), T([1, 128, 320, 479], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([1, 128, 160, 239], i64)), {})
+cnt: 1, ((T([1, 64, 320, 479], f16), T([1, 64, 640, 959], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([1, 64, 320, 479], i64)), {})
+Operator: aten.native_batch_norm.default
+cnt: 4, ((T([1, 64, 640, 959], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([1, 128, 320, 479], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([1, 256, 160, 239], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([1, 512, 80, 119], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 2, ((T([1, 512, 40, 59], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([1, 256, 80, 119], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([1, 128, 160, 239], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([1, 64, 320, 479], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([1, 64, 640, 959], f16), T([1, 64, 640, 959], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([1, 64, 320, 479], f16), T([1, 64, 320, 479], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([1, 128, 320, 479], f16), T([1, 128, 320, 479], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([1, 128, 160, 239], f16), T([1, 128, 160, 239], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([1, 256, 160, 239], f16), T([1, 256, 160, 239], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([1, 256, 80, 119], f16), T([1, 256, 80, 119], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([1, 512, 80, 119], f16), T([1, 512, 80, 119], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([1, 512, 40, 59], f16), T([1, 512, 40, 59], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu_.default
+cnt: 4, ((T([1, 64, 640, 959], f16),), {})
+cnt: 3, ((T([1, 128, 320, 479], f16),), {})
+cnt: 3, ((T([1, 256, 160, 239], f16),), {})
+cnt: 3, ((T([1, 512, 80, 119], f16),), {})
+cnt: 2, ((T([1, 512, 40, 59], f16),), {})
+cnt: 1, ((T([1, 256, 80, 119], f16),), {})
+cnt: 1, ((T([1, 128, 160, 239], f16),), {})
+cnt: 1, ((T([1, 64, 320, 479], f16),), {})
+Operator: aten.sum.default
+cnt: 1, ((T([1, 2, 640, 959], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 4, ((T([1, 64, 640, 959], f16), T([1, 64, 640, 959], f16), 0), {})
+cnt: 1, ((T([1, 64, 320, 479], f16), T([1, 64, 320, 479], f16), 0), {})
+cnt: 3, ((T([1, 128, 320, 479], f16), T([1, 128, 320, 479], f16), 0), {})
+cnt: 1, ((T([1, 128, 160, 239], f16), T([1, 128, 160, 239], f16), 0), {})
+cnt: 3, ((T([1, 256, 160, 239], f16), T([1, 256, 160, 239], f16), 0), {})
+cnt: 1, ((T([1, 256, 80, 119], f16), T([1, 256, 80, 119], f16), 0), {})
+cnt: 3, ((T([1, 512, 80, 119], f16), T([1, 512, 80, 119], f16), 0), {})
+cnt: 2, ((T([1, 512, 40, 59], f16), T([1, 512, 40, 59], f16), 0), {})
+Operator: aten.upsample_bilinear2d.vec
+cnt: 1, ((T([1, 512, 40, 59], f16), None, True, [2.0, 2.0]), {})
+cnt: 1, ((T([1, 256, 80, 119], f16), None, True, [2.0, 2.0]), {})
+cnt: 1, ((T([1, 128, 160, 239], f16), None, True, [2.0, 2.0]), {})
+cnt: 1, ((T([1, 64, 320, 479], f16), None, True, [2.0, 2.0]), {})
+Operator: aten.upsample_bilinear2d_backward.vec
+cnt: 1, ((T([1, 64, 640, 958], f16), None, [1, 64, 320, 479], True, [2.0, 2.0]), {})
+cnt: 1, ((T([1, 128, 320, 478], f16), None, [1, 128, 160, 239], True, [2.0, 2.0]), {})
+cnt: 1, ((T([1, 256, 160, 238], f16), None, [1, 256, 80, 119], True, [2.0, 2.0]), {})
+cnt: 1, ((T([1, 512, 80, 118], f16), None, [1, 512, 40, 59], True, [2.0, 2.0]), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/resnet18_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/resnet18_training.txt
new file mode 100644
index 0000000000000..f949353a358a6
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/resnet18_training.txt
@@ -0,0 +1,81 @@
+Operator: aten.add.Tensor
+cnt: 1, ((T([16, 512, 7, 7], f16), T([16, 512, 7, 7], f16)), {})
+cnt: 2, ((T([16, 256, 14, 14], f16), T([16, 256, 14, 14], f16)), {})
+cnt: 2, ((T([16, 128, 28, 28], f16), T([16, 128, 28, 28], f16)), {})
+cnt: 3, ((T([16, 64, 56, 56], f16), T([16, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 2, ((T([16, 64, 56, 56], f16), T([16, 64, 56, 56], f16)), {})
+cnt: 2, ((T([16, 128, 28, 28], f16), T([16, 128, 28, 28], f16)), {})
+cnt: 2, ((T([16, 256, 14, 14], f16), T([16, 256, 14, 14], f16)), {})
+cnt: 2, ((T([16, 512, 7, 7], f16), T([16, 512, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([16, 512], f16), T([512, 1000], f16, stride=(1, 512))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([16, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([16, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([16, 64, 56, 56], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 64, 56, 56], f16), T([128, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([16, 128, 28, 28], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 64, 56, 56], f16), T([128, 64, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 128, 28, 28], f16), T([256, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([16, 256, 14, 14], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 128, 28, 28], f16), T([256, 128, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 256, 14, 14], f16), T([512, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([16, 512, 7, 7], f16), T([512, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([16, 256, 14, 14], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([16, 512, 7, 7], f16), T([16, 512, 7, 7], f16), T([512, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 512, 7, 7], f16), T([16, 256, 14, 14], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 512, 7, 7], f16), T([16, 256, 14, 14], f16), T([512, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([16, 256, 14, 14], f16), T([16, 256, 14, 14], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 256, 14, 14], f16), T([16, 128, 28, 28], f16), T([256, 128, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 256, 14, 14], f16), T([16, 128, 28, 28], f16), T([256, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([16, 128, 28, 28], f16), T([16, 128, 28, 28], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 128, 28, 28], f16), T([16, 64, 56, 56], f16), T([128, 64, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 128, 28, 28], f16), T([16, 64, 56, 56], f16), T([128, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([16, 64, 56, 56], f16), T([16, 64, 56, 56], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([16, 64, 112, 112], f16), T([16, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([16, 3, 224, 224], f16), T([16, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([16, 512, 7, 7], f16, stride=(512, 1, 0, 0)), 49), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 16000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([16, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([16, 64, 56, 56], f16), T([16, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([16, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([16, 512, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([16, 1000], f16, stride=(0, 0)), T([1000, 512], f16)), {})
+cnt: 1, ((T([1000, 16], f16, stride=(0, 0)), T([16, 512], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([16, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 4, ((T([16, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([16, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([16, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([16, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 5, ((T([16, 512, 7, 7], f16), T([16, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([16, 256, 14, 14], f16), T([16, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([16, 128, 28, 28], f16), T([16, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([16, 64, 56, 56], f16), T([16, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([16, 64, 112, 112], f16), T([16, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([16, 64, 112, 112], f16),), {})
+cnt: 4, ((T([16, 64, 56, 56], f16),), {})
+cnt: 4, ((T([16, 128, 28, 28], f16),), {})
+cnt: 4, ((T([16, 256, 14, 14], f16),), {})
+cnt: 4, ((T([16, 512, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([16, 1000], f16, stride=(0, 0)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([16, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 4, ((T([16, 512, 7, 7], f16), T([16, 512, 7, 7], f16), 0), {})
+cnt: 4, ((T([16, 256, 14, 14], f16), T([16, 256, 14, 14], f16), 0), {})
+cnt: 4, ((T([16, 128, 28, 28], f16), T([16, 128, 28, 28], f16), 0), {})
+cnt: 4, ((T([16, 64, 56, 56], f16), T([16, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([16, 64, 112, 112], f16), T([16, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/resnet50_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/resnet50_training.txt
new file mode 100644
index 0000000000000..517a1e3f175db
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/resnet50_training.txt
@@ -0,0 +1,134 @@
+Operator: aten.add.Tensor
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+cnt: 6, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 4, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+cnt: 4, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 6, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 64, 56, 56], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 64, 56, 56], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 256, 56, 56], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([128, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 128, 28, 28], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 512, 28, 28], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 128, 28, 28], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([256, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([32, 256, 14, 14], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([32, 1024, 14, 14], f16), T([256, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([32, 256, 14, 14], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([512, 512, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 512, 7, 7], f16), T([2048, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([512, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 512, 7, 7], f16), T([512, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 512, 7, 7], f16), T([2048, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 512, 7, 7], f16), T([32, 512, 7, 7], f16), T([512, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 512, 7, 7], f16), T([32, 2048, 7, 7], f16), T([512, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 7, 7], f16), T([32, 512, 14, 14], f16), T([512, 512, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([32, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([32, 1024, 14, 14], f16), T([32, 256, 14, 14], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([32, 256, 14, 14], f16), T([32, 256, 14, 14], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([32, 256, 14, 14], f16), T([32, 1024, 14, 14], f16), T([256, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 14, 14], f16), T([32, 256, 28, 28], f16), T([256, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([32, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 512, 28, 28], f16), T([32, 128, 28, 28], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 128, 28, 28], f16), T([32, 128, 28, 28], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 128, 28, 28], f16), T([32, 512, 28, 28], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 256, 56, 56], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 28, 28], f16), T([32, 128, 56, 56], f16), T([128, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), T([32, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 64, 56, 56], f16), T([32, 256, 56, 56], f16), T([64, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 32000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([32, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([32, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(0, 0)), T([32, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 6, ((T([32, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 7, ((T([32, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 11, ((T([32, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 7, ((T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([32, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([32, 512, 7, 7], f16), T([32, 512, 7, 7], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([32, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 1e-05, [True, True, True]), {})
+cnt: 11, ((T([32, 256, 14, 14], f16), T([32, 256, 14, 14], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([32, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([32, 128, 28, 28], f16), T([32, 128, 28, 28], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([32, 64, 112, 112], f16),), {})
+cnt: 6, ((T([32, 64, 56, 56], f16),), {})
+cnt: 3, ((T([32, 256, 56, 56], f16),), {})
+cnt: 1, ((T([32, 128, 56, 56], f16),), {})
+cnt: 7, ((T([32, 128, 28, 28], f16),), {})
+cnt: 4, ((T([32, 512, 28, 28], f16),), {})
+cnt: 1, ((T([32, 256, 28, 28], f16),), {})
+cnt: 11, ((T([32, 256, 14, 14], f16),), {})
+cnt: 6, ((T([32, 1024, 14, 14], f16),), {})
+cnt: 1, ((T([32, 512, 14, 14], f16),), {})
+cnt: 5, ((T([32, 512, 7, 7], f16),), {})
+cnt: 3, ((T([32, 2048, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([32, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), 0), {})
+cnt: 5, ((T([32, 512, 7, 7], f16), T([32, 512, 7, 7], f16), 0), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([32, 512, 14, 14], f16), 0), {})
+cnt: 6, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), 0), {})
+cnt: 11, ((T([32, 256, 14, 14], f16), T([32, 256, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([32, 256, 28, 28], f16), 0), {})
+cnt: 4, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), 0), {})
+cnt: 7, ((T([32, 128, 28, 28], f16), T([32, 128, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), 0), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), 0), {})
+cnt: 6, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/resnext50_32x4d_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/resnext50_32x4d_training.txt
new file mode 100644
index 0000000000000..256d8ac3242c9
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/resnext50_32x4d_training.txt
@@ -0,0 +1,124 @@
+Operator: aten.add.Tensor
+cnt: 2, ((T([8, 2048, 7, 7], f16), T([8, 2048, 7, 7], f16)), {})
+cnt: 6, ((T([8, 1024, 14, 14], f16), T([8, 1024, 14, 14], f16)), {})
+cnt: 4, ((T([8, 512, 28, 28], f16), T([8, 512, 28, 28], f16)), {})
+cnt: 3, ((T([8, 256, 56, 56], f16), T([8, 256, 56, 56], f16)), {})
+cnt: 1, ((T([8, 64, 56, 56], f16), T([8, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 3, ((T([8, 256, 56, 56], f16), T([8, 256, 56, 56], f16)), {})
+cnt: 4, ((T([8, 512, 28, 28], f16), T([8, 512, 28, 28], f16)), {})
+cnt: 6, ((T([8, 1024, 14, 14], f16), T([8, 1024, 14, 14], f16)), {})
+cnt: 3, ((T([8, 2048, 7, 7], f16), T([8, 2048, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([8, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([8, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([8, 3, 224, 224], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 64, 56, 56], f16), T([128, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([8, 128, 56, 56], f16), T([128, 4, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 3, ((T([8, 128, 56, 56], f16), T([256, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 64, 56, 56], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([8, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 256, 56, 56], f16), T([256, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 256, 56, 56], f16), T([256, 8, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 4, ((T([8, 256, 28, 28], f16), T([512, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 256, 56, 56], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([8, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([8, 256, 28, 28], f16), T([256, 8, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([8, 512, 28, 28], f16), T([512, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 512, 28, 28], f16), T([512, 16, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 6, ((T([8, 512, 14, 14], f16), T([1024, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([8, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([8, 512, 14, 14], f16), T([512, 16, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([8, 1024, 14, 14], f16), T([1024, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 1024, 14, 14], f16), T([1024, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 3, ((T([8, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([8, 2048, 7, 7], f16), T([1024, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([8, 1024, 7, 7], f16), T([1024, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([8, 2048, 7, 7], f16), T([8, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([8, 1024, 7, 7], f16), T([8, 1024, 7, 7], f16), T([1024, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 2, ((T([8, 1024, 7, 7], f16), T([8, 2048, 7, 7], f16), T([1024, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 2048, 7, 7], f16), T([8, 1024, 14, 14], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 1024, 7, 7], f16), T([8, 1024, 14, 14], f16), T([1024, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([8, 1024, 14, 14], f16), T([8, 1024, 14, 14], f16), T([1024, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([8, 1024, 14, 14], f16), T([8, 512, 14, 14], f16), T([1024, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([8, 512, 14, 14], f16), T([8, 512, 14, 14], f16), T([512, 16, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 5, ((T([8, 512, 14, 14], f16), T([8, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 1024, 14, 14], f16), T([8, 512, 28, 28], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 512, 14, 14], f16), T([8, 512, 28, 28], f16), T([512, 16, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([8, 512, 28, 28], f16), T([8, 512, 28, 28], f16), T([512, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([8, 512, 28, 28], f16), T([8, 256, 28, 28], f16), T([512, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([8, 256, 28, 28], f16), T([8, 256, 28, 28], f16), T([256, 8, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 3, ((T([8, 256, 28, 28], f16), T([8, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 512, 28, 28], f16), T([8, 256, 56, 56], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 256, 28, 28], f16), T([8, 256, 56, 56], f16), T([256, 8, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([8, 256, 56, 56], f16), T([8, 256, 56, 56], f16), T([256, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([8, 256, 56, 56], f16), T([8, 128, 56, 56], f16), T([256, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([8, 128, 56, 56], f16), T([8, 128, 56, 56], f16), T([128, 4, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 2, ((T([8, 128, 56, 56], f16), T([8, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 256, 56, 56], f16), T([8, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 128, 56, 56], f16), T([8, 64, 56, 56], f16), T([128, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 64, 112, 112], f16), T([8, 3, 224, 224], f16), T([64, 3, 7, 7], f16), [0], [2, 2], [3, 3], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([8, 3, 224, 224], f16), T([8, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([8, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 8000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([8, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([8, 64, 56, 56], f16), T([8, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([8, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([8, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([8, 1000], f16, stride=(0, 0)), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 8], f16, stride=(0, 0)), T([8, 2048], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([8, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 6, ((T([8, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([8, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 7, ((T([8, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 6, ((T([8, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 11, ((T([8, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 8, ((T([8, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([8, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.1, 1e-05), {})
+cnt: 4, ((T([8, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 4, ((T([8, 2048, 7, 7], f16), T([8, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([8, 1024, 7, 7], f16), T([8, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([8, 1024, 14, 14], f16), T([8, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 1e-05, [True, True, True]), {})
+cnt: 11, ((T([8, 512, 14, 14], f16), T([8, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([8, 512, 28, 28], f16), T([8, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 7, ((T([8, 256, 28, 28], f16), T([8, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([8, 256, 56, 56], f16), T([8, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([8, 128, 56, 56], f16), T([8, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([8, 64, 112, 112], f16), T([8, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([8, 64, 112, 112], f16),), {})
+cnt: 6, ((T([8, 128, 56, 56], f16),), {})
+cnt: 4, ((T([8, 256, 56, 56], f16),), {})
+cnt: 7, ((T([8, 256, 28, 28], f16),), {})
+cnt: 5, ((T([8, 512, 28, 28], f16),), {})
+cnt: 11, ((T([8, 512, 14, 14], f16),), {})
+cnt: 7, ((T([8, 1024, 14, 14], f16),), {})
+cnt: 5, ((T([8, 1024, 7, 7], f16),), {})
+cnt: 3, ((T([8, 2048, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([8, 1000], f16, stride=(0, 0)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([8, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 3, ((T([8, 2048, 7, 7], f16), T([8, 2048, 7, 7], f16), 0), {})
+cnt: 5, ((T([8, 1024, 7, 7], f16), T([8, 1024, 7, 7], f16), 0), {})
+cnt: 7, ((T([8, 1024, 14, 14], f16), T([8, 1024, 14, 14], f16), 0), {})
+cnt: 11, ((T([8, 512, 14, 14], f16), T([8, 512, 14, 14], f16), 0), {})
+cnt: 5, ((T([8, 512, 28, 28], f16), T([8, 512, 28, 28], f16), 0), {})
+cnt: 7, ((T([8, 256, 28, 28], f16), T([8, 256, 28, 28], f16), 0), {})
+cnt: 4, ((T([8, 256, 56, 56], f16), T([8, 256, 56, 56], f16), 0), {})
+cnt: 6, ((T([8, 128, 56, 56], f16), T([8, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([8, 64, 112, 112], f16), T([8, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/shufflenet_v2_x1_0_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/shufflenet_v2_x1_0_training.txt
new file mode 100644
index 0000000000000..9b26d6a7b7c15
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/shufflenet_v2_x1_0_training.txt
@@ -0,0 +1,123 @@
+Operator: aten._unsafe_view.default
+cnt: 4, ((T([128, 2, 232, 7, 7], f16), [128, 464, 7, 7]), {})
+cnt: 8, ((T([128, 2, 116, 14, 14], f16), [128, 232, 14, 14]), {})
+cnt: 4, ((T([128, 2, 58, 28, 28], f16), [128, 116, 28, 28]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([128, 232, 14, 14], f16), T([128, 232, 14, 14], f16)), {})
+cnt: 1, ((T([128, 116, 28, 28], f16), T([128, 116, 28, 28], f16)), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 24, 56, 56], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([128, 58, 28, 28], f16), T([128, 58, 28, 28], f16)], 1), {})
+cnt: 6, (([T([128, 58, 28, 28], f16, stride=(90944, 784, 28, 1)), T([128, 58, 28, 28], f16)], 1), {})
+cnt: 1, (([T([128, 116, 14, 14], f16), T([128, 116, 14, 14], f16)], 1), {})
+cnt: 14, (([T([128, 116, 14, 14], f16, stride=(45472, 196, 14, 1)), T([128, 116, 14, 14], f16)], 1), {})
+cnt: 1, (([T([128, 232, 7, 7], f16), T([128, 232, 7, 7], f16)], 1), {})
+cnt: 6, (([T([128, 232, 7, 7], f16, stride=(22736, 49, 7, 1)), T([128, 232, 7, 7], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([24, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([24, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 24), {})
+cnt: 1, ((T([128, 24, 28, 28], f16), T([58, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 24, 56, 56], f16), T([58, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 58, 56, 56], f16), T([58, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 58), {})
+cnt: 4, ((T([128, 58, 28, 28], f16), T([58, 58, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 58, 28, 28], f16, stride=(90944, 784, 28, 1)), T([58, 58, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 58, 28, 28], f16), T([58, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 58), {})
+cnt: 2, ((T([128, 116, 28, 28], f16), T([116, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 116), {})
+cnt: 9, ((T([128, 116, 14, 14], f16), T([116, 116, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 116, 28, 28], f16), T([116, 116, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 116, 14, 14], f16, stride=(45472, 196, 14, 1)), T([116, 116, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([128, 116, 14, 14], f16), T([116, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 116), {})
+cnt: 2, ((T([128, 232, 14, 14], f16), T([232, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 232), {})
+cnt: 5, ((T([128, 232, 7, 7], f16), T([232, 232, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 232, 14, 14], f16), T([232, 232, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 232, 7, 7], f16, stride=(22736, 49, 7, 1)), T([232, 232, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 232, 7, 7], f16), T([232, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 232), {})
+cnt: 1, ((T([128, 464, 7, 7], f16), T([1024, 464, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 464, 7, 7], f16), T([1024, 464, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([128, 232, 7, 7], f16), T([128, 232, 7, 7], f16), T([232, 232, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 232, 7, 7], f16), T([128, 232, 7, 7], f16), T([232, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 232, [True, True, False]), {})
+cnt: 3, ((T([128, 232, 7, 7], f16), T([128, 232, 7, 7], f16, stride=(22736, 49, 7, 1)), T([232, 232, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 232, 7, 7], f16), T([128, 232, 14, 14], f16), T([232, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 232, [True, True, False]), {})
+cnt: 1, ((T([128, 232, 14, 14], f16), T([128, 232, 14, 14], f16), T([232, 232, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 9, ((T([128, 116, 14, 14], f16), T([128, 116, 14, 14], f16), T([116, 116, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([128, 116, 14, 14], f16), T([128, 116, 14, 14], f16), T([116, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 116, [True, True, False]), {})
+cnt: 7, ((T([128, 116, 14, 14], f16), T([128, 116, 14, 14], f16, stride=(45472, 196, 14, 1)), T([116, 116, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([128, 116, 14, 14], f16), T([128, 116, 28, 28], f16), T([116, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 116, [True, True, False]), {})
+cnt: 1, ((T([128, 116, 28, 28], f16), T([128, 116, 28, 28], f16), T([116, 116, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([128, 58, 28, 28], f16), T([128, 58, 28, 28], f16), T([58, 58, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([128, 58, 28, 28], f16), T([128, 58, 28, 28], f16), T([58, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 58, [True, True, False]), {})
+cnt: 3, ((T([128, 58, 28, 28], f16), T([128, 58, 28, 28], f16, stride=(90944, 784, 28, 1)), T([58, 58, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 58, 28, 28], f16), T([128, 58, 56, 56], f16), T([58, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 58, [True, True, False]), {})
+cnt: 1, ((T([128, 58, 56, 56], f16), T([128, 24, 56, 56], f16), T([58, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 58, 28, 28], f16), T([128, 24, 28, 28], f16), T([58, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 28, 28], f16), T([128, 24, 56, 56], f16), T([24, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 24, [True, True, False]), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([128, 3, 224, 224], f16), T([24, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 224, 224], f16), T([128, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 1024, 7, 7], f16, stride=(1024, 1, 0, 0)), 49), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 128000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([128, 24, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([128, 24, 56, 56], f16), T([128, 24, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([128, 24, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 1024, 7, 7], f16), [2, 3]), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16, stride=(0, 0)), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(0, 0)), T([128, 1024], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([128, 24, 112, 112], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 24, 28, 28], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), False, 0.1, 1e-05), {})
+cnt: 12, ((T([128, 58, 28, 28], f16), T([58], f16), T([58], f16), T([58], f16), T([58], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 58, 56, 56], f16), T([58], f16), T([58], f16), T([58], f16), T([58], f16), False, 0.1, 1e-05), {})
+cnt: 25, ((T([128, 116, 14, 14], f16), T([116], f16), T([116], f16), T([116], f16), T([116], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 116, 28, 28], f16), T([116], f16), T([116], f16), T([116], f16), T([116], f16), False, 0.1, 1e-05), {})
+cnt: 13, ((T([128, 232, 7, 7], f16), T([232], f16), T([232], f16), T([232], f16), T([232], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 232, 14, 14], f16), T([232], f16), T([232], f16), T([232], f16), T([232], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 1e-05, [True, True, True]), {})
+cnt: 13, ((T([128, 232, 7, 7], f16), T([128, 232, 7, 7], f16), T([232], f16), T([232], f16), T([232], f16), T([232], f32), T([232], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 232, 14, 14], f16), T([128, 232, 14, 14], f16), T([232], f16), T([232], f16), T([232], f16), T([232], f32), T([232], f32), False, 1e-05, [True, True, True]), {})
+cnt: 25, ((T([128, 116, 14, 14], f16), T([128, 116, 14, 14], f16), T([116], f16), T([116], f16), T([116], f16), T([116], f32), T([116], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 116, 28, 28], f16), T([128, 116, 28, 28], f16), T([116], f16), T([116], f16), T([116], f16), T([116], f32), T([116], f32), False, 1e-05, [True, True, True]), {})
+cnt: 12, ((T([128, 58, 28, 28], f16), T([128, 58, 28, 28], f16), T([58], f16), T([58], f16), T([58], f16), T([58], f32), T([58], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 58, 56, 56], f16), T([128, 58, 56, 56], f16), T([58], f16), T([58], f16), T([58], f16), T([58], f32), T([58], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 28, 28], f16), T([128, 24, 28, 28], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([128, 24, 112, 112], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 24, 112, 112], f16),), {})
+cnt: 8, ((T([128, 58, 28, 28], f16),), {})
+cnt: 1, ((T([128, 58, 56, 56], f16),), {})
+cnt: 16, ((T([128, 116, 14, 14], f16),), {})
+cnt: 1, ((T([128, 116, 28, 28], f16),), {})
+cnt: 8, ((T([128, 232, 7, 7], f16),), {})
+cnt: 1, ((T([128, 232, 14, 14], f16),), {})
+cnt: 1, ((T([128, 1024, 7, 7], f16),), {})
+Operator: aten.split.Tensor
+cnt: 3, ((T([128, 116, 28, 28], f16), 58, 1), {})
+cnt: 7, ((T([128, 232, 14, 14], f16), 116, 1), {})
+cnt: 3, ((T([128, 464, 7, 7], f16), 232, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16, stride=(0, 0)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([128, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([128, 1024, 7, 7], f16), T([128, 1024, 7, 7], f16), 0), {})
+cnt: 5, ((T([128, 232, 7, 7], f16, stride=(22736, 49, 7, 1)), T([128, 232, 7, 7], f16), 0), {})
+cnt: 3, ((T([128, 232, 7, 7], f16), T([128, 232, 7, 7], f16), 0), {})
+cnt: 1, ((T([128, 232, 14, 14], f16), T([128, 232, 14, 14], f16), 0), {})
+cnt: 9, ((T([128, 116, 14, 14], f16, stride=(45472, 196, 14, 1)), T([128, 116, 14, 14], f16), 0), {})
+cnt: 7, ((T([128, 116, 14, 14], f16), T([128, 116, 14, 14], f16), 0), {})
+cnt: 1, ((T([128, 116, 28, 28], f16), T([128, 116, 28, 28], f16), 0), {})
+cnt: 5, ((T([128, 58, 28, 28], f16, stride=(90944, 784, 28, 1)), T([128, 58, 28, 28], f16), 0), {})
+cnt: 3, ((T([128, 58, 28, 28], f16), T([128, 58, 28, 28], f16), 0), {})
+cnt: 1, ((T([128, 58, 56, 56], f16), T([128, 58, 56, 56], f16), 0), {})
+cnt: 1, ((T([128, 24, 112, 112], f16), T([128, 24, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/speech_transformer_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/speech_transformer_training.txt
new file mode 100644
index 0000000000000..8431f307e34d0
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/speech_transformer_training.txt
@@ -0,0 +1,178 @@
+Operator: aten._softmax.default
+cnt: 6, ((T([80, 204, 204], f16), 2, False), {})
+cnt: 6, ((T([80, 22, 22], f16), 2, False), {})
+cnt: 6, ((T([80, 22, 204], f16), 2, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 6, ((T([80, 22, 204], f16), T([80, 22, 204], f16), 2, f16), {})
+cnt: 6, ((T([80, 22, 22], f16), T([80, 22, 22], f16), 2, f16), {})
+cnt: 6, ((T([80, 204, 204], f16), T([80, 204, 204], f16), 2, f16), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([10, 22], b8),), {'dtype': f32})
+cnt: 1, ((T([], f32),), {'dtype': f16})
+cnt: 18, ((T([10, 22, 512], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 1, ((T([220, 1014], f16), [10, 22, 1014]), {})
+cnt: 12, ((T([8, 10, 22, 64], f16), [80, 22, 64]), {})
+cnt: 30, ((T([10, 204, 8, 64], f16), [10, 204, 512]), {})
+cnt: 24, ((T([10, 22, 8, 64], f16), [10, 22, 512]), {})
+cnt: 6, ((T([8, 10, 204, 64], f16), [80, 204, 64]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([10, 204, 512], f16), T([1, 204, 512], f16)), {})
+cnt: 47, ((T([10, 204, 512], f16), T([10, 204, 512], f16)), {})
+cnt: 1, ((T([10, 22, 22], b8, stride=(22, 0, 1)), T([10, 22, 22], u8, stride=(0, 22, 1))), {})
+cnt: 1, ((T([10, 22, 512], f16), T([1, 22, 512], f16)), {})
+cnt: 48, ((T([10, 22, 512], f16), T([10, 22, 512], f16)), {})
+cnt: 1, ((T([], f16), 0), {})
+cnt: 1, ((T([], f16), T([], f32)), {})
+cnt: 1, ((T([1014, 512], f16), T([1014, 512], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([512], f16), T([2040, 320], f16), T([320, 512], f16, stride=(1, 320))), {})
+cnt: 36, ((T([512], f16), T([2040, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 6, ((T([2048], f16), T([2040, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 6, ((T([512], f16), T([2040, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+cnt: 36, ((T([512], f16), T([220, 512], f16), T([512, 512], f16, stride=(1, 512))), {})
+cnt: 6, ((T([2048], f16), T([220, 512], f16), T([512, 2048], f16, stride=(1, 512))), {})
+cnt: 6, ((T([512], f16), T([220, 2048], f16), T([2048, 512], f16, stride=(1, 2048))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([80, 204, 64], f16), T([80, 64, 204], f16, stride=(13056, 1, 64))), {})
+cnt: 12, ((T([80, 204, 204], f16), T([80, 204, 64], f16)), {})
+cnt: 12, ((T([80, 22, 64], f16), T([80, 64, 22], f16, stride=(1408, 1, 64))), {})
+cnt: 12, ((T([80, 22, 22], f16), T([80, 22, 64], f16)), {})
+cnt: 12, ((T([80, 22, 64], f16), T([80, 64, 204], f16, stride=(13056, 1, 64))), {})
+cnt: 12, ((T([80, 22, 204], f16), T([80, 204, 64], f16)), {})
+cnt: 6, ((T([80, 204, 22], f16, stride=(4488, 1, 204)), T([80, 22, 64], f16)), {})
+cnt: 6, ((T([80, 64, 22], f16, stride=(1408, 1, 64)), T([80, 22, 204], f16)), {})
+cnt: 6, ((T([80, 22, 22], f16, stride=(484, 1, 22)), T([80, 22, 64], f16)), {})
+cnt: 6, ((T([80, 64, 22], f16, stride=(1408, 1, 64)), T([80, 22, 22], f16)), {})
+cnt: 6, ((T([80, 204, 204], f16, stride=(41616, 1, 204)), T([80, 204, 64], f16)), {})
+cnt: 6, ((T([80, 64, 204], f16, stride=(13056, 1, 64)), T([80, 204, 204], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([1], i64), T([17], i64)],), {})
+cnt: 1, (([T([1], i64), T([15], i64)],), {})
+cnt: 1, (([T([1], i64), T([21], i64)],), {})
+cnt: 1, (([T([1], i64), T([18], i64)],), {})
+cnt: 3, (([T([1], i64), T([9], i64)],), {})
+cnt: 1, (([T([1], i64), T([12], i64)],), {})
+cnt: 1, (([T([1], i64), T([11], i64)],), {})
+cnt: 1, (([T([1], i64), T([10], i64)],), {})
+cnt: 1, (([T([17], i64), T([1], i64)],), {})
+cnt: 1, (([T([15], i64), T([1], i64)],), {})
+cnt: 1, (([T([21], i64), T([1], i64)],), {})
+cnt: 1, (([T([18], i64), T([1], i64)],), {})
+cnt: 3, (([T([9], i64), T([1], i64)],), {})
+cnt: 1, (([T([12], i64), T([1], i64)],), {})
+cnt: 1, (([T([11], i64), T([1], i64)],), {})
+cnt: 1, (([T([10], i64), T([1], i64)],), {})
+Operator: aten.clone.default
+cnt: 1, ((T([10, 204, 320], f16),), {})
+cnt: 1, ((T([10], i64),), {})
+cnt: 1, ((T([10, 21], i64),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([10, 204, 320], f16), T([10, 204, 320], f16)), {})
+cnt: 7, ((T([10], i64), T([10], i64)), {})
+cnt: 1, ((T([10, 21], i64), T([10, 21], i64)), {})
+cnt: 2, ((T([18], i64), T([18], i64)), {})
+cnt: 2, ((T([16], i64), T([16], i64)), {})
+cnt: 2, ((T([22], i64), T([22], i64)), {})
+cnt: 2, ((T([19], i64), T([19], i64)), {})
+cnt: 2, ((T([13], i64), T([13], i64)), {})
+cnt: 2, ((T([12], i64), T([12], i64)), {})
+cnt: 2, ((T([11], i64), T([11], i64)), {})
+Operator: aten.div.Tensor
+cnt: 12, ((T([80, 204, 204], f16), 8.0), {})
+cnt: 12, ((T([80, 22, 22], f16), 8.0), {})
+cnt: 12, ((T([80, 22, 204], f16), 8.0), {})
+cnt: 2, ((T([], f16), 223080), {})
+cnt: 1, ((T([], i64), 220), {})
+cnt: 2, ((T([], f32), 2), {})
+Operator: aten.embedding.default
+cnt: 1, ((T([1014, 512], f16), T([10, 22], i64)), {})
+Operator: aten.embedding_dense_backward.default
+cnt: 1, ((T([10, 22, 512], f16), T([10, 22], i64), 1014, -1, False), {})
+Operator: aten.eq.Scalar
+cnt: 1, ((T([10, 22], i64), 2), {})
+Operator: aten.fill_.Scalar
+cnt: 1, ((T([10, 22], i64), 2), {})
+cnt: 1, ((T([10, 22], i64), -1), {})
+Operator: aten.fill_.Tensor
+cnt: 3, ((T([0], f16), T([], f16)), {})
+cnt: 3, ((T([4], f16), T([], f16)), {})
+cnt: 3, ((T([8], f16), T([], f16)), {})
+cnt: 3, ((T([24], f16), T([], f16)), {})
+cnt: 3, ((T([57], f16), T([], f16)), {})
+cnt: 3, ((T([67], f16), T([], f16)), {})
+cnt: 3, ((T([75], f16), T([], f16)), {})
+cnt: 3, ((T([91], f16), T([], f16)), {})
+cnt: 3, ((T([99], f16), T([], f16)), {})
+cnt: 3, ((T([118], f16), T([], f16)), {})
+Operator: aten.gt.Scalar
+cnt: 1, ((T([10, 22, 22], u8), 0), {})
+Operator: aten.index.Tensor
+cnt: 10, ((T([21], i64), [T([21], b8)]), {})
+Operator: aten.lt.Scalar
+cnt: 2, ((T([10, 204], f16), 1), {})
+Operator: aten.masked_fill.Scalar
+cnt: 6, ((T([80, 204, 204], f16), T([80, 204, 204], b8), -inf), {})
+cnt: 6, ((T([80, 22, 22], f16), T([80, 22, 22], b8), -inf), {})
+cnt: 6, ((T([80, 22, 204], f16), T([80, 22, 204], b8), -inf), {})
+cnt: 6, ((T([80, 22, 204], f16), T([80, 22, 204], b8), 0), {})
+cnt: 6, ((T([80, 22, 22], f16), T([80, 22, 22], b8), 0), {})
+cnt: 6, ((T([80, 204, 204], f16), T([80, 204, 204], b8), 0), {})
+Operator: aten.mm.default
+cnt: 1, ((T([220, 512], f16), T([512, 1014], f16, stride=(1, 512))), {})
+cnt: 1, ((T([1014, 220], f16, stride=(0, 0)), T([220, 512], f16)), {})
+cnt: 1, ((T([220, 1014], f16, stride=(0, 0)), T([1014, 512], f16)), {})
+cnt: 6, ((T([220, 512], f16), T([512, 2048], f16)), {})
+cnt: 6, ((T([512, 220], f16, stride=(1, 512)), T([220, 2048], f16)), {})
+cnt: 6, ((T([220, 2048], f16), T([2048, 512], f16)), {})
+cnt: 6, ((T([2048, 220], f16, stride=(1, 2048)), T([220, 512], f16)), {})
+cnt: 36, ((T([220, 512], f16), T([512, 512], f16)), {})
+cnt: 36, ((T([512, 220], f16, stride=(1, 512)), T([220, 512], f16)), {})
+cnt: 36, ((T([2040, 512], f16), T([512, 512], f16)), {})
+cnt: 36, ((T([512, 2040], f16, stride=(1, 512)), T([2040, 512], f16)), {})
+cnt: 6, ((T([2040, 512], f16), T([512, 2048], f16)), {})
+cnt: 6, ((T([512, 2040], f16, stride=(1, 512)), T([2040, 2048], f16)), {})
+cnt: 6, ((T([2040, 2048], f16), T([2048, 512], f16)), {})
+cnt: 6, ((T([2048, 2040], f16, stride=(1, 2048)), T([2040, 512], f16)), {})
+cnt: 1, ((T([512, 2040], f16, stride=(1, 512)), T([2040, 320], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([10, 22, 512], f16), 22.627416997969522), {})
+cnt: 18, ((T([10, 22, 512], f16), T([10, 22, 1], f32)), {})
+cnt: 12, ((T([10, 204, 512], f16), T([10, 204, 1], f16)), {})
+Operator: aten.mul_.Tensor
+cnt: 12, ((T([10, 204, 512], f16), T([10, 204, 1], f16)), {})
+cnt: 18, ((T([10, 22, 512], f16), T([10, 22, 1], f32)), {})
+Operator: aten.native_layer_norm.default
+cnt: 13, ((T([10, 204, 512], f16), [512], T([512], f16), T([512], f16), 1e-05), {})
+cnt: 18, ((T([10, 22, 512], f16), [512], T([512], f16), T([512], f16), 1e-05), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 18, ((T([10, 22, 512], f16), T([10, 22, 512], f16), [512], T([10, 22, 1], f32), T([10, 22, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+cnt: 13, ((T([10, 204, 512], f16), T([10, 204, 512], f16), [512], T([10, 204, 1], f32), T([10, 204, 1], f32), T([512], f16), T([512], f16), [True, True, True]), {})
+Operator: aten.ne.Scalar
+cnt: 10, ((T([21], i64), -1), {})
+cnt: 1, ((T([10, 22], i64), 2), {})
+Operator: aten.new_ones.default
+cnt: 2, ((T([10, 204, 320], f16), [10, 204]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 1, ((T([10, 204, 512], f16), [10, 204]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.relu.default
+cnt: 6, ((T([10, 204, 2048], f16),), {})
+cnt: 6, ((T([10, 22, 2048], f16),), {})
+Operator: aten.repeat.default
+cnt: 6, ((T([10, 204, 204], b8, stride=(204, 0, 1)), [8, 1, 1]), {})
+cnt: 6, ((T([10, 22, 22], b8), [8, 1, 1]), {})
+cnt: 6, ((T([10, 22, 204], b8, stride=(204, 0, 1)), [8, 1, 1]), {})
+Operator: aten.sum.SymInt
+cnt: 42, ((T([220, 512], f16), [0], True), {})
+cnt: 6, ((T([220, 2048], f16), [0], True), {})
+cnt: 43, ((T([2040, 512], f16), [0], True), {})
+cnt: 6, ((T([2040, 2048], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([10, 22, 1014], f16),), {})
+cnt: 1, ((T([10, 22], i64),), {})
+Operator: aten.threshold_backward.default
+cnt: 6, ((T([10, 22, 2048], f16), T([10, 22, 2048], f16), 0), {})
+cnt: 6, ((T([10, 204, 2048], f16), T([10, 204, 2048], f16), 0), {})
+Operator: aten.triu.default
+cnt: 1, ((T([22, 22], u8), 1), {})
+Operator: aten.unbind.int
+cnt: 1, ((T([10, 21], i64),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/squeezenet1_1_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/squeezenet1_1_training.txt
new file mode 100644
index 0000000000000..4e4da308b341b
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/squeezenet1_1_training.txt
@@ -0,0 +1,90 @@
+Operator: aten.add.Tensor
+cnt: 2, ((T([32, 64, 13, 13], f16), T([32, 64, 13, 13], f16)), {})
+cnt: 2, ((T([32, 48, 13, 13], f16), T([32, 48, 13, 13], f16)), {})
+cnt: 2, ((T([32, 32, 27, 27], f16), T([32, 32, 27, 27], f16)), {})
+cnt: 2, ((T([32, 16, 55, 55], f16), T([32, 16, 55, 55], f16)), {})
+Operator: aten.cat.default
+cnt: 2, (([T([32, 64, 55, 55], f16), T([32, 64, 55, 55], f16)], 1), {})
+cnt: 2, (([T([32, 128, 27, 27], f16), T([32, 128, 27, 27], f16)], 1), {})
+cnt: 2, (([T([32, 192, 13, 13], f16), T([32, 192, 13, 13], f16)], 1), {})
+cnt: 2, (([T([32, 256, 13, 13], f16), T([32, 256, 13, 13], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([64, 3, 3, 3], f16), T([64], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 55, 55], f16), T([16, 64, 1, 1], f16), T([16], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 16, 55, 55], f16), T([64, 16, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 16, 55, 55], f16), T([64, 16, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 55, 55], f16), T([16, 128, 1, 1], f16), T([16], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 27, 27], f16), T([32, 128, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 32, 27, 27], f16), T([128, 32, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 32, 27, 27], f16), T([128, 32, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 27, 27], f16), T([32, 256, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 13, 13], f16), T([48, 256, 1, 1], f16), T([48], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 48, 13, 13], f16), T([192, 48, 1, 1], f16), T([192], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 48, 13, 13], f16), T([192, 48, 3, 3], f16), T([192], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 384, 13, 13], f16), T([48, 384, 1, 1], f16), T([48], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 384, 13, 13], f16), T([64, 384, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 64, 13, 13], f16), T([256, 64, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 64, 13, 13], f16), T([256, 64, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 13, 13], f16), T([64, 512, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 13, 13], f16), T([1000, 512, 1, 1], f16), T([1000], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 1000, 13, 13], f16), T([32, 512, 13, 13], f16), T([1000, 512, 1, 1], f16), [1000], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 256, 13, 13], f16), T([32, 64, 13, 13], f16), T([256, 64, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 256, 13, 13], f16), T([32, 64, 13, 13], f16), T([256, 64, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 13, 13], f16), T([32, 512, 13, 13], f16), T([64, 512, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 13, 13], f16), T([32, 384, 13, 13], f16), T([64, 384, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 192, 13, 13], f16), T([32, 48, 13, 13], f16), T([192, 48, 3, 3], f16), [192], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 192, 13, 13], f16), T([32, 48, 13, 13], f16), T([192, 48, 1, 1], f16), [192], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 48, 13, 13], f16), T([32, 384, 13, 13], f16), T([48, 384, 1, 1], f16), [48], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 48, 13, 13], f16), T([32, 256, 13, 13], f16), T([48, 256, 1, 1], f16), [48], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 128, 27, 27], f16), T([32, 32, 27, 27], f16), T([128, 32, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 128, 27, 27], f16), T([32, 32, 27, 27], f16), T([128, 32, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 32, 27, 27], f16), T([32, 256, 27, 27], f16), T([32, 256, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 32, 27, 27], f16), T([32, 128, 27, 27], f16), T([32, 128, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 64, 55, 55], f16), T([32, 16, 55, 55], f16), T([64, 16, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 64, 55, 55], f16), T([32, 16, 55, 55], f16), T([64, 16, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 16, 55, 55], f16), T([32, 128, 55, 55], f16), T([16, 128, 1, 1], f16), [16], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 16, 55, 55], f16), T([32, 64, 55, 55], f16), T([16, 64, 1, 1], f16), [16], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 111, 111], f16), T([32, 3, 224, 224], f16), T([64, 3, 3, 3], f16), [64], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 1000, 13, 13], f16, stride=(0, 0, 0, 0)), 169), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 32000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([32, 64, 111, 111], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+cnt: 1, ((T([32, 128, 55, 55], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+cnt: 1, ((T([32, 256, 27, 27], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([32, 256, 13, 13], f16), T([32, 256, 27, 27], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([32, 256, 13, 13], i64)), {})
+cnt: 1, ((T([32, 128, 27, 27], f16), T([32, 128, 55, 55], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([32, 128, 27, 27], i64)), {})
+cnt: 1, ((T([32, 64, 55, 55], f16), T([32, 64, 111, 111], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([32, 64, 55, 55], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 1000, 13, 13], f16), [-1, -2], True), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([32, 64, 111, 111], f16),), {})
+cnt: 2, ((T([32, 16, 55, 55], f16),), {})
+cnt: 4, ((T([32, 64, 55, 55], f16),), {})
+cnt: 2, ((T([32, 32, 27, 27], f16),), {})
+cnt: 4, ((T([32, 128, 27, 27], f16),), {})
+cnt: 2, ((T([32, 48, 13, 13], f16),), {})
+cnt: 4, ((T([32, 192, 13, 13], f16),), {})
+cnt: 2, ((T([32, 64, 13, 13], f16),), {})
+cnt: 4, ((T([32, 256, 13, 13], f16),), {})
+cnt: 1, ((T([32, 1000, 13, 13], f16),), {})
+Operator: aten.sum.default
+cnt: 1, ((T([32, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([32, 1000, 13, 13], f16), T([32, 1000, 13, 13], f16), 0), {})
+cnt: 4, ((T([32, 256, 13, 13], f16, stride=(86528, 169, 13, 1)), T([32, 256, 13, 13], f16), 0), {})
+cnt: 2, ((T([32, 64, 13, 13], f16), T([32, 64, 13, 13], f16), 0), {})
+cnt: 4, ((T([32, 192, 13, 13], f16, stride=(64896, 169, 13, 1)), T([32, 192, 13, 13], f16), 0), {})
+cnt: 2, ((T([32, 48, 13, 13], f16), T([32, 48, 13, 13], f16), 0), {})
+cnt: 4, ((T([32, 128, 27, 27], f16, stride=(186624, 729, 27, 1)), T([32, 128, 27, 27], f16), 0), {})
+cnt: 2, ((T([32, 32, 27, 27], f16), T([32, 32, 27, 27], f16), 0), {})
+cnt: 4, ((T([32, 64, 55, 55], f16, stride=(387200, 3025, 55, 1)), T([32, 64, 55, 55], f16), 0), {})
+cnt: 2, ((T([32, 16, 55, 55], f16), T([32, 16, 55, 55], f16), 0), {})
+cnt: 1, ((T([32, 64, 111, 111], f16), T([32, 64, 111, 111], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_efficientdet_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_efficientdet_training.txt
new file mode 100644
index 0000000000000..873f036593f0e
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_efficientdet_training.txt
@@ -0,0 +1,623 @@
+Operator: aten._index_put_impl_.default
+cnt: 1, ((T([5000, 1], f32), [T([100], i64)], T([100, 1], f32, stride=(0, 0)), True, True), {})
+cnt: 1, ((T([5000, 4], f32), [T([100], i64)], T([100, 4], f32), True, True), {})
+Operator: aten._to_copy.default
+cnt: 1, ((T([5000, 4], f16),), {'dtype': f32})
+cnt: 1, ((T([5000], f16),), {'dtype': f32})
+cnt: 1, ((T([5000], i64),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 1, ((T([], i64),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 1, ((T([100, 1], i64),), {'dtype': f32})
+cnt: 1, ((T([5000], f32),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([5000, 4], f32),), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten._unsafe_view.default
+cnt: 1, ((T([1, 80, 80, 810], f16), [1, 57600, 90]), {})
+cnt: 1, ((T([1, 40, 40, 810], f16), [1, 14400, 90]), {})
+cnt: 1, ((T([1, 20, 20, 810], f16), [1, 3600, 90]), {})
+cnt: 1, ((T([1, 10, 10, 810], f16), [1, 900, 90]), {})
+cnt: 1, ((T([1, 5, 5, 810], f16), [1, 225, 90]), {})
+cnt: 1, ((T([1, 80, 80, 36], f16), [1, 57600, 4]), {})
+cnt: 1, ((T([1, 40, 40, 36], f16), [1, 14400, 4]), {})
+cnt: 1, ((T([1, 20, 20, 36], f16), [1, 3600, 4]), {})
+cnt: 1, ((T([1, 10, 10, 36], f16), [1, 900, 4]), {})
+cnt: 1, ((T([1, 5, 5, 36], f16), [1, 225, 4]), {})
+Operator: aten.add.Scalar
+cnt: 1, ((T([100, 1], i64), 1), {})
+Operator: aten.add.Tensor
+cnt: 3, ((T([1, 16, 320, 320], f16), T([1, 16, 320, 320], f16)), {})
+cnt: 4, ((T([1, 24, 160, 160], f16), T([1, 24, 160, 160], f16)), {})
+cnt: 5, ((T([1, 40, 80, 80], f16), T([1, 40, 80, 80], f16)), {})
+cnt: 6, ((T([1, 80, 40, 40], f16), T([1, 80, 40, 40], f16)), {})
+cnt: 8, ((T([1, 112, 40, 40], f16), T([1, 112, 40, 40], f16)), {})
+cnt: 8, ((T([1, 192, 20, 20], f16), T([1, 192, 20, 20], f16)), {})
+cnt: 4, ((T([1, 320, 20, 20], f16), T([1, 320, 20, 20], f16)), {})
+cnt: 76, ((T([], f16), 0.0001), {})
+cnt: 2, ((T([5000], f16, stride=(4,)), T([5000], f16, stride=(4,))), {})
+cnt: 2, ((T([5000], f32), T([5000], f16)), {})
+cnt: 2, ((T([5000], f32), T([5000], f32)), {})
+cnt: 1, ((T([], f32), T([], f32)), {})
+cnt: 1, ((T([5000, 4], f32), T([5000, 1], f32)), {})
+cnt: 2, ((T([5000], f32, stride=(4,)), T([5000], f32, stride=(4,))), {})
+cnt: 2, ((T([5000], f32, stride=(4,)), T([5000], f32)), {})
+cnt: 4, ((T([36, 88, 1, 1], f16), T([36, 88, 1, 1], f16)), {})
+cnt: 4, ((T([36], f16), T([36], f16)), {})
+cnt: 32, ((T([88, 1, 3, 3], f16), T([88, 1, 3, 3], f16)), {})
+cnt: 24, ((T([88, 88, 1, 1], f16), T([88, 88, 1, 1], f16)), {})
+cnt: 24, ((T([88], f16), T([88], f16)), {})
+cnt: 5, ((T([1, 88, 5, 5], f16), T([1, 88, 5, 5], f16)), {})
+cnt: 4, ((T([810, 88, 1, 1], f16), T([810, 88, 1, 1], f16)), {})
+cnt: 4, ((T([810], f16), T([810], f16)), {})
+cnt: 14, ((T([1, 88, 10, 10], f16), T([1, 88, 10, 10], f16)), {})
+cnt: 12, ((T([1, 88, 20, 20], f16), T([1, 88, 20, 20], f16)), {})
+cnt: 12, ((T([1, 88, 40, 40], f16), T([1, 88, 40, 40], f16)), {})
+cnt: 5, ((T([1, 88, 80, 80], f16), T([1, 88, 80, 80], f16)), {})
+cnt: 44, ((T([], f16), T([], f16)), {})
+cnt: 20, ((T([2], f16), T([2], f16)), {})
+cnt: 20, ((T([2], f16), T([2], f16, stride=(0,))), {})
+cnt: 24, ((T([3], f16), T([3], f16)), {})
+cnt: 12, ((T([3], f16), T([3], f16, stride=(0,))), {})
+cnt: 1, ((T([1, 1920, 20, 20], f16), T([1, 1920, 20, 20], f16)), {})
+cnt: 5, ((T([1, 1152, 20, 20], f16), T([1, 1152, 20, 20], f16)), {})
+cnt: 1, ((T([1, 672, 20, 20], f16), T([1, 672, 20, 20], f16)), {})
+cnt: 3, ((T([1, 672, 40, 40], f16), T([1, 672, 40, 40], f16)), {})
+cnt: 4, ((T([1, 480, 40, 40], f16), T([1, 480, 40, 40], f16)), {})
+cnt: 1, ((T([1, 240, 40, 40], f16), T([1, 240, 40, 40], f16)), {})
+cnt: 2, ((T([1, 240, 80, 80], f16), T([1, 240, 80, 80], f16)), {})
+cnt: 1, ((T([1, 144, 80, 80], f16), T([1, 144, 80, 80], f16)), {})
+cnt: 2, ((T([1, 144, 160, 160], f16), T([1, 144, 160, 160], f16)), {})
+cnt: 1, ((T([1, 96, 160, 160], f16), T([1, 96, 160, 160], f16)), {})
+cnt: 1, ((T([1, 32, 320, 320], f16), T([1, 32, 320, 320], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([1, 57600, 90], f16), T([1, 14400, 90], f16), T([1, 3600, 90], f16), T([1, 900, 90], f16), T([1, 225, 90], f16)], 1), {})
+cnt: 1, (([T([1, 57600, 4], f16), T([1, 14400, 4], f16), T([1, 3600, 4], f16), T([1, 900, 4], f16), T([1, 225, 4], f16)], 1), {})
+cnt: 1, (([T([2], f16), T([2], f16)],), {})
+cnt: 1, (([T([100, 4], f32), T([100, 1], f32), T([100, 1], f32)], 1), {})
+Operator: aten.clamp.default
+cnt: 1, ((T([5000, 4], f32), 0), {})
+Operator: aten.clone.default
+cnt: 1, ((T([1, 3, 640, 640], f16),), {})
+cnt: 2, ((T([1, 32, 320, 320], f16),), {})
+cnt: 1, ((T([1, 8, 1, 1], f16),), {})
+cnt: 1, ((T([1, 16, 320, 320], f16),), {})
+cnt: 2, ((T([1, 4, 1, 1], f16),), {})
+cnt: 1, ((T([1, 96, 320, 320], f16),), {})
+cnt: 1, ((T([1, 96, 160, 160], f16),), {})
+cnt: 5, ((T([1, 144, 160, 160], f16),), {})
+cnt: 3, ((T([1, 6, 1, 1], f16),), {})
+cnt: 1, ((T([1, 144, 80, 80], f16),), {})
+cnt: 5, ((T([1, 240, 80, 80], f16),), {})
+cnt: 3, ((T([1, 10, 1, 1], f16),), {})
+cnt: 1, ((T([1, 240, 40, 40], f16),), {})
+cnt: 8, ((T([1, 480, 40, 40], f16),), {})
+cnt: 4, ((T([1, 20, 1, 1], f16),), {})
+cnt: 7, ((T([1, 672, 40, 40], f16),), {})
+cnt: 4, ((T([1, 28, 1, 1], f16),), {})
+cnt: 1, ((T([1, 672, 20, 20], f16),), {})
+cnt: 10, ((T([1, 1152, 20, 20], f16),), {})
+cnt: 5, ((T([1, 48, 1, 1], f16),), {})
+cnt: 2, ((T([1, 1920, 20, 20], f16),), {})
+cnt: 1, ((T([1, 80, 1, 1], f16),), {})
+cnt: 14, ((T([1, 88, 10, 10], f16),), {})
+cnt: 14, ((T([1, 88, 20, 20], f16),), {})
+cnt: 14, ((T([1, 88, 40, 40], f16),), {})
+cnt: 10, ((T([1, 88, 80, 80], f16),), {})
+cnt: 10, ((T([1, 88, 5, 5], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([1, 3, 640, 640], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([1, 96, 320, 320], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([1, 144, 160, 160], f16), [1, 2, 1, 2], 0.0), {})
+cnt: 1, ((T([1, 240, 80, 80], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([1, 672, 40, 40], f16), [1, 2, 1, 2], 0.0), {})
+cnt: 5, ((T([1, 88, 20, 20], f16), [0, 1, 0, 1], -inf), {})
+cnt: 5, ((T([1, 88, 10, 10], f16), [0, 1, 0, 1], -inf), {})
+cnt: 4, ((T([1, 88, 80, 80], f16), [0, 1, 0, 1], -inf), {})
+cnt: 4, ((T([1, 88, 40, 40], f16), [0, 1, 0, 1], -inf), {})
+cnt: 5, ((T([1, 88, 11, 11], f16), [0, -1, 0, -1]), {})
+cnt: 5, ((T([1, 88, 21, 21], f16), [0, -1, 0, -1]), {})
+cnt: 4, ((T([1, 88, 41, 41], f16), [0, -1, 0, -1]), {})
+cnt: 4, ((T([1, 88, 81, 81], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([1, 672, 43, 43], f16), [-1, -2, -1, -2]), {})
+cnt: 1, ((T([1, 240, 81, 81], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([1, 144, 163, 163], f16), [-1, -2, -1, -2]), {})
+cnt: 1, ((T([1, 96, 321, 321], f16), [0, -1, 0, -1]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([1, 3, 641, 641], f16), T([32, 3, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 32, 320, 320], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([1, 32, 1, 1], f16), T([8, 32, 1, 1], f16), T([8], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 8, 1, 1], f16), T([32, 8, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 32, 320, 320], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), T([16, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 16), {})
+cnt: 1, ((T([1, 16, 1, 1], f16), T([4, 16, 1, 1], f16), T([4], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 4, 1, 1], f16), T([16, 4, 1, 1], f16), T([16], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), T([16, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 96, 321, 321], f16), T([96, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([1, 96, 1, 1], f16), T([4, 96, 1, 1], f16), T([4], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 4, 1, 1], f16), T([96, 4, 1, 1], f16), T([96], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 96, 160, 160], f16), T([24, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([1, 24, 160, 160], f16), T([144, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([1, 144, 160, 160], f16), T([144, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 144), {})
+cnt: 3, ((T([1, 144, 1, 1], f16), T([6, 144, 1, 1], f16), T([6], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([1, 6, 1, 1], f16), T([144, 6, 1, 1], f16), T([144], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([1, 144, 160, 160], f16), T([24, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 144, 163, 163], f16), T([144, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([1, 144, 80, 80], f16), T([40, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([1, 40, 80, 80], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([1, 240, 80, 80], f16), T([240, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 3, ((T([1, 240, 1, 1], f16), T([10, 240, 1, 1], f16), T([10], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([1, 10, 1, 1], f16), T([240, 10, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([1, 240, 80, 80], f16), T([40, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 240, 81, 81], f16), T([240, 1, 3, 3], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([1, 240, 40, 40], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([1, 80, 40, 40], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([1, 480, 40, 40], f16), T([480, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 480), {})
+cnt: 4, ((T([1, 480, 1, 1], f16), T([20, 480, 1, 1], f16), T([20], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([1, 20, 1, 1], f16), T([480, 20, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([1, 480, 40, 40], f16), T([80, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 480, 40, 40], f16), T([480, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 480), {})
+cnt: 1, ((T([1, 480, 40, 40], f16), T([112, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([1, 112, 40, 40], f16), T([672, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([1, 672, 40, 40], f16), T([672, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 4, ((T([1, 672, 1, 1], f16), T([28, 672, 1, 1], f16), T([28], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([1, 28, 1, 1], f16), T([672, 28, 1, 1], f16), T([672], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([1, 672, 40, 40], f16), T([112, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 672, 43, 43], f16), T([672, 1, 5, 5], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 672), {})
+cnt: 1, ((T([1, 672, 20, 20], f16), T([192, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([1, 192, 20, 20], f16), T([1152, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([1, 1152, 20, 20], f16), T([1152, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1152), {})
+cnt: 5, ((T([1, 1152, 1, 1], f16), T([48, 1152, 1, 1], f16), T([48], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([1, 48, 1, 1], f16), T([1152, 48, 1, 1], f16), T([1152], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([1, 1152, 20, 20], f16), T([192, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 1152, 20, 20], f16), T([1152, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1152), {})
+cnt: 1, ((T([1, 1152, 20, 20], f16), T([320, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 320, 20, 20], f16), T([1920, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 1920, 20, 20], f16), T([1920, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1920), {})
+cnt: 1, ((T([1, 1920, 1, 1], f16), T([80, 1920, 1, 1], f16), T([80], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 80, 1, 1], f16), T([1920, 80, 1, 1], f16), T([1920], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 1920, 20, 20], f16), T([320, 1920, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([1, 320, 20, 20], f16), T([88, 320, 1, 1], f16), T([88], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 16, ((T([1, 88, 10, 10], f16), T([88, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 88), {})
+cnt: 14, ((T([1, 88, 10, 10], f16), T([88, 88, 1, 1], f16), T([88], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 16, ((T([1, 88, 20, 20], f16), T([88, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 88), {})
+cnt: 14, ((T([1, 88, 20, 20], f16), T([88, 88, 1, 1], f16), T([88], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([1, 112, 40, 40], f16), T([88, 112, 1, 1], f16), T([88], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 16, ((T([1, 88, 40, 40], f16), T([88, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 88), {})
+cnt: 14, ((T([1, 88, 40, 40], f16), T([88, 88, 1, 1], f16), T([88], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 40, 80, 80], f16), T([88, 40, 1, 1], f16), T([88], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([1, 88, 80, 80], f16), T([88, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 88), {})
+cnt: 10, ((T([1, 88, 80, 80], f16), T([88, 88, 1, 1], f16), T([88], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 12, ((T([1, 88, 5, 5], f16), T([88, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 88), {})
+cnt: 10, ((T([1, 88, 5, 5], f16), T([88, 88, 1, 1], f16), T([88], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 80, 80], f16), T([810, 88, 1, 1], f16), T([810], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 40, 40], f16), T([810, 88, 1, 1], f16), T([810], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 20, 20], f16), T([810, 88, 1, 1], f16), T([810], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 10, 10], f16), T([810, 88, 1, 1], f16), T([810], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 5, 5], f16), T([810, 88, 1, 1], f16), T([810], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 80, 80], f16), T([36, 88, 1, 1], f16), T([36], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 40, 40], f16), T([36, 88, 1, 1], f16), T([36], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 20, 20], f16), T([36, 88, 1, 1], f16), T([36], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 10, 10], f16), T([36, 88, 1, 1], f16), T([36], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([1, 88, 5, 5], f16), T([36, 88, 1, 1], f16), T([36], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([1, 36, 5, 5], f16, stride=(900, 1, 180, 36)), T([1, 88, 5, 5], f16), T([36, 88, 1, 1], f16), [36], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 12, ((T([1, 88, 5, 5], f16), T([1, 88, 5, 5], f16), T([88, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 88, [True, True, False]), {})
+cnt: 10, ((T([1, 88, 5, 5], f16), T([1, 88, 5, 5], f16), T([88, 88, 1, 1], f16), [88], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 36, 10, 10], f16, stride=(3600, 1, 360, 36)), T([1, 88, 10, 10], f16), T([36, 88, 1, 1], f16), [36], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 16, ((T([1, 88, 10, 10], f16), T([1, 88, 10, 10], f16), T([88, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 88, [True, True, False]), {})
+cnt: 14, ((T([1, 88, 10, 10], f16), T([1, 88, 10, 10], f16), T([88, 88, 1, 1], f16), [88], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 36, 20, 20], f16, stride=(14400, 1, 720, 36)), T([1, 88, 20, 20], f16), T([36, 88, 1, 1], f16), [36], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 16, ((T([1, 88, 20, 20], f16), T([1, 88, 20, 20], f16), T([88, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 88, [True, True, False]), {})
+cnt: 14, ((T([1, 88, 20, 20], f16), T([1, 88, 20, 20], f16), T([88, 88, 1, 1], f16), [88], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 36, 40, 40], f16, stride=(57600, 1, 1440, 36)), T([1, 88, 40, 40], f16), T([36, 88, 1, 1], f16), [36], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 16, ((T([1, 88, 40, 40], f16), T([1, 88, 40, 40], f16), T([88, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 88, [True, True, False]), {})
+cnt: 14, ((T([1, 88, 40, 40], f16), T([1, 88, 40, 40], f16), T([88, 88, 1, 1], f16), [88], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 36, 80, 80], f16, stride=(230400, 1, 2880, 36)), T([1, 88, 80, 80], f16), T([36, 88, 1, 1], f16), [36], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 12, ((T([1, 88, 80, 80], f16), T([1, 88, 80, 80], f16), T([88, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 88, [True, True, False]), {})
+cnt: 10, ((T([1, 88, 80, 80], f16), T([1, 88, 80, 80], f16), T([88, 88, 1, 1], f16), [88], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 810, 5, 5], f16, stride=(20250, 1, 4050, 810)), T([1, 88, 5, 5], f16), T([810, 88, 1, 1], f16), [810], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 810, 10, 10], f16, stride=(81000, 1, 8100, 810)), T([1, 88, 10, 10], f16), T([810, 88, 1, 1], f16), [810], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 810, 20, 20], f16, stride=(324000, 1, 16200, 810)), T([1, 88, 20, 20], f16), T([810, 88, 1, 1], f16), [810], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 810, 40, 40], f16, stride=(1296000, 1, 32400, 810)), T([1, 88, 40, 40], f16), T([810, 88, 1, 1], f16), [810], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 810, 80, 80], f16, stride=(5184000, 1, 64800, 810)), T([1, 88, 80, 80], f16), T([810, 88, 1, 1], f16), [810], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([1, 88, 20, 20], f16), T([1, 320, 20, 20], f16), T([88, 320, 1, 1], f16), [88], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([1, 88, 40, 40], f16), T([1, 112, 40, 40], f16), T([88, 112, 1, 1], f16), [88], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 88, 80, 80], f16), T([1, 40, 80, 80], f16), T([88, 40, 1, 1], f16), [88], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 320, 20, 20], f16), T([1, 1920, 20, 20], f16), T([320, 1920, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([1, 1920, 1, 1], f16), T([1, 80, 1, 1], f16), T([1920, 80, 1, 1], f16), [1920], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 80, 1, 1], f16), T([1, 1920, 1, 1], f16), T([80, 1920, 1, 1], f16), [80], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 1920, 20, 20], f16), T([1, 1920, 20, 20], f16), T([1920, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1920, [True, True, False]), {})
+cnt: 1, ((T([1, 1920, 20, 20], f16), T([1, 320, 20, 20], f16), T([1920, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([1, 320, 20, 20], f16), T([1, 1152, 20, 20], f16), T([320, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([1, 1152, 1, 1], f16), T([1, 48, 1, 1], f16), T([1152, 48, 1, 1], f16), [1152], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 5, ((T([1, 48, 1, 1], f16), T([1, 1152, 1, 1], f16), T([48, 1152, 1, 1], f16), [48], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 1152, 20, 20], f16), T([1, 1152, 20, 20], f16), T([1152, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 5, ((T([1, 1152, 20, 20], f16), T([1, 192, 20, 20], f16), T([1152, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([1, 192, 20, 20], f16), T([1, 1152, 20, 20], f16), T([192, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([1, 1152, 20, 20], f16), T([1, 1152, 20, 20], f16), T([1152, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 1, ((T([1, 192, 20, 20], f16), T([1, 672, 20, 20], f16), T([192, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([1, 672, 1, 1], f16), T([1, 28, 1, 1], f16), T([672, 28, 1, 1], f16), [672], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([1, 28, 1, 1], f16), T([1, 672, 1, 1], f16), T([28, 672, 1, 1], f16), [28], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 672, 20, 20], f16), T([1, 672, 43, 43], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 4, ((T([1, 672, 40, 40], f16), T([1, 112, 40, 40], f16), T([672, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([1, 112, 40, 40], f16), T([1, 672, 40, 40], f16), T([112, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([1, 672, 40, 40], f16), T([1, 672, 40, 40], f16), T([672, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 1, ((T([1, 112, 40, 40], f16), T([1, 480, 40, 40], f16), T([112, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([1, 480, 1, 1], f16), T([1, 20, 1, 1], f16), T([480, 20, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([1, 20, 1, 1], f16), T([1, 480, 1, 1], f16), T([20, 480, 1, 1], f16), [20], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 480, 40, 40], f16), T([1, 480, 40, 40], f16), T([480, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 4, ((T([1, 480, 40, 40], f16), T([1, 80, 40, 40], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([1, 80, 40, 40], f16), T([1, 480, 40, 40], f16), T([80, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([1, 480, 40, 40], f16), T([1, 480, 40, 40], f16), T([480, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 1, ((T([1, 80, 40, 40], f16), T([1, 240, 40, 40], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([1, 240, 1, 1], f16), T([1, 10, 1, 1], f16), T([240, 10, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([1, 10, 1, 1], f16), T([1, 240, 1, 1], f16), T([10, 240, 1, 1], f16), [10], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 240, 40, 40], f16), T([1, 240, 81, 81], f16), T([240, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 3, ((T([1, 240, 80, 80], f16), T([1, 40, 80, 80], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([1, 40, 80, 80], f16), T([1, 240, 80, 80], f16), T([40, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([1, 240, 80, 80], f16), T([1, 240, 80, 80], f16), T([240, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([1, 40, 80, 80], f16), T([1, 144, 80, 80], f16), T([40, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([1, 144, 1, 1], f16), T([1, 6, 1, 1], f16), T([144, 6, 1, 1], f16), [144], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([1, 6, 1, 1], f16), T([1, 144, 1, 1], f16), T([6, 144, 1, 1], f16), [6], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 144, 80, 80], f16), T([1, 144, 163, 163], f16), T([144, 1, 5, 5], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 3, ((T([1, 144, 160, 160], f16), T([1, 24, 160, 160], f16), T([144, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([1, 24, 160, 160], f16), T([1, 144, 160, 160], f16), T([24, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([1, 144, 160, 160], f16), T([1, 144, 160, 160], f16), T([144, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 1, ((T([1, 24, 160, 160], f16), T([1, 96, 160, 160], f16), T([24, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([1, 96, 1, 1], f16), T([1, 4, 1, 1], f16), T([96, 4, 1, 1], f16), [96], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 4, 1, 1], f16), T([1, 96, 1, 1], f16), T([4, 96, 1, 1], f16), [4], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 96, 160, 160], f16), T([1, 96, 321, 321], f16), T([96, 1, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([1, 96, 320, 320], f16), T([1, 16, 320, 320], f16), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), T([1, 16, 320, 320], f16), T([16, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([1, 16, 1, 1], f16), T([1, 4, 1, 1], f16), T([16, 4, 1, 1], f16), [16], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 4, 1, 1], f16), T([1, 16, 1, 1], f16), T([4, 16, 1, 1], f16), [4], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), T([1, 16, 320, 320], f16), T([16, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 16, [True, True, False]), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), T([1, 32, 320, 320], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([1, 32, 1, 1], f16), T([1, 8, 1, 1], f16), T([32, 8, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 8, 1, 1], f16), T([1, 32, 1, 1], f16), T([8, 32, 1, 1], f16), [8], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([1, 32, 320, 320], f16), T([1, 32, 320, 320], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([1, 32, 320, 320], f16), T([1, 3, 641, 641], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([1, 3, 640, 640], f16), T([1, 3, 640, 640], f16)), {})
+Operator: aten.div.Scalar
+cnt: 2, ((T([5000], f16), 2), {})
+cnt: 2, ((T([5000], f32), 2.0), {})
+cnt: 1, ((T([5000, 4], f32), 2), {})
+cnt: 1, ((T([1, 1920, 20, 20], f16, stride=(1920, 1, 0, 0)), 400), {})
+cnt: 5, ((T([1, 1152, 20, 20], f16, stride=(1152, 1, 0, 0)), 400), {})
+cnt: 1, ((T([1, 672, 20, 20], f16, stride=(672, 1, 0, 0)), 400), {})
+cnt: 3, ((T([1, 672, 40, 40], f16, stride=(672, 1, 0, 0)), 1600), {})
+cnt: 4, ((T([1, 480, 40, 40], f16, stride=(480, 1, 0, 0)), 1600), {})
+cnt: 1, ((T([1, 240, 40, 40], f16, stride=(240, 1, 0, 0)), 1600), {})
+cnt: 2, ((T([1, 240, 80, 80], f16, stride=(240, 1, 0, 0)), 6400), {})
+cnt: 1, ((T([1, 144, 80, 80], f16, stride=(144, 1, 0, 0)), 6400), {})
+cnt: 2, ((T([1, 144, 160, 160], f16, stride=(144, 1, 0, 0)), 25600), {})
+cnt: 1, ((T([1, 96, 160, 160], f16, stride=(96, 1, 0, 0)), 25600), {})
+cnt: 1, ((T([1, 16, 320, 320], f16, stride=(16, 1, 0, 0)), 102400), {})
+cnt: 1, ((T([1, 32, 320, 320], f16, stride=(32, 1, 0, 0)), 102400), {})
+Operator: aten.div.Tensor
+cnt: 80, ((T([1, 88, 10, 10], f16), T([], f16)), {})
+cnt: 80, ((T([1, 88, 20, 20], f16), T([], f16)), {})
+cnt: 80, ((T([1, 88, 40, 40], f16), T([], f16)), {})
+cnt: 32, ((T([1, 88, 80, 80], f16), T([], f16)), {})
+cnt: 32, ((T([1, 88, 5, 5], f16), T([], f16)), {})
+cnt: 1, ((T([2], i32), T([], f16)), {})
+cnt: 2, ((T([], f32), 600), {})
+cnt: 2, ((T([5000], f32), T([], f64)), {})
+Operator: aten.eq.Tensor
+cnt: 1, ((T([5000, 4], f32), T([4], f16)), {})
+Operator: aten.exp.default
+cnt: 2, ((T([5000], f32, stride=(4,)),), {})
+Operator: aten.floor_divide.default
+cnt: 1, ((T([1, 5000], i64), 90), {})
+Operator: aten.gather.default
+cnt: 1, ((T([1, 76725, 4], f16), 1, T([1, 5000, 4], i64, stride=(5000, 1, 0))), {})
+cnt: 1, ((T([1, 76725, 90], f16), 1, T([1, 5000, 90], i64, stride=(5000, 1, 0))), {})
+cnt: 1, ((T([1, 5000, 90], f16), 2, T([1, 5000, 1], i64)), {})
+Operator: aten.ge.Scalar
+cnt: 1, ((T([5000, 4], f32), 0), {})
+Operator: aten.gt.Tensor
+cnt: 1, ((T([5000, 4], f32), T([4], f16)), {})
+Operator: aten.index.Tensor
+cnt: 1, ((T([76725, 4], f16, stride=(1, 76725)), [T([5000], i64)]), {})
+cnt: 1, ((T([5000, 4], f32), [T([100], i64)]), {})
+cnt: 1, ((T([5000, 1], f32), [T([100], i64)]), {})
+cnt: 1, ((T([5000, 1], i64), [T([100], i64)]), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([5000, 4], f32), T([5000, 4], b8), 0), {})
+Operator: aten.max.default
+cnt: 1, ((T([5000, 4], f32),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 5, ((T([1, 88, 21, 21], f16), [3, 3], [2, 2]), {})
+cnt: 5, ((T([1, 88, 11, 11], f16), [3, 3], [2, 2]), {})
+cnt: 4, ((T([1, 88, 81, 81], f16), [3, 3], [2, 2]), {})
+cnt: 4, ((T([1, 88, 41, 41], f16), [3, 3], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 5, ((T([1, 88, 5, 5], f16), T([1, 88, 11, 11], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([1, 88, 5, 5], i64)), {})
+cnt: 5, ((T([1, 88, 10, 10], f16), T([1, 88, 21, 21], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([1, 88, 10, 10], i64)), {})
+cnt: 4, ((T([1, 88, 20, 20], f16), T([1, 88, 41, 41], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([1, 88, 20, 20], i64)), {})
+cnt: 4, ((T([1, 88, 40, 40], f16), T([1, 88, 81, 81], f16), [3, 3], [2, 2], [0, 0], [1, 1], False, T([1, 88, 40, 40], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([1, 32, 320, 320], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 96, 160, 160], f16), [2, 3], True), {})
+cnt: 2, ((T([1, 144, 160, 160], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 144, 80, 80], f16), [2, 3], True), {})
+cnt: 2, ((T([1, 240, 80, 80], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 240, 40, 40], f16), [2, 3], True), {})
+cnt: 4, ((T([1, 480, 40, 40], f16), [2, 3], True), {})
+cnt: 3, ((T([1, 672, 40, 40], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 672, 20, 20], f16), [2, 3], True), {})
+cnt: 5, ((T([1, 1152, 20, 20], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 1920, 20, 20], f16), [2, 3], True), {})
+Operator: aten.minimum.default
+cnt: 1, ((T([5000, 4], f32), T([4], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([1, 32, 320, 320], f16), T([1, 32, 1, 1], f16)), {})
+cnt: 2, ((T([1, 16, 320, 320], f16), T([1, 16, 1, 1], f16)), {})
+cnt: 2, ((T([1, 96, 160, 160], f16), T([1, 96, 1, 1], f16)), {})
+cnt: 4, ((T([1, 144, 160, 160], f16), T([1, 144, 1, 1], f16)), {})
+cnt: 2, ((T([1, 144, 80, 80], f16), T([1, 144, 1, 1], f16)), {})
+cnt: 4, ((T([1, 240, 80, 80], f16), T([1, 240, 1, 1], f16)), {})
+cnt: 2, ((T([1, 240, 40, 40], f16), T([1, 240, 1, 1], f16)), {})
+cnt: 8, ((T([1, 480, 40, 40], f16), T([1, 480, 1, 1], f16)), {})
+cnt: 6, ((T([1, 672, 40, 40], f16), T([1, 672, 1, 1], f16)), {})
+cnt: 2, ((T([1, 672, 20, 20], f16), T([1, 672, 1, 1], f16)), {})
+cnt: 10, ((T([1, 1152, 20, 20], f16), T([1, 1152, 1, 1], f16)), {})
+cnt: 2, ((T([1, 1920, 20, 20], f16), T([1, 1920, 1, 1], f16)), {})
+cnt: 40, ((T([1, 88, 10, 10], f16), T([], f16)), {})
+cnt: 40, ((T([1, 88, 20, 20], f16), T([], f16)), {})
+cnt: 40, ((T([1, 88, 40, 40], f16), T([], f16)), {})
+cnt: 16, ((T([1, 88, 80, 80], f16), T([], f16)), {})
+cnt: 16, ((T([1, 88, 5, 5], f16), T([], f16)), {})
+cnt: 6, ((T([5000], f32), T([5000], f16)), {})
+cnt: 2, ((T([5000], f32, stride=(4,)), T([5000], f16)), {})
+cnt: 1, ((T([5000], f32), T([], f32)), {})
+cnt: 1, ((T([100, 4], f32), T([], f16)), {})
+cnt: 1, ((T([100, 4], f32, stride=(0, 0)), T([], f16)), {})
+cnt: 2, ((T([5000], f32), T([5000], f32)), {})
+cnt: 16, ((T([1, 88, 5, 5], f16), T([1, 88, 5, 5], f16)), {})
+cnt: 40, ((T([1, 88, 10, 10], f16), T([1, 88, 10, 10], f16)), {})
+cnt: 40, ((T([1, 88, 20, 20], f16), T([1, 88, 20, 20], f16)), {})
+cnt: 40, ((T([1, 88, 40, 40], f16), T([1, 88, 40, 40], f16)), {})
+cnt: 16, ((T([1, 88, 80, 80], f16), T([1, 88, 80, 80], f16)), {})
+cnt: 1, ((T([1, 1920, 20, 20], f16), T([1, 1920, 20, 20], f16)), {})
+cnt: 5, ((T([1, 1152, 20, 20], f16), T([1, 1152, 20, 20], f16)), {})
+cnt: 1, ((T([1, 672, 20, 20], f16), T([1, 672, 20, 20], f16)), {})
+cnt: 3, ((T([1, 672, 40, 40], f16), T([1, 672, 40, 40], f16)), {})
+cnt: 4, ((T([1, 480, 40, 40], f16), T([1, 480, 40, 40], f16)), {})
+cnt: 1, ((T([1, 240, 40, 40], f16), T([1, 240, 40, 40], f16)), {})
+cnt: 2, ((T([1, 240, 80, 80], f16), T([1, 240, 80, 80], f16)), {})
+cnt: 1, ((T([1, 144, 80, 80], f16), T([1, 144, 80, 80], f16)), {})
+cnt: 2, ((T([1, 144, 160, 160], f16), T([1, 144, 160, 160], f16)), {})
+cnt: 1, ((T([1, 96, 160, 160], f16), T([1, 96, 160, 160], f16)), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), T([1, 16, 320, 320], f16)), {})
+cnt: 1, ((T([1, 32, 320, 320], f16), T([1, 32, 320, 320], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([1, 32, 320, 320], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.1, 0.001), {})
+cnt: 3, ((T([1, 16, 320, 320], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), False, 0.1, 0.001), {})
+cnt: 1, ((T([1, 96, 320, 320], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), False, 0.1, 0.001), {})
+cnt: 1, ((T([1, 96, 160, 160], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), False, 0.1, 0.001), {})
+cnt: 3, ((T([1, 24, 160, 160], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), False, 0.1, 0.001), {})
+cnt: 5, ((T([1, 144, 160, 160], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), False, 0.1, 0.001), {})
+cnt: 1, ((T([1, 144, 80, 80], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), False, 0.1, 0.001), {})
+cnt: 3, ((T([1, 40, 80, 80], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), False, 0.1, 0.001), {})
+cnt: 5, ((T([1, 240, 80, 80], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), False, 0.1, 0.001), {})
+cnt: 1, ((T([1, 240, 40, 40], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), False, 0.1, 0.001), {})
+cnt: 4, ((T([1, 80, 40, 40], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), False, 0.1, 0.001), {})
+cnt: 8, ((T([1, 480, 40, 40], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), False, 0.1, 0.001), {})
+cnt: 4, ((T([1, 112, 40, 40], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), False, 0.1, 0.001), {})
+cnt: 7, ((T([1, 672, 40, 40], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), False, 0.1, 0.001), {})
+cnt: 1, ((T([1, 672, 20, 20], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), False, 0.1, 0.001), {})
+cnt: 5, ((T([1, 192, 20, 20], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), False, 0.1, 0.001), {})
+cnt: 10, ((T([1, 1152, 20, 20], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f16), False, 0.1, 0.001), {})
+cnt: 2, ((T([1, 320, 20, 20], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), False, 0.1, 0.001), {})
+cnt: 2, ((T([1, 1920, 20, 20], f16), T([1920], f16), T([1920], f16), T([1920], f16), T([1920], f16), False, 0.1, 0.001), {})
+cnt: 17, ((T([1, 88, 20, 20], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f16), False, 0.01, 0.001), {})
+cnt: 14, ((T([1, 88, 10, 10], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f16), False, 0.01, 0.001), {})
+cnt: 16, ((T([1, 88, 40, 40], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f16), False, 0.01, 0.001), {})
+cnt: 11, ((T([1, 88, 80, 80], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f16), False, 0.01, 0.001), {})
+cnt: 10, ((T([1, 88, 5, 5], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f16), False, 0.01, 0.001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 10, ((T([1, 88, 5, 5], f16), T([1, 88, 5, 5], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f32), T([88], f32), False, 0.001, [True, True, True]), {})
+cnt: 14, ((T([1, 88, 10, 10], f16), T([1, 88, 10, 10], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f32), T([88], f32), False, 0.001, [True, True, True]), {})
+cnt: 17, ((T([1, 88, 20, 20], f16), T([1, 88, 20, 20], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f32), T([88], f32), False, 0.001, [True, True, True]), {})
+cnt: 16, ((T([1, 88, 40, 40], f16), T([1, 88, 40, 40], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f32), T([88], f32), False, 0.001, [True, True, True]), {})
+cnt: 11, ((T([1, 88, 80, 80], f16), T([1, 88, 80, 80], f16), T([88], f16), T([88], f16), T([88], f16), T([88], f32), T([88], f32), False, 0.001, [True, True, True]), {})
+cnt: 2, ((T([1, 320, 20, 20], f16), T([1, 320, 20, 20], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), False, 0.001, [True, True, True]), {})
+cnt: 2, ((T([1, 1920, 20, 20], f16), T([1, 1920, 20, 20], f16), T([1920], f16), T([1920], f16), T([1920], f16), T([1920], f32), T([1920], f32), False, 0.001, [True, True, True]), {})
+cnt: 10, ((T([1, 1152, 20, 20], f16), T([1, 1152, 20, 20], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f32), T([1152], f32), False, 0.001, [True, True, True]), {})
+cnt: 5, ((T([1, 192, 20, 20], f16), T([1, 192, 20, 20], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([1, 672, 20, 20], f16), T([1, 672, 20, 20], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), False, 0.001, [True, True, True]), {})
+cnt: 7, ((T([1, 672, 40, 40], f16), T([1, 672, 40, 40], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), False, 0.001, [True, True, True]), {})
+cnt: 4, ((T([1, 112, 40, 40], f16), T([1, 112, 40, 40], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), False, 0.001, [True, True, True]), {})
+cnt: 8, ((T([1, 480, 40, 40], f16), T([1, 480, 40, 40], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), False, 0.001, [True, True, True]), {})
+cnt: 4, ((T([1, 80, 40, 40], f16), T([1, 80, 40, 40], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([1, 240, 40, 40], f16), T([1, 240, 40, 40], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), False, 0.001, [True, True, True]), {})
+cnt: 5, ((T([1, 240, 80, 80], f16), T([1, 240, 80, 80], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), False, 0.001, [True, True, True]), {})
+cnt: 3, ((T([1, 40, 80, 80], f16), T([1, 40, 80, 80], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([1, 144, 80, 80], f16), T([1, 144, 80, 80], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), False, 0.001, [True, True, True]), {})
+cnt: 5, ((T([1, 144, 160, 160], f16), T([1, 144, 160, 160], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), False, 0.001, [True, True, True]), {})
+cnt: 3, ((T([1, 24, 160, 160], f16), T([1, 24, 160, 160], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([1, 96, 160, 160], f16), T([1, 96, 160, 160], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), False, 0.001, [True, True, True]), {})
+cnt: 1, ((T([1, 96, 320, 320], f16), T([1, 96, 320, 320], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), False, 0.001, [True, True, True]), {})
+cnt: 3, ((T([1, 16, 320, 320], f16), T([1, 16, 320, 320], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), False, 0.001, [True, True, True]), {})
+cnt: 2, ((T([1, 32, 320, 320], f16), T([1, 32, 320, 320], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 0.001, [True, True, True]), {})
+Operator: aten.neg.default
+cnt: 2, ((T([5000], f32, stride=(4,)),), {})
+cnt: 8, ((T([1, 88, 5, 5], f16),), {})
+cnt: 20, ((T([1, 88, 10, 10], f16),), {})
+cnt: 20, ((T([1, 88, 20, 20], f16),), {})
+cnt: 20, ((T([1, 88, 40, 40], f16),), {})
+cnt: 8, ((T([1, 88, 80, 80], f16),), {})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([100, 1], f32, stride=(0, 0)), [5000, 1]), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([100, 4], f32), [5000, 4]), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([1, 5000, 1], f16), [1, 5000, 90]), {})
+cnt: 1, ((T([1, 5000, 90], f16), [1, 76725, 90]), {})
+cnt: 1, ((T([1, 5000, 4], f16), [1, 76725, 4]), {})
+Operator: aten.relu.default
+cnt: 20, ((T([2], f16),), {})
+cnt: 12, ((T([3], f16),), {})
+Operator: aten.remainder.Scalar
+cnt: 1, ((T([1, 5000], i64), 90), {})
+Operator: aten.scatter_add_.default
+cnt: 1, ((T([1, 5000, 90], f16), 2, T([1, 5000, 1], i64), T([1, 5000, 1], f16)), {})
+cnt: 1, ((T([1, 76725, 90], f16), 1, T([1, 5000, 90], i64, stride=(5000, 1, 0)), T([1, 5000, 90], f16)), {})
+cnt: 1, ((T([1, 76725, 4], f16), 1, T([1, 5000, 4], i64, stride=(5000, 1, 0)), T([1, 5000, 4], f16)), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([5000, 4], f16), [1, 5000, 4], 0, 0), {})
+cnt: 1, ((T([5000, 1], f16), [1, 5000, 1], 0, 0), {})
+cnt: 20, ((T([], f16), [2], 0, 1), {})
+cnt: 20, ((T([], f16), [2], 0, 0), {})
+cnt: 12, ((T([], f16), [3], 0, 2), {})
+cnt: 12, ((T([], f16), [3], 0, 1), {})
+cnt: 12, ((T([], f16), [3], 0, 0), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([1, 32, 1, 1], f16),), {})
+cnt: 1, ((T([1, 16, 1, 1], f16),), {})
+cnt: 1, ((T([1, 96, 1, 1], f16),), {})
+cnt: 3, ((T([1, 144, 1, 1], f16),), {})
+cnt: 3, ((T([1, 240, 1, 1], f16),), {})
+cnt: 4, ((T([1, 480, 1, 1], f16),), {})
+cnt: 4, ((T([1, 672, 1, 1], f16),), {})
+cnt: 5, ((T([1, 1152, 1, 1], f16),), {})
+cnt: 1, ((T([1, 1920, 1, 1], f16),), {})
+cnt: 1, ((T([5000, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([5000, 1], f16), T([5000, 1], f16)), {})
+cnt: 1, ((T([1, 1920, 1, 1], f16), T([1, 1920, 1, 1], f16)), {})
+cnt: 5, ((T([1, 1152, 1, 1], f16), T([1, 1152, 1, 1], f16)), {})
+cnt: 4, ((T([1, 672, 1, 1], f16), T([1, 672, 1, 1], f16)), {})
+cnt: 4, ((T([1, 480, 1, 1], f16), T([1, 480, 1, 1], f16)), {})
+cnt: 3, ((T([1, 240, 1, 1], f16), T([1, 240, 1, 1], f16)), {})
+cnt: 3, ((T([1, 144, 1, 1], f16), T([1, 144, 1, 1], f16)), {})
+cnt: 1, ((T([1, 96, 1, 1], f16), T([1, 96, 1, 1], f16)), {})
+cnt: 1, ((T([1, 16, 1, 1], f16), T([1, 16, 1, 1], f16)), {})
+cnt: 1, ((T([1, 32, 1, 1], f16), T([1, 32, 1, 1], f16)), {})
+Operator: aten.silu_.default
+cnt: 2, ((T([1, 32, 320, 320], f16),), {})
+cnt: 1, ((T([1, 8, 1, 1], f16),), {})
+cnt: 1, ((T([1, 16, 320, 320], f16),), {})
+cnt: 2, ((T([1, 4, 1, 1], f16),), {})
+cnt: 1, ((T([1, 96, 320, 320], f16),), {})
+cnt: 1, ((T([1, 96, 160, 160], f16),), {})
+cnt: 5, ((T([1, 144, 160, 160], f16),), {})
+cnt: 3, ((T([1, 6, 1, 1], f16),), {})
+cnt: 1, ((T([1, 144, 80, 80], f16),), {})
+cnt: 5, ((T([1, 240, 80, 80], f16),), {})
+cnt: 3, ((T([1, 10, 1, 1], f16),), {})
+cnt: 1, ((T([1, 240, 40, 40], f16),), {})
+cnt: 8, ((T([1, 480, 40, 40], f16),), {})
+cnt: 4, ((T([1, 20, 1, 1], f16),), {})
+cnt: 7, ((T([1, 672, 40, 40], f16),), {})
+cnt: 4, ((T([1, 28, 1, 1], f16),), {})
+cnt: 1, ((T([1, 672, 20, 20], f16),), {})
+cnt: 10, ((T([1, 1152, 20, 20], f16),), {})
+cnt: 5, ((T([1, 48, 1, 1], f16),), {})
+cnt: 2, ((T([1, 1920, 20, 20], f16),), {})
+cnt: 1, ((T([1, 80, 1, 1], f16),), {})
+cnt: 14, ((T([1, 88, 10, 10], f16),), {})
+cnt: 14, ((T([1, 88, 20, 20], f16),), {})
+cnt: 14, ((T([1, 88, 40, 40], f16),), {})
+cnt: 10, ((T([1, 88, 80, 80], f16),), {})
+cnt: 10, ((T([1, 88, 5, 5], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 10, ((T([1, 88, 5, 5], f16), T([1, 88, 5, 5], f16)), {})
+cnt: 14, ((T([1, 88, 10, 10], f16), T([1, 88, 10, 10], f16)), {})
+cnt: 14, ((T([1, 88, 20, 20], f16), T([1, 88, 20, 20], f16)), {})
+cnt: 14, ((T([1, 88, 40, 40], f16), T([1, 88, 40, 40], f16)), {})
+cnt: 10, ((T([1, 88, 80, 80], f16), T([1, 88, 80, 80], f16)), {})
+cnt: 1, ((T([1, 80, 1, 1], f16), T([1, 80, 1, 1], f16)), {})
+cnt: 2, ((T([1, 1920, 20, 20], f16), T([1, 1920, 20, 20], f16)), {})
+cnt: 5, ((T([1, 48, 1, 1], f16), T([1, 48, 1, 1], f16)), {})
+cnt: 10, ((T([1, 1152, 20, 20], f16), T([1, 1152, 20, 20], f16)), {})
+cnt: 4, ((T([1, 28, 1, 1], f16), T([1, 28, 1, 1], f16)), {})
+cnt: 1, ((T([1, 672, 20, 20], f16), T([1, 672, 20, 20], f16)), {})
+cnt: 7, ((T([1, 672, 40, 40], f16), T([1, 672, 40, 40], f16)), {})
+cnt: 4, ((T([1, 20, 1, 1], f16), T([1, 20, 1, 1], f16)), {})
+cnt: 8, ((T([1, 480, 40, 40], f16), T([1, 480, 40, 40], f16)), {})
+cnt: 3, ((T([1, 10, 1, 1], f16), T([1, 10, 1, 1], f16)), {})
+cnt: 1, ((T([1, 240, 40, 40], f16), T([1, 240, 40, 40], f16)), {})
+cnt: 5, ((T([1, 240, 80, 80], f16), T([1, 240, 80, 80], f16)), {})
+cnt: 3, ((T([1, 6, 1, 1], f16), T([1, 6, 1, 1], f16)), {})
+cnt: 1, ((T([1, 144, 80, 80], f16), T([1, 144, 80, 80], f16)), {})
+cnt: 5, ((T([1, 144, 160, 160], f16), T([1, 144, 160, 160], f16)), {})
+cnt: 2, ((T([1, 4, 1, 1], f16), T([1, 4, 1, 1], f16)), {})
+cnt: 1, ((T([1, 96, 160, 160], f16), T([1, 96, 160, 160], f16)), {})
+cnt: 1, ((T([1, 96, 320, 320], f16), T([1, 96, 320, 320], f16)), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), T([1, 16, 320, 320], f16)), {})
+cnt: 1, ((T([1, 8, 1, 1], f16), T([1, 8, 1, 1], f16)), {})
+cnt: 2, ((T([1, 32, 320, 320], f16), T([1, 32, 320, 320], f16)), {})
+Operator: aten.stack.default
+cnt: 4, (([T([1, 88, 10, 10], f16), T([1, 88, 10, 10], f16)], -1), {})
+cnt: 4, (([T([1, 88, 20, 20], f16), T([1, 88, 20, 20], f16)], -1), {})
+cnt: 4, (([T([1, 88, 40, 40], f16), T([1, 88, 40, 40], f16)], -1), {})
+cnt: 4, (([T([1, 88, 80, 80], f16), T([1, 88, 80, 80], f16)], -1), {})
+cnt: 4, (([T([1, 88, 40, 40], f16), T([1, 88, 40, 40], f16), T([1, 88, 40, 40], f16)], -1), {})
+cnt: 4, (([T([1, 88, 20, 20], f16), T([1, 88, 20, 20], f16), T([1, 88, 20, 20], f16)], -1), {})
+cnt: 4, (([T([1, 88, 10, 10], f16), T([1, 88, 10, 10], f16), T([1, 88, 10, 10], f16)], -1), {})
+cnt: 4, (([T([1, 88, 5, 5], f16), T([1, 88, 5, 5], f16)], -1), {})
+cnt: 2, (([T([5000], f32), T([5000], f32), T([5000], f32), T([5000], f32)], 1), {})
+cnt: 1, (([T([100, 6], f32)],), {})
+Operator: aten.sub.Tensor
+cnt: 2, ((T([5000], f16, stride=(4,)), T([5000], f16, stride=(4,))), {})
+cnt: 2, ((T([5000], f32), T([5000], f32)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([1, 1920, 20, 20], f16), [2, 3], True), {})
+cnt: 5, ((T([1, 1152, 20, 20], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 672, 20, 20], f16), [2, 3], True), {})
+cnt: 3, ((T([1, 672, 40, 40], f16), [2, 3], True), {})
+cnt: 4, ((T([1, 480, 40, 40], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 240, 40, 40], f16), [2, 3], True), {})
+cnt: 2, ((T([1, 240, 80, 80], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 144, 80, 80], f16), [2, 3], True), {})
+cnt: 2, ((T([1, 144, 160, 160], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 96, 160, 160], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 16, 320, 320], f16), [2, 3], True), {})
+cnt: 1, ((T([1, 32, 320, 320], f16), [2, 3], True), {})
+Operator: aten.sum.default
+cnt: 20, ((T([2], f16),), {})
+cnt: 12, ((T([3], f16),), {})
+cnt: 1, ((T([1, 100, 6], f32),), {})
+cnt: 16, ((T([1, 88, 5, 5], f16),), {})
+cnt: 40, ((T([1, 88, 10, 10], f16),), {})
+cnt: 40, ((T([1, 88, 20, 20], f16),), {})
+cnt: 40, ((T([1, 88, 40, 40], f16),), {})
+cnt: 16, ((T([1, 88, 80, 80], f16),), {})
+Operator: aten.sum.dim_IntList
+cnt: 4, ((T([1, 88, 10, 10, 2], f16), [-1]), {})
+cnt: 4, ((T([1, 88, 20, 20, 2], f16), [-1]), {})
+cnt: 4, ((T([1, 88, 40, 40, 2], f16), [-1]), {})
+cnt: 4, ((T([1, 88, 80, 80, 2], f16), [-1]), {})
+cnt: 4, ((T([1, 88, 40, 40, 3], f16), [-1]), {})
+cnt: 4, ((T([1, 88, 20, 20, 3], f16), [-1]), {})
+cnt: 4, ((T([1, 88, 10, 10, 3], f16), [-1]), {})
+cnt: 4, ((T([1, 88, 5, 5, 2], f16), [-1]), {})
+Operator: aten.threshold_backward.default
+cnt: 20, ((T([2], f16), T([2], f16), 0), {})
+cnt: 12, ((T([3], f16), T([3], f16), 0), {})
+Operator: aten.topk.default
+cnt: 1, ((T([1, 6905250], f16), 5000, 1), {})
+Operator: aten.unbind.int
+cnt: 2, ((T([5000, 4], f32), 1), {})
+cnt: 1, ((T([1, 100, 6], f32, stride=(0, 0, 0)),), {})
+cnt: 4, ((T([1, 88, 5, 5, 2], f16, stride=(2200, 25, 5, 1, 0)), -1), {})
+cnt: 4, ((T([1, 88, 10, 10, 3], f16, stride=(8800, 100, 10, 1, 0)), -1), {})
+cnt: 4, ((T([1, 88, 20, 20, 3], f16, stride=(35200, 400, 20, 1, 0)), -1), {})
+cnt: 4, ((T([1, 88, 40, 40, 3], f16, stride=(140800, 1600, 40, 1, 0)), -1), {})
+cnt: 4, ((T([1, 88, 80, 80, 2], f16, stride=(563200, 6400, 80, 1, 0)), -1), {})
+cnt: 4, ((T([1, 88, 40, 40, 2], f16, stride=(140800, 1600, 40, 1, 0)), -1), {})
+cnt: 4, ((T([1, 88, 20, 20, 2], f16, stride=(35200, 400, 20, 1, 0)), -1), {})
+cnt: 4, ((T([1, 88, 10, 10, 2], f16, stride=(8800, 100, 10, 1, 0)), -1), {})
+Operator: aten.upsample_nearest2d.vec
+cnt: 4, ((T([1, 88, 5, 5], f16), [10, 10], None), {})
+cnt: 4, ((T([1, 88, 10, 10], f16), [20, 20], None), {})
+cnt: 4, ((T([1, 88, 20, 20], f16), [40, 40], None), {})
+cnt: 4, ((T([1, 88, 40, 40], f16), [80, 80], None), {})
+Operator: aten.upsample_nearest2d_backward.vec
+cnt: 4, ((T([1, 88, 80, 80], f16), [80, 80], [1, 88, 40, 40], None), {})
+cnt: 4, ((T([1, 88, 40, 40], f16), [40, 40], [1, 88, 20, 20], None), {})
+cnt: 4, ((T([1, 88, 20, 20], f16), [20, 20], [1, 88, 10, 10], None), {})
+cnt: 4, ((T([1, 88, 10, 10], f16), [10, 10], [1, 88, 5, 5], None), {})
+Operator: aten.where.self
+cnt: 1, ((T([5000, 4], b8), T([5000, 4], f32), T([5000, 4], f32)), {})
+cnt: 1, ((T([5000, 4], b8), T([5000, 4], f32), T([], f32)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_efficientnet_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_efficientnet_training.txt
new file mode 100644
index 0000000000000..1f004ded91be3
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_efficientnet_training.txt
@@ -0,0 +1,295 @@
+Operator: aten.add.Tensor
+cnt: 2, ((T([32, 24, 56, 56], f16), T([32, 24, 56, 56], f16)), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([32, 40, 28, 28], f16)), {})
+cnt: 4, ((T([32, 80, 14, 14], f16), T([32, 80, 14, 14], f16)), {})
+cnt: 4, ((T([32, 112, 14, 14], f16), T([32, 112, 14, 14], f16)), {})
+cnt: 6, ((T([32, 192, 7, 7], f16), T([32, 192, 7, 7], f16)), {})
+cnt: 4, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16)), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 7, 7], f16)), {})
+cnt: 2, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16)), {})
+cnt: 3, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16)), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 14, 14], f16)), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([32, 240, 28, 28], f16)), {})
+cnt: 1, ((T([32, 144, 28, 28], f16), T([32, 144, 28, 28], f16)), {})
+cnt: 1, ((T([32, 144, 56, 56], f16), T([32, 144, 56, 56], f16)), {})
+cnt: 1, ((T([32, 96, 56, 56], f16), T([32, 96, 56, 56], f16)), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 1280], f16), T([1280, 1000], f16, stride=(1, 1280))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+cnt: 2, ((T([32, 32, 112, 112], f16),), {})
+cnt: 1, ((T([32, 8, 1, 1], f16),), {})
+cnt: 1, ((T([32, 96, 112, 112], f16),), {})
+cnt: 1, ((T([32, 96, 56, 56], f16),), {})
+cnt: 1, ((T([32, 4, 1, 1], f16),), {})
+cnt: 3, ((T([32, 144, 56, 56], f16),), {})
+cnt: 2, ((T([32, 6, 1, 1], f16),), {})
+cnt: 1, ((T([32, 144, 28, 28], f16),), {})
+cnt: 3, ((T([32, 240, 28, 28], f16),), {})
+cnt: 2, ((T([32, 10, 1, 1], f16),), {})
+cnt: 1, ((T([32, 240, 14, 14], f16),), {})
+cnt: 6, ((T([32, 480, 14, 14], f16),), {})
+cnt: 3, ((T([32, 20, 1, 1], f16),), {})
+cnt: 5, ((T([32, 672, 14, 14], f16),), {})
+cnt: 3, ((T([32, 28, 1, 1], f16),), {})
+cnt: 1, ((T([32, 672, 7, 7], f16),), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16),), {})
+cnt: 4, ((T([32, 48, 1, 1], f16),), {})
+cnt: 1, ((T([32, 1280, 7, 7], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), {})
+cnt: 1, ((T([32, 32, 1, 1], f16), T([8, 32, 1, 1], f16), T([8], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 8, 1, 1], f16), T([32, 8, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([16, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([96, 16, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 96, 112, 112], f16), T([96, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 96), {})
+cnt: 1, ((T([32, 96, 1, 1], f16), T([4, 96, 1, 1], f16), T([4], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 4, 1, 1], f16), T([96, 4, 1, 1], f16), T([96], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 96, 56, 56], f16), T([24, 96, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 24, 56, 56], f16), T([144, 24, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 144, 56, 56], f16), T([144, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 144), {})
+cnt: 2, ((T([32, 144, 1, 1], f16), T([6, 144, 1, 1], f16), T([6], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 6, 1, 1], f16), T([144, 6, 1, 1], f16), T([144], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 144, 56, 56], f16), T([24, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 144, 56, 56], f16), T([144, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 144), {})
+cnt: 1, ((T([32, 144, 28, 28], f16), T([40, 144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([240, 40, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([240, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 240), {})
+cnt: 2, ((T([32, 240, 1, 1], f16), T([10, 240, 1, 1], f16), T([10], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 10, 1, 1], f16), T([240, 10, 1, 1], f16), T([240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([40, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([240, 1, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 240), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([80, 240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 80, 14, 14], f16), T([480, 80, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 480, 14, 14], f16), T([480, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 480), {})
+cnt: 3, ((T([32, 480, 1, 1], f16), T([20, 480, 1, 1], f16), T([20], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 20, 1, 1], f16), T([480, 20, 1, 1], f16), T([480], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 480, 14, 14], f16), T([80, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([480, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 480), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([112, 480, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 112, 14, 14], f16), T([672, 112, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 3, ((T([32, 672, 1, 1], f16), T([28, 672, 1, 1], f16), T([28], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 28, 1, 1], f16), T([672, 28, 1, 1], f16), T([672], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 672, 14, 14], f16), T([112, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 672, 14, 14], f16), T([672, 1, 5, 5], f16), None, [2, 2], [2, 2], [1, 1], False, [0, 0], 672), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([192, 672, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1152), {})
+cnt: 4, ((T([32, 1152, 1, 1], f16), T([48, 1152, 1, 1], f16), T([48], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 48, 1, 1], f16), T([1152, 48, 1, 1], f16), T([1152], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1152), {})
+cnt: 1, ((T([32, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 1280, 7, 7], f16), T([32, 320, 7, 7], f16), T([1280, 320, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 320, 7, 7], f16), T([32, 1152, 7, 7], f16), T([320, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 1152, 1, 1], f16), T([32, 48, 1, 1], f16), T([1152, 48, 1, 1], f16), [1152], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([32, 48, 1, 1], f16), T([32, 1152, 1, 1], f16), T([48, 1152, 1, 1], f16), [48], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16), T([1152, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 4, ((T([32, 1152, 7, 7], f16), T([32, 192, 7, 7], f16), T([1152, 192, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 192, 7, 7], f16), T([32, 1152, 7, 7], f16), T([192, 1152, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16), T([1152, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 1152, [True, True, False]), {})
+cnt: 1, ((T([32, 192, 7, 7], f16), T([32, 672, 7, 7], f16), T([192, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 672, 1, 1], f16), T([32, 28, 1, 1], f16), T([672, 28, 1, 1], f16), [672], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 28, 1, 1], f16), T([32, 672, 1, 1], f16), T([28, 672, 1, 1], f16), [28], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 3, ((T([32, 672, 14, 14], f16), T([32, 112, 14, 14], f16), T([672, 112, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 112, 14, 14], f16), T([32, 672, 14, 14], f16), T([112, 672, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16), T([672, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 672, [True, True, False]), {})
+cnt: 1, ((T([32, 112, 14, 14], f16), T([32, 480, 14, 14], f16), T([112, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 480, 1, 1], f16), T([32, 20, 1, 1], f16), T([480, 20, 1, 1], f16), [480], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([32, 20, 1, 1], f16), T([32, 480, 1, 1], f16), T([20, 480, 1, 1], f16), [20], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16), T([480, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 3, ((T([32, 480, 14, 14], f16), T([32, 80, 14, 14], f16), T([480, 80, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 80, 14, 14], f16), T([32, 480, 14, 14], f16), T([80, 480, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16), T([480, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 480, [True, True, False]), {})
+cnt: 1, ((T([32, 80, 14, 14], f16), T([32, 240, 14, 14], f16), T([80, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 240, 1, 1], f16), T([32, 10, 1, 1], f16), T([240, 10, 1, 1], f16), [240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 10, 1, 1], f16), T([32, 240, 1, 1], f16), T([10, 240, 1, 1], f16), [10], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 28, 28], f16), T([240, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 2, ((T([32, 240, 28, 28], f16), T([32, 40, 28, 28], f16), T([240, 40, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 40, 28, 28], f16), T([32, 240, 28, 28], f16), T([40, 240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([32, 240, 28, 28], f16), T([240, 1, 5, 5], f16), [0], [1, 1], [2, 2], [1, 1], False, [0, 0], 240, [True, True, False]), {})
+cnt: 1, ((T([32, 40, 28, 28], f16), T([32, 144, 28, 28], f16), T([40, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 144, 1, 1], f16), T([32, 6, 1, 1], f16), T([144, 6, 1, 1], f16), [144], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([32, 6, 1, 1], f16), T([32, 144, 1, 1], f16), T([6, 144, 1, 1], f16), [6], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 144, 28, 28], f16), T([32, 144, 56, 56], f16), T([144, 1, 5, 5], f16), [0], [2, 2], [2, 2], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 2, ((T([32, 144, 56, 56], f16), T([32, 24, 56, 56], f16), T([144, 24, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 24, 56, 56], f16), T([32, 144, 56, 56], f16), T([24, 144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 144, 56, 56], f16), T([32, 144, 56, 56], f16), T([144, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 144, [True, True, False]), {})
+cnt: 1, ((T([32, 24, 56, 56], f16), T([32, 96, 56, 56], f16), T([24, 96, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 96, 1, 1], f16), T([32, 4, 1, 1], f16), T([96, 4, 1, 1], f16), [96], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 4, 1, 1], f16), T([32, 96, 1, 1], f16), T([4, 96, 1, 1], f16), [4], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 96, 56, 56], f16), T([32, 96, 112, 112], f16), T([96, 1, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 96, [True, True, False]), {})
+cnt: 1, ((T([32, 96, 112, 112], f16), T([32, 16, 112, 112], f16), T([96, 16, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 32, 112, 112], f16), T([16, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 32, 1, 1], f16), T([32, 8, 1, 1], f16), T([32, 8, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 8, 1, 1], f16), T([32, 32, 1, 1], f16), T([8, 32, 1, 1], f16), [8], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), T([32, 1, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 32, [True, True, False]), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 1280, 7, 7], f16, stride=(1280, 1, 0, 0)), 49), {})
+cnt: 4, ((T([32, 1152, 7, 7], f16, stride=(1152, 1, 0, 0)), 49), {})
+cnt: 1, ((T([32, 672, 7, 7], f16, stride=(672, 1, 0, 0)), 49), {})
+cnt: 2, ((T([32, 672, 14, 14], f16, stride=(672, 1, 0, 0)), 196), {})
+cnt: 3, ((T([32, 480, 14, 14], f16, stride=(480, 1, 0, 0)), 196), {})
+cnt: 1, ((T([32, 240, 14, 14], f16, stride=(240, 1, 0, 0)), 196), {})
+cnt: 1, ((T([32, 240, 28, 28], f16, stride=(240, 1, 0, 0)), 784), {})
+cnt: 1, ((T([32, 144, 28, 28], f16, stride=(144, 1, 0, 0)), 784), {})
+cnt: 1, ((T([32, 144, 56, 56], f16, stride=(144, 1, 0, 0)), 3136), {})
+cnt: 1, ((T([32, 96, 56, 56], f16, stride=(96, 1, 0, 0)), 3136), {})
+cnt: 1, ((T([32, 32, 112, 112], f16, stride=(32, 1, 0, 0)), 12544), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 32000), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 32, 112, 112], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 96, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 144, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 144, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([32, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 2, ((T([32, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 4, ((T([32, 1152, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 1280, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), T([1000, 1280], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(0, 0)), T([32, 1280], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32, 32, 1, 1], f16)), {})
+cnt: 2, ((T([32, 96, 56, 56], f16), T([32, 96, 1, 1], f16)), {})
+cnt: 2, ((T([32, 144, 56, 56], f16), T([32, 144, 1, 1], f16)), {})
+cnt: 2, ((T([32, 144, 28, 28], f16), T([32, 144, 1, 1], f16)), {})
+cnt: 2, ((T([32, 240, 28, 28], f16), T([32, 240, 1, 1], f16)), {})
+cnt: 2, ((T([32, 240, 14, 14], f16), T([32, 240, 1, 1], f16)), {})
+cnt: 6, ((T([32, 480, 14, 14], f16), T([32, 480, 1, 1], f16)), {})
+cnt: 4, ((T([32, 672, 14, 14], f16), T([32, 672, 1, 1], f16)), {})
+cnt: 2, ((T([32, 672, 7, 7], f16), T([32, 672, 1, 1], f16)), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16), T([32, 1152, 1, 1], f16)), {})
+cnt: 4, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16)), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 7, 7], f16)), {})
+cnt: 2, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16)), {})
+cnt: 3, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16)), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 14, 14], f16)), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), T([32, 240, 28, 28], f16)), {})
+cnt: 1, ((T([32, 144, 28, 28], f16), T([32, 144, 28, 28], f16)), {})
+cnt: 1, ((T([32, 144, 56, 56], f16), T([32, 144, 56, 56], f16)), {})
+cnt: 1, ((T([32, 96, 56, 56], f16), T([32, 96, 56, 56], f16)), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f16), False, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f16), False, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f16), False, 0.1, 1e-05), {})
+cnt: 6, ((T([32, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([32, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f16), False, 0.1, 1e-05), {})
+cnt: 4, ((T([32, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), False, 0.1, 1e-05), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([32, 1280, 7, 7], f16), T([32, 1280, 7, 7], f16), T([1280], f16), T([1280], f16), T([1280], f16), T([1280], f32), T([1280], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 320, 7, 7], f16), T([32, 320, 7, 7], f16), T([320], f16), T([320], f16), T([320], f16), T([320], f32), T([320], f32), False, 1e-05, [True, True, True]), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16), T([1152], f16), T([1152], f16), T([1152], f16), T([1152], f32), T([1152], f32), False, 1e-05, [True, True, True]), {})
+cnt: 4, ((T([32, 192, 7, 7], f16), T([32, 192, 7, 7], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 7, 7], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16), T([672], f16), T([672], f16), T([672], f16), T([672], f32), T([672], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 112, 14, 14], f16), T([32, 112, 14, 14], f16), T([112], f16), T([112], f16), T([112], f16), T([112], f32), T([112], f32), False, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16), T([480], f16), T([480], f16), T([480], f16), T([480], f32), T([480], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 80, 14, 14], f16), T([32, 80, 14, 14], f16), T([80], f16), T([80], f16), T([80], f16), T([80], f32), T([80], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 14, 14], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 240, 28, 28], f16), T([32, 240, 28, 28], f16), T([240], f16), T([240], f16), T([240], f16), T([240], f32), T([240], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 40, 28, 28], f16), T([32, 40, 28, 28], f16), T([40], f16), T([40], f16), T([40], f16), T([40], f32), T([40], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 144, 28, 28], f16), T([32, 144, 28, 28], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 144, 56, 56], f16), T([32, 144, 56, 56], f16), T([144], f16), T([144], f16), T([144], f16), T([144], f32), T([144], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 24, 56, 56], f16), T([32, 24, 56, 56], f16), T([24], f16), T([24], f16), T([24], f16), T([24], f32), T([24], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 96, 56, 56], f16), T([32, 96, 56, 56], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 96, 112, 112], f16), T([32, 96, 112, 112], f16), T([96], f16), T([96], f16), T([96], f16), T([96], f32), T([96], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 16, 112, 112], f16), T([32, 16, 112, 112], f16), T([16], f16), T([16], f16), T([16], f16), T([16], f32), T([16], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([32, 32, 1, 1], f16),), {})
+cnt: 1, ((T([32, 96, 1, 1], f16),), {})
+cnt: 2, ((T([32, 144, 1, 1], f16),), {})
+cnt: 2, ((T([32, 240, 1, 1], f16),), {})
+cnt: 3, ((T([32, 480, 1, 1], f16),), {})
+cnt: 3, ((T([32, 672, 1, 1], f16),), {})
+cnt: 4, ((T([32, 1152, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 4, ((T([32, 1152, 1, 1], f16), T([32, 1152, 1, 1], f16)), {})
+cnt: 3, ((T([32, 672, 1, 1], f16), T([32, 672, 1, 1], f16)), {})
+cnt: 3, ((T([32, 480, 1, 1], f16), T([32, 480, 1, 1], f16)), {})
+cnt: 2, ((T([32, 240, 1, 1], f16), T([32, 240, 1, 1], f16)), {})
+cnt: 2, ((T([32, 144, 1, 1], f16), T([32, 144, 1, 1], f16)), {})
+cnt: 1, ((T([32, 96, 1, 1], f16), T([32, 96, 1, 1], f16)), {})
+cnt: 1, ((T([32, 32, 1, 1], f16), T([32, 32, 1, 1], f16)), {})
+Operator: aten.silu_.default
+cnt: 2, ((T([32, 32, 112, 112], f16),), {})
+cnt: 1, ((T([32, 8, 1, 1], f16),), {})
+cnt: 1, ((T([32, 96, 112, 112], f16),), {})
+cnt: 1, ((T([32, 96, 56, 56], f16),), {})
+cnt: 1, ((T([32, 4, 1, 1], f16),), {})
+cnt: 3, ((T([32, 144, 56, 56], f16),), {})
+cnt: 2, ((T([32, 6, 1, 1], f16),), {})
+cnt: 1, ((T([32, 144, 28, 28], f16),), {})
+cnt: 3, ((T([32, 240, 28, 28], f16),), {})
+cnt: 2, ((T([32, 10, 1, 1], f16),), {})
+cnt: 1, ((T([32, 240, 14, 14], f16),), {})
+cnt: 6, ((T([32, 480, 14, 14], f16),), {})
+cnt: 3, ((T([32, 20, 1, 1], f16),), {})
+cnt: 5, ((T([32, 672, 14, 14], f16),), {})
+cnt: 3, ((T([32, 28, 1, 1], f16),), {})
+cnt: 1, ((T([32, 672, 7, 7], f16),), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16),), {})
+cnt: 4, ((T([32, 48, 1, 1], f16),), {})
+cnt: 1, ((T([32, 1280, 7, 7], f16),), {})
+Operator: aten.silu_backward.default
+cnt: 1, ((T([32, 1280, 7, 7], f16), T([32, 1280, 7, 7], f16)), {})
+cnt: 4, ((T([32, 48, 1, 1], f16), T([32, 48, 1, 1], f16)), {})
+cnt: 8, ((T([32, 1152, 7, 7], f16), T([32, 1152, 7, 7], f16)), {})
+cnt: 3, ((T([32, 28, 1, 1], f16), T([32, 28, 1, 1], f16)), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), T([32, 672, 7, 7], f16)), {})
+cnt: 5, ((T([32, 672, 14, 14], f16), T([32, 672, 14, 14], f16)), {})
+cnt: 3, ((T([32, 20, 1, 1], f16), T([32, 20, 1, 1], f16)), {})
+cnt: 6, ((T([32, 480, 14, 14], f16), T([32, 480, 14, 14], f16)), {})
+cnt: 2, ((T([32, 10, 1, 1], f16), T([32, 10, 1, 1], f16)), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), T([32, 240, 14, 14], f16)), {})
+cnt: 3, ((T([32, 240, 28, 28], f16), T([32, 240, 28, 28], f16)), {})
+cnt: 2, ((T([32, 6, 1, 1], f16), T([32, 6, 1, 1], f16)), {})
+cnt: 1, ((T([32, 144, 28, 28], f16), T([32, 144, 28, 28], f16)), {})
+cnt: 3, ((T([32, 144, 56, 56], f16), T([32, 144, 56, 56], f16)), {})
+cnt: 1, ((T([32, 4, 1, 1], f16), T([32, 4, 1, 1], f16)), {})
+cnt: 1, ((T([32, 96, 56, 56], f16), T([32, 96, 56, 56], f16)), {})
+cnt: 1, ((T([32, 96, 112, 112], f16), T([32, 96, 112, 112], f16)), {})
+cnt: 1, ((T([32, 8, 1, 1], f16), T([32, 8, 1, 1], f16)), {})
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), [0], True), {})
+cnt: 4, ((T([32, 1152, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 672, 7, 7], f16), [2, 3], True), {})
+cnt: 2, ((T([32, 672, 14, 14], f16), [2, 3], True), {})
+cnt: 3, ((T([32, 480, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 240, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 240, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 144, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 144, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 96, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), [2, 3], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([32, 1000], f16),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_nfnet_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_nfnet_training.txt
new file mode 100644
index 0000000000000..c94aacd7fa2c9
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_nfnet_training.txt
@@ -0,0 +1,289 @@
+Operator: aten.add.Tensor
+cnt: 3, ((T([128, 256, 48, 48], f16), T([128, 256, 48, 48], f16)), {})
+cnt: 6, ((T([128, 512, 24, 24], f16), T([128, 512, 24, 24], f16)), {})
+cnt: 18, ((T([128, 1536, 12, 12], f16), T([128, 1536, 12, 12], f16)), {})
+cnt: 8, ((T([128, 1536, 6, 6], f16), T([128, 1536, 6, 6], f16)), {})
+cnt: 1, ((T([128, 128, 48, 48], f16), T([128, 128, 48, 48], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([128, 3072], f16), T([3072, 1000], f16, stride=(1, 3072))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([128, 256, 48, 48], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([128, 512, 24, 24], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([128, 1536, 12, 12], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([128, 1536, 6, 6], f16), T([128, 1536, 12, 12], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([128, 512, 12, 12], f16), T([128, 512, 24, 24], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([128, 256, 24, 24], f16), T([128, 256, 48, 48], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+Operator: aten.clone.default
+cnt: 1, ((T([128, 3, 192, 192], f16),), {})
+cnt: 1, ((T([128, 256, 48, 48], f16),), {})
+cnt: 2, ((T([128, 512, 24, 24], f16),), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16),), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16),), {})
+Operator: aten.constant_pad_nd.default
+cnt: 1, ((T([128, 3, 192, 192], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 768, 12, 12], f16), [0, 1, 0, 1], 0.0), {})
+cnt: 1, ((T([128, 768, 13, 13], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([128, 768, 25, 25], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([128, 256, 49, 49], f16), [0, -1, 0, -1]), {})
+cnt: 1, ((T([128, 64, 97, 97], f16), [0, -1, 0, -1]), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([128, 3, 193, 193], f16), T([16, 3, 3, 3], f16), T([16], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([32, 16, 3, 3], f16), T([32], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([64, 32, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 64, 97, 97], f16), T([128, 64, 3, 3], f16), T([128], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 48, 48], f16), T([256, 128, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 48, 48], f16), T([128, 128, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 128, 48, 48], f16), T([128, 128, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 128, 1, 1], f16), T([256, 128, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([128, 256, 24, 24], f16), T([512, 256, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), T([256, 256, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 256, 49, 49], f16), T([256, 128, 3, 3], f16), T([256], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 2), {})
+cnt: 3, ((T([128, 256, 24, 24], f16), T([256, 128, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 2, ((T([128, 512, 1, 1], f16), T([256, 512, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 256, 1, 1], f16), T([512, 256, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 24, 24], f16), T([256, 512, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 12, 12], f16), T([1536, 512, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 512, 24, 24], f16), T([768, 512, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 768, 25, 25], f16), T([768, 128, 3, 3], f16), T([768], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 6), {})
+cnt: 11, ((T([128, 768, 12, 12], f16), T([768, 128, 3, 3], f16), T([768], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 6), {})
+cnt: 6, ((T([128, 768, 12, 12], f16), T([1536, 768, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([128, 1536, 1, 1], f16), T([768, 1536, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([128, 768, 1, 1], f16), T([1536, 768, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([768, 1536, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1536, 6, 6], f16), T([1536, 1536, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 768, 13, 13], f16), T([768, 128, 3, 3], f16), T([768], f16), [2, 2], [0, 0], [1, 1], False, [0, 0], 6), {})
+cnt: 5, ((T([128, 768, 6, 6], f16), T([768, 128, 3, 3], f16), T([768], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 6), {})
+cnt: 3, ((T([128, 768, 6, 6], f16), T([1536, 768, 1, 1], f16), T([1536], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), T([768, 1536, 1, 1], f16), T([768], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([128, 1536, 6, 6], f16), T([3072, 1536, 1, 1], f16), T([3072], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([128, 3072, 6, 6], f16), T([128, 1536, 6, 6], f16), T([3072, 1536, 1, 1], f16), [3072], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 9, ((T([128, 1536, 1, 1], f16), T([128, 768, 1, 1], f16), T([1536, 768, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 9, ((T([128, 768, 1, 1], f16), T([128, 1536, 1, 1], f16), T([768, 1536, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), T([128, 768, 6, 6], f16), T([1536, 768, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 5, ((T([128, 768, 6, 6], f16), T([128, 768, 6, 6], f16), T([768, 128, 3, 3], f16), [768], [1, 1], [1, 1], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 2, ((T([128, 768, 6, 6], f16), T([128, 1536, 6, 6], f16), T([768, 1536, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 6, 6], f16), T([128, 768, 13, 13], f16), T([768, 128, 3, 3], f16), [768], [2, 2], [0, 0], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 6, ((T([128, 768, 12, 12], f16), T([128, 1536, 12, 12], f16), T([768, 1536, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1536, 6, 6], f16), T([128, 1536, 6, 6], f16), T([1536, 1536, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([128, 768, 12, 12], f16), T([1536, 768, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 11, ((T([128, 768, 12, 12], f16), T([128, 768, 12, 12], f16), T([768, 128, 3, 3], f16), [768], [1, 1], [1, 1], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 12, 12], f16), T([128, 768, 25, 25], f16), T([768, 128, 3, 3], f16), [768], [2, 2], [0, 0], [1, 1], False, [0, 0], 6, [True, True, True]), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), T([128, 512, 24, 24], f16), T([768, 512, 1, 1], f16), [768], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 1536, 12, 12], f16), T([128, 512, 12, 12], f16), T([1536, 512, 1, 1], f16), [1536], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 512, 1, 1], f16), T([128, 256, 1, 1], f16), T([512, 256, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 256, 1, 1], f16), T([128, 512, 1, 1], f16), T([256, 512, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 512, 24, 24], f16), T([128, 256, 24, 24], f16), T([512, 256, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([128, 256, 24, 24], f16), T([128, 256, 24, 24], f16), T([256, 128, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 24, 24], f16), T([128, 512, 24, 24], f16), T([256, 512, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 24, 24], f16), T([128, 256, 49, 49], f16), T([256, 128, 3, 3], f16), [256], [2, 2], [0, 0], [1, 1], False, [0, 0], 2, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), T([128, 256, 48, 48], f16), T([256, 256, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 128, 1, 1], f16), T([256, 128, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 1, 1], f16), T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), T([128, 128, 48, 48], f16), T([256, 128, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([128, 128, 48, 48], f16), T([128, 128, 48, 48], f16), T([128, 128, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 48, 48], f16), T([128, 128, 48, 48], f16), T([128, 128, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 128, 48, 48], f16), T([128, 64, 97, 97], f16), T([128, 64, 3, 3], f16), [128], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), T([128, 32, 96, 96], f16), T([64, 32, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([128, 16, 96, 96], f16), T([32, 16, 3, 3], f16), [32], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([128, 3, 193, 193], f16), T([16, 3, 3, 3], f16), [16], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([128, 3, 192, 192], f16), T([128, 3, 192, 192], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([128, 3072, 6, 6], f16, stride=(3072, 1, 0, 0)), 36), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16, stride=(1536, 1, 0, 0)), 36), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16, stride=(1536, 1, 0, 0)), 144), {})
+cnt: 2, ((T([128, 512, 24, 24], f16, stride=(512, 1, 0, 0)), 576), {})
+cnt: 1, ((T([128, 256, 48, 48], f16, stride=(256, 1, 0, 0)), 2304), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 128000), {})
+Operator: aten.gelu.default
+cnt: 1, ((T([128, 16, 96, 96], f16),), {})
+cnt: 1, ((T([128, 32, 96, 96], f16),), {})
+cnt: 1, ((T([128, 64, 96, 96], f16),), {})
+cnt: 4, ((T([128, 128, 48, 48], f16),), {})
+cnt: 2, ((T([128, 256, 48, 48], f16),), {})
+cnt: 5, ((T([128, 256, 24, 24], f16),), {})
+cnt: 2, ((T([128, 512, 24, 24], f16),), {})
+cnt: 1, ((T([128, 768, 24, 24], f16),), {})
+cnt: 18, ((T([128, 768, 12, 12], f16),), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16),), {})
+cnt: 8, ((T([128, 768, 6, 6], f16),), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16),), {})
+cnt: 1, ((T([128, 3072, 6, 6], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 1, ((T([128, 3072, 6, 6], f16), T([128, 3072, 6, 6], f16)), {})
+cnt: 8, ((T([128, 768, 6, 6], f16), T([128, 768, 6, 6], f16)), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), T([128, 1536, 6, 6], f16)), {})
+cnt: 18, ((T([128, 768, 12, 12], f16), T([128, 768, 12, 12], f16)), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([128, 1536, 12, 12], f16)), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), T([128, 768, 24, 24], f16)), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), T([128, 512, 24, 24], f16)), {})
+cnt: 5, ((T([128, 256, 24, 24], f16), T([128, 256, 24, 24], f16)), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), T([128, 256, 48, 48], f16)), {})
+cnt: 4, ((T([128, 128, 48, 48], f16), T([128, 128, 48, 48], f16)), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), T([128, 64, 96, 96], f16)), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), T([128, 32, 96, 96], f16)), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), T([128, 16, 96, 96], f16)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([128, 256, 48, 48], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), [2, 3], True), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), [2, 3], True), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 3072, 6, 6], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([128, 1000], f16, stride=(0, 0)), T([1000, 3072], f16)), {})
+cnt: 1, ((T([1000, 128], f16, stride=(0, 0)), T([128, 3072], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 2, ((T([16, 1, 1, 1], f16), 0.19245008972987526), {})
+cnt: 2, ((T([32, 1, 1, 1], f16), 0.08333333333333333), {})
+cnt: 2, ((T([64, 1, 1, 1], f16), 0.05892556509887896), {})
+cnt: 2, ((T([128, 1, 1, 1], f16), 0.041666666666666664), {})
+cnt: 2, ((T([128, 128, 48, 48], f16), 1.0), {})
+cnt: 4, ((T([256, 1, 1, 1], f16), 0.08838834764831845), {})
+cnt: 2, ((T([128, 1, 1, 1], f16), 0.08838834764831845), {})
+cnt: 4, ((T([128, 1, 1, 1], f16), 0.02946278254943948), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), T([128, 256, 1, 1], f16)), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 2.0), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 0.2), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 0.9805806756909201), {})
+cnt: 6, ((T([512, 1, 1, 1], f16), 0.0625), {})
+cnt: 2, ((T([256, 1, 1, 1], f16), 0.0625), {})
+cnt: 8, ((T([256, 1, 1, 1], f16), 0.02946278254943948), {})
+cnt: 4, ((T([128, 512, 24, 24], f16), T([128, 512, 1, 1], f16)), {})
+cnt: 4, ((T([128, 512, 24, 24], f16), 2.0), {})
+cnt: 4, ((T([128, 512, 24, 24], f16), 0.2), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), 0.9805806756909201), {})
+cnt: 2, ((T([256, 1, 1, 1], f16), 0.04419417382415922), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), 0.9622504486493761), {})
+cnt: 2, ((T([1536, 1, 1, 1], f16), 0.04419417382415922), {})
+cnt: 2, ((T([768, 1, 1, 1], f16), 0.04419417382415922), {})
+cnt: 36, ((T([768, 1, 1, 1], f16), 0.02946278254943948), {})
+cnt: 18, ((T([1536, 1, 1, 1], f16), 0.03608439182435161), {})
+cnt: 12, ((T([128, 1536, 12, 12], f16), T([128, 1536, 1, 1], f16)), {})
+cnt: 12, ((T([128, 1536, 12, 12], f16), 2.0), {})
+cnt: 12, ((T([128, 1536, 12, 12], f16), 0.2), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9805806756909201), {})
+cnt: 16, ((T([768, 1, 1, 1], f16), 0.02551551815399144), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9622504486493761), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9449111825230679), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9284766908852592), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.9128709291752768), {})
+cnt: 2, ((T([128, 1536, 12, 12], f16), 0.8980265101338745), {})
+cnt: 2, ((T([1536, 1, 1, 1], f16), 0.02551551815399144), {})
+cnt: 6, ((T([128, 1536, 6, 6], f16), T([128, 1536, 1, 1], f16)), {})
+cnt: 6, ((T([128, 1536, 6, 6], f16), 2.0), {})
+cnt: 6, ((T([128, 1536, 6, 6], f16), 0.2), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), 0.9805806756909201), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), 0.9622504486493761), {})
+cnt: 2, ((T([3072, 1, 1, 1], f16), 0.02551551815399144), {})
+cnt: 1, ((T([128, 3072, 6, 6], f16), 1.7015043497085571), {})
+cnt: 6, ((T([128, 1536, 6, 6], f16), T([128, 1536, 6, 6], f16)), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), T([], f16)), {})
+cnt: 8, ((T([128, 768, 6, 6], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), 1.7015043497085571), {})
+cnt: 18, ((T([128, 768, 12, 12], f16), 1.7015043497085571), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), 1.7015043497085571), {})
+cnt: 12, ((T([128, 1536, 12, 12], f16), T([128, 1536, 12, 12], f16)), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([], f16)), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), 1.7015043497085571), {})
+cnt: 4, ((T([128, 512, 24, 24], f16), T([128, 512, 24, 24], f16)), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), T([], f16)), {})
+cnt: 5, ((T([128, 256, 24, 24], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), T([128, 256, 48, 48], f16)), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), T([], f16)), {})
+cnt: 4, ((T([128, 128, 48, 48], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 16, 96, 96], f16), 1.7015043497085571), {})
+Operator: aten.mul_.Tensor
+cnt: 1, ((T([128, 16, 96, 96], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 32, 96, 96], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 64, 96, 96], f16), 1.7015043497085571), {})
+cnt: 4, ((T([128, 128, 48, 48], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), T([], f16)), {})
+cnt: 2, ((T([128, 256, 48, 48], f16), 1.7015043497085571), {})
+cnt: 5, ((T([128, 256, 24, 24], f16), 1.7015043497085571), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), T([], f16)), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 768, 24, 24], f16), 1.7015043497085571), {})
+cnt: 18, ((T([128, 768, 12, 12], f16), 1.7015043497085571), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), T([], f16)), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), 1.7015043497085571), {})
+cnt: 8, ((T([128, 768, 6, 6], f16), 1.7015043497085571), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), T([], f16)), {})
+cnt: 2, ((T([128, 1536, 6, 6], f16), 1.7015043497085571), {})
+cnt: 1, ((T([128, 3072, 6, 6], f16), 1.7015043497085571), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([1, 16, 27], f16), T([16], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 32, 144], f16), T([32], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 64, 288], f16), T([64], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 128, 576], f16), T([128], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 2, ((T([1, 256, 128], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 128, 128], f16), T([128], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 2, ((T([1, 128, 1152], f16), T([128], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 3, ((T([1, 512, 256], f16), T([512], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 256, 256], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 4, ((T([1, 256, 1152], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 256, 512], f16), T([256], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 1536, 512], f16), T([1536], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 768, 512], f16), T([768], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 18, ((T([1, 768, 1152], f16), T([768], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 9, ((T([1, 1536, 768], f16), T([1536], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 8, ((T([1, 768, 1536], f16), T([768], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 1536, 1536], f16), T([1536], f16), None, None, None, True, 0.0, 1e-05), {})
+cnt: 1, ((T([1, 3072, 1536], f16), T([3072], f16), None, None, None, True, 0.0, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 1, ((T([1, 3072, 1536], f16), T([1, 3072, 1536], f16), T([3072], f16), None, None, T([3072], f32), T([3072], f32), True, 1e-05, [True, True, False]), {})
+cnt: 9, ((T([1, 1536, 768], f16), T([1, 1536, 768], f16), T([1536], f16), None, None, T([1536], f32), T([1536], f32), True, 1e-05, [True, True, False]), {})
+cnt: 18, ((T([1, 768, 1152], f16), T([1, 768, 1152], f16), T([768], f16), None, None, T([768], f32), T([768], f32), True, 1e-05, [True, True, False]), {})
+cnt: 8, ((T([1, 768, 1536], f16), T([1, 768, 1536], f16), T([768], f16), None, None, T([768], f32), T([768], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 1536, 1536], f16), T([1, 1536, 1536], f16), T([1536], f16), None, None, T([1536], f32), T([1536], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 768, 512], f16), T([1, 768, 512], f16), T([768], f16), None, None, T([768], f32), T([768], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 1536, 512], f16), T([1, 1536, 512], f16), T([1536], f16), None, None, T([1536], f32), T([1536], f32), True, 1e-05, [True, True, False]), {})
+cnt: 3, ((T([1, 512, 256], f16), T([1, 512, 256], f16), T([512], f16), None, None, T([512], f32), T([512], f32), True, 1e-05, [True, True, False]), {})
+cnt: 4, ((T([1, 256, 1152], f16), T([1, 256, 1152], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 256, 512], f16), T([1, 256, 512], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 256, 256], f16), T([1, 256, 256], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 2, ((T([1, 256, 128], f16), T([1, 256, 128], f16), T([256], f16), None, None, T([256], f32), T([256], f32), True, 1e-05, [True, True, False]), {})
+cnt: 2, ((T([1, 128, 1152], f16), T([1, 128, 1152], f16), T([128], f16), None, None, T([128], f32), T([128], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 128, 128], f16), T([1, 128, 128], f16), T([128], f16), None, None, T([128], f32), T([128], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 128, 576], f16), T([1, 128, 576], f16), T([128], f16), None, None, T([128], f32), T([128], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 64, 288], f16), T([1, 64, 288], f16), T([64], f16), None, None, T([64], f32), T([64], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 32, 144], f16), T([1, 32, 144], f16), T([32], f16), None, None, T([32], f32), T([32], f32), True, 1e-05, [True, True, False]), {})
+cnt: 1, ((T([1, 16, 27], f16), T([1, 16, 27], f16), T([16], f16), None, None, T([16], f32), T([16], f32), True, 1e-05, [True, True, False]), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([128, 128, 1, 1], f16),), {})
+cnt: 2, ((T([128, 256, 1, 1], f16),), {})
+cnt: 9, ((T([128, 768, 1, 1], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([128, 256, 1, 1], f16),), {})
+cnt: 2, ((T([128, 512, 1, 1], f16),), {})
+cnt: 9, ((T([128, 1536, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 9, ((T([128, 1536, 1, 1], f16), T([128, 1536, 1, 1], f16)), {})
+cnt: 2, ((T([128, 512, 1, 1], f16), T([128, 512, 1, 1], f16)), {})
+cnt: 1, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([128, 1000], f16, stride=(0, 0)), [0], True), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16), [2, 3], True), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16), [2, 3], True), {})
+cnt: 2, ((T([128, 512, 24, 24], f16), [2, 3], True), {})
+cnt: 1, ((T([128, 256, 48, 48], f16), [2, 3], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([128, 1000], f16),), {})
+cnt: 3, ((T([128, 1536, 6, 6], f16),), {})
+cnt: 6, ((T([128, 1536, 12, 12], f16),), {})
+cnt: 2, ((T([128, 512, 24, 24], f16),), {})
+cnt: 1, ((T([128, 256, 48, 48], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 9, ((T([128, 768, 1, 1], f16), T([128, 768, 1, 1], f16), 0), {})
+cnt: 2, ((T([128, 256, 1, 1], f16), T([128, 256, 1, 1], f16), 0), {})
+cnt: 1, ((T([128, 128, 1, 1], f16), T([128, 128, 1, 1], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_regnet_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_regnet_training.txt
new file mode 100644
index 0000000000000..e67c9e94a87a7
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_regnet_training.txt
@@ -0,0 +1,178 @@
+Operator: aten.add.Tensor
+cnt: 6, ((T([32, 224, 56, 56], f16), T([32, 224, 56, 56], f16)), {})
+cnt: 15, ((T([32, 448, 28, 28], f16), T([32, 448, 28, 28], f16)), {})
+cnt: 33, ((T([32, 896, 14, 14], f16), T([32, 896, 14, 14], f16)), {})
+cnt: 2, ((T([32, 2240, 7, 7], f16), T([32, 2240, 7, 7], f16)), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 2240], f16), T([2240, 1000], f16, stride=(1, 2240))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([224, 32, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 224, 112, 112], f16), T([224, 112, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 1, ((T([32, 224, 1, 1], f16), T([8, 224, 1, 1], f16), T([8], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 8, 1, 1], f16), T([224, 8, 1, 1], f16), T([224], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([32, 224, 56, 56], f16), T([224, 224, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([224, 32, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 224, 56, 56], f16), T([224, 112, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 1, ((T([32, 224, 1, 1], f16), T([56, 224, 1, 1], f16), T([56], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 56, 1, 1], f16), T([224, 56, 1, 1], f16), T([224], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 224, 56, 56], f16), T([448, 224, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 448, 56, 56], f16), T([448, 112, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 4), {})
+cnt: 1, ((T([32, 448, 1, 1], f16), T([56, 448, 1, 1], f16), T([56], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 56, 1, 1], f16), T([448, 56, 1, 1], f16), T([448], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 9, ((T([32, 448, 28, 28], f16), T([448, 448, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 224, 56, 56], f16), T([448, 224, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 448, 28, 28], f16), T([448, 112, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 4), {})
+cnt: 4, ((T([32, 448, 1, 1], f16), T([112, 448, 1, 1], f16), T([112], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 112, 1, 1], f16), T([448, 112, 1, 1], f16), T([448], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 448, 28, 28], f16), T([896, 448, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 896, 28, 28], f16), T([896, 112, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 1, ((T([32, 896, 1, 1], f16), T([112, 896, 1, 1], f16), T([112], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 112, 1, 1], f16), T([896, 112, 1, 1], f16), T([896], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 21, ((T([32, 896, 14, 14], f16), T([896, 896, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 448, 28, 28], f16), T([896, 448, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 10, ((T([32, 896, 14, 14], f16), T([896, 112, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 8), {})
+cnt: 10, ((T([32, 896, 1, 1], f16), T([224, 896, 1, 1], f16), T([224], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 10, ((T([32, 224, 1, 1], f16), T([896, 224, 1, 1], f16), T([896], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 896, 14, 14], f16), T([2240, 896, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16), T([2240, 112, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 20), {})
+cnt: 1, ((T([32, 2240, 1, 1], f16), T([224, 2240, 1, 1], f16), T([224], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 224, 1, 1], f16), T([2240, 224, 1, 1], f16), T([2240], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2240, 7, 7], f16), T([2240, 2240, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 896, 14, 14], f16), T([2240, 896, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 2240, 7, 7], f16), T([32, 896, 14, 14], f16), T([2240, 896, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2240, 7, 7], f16), T([32, 2240, 7, 7], f16), T([2240, 2240, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2240, 1, 1], f16), T([32, 224, 1, 1], f16), T([2240, 224, 1, 1], f16), [2240], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 224, 1, 1], f16), T([32, 2240, 1, 1], f16), T([224, 2240, 1, 1], f16), [224], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 2240, 7, 7], f16), T([32, 2240, 14, 14], f16), T([2240, 112, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 20, [True, True, False]), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16), T([32, 896, 14, 14], f16), T([2240, 896, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 21, ((T([32, 896, 14, 14], f16), T([32, 896, 14, 14], f16), T([896, 896, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 10, ((T([32, 896, 1, 1], f16), T([32, 224, 1, 1], f16), T([896, 224, 1, 1], f16), [896], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 10, ((T([32, 224, 1, 1], f16), T([32, 896, 1, 1], f16), T([224, 896, 1, 1], f16), [224], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 10, ((T([32, 896, 14, 14], f16), T([32, 896, 14, 14], f16), T([896, 112, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([32, 896, 14, 14], f16), T([32, 448, 28, 28], f16), T([896, 448, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 896, 1, 1], f16), T([32, 112, 1, 1], f16), T([896, 112, 1, 1], f16), [896], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 112, 1, 1], f16), T([32, 896, 1, 1], f16), T([112, 896, 1, 1], f16), [112], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 896, 14, 14], f16), T([32, 896, 28, 28], f16), T([896, 112, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 8, [True, True, False]), {})
+cnt: 1, ((T([32, 896, 28, 28], f16), T([32, 448, 28, 28], f16), T([896, 448, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 9, ((T([32, 448, 28, 28], f16), T([32, 448, 28, 28], f16), T([448, 448, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 448, 1, 1], f16), T([32, 112, 1, 1], f16), T([448, 112, 1, 1], f16), [448], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([32, 112, 1, 1], f16), T([32, 448, 1, 1], f16), T([112, 448, 1, 1], f16), [112], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 4, ((T([32, 448, 28, 28], f16), T([32, 448, 28, 28], f16), T([448, 112, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, False]), {})
+cnt: 1, ((T([32, 448, 28, 28], f16), T([32, 224, 56, 56], f16), T([448, 224, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 448, 1, 1], f16), T([32, 56, 1, 1], f16), T([448, 56, 1, 1], f16), [448], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 56, 1, 1], f16), T([32, 448, 1, 1], f16), T([56, 448, 1, 1], f16), [56], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 448, 28, 28], f16), T([32, 448, 56, 56], f16), T([448, 112, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 4, [True, True, False]), {})
+cnt: 1, ((T([32, 448, 56, 56], f16), T([32, 224, 56, 56], f16), T([448, 224, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([32, 224, 56, 56], f16), T([32, 224, 56, 56], f16), T([224, 224, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 224, 1, 1], f16), T([32, 56, 1, 1], f16), T([224, 56, 1, 1], f16), [224], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 56, 1, 1], f16), T([32, 224, 1, 1], f16), T([56, 224, 1, 1], f16), [56], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 224, 56, 56], f16), T([32, 224, 56, 56], f16), T([224, 112, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 1, ((T([32, 224, 56, 56], f16), T([32, 32, 112, 112], f16), T([224, 32, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 224, 1, 1], f16), T([32, 8, 1, 1], f16), T([224, 8, 1, 1], f16), [224], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 8, 1, 1], f16), T([32, 224, 1, 1], f16), T([8, 224, 1, 1], f16), [8], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 224, 56, 56], f16), T([32, 224, 112, 112], f16), T([224, 112, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 1, ((T([32, 224, 112, 112], f16), T([32, 32, 112, 112], f16), T([224, 32, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 2, ((T([32, 2240, 7, 7], f16, stride=(2240, 1, 0, 0)), 49), {})
+cnt: 11, ((T([32, 896, 14, 14], f16, stride=(896, 1, 0, 0)), 196), {})
+cnt: 5, ((T([32, 448, 28, 28], f16, stride=(448, 1, 0, 0)), 784), {})
+cnt: 2, ((T([32, 224, 56, 56], f16, stride=(224, 1, 0, 0)), 3136), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 32000), {})
+Operator: aten.mean.dim
+cnt: 2, ((T([32, 224, 56, 56], f16), [2, 3], True), {})
+cnt: 5, ((T([32, 448, 28, 28], f16), [2, 3], True), {})
+cnt: 11, ((T([32, 896, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 2240, 7, 7], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 2240, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), T([1000, 2240], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(0, 0)), T([32, 2240], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([32, 224, 56, 56], f16), T([32, 224, 1, 1], f16)), {})
+cnt: 10, ((T([32, 448, 28, 28], f16), T([32, 448, 1, 1], f16)), {})
+cnt: 22, ((T([32, 896, 14, 14], f16), T([32, 896, 1, 1], f16)), {})
+cnt: 2, ((T([32, 2240, 7, 7], f16), T([32, 2240, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2240, 7, 7], f16), T([32, 2240, 7, 7], f16)), {})
+cnt: 11, ((T([32, 896, 14, 14], f16), T([32, 896, 14, 14], f16)), {})
+cnt: 5, ((T([32, 448, 28, 28], f16), T([32, 448, 28, 28], f16)), {})
+cnt: 2, ((T([32, 224, 56, 56], f16), T([32, 224, 56, 56], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 224, 112, 112], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), False, 0.1, 1e-05), {})
+cnt: 6, ((T([32, 224, 56, 56], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 448, 56, 56], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), False, 0.1, 1e-05), {})
+cnt: 15, ((T([32, 448, 28, 28], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 896, 28, 28], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f16), False, 0.1, 1e-05), {})
+cnt: 33, ((T([32, 896, 14, 14], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16), T([2240], f16), T([2240], f16), T([2240], f16), T([2240], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 2240, 7, 7], f16), T([2240], f16), T([2240], f16), T([2240], f16), T([2240], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 3, ((T([32, 2240, 7, 7], f16), T([32, 2240, 7, 7], f16), T([2240], f16), T([2240], f16), T([2240], f16), T([2240], f32), T([2240], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16), T([32, 2240, 14, 14], f16), T([2240], f16), T([2240], f16), T([2240], f16), T([2240], f32), T([2240], f32), False, 1e-05, [True, True, True]), {})
+cnt: 33, ((T([32, 896, 14, 14], f16), T([32, 896, 14, 14], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f32), T([896], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 896, 28, 28], f16), T([32, 896, 28, 28], f16), T([896], f16), T([896], f16), T([896], f16), T([896], f32), T([896], f32), False, 1e-05, [True, True, True]), {})
+cnt: 15, ((T([32, 448, 28, 28], f16), T([32, 448, 28, 28], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 448, 56, 56], f16), T([32, 448, 56, 56], f16), T([448], f16), T([448], f16), T([448], f16), T([448], f32), T([448], f32), False, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([32, 224, 56, 56], f16), T([32, 224, 56, 56], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 224, 112, 112], f16), T([32, 224, 112, 112], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu.default
+cnt: 2, ((T([32, 224, 56, 56], f16),), {})
+cnt: 5, ((T([32, 448, 28, 28], f16),), {})
+cnt: 11, ((T([32, 896, 14, 14], f16),), {})
+cnt: 1, ((T([32, 2240, 7, 7], f16),), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([32, 32, 112, 112], f16),), {})
+cnt: 1, ((T([32, 224, 112, 112], f16),), {})
+cnt: 3, ((T([32, 224, 56, 56], f16),), {})
+cnt: 1, ((T([32, 8, 1, 1], f16),), {})
+cnt: 2, ((T([32, 56, 1, 1], f16),), {})
+cnt: 1, ((T([32, 448, 56, 56], f16),), {})
+cnt: 9, ((T([32, 448, 28, 28], f16),), {})
+cnt: 5, ((T([32, 112, 1, 1], f16),), {})
+cnt: 1, ((T([32, 896, 28, 28], f16),), {})
+cnt: 21, ((T([32, 896, 14, 14], f16),), {})
+cnt: 11, ((T([32, 224, 1, 1], f16),), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16),), {})
+cnt: 1, ((T([32, 2240, 7, 7], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 2, ((T([32, 224, 1, 1], f16),), {})
+cnt: 5, ((T([32, 448, 1, 1], f16),), {})
+cnt: 11, ((T([32, 896, 1, 1], f16),), {})
+cnt: 1, ((T([32, 2240, 1, 1], f16),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([32, 2240, 1, 1], f16), T([32, 2240, 1, 1], f16)), {})
+cnt: 11, ((T([32, 896, 1, 1], f16), T([32, 896, 1, 1], f16)), {})
+cnt: 5, ((T([32, 448, 1, 1], f16), T([32, 448, 1, 1], f16)), {})
+cnt: 2, ((T([32, 224, 1, 1], f16), T([32, 224, 1, 1], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), [0], True), {})
+cnt: 1, ((T([32, 2240, 7, 7], f16), [2, 3], True), {})
+cnt: 11, ((T([32, 896, 14, 14], f16), [2, 3], True), {})
+cnt: 5, ((T([32, 448, 28, 28], f16), [2, 3], True), {})
+cnt: 2, ((T([32, 224, 56, 56], f16), [2, 3], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([32, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([32, 2240, 7, 7], f16), T([32, 2240, 7, 7], f16), 0), {})
+cnt: 11, ((T([32, 224, 1, 1], f16), T([32, 224, 1, 1], f16), 0), {})
+cnt: 1, ((T([32, 2240, 14, 14], f16), T([32, 2240, 14, 14], f16), 0), {})
+cnt: 32, ((T([32, 896, 14, 14], f16), T([32, 896, 14, 14], f16), 0), {})
+cnt: 5, ((T([32, 112, 1, 1], f16), T([32, 112, 1, 1], f16), 0), {})
+cnt: 1, ((T([32, 896, 28, 28], f16), T([32, 896, 28, 28], f16), 0), {})
+cnt: 14, ((T([32, 448, 28, 28], f16), T([32, 448, 28, 28], f16), 0), {})
+cnt: 2, ((T([32, 56, 1, 1], f16), T([32, 56, 1, 1], f16), 0), {})
+cnt: 1, ((T([32, 448, 56, 56], f16), T([32, 448, 56, 56], f16), 0), {})
+cnt: 5, ((T([32, 224, 56, 56], f16), T([32, 224, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 8, 1, 1], f16), T([32, 8, 1, 1], f16), 0), {})
+cnt: 1, ((T([32, 224, 112, 112], f16), T([32, 224, 112, 112], f16), 0), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_resnest_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_resnest_training.txt
new file mode 100644
index 0000000000000..31d5de6bf2879
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_resnest_training.txt
@@ -0,0 +1,205 @@
+Operator: aten._softmax.default
+cnt: 1, ((T([32, 2, 1, 64], f16), 1, False), {})
+cnt: 1, ((T([32, 2, 1, 128], f16), 1, False), {})
+cnt: 1, ((T([32, 2, 1, 256], f16), 1, False), {})
+cnt: 1, ((T([32, 2, 1, 512], f16), 1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 1, ((T([32, 2, 1, 512], f16), T([32, 2, 1, 512], f16), 1, f16), {})
+cnt: 1, ((T([32, 2, 1, 256], f16), T([32, 2, 1, 256], f16), 1, f16), {})
+cnt: 1, ((T([32, 2, 1, 128], f16), T([32, 2, 1, 128], f16), 1, f16), {})
+cnt: 1, ((T([32, 2, 1, 64], f16), T([32, 2, 1, 64], f16), 1, f16), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([32, 2, 512, 14, 14], f16), T([32, 2, 512, 14, 14], f16, stride=(100352, 0, 196, 14, 1))), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 1, ((T([32, 2, 256, 28, 28], f16), T([32, 2, 256, 28, 28], f16, stride=(200704, 0, 784, 28, 1))), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 1, ((T([32, 2, 128, 56, 56], f16), T([32, 2, 128, 56, 56], f16, stride=(401408, 0, 3136, 56, 1))), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+cnt: 1, ((T([32, 2, 64, 56, 56], f16), T([32, 2, 64, 56, 56], f16, stride=(200704, 0, 3136, 56, 1))), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16)), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16)), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16)), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 2048], f16), T([2048, 1000], f16, stride=(1, 2048))), {})
+Operator: aten.avg_pool2d.default
+cnt: 1, ((T([32, 128, 56, 56], f16), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), [2, 2], [2, 2], [0, 0], True, False), {})
+Operator: aten.avg_pool2d_backward.default
+cnt: 1, ((T([32, 1024, 7, 7], f16), T([32, 1024, 14, 14], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([32, 512, 7, 7], f16), T([32, 512, 14, 14], f16), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([32, 512, 28, 28], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([32, 256, 14, 14], f16), T([32, 256, 28, 28], f16), [3, 3], [2, 2], [1, 1], False, True, None), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([32, 256, 56, 56], f16), [2, 2], [2, 2], [0, 0], True, False, None), {})
+cnt: 1, ((T([32, 128, 28, 28], f16), T([32, 128, 56, 56], f16), [3, 3], [2, 2], [1, 1], False, True, None), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([128, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 1, ((T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), T([32], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 32, 1, 1], f16), T([128, 32, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([32, 64, 56, 56], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([256, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 1, ((T([32, 128, 1, 1], f16), T([64, 128, 1, 1], f16), T([64], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 1, 1], f16), T([256, 64, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 28, 28], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([512, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([512, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 1, ((T([32, 256, 1, 1], f16), T([128, 256, 1, 1], f16), T([128], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 128, 1, 1], f16), T([512, 128, 1, 1], f16), T([512], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 14, 14], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([1024, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([1024, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 2), {})
+cnt: 1, ((T([32, 512, 1, 1], f16), T([256, 512, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 1, 1], f16), T([1024, 256, 1, 1], f16), T([1024], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 7, 7], f16), T([2048, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 1024, 7, 7], f16), T([2048, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 512, 7, 7], f16), T([2048, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 1, 1], f16), T([32, 256, 1, 1], f16), T([1024, 256, 1, 1], f16), [1024], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 256, 1, 1], f16), T([32, 512, 1, 1], f16), T([256, 512, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 512, 14, 14], f16), T([1024, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([32, 1024, 14, 14], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 512, 14, 14], f16), T([1024, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 14, 14], f16), T([32, 256, 14, 14], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 1, 1], f16), T([32, 128, 1, 1], f16), T([512, 128, 1, 1], f16), [512], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 1, 1], f16), T([32, 256, 1, 1], f16), T([128, 256, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 256, 28, 28], f16), T([512, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([32, 512, 28, 28], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 256, 28, 28], f16), T([512, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 128, 28, 28], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 1, 1], f16), T([32, 64, 1, 1], f16), T([256, 64, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 1, 1], f16), T([32, 128, 1, 1], f16), T([64, 128, 1, 1], f16), [64], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 128, 56, 56], f16), T([256, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 256, 56, 56], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([32, 256, 56, 56], f16), T([32, 64, 56, 56], f16), T([256, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 1, 1], f16), T([32, 32, 1, 1], f16), T([128, 32, 1, 1], f16), [128], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 32, 1, 1], f16), T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), [32], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 64, 56, 56], f16), T([128, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 2, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16), T([64, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 32, 112, 112], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), T([32, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 32, 112, 112], f16), T([32, 3, 224, 224], f16), T([32, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 2048, 7, 7], f16, stride=(2048, 1, 0, 0)), 49), {})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(512, 1, 0, 0)), 196), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(256, 1, 0, 0)), 784), {})
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(128, 1, 0, 0)), 3136), {})
+cnt: 1, ((T([32, 64, 56, 56], f16, stride=(64, 1, 0, 0)), 3136), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 32000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([32, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 112, 112], f16), [3, 3], [2, 2], [1, 1], [1, 1], False, T([32, 64, 56, 56], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 64, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), [2, 3], True), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), T([1000, 2048], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(0, 0)), T([32, 2048], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([32, 2, 64, 56, 56], f16), T([32, 2, 64, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 128, 56, 56], f16), T([32, 2, 128, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 256, 28, 28], f16), T([32, 2, 256, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 512, 14, 14], f16), T([32, 2, 512, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 512, 14, 14], f16, stride=(100352, 0, 196, 14, 1)), T([32, 2, 512, 14, 14], f16)), {})
+cnt: 1, ((T([32, 2, 512, 14, 14], f16, stride=(100352, 0, 196, 14, 1)), T([32, 2, 512, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 256, 28, 28], f16, stride=(200704, 0, 784, 28, 1)), T([32, 2, 256, 28, 28], f16)), {})
+cnt: 1, ((T([32, 2, 256, 28, 28], f16, stride=(200704, 0, 784, 28, 1)), T([32, 2, 256, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 128, 56, 56], f16, stride=(401408, 0, 3136, 56, 1)), T([32, 2, 128, 56, 56], f16)), {})
+cnt: 1, ((T([32, 2, 128, 56, 56], f16, stride=(401408, 0, 3136, 56, 1)), T([32, 2, 128, 1, 1], f16)), {})
+cnt: 1, ((T([32, 2, 64, 56, 56], f16, stride=(200704, 0, 3136, 56, 1)), T([32, 2, 64, 56, 56], f16)), {})
+cnt: 1, ((T([32, 2, 64, 56, 56], f16, stride=(200704, 0, 3136, 56, 1)), T([32, 2, 64, 1, 1], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 32, 1, 1], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 64, 1, 1], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 128, 1, 1], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 3, ((T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 256, 1, 1], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 2, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), T([2048], f16), T([2048], f16), T([2048], f16), T([2048], f32), T([2048], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 256, 1, 1], f16), T([32, 256, 1, 1], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([32, 512, 14, 14], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 128, 1, 1], f16), T([32, 128, 1, 1], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([32, 256, 28, 28], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 3, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 32, 1, 1], f16), T([32, 32, 1, 1], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([32, 32, 112, 112], f16),), {})
+cnt: 1, ((T([32, 64, 112, 112], f16),), {})
+cnt: 1, ((T([32, 64, 56, 56], f16),), {})
+cnt: 2, ((T([32, 128, 56, 56], f16),), {})
+cnt: 1, ((T([32, 32, 1, 1], f16),), {})
+cnt: 2, ((T([32, 256, 56, 56], f16),), {})
+cnt: 1, ((T([32, 64, 1, 1], f16),), {})
+cnt: 2, ((T([32, 512, 28, 28], f16),), {})
+cnt: 1, ((T([32, 256, 28, 28], f16),), {})
+cnt: 1, ((T([32, 128, 1, 1], f16),), {})
+cnt: 2, ((T([32, 1024, 14, 14], f16),), {})
+cnt: 1, ((T([32, 512, 14, 14], f16),), {})
+cnt: 1, ((T([32, 256, 1, 1], f16),), {})
+cnt: 1, ((T([32, 2048, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), [0], True), {})
+cnt: 1, ((T([32, 2, 512, 14, 14], f16), [3, 4], True), {})
+cnt: 1, ((T([32, 2, 256, 28, 28], f16), [3, 4], True), {})
+cnt: 1, ((T([32, 2, 128, 56, 56], f16), [3, 4], True), {})
+cnt: 1, ((T([32, 2, 64, 56, 56], f16), [3, 4], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([32, 1000], f16),), {})
+Operator: aten.sum.dim_IntList
+cnt: 2, ((T([32, 2, 64, 56, 56], f16), [1]), {})
+cnt: 2, ((T([32, 2, 128, 56, 56], f16), [1]), {})
+cnt: 2, ((T([32, 2, 256, 28, 28], f16), [1]), {})
+cnt: 2, ((T([32, 2, 512, 14, 14], f16), [1]), {})
+Operator: aten.threshold_backward.default
+cnt: 1, ((T([32, 2048, 7, 7], f16), T([32, 2048, 7, 7], f16), 0), {})
+cnt: 1, ((T([32, 256, 1, 1], f16), T([32, 256, 1, 1], f16), 0), {})
+cnt: 2, ((T([32, 1024, 14, 14], f16), T([32, 1024, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([32, 512, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 128, 1, 1], f16), T([32, 128, 1, 1], f16), 0), {})
+cnt: 2, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([32, 256, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 64, 1, 1], f16), T([32, 64, 1, 1], f16), 0), {})
+cnt: 2, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), 0), {})
+cnt: 2, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 32, 1, 1], f16), T([32, 32, 1, 1], f16), 0), {})
+cnt: 1, ((T([32, 64, 56, 56], f16), T([32, 64, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), 0), {})
+cnt: 2, ((T([32, 32, 112, 112], f16), T([32, 32, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_vision_transformer_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_vision_transformer_training.txt
new file mode 100644
index 0000000000000..ed9e7bf694f66
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_vision_transformer_training.txt
@@ -0,0 +1,77 @@
+Operator: aten._softmax.default
+cnt: 12, ((T([8, 6, 197, 197], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 12, ((T([8, 6, 197, 197], f16), T([8, 6, 197, 197], f16), -1, f16), {})
+Operator: aten._unsafe_view.default
+cnt: 36, ((T([8, 6, 197, 64], f16), [48, 197, 64]), {})
+cnt: 12, ((T([8, 6, 64, 197], f16), [48, 64, 197]), {})
+cnt: 12, ((T([48, 197, 197], f16), [8, 6, 197, 197]), {})
+cnt: 12, ((T([48, 197, 64], f16), [8, 6, 197, 64]), {})
+cnt: 12, ((T([8, 197, 6, 64], f16), [8, 197, 384]), {})
+cnt: 12, ((T([8, 197, 3, 6, 64], f16), [8, 197, 1152]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([8, 197, 384], f16), T([1, 197, 384], f16)), {})
+cnt: 48, ((T([8, 197, 384], f16), T([8, 197, 384], f16)), {})
+Operator: aten.addmm.default
+cnt: 12, ((T([1152], f16), T([1576, 384], f16), T([384, 1152], f16, stride=(1, 384))), {})
+cnt: 12, ((T([384], f16), T([1576, 384], f16), T([384, 384], f16, stride=(1, 384))), {})
+cnt: 12, ((T([1536], f16), T([1576, 384], f16), T([384, 1536], f16, stride=(1, 384))), {})
+cnt: 12, ((T([384], f16), T([1576, 1536], f16), T([1536, 384], f16, stride=(1, 1536))), {})
+cnt: 1, ((T([1000], f16), T([8, 384], f16, stride=(75648, 1)), T([384, 1000], f16, stride=(1, 384))), {})
+Operator: aten.bmm.default
+cnt: 12, ((T([48, 197, 64], f16), T([48, 64, 197], f16)), {})
+cnt: 12, ((T([48, 197, 197], f16), T([48, 197, 64], f16)), {})
+cnt: 12, ((T([48, 197, 197], f16, stride=(38809, 1, 197)), T([48, 197, 64], f16)), {})
+cnt: 12, ((T([48, 197, 64], f16), T([48, 64, 197], f16, stride=(12608, 1, 64))), {})
+cnt: 12, ((T([48, 64, 197], f16, stride=(12608, 1, 64)), T([48, 197, 197], f16)), {})
+cnt: 12, ((T([48, 197, 197], f16), T([48, 197, 64], f16, stride=(12608, 1, 197))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([8, 1, 384], f16, stride=(0, 384, 1)), T([8, 196, 384], f16, stride=(75264, 1, 196))], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([8, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([8, 3, 224, 224], f16), T([384, 3, 16, 16], f16), T([384], f16), [16, 16], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([8, 384, 14, 14], f16, stride=(75648, 1, 5376, 384)), T([8, 3, 224, 224], f16), T([384, 3, 16, 16], f16), [384], [16, 16], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([8, 3, 224, 224], f16), T([8, 3, 224, 224], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 8000), {})
+Operator: aten.gelu.default
+cnt: 12, ((T([8, 197, 1536], f16),), {})
+Operator: aten.gelu_backward.default
+cnt: 12, ((T([8, 197, 1536], f16), T([8, 197, 1536], f16)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([8, 1000], f16, stride=(0, 0)), T([1000, 384], f16)), {})
+cnt: 1, ((T([1000, 8], f16, stride=(0, 0)), T([8, 384], f16, stride=(75648, 1))), {})
+cnt: 12, ((T([1576, 384], f16), T([384, 1536], f16)), {})
+cnt: 12, ((T([384, 1576], f16, stride=(1, 384)), T([1576, 1536], f16)), {})
+cnt: 12, ((T([1576, 1536], f16), T([1536, 384], f16)), {})
+cnt: 12, ((T([1536, 1576], f16, stride=(1, 1536)), T([1576, 384], f16)), {})
+cnt: 12, ((T([1576, 384], f16), T([384, 384], f16)), {})
+cnt: 12, ((T([384, 1576], f16, stride=(1, 384)), T([1576, 384], f16)), {})
+cnt: 12, ((T([1576, 1152], f16), T([1152, 384], f16)), {})
+cnt: 12, ((T([1152, 1576], f16, stride=(1, 1152)), T([1576, 384], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 24, ((T([8, 6, 197, 197], f16), 0.125), {})
+Operator: aten.native_layer_norm.default
+cnt: 25, ((T([8, 197, 384], f16), [384], T([384], f16), T([384], f16), 1e-06), {})
+Operator: aten.native_layer_norm_backward.default
+cnt: 25, ((T([8, 197, 384], f16), T([8, 197, 384], f16), [384], T([8, 197, 1], f32), T([8, 197, 1], f32), T([384], f16), T([384], f16), [True, True, True]), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([8, 384], f16), [8, 197, 384], 1, 0), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([8, 197, 384], f16), [8, 197, 384], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.stack.default
+cnt: 12, (([T([8, 6, 197, 64], f16), T([8, 6, 197, 64], f16, stride=(75648, 12608, 1, 197)), T([8, 6, 197, 64], f16)],), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([8, 1000], f16, stride=(0, 0)), [0], True), {})
+cnt: 24, ((T([1576, 384], f16), [0], True), {})
+cnt: 12, ((T([1576, 1536], f16), [0], True), {})
+cnt: 12, ((T([1576, 1152], f16), [0], True), {})
+cnt: 1, ((T([8, 197, 384], f16), [0], True), {})
+cnt: 1, ((T([8, 1, 384], f16, stride=(75648, 384, 1)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([8, 1000], f16),), {})
+Operator: aten.unbind.int
+cnt: 12, ((T([3, 8, 6, 197, 64], f16, stride=(384, 226944, 64, 1152, 1)),), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_vovnet_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_vovnet_training.txt
new file mode 100644
index 0000000000000..0ff92b240c675
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/timm_vovnet_training.txt
@@ -0,0 +1,130 @@
+Operator: aten.add.Tensor
+cnt: 4, ((T([32, 224, 7, 7], f16, stride=(105056, 49, 7, 1)), T([32, 224, 7, 7], f16)), {})
+cnt: 1, ((T([32, 1024, 7, 7], f16, stride=(105056, 49, 7, 1)), T([32, 1024, 7, 7], f16)), {})
+cnt: 4, ((T([32, 224, 7, 7], f16, stride=(92512, 49, 7, 1)), T([32, 224, 7, 7], f16)), {})
+cnt: 1, ((T([32, 768, 7, 7], f16, stride=(92512, 49, 7, 1)), T([32, 768, 7, 7], f16)), {})
+cnt: 4, ((T([32, 192, 14, 14], f16, stride=(338688, 196, 14, 1)), T([32, 192, 14, 14], f16)), {})
+cnt: 1, ((T([32, 768, 14, 14], f16, stride=(338688, 196, 14, 1)), T([32, 768, 14, 14], f16)), {})
+cnt: 4, ((T([32, 192, 14, 14], f16, stride=(288512, 196, 14, 1)), T([32, 192, 14, 14], f16)), {})
+cnt: 1, ((T([32, 512, 14, 14], f16, stride=(288512, 196, 14, 1)), T([32, 512, 14, 14], f16)), {})
+cnt: 4, ((T([32, 160, 28, 28], f16, stride=(827904, 784, 28, 1)), T([32, 160, 28, 28], f16)), {})
+cnt: 1, ((T([32, 256, 28, 28], f16, stride=(827904, 784, 28, 1)), T([32, 256, 28, 28], f16)), {})
+cnt: 5, ((T([32, 128, 56, 56], f16, stride=(2408448, 3136, 56, 1)), T([32, 128, 56, 56], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1000], f16), T([32, 1024], f16), T([1024, 1000], f16, stride=(1, 1024))), {})
+Operator: aten.cat.default
+cnt: 1, (([T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16)], 1), {})
+cnt: 1, (([T([32, 256, 28, 28], f16), T([32, 160, 28, 28], f16), T([32, 160, 28, 28], f16), T([32, 160, 28, 28], f16), T([32, 160, 28, 28], f16), T([32, 160, 28, 28], f16)], 1), {})
+cnt: 1, (([T([32, 512, 14, 14], f16), T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 768, 14, 14], f16), T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16)], 1), {})
+cnt: 1, (([T([32, 768, 7, 7], f16), T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16)], 1), {})
+cnt: 1, (([T([32, 1024, 7, 7], f16), T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([32, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([64, 3, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([128, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([32, 128, 56, 56], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 768, 56, 56], f16), T([256, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([160, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([32, 160, 28, 28], f16), T([160, 160, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1056, 28, 28], f16), T([512, 1056, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([192, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([32, 192, 14, 14], f16), T([192, 192, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1472, 14, 14], f16), T([768, 1472, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 768, 14, 14], f16), T([192, 768, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1728, 14, 14], f16), T([768, 1728, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 768, 7, 7], f16), T([224, 768, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 8, ((T([32, 224, 7, 7], f16), T([224, 224, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1888, 7, 7], f16), T([1024, 1888, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 1024, 7, 7], f16), T([224, 1024, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([32, 2144, 7, 7], f16), T([1024, 2144, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([32, 1024, 7, 7], f16), T([32, 2144, 7, 7], f16), T([1024, 2144, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16), T([224, 224, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 224, 7, 7], f16), T([32, 1024, 7, 7], f16), T([224, 1024, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 1024, 7, 7], f16), T([32, 1888, 7, 7], f16), T([1024, 1888, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 224, 7, 7], f16), T([32, 768, 7, 7], f16), T([224, 768, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 768, 14, 14], f16), T([32, 1728, 14, 14], f16), T([768, 1728, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 8, ((T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16), T([192, 192, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 192, 14, 14], f16), T([32, 768, 14, 14], f16), T([192, 768, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 768, 14, 14], f16), T([32, 1472, 14, 14], f16), T([768, 1472, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 192, 14, 14], f16), T([32, 512, 14, 14], f16), T([192, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 1056, 28, 28], f16), T([512, 1056, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([32, 160, 28, 28], f16), T([32, 160, 28, 28], f16), T([160, 160, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 160, 28, 28], f16), T([32, 256, 28, 28], f16), T([160, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 768, 56, 56], f16), T([256, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 128, 56, 56], f16), T([32, 64, 112, 112], f16), T([128, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([32, 64, 112, 112], f16), T([32, 3, 224, 224], f16), T([64, 3, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([32, 3, 224, 224], f16), T([32, 3, 224, 224], f16)), {})
+Operator: aten.div.Scalar
+cnt: 1, ((T([32, 1024, 7, 7], f16, stride=(1024, 1, 0, 0)), 49), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 32000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([32, 256, 56, 56], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+cnt: 1, ((T([32, 768, 14, 14], f16), [3, 3], [2, 2], [0, 0], [1, 1], True), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([32, 768, 7, 7], f16), T([32, 768, 14, 14], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([32, 768, 7, 7], i64)), {})
+cnt: 1, ((T([32, 512, 14, 14], f16), T([32, 512, 28, 28], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([32, 512, 14, 14], i64)), {})
+cnt: 1, ((T([32, 256, 28, 28], f16), T([32, 256, 56, 56], f16), [3, 3], [2, 2], [0, 0], [1, 1], True, T([32, 256, 28, 28], i64)), {})
+Operator: aten.mean.dim
+cnt: 1, ((T([32, 1024, 7, 7], f16), [-1, -2], True), {})
+Operator: aten.mm.default
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), T([1000, 1024], f16)), {})
+cnt: 1, ((T([1000, 32], f16, stride=(0, 0)), T([32, 1024], f16)), {})
+Operator: aten.native_batch_norm.default
+cnt: 2, ((T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.1, 1e-05), {})
+cnt: 6, ((T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.1, 1e-05), {})
+cnt: 5, ((T([32, 160, 28, 28], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f16), False, 0.1, 1e-05), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.1, 1e-05), {})
+cnt: 10, ((T([32, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f16), False, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f16), False, 0.1, 1e-05), {})
+cnt: 10, ((T([32, 224, 7, 7], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f16), False, 0.1, 1e-05), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.1, 1e-05), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([32, 1024, 7, 7], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 1e-05, [True, True, True]), {})
+cnt: 10, ((T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16), T([224], f16), T([224], f16), T([224], f16), T([224], f32), T([224], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 768, 14, 14], f16), T([32, 768, 14, 14], f16), T([768], f16), T([768], f16), T([768], f16), T([768], f32), T([768], f32), False, 1e-05, [True, True, True]), {})
+cnt: 10, ((T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16), T([192], f16), T([192], f16), T([192], f16), T([192], f32), T([192], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 1e-05, [True, True, True]), {})
+cnt: 5, ((T([32, 160, 28, 28], f16), T([32, 160, 28, 28], f16), T([160], f16), T([160], f16), T([160], f16), T([160], f32), T([160], f32), False, 1e-05, [True, True, True]), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 1e-05, [True, True, True]), {})
+cnt: 6, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 1e-05, [True, True, True]), {})
+cnt: 2, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 1e-05, [True, True, True]), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([32, 64, 112, 112], f16),), {})
+cnt: 6, ((T([32, 128, 56, 56], f16),), {})
+cnt: 1, ((T([32, 256, 56, 56], f16),), {})
+cnt: 5, ((T([32, 160, 28, 28], f16),), {})
+cnt: 1, ((T([32, 512, 28, 28], f16),), {})
+cnt: 10, ((T([32, 192, 14, 14], f16),), {})
+cnt: 2, ((T([32, 768, 14, 14], f16),), {})
+cnt: 10, ((T([32, 224, 7, 7], f16),), {})
+cnt: 2, ((T([32, 1024, 7, 7], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([32, 1000], f16, stride=(0, 0)), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([32, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([32, 1024, 7, 7], f16), T([32, 1024, 7, 7], f16), 0), {})
+cnt: 1, ((T([32, 224, 7, 7], f16, stride=(105056, 49, 7, 1)), T([32, 224, 7, 7], f16), 0), {})
+cnt: 8, ((T([32, 224, 7, 7], f16), T([32, 224, 7, 7], f16), 0), {})
+cnt: 1, ((T([32, 224, 7, 7], f16, stride=(92512, 49, 7, 1)), T([32, 224, 7, 7], f16), 0), {})
+cnt: 2, ((T([32, 768, 14, 14], f16), T([32, 768, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 192, 14, 14], f16, stride=(338688, 196, 14, 1)), T([32, 192, 14, 14], f16), 0), {})
+cnt: 8, ((T([32, 192, 14, 14], f16), T([32, 192, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 192, 14, 14], f16, stride=(288512, 196, 14, 1)), T([32, 192, 14, 14], f16), 0), {})
+cnt: 1, ((T([32, 512, 28, 28], f16), T([32, 512, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 160, 28, 28], f16, stride=(827904, 784, 28, 1)), T([32, 160, 28, 28], f16), 0), {})
+cnt: 4, ((T([32, 160, 28, 28], f16), T([32, 160, 28, 28], f16), 0), {})
+cnt: 1, ((T([32, 256, 56, 56], f16), T([32, 256, 56, 56], f16), 0), {})
+cnt: 1, ((T([32, 128, 56, 56], f16, stride=(2408448, 3136, 56, 1)), T([32, 128, 56, 56], f16), 0), {})
+cnt: 5, ((T([32, 128, 56, 56], f16), T([32, 128, 56, 56], f16), 0), {})
+cnt: 2, ((T([32, 64, 112, 112], f16), T([32, 64, 112, 112], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/tts_angular_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/tts_angular_training.txt
new file mode 100644
index 0000000000000..847934aa9e1fa
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/tts_angular_training.txt
@@ -0,0 +1,51 @@
+Operator: aten._cudnn_rnn.default
+cnt: 1, ((T([64, 50, 40], f16), [T([3072, 40], f16), T([3072, 768], f16), T([3072], f16), T([3072], f16)], 4, None, T([1, 64, 768], f16), T([1, 64, 768], f16), 2, 768, 0, 1, True, 0.0, True, False, [], None), {})
+cnt: 2, ((T([64, 50, 256], f16), [T([3072, 256], f16), T([3072, 768], f16), T([3072], f16), T([3072], f16)], 4, None, T([1, 64, 768], f16), T([1, 64, 768], f16), 2, 768, 0, 1, True, 0.0, True, False, [], None), {})
+Operator: aten._cudnn_rnn_backward.default
+cnt: 2, ((T([64, 50, 256], f16), [T([3072, 256], f16), T([3072, 768], f16), T([3072], f16), T([3072], f16)], 4, T([3151872], f16), T([1, 64, 768], f16), T([1, 64, 768], f16), T([64, 50, 768], f16, stride=(768, 49152, 1)), T([64, 50, 768], f16), None, None, 2, 768, 0, 1, True, 0.0, True, False, [], None, T([24576016], u8), [True, False, False, True]), {})
+cnt: 1, ((T([64, 50, 40], f16), [T([3072, 40], f16), T([3072, 768], f16), T([3072], f16), T([3072], f16)], 4, T([2488320], f16), T([1, 64, 768], f16), T([1, 64, 768], f16), T([64, 50, 768], f16, stride=(768, 49152, 1)), T([64, 50, 768], f16), None, None, 2, 768, 0, 1, True, 0.0, True, False, [], None, T([24576016], u8), [False, False, False, True]), {})
+Operator: aten._unsafe_view.default
+cnt: 3, ((T([64, 50, 768], f16), [3200, 768]), {})
+cnt: 3, ((T([3200, 256], f16), [64, 50, 256]), {})
+cnt: 2, ((T([64, 50, 256], f16), [3200, 256]), {})
+Operator: aten.add.Tensor
+cnt: 1, ((T([64, 256], f16), T([64, 256], f16)), {})
+Operator: aten.clamp_min.default
+cnt: 1, ((T([64, 1], f16), 1e-12), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 50, 40], f16),), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 50, 40], f16), T([64, 50, 40], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([64, 256], f16, stride=(12800, 1)), T([64, 256], f16, stride=(1, 0))), {})
+cnt: 2, ((T([], f16), 16384), {})
+cnt: 1, ((T([64, 256], f16), T([64, 256], f16, stride=(1, 0))), {})
+cnt: 1, ((T([64, 256], f16, stride=(0, 0)), T([64, 256], f16, stride=(1, 0))), {})
+cnt: 1, ((T([64, 256], f16, stride=(12800, 1)), T([64, 1], f16)), {})
+Operator: aten.eq.Scalar
+cnt: 1, ((T([64, 1], f16), 0), {})
+Operator: aten.ge.Scalar
+cnt: 1, ((T([64, 1], f16), 1e-12), {})
+Operator: aten.masked_fill_.Scalar
+cnt: 1, ((T([64, 256], f16), T([64, 1], b8), 0), {})
+Operator: aten.mm.default
+cnt: 3, ((T([3200, 768], f16), T([768, 256], f16, stride=(1, 768))), {})
+cnt: 3, ((T([256, 3200], f16, stride=(1, 256)), T([3200, 768], f16)), {})
+cnt: 3, ((T([3200, 256], f16), T([256, 768], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([64, 256], f16), T([64, 256], f16)), {})
+cnt: 1, ((T([64, 1], f16), T([64, 256], f16)), {})
+Operator: aten.neg.default
+cnt: 1, ((T([64, 256], f16, stride=(0, 0)),), {})
+Operator: aten.norm.ScalarOpt_dim
+cnt: 1, ((T([64, 256], f16, stride=(12800, 1)), 2, [1], True), {})
+Operator: aten.select_backward.default
+cnt: 1, ((T([64, 256], f16), [64, 50, 256], 1, -1), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([64, 50, 256], f16), [64, 50, 256], 0, 0, 9223372036854775807, 1), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 256], f16), [1], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([64, 256], f16),), {})
+Operator: aten.where.self
+cnt: 1, ((T([64, 1], b8), T([64, 1], f16), T([], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/vgg16_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/vgg16_training.txt
new file mode 100644
index 0000000000000..cc96188bb03f5
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/vgg16_training.txt
@@ -0,0 +1,72 @@
+Operator: aten._adaptive_avg_pool2d.default
+cnt: 1, ((T([64, 512, 7, 7], f16), [7, 7]), {})
+Operator: aten._adaptive_avg_pool2d_backward.default
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 7, 7], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([4096], f16), T([64, 25088], f16), T([25088, 4096], f16, stride=(1, 25088))), {})
+cnt: 1, ((T([4096], f16), T([64, 4096], f16), T([4096, 4096], f16, stride=(1, 4096))), {})
+cnt: 1, ((T([1000], f16), T([64, 4096], f16), T([4096, 1000], f16, stride=(1, 4096))), {})
+Operator: aten.clone.default
+cnt: 1, ((T([64, 3, 224, 224], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 224, 224], f16), T([64, 64, 3, 3], f16), T([64], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([128, 64, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 112, 112], f16), T([128, 128, 3, 3], f16), T([128], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([256, 128, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 256, 56, 56], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([512, 256, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([64, 512, 28, 28], f16), T([512, 512, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([64, 512, 14, 14], f16), T([512, 512, 3, 3], f16), T([512], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 3, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), T([512, 512, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16), T([512, 512, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), T([64, 256, 28, 28], f16), T([512, 256, 3, 3], f16), [512], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 2, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), T([64, 128, 56, 56], f16), T([256, 128, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 128, 112, 112], f16), T([64, 128, 112, 112], f16), T([128, 128, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 128, 112, 112], f16), T([64, 64, 112, 112], f16), T([128, 64, 3, 3], f16), [128], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 64, 224, 224], f16), T([64, 64, 224, 224], f16), T([64, 64, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([64, 64, 224, 224], f16), T([64, 3, 224, 224], f16), T([64, 3, 3, 3], f16), [64], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([64, 3, 224, 224], f16), T([64, 3, 224, 224], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 64000), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([64, 64, 224, 224], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([64, 128, 112, 112], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([64, 256, 56, 56], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([64, 512, 28, 28], f16), [2, 2], [2, 2]), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), [2, 2], [2, 2]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([64, 512, 7, 7], f16), T([64, 512, 14, 14], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 512, 7, 7], i64)), {})
+cnt: 1, ((T([64, 512, 14, 14], f16), T([64, 512, 28, 28], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 512, 14, 14], i64)), {})
+cnt: 1, ((T([64, 256, 28, 28], f16), T([64, 256, 56, 56], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 256, 28, 28], i64)), {})
+cnt: 1, ((T([64, 128, 56, 56], f16), T([64, 128, 112, 112], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 128, 56, 56], i64)), {})
+cnt: 1, ((T([64, 64, 112, 112], f16), T([64, 64, 224, 224], f16), [2, 2], [2, 2], [0, 0], [1, 1], False, T([64, 64, 112, 112], i64)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([64, 1000], f16, stride=(0, 0)), T([1000, 4096], f16)), {})
+cnt: 1, ((T([1000, 64], f16, stride=(0, 0)), T([64, 4096], f16)), {})
+cnt: 1, ((T([64, 4096], f16), T([4096, 4096], f16)), {})
+cnt: 1, ((T([4096, 64], f16, stride=(1, 4096)), T([64, 4096], f16)), {})
+cnt: 1, ((T([64, 4096], f16), T([4096, 25088], f16)), {})
+cnt: 1, ((T([4096, 64], f16, stride=(1, 4096)), T([64, 25088], f16)), {})
+Operator: aten.relu_.default
+cnt: 2, ((T([64, 64, 224, 224], f16),), {})
+cnt: 2, ((T([64, 128, 112, 112], f16),), {})
+cnt: 3, ((T([64, 256, 56, 56], f16),), {})
+cnt: 3, ((T([64, 512, 28, 28], f16),), {})
+cnt: 3, ((T([64, 512, 14, 14], f16),), {})
+cnt: 2, ((T([64, 4096], f16),), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([64, 1000], f16, stride=(0, 0)), [0], True), {})
+cnt: 2, ((T([64, 4096], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 1, ((T([64, 1000], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([64, 4096], f16), T([64, 4096], f16), 0), {})
+cnt: 3, ((T([64, 512, 14, 14], f16), T([64, 512, 14, 14], f16), 0), {})
+cnt: 3, ((T([64, 512, 28, 28], f16), T([64, 512, 28, 28], f16), 0), {})
+cnt: 3, ((T([64, 256, 56, 56], f16), T([64, 256, 56, 56], f16), 0), {})
+cnt: 2, ((T([64, 128, 112, 112], f16), T([64, 128, 112, 112], f16), 0), {})
+cnt: 2, ((T([64, 64, 224, 224], f16), T([64, 64, 224, 224], f16), 0), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/vision_maskrcnn_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/vision_maskrcnn_training.txt
new file mode 100644
index 0000000000000..a88dbc3aec300
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/vision_maskrcnn_training.txt
@@ -0,0 +1,477 @@
+Operator: aten._index_put_impl_.default
+cnt: 12, ((T([0], f16), [T([0], i64)], T([0], f16), True, True), {})
+cnt: 12, ((T([0, 4], f16), [T([0], i64)], T([0, 4], f16), True, True), {})
+Operator: aten._softmax.default
+cnt: 1, ((T([0, 91], f16), -1, False), {})
+Operator: aten._softmax_backward_data.default
+cnt: 1, ((T([0, 91], f16), T([0, 91], f16), -1, f16), {})
+Operator: aten._to_copy.default
+cnt: 8, ((T([], i64),), {'dtype': f32})
+cnt: 5, ((T([3, 4], f32),), {'dtype': f16, 'device': 'cuda'})
+cnt: 8, ((T([0, 4], f16),), {'dtype': f32})
+cnt: 2, ((T([0], f32),), {'dtype': i64})
+cnt: 4, ((T([0, 4], f16),), {'dtype': i64})
+cnt: 8, ((T([], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 2, ((T([296, 304], i32), [89984]), {})
+cnt: 2, ((T([148, 152], i32), [22496]), {})
+cnt: 2, ((T([74, 76], i32), [5624]), {})
+cnt: 2, ((T([37, 38], i32), [1406]), {})
+cnt: 2, ((T([19, 19], i32), [361]), {})
+cnt: 1, ((T([4, 296, 304, 3, 1], f16), [4, 269952, 1]), {})
+cnt: 1, ((T([4, 296, 304, 3, 4], f16), [4, 269952, 4]), {})
+cnt: 1, ((T([4, 148, 152, 3, 1], f16), [4, 67488, 1]), {})
+cnt: 1, ((T([4, 148, 152, 3, 4], f16), [4, 67488, 4]), {})
+cnt: 1, ((T([4, 74, 76, 3, 1], f16), [4, 16872, 1]), {})
+cnt: 1, ((T([4, 74, 76, 3, 4], f16), [4, 16872, 4]), {})
+cnt: 1, ((T([4, 37, 38, 3, 1], f16), [4, 4218, 1]), {})
+cnt: 1, ((T([4, 37, 38, 3, 4], f16), [4, 4218, 4]), {})
+cnt: 1, ((T([4, 19, 19, 3, 1], f16), [4, 1083, 1]), {})
+cnt: 1, ((T([4, 19, 19, 3, 4], f16), [4, 1083, 4]), {})
+Operator: aten.add.Tensor
+cnt: 7, ((T([1, 64, 1, 1], f16), 0.0), {})
+cnt: 1, ((T([4, 64, 592, 608], f16), T([1, 64, 1, 1], f16)), {})
+cnt: 6, ((T([4, 64, 296, 304], f16), T([1, 64, 1, 1], f16)), {})
+cnt: 16, ((T([1, 256, 1, 1], f16), 0.0), {})
+cnt: 4, ((T([4, 256, 296, 304], f16), T([1, 256, 1, 1], f16)), {})
+cnt: 8, ((T([1, 128, 1, 1], f16), 0.0), {})
+cnt: 1, ((T([4, 128, 296, 304], f16), T([1, 128, 1, 1], f16)), {})
+cnt: 7, ((T([4, 128, 148, 152], f16), T([1, 128, 1, 1], f16)), {})
+cnt: 11, ((T([1, 512, 1, 1], f16), 0.0), {})
+cnt: 5, ((T([4, 512, 148, 152], f16), T([1, 512, 1, 1], f16)), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), T([1, 256, 1, 1], f16)), {})
+cnt: 11, ((T([4, 256, 74, 76], f16), T([1, 256, 1, 1], f16)), {})
+cnt: 7, ((T([1, 1024, 1, 1], f16), 0.0), {})
+cnt: 7, ((T([4, 1024, 74, 76], f16), T([1, 1024, 1, 1], f16)), {})
+cnt: 1, ((T([4, 512, 74, 76], f16), T([1, 512, 1, 1], f16)), {})
+cnt: 5, ((T([4, 512, 37, 38], f16), T([1, 512, 1, 1], f16)), {})
+cnt: 4, ((T([1, 2048, 1, 1], f16), 0.0), {})
+cnt: 4, ((T([4, 2048, 37, 38], f16), T([1, 2048, 1, 1], f16)), {})
+cnt: 2, ((T([4, 256, 74, 76], f16), T([4, 256, 74, 76], f16)), {})
+cnt: 2, ((T([4, 256, 148, 152], f16), T([4, 256, 148, 152], f16)), {})
+cnt: 1, ((T([4, 256, 296, 304], f16), T([4, 256, 296, 304], f16)), {})
+cnt: 1, ((T([89984, 1, 4], i32), T([1, 3, 4], f16)), {})
+cnt: 1, ((T([22496, 1, 4], i32), T([1, 3, 4], f16)), {})
+cnt: 1, ((T([5624, 1, 4], i32), T([1, 3, 4], f16)), {})
+cnt: 1, ((T([1406, 1, 4], i32), T([1, 3, 4], f16)), {})
+cnt: 1, ((T([361, 1, 4], i32), T([1, 3, 4], f16)), {})
+cnt: 2, ((T([1438452], f16, stride=(4,)), T([1438452], f16)), {})
+cnt: 4, ((T([1438452, 1], f16), T([1438452, 1], f16)), {})
+cnt: 1, ((T([4, 1000], i64), 0), {})
+cnt: 1, ((T([4, 1000], i64), 269952), {})
+cnt: 1, ((T([4, 1000], i64), 337440), {})
+cnt: 1, ((T([4, 1000], i64), 354312), {})
+cnt: 1, ((T([4, 1000], i64), 358530), {})
+cnt: 2, ((T([0], f32), 4), {})
+cnt: 2, ((T([0], f32), T([], f32)), {})
+cnt: 18, ((T([0], f16), T([0], f16)), {})
+cnt: 2, ((T([0, 91], f16), T([0, 1], f16)), {})
+cnt: 6, ((T([0, 91], f16), T([0, 91], f16)), {})
+cnt: 4, ((T([], f16), 0), {})
+cnt: 4, ((T([], f16), T([], f32)), {})
+cnt: 8, ((T([], f32), T([], f16)), {})
+cnt: 1, ((T([], f32), 0), {})
+cnt: 3, ((T([], f32), T([], f32)), {})
+cnt: 7, ((T([0, 364], f16), T([0, 364], f16)), {})
+cnt: 1, ((T([0, 1024], f16), T([0, 1024], f16)), {})
+cnt: 1, ((T([4, 256, 37, 38], f16), T([4, 256, 37, 38], f16)), {})
+cnt: 2, ((T([4, 2048, 37, 38], f16), T([4, 2048, 37, 38], f16)), {})
+cnt: 7, ((T([4, 1024, 74, 76], f16), T([4, 1024, 74, 76], f16)), {})
+cnt: 5, ((T([4, 512, 148, 152], f16), T([4, 512, 148, 152], f16)), {})
+Operator: aten.add_.Tensor
+cnt: 3, ((T([4, 256, 296, 304], f16), T([4, 256, 296, 304], f16)), {})
+cnt: 4, ((T([4, 512, 148, 152], f16), T([4, 512, 148, 152], f16)), {})
+cnt: 6, ((T([4, 1024, 74, 76], f16), T([4, 1024, 74, 76], f16)), {})
+cnt: 3, ((T([4, 2048, 37, 38], f16), T([4, 2048, 37, 38], f16)), {})
+Operator: aten.addmm.default
+cnt: 1, ((T([1024], f16), T([0, 12544], f16), T([12544, 1024], f16, stride=(1, 12544))), {})
+cnt: 1, ((T([1024], f16), T([0, 1024], f16), T([1024, 1024], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([91], f16), T([0, 1024], f16), T([1024, 91], f16, stride=(1, 1024))), {})
+cnt: 1, ((T([364], f16), T([0, 1024], f16), T([1024, 364], f16, stride=(1, 1024))), {})
+Operator: aten.bitwise_and.Tensor
+cnt: 4, ((T([5000], b8), T([5000], b8)), {})
+cnt: 4, ((T([0], b8), T([0], b8)), {})
+Operator: aten.cat.default
+cnt: 4, (([T([269952, 4], f16), T([67488, 4], f16), T([16872, 4], f16), T([4218, 4], f16), T([1083, 4], f16)],), {})
+cnt: 1, (([T([4, 269952, 1], f16), T([4, 67488, 1], f16), T([4, 16872, 1], f16), T([4, 4218, 1], f16), T([4, 1083, 1], f16)], 1), {})
+cnt: 1, (([T([4, 269952, 4], f16), T([4, 67488, 4], f16), T([4, 16872, 4], f16), T([4, 4218, 4], f16), T([4, 1083, 4], f16)], 1), {})
+cnt: 1, (([T([359613, 4], f16), T([359613, 4], f16), T([359613, 4], f16), T([359613, 4], f16)],), {})
+cnt: 1, (([T([269952], i64), T([67488], i64), T([16872], i64), T([4218], i64), T([1083], i64)],), {})
+cnt: 1, (([T([4, 1000], i64), T([4, 1000], i64), T([4, 1000], i64), T([4, 1000], i64), T([4, 1000], i64)], 1), {})
+cnt: 3, (([T([0, 4], f16), T([0, 4], f16), T([0, 4], f16), T([0, 4], f16)],), {})
+cnt: 2, (([T([0, 1], f16), T([0, 1], f16), T([0, 1], f16), T([0, 1], f16)],), {})
+cnt: 2, (([T([0, 1], f16), T([0, 4], f16)], 1), {})
+cnt: 2, (([T([0], f32), T([0], f32), T([0], f32), T([0], f32)],), {})
+cnt: 1, (([T([0], i64), T([0], i64), T([0], i64), T([0], i64)],), {})
+cnt: 1, (([T([0, 91], f16), T([0, 91], f16), T([0, 91], f16), T([0, 91], f16)],), {})
+cnt: 1, (([T([0, 364], f16), T([0, 364], f16), T([0, 364], f16), T([0, 364], f16)],), {})
+Operator: aten.clamp.default
+cnt: 2, ((T([1438452, 1], f16), None, 4.135166556742356), {})
+cnt: 1, ((T([5000, 2], f16, stride=(4, 2)), 0, 1199), {})
+cnt: 2, ((T([5000, 2], f16, stride=(4, 2)), 0, 799), {})
+cnt: 3, ((T([5000, 2], f16, stride=(4, 2)), 0, 800), {})
+cnt: 1, ((T([5000, 2], f16, stride=(4, 2)), 0, 1155), {})
+cnt: 1, ((T([5000, 2], f16, stride=(4, 2)), 0, 1115), {})
+cnt: 2, ((T([0], f32), 2, 5), {})
+cnt: 2, ((T([0, 91], f16), None, 4.135166556742356), {})
+cnt: 1, ((T([0, 182], f16), 0, 1199), {})
+cnt: 2, ((T([0, 182], f16), 0, 799), {})
+cnt: 3, ((T([0, 182], f16), 0, 800), {})
+cnt: 1, ((T([0, 182], f16), 0, 1155), {})
+cnt: 1, ((T([0, 182], f16), 0, 1115), {})
+Operator: aten.constant_pad_nd.default
+cnt: 4, ((T([0, 1, 28, 28], f16), [1, 1, 1, 1], 0.0), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([4, 3, 1184, 1216], f16), T([64, 3, 7, 7], f16), None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 64, 296, 304], f16), T([64, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([4, 64, 296, 304], f16), T([64, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([4, 64, 296, 304], f16), T([256, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([4, 256, 296, 304], f16), T([64, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 296, 304], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 128, 296, 304], f16), T([128, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([4, 128, 148, 152], f16), T([512, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 296, 304], f16), T([512, 256, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([4, 512, 148, 152], f16), T([128, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([4, 128, 148, 152], f16), T([128, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 512, 148, 152], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), T([256, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 6, ((T([4, 256, 74, 76], f16), T([1024, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 512, 148, 152], f16), T([1024, 512, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([4, 1024, 74, 76], f16), T([256, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 5, ((T([4, 256, 74, 76], f16), T([256, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 1024, 74, 76], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 512, 74, 76], f16), T([512, 512, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 3, ((T([4, 512, 37, 38], f16), T([2048, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 1024, 74, 76], f16), T([2048, 1024, 1, 1], f16), None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([4, 2048, 37, 38], f16), T([512, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([4, 512, 37, 38], f16), T([512, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 2048, 37, 38], f16), T([256, 2048, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([4, 256, 37, 38], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 1024, 74, 76], f16), T([256, 1024, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([4, 256, 74, 76], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 512, 148, 152], f16), T([256, 512, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([4, 256, 148, 152], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 296, 304], f16), T([256, 256, 1, 1], f16), T([256], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([4, 256, 296, 304], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 296, 304], f16), T([3, 256, 1, 1], f16), T([3], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 296, 304], f16), T([12, 256, 1, 1], f16), T([12], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), T([3, 256, 1, 1], f16), T([3], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), T([12, 256, 1, 1], f16), T([12], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 74, 76], f16), T([3, 256, 1, 1], f16), T([3], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 74, 76], f16), T([12, 256, 1, 1], f16), T([12], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 37, 38], f16), T([3, 256, 1, 1], f16), T([3], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 37, 38], f16), T([12, 256, 1, 1], f16), T([12], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 19, 19], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 19, 19], f16), T([3, 256, 1, 1], f16), T([3], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([4, 256, 19, 19], f16), T([12, 256, 1, 1], f16), T([12], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 4, ((T([0, 256, 14, 14], f16), T([256, 256, 3, 3], f16), T([256], f16), [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([0, 256, 14, 14], f16), T([256, 256, 2, 2], f16), T([256], f16), [2, 2], [0, 0], [1, 1], True, [0, 0], 1), {})
+cnt: 1, ((T([0, 256, 28, 28], f16), T([91, 256, 1, 1], f16), T([91], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([4, 256, 296, 304], f16), T([4, 256, 296, 304], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([4, 256, 296, 304], f16), T([4, 256, 296, 304], f16), T([256, 256, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), T([4, 256, 148, 152], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), T([4, 512, 148, 152], f16), T([256, 512, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([4, 256, 74, 76], f16), T([4, 256, 74, 76], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([4, 256, 74, 76], f16), T([4, 1024, 74, 76], f16), T([256, 1024, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([4, 256, 37, 38], f16), T([4, 256, 37, 38], f16), T([256, 256, 3, 3], f16), [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 1, ((T([4, 256, 37, 38], f16), T([4, 2048, 37, 38], f16), T([256, 2048, 1, 1], f16), [256], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 3, ((T([4, 2048, 37, 38], f16), T([4, 512, 37, 38], f16), T([2048, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([4, 512, 37, 38], f16), T([4, 512, 37, 38], f16), T([512, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([4, 512, 37, 38], f16), T([4, 2048, 37, 38], f16), T([512, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 2048, 37, 38], f16), T([4, 1024, 74, 76], f16), T([2048, 1024, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 512, 37, 38], f16), T([4, 512, 74, 76], f16), T([512, 512, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 512, 74, 76], f16), T([4, 1024, 74, 76], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 6, ((T([4, 1024, 74, 76], f16), T([4, 256, 74, 76], f16), T([1024, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([4, 256, 74, 76], f16), T([4, 256, 74, 76], f16), T([256, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 5, ((T([4, 256, 74, 76], f16), T([4, 1024, 74, 76], f16), T([256, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 1024, 74, 76], f16), T([4, 512, 148, 152], f16), T([1024, 512, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 256, 74, 76], f16), T([4, 256, 148, 152], f16), T([256, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), T([4, 512, 148, 152], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 4, ((T([4, 512, 148, 152], f16), T([4, 128, 148, 152], f16), T([512, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([4, 128, 148, 152], f16), T([4, 128, 148, 152], f16), T([128, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 3, ((T([4, 128, 148, 152], f16), T([4, 512, 148, 152], f16), T([128, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 512, 148, 152], f16), T([4, 256, 296, 304], f16), T([512, 256, 1, 1], f16), [0], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+cnt: 1, ((T([4, 128, 148, 152], f16), T([4, 128, 296, 304], f16), T([128, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([4, 128, 296, 304], f16), T([4, 256, 296, 304], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([3, 799, 1199], f16, stride=(1439744, 1216, 1)), T([3, 799, 1199], f16)), {})
+cnt: 1, ((T([3, 800, 800], f16, stride=(1439744, 1216, 1)), T([3, 800, 800], f16)), {})
+cnt: 1, ((T([3, 1155, 800], f16, stride=(1439744, 1216, 1)), T([3, 1155, 800], f16)), {})
+cnt: 1, ((T([3, 799, 1115], f16, stride=(1439744, 1216, 1)), T([3, 799, 1115], f16)), {})
+cnt: 16, ((T([0], f16), T([0], f16)), {})
+Operator: aten.div.Tensor
+cnt: 1, ((T([3, 427, 640], f16, stride=(1, 1920, 3)), T([3, 1, 1], f16)), {})
+cnt: 1, ((T([3, 612, 612], f16, stride=(1, 1836, 3)), T([3, 1, 1], f16)), {})
+cnt: 1, ((T([3, 640, 443], f16, stride=(1, 1329, 3)), T([3, 1, 1], f16)), {})
+cnt: 1, ((T([3, 459, 640], f16, stride=(1, 1920, 3)), T([3, 1, 1], f16)), {})
+cnt: 4, ((T([1438452, 1], f16, stride=(4, 4)), 1.0), {})
+cnt: 2, ((T([0], f32), 224), {})
+cnt: 4, ((T([0, 91], f16), 10.0), {})
+cnt: 4, ((T([0, 91], f16), 5.0), {})
+cnt: 8, ((T([], f32), T([], f32)), {})
+cnt: 20, ((T([], f16), 0), {})
+cnt: 4, ((T([], i64), 0), {})
+cnt: 10, ((T([], f32), 4), {})
+Operator: aten.eq.Scalar
+cnt: 2, ((T([0], i64), 0), {})
+cnt: 2, ((T([0], i64), 1), {})
+cnt: 2, ((T([0], i64), 2), {})
+cnt: 2, ((T([0], i64), 3), {})
+Operator: aten.exp.default
+cnt: 2, ((T([1438452, 1], f16),), {})
+cnt: 2, ((T([0, 91], f16),), {})
+Operator: aten.fill_.Scalar
+cnt: 2, ((T([], i64), 4), {})
+cnt: 2, ((T([], i64), 8), {})
+cnt: 2, ((T([], i64), 16), {})
+cnt: 2, ((T([], i64), 32), {})
+cnt: 1, ((T([], i64), 62), {})
+cnt: 1, ((T([], i64), 64), {})
+Operator: aten.floor.default
+cnt: 2, ((T([0], f32),), {})
+Operator: aten.ge.Scalar
+cnt: 8, ((T([5000], f16), 0.001), {})
+cnt: 4, ((T([0], f16), 0.0), {})
+cnt: 8, ((T([0], f16), 0.01), {})
+cnt: 8, ((T([0, 182], f16), 0), {})
+Operator: aten.gt.Scalar
+cnt: 4, ((T([0], f16), 0.05), {})
+Operator: aten.index.Tensor
+cnt: 1, ((T([4, 359613], f16), [T([4, 1], i64), T([4, 5000], i64)]), {})
+cnt: 1, ((T([4, 359613], i64, stride=(0, 1)), [T([4, 1], i64), T([4, 5000], i64)]), {})
+cnt: 1, ((T([4, 359613, 4], f16), [T([4, 1], i64), T([4, 5000], i64)]), {})
+cnt: 4, ((T([5000, 4], f16), [T([0], i64)]), {})
+cnt: 4, ((T([5000], f16), [T([0], i64)]), {})
+cnt: 4, ((T([5000], i64), [T([0], i64)]), {})
+cnt: 20, ((T([0, 4], f16), [T([0], i64)]), {})
+cnt: 20, ((T([0], f16), [T([0], i64)]), {})
+cnt: 16, ((T([0], i64), [T([0], i64)]), {})
+cnt: 8, ((T([0, 5], f16), [T([0], i64)]), {})
+cnt: 1, ((T([0, 91, 28, 28], f16), [T([0], i64), T([0], i64)]), {})
+cnt: 4, ((T([0, 256, 7, 7], f16), [T([0], i64)]), {})
+Operator: aten.index_put.default
+cnt: 3, ((T([0, 256, 7, 7], f16), [T([0], i64)], T([0, 256, 7, 7], f16)), {})
+Operator: aten.index_put_.default
+cnt: 4, ((T([0, 256, 7, 7], f16), [T([0], i64)], T([0, 256, 7, 7], f16)), {})
+cnt: 4, ((T([0, 256, 14, 14], f16), [T([0], i64)], T([0, 256, 14, 14], f16)), {})
+Operator: aten.le.Scalar
+cnt: 2, ((T([0, 182], f16), 799), {})
+cnt: 1, ((T([0, 182], f16), 1115), {})
+cnt: 1, ((T([0, 182], f16), 1155), {})
+cnt: 3, ((T([0, 182], f16), 800), {})
+cnt: 1, ((T([0, 182], f16), 1199), {})
+cnt: 2, ((T([0, 91], f16), 4.135166556742356), {})
+Operator: aten.log2.default
+cnt: 20, ((T([], f32),), {})
+cnt: 2, ((T([0], f32),), {})
+Operator: aten.logical_and_.default
+cnt: 8, ((T([0, 182], b8), T([0, 182], b8)), {})
+Operator: aten.max.default
+cnt: 4, ((T([2], i64),), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([4, 64, 592, 608], f16), [3, 3], [2, 2], [1, 1]), {})
+cnt: 1, ((T([4, 256, 37, 38], f16), [1, 1], [2, 2]), {})
+Operator: aten.min.default
+cnt: 4, ((T([2], i64),), {})
+Operator: aten.minimum.default
+cnt: 4, ((T([], f32), T([], f32)), {})
+Operator: aten.mm.default
+cnt: 1, ((T([0, 364], f16), T([364, 1024], f16)), {})
+cnt: 1, ((T([364, 0], f16), T([0, 1024], f16)), {})
+cnt: 1, ((T([0, 91], f16), T([91, 1024], f16)), {})
+cnt: 1, ((T([91, 0], f16), T([0, 1024], f16)), {})
+cnt: 1, ((T([0, 1024], f16), T([1024, 1024], f16)), {})
+cnt: 1, ((T([1024, 0], f16), T([0, 1024], f16)), {})
+cnt: 1, ((T([0, 1024], f16), T([1024, 12544], f16)), {})
+cnt: 1, ((T([1024, 0], f16), T([0, 12544], f16)), {})
+Operator: aten.mul.Tensor
+cnt: 4, ((T([], f32), 800.0), {})
+cnt: 4, ((T([], f32), 1333.0), {})
+cnt: 14, ((T([1, 64, 1, 1], f16), T([1, 64, 1, 1], f16)), {})
+cnt: 1, ((T([4, 64, 592, 608], f16), T([1, 64, 1, 1], f16)), {})
+cnt: 6, ((T([4, 64, 296, 304], f16), T([1, 64, 1, 1], f16)), {})
+cnt: 32, ((T([1, 256, 1, 1], f16), T([1, 256, 1, 1], f16)), {})
+cnt: 4, ((T([4, 256, 296, 304], f16), T([1, 256, 1, 1], f16)), {})
+cnt: 16, ((T([1, 128, 1, 1], f16), T([1, 128, 1, 1], f16)), {})
+cnt: 2, ((T([4, 128, 296, 304], f16), T([1, 128, 1, 1], f16)), {})
+cnt: 14, ((T([4, 128, 148, 152], f16), T([1, 128, 1, 1], f16)), {})
+cnt: 22, ((T([1, 512, 1, 1], f16), T([1, 512, 1, 1], f16)), {})
+cnt: 10, ((T([4, 512, 148, 152], f16), T([1, 512, 1, 1], f16)), {})
+cnt: 2, ((T([4, 256, 148, 152], f16), T([1, 256, 1, 1], f16)), {})
+cnt: 22, ((T([4, 256, 74, 76], f16), T([1, 256, 1, 1], f16)), {})
+cnt: 14, ((T([1, 1024, 1, 1], f16), T([1, 1024, 1, 1], f16)), {})
+cnt: 14, ((T([4, 1024, 74, 76], f16), T([1, 1024, 1, 1], f16)), {})
+cnt: 2, ((T([4, 512, 74, 76], f16), T([1, 512, 1, 1], f16)), {})
+cnt: 10, ((T([4, 512, 37, 38], f16), T([1, 512, 1, 1], f16)), {})
+cnt: 8, ((T([1, 2048, 1, 1], f16), T([1, 2048, 1, 1], f16)), {})
+cnt: 8, ((T([4, 2048, 37, 38], f16), T([1, 2048, 1, 1], f16)), {})
+cnt: 1, ((T([304], i32), T([], i64)), {})
+cnt: 1, ((T([296], i32), T([], i64)), {})
+cnt: 1, ((T([152], i32), T([], i64)), {})
+cnt: 1, ((T([148], i32), T([], i64)), {})
+cnt: 1, ((T([76], i32), T([], i64)), {})
+cnt: 1, ((T([74], i32), T([], i64)), {})
+cnt: 1, ((T([38], i32), T([], i64)), {})
+cnt: 1, ((T([37], i32), T([], i64)), {})
+cnt: 2, ((T([19], i32), T([], i64)), {})
+cnt: 2, ((T([1438452], f16), 0.5), {})
+cnt: 4, ((T([1438452, 1], f16), T([1438452, 1], f16)), {})
+cnt: 2, ((T([], f16), T([1438452, 1], f16)), {})
+cnt: 8, ((T([0], f32), T([0], f32)), {})
+cnt: 18, ((T([0], f16), 0.5), {})
+cnt: 8, ((T([0, 91], f16), T([0, 1], f16)), {})
+cnt: 2, ((T([], f16), T([0, 91], f16)), {})
+cnt: 32, ((T([0], f16), T([], f32)), {})
+cnt: 2, ((T([0, 91], f16), T([], f16)), {})
+cnt: 2, ((T([0, 91], f16), T([0, 91], f16)), {})
+Operator: aten.mul_.Tensor
+cnt: 8, ((T([0], f16), 1.0714285714285714), {})
+Operator: aten.neg.default
+cnt: 2, ((T([0, 91], f16),), {})
+Operator: aten.new_empty.default
+cnt: 1, ((T([0, 1, 30, 30], f16), [0, 1, 427, 640]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 1, ((T([0, 1, 30, 30], f16), [0, 1, 612, 612]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 1, ((T([0, 1, 30, 30], f16), [0, 1, 640, 443]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+cnt: 1, ((T([0, 1, 30, 30], f16), [0, 1, 459, 640]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_full.default
+cnt: 1, ((T([3, 799, 1199], f16), [4, 3, 1184, 1216], 0), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda', 'pin_memory': False})
+Operator: aten.new_zeros.default
+cnt: 12, ((T([0], f16), [0]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 12, ((T([0, 4], f16), [0, 4]), {'dtype': f16, 'layout': torch.strided, 'device': 'cuda'})
+Operator: aten.nonzero.default
+cnt: 4, ((T([5000], b8),), {})
+cnt: 20, ((T([0], b8),), {})
+Operator: aten.reciprocal.default
+cnt: 8, ((T([], f32),), {})
+Operator: aten.relu.default
+cnt: 2, ((T([0, 1024], f16),), {})
+Operator: aten.relu_.default
+cnt: 1, ((T([4, 64, 592, 608], f16),), {})
+cnt: 6, ((T([4, 64, 296, 304], f16),), {})
+cnt: 4, ((T([4, 256, 296, 304], f16),), {})
+cnt: 1, ((T([4, 128, 296, 304], f16),), {})
+cnt: 7, ((T([4, 128, 148, 152], f16),), {})
+cnt: 4, ((T([4, 512, 148, 152], f16),), {})
+cnt: 2, ((T([4, 256, 148, 152], f16),), {})
+cnt: 12, ((T([4, 256, 74, 76], f16),), {})
+cnt: 6, ((T([4, 1024, 74, 76], f16),), {})
+cnt: 1, ((T([4, 512, 74, 76], f16),), {})
+cnt: 5, ((T([4, 512, 37, 38], f16),), {})
+cnt: 3, ((T([4, 2048, 37, 38], f16),), {})
+cnt: 1, ((T([4, 256, 37, 38], f16),), {})
+cnt: 1, ((T([4, 256, 19, 19], f16),), {})
+cnt: 4, ((T([0, 256, 14, 14], f16),), {})
+cnt: 1, ((T([0, 256, 28, 28], f16),), {})
+Operator: aten.round.default
+cnt: 16, ((T([], f32),), {})
+Operator: aten.rsqrt.default
+cnt: 7, ((T([1, 64, 1, 1], f16),), {})
+cnt: 16, ((T([1, 256, 1, 1], f16),), {})
+cnt: 8, ((T([1, 128, 1, 1], f16),), {})
+cnt: 11, ((T([1, 512, 1, 1], f16),), {})
+cnt: 7, ((T([1, 1024, 1, 1], f16),), {})
+cnt: 4, ((T([1, 2048, 1, 1], f16),), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([4, 5000], f16),), {})
+cnt: 1, ((T([0, 91, 28, 28], f16),), {})
+Operator: aten.slice_backward.default
+cnt: 4, ((T([0, 90], f16), [0, 91], 1, 1, 9223372036854775807, 1), {})
+cnt: 4, ((T([0, 91], f16), [0, 91], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([0, 363], f16), [0, 364], 1, 1, 9223372036854775807, 1), {})
+cnt: 8, ((T([0, 364], f16), [0, 364], 0, 0, 9223372036854775807, 1), {})
+cnt: 4, ((T([0, 182], f16), [0, 364], 1, 1, 9223372036854775807, 2), {})
+cnt: 4, ((T([0, 182], f16), [0, 364], 1, 0, 9223372036854775807, 2), {})
+cnt: 1, ((T([0, 91], f16), [0, 364], 1, 3, 9223372036854775807, 4), {})
+cnt: 1, ((T([0, 91], f16), [0, 364], 1, 2, 9223372036854775807, 4), {})
+cnt: 1, ((T([0, 91], f16), [0, 364], 1, 1, 9223372036854775807, 4), {})
+cnt: 1, ((T([0, 91], f16), [0, 364], 1, 0, 9223372036854775807, 4), {})
+Operator: aten.split_with_sizes.default
+cnt: 1, ((T([4, 359613], f16), [269952, 67488, 16872, 4218, 1083], 1), {})
+cnt: 1, ((T([0, 364], f16), [0, 0, 0, 0]), {})
+cnt: 1, ((T([0, 91], f16), [0, 0, 0, 0]), {})
+cnt: 1, ((T([0, 1, 28, 28], f16), [0, 0, 0, 0]), {})
+Operator: aten.sqrt.default
+cnt: 2, ((T([0], f32),), {})
+Operator: aten.stack.default
+cnt: 1, (([T([89984], i32), T([89984], i32), T([89984], i32), T([89984], i32)], 1), {})
+cnt: 1, (([T([22496], i32), T([22496], i32), T([22496], i32), T([22496], i32)], 1), {})
+cnt: 1, (([T([5624], i32), T([5624], i32), T([5624], i32), T([5624], i32)], 1), {})
+cnt: 1, (([T([1406], i32), T([1406], i32), T([1406], i32), T([1406], i32)], 1), {})
+cnt: 1, (([T([361], i32), T([361], i32), T([361], i32), T([361], i32)], 1), {})
+cnt: 1, (([T([1438452, 1], f16), T([1438452, 1], f16), T([1438452, 1], f16), T([1438452, 1], f16)], 2), {})
+cnt: 4, (([T([5000, 2], f16), T([5000, 2], f16)], 2), {})
+cnt: 1, (([T([0, 91], f16), T([0, 91], f16), T([0, 91], f16), T([0, 91], f16)], 2), {})
+cnt: 4, (([T([0, 182], f16), T([0, 182], f16)], 2), {})
+cnt: 8, (([T([0], f16), T([0], f16), T([0], f16), T([0], f16)], 1), {})
+Operator: aten.sub.Tensor
+cnt: 1, ((T([3, 427, 640], f16, stride=(1, 1920, 3)), T([3, 1, 1], f16)), {})
+cnt: 1, ((T([3, 612, 612], f16, stride=(1, 1836, 3)), T([3, 1, 1], f16)), {})
+cnt: 1, ((T([3, 640, 443], f16, stride=(1, 1329, 3)), T([3, 1, 1], f16)), {})
+cnt: 1, ((T([3, 459, 640], f16, stride=(1, 1920, 3)), T([3, 1, 1], f16)), {})
+cnt: 7, ((T([1, 64, 1, 1], f16), T([1, 64, 1, 1], f16)), {})
+cnt: 16, ((T([1, 256, 1, 1], f16), T([1, 256, 1, 1], f16)), {})
+cnt: 8, ((T([1, 128, 1, 1], f16), T([1, 128, 1, 1], f16)), {})
+cnt: 11, ((T([1, 512, 1, 1], f16), T([1, 512, 1, 1], f16)), {})
+cnt: 7, ((T([1, 1024, 1, 1], f16), T([1, 1024, 1, 1], f16)), {})
+cnt: 4, ((T([1, 2048, 1, 1], f16), T([1, 2048, 1, 1], f16)), {})
+cnt: 2, ((T([1438452], f16, stride=(4,)), T([1438452], f16, stride=(4,))), {})
+cnt: 2, ((T([1438452, 1], f16), T([1438452, 1], f16)), {})
+cnt: 8, ((T([5000], f16, stride=(4,)), T([5000], f16, stride=(4,))), {})
+cnt: 16, ((T([0], f32), T([0], f32)), {})
+cnt: 2, ((T([0], i64), 2), {})
+cnt: 26, ((T([0], f16), T([0], f16)), {})
+cnt: 2, ((T([0, 91], f16), T([0, 91], f16)), {})
+Operator: aten.sum.SymInt
+cnt: 1, ((T([0, 364], f16), [0], True), {})
+cnt: 1, ((T([0, 91], f16), [0], True), {})
+cnt: 2, ((T([0, 1024], f16), [0], True), {})
+Operator: aten.sum.default
+cnt: 4, ((T([0, 4], f16),), {})
+cnt: 4, ((T([0], i64),), {})
+cnt: 4, ((T([0], f16),), {})
+cnt: 1, ((T([0, 1, 427, 640], f16),), {})
+cnt: 1, ((T([0, 1, 612, 612], f16),), {})
+cnt: 1, ((T([0, 1, 640, 443], f16),), {})
+cnt: 1, ((T([0, 1, 459, 640], f16),), {})
+Operator: aten.threshold_backward.default
+cnt: 2, ((T([0, 1024], f16), T([0, 1024], f16), 0), {})
+cnt: 3, ((T([4, 2048, 37, 38], f16), T([4, 2048, 37, 38], f16), 0), {})
+cnt: 5, ((T([4, 512, 37, 38], f16), T([4, 512, 37, 38], f16), 0), {})
+cnt: 1, ((T([4, 512, 74, 76], f16), T([4, 512, 74, 76], f16), 0), {})
+cnt: 6, ((T([4, 1024, 74, 76], f16), T([4, 1024, 74, 76], f16), 0), {})
+cnt: 11, ((T([4, 256, 74, 76], f16), T([4, 256, 74, 76], f16), 0), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), T([4, 256, 148, 152], f16), 0), {})
+cnt: 4, ((T([4, 512, 148, 152], f16), T([4, 512, 148, 152], f16), 0), {})
+cnt: 7, ((T([4, 128, 148, 152], f16), T([4, 128, 148, 152], f16), 0), {})
+cnt: 1, ((T([4, 128, 296, 304], f16), T([4, 128, 296, 304], f16), 0), {})
+Operator: aten.topk.default
+cnt: 1, ((T([4, 269952], f16, stride=(359613, 1)), 1000, 1), {})
+cnt: 1, ((T([4, 67488], f16, stride=(359613, 1)), 1000, 1), {})
+cnt: 1, ((T([4, 16872], f16, stride=(359613, 1)), 1000, 1), {})
+cnt: 1, ((T([4, 4218], f16, stride=(359613, 1)), 1000, 1), {})
+cnt: 1, ((T([4, 1083], f16, stride=(359613, 1)), 1000, 1), {})
+Operator: aten.unbind.int
+cnt: 1, ((T([4, 5000, 4], f16),), {})
+cnt: 1, ((T([4, 5000], f16),), {})
+cnt: 1, ((T([4, 5000], i64),), {})
+cnt: 24, ((T([0, 1], i64), 1), {})
+cnt: 8, ((T([0, 4], f16), 1), {})
+cnt: 4, ((T([0, 182, 2], f16), 2), {})
+cnt: 1, ((T([0, 91, 4], f16), 2), {})
+Operator: aten.upsample_bilinear2d.vec
+cnt: 1, ((T([1, 3, 427, 640], f16, stride=(3, 1, 1920, 3)), [799, 1199], False, None), {})
+cnt: 1, ((T([1, 3, 612, 612], f16, stride=(3, 1, 1836, 3)), [800, 800], False, None), {})
+cnt: 1, ((T([1, 3, 640, 443], f16, stride=(3, 1, 1329, 3)), [1155, 800], False, None), {})
+cnt: 1, ((T([1, 3, 459, 640], f16, stride=(3, 1, 1920, 3)), [799, 1115], False, None), {})
+Operator: aten.upsample_nearest2d.vec
+cnt: 1, ((T([4, 256, 37, 38], f16), [74, 76], None), {})
+cnt: 1, ((T([4, 256, 74, 76], f16), [148, 152], None), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), [296, 304], None), {})
+Operator: aten.upsample_nearest2d_backward.vec
+cnt: 1, ((T([4, 256, 296, 304], f16), [296, 304], [4, 256, 148, 152], None), {})
+cnt: 1, ((T([4, 256, 148, 152], f16), [148, 152], [4, 256, 74, 76], None), {})
+cnt: 1, ((T([4, 256, 74, 76], f16), [74, 76], [4, 256, 37, 38], None), {})
+Operator: aten.where.self
+cnt: 8, ((T([0, 182], b8), T([0, 182], f16), T([], f16)), {})
+cnt: 2, ((T([0, 91], b8), T([0, 91], f16), T([], f16)), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/yolov3_training.txt b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/yolov3_training.txt
new file mode 100644
index 0000000000000..c8ad368382fc8
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/yolov3_training.txt
@@ -0,0 +1,261 @@
+Operator: aten._to_copy.default
+cnt: 1, ((T([1, 1, 12, 16, 2], i64),), {'dtype': f32})
+cnt: 3, ((T([3, 2], f32),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 3, ((T([1, 3, 1, 1, 2], f32),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 1, ((T([1, 1, 24, 32, 2], i64),), {'dtype': f32})
+cnt: 1, ((T([1, 1, 48, 64, 2], i64),), {'dtype': f32})
+cnt: 2, ((T([8, 3, 48, 64, 2], f16),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 2, ((T([8, 3, 48, 64, 2], f32),), {'dtype': f16})
+cnt: 2, ((T([8, 3, 24, 32, 2], f16),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 2, ((T([8, 3, 24, 32, 2], f32),), {'dtype': f16})
+cnt: 2, ((T([8, 3, 12, 16, 2], f16),), {'dtype': f32, 'layout': torch.strided, 'device': 'cuda'})
+cnt: 2, ((T([8, 3, 12, 16, 2], f32),), {'dtype': f16})
+Operator: aten._unsafe_view.default
+cnt: 1, ((T([8, 3, 85, 48, 64], f16), [8, 255, 48, 64]), {})
+cnt: 1, ((T([8, 3, 85, 24, 32], f16), [8, 255, 24, 32]), {})
+cnt: 1, ((T([8, 3, 85, 12, 16], f16), [8, 255, 12, 16]), {})
+Operator: aten.add.Tensor
+cnt: 2, ((T([8, 64, 192, 256], f16), T([8, 64, 192, 256], f16)), {})
+cnt: 4, ((T([8, 128, 96, 128], f16), T([8, 128, 96, 128], f16)), {})
+cnt: 16, ((T([8, 256, 48, 64], f16), T([8, 256, 48, 64], f16)), {})
+cnt: 16, ((T([8, 512, 24, 32], f16), T([8, 512, 24, 32], f16)), {})
+cnt: 8, ((T([8, 1024, 12, 16], f16), T([8, 1024, 12, 16], f16)), {})
+cnt: 1, ((T([8, 3, 12, 16, 2], f16), T([1, 1, 12, 16, 2], f32)), {})
+cnt: 1, ((T([8, 3, 24, 32, 2], f16), T([1, 1, 24, 32, 2], f32)), {})
+cnt: 1, ((T([8, 3, 48, 64, 2], f16), T([1, 1, 48, 64, 2], f32)), {})
+cnt: 2, ((T([], f16), 0), {})
+cnt: 3, ((T([], f16), T([], f16)), {})
+cnt: 3, ((T([8, 3, 48, 64, 85], f16), T([8, 3, 48, 64, 85], f16)), {})
+cnt: 1, ((T([8, 3, 48, 64, 85], f16, stride=(0, 0, 0, 0, 0)), T([8, 3, 48, 64, 85], f16)), {})
+cnt: 3, ((T([8, 3, 24, 32, 85], f16), T([8, 3, 24, 32, 85], f16)), {})
+cnt: 1, ((T([8, 3, 24, 32, 85], f16, stride=(0, 0, 0, 0, 0)), T([8, 3, 24, 32, 85], f16)), {})
+cnt: 1, ((T([8, 256, 24, 32], f16), T([8, 256, 24, 32], f16)), {})
+cnt: 3, ((T([8, 3, 12, 16, 85], f16), T([8, 3, 12, 16, 85], f16)), {})
+cnt: 1, ((T([8, 3, 12, 16, 85], f16, stride=(0, 0, 0, 0, 0)), T([8, 3, 12, 16, 85], f16)), {})
+cnt: 3, ((T([8, 512, 12, 16], f16), T([8, 512, 12, 16], f16)), {})
+cnt: 1, ((T([8, 512, 12, 16], f16, stride=(393216, 192, 16, 1)), T([8, 512, 12, 16], f16)), {})
+cnt: 1, ((T([8, 512, 24, 32], f16, stride=(589824, 768, 32, 1)), T([8, 512, 24, 32], f16)), {})
+cnt: 1, ((T([8, 256, 48, 64], f16, stride=(1179648, 3072, 64, 1)), T([8, 256, 48, 64], f16)), {})
+Operator: aten.cat.default
+cnt: 1, (([T([8, 512, 12, 16], f16), T([8, 512, 12, 16], f16), T([8, 512, 12, 16], f16), T([8, 512, 12, 16], f16)], 1), {})
+cnt: 1, (([T([8, 256, 24, 32], f16), T([8, 512, 24, 32], f16)], 1), {})
+cnt: 1, (([T([8, 128, 48, 64], f16), T([8, 256, 48, 64], f16)], 1), {})
+cnt: 1, (([T([8, 576, 85], f16), T([8, 2304, 85], f16), T([8, 9216, 85], f16)], 1), {})
+Operator: aten.clone.default
+cnt: 1, ((T([8, 3, 384, 512], f16),), {})
+cnt: 1, ((T([8, 3, 12, 16, 85], f16),), {})
+cnt: 1, ((T([8, 3, 24, 32, 85], f16),), {})
+cnt: 1, ((T([8, 3, 48, 64, 85], f16),), {})
+Operator: aten.convolution.default
+cnt: 1, ((T([8, 3, 384, 512], f16), T([32, 3, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 32, 384, 512], f16), T([64, 32, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 64, 192, 256], f16), T([32, 64, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 32, 192, 256], f16), T([64, 32, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 64, 192, 256], f16), T([128, 64, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([8, 128, 96, 128], f16), T([64, 128, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 2, ((T([8, 64, 96, 128], f16), T([128, 64, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 128, 96, 128], f16), T([256, 128, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 10, ((T([8, 256, 48, 64], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 11, ((T([8, 128, 48, 64], f16), T([256, 128, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 256, 48, 64], f16), T([512, 256, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 10, ((T([8, 512, 24, 32], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 11, ((T([8, 256, 24, 32], f16), T([512, 256, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 512, 24, 32], f16), T([1024, 512, 3, 3], f16), None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([8, 1024, 12, 16], f16), T([512, 1024, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 7, ((T([8, 512, 12, 16], f16), T([1024, 512, 3, 3], f16), None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 2048, 12, 16], f16), T([512, 2048, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 1024, 12, 16], f16), T([255, 1024, 1, 1], f16), T([255], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 512, 12, 16], f16), T([256, 512, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 768, 24, 32], f16), T([256, 768, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 512, 24, 32], f16), T([255, 512, 1, 1], f16), T([255], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 256, 24, 32], f16), T([128, 256, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 384, 48, 64], f16), T([128, 384, 1, 1], f16), None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+cnt: 1, ((T([8, 256, 48, 64], f16), T([255, 256, 1, 1], f16), T([255], f16), [1, 1], [0, 0], [1, 1], False, [0, 0], 1), {})
+Operator: aten.convolution_backward.default
+cnt: 1, ((T([8, 255, 48, 64], f16), T([8, 256, 48, 64], f16), T([255, 256, 1, 1], f16), [255], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 11, ((T([8, 256, 48, 64], f16), T([8, 128, 48, 64], f16), T([256, 128, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 10, ((T([8, 128, 48, 64], f16), T([8, 256, 48, 64], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 128, 48, 64], f16), T([8, 384, 48, 64], f16), T([128, 384, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 128, 24, 32], f16), T([8, 256, 24, 32], f16), T([128, 256, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 255, 24, 32], f16), T([8, 512, 24, 32], f16), T([255, 512, 1, 1], f16), [255], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 11, ((T([8, 512, 24, 32], f16), T([8, 256, 24, 32], f16), T([512, 256, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 10, ((T([8, 256, 24, 32], f16), T([8, 512, 24, 32], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 256, 24, 32], f16), T([8, 768, 24, 32], f16), T([256, 768, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 256, 12, 16], f16), T([8, 512, 12, 16], f16), T([256, 512, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 255, 12, 16], f16), T([8, 1024, 12, 16], f16), T([255, 1024, 1, 1], f16), [255], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, True]), {})
+cnt: 7, ((T([8, 1024, 12, 16], f16), T([8, 512, 12, 16], f16), T([1024, 512, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 7, ((T([8, 512, 12, 16], f16), T([8, 1024, 12, 16], f16), T([512, 1024, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 512, 12, 16], f16), T([8, 2048, 12, 16], f16), T([512, 2048, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 1024, 12, 16], f16), T([8, 512, 24, 32], f16), T([1024, 512, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 512, 24, 32], f16), T([8, 256, 48, 64], f16), T([512, 256, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 256, 48, 64], f16), T([8, 128, 96, 128], f16), T([256, 128, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([8, 128, 96, 128], f16), T([8, 64, 96, 128], f16), T([128, 64, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 2, ((T([8, 64, 96, 128], f16), T([8, 128, 96, 128], f16), T([64, 128, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 128, 96, 128], f16), T([8, 64, 192, 256], f16), T([128, 64, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 64, 192, 256], f16), T([8, 32, 192, 256], f16), T([64, 32, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 32, 192, 256], f16), T([8, 64, 192, 256], f16), T([32, 64, 1, 1], f16), [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 64, 192, 256], f16), T([8, 32, 384, 512], f16), T([64, 32, 3, 3], f16), [0], [2, 2], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), {})
+cnt: 1, ((T([8, 32, 384, 512], f16), T([8, 3, 384, 512], f16), T([32, 3, 3, 3], f16), [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [False, True, False]), {})
+Operator: aten.copy_.default
+cnt: 1, ((T([8, 3, 384, 512], f16), T([8, 3, 384, 512], f16)), {})
+cnt: 2, ((T([8, 3, 12, 16, 2], f16, stride=(48960, 16320, 1360, 85, 1)), T([8, 3, 12, 16, 2], f32)), {})
+cnt: 1, ((T([8, 3, 12, 16, 4], f16, stride=(48960, 16320, 1360, 85, 1)), T([8, 3, 12, 16, 4], f16, stride=(48960, 16320, 1360, 85, 1))), {})
+cnt: 2, ((T([8, 3, 24, 32, 2], f16, stride=(195840, 65280, 2720, 85, 1)), T([8, 3, 24, 32, 2], f32)), {})
+cnt: 1, ((T([8, 3, 24, 32, 4], f16, stride=(195840, 65280, 2720, 85, 1)), T([8, 3, 24, 32, 4], f16, stride=(195840, 65280, 2720, 85, 1))), {})
+cnt: 2, ((T([8, 3, 48, 64, 2], f16, stride=(783360, 261120, 5440, 85, 1)), T([8, 3, 48, 64, 2], f32)), {})
+cnt: 1, ((T([8, 3, 48, 64, 4], f16, stride=(783360, 261120, 5440, 85, 1)), T([8, 3, 48, 64, 4], f16, stride=(783360, 261120, 5440, 85, 1))), {})
+cnt: 1, ((T([8, 3, 48, 64, 85], f16), T([8, 3, 48, 64, 85], f16, stride=(0, 0, 0, 0, 0))), {})
+cnt: 1, ((T([8, 3, 48, 64, 81], f16, stride=(783360, 261120, 5440, 85, 1)), T([8, 3, 48, 64, 81], f16)), {})
+cnt: 4, ((T([8, 3, 48, 64, 85], f16), T([8, 3, 48, 64, 85], f16)), {})
+cnt: 3, ((T([8, 3, 48, 64, 4], f16, stride=(783360, 261120, 5440, 85, 1)), T([8, 3, 48, 64, 4], f16)), {})
+cnt: 2, ((T([8, 3, 48, 64, 2], f16, stride=(783360, 261120, 5440, 85, 1)), T([8, 3, 48, 64, 2], f16)), {})
+cnt: 1, ((T([8, 3, 24, 32, 85], f16), T([8, 3, 24, 32, 85], f16, stride=(0, 0, 0, 0, 0))), {})
+cnt: 1, ((T([8, 3, 24, 32, 81], f16, stride=(195840, 65280, 2720, 85, 1)), T([8, 3, 24, 32, 81], f16)), {})
+cnt: 4, ((T([8, 3, 24, 32, 85], f16), T([8, 3, 24, 32, 85], f16)), {})
+cnt: 3, ((T([8, 3, 24, 32, 4], f16, stride=(195840, 65280, 2720, 85, 1)), T([8, 3, 24, 32, 4], f16)), {})
+cnt: 2, ((T([8, 3, 24, 32, 2], f16, stride=(195840, 65280, 2720, 85, 1)), T([8, 3, 24, 32, 2], f16)), {})
+cnt: 1, ((T([8, 3, 12, 16, 85], f16), T([8, 3, 12, 16, 85], f16, stride=(0, 0, 0, 0, 0))), {})
+cnt: 1, ((T([8, 3, 12, 16, 81], f16, stride=(48960, 16320, 1360, 85, 1)), T([8, 3, 12, 16, 81], f16)), {})
+cnt: 4, ((T([8, 3, 12, 16, 85], f16), T([8, 3, 12, 16, 85], f16)), {})
+cnt: 3, ((T([8, 3, 12, 16, 4], f16, stride=(48960, 16320, 1360, 85, 1)), T([8, 3, 12, 16, 4], f16)), {})
+cnt: 2, ((T([8, 3, 12, 16, 2], f16, stride=(48960, 16320, 1360, 85, 1)), T([8, 3, 12, 16, 2], f16)), {})
+Operator: aten.div.Tensor
+cnt: 2, ((T([], f16), 8225280), {})
+cnt: 2, ((T([], f16), 391680), {})
+cnt: 2, ((T([], f16), 1566720), {})
+cnt: 2, ((T([], f16), 6266880), {})
+cnt: 2, ((T([], f16), 3), {})
+cnt: 2, ((T([], f16), 2), {})
+Operator: aten.exp.default
+cnt: 1, ((T([8, 3, 12, 16, 2], f16, stride=(48960, 16320, 1360, 85, 1)),), {})
+cnt: 1, ((T([8, 3, 24, 32, 2], f16, stride=(195840, 65280, 2720, 85, 1)),), {})
+cnt: 1, ((T([8, 3, 48, 64, 2], f16, stride=(783360, 261120, 5440, 85, 1)),), {})
+Operator: aten.leaky_relu_.default
+cnt: 1, ((T([8, 32, 384, 512], f16), 0.1), {})
+cnt: 2, ((T([8, 64, 192, 256], f16), 0.1), {})
+cnt: 1, ((T([8, 32, 192, 256], f16), 0.1), {})
+cnt: 3, ((T([8, 128, 96, 128], f16), 0.1), {})
+cnt: 2, ((T([8, 64, 96, 128], f16), 0.1), {})
+cnt: 12, ((T([8, 256, 48, 64], f16), 0.1), {})
+cnt: 11, ((T([8, 128, 48, 64], f16), 0.1), {})
+cnt: 12, ((T([8, 512, 24, 32], f16), 0.1), {})
+cnt: 11, ((T([8, 256, 24, 32], f16), 0.1), {})
+cnt: 8, ((T([8, 1024, 12, 16], f16), 0.1), {})
+cnt: 8, ((T([8, 512, 12, 16], f16), 0.1), {})
+cnt: 1, ((T([8, 256, 12, 16], f16), 0.1), {})
+cnt: 1, ((T([8, 128, 24, 32], f16), 0.1), {})
+Operator: aten.leaky_relu_backward.default
+cnt: 12, ((T([8, 256, 48, 64], f16), T([8, 256, 48, 64], f16), 0.1, True), {})
+cnt: 11, ((T([8, 128, 48, 64], f16), T([8, 128, 48, 64], f16), 0.1, True), {})
+cnt: 1, ((T([8, 128, 24, 32], f16), T([8, 128, 24, 32], f16), 0.1, True), {})
+cnt: 12, ((T([8, 512, 24, 32], f16), T([8, 512, 24, 32], f16), 0.1, True), {})
+cnt: 11, ((T([8, 256, 24, 32], f16), T([8, 256, 24, 32], f16), 0.1, True), {})
+cnt: 1, ((T([8, 256, 12, 16], f16), T([8, 256, 12, 16], f16), 0.1, True), {})
+cnt: 8, ((T([8, 1024, 12, 16], f16), T([8, 1024, 12, 16], f16), 0.1, True), {})
+cnt: 8, ((T([8, 512, 12, 16], f16), T([8, 512, 12, 16], f16), 0.1, True), {})
+cnt: 3, ((T([8, 128, 96, 128], f16), T([8, 128, 96, 128], f16), 0.1, True), {})
+cnt: 2, ((T([8, 64, 96, 128], f16), T([8, 64, 96, 128], f16), 0.1, True), {})
+cnt: 2, ((T([8, 64, 192, 256], f16), T([8, 64, 192, 256], f16), 0.1, True), {})
+cnt: 1, ((T([8, 32, 192, 256], f16), T([8, 32, 192, 256], f16), 0.1, True), {})
+cnt: 1, ((T([8, 32, 384, 512], f16), T([8, 32, 384, 512], f16), 0.1, True), {})
+Operator: aten.max_pool2d_with_indices.default
+cnt: 1, ((T([8, 512, 12, 16], f16), [5, 5], [1, 1], [2, 2]), {})
+cnt: 1, ((T([8, 512, 12, 16], f16), [9, 9], [1, 1], [4, 4]), {})
+cnt: 1, ((T([8, 512, 12, 16], f16), [13, 13], [1, 1], [6, 6]), {})
+Operator: aten.max_pool2d_with_indices_backward.default
+cnt: 1, ((T([8, 512, 12, 16], f16, stride=(393216, 192, 16, 1)), T([8, 512, 12, 16], f16), [13, 13], [1, 1], [6, 6], [1, 1], False, T([8, 512, 12, 16], i64)), {})
+cnt: 1, ((T([8, 512, 12, 16], f16, stride=(393216, 192, 16, 1)), T([8, 512, 12, 16], f16), [9, 9], [1, 1], [4, 4], [1, 1], False, T([8, 512, 12, 16], i64)), {})
+cnt: 1, ((T([8, 512, 12, 16], f16, stride=(393216, 192, 16, 1)), T([8, 512, 12, 16], f16), [5, 5], [1, 1], [2, 2], [1, 1], False, T([8, 512, 12, 16], i64)), {})
+Operator: aten.mul.Tensor
+cnt: 1, ((T([8, 3, 12, 16, 2], f16), T([1, 3, 1, 1, 2], f32)), {})
+cnt: 1, ((T([8, 3, 24, 32, 2], f16), T([1, 3, 1, 1, 2], f32)), {})
+cnt: 1, ((T([8, 3, 48, 64, 2], f16), T([1, 3, 1, 1, 2], f32)), {})
+cnt: 1, ((T([8, 3, 48, 64, 4], f16), 8), {})
+cnt: 1, ((T([8, 3, 48, 64, 2], f32), T([1, 3, 1, 1, 2], f32)), {})
+cnt: 1, ((T([8, 3, 48, 64, 2], f16), T([8, 3, 48, 64, 2], f16)), {})
+cnt: 1, ((T([8, 3, 24, 32, 4], f16), 16), {})
+cnt: 1, ((T([8, 3, 24, 32, 2], f32), T([1, 3, 1, 1, 2], f32)), {})
+cnt: 1, ((T([8, 3, 24, 32, 2], f16), T([8, 3, 24, 32, 2], f16)), {})
+cnt: 1, ((T([8, 3, 12, 16, 4], f16), 32), {})
+cnt: 1, ((T([8, 3, 12, 16, 2], f32), T([1, 3, 1, 1, 2], f32)), {})
+cnt: 1, ((T([8, 3, 12, 16, 2], f16), T([8, 3, 12, 16, 2], f16)), {})
+Operator: aten.mul_.Tensor
+cnt: 1, ((T([8, 3, 12, 16, 4], f16, stride=(48960, 16320, 1360, 85, 1)), 32), {})
+cnt: 1, ((T([8, 3, 24, 32, 4], f16, stride=(195840, 65280, 2720, 85, 1)), 16), {})
+cnt: 1, ((T([8, 3, 48, 64, 4], f16, stride=(783360, 261120, 5440, 85, 1)), 8), {})
+Operator: aten.native_batch_norm.default
+cnt: 1, ((T([8, 32, 384, 512], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.03, 0.0001), {})
+cnt: 2, ((T([8, 64, 192, 256], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.03, 0.0001), {})
+cnt: 1, ((T([8, 32, 192, 256], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f16), False, 0.03, 0.0001), {})
+cnt: 3, ((T([8, 128, 96, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.03, 0.0001), {})
+cnt: 2, ((T([8, 64, 96, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f16), False, 0.03, 0.0001), {})
+cnt: 12, ((T([8, 256, 48, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.03, 0.0001), {})
+cnt: 11, ((T([8, 128, 48, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.03, 0.0001), {})
+cnt: 12, ((T([8, 512, 24, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.03, 0.0001), {})
+cnt: 11, ((T([8, 256, 24, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.03, 0.0001), {})
+cnt: 8, ((T([8, 1024, 12, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f16), False, 0.03, 0.0001), {})
+cnt: 8, ((T([8, 512, 12, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f16), False, 0.03, 0.0001), {})
+cnt: 1, ((T([8, 256, 12, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f16), False, 0.03, 0.0001), {})
+cnt: 1, ((T([8, 128, 24, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f16), False, 0.03, 0.0001), {})
+Operator: aten.native_batch_norm_backward.default
+cnt: 12, ((T([8, 256, 48, 64], f16), T([8, 256, 48, 64], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 0.0001, [True, True, True]), {})
+cnt: 11, ((T([8, 128, 48, 64], f16), T([8, 128, 48, 64], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 0.0001, [True, True, True]), {})
+cnt: 1, ((T([8, 128, 24, 32], f16), T([8, 128, 24, 32], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 0.0001, [True, True, True]), {})
+cnt: 12, ((T([8, 512, 24, 32], f16), T([8, 512, 24, 32], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 0.0001, [True, True, True]), {})
+cnt: 11, ((T([8, 256, 24, 32], f16), T([8, 256, 24, 32], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 0.0001, [True, True, True]), {})
+cnt: 1, ((T([8, 256, 12, 16], f16), T([8, 256, 12, 16], f16), T([256], f16), T([256], f16), T([256], f16), T([256], f32), T([256], f32), False, 0.0001, [True, True, True]), {})
+cnt: 8, ((T([8, 1024, 12, 16], f16), T([8, 1024, 12, 16], f16), T([1024], f16), T([1024], f16), T([1024], f16), T([1024], f32), T([1024], f32), False, 0.0001, [True, True, True]), {})
+cnt: 8, ((T([8, 512, 12, 16], f16), T([8, 512, 12, 16], f16), T([512], f16), T([512], f16), T([512], f16), T([512], f32), T([512], f32), False, 0.0001, [True, True, True]), {})
+cnt: 3, ((T([8, 128, 96, 128], f16), T([8, 128, 96, 128], f16), T([128], f16), T([128], f16), T([128], f16), T([128], f32), T([128], f32), False, 0.0001, [True, True, True]), {})
+cnt: 2, ((T([8, 64, 96, 128], f16), T([8, 64, 96, 128], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 0.0001, [True, True, True]), {})
+cnt: 2, ((T([8, 64, 192, 256], f16), T([8, 64, 192, 256], f16), T([64], f16), T([64], f16), T([64], f16), T([64], f32), T([64], f32), False, 0.0001, [True, True, True]), {})
+cnt: 1, ((T([8, 32, 192, 256], f16), T([8, 32, 192, 256], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 0.0001, [True, True, True]), {})
+cnt: 1, ((T([8, 32, 384, 512], f16), T([8, 32, 384, 512], f16), T([32], f16), T([32], f16), T([32], f16), T([32], f32), T([32], f32), False, 0.0001, [True, True, True]), {})
+Operator: aten.new_empty_strided.default
+cnt: 1, ((T([8, 3, 48, 64, 85], f16, stride=(0, 0, 0, 0, 0)), [8, 3, 48, 64, 85], [783360, 261120, 5440, 85, 1]), {})
+cnt: 4, ((T([8, 3, 48, 64, 85], f16), [8, 3, 48, 64, 85], [783360, 261120, 5440, 85, 1]), {})
+cnt: 1, ((T([8, 3, 24, 32, 85], f16, stride=(0, 0, 0, 0, 0)), [8, 3, 24, 32, 85], [195840, 65280, 2720, 85, 1]), {})
+cnt: 4, ((T([8, 3, 24, 32, 85], f16), [8, 3, 24, 32, 85], [195840, 65280, 2720, 85, 1]), {})
+cnt: 1, ((T([8, 3, 12, 16, 85], f16, stride=(0, 0, 0, 0, 0)), [8, 3, 12, 16, 85], [48960, 16320, 1360, 85, 1]), {})
+cnt: 4, ((T([8, 3, 12, 16, 85], f16), [8, 3, 12, 16, 85], [48960, 16320, 1360, 85, 1]), {})
+Operator: aten.new_zeros.default
+cnt: 1, ((T([8, 3, 48, 64, 4], f16), [6266880]), {})
+cnt: 1, ((T([8, 3, 24, 32, 4], f16), [1566720]), {})
+cnt: 1, ((T([8, 3, 12, 16, 4], f16), [391680]), {})
+Operator: aten.sigmoid.default
+cnt: 1, ((T([8, 3, 12, 16, 2], f16, stride=(48960, 16320, 1360, 85, 1)),), {})
+cnt: 1, ((T([8, 3, 24, 32, 2], f16, stride=(195840, 65280, 2720, 85, 1)),), {})
+cnt: 1, ((T([8, 3, 48, 64, 2], f16, stride=(783360, 261120, 5440, 85, 1)),), {})
+Operator: aten.sigmoid_.default
+cnt: 1, ((T([8, 3, 12, 16, 81], f16, stride=(48960, 16320, 1360, 85, 1)),), {})
+cnt: 1, ((T([8, 3, 24, 32, 81], f16, stride=(195840, 65280, 2720, 85, 1)),), {})
+cnt: 1, ((T([8, 3, 48, 64, 81], f16, stride=(783360, 261120, 5440, 85, 1)),), {})
+Operator: aten.sigmoid_backward.default
+cnt: 1, ((T([8, 3, 48, 64, 81], f16), T([8, 3, 48, 64, 81], f16, stride=(783360, 261120, 5440, 85, 1))), {})
+cnt: 1, ((T([8, 3, 48, 64, 2], f16), T([8, 3, 48, 64, 2], f16)), {})
+cnt: 1, ((T([8, 3, 24, 32, 81], f16), T([8, 3, 24, 32, 81], f16, stride=(195840, 65280, 2720, 85, 1))), {})
+cnt: 1, ((T([8, 3, 24, 32, 2], f16), T([8, 3, 24, 32, 2], f16)), {})
+cnt: 1, ((T([8, 3, 12, 16, 81], f16), T([8, 3, 12, 16, 81], f16, stride=(48960, 16320, 1360, 85, 1))), {})
+cnt: 1, ((T([8, 3, 12, 16, 2], f16), T([8, 3, 12, 16, 2], f16)), {})
+Operator: aten.slice_backward.default
+cnt: 1, ((T([8, 3, 48, 64, 2], f16), [8, 3, 48, 64, 85], 4, 2, 4, 1), {})
+cnt: 1, ((T([8, 3, 48, 64, 2], f16), [8, 3, 48, 64, 85], 4, 0, 2, 1), {})
+cnt: 1, ((T([8, 3, 24, 32, 2], f16), [8, 3, 24, 32, 85], 4, 2, 4, 1), {})
+cnt: 1, ((T([8, 3, 24, 32, 2], f16), [8, 3, 24, 32, 85], 4, 0, 2, 1), {})
+cnt: 1, ((T([8, 3, 12, 16, 2], f16), [8, 3, 12, 16, 85], 4, 2, 4, 1), {})
+cnt: 1, ((T([8, 3, 12, 16, 2], f16), [8, 3, 12, 16, 85], 4, 0, 2, 1), {})
+Operator: aten.stack.default
+cnt: 1, (([T([12, 16], i64, stride=(0, 1)), T([12, 16], i64, stride=(1, 0))], 2), {})
+cnt: 1, (([T([24, 32], i64, stride=(0, 1)), T([24, 32], i64, stride=(1, 0))], 2), {})
+cnt: 1, (([T([48, 64], i64, stride=(0, 1)), T([48, 64], i64, stride=(1, 0))], 2), {})
+Operator: aten.sum.default
+cnt: 1, ((T([8, 12096, 85], f16),), {})
+cnt: 1, ((T([8, 3, 12, 16, 85], f16),), {})
+cnt: 1, ((T([8, 3, 24, 32, 85], f16),), {})
+cnt: 1, ((T([8, 3, 48, 64, 85], f16),), {})
+Operator: aten.upsample_nearest2d.vec
+cnt: 1, ((T([8, 256, 12, 16], f16), None, [2.0, 2.0]), {})
+cnt: 1, ((T([8, 128, 24, 32], f16), None, [2.0, 2.0]), {})
+Operator: aten.upsample_nearest2d_backward.vec
+cnt: 1, ((T([8, 128, 48, 64], f16, stride=(1179648, 3072, 64, 1)), None, [8, 128, 24, 32], [2.0, 2.0]), {})
+cnt: 1, ((T([8, 256, 24, 32], f16, stride=(589824, 768, 32, 1)), None, [8, 256, 12, 16], [2.0, 2.0]), {})
diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py
new file mode 100644
index 0000000000000..15037d70a0d16
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py
@@ -0,0 +1,327 @@
+import functools
+import logging
+import math
+import os
+from collections import Counter, defaultdict
+from functools import partial
+from typing import Any, Dict, Generator, Iterable, Tuple
+
+import torch
+from torch.testing import make_tensor
+from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils._pytree import tree_flatten, tree_map
+
+log = logging.getLogger(__name__)
+
+OP_INP_DIRECTORY = os.path.join(os.path.dirname(__file__), "operator_inp_logs")
+
+TIMM_DIR = os.path.join(OP_INP_DIRECTORY, "timm_train")
+HF_DIR = os.path.join(OP_INP_DIRECTORY, "hf_train")
+TORCHBENCH_DIR = os.path.join(OP_INP_DIRECTORY, "torchbench_train")
+
+aten = torch.ops.aten
+tensor_type = torch._C.TensorType.get()
+
+dtype_abbrs = {
+ torch.bfloat16: "bf16",
+ torch.float64: "f64",
+ torch.float32: "f32",
+ torch.float16: "f16",
+ torch.complex32: "c32",
+ torch.complex64: "c64",
+ torch.complex128: "c128",
+ torch.int8: "i8",
+ torch.int16: "i16",
+ torch.int32: "i32",
+ torch.int64: "i64",
+ torch.bool: "b8",
+ torch.uint8: "u8",
+}
+
+dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()}
+
+
+def truncate_inp(arg):
+ if arg in dtype_abbrs:
+ return dtype_abbrs[arg]
+ elif isinstance(arg, torch.device):
+ return arg.type
+ else:
+ return arg
+
+
+# Serialize Function Call
+class FuncCallWrapper:
+ def __init__(self, call, *args, **kwargs):
+ self.call = call
+ self.args = tree_map(truncate_inp, args)
+ self.kwargs = tree_map(truncate_inp, kwargs) if kwargs is not None else {}
+
+ def __repr__(self):
+ args = ", ".join([repr(arg) for arg in self.args])
+ kwargs = "".join(
+ [f", {str(key)}={value}" for key, value in self.kwargs.items()]
+ )
+ out = f"{self.call}({args}{kwargs})".strip('"')
+ # f strings introduce quotations we dont want
+ for key in dtype_abbrs_parsing:
+ out = out.replace(f"'{key}'", key)
+ return out
+
+
+def serialize_sparse_tensor(e):
+ if isinstance(e, torch._subclasses.FakeTensor):
+ return FuncCallWrapper("ST", list(e.shape), e.dtype, e.layout, e.is_coalesced())
+ else:
+ return FuncCallWrapper(
+ "ST", list(e.shape), e.dtype, e.layout, e.is_coalesced(), e._nnz()
+ )
+
+
+def deserialize_sparse_tensor(size, dtype, layout, is_coalesced, nnz=None):
+ raise NotImplementedError()
+
+
+def deserialize_tensor(size, dtype, stride=None):
+ if stride is not None:
+ out = torch.empty_strided(size, stride, dtype=dtype)
+ else:
+ out = torch.empty(size, dtype=dtype)
+ try:
+ out.copy_(make_tensor(size, dtype=dtype, device="cpu"))
+ except Exception as e:
+ print(e)
+ return out
+ return out
+
+
+def serialize_tensor(e):
+ if not e.is_contiguous():
+ return FuncCallWrapper("T", list(e.shape), e.dtype, stride=e.stride())
+ else:
+ return FuncCallWrapper("T", list(e.shape), e.dtype)
+
+
+def serialize_torch_args(e):
+ if isinstance(e, torch.Tensor):
+ if e.is_sparse:
+ return serialize_sparse_tensor(e)
+ return serialize_tensor(e)
+ else:
+ return truncate_inp(e)
+
+
+def contains_tensor(elems):
+ for elem in tree_flatten(elems)[0]:
+ if isinstance(elem, torch.Tensor):
+ return True
+ return False
+
+
+def skip_args(elems):
+ for i in tree_flatten(elems)[0]:
+ # only shows up in constructors and ops like that
+ if isinstance(i, (torch.memory_format, torch.storage.UntypedStorage)):
+ return True
+ return False
+
+
+def contains_tensor_types(type):
+ return type.isSubtypeOf(tensor_type) or any(
+ contains_tensor_types(e) for e in type.containedTypes()
+ )
+
+
+@functools.lru_cache(None)
+def non_compute_operator(op):
+ schema = op._schema
+
+ # skip constructors
+ if not any(contains_tensor_types(arg.type) for arg in schema.arguments):
+ return True
+ if "_like" in op.name:
+ return True
+
+ # allow in place writes
+ if schema.is_mutable:
+ return False
+
+ tensor_inps = [arg for arg in schema.arguments if arg.type is tensor_type]
+ tensor_outputs = [ret for ret in schema.returns if ret.type is tensor_type]
+
+ # skip aliasing unless there are multiple outputs
+ if len(tensor_outputs) != 1:
+ return False
+
+ for inp in tensor_inps:
+ if inp.alias_info and tensor_outputs[0].alias_info:
+ if inp.alias_info.before_set.intersection(
+ tensor_outputs[0].alias_info.after_set
+ ):
+ return True
+
+ return False
+
+
+class OperatorInputsMode(TorchDispatchMode):
+ def __init__(self, func_db=None):
+ self.func_db = defaultdict(Counter) if func_db is None else func_db
+
+ def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
+ kwargs = kwargs if kwargs else {}
+ arg_meta, kwarg_meta = tree_map(serialize_torch_args, (args, kwargs))
+
+ out = func_overload(*args, **kwargs)
+
+ inps = (args, kwargs)
+ if contains_tensor(inps) and not skip_args(inps) and contains_tensor(out):
+ serialized_str = repr((arg_meta, kwarg_meta))
+ self.func_db[str(func_overload)][serialized_str] += 1
+
+ return out
+
+ def log_to_file(self, output_filename, *, skip_non_compute_operators=True):
+ sorted_operators = sorted(list(self.func_db.keys()))
+ with open(output_filename, "w") as f:
+ for operator in sorted_operators:
+ if skip_non_compute_operators and non_compute_operator(eval(operator)):
+ continue
+ f.write(f"Operator: {operator}\n")
+ operator_inputs = self.func_db[operator]
+ for inps, count in operator_inputs.items():
+ f.write(f"cnt: {count}, ")
+ # repr will add quotation marks around the dtype strings
+ for dtype_abbr in dtype_abbrs.values():
+ inps = inps.replace("'" + dtype_abbr + "'", dtype_abbr)
+ f.write(inps)
+ f.write("\n")
+
+
+def map_to_device(e, device):
+ return e.to(device) if isinstance(e, torch.Tensor) else e
+
+
+def map_to_dtype(e, dtype):
+ if isinstance(e, torch.Tensor) and e.is_floating_point():
+ return e.to(dtype)
+ else:
+ return e
+
+
+def deserialize_args(inps):
+ inps = inps.strip().strip("'")
+ global_vals = {
+ **{
+ "T": deserialize_tensor,
+ "ST": deserialize_sparse_tensor,
+ "th": torch,
+ "inf": math.inf,
+ "torch": torch,
+ },
+ **dtype_abbrs_parsing,
+ }
+ # f strings introduce quotations we dont want
+ for key in dtype_abbrs_parsing:
+ inps = inps.replace(f"'{key}'", key)
+ return eval(inps.strip().strip("'").strip('"'), global_vals)
+
+
+class OperatorInputsLoader:
+ def __init__(self, json_file_path):
+ self.operator_db = defaultdict(Counter)
+
+ with open(json_file_path, "r") as f:
+ lines = f.readlines()
+
+ i = 0
+ while i < len(lines):
+ op_line = lines[i].strip("\n")
+ assert "Operator: " in op_line, op_line
+ operator = op_line[len("Operator: ") :]
+ operator = (
+ operator if operator != "aten.sum.SymInt" else "aten.sum.dim_IntList"
+ )
+ op_inps = Counter()
+ i += 1
+ while i < len(lines) and "Operator: " not in lines[i]:
+ line = lines[i]
+ cnt = eval(line[len("cnt: ") : line.find(",")])
+ inps = line[line.find(",") + 2 :].strip("'")
+ op_inps[inps] += cnt
+ i += 1
+ self.operator_db[operator] = op_inps
+
+ def get_inputs_for_operator(
+ self, operator, dtype=None, device="cuda"
+ ) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], None, None]:
+ assert (
+ str(operator) in self.operator_db
+ ), f"Could not find {operator}, must provide overload"
+
+ if "embedding" in str(operator):
+ log.warning("Embedding inputs NYI, input data cannot be randomized")
+ yield
+ return
+
+ # line[1] represents number of times these inputs occured, ignored for now
+ for line in self.operator_db[str(operator)].items():
+ inps = line[0]
+
+ args, kwargs = deserialize_args(inps)
+
+ # Backwards require some inputs to be float16 and some to be float32
+ # So we record on half and upcast to float when specified
+ if dtype and dtype != torch.float16:
+ to_dtype = partial(map_to_dtype, dtype=dtype)
+ args, kwargs = tree_map(to_dtype, (args, kwargs))
+
+ if device:
+ to_device = partial(map_to_device, device=torch.device(device))
+ args, kwargs = tree_map(to_device, (args, kwargs))
+
+ yield args, kwargs
+
+ def get_all_ops(self):
+ for key in self.operator_db.keys():
+ yield eval(key)
+
+ def get_call_frequency(self, op):
+ assert (
+ str(op) in self.operator_db
+ ), f"Could not find {op}, must provide overload"
+
+ count = 0
+ for _, counter in self.operator_db[str(op)].items():
+ count += counter
+ return count
+
+ def merge(self, other):
+ for operator, counter_dict in other.operator_db.items():
+ for inps, cnt in counter_dict.items():
+ self.operator_db[operator][inps] += cnt
+
+ @staticmethod
+ def get_timm_loader():
+ return OperatorInputsLoader._load_directory(TIMM_DIR)
+
+ @staticmethod
+ def get_huggingface_loader():
+ return OperatorInputsLoader._load_directory(HF_DIR)
+
+ @staticmethod
+ def get_torchbench_loader():
+ return OperatorInputsLoader._load_directory(TORCHBENCH_DIR)
+
+ @staticmethod
+ def _load_directory(inp_dir):
+ assert os.path.isdir(inp_dir), inp_dir
+ union = None
+ for inp in os.listdir(inp_dir):
+ if inp[-4:] != ".txt":
+ continue
+ path = os.path.join(inp_dir, inp)
+ if union is None:
+ union = OperatorInputsLoader(path)
+ else:
+ union.merge(OperatorInputsLoader(path))
+ return union
diff --git a/benchmarks/dynamo/microbenchmarks/operatorbench.py b/benchmarks/dynamo/microbenchmarks/operatorbench.py
new file mode 100644
index 0000000000000..fcc15bf5d9326
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/operatorbench.py
@@ -0,0 +1,226 @@
+#!/usr/bin/env python3
+import click
+import numpy as np
+import torch
+import triton
+from operator_inp_utils import OperatorInputsLoader
+
+from torch._dynamo.optimizations.backends import cudagraphs_inner
+from torch._dynamo.testing import same
+from torch._inductor import config as inductor_config
+from torch._inductor.compile_fx import compile_fx
+from torch._inductor.decomposition import decompositions
+from torch._inductor.lowering import fallbacks, lowerings
+from torch._inductor.utils import gen_gm_and_inputs
+
+aten = torch.ops.aten
+
+
+def compute_speedups(
+ operator, models, example_inputs, repeats, accuracy_checking=False
+):
+ expected = models[0](*example_inputs)
+ if accuracy_checking:
+ for model in models[1:]:
+ actual = model(*example_inputs)
+ # change to assert later
+ try:
+ same(actual, expected, cos_similarity=True, equal_nan=True)
+ except AssertionError as e:
+ print(e)
+ print(f"Accuracy check failed: {operator}")
+ print((expected[0] - actual[0]).abs().max())
+
+ timings = np.zeros((repeats, len(models)), np.float64)
+ for rep in range(repeats):
+ # interleave the runs to handle frequency scaling and load changes
+ for m, model in enumerate(models):
+ # do_bench() clears L2 cache to hide the latency of CPU launch time
+ # along with cuda synchronization
+ median_ms, _, _ = triton.testing.do_bench(lambda: model(*example_inputs))
+ timings[rep, m] = median_ms
+ return np.median(timings, axis=0)
+
+
+def strip_overloads(gm):
+ """
+ Modifies the target of graph nodes in :attr:`gm` to strip overloads.
+ Args:
+ gm(fx.GraphModule): The input Fx graph module to be modified
+ """
+ for node in gm.graph.nodes:
+ if isinstance(node.target, torch._ops.OpOverload):
+ node.target = node.target.overloadpacket
+ gm.recompile()
+
+
+def convert_to_jit(gm, gm_args):
+ strip_overloads(gm)
+ try:
+ return torch.jit.script(gm)
+ except Exception:
+ pass
+ return torch.jit.trace(gm, gm_args)
+
+
+def microbenchmark(
+ operator, args, kwargs, dtype, accuracy_checking, repeats, measure_nvfuser
+):
+ gm, gm_args = gen_gm_and_inputs(operator, args, kwargs)
+ torch.jit._builtins._register_builtin(
+ torch.ops.aten.convolution_backward.default, "aten::convolution_backward"
+ )
+ cudagraphs_eager = cudagraphs_inner(gm, gm_args, copy_outputs=False)
+ compiled_fn = compile_fx(gm, gm_args)
+ compiled = [cudagraphs_eager, compiled_fn]
+ if measure_nvfuser:
+ g = convert_to_jit(gm, gm_args)
+ cudagraphs_jit = cudagraphs_inner(g, gm_args, copy_outputs=False)
+ compiled += [cudagraphs_jit]
+ if accuracy_checking:
+ repeats = 1
+
+ medians = compute_speedups(operator, compiled, gm_args, repeats, accuracy_checking)
+ return medians
+
+
+def skip_operator(operator):
+ nyi_strings = (
+ "aten.gather.default",
+ "nll_loss",
+ "aten.index",
+ "aten.scatter_",
+ "masked_fill_.Scalar",
+ )
+
+ if any(nyi_string in str(operator) for nyi_string in nyi_strings):
+ # maybe disable aten.native_layer_norm.default
+ # TODO - inputs cannot be randomly initialized, causes cyda failures
+ print(f"Skipping {operator}, input generator nyi")
+ return True
+
+ # not covered by other non-compute operator heuristics
+ if operator == torch.ops.aten._unsafe_view.default:
+ print(f"Skipping {operator}, non compute operator")
+ return True
+
+ # some of inductor registered to the OpOverload, some registered to OpOverloadPacket
+ op_impls = [operator]
+ if isinstance(operator, torch._ops.OpOverload):
+ op_impls.append(operator.overloadpacket)
+
+ if any(op in fallbacks for op in op_impls):
+ print(f"Skipping {operator}, no inductor impl")
+ return True
+
+ if all(op not in decompositions and op not in lowerings for op in op_impls):
+ print(f"Skipping {operator}, no inductor impl")
+ return True
+
+ if inductor_config.triton.convolution == "aten" and "convolution" in str(operator):
+ return True
+
+ if inductor_config.triton.mm == "aten" and operator in (
+ aten.mm.default,
+ aten.bmm.default,
+ aten.addmm.default,
+ aten.matmul.default,
+ ):
+ return True
+
+ return False
+
+
+@click.command()
+@click.option(
+ "--suite",
+ help="suite to load inps from: options: timm, huggingface, torchbench",
+ default="torchbench",
+)
+@click.option("--op", help="operator overload to benchmark")
+@click.option("--dtype", help="dtype to benchmark")
+@click.option("--max-samples", help="max samples per op", default=15)
+@click.option("--accuracy-checking", help="check accuracy", default=False)
+@click.option(
+ "--repeats", help="how many times to repeat for perf measurement", default=3
+)
+@click.option(
+ "--measure-nvfuser", help="default we only measure inductor", default=False
+)
+def benchmark(
+ suite, op, dtype, max_samples, accuracy_checking, repeats, measure_nvfuser
+):
+ assert suite in ("timm", "huggingface", "torchbench"), f"got {suite}"
+ if suite == "timm":
+ loader = OperatorInputsLoader.get_timm_loader()
+ elif suite == "huggingface":
+ loader = OperatorInputsLoader.get_huggingface_loader()
+ else:
+ loader = OperatorInputsLoader.get_torchbench_loader()
+
+ assert dtype in ("float16", "float32"), f"got {dtype}"
+
+ if op == "all":
+ filename = f"timings_{suite}_{op.replace('.', '_')}{dtype}.txt"
+ f = open(filename, "a")
+
+ dtype = torch.float16 if dtype == "float16" else torch.float32
+
+ if op == "all":
+ ops = loader.get_all_ops()
+ else:
+ ops = [eval(op)]
+
+ for operator in ops:
+ if skip_operator(operator):
+ continue
+
+ print(f"Running {operator}")
+ inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype)
+ timings = []
+
+ for i in range(min(max_samples, 1000000)):
+ print(f"Iter {i}")
+ try:
+ inps = next(inp_gen)
+ if inps is None:
+ break
+ args, kwargs = inps
+ except StopIteration:
+ break
+ try:
+ # aten, nvfuser, inductor
+ timings.append(
+ microbenchmark(
+ operator,
+ args,
+ kwargs,
+ dtype,
+ accuracy_checking,
+ repeats,
+ measure_nvfuser,
+ )
+ )
+ except Exception as e:
+ print(f"error {operator}")
+ print(e)
+ raise e
+
+ if not timings:
+ continue
+
+ timings = torch.tensor(timings).T
+ q = torch.tensor([0.2, 0.5, 0.8], dtype=torch.float64)
+ output = f"{operator}:\nInductor Speedups : {(torch.quantile(timings[0] / timings[1], q)).tolist()}\n"
+ if measure_nvfuser:
+ output += f"NVFUSER Speedups :{(torch.quantile(timings[0] / timings[2], q)).tolist()}\n"
+ if op == "all":
+ f.write(output)
+ print(output)
+
+ if op == "all":
+ f.close()
+
+
+if __name__ == "__main__":
+ benchmark()
diff --git a/benchmarks/dynamo/microbenchmarks/profile_conv.py b/benchmarks/dynamo/microbenchmarks/profile_conv.py
new file mode 100644
index 0000000000000..1d57414d94210
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/profile_conv.py
@@ -0,0 +1,107 @@
+import torch
+
+import torch._inductor.triton_ops
+from torch.profiler import profile, ProfilerActivity, record_function
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
+torch.backends.cuda.matmul.allow_tf32 = True
+# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
+torch.backends.cudnn.allow_tf32 = True
+
+
+(
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ stride,
+ padding,
+ dilation,
+ groups,
+ dtype,
+) = (32, 56, 56, 64, 3, 3, 64, (1, 1), (0, 0), (1, 1), 1, torch.float32)
+
+
+def profile_op(
+ # provider
+ provider,
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ # parameters of conv
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ groups=1,
+ dtype=torch.float16,
+ layout="nhwc",
+ warmup=25,
+ rep=50,
+):
+
+ # allocate inputs, nchw
+ x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
+ w = torch.randn(
+ (KERNEL_N, IN_C // groups, KERNEL_H, KERNEL_W), dtype=dtype, device="cuda"
+ )
+ bias = torch.randn((KERNEL_N), dtype=dtype, device="cuda")
+ if layout == "nhwc":
+ x = x.to(memory_format=torch.channels_last)
+ w = w.to(memory_format=torch.channels_last)
+
+ if provider == "cublas":
+
+ def fn():
+ return torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+
+ elif provider == "triton":
+
+ def fn():
+ return torch._inductor.triton_ops.conv(
+ x, w, bias, stride, padding, dilation, False, (0, 0), groups
+ )
+
+ else:
+ raise ValueError(f"{provider} not supported")
+ # warm up
+ for _ in range(warmup):
+ fn()
+ with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
+ with record_function("model_inference"):
+ for _ in range(rep):
+ fn()
+
+ print("Profiling ", provider)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+
+for provider in ["cublas", "triton"]:
+ profile_op(
+ # provider
+ provider,
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ # parameters of conv
+ stride,
+ padding,
+ dilation,
+ groups,
+ dtype=dtype,
+ layout="nhwc",
+ warmup=25,
+ rep=50,
+ )
diff --git a/benchmarks/dynamo/microbenchmarks/utils.py b/benchmarks/dynamo/microbenchmarks/utils.py
new file mode 100644
index 0000000000000..18972ba09ae62
--- /dev/null
+++ b/benchmarks/dynamo/microbenchmarks/utils.py
@@ -0,0 +1,19 @@
+import math
+
+import torch
+
+
+def rounded_linspace(low, high, steps, div):
+ ret = torch.linspace(low, high, steps)
+ ret = (ret.int() + div - 1) // div * div
+ ret = torch.unique(ret)
+ return list(map(int, ret))
+
+
+def powspace(start, stop, pow, step):
+ start = math.log(start, pow)
+ stop = math.log(stop, pow)
+ steps = int((stop - start + 1) // step)
+ ret = torch.pow(pow, torch.linspace(start, stop, steps))
+ ret = torch.unique(ret)
+ return list(map(int, ret))
diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py
new file mode 100755
index 0000000000000..7dcb51b78d951
--- /dev/null
+++ b/benchmarks/dynamo/runner.py
@@ -0,0 +1,870 @@
+#!/usr/bin/env python3
+
+"""
+A wrapper over the benchmark infrastructure to generate commonly used commands,
+parse results and generate csv/graphs.
+
+The script works on manually written TABLE (see below). We can add more commands
+in the future.
+
+One example usage is
+-> python benchmarks/runner.py --suites=torchbench --inference
+This command will generate the commands for the default compilers (see DEFAULTS
+below) for inference, run them and visualize the logs.
+
+If you want to just print the commands, you could use the following command
+-> python benchmarks/runner.py --print_run_commands --suites=torchbench --inference
+
+Similarly, if you want to just visualize the already finished logs
+-> python benchmarks/runner.py --visualize_logs --suites=torchbench --inference
+
+If you want to test float16
+-> python benchmarks/runner.py --suites=torchbench --inference --dtypes=float16
+
+"""
+
+
+import argparse
+import dataclasses
+import glob
+import importlib
+import io
+import itertools
+import logging
+import os
+import shutil
+import subprocess
+from collections import defaultdict
+from datetime import datetime
+from os.path import abspath, exists
+from random import randint
+
+import matplotlib.pyplot as plt
+import pandas as pd
+import torch
+
+import torch._dynamo
+from matplotlib import rcParams
+from scipy.stats import gmean
+from tabulate import tabulate
+
+rcParams.update({"figure.autolayout": True})
+plt.rc("axes", axisbelow=True)
+
+DEFAULT_OUTPUT_DIR = "benchmark_logs"
+
+
+log = logging.getLogger(__name__)
+
+TABLE = {
+ "training": {
+ "ts_nnc": "--training --speedup-ts ",
+ "ts_nvfuser": "--training --nvfuser --speedup-dynamo-ts ",
+ "eager": "--training --backend=eager ",
+ "aot_eager": "--training --backend=aot_eager ",
+ "aot_cudagraphs": "--training --backend=aot_cudagraphs ",
+ "aot_nvfuser": "--training --nvfuser --backend=aot_nvfuser ",
+ "inductor": "--training --inductor ",
+ },
+ "inference": {
+ "ts_nnc": "--speedup-ts",
+ "ts_nvfuser": "-n100 --speedup-ts --nvfuser",
+ "trt": "-n100 --speedup-trt",
+ "ts_nvfuser_cudagraphs": "--inductor-settings --float32 -n50 --backend=cudagraphs_ts",
+ "inductor": "--inductor-settings --float32 -n50 --inductor",
+ },
+}
+
+INFERENCE_COMPILERS = tuple(TABLE["inference"].keys())
+TRAINING_COMPILERS = tuple(TABLE["training"].keys())
+
+DEFAULTS = {
+ "training": [
+ "eager",
+ "aot_eager",
+ "aot_cudagraphs",
+ "aot_nvfuser",
+ "inductor",
+ ],
+ "inference": ["ts_nvfuser_cudagraphs", "inductor"],
+ "dtypes": [
+ "float32",
+ ],
+ "suites": ["torchbench", "huggingface", "timm_models"],
+ "devices": [
+ "cuda",
+ ],
+ "quick": {
+ "torchbench": '-k "resnet..$"',
+ "huggingface": "-k Albert",
+ "timm_models": ' -k "^resnet" -k "^inception"',
+ },
+}
+
+
+DASHBOARD_DEFAULTS = {
+ "dashboard_image_uploader": "/fsx/users/anijain/bin/imgur.sh",
+ "dashboard_archive_path": "/data/home/anijain/cluster/cron_logs",
+ "dashboard_gh_cli_path": "/data/home/anijain/miniconda/bin/gh",
+}
+
+
+def percentage(part, whole, decimals=2):
+ if whole == 0:
+ return 0
+ return round(100 * float(part) / float(whole), decimals)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--devices", action="append", help="cpu or cuda")
+ parser.add_argument("--dtypes", action="append", help="float16/float32/amp")
+ parser.add_argument("--suites", action="append", help="huggingface/torchbench/timm")
+ parser.add_argument(
+ "--compilers",
+ action="append",
+ help=f"For --inference, options are {INFERENCE_COMPILERS}. For --training, options are {TRAINING_COMPILERS}",
+ )
+ parser.add_argument(
+ "--quick", action="store_true", help="Just runs one model. Helps in debugging"
+ )
+ parser.add_argument(
+ "--output-dir",
+ help="Choose the output directory to save the logs",
+ default=DEFAULT_OUTPUT_DIR,
+ )
+
+ # Choose either generation of commands, pretty parsing or e2e runs
+ group = parser.add_mutually_exclusive_group(required=False)
+ group.add_argument(
+ "--print_run_commands",
+ action="store_true",
+ help="Generate commands and saves them to run.sh",
+ )
+ group.add_argument(
+ "--visualize_logs",
+ action="store_true",
+ help="Pretty print the log files and draw graphs",
+ )
+ group.add_argument(
+ "--run",
+ action="store_true",
+ default=True,
+ help="Generate commands, run and parses the files",
+ )
+
+ parser.add_argument(
+ "--log-operator-inputs",
+ action="store_true",
+ default=False,
+ help="Log operator inputs",
+ )
+
+ # Choose either inference or training
+ group_mode = parser.add_mutually_exclusive_group(required=True)
+ group_mode.add_argument(
+ "--inference", action="store_true", help="Only run inference related tasks"
+ )
+ group_mode.add_argument(
+ "--training", action="store_true", help="Only run training related tasks"
+ )
+
+ parser.add_argument(
+ "--update-dashboard",
+ action="store_true",
+ default=False,
+ help="Updates to dashboard",
+ )
+ parser.add_argument(
+ "--dashboard-image-uploader",
+ default=DASHBOARD_DEFAULTS["dashboard_image_uploader"],
+ help="Image uploader command",
+ )
+ parser.add_argument(
+ "--dashboard-archive-path",
+ default=DASHBOARD_DEFAULTS["dashboard_archive_path"],
+ help="Archived directory path",
+ )
+ parser.add_argument(
+ "--dashboard-gh-cli-path",
+ default=DASHBOARD_DEFAULTS["dashboard_gh_cli_path"],
+ help="Github CLI path",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def get_mode(args):
+ if args.inference:
+ return "inference"
+ return "training"
+
+
+def get_skip_tests(suite):
+ """
+ Generate -x seperated string to skip the unusual setup training tests
+ """
+ skip_tests = set()
+ original_dir = abspath(os.getcwd())
+ module = importlib.import_module(suite)
+ os.chdir(original_dir)
+
+ if hasattr(module, "SKIP"):
+ skip_tests.update(module.SKIP)
+ if hasattr(module, "SKIP_TRAIN"):
+ skip_tests.update(module.SKIP_TRAIN)
+
+ skip_tests = map(lambda name: f"-x {name}", skip_tests)
+ skip_str = " ".join(skip_tests)
+ return skip_str
+
+
+def generate_commands(args, dtypes, suites, devices, compilers, output_dir):
+ mode = get_mode(args)
+ with open("run.sh", "w") as runfile:
+ lines = []
+
+ lines.append("# Setup the output directory")
+ lines.append(f"rm -rf {output_dir}")
+ lines.append(f"mkdir {output_dir}")
+ lines.append("")
+
+ for testing in ["performance", "accuracy"]:
+ for iter in itertools.product(suites, devices, dtypes):
+ suite, device, dtype = iter
+ lines.append(
+ f"# Commands for {suite} for device={device}, dtype={dtype} for {mode} and for {testing} testing"
+ )
+ info = TABLE[mode]
+ for compiler in compilers:
+ base_cmd = info[compiler]
+ output_filename = f"{output_dir}/{compiler}_{suite}_{dtype}_{mode}_{device}_{testing}.csv"
+ cmd = f"python benchmarks/{suite}.py --{testing} --{dtype} -d{device} --output={output_filename}"
+ cmd = f"{cmd} {base_cmd} --no-skip --dashboard"
+
+ skip_tests_str = get_skip_tests(suite)
+ cmd = f"{cmd} {skip_tests_str}"
+
+ if args.log_operator_inputs:
+ cmd = f"{cmd} --log-operator-inputs"
+
+ if args.quick:
+ filters = DEFAULTS["quick"][suite]
+ cmd = f"{cmd} {filters}"
+ lines.append(cmd)
+ lines.append("")
+ runfile.writelines([line + "\n" for line in lines])
+
+
+def generate_dropdown_comment(title, body):
+ str_io = io.StringIO()
+ str_io.write(f"{title}\n")
+ str_io.write("\n")
+ str_io.write("see more
\n")
+ str_io.write(f"{body}")
+ str_io.write("\n")
+ str_io.write(" \n\n")
+ return str_io.getvalue()
+
+
+def build_summary():
+ import git
+
+ out_io = io.StringIO()
+
+ def print_commit_hash(path, name):
+ if exists(path):
+ repo = git.Repo(path, search_parent_directories=True)
+ sha = repo.head.object.hexsha
+ out_io.write(f"{name} commit: {sha}\n")
+ else:
+ out_io.write(f"{name} Absent\n")
+
+ def env_var(name):
+ out_io.write(f"{name} = {os.environ[name]}\n")
+
+ out_io.write("## Commit hashes ##\n")
+ print_commit_hash(".", "torch._dynamo")
+ print_commit_hash("../pytorch", "pytorch")
+ print_commit_hash("../functorch", "functorch")
+ print_commit_hash("../torchbenchmark", "torchbench")
+
+ out_io.write("\n")
+ out_io.write("## TorchDynamo config flags ##\n")
+ for key in dir(torch._dynamo.config):
+ val = getattr(torch._dynamo.config, key)
+ if not key.startswith("__") and isinstance(val, bool):
+ out_io.write(f"torch._dynamo.config.{key} = {val}\n")
+
+ out_io.write("\n")
+ out_io.write("## Torch version ##\n")
+ out_io.write(f"torch: {torch.__version__}\n")
+
+ out_io.write("\n")
+ out_io.write("## Environment variables ##\n")
+ env_var("TORCH_CUDA_ARCH_LIST")
+ env_var("CUDA_HOME")
+ env_var("USE_LLVM")
+
+ out_io.write("\n")
+ out_io.write("## GPU details ##\n")
+ out_io.write(f"CUDNN VERSION: {torch.backends.cudnn.version()}\n")
+ out_io.write(f"Number CUDA Devices: {torch.cuda.device_count()}\n")
+ out_io.write(f"Device Name: {torch.cuda.get_device_name(0)}\n")
+ out_io.write(
+ f"Device Memory [GB]: {torch.cuda.get_device_properties(0).total_memory/1e9}\n"
+ )
+
+ title = "## Build Summary"
+ comment = generate_dropdown_comment(title, out_io.getvalue())
+ with open(f"{output_dir}/gh_build_summary.txt", "w") as gh_fh:
+ gh_fh.write(comment)
+
+
+class Parser:
+ def __init__(self, suites, devices, dtypes, compilers, mode, output_dir):
+ self.suites = suites
+ self.devices = devices
+ self.dtypes = dtypes
+ self.compilers = compilers
+ self.output_dir = output_dir
+ self.mode = mode
+
+ def has_header(self, output_filename):
+ header_present = False
+ with open(output_filename, "r") as f:
+ line = f.readline()
+ if "dev" in line:
+ header_present = True
+ return header_present
+
+
+class ParsePerformanceLogs(Parser):
+ def __init__(self, suites, devices, dtypes, compilers, mode, output_dir):
+ super().__init__(suites, devices, dtypes, compilers, mode, output_dir)
+ self.parsed_frames = defaultdict(lambda: defaultdict(None))
+ self.untouched_parsed_frames = defaultdict(lambda: defaultdict(None))
+ self.metrics = ["speedup", "compilation_latency", "compression_ratio"]
+ self.bottom_k = 50
+ self.parse()
+
+ def plot_graph(self, df, title):
+ labels = df.columns.values.tolist()
+ labels = labels[3:]
+ df.plot(
+ x="name",
+ y=labels,
+ kind="bar",
+ width=0.65,
+ title=title,
+ ylabel="Speedup over eager",
+ xlabel="",
+ grid=True,
+ figsize=(max(len(df.index) / 4, 5), 10),
+ edgecolor="black",
+ )
+ plt.tight_layout()
+ plt.savefig(f"{self.output_dir}/{title}.png")
+
+ def read_csv(self, output_filename):
+ if self.has_header(output_filename):
+ return pd.read_csv(output_filename)
+ else:
+ return pd.read_csv(
+ output_filename,
+ names=[
+ "dev",
+ "name",
+ "batch_size",
+ "speedup",
+ "compilation_latency",
+ "compression_ratio",
+ ],
+ header=None,
+ engine="python",
+ )
+
+ def parse(self):
+ self.extract_df("accuracy", "accuracy")
+ for metric in self.metrics:
+ self.extract_df(metric, "performance")
+ self.generate_executive_summary()
+ for suite in self.suites:
+ self.plot_graph(
+ self.untouched_parsed_frames[suite]["speedup"],
+ f"{suite}_{self.dtypes[0]}",
+ )
+
+ def clean_batch_sizes(self, frames):
+ # Clean up batch sizes when its 0
+ if len(frames) == 1:
+ return frames
+ batch_sizes = frames[0]["batch_size"].to_list()
+ for frame in frames[1:]:
+ frame_batch_sizes = frame["batch_size"].to_list()
+ for idx, (batch_a, batch_b) in enumerate(
+ zip(batch_sizes, frame_batch_sizes)
+ ):
+ assert batch_a == batch_b or batch_a == 0 or batch_b == 0, print(
+ f"a={batch_a}, b={batch_b}"
+ )
+ batch_sizes[idx] = max(batch_a, batch_b)
+ for frame in frames:
+ frame["batch_size"] = batch_sizes
+ return frames
+
+ def extract_df(self, metric, testing):
+ for iter in itertools.product(self.suites, self.devices, self.dtypes):
+ suite, device, dtype = iter
+ frames = []
+ for compiler in self.compilers:
+ output_filename = f"{self.output_dir}/{compiler}_{suite}_{dtype}_{self.mode}_{device}_{testing}.csv"
+ df = self.read_csv(output_filename)
+ df = df[["dev", "name", "batch_size", metric]]
+ df.rename(columns={metric: compiler}, inplace=True)
+ df["batch_size"] = df["batch_size"].astype(int)
+ frames.append(df)
+
+ # Merge the results
+ frames = self.clean_batch_sizes(frames)
+ if len(self.compilers) == 1:
+ df = frames[0]
+ else:
+ # Merge data frames
+ df = pd.merge(frames[0], frames[1], on=["dev", "name", "batch_size"])
+ for idx in range(2, len(frames)):
+ df = pd.merge(df, frames[idx], on=["dev", "name", "batch_size"])
+
+ df_copy = df.copy()
+ df_copy = df_copy.sort_values(
+ by=list(reversed(self.compilers)), ascending=False
+ )
+ self.untouched_parsed_frames[suite][metric] = df_copy
+
+ if testing == "performance":
+ df_accuracy = self.parsed_frames[suite]["accuracy"]
+ perf_rows = []
+ for model_name in df["name"]:
+ perf_row = df[df["name"] == model_name]
+ acc_row = df_accuracy[df_accuracy["name"] == model_name]
+ for compiler in self.compilers:
+ if not perf_row.empty:
+ if acc_row.empty:
+ perf_row[compiler].iloc[0] = 0.0
+ elif acc_row[compiler].iloc[0] not in (
+ "pass",
+ "pass_due_to_skip",
+ ):
+ perf_row[compiler].iloc[0] = 0.0
+ perf_rows.append(perf_row)
+ df = pd.concat(perf_rows)
+ df = df.sort_values(by=list(reversed(self.compilers)), ascending=False)
+ self.parsed_frames[suite][metric] = df
+
+ def get_passing_entries(self, compiler, df):
+ return df[compiler][df[compiler] > 0]
+
+ def comp_time(self, compiler, df):
+ df = self.get_passing_entries(compiler, df)
+ # df = df.sort_values(by=compiler, ascending=False)[compiler][: self.bottom_k]
+ if df.empty:
+ return "0.0"
+
+ return f"{df.mean():.2f}"
+
+ def geomean(self, compiler, df):
+ cleaned_df = self.get_passing_entries(compiler, df).clip(1)
+ if cleaned_df.empty:
+ return "0.0x"
+ return f"{gmean(cleaned_df):.2f}x"
+
+ def passrate(self, compiler, df):
+ total = len(df.index)
+ passing = df[df[compiler] > 0.0][compiler].count()
+ perc = int(percentage(passing, total, decimals=0))
+ return f"{perc}%, {passing}/{total}"
+
+ def memory(self, compiler, df):
+ df = self.get_passing_entries(compiler, df)
+ df = df.fillna(0)
+ df = df[df > 0]
+ if df.empty:
+ return "0.0x"
+ return f"{df.mean():.2f}x"
+
+ def exec_summary_df(self, fn, metric):
+ """
+ Generate a table with passrate and geomean perf
+ """
+ cols = {}
+ cols["Compiler"] = self.compilers
+ for suite in self.suites:
+ df = self.parsed_frames[suite][metric]
+ # speedups = [self.geomean(compiler, df) for compiler in self.compilers]
+ speedups = [fn(compiler, df) for compiler in self.compilers]
+ col = pd.Series(data=speedups, index=self.compilers)
+ cols[suite] = col
+ df = pd.DataFrame(cols)
+ df = df.fillna(0)
+ df.to_csv(os.path.join(self.output_dir, f"{fn.__name__}.csv"))
+ return df
+
+ def exec_summary_text(self, caption, fn, metric):
+ df = self.exec_summary_df(fn, metric)
+ tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never")
+
+ str_io = io.StringIO()
+ str_io.write(f"{caption}")
+ str_io.write("~~~\n")
+ str_io.write(f"{tabform}\n")
+ str_io.write("~~~\n")
+ return str_io.getvalue()
+
+ def generate_executive_summary(self):
+ description = (
+ "We evaluate different backends "
+ "across three benchmark suites - torchbench, huggingface and timm. We run "
+ "these experiments on A100 GPUs. Each experiment runs one iteration of forward "
+ "and backward pass. For accuracy, we check the numerical correctness of forward "
+ "pass outputs and gradients by comparing with native pytorch. We measure speedup "
+ "by normalizing against the performance of native pytorch. We report mean "
+ "compilation latency numbers and peak memory footprint reduction ratio. \n\n"
+ "Caveats\n"
+ "1) Batch size has been reduced to workaround OOM errors. Work is in progress to "
+ "reduce peak memory footprint.\n"
+ "2) Experiments do not cover dynamic shapes.\n"
+ "3) Experimental setup does not have optimizer.\n\n"
+ )
+
+ comment = generate_dropdown_comment("", description)
+ str_io = io.StringIO()
+ str_io.write("\n")
+ str_io.write("## Executive Summary ##\n")
+ str_io.write(comment)
+
+ speedup_caption = "Geometric mean speedup \n"
+ speedup_summary = self.exec_summary_text(
+ speedup_caption, self.geomean, "speedup"
+ )
+
+ passrate_caption = "Passrate\n"
+ passrate_summary = self.exec_summary_text(
+ passrate_caption, self.passrate, "speedup"
+ )
+
+ comp_time_caption = "Mean compilation time (seconds)\n"
+ comp_time_summary = self.exec_summary_text(
+ comp_time_caption, self.comp_time, "compilation_latency"
+ )
+
+ peak_memory_caption = (
+ "Peak memory footprint compression ratio (higher is better)\n"
+ )
+ peak_memory_summary = self.exec_summary_text(
+ peak_memory_caption, self.memory, "compression_ratio"
+ )
+
+ str_io.write(
+ "To measure performance, compilation latency and memory footprint reduction, "
+ "we remove the models that fail accuracy checks.\n\n"
+ )
+ str_io.write(passrate_summary)
+ str_io.write(speedup_summary)
+ str_io.write(comp_time_summary)
+ str_io.write(peak_memory_summary)
+ self.executive_summary = str_io.getvalue()
+
+ def prepare_message(self, suite):
+ title = f"## {suite} suite with {self.dtypes[0]} precision ##"
+ body = ""
+ for metric in [
+ "speedup",
+ "accuracy",
+ "compilation_latency",
+ "compression_ratio",
+ ]:
+ df = self.untouched_parsed_frames[suite][metric]
+ df = df.drop("dev", axis=1)
+ df = df.rename(columns={"batch_size": "bs"})
+ tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never")
+ str_io = io.StringIO()
+ str_io.write("\n")
+ if metric == "speedup":
+ str_io.write("Performance speedup\n")
+ elif metric == "accuracy":
+ str_io.write("Accuracy\n")
+ elif metric == "compilation_latency":
+ str_io.write("Compilation latency (sec)\n")
+ elif metric == "compression_ratio":
+ str_io.write("Peak Memory Compression Ratio\n")
+ str_io.write("~~~\n")
+ str_io.write(f"{tabform}\n")
+ str_io.write("~~~\n")
+ body += str_io.getvalue()
+
+ comment = generate_dropdown_comment(title, body)
+ return comment
+
+ def gen_summary_files(self):
+ with open(f"{self.output_dir}/gh_title.txt", "w") as gh_fh:
+ str_io = io.StringIO()
+ str_io.write("\n")
+ str_io.write(f"# Performance Dashboard for {self.dtypes[0]} precision ##\n")
+ str_io.write("\n")
+ gh_fh.write(str_io.getvalue())
+
+ with open(f"{self.output_dir}/gh_executive_summary.txt", "w") as gh_fh:
+ gh_fh.write(self.executive_summary)
+ print(self.executive_summary)
+
+ str_io = io.StringIO()
+ for suite in self.suites:
+ str_io.write(self.prepare_message(suite))
+ str_io.write("\n")
+ print(str_io.getvalue())
+ with open(f"{self.output_dir}/gh_{self.mode}.txt", "w") as gh_fh:
+ gh_fh.write(str_io.getvalue())
+
+
+def parse_logs(args, dtypes, suites, devices, compilers, output_dir):
+ mode = get_mode(args)
+ build_summary()
+
+ parser_class = ParsePerformanceLogs
+ parser = parser_class(suites, devices, dtypes, compilers, mode, output_dir)
+ parser.gen_summary_files()
+ return
+
+
+@dataclasses.dataclass
+class LogInfo:
+ # Day of the year this log was generated
+ day: str
+
+ # Directory path where all logs are present
+ dir_path: str
+
+
+def get_date(log_info):
+ return datetime.strptime(f"{log_info.day}", "%j").strftime("%m-%d")
+
+
+class RegressionTracker:
+ """
+ Plots progress of different metrics over time to detect regressions.
+ """
+
+ def __init__(self, args):
+ self.args = args
+ self.suites = self.args.suites
+ self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv")
+ assert os.path.exists(self.lookup_file)
+ self.k = 10
+
+ def find_last_k(self):
+ """
+ Find the last k pairs of (day number, log_path)
+ """
+ dtype = self.args.dtypes[0]
+ df = pd.read_csv(self.lookup_file, names=("day", "mode", "prec", "path"))
+ df = df[df["mode"] == "performance"]
+ df = df[df["prec"] == dtype]
+ log_infos = []
+ for day, path in zip(df["day"], df["path"]):
+ log_infos.append(LogInfo(day, path))
+
+ assert len(log_infos) >= self.k
+ log_infos = log_infos[len(log_infos) - self.k :]
+ return log_infos
+
+ def generate_comment(self):
+ title = "## Metrics over time ##\n"
+ str_io = io.StringIO()
+ for name in glob.glob(self.args.output_dir + "/*over_time.png"):
+ output = (
+ subprocess.check_output([self.args.dashboard_image_uploader, name])
+ .decode("ascii")
+ .rstrip()
+ )
+ str_io.write(f"\n{name} : ![]({output})\n")
+ comment = generate_dropdown_comment(title, str_io.getvalue())
+
+ with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh:
+ gh_fh.write(comment)
+
+ def diff(self):
+ log_infos = self.find_last_k()
+
+ for metric in ["geomean", "passrate"]:
+ fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
+ for idx, suite in enumerate(self.suites):
+ dfs = []
+ for log_info in log_infos:
+ dir_path = os.path.join(
+ self.args.dashboard_archive_path, log_info.dir_path
+ )
+ assert os.path.exists(dir_path)
+ gmean_filename = os.path.join(dir_path, f"{metric}.csv")
+ if not os.path.exists(gmean_filename):
+ continue
+ df = pd.read_csv(gmean_filename)
+ if metric == "geomean":
+ df[suite] = df[suite].str.replace("x", "").astype(float)
+ elif metric == "passrate":
+ df[suite] = df[suite].str.split("%").str[0].astype(float)
+ df.insert(0, "day", get_date(log_info))
+ df = df.pivot(index="day", columns="Compiler", values=suite)
+
+ # Interim stage when both inductor_cudagraphs and inductor exist
+ df = df.rename(columns={"inductor_cudagraphs": "inductor"})
+ for col_name in df.columns:
+ if col_name not in self.args.compilers:
+ df = df.drop(columns=[col_name])
+ dfs.append(df)
+
+ df = pd.concat(dfs)
+ ax = df.plot(
+ ax=axes[idx],
+ kind="line",
+ ylabel=metric,
+ xlabel="Date",
+ grid=True,
+ ylim=0 if metric == "passrate" else 0.8,
+ title=suite,
+ style=".-",
+ legend=False,
+ )
+ ax.legend(loc="lower right", ncol=2)
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(output_dir, f"{metric}_over_time.png"))
+
+ self.generate_comment()
+
+
+class DashboardUpdater:
+ """
+ Aggregates the information and makes a comment to Performance Dashboard.
+ https://github.com/pytorch/torchdynamo/issues/681
+ """
+
+ def __init__(self, args):
+ self.args = args
+ self.output_dir = args.output_dir
+ self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv")
+ assert os.path.exists(self.lookup_file)
+ self.archive()
+
+ def archive(self):
+ # Copy the folder to archived location
+ src = self.output_dir
+ day = datetime.today().strftime("%j")
+ prefix = datetime.today().strftime(f"day_{day}_%d_%m_%y")
+ target_dir = f"{prefix}_performance_{self.args.dtypes[0]}_{randint(100, 999)}"
+ target = os.path.join(self.args.dashboard_archive_path, target_dir)
+ shutil.copytree(src, target)
+
+ # Update lookup csv the folder to arhived logs
+ dtype = self.args.dtypes[0]
+ subprocess.check_call(
+ f'echo "{day},performance,{dtype},{target_dir}" >> {self.lookup_file}',
+ shell=True,
+ )
+
+ def upload_graphs(self):
+ title = "## Performance graphs ##\n"
+ str_io = io.StringIO()
+ for name in glob.glob(self.output_dir + "/*png"):
+ if "over_time" not in name:
+ output = (
+ subprocess.check_output([self.args.dashboard_image_uploader, name])
+ .decode("ascii")
+ .rstrip()
+ )
+ str_io.write(f"\n{name} : ![]({output})\n")
+ comment = generate_dropdown_comment(title, str_io.getvalue())
+
+ with open(f"{self.output_dir}/gh_graphs.txt", "w") as gh_fh:
+ gh_fh.write(comment)
+
+ def gen_comment(self):
+ files = [
+ "gh_title.txt",
+ "gh_executive_summary.txt",
+ "gh_regression.txt",
+ "gh_training.txt",
+ "gh_graphs.txt",
+ ]
+ all_lines = []
+ for f in files:
+ with open(os.path.join(self.output_dir, f), "r") as fh:
+ all_lines.extend(fh.readlines())
+
+ return "\n".join([x.rstrip() for x in all_lines])
+
+ def comment_on_gh(self, comment):
+ """
+ Send a commment to dashboard
+ """
+ subprocess.check_call(
+ [
+ self.args.dashboard_gh_cli_path,
+ "issue",
+ "comment",
+ "681",
+ "-b",
+ comment,
+ ]
+ )
+
+ def update(self):
+ self.upload_graphs()
+ try:
+ RegressionTracker(self.args).diff()
+ except Exception:
+ with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh:
+ gh_fh.write("")
+
+ comment = self.gen_comment()
+ self.comment_on_gh(comment)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ def extract(key):
+ return DEFAULTS[key] if getattr(args, key, None) is None else getattr(args, key)
+
+ dtypes = extract("dtypes")
+ suites = extract("suites")
+ devices = extract("devices")
+
+ if args.inference:
+ compilers = DEFAULTS["inference"] if args.compilers is None else args.compilers
+ else:
+ assert args.training
+ compilers = DEFAULTS["training"] if args.compilers is None else args.compilers
+
+ output_dir = args.output_dir
+ args.compilers = compilers
+ args.suites = suites
+
+ if args.print_run_commands:
+ generate_commands(args, dtypes, suites, devices, compilers, output_dir)
+ elif args.visualize_logs:
+ parse_logs(args, dtypes, suites, devices, compilers, output_dir)
+ elif args.run:
+ generate_commands(args, dtypes, suites, devices, compilers, output_dir)
+ # TODO - Do we need to worry about segfaults
+ try:
+ os.system("bash run.sh")
+ except Exception as e:
+ print(
+ "Running commands failed. Please run manually (bash run.sh) and inspect the errors."
+ )
+ raise e
+ if not args.log_operator_inputs:
+ parse_logs(args, dtypes, suites, devices, compilers, output_dir)
+
+ if args.update_dashboard:
+ DashboardUpdater(args).update()
diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py
new file mode 100755
index 0000000000000..ae9200d0b8b28
--- /dev/null
+++ b/benchmarks/dynamo/timm_models.py
@@ -0,0 +1,334 @@
+#!/usr/bin/env python3
+import importlib
+import logging
+import os
+import re
+import subprocess
+import sys
+import time
+import warnings
+
+import torch
+from common import BenchmarkRunner, main
+
+from torch._dynamo.testing import collect_results
+from torch._dynamo.utils import clone_inputs
+
+
+def pip_install(package):
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
+
+
+try:
+ importlib.import_module("timm")
+except ModuleNotFoundError:
+ print("Installing Pytorch Image Models...")
+ pip_install("git+https://github.com/rwightman/pytorch-image-models")
+finally:
+ from timm.data import resolve_data_config
+ from timm.models import create_model
+
+TIMM_MODELS = dict()
+filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
+
+with open(filename, "r") as fh:
+ lines = fh.readlines()
+ lines = [line.rstrip() for line in lines]
+ for line in lines:
+ model_name, batch_size = line.split(" ")
+ TIMM_MODELS[model_name] = int(batch_size)
+
+
+# TODO - Figure out the reason of cold start memory spike
+BATCH_SIZE_DIVISORS = {
+ "beit_base_patch16_224": 2,
+ "cait_m36_384": 4,
+ "convit_base": 4,
+ "convmixer_768_32": 2,
+ "convnext_base": 4,
+ "crossvit_9_240": 2,
+ "cspdarknet53": 2,
+ "deit_base_distilled_patch16_224": 2,
+ "dla102": 2,
+ "dpn107": 2,
+ "eca_botnext26ts_256": 2,
+ "eca_halonext26ts": 2,
+ "gluon_senet154": 2,
+ "gluon_xception65": 2,
+ "gmixer_24_224": 2,
+ "gmlp_s16_224": 2,
+ "hrnet_w18": 64,
+ "jx_nest_base": 4,
+ "mixer_b16_224": 2,
+ "mixnet_l": 2,
+ "mobilevit_s": 4,
+ "nfnet_l0": 2,
+ "pit_b_224": 2,
+ "pnasnet5large": 2,
+ "poolformer_m36": 2,
+ "res2net101_26w_4s": 2,
+ "res2net50_14w_8s": 64,
+ "res2next50": 64,
+ "resnest101e": 4,
+ "sebotnet33ts_256": 2,
+ "swin_base_patch4_window7_224": 2,
+ "swsl_resnext101_32x16d": 2,
+ "tf_mixnet_l": 2,
+ "tnt_s_patch16_224": 2,
+ "twins_pcpvt_base": 4,
+ "vit_base_patch16_224": 2,
+ "volo_d1_224": 2,
+ "xcit_large_24_p8_224": 4,
+}
+
+REQUIRE_HIGHER_TOLERANCE = set()
+
+SKIP = {
+ # Unusual training setup
+ "levit_128",
+}
+
+
+def refresh_model_names():
+ import glob
+
+ from timm.models import list_models
+
+ def read_models_from_docs():
+ models = set()
+ # TODO - set the path to pytorch-image-models repo
+ for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
+ with open(fn, "r") as f:
+ while True:
+ line = f.readline()
+ if not line:
+ break
+ if not line.startswith("model = timm.create_model("):
+ continue
+
+ model = line.split("'")[1]
+ # print(model)
+ models.add(model)
+ return models
+
+ def get_family_name(name):
+ known_families = [
+ "darknet",
+ "densenet",
+ "dla",
+ "dpn",
+ "ecaresnet",
+ "halo",
+ "regnet",
+ "efficientnet",
+ "deit",
+ "mobilevit",
+ "mnasnet",
+ "convnext",
+ "resnet",
+ "resnest",
+ "resnext",
+ "selecsls",
+ "vgg",
+ "xception",
+ ]
+
+ for known_family in known_families:
+ if known_family in name:
+ return known_family
+
+ if name.startswith("gluon_"):
+ return "gluon_" + name.split("_")[1]
+ return name.split("_")[0]
+
+ def populate_family(models):
+ family = dict()
+ for model_name in models:
+ family_name = get_family_name(model_name)
+ if family_name not in family:
+ family[family_name] = []
+ family[family_name].append(model_name)
+ return family
+
+ docs_models = read_models_from_docs()
+ all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
+
+ all_models_family = populate_family(all_models)
+ docs_models_family = populate_family(docs_models)
+
+ # print(docs_models_family.keys())
+ for key in docs_models_family:
+ del all_models_family[key]
+
+ chosen_models = set()
+ for value in docs_models_family.values():
+ chosen_models.add(value[0])
+
+ for key, value in all_models_family.items():
+ chosen_models.add(value[0])
+
+ filename = "timm_models_list.txt"
+ if os.path.exists("benchmarks"):
+ filename = "benchmarks/" + filename
+ with open(filename, "w") as fw:
+ for model_name in sorted(chosen_models):
+ fw.write(model_name + "\n")
+
+
+class TimmRunnner(BenchmarkRunner):
+ def __init__(self):
+ super(TimmRunnner, self).__init__()
+ self.suite_name = "timm_models"
+
+ def load_model(
+ self,
+ device,
+ model_name,
+ batch_size=None,
+ ):
+
+ is_training = self.args.training
+ use_eval_mode = self.args.use_eval_mode
+
+ # _, model_dtype, data_dtype = self.resolve_precision()
+ channels_last = self._args.channels_last
+
+ retries = 1
+ success = False
+ while not success and retries < 4:
+ try:
+ model = create_model(
+ model_name,
+ in_chans=3,
+ scriptable=False,
+ num_classes=None,
+ drop_rate=0.0,
+ drop_path_rate=None,
+ drop_block_rate=None,
+ pretrained=True,
+ # global_pool=kwargs.pop('gp', 'fast'),
+ # num_classes=kwargs.pop('num_classes', None),
+ # drop_rate=kwargs.pop('drop', 0.),
+ # drop_path_rate=kwargs.pop('drop_path', None),
+ # drop_block_rate=kwargs.pop('drop_block', None),
+ )
+ success = True
+ except Exception:
+ wait = retries * 30
+ time.sleep(wait)
+ retries += 1
+
+ model.to(
+ device=device,
+ memory_format=torch.channels_last if channels_last else None,
+ )
+
+ self.num_classes = model.num_classes
+
+ data_config = resolve_data_config(
+ self._args, model=model, use_test_size=not is_training
+ )
+ input_size = data_config["input_size"]
+ recorded_batch_size = TIMM_MODELS[model_name]
+ recorded_batch_size = max(
+ int(recorded_batch_size / BATCH_SIZE_DIVISORS.get(model_name, 1)), 1
+ )
+ batch_size = batch_size or recorded_batch_size
+
+ # example_inputs = torch.randn(
+ # (batch_size,) + input_size, device=device, dtype=data_dtype
+ # )
+ torch.manual_seed(1337)
+ input_tensor = torch.randint(
+ 256, size=(batch_size,) + input_size, device=device
+ ).to(dtype=torch.float32)
+ mean = torch.mean(input_tensor)
+ std_dev = torch.std(input_tensor)
+ example_inputs = (input_tensor - mean) / std_dev
+
+ if channels_last:
+ example_inputs = example_inputs.contiguous(
+ memory_format=torch.channels_last
+ )
+ example_inputs = [
+ example_inputs,
+ ]
+ self.target = self._gen_target(batch_size, device)
+
+ self.loss = torch.nn.CrossEntropyLoss().to(device)
+ if is_training and not use_eval_mode:
+ model.train()
+ else:
+ model.eval()
+
+ self.init_optimizer(device, model.parameters())
+
+ self.validate_model(model, example_inputs)
+
+ return device, model_name, model, example_inputs, batch_size
+
+ def iter_model_names(self, args):
+ # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
+ model_names = sorted(TIMM_MODELS.keys())
+ start, end = self.get_benchmark_indices(len(model_names))
+ for index, model_name in enumerate(model_names):
+ if index < start or index >= end:
+ continue
+ if (
+ not re.search("|".join(args.filter), model_name, re.I)
+ or re.search("|".join(args.exclude), model_name, re.I)
+ or model_name in self.skip_models
+ ):
+ continue
+
+ yield model_name
+
+ def pick_grad(self, name, is_training):
+ if is_training:
+ return torch.enable_grad()
+ else:
+ return torch.no_grad()
+
+ def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
+ cosine = self.args.cosine
+ tolerance = 1e-3
+ if is_training:
+ if REQUIRE_HIGHER_TOLERANCE:
+ tolerance = 2 * 1e-2
+ else:
+ tolerance = 1e-2
+ return tolerance, cosine
+
+ def _gen_target(self, batch_size, device):
+ # return torch.ones((batch_size,) + (), device=device, dtype=torch.long)
+ return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
+ self.num_classes
+ )
+
+ def compute_loss(self, pred):
+ # High loss values make gradient checking harder, as small changes in
+ # accumulation order upsets accuracy checks.
+ return self.loss(pred, self.target) / 10.0
+
+ def forward_pass(self, mod, inputs, collect_outputs=True):
+ return mod(*inputs)
+
+ def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
+ cloned_inputs = clone_inputs(inputs)
+ mod.zero_grad(True)
+ with self.autocast():
+ pred = mod(*cloned_inputs)
+ if isinstance(pred, tuple):
+ pred = pred[0]
+ loss = self.compute_loss(pred)
+ self.grad_scaler.scale(loss).backward()
+ self.optimizer_step()
+ if collect_outputs:
+ return collect_results(mod, pred, loss, cloned_inputs)
+ return None
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.WARNING)
+ warnings.filterwarnings("ignore")
+ main(TimmRunnner())
diff --git a/benchmarks/dynamo/timm_models_list.txt b/benchmarks/dynamo/timm_models_list.txt
new file mode 100644
index 0000000000000..d8c40edd7da9f
--- /dev/null
+++ b/benchmarks/dynamo/timm_models_list.txt
@@ -0,0 +1,62 @@
+adv_inception_v3 128
+beit_base_patch16_224 128
+botnet26t_256 128
+cait_m36_384 8
+coat_lite_mini 128
+convit_base 128
+convmixer_768_32 64
+convnext_base 128
+crossvit_9_240 128
+cspdarknet53 128
+deit_base_distilled_patch16_224 128
+dla102 128
+dm_nfnet_f0 128
+dpn107 64
+eca_botnext26ts_256 128
+eca_halonext26ts 128
+ese_vovnet19b_dw 128
+fbnetc_100 128
+fbnetv3_b 128
+gernet_l 128
+ghostnet_100 128
+gluon_inception_v3 128
+gluon_xception65 64
+gmixer_24_224 128
+gmlp_s16_224 128
+hrnet_w18 128
+inception_v3 128
+jx_nest_base 128
+lcnet_050 128
+levit_128 128
+mixer_b16_224 128
+mixnet_l 128
+mnasnet_100 128
+mobilenetv2_100 128
+mobilenetv3_large_100 128
+mobilevit_s 128
+nfnet_l0 128
+pit_b_224 128
+pnasnet5large 32
+poolformer_m36 128
+regnety_002 128
+repvgg_a2 128
+res2net101_26w_4s 128
+res2net50_14w_8s 128
+res2next50 128
+resmlp_12_224 128
+resnest101e 128
+rexnet_100 128
+sebotnet33ts_256 128
+selecsls42b 128
+spnasnet_100 128
+swin_base_patch4_window7_224 128
+swsl_resnext101_32x16d 64
+tf_efficientnet_b0 128
+tf_mixnet_l 128
+tinynet_a 128
+tnt_s_patch16_224 128
+twins_pcpvt_base 128
+visformer_small 128
+vit_base_patch16_224 128
+volo_d1_224 128
+xcit_large_24_p8_224 23
diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py
new file mode 100755
index 0000000000000..9b1297a129aea
--- /dev/null
+++ b/benchmarks/dynamo/torchbench.py
@@ -0,0 +1,338 @@
+#!/usr/bin/env python3
+import gc
+import importlib
+import logging
+import os
+import re
+import sys
+import warnings
+from os.path import abspath, exists
+
+import torch
+from common import BenchmarkRunner, main
+
+from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
+from torch._dynamo.utils import clone_inputs
+
+# We are primarily interested in tf32 datatype
+torch.backends.cuda.matmul.allow_tf32 = True
+
+os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam
+for torchbench_dir in (
+ "./torchbenchmark",
+ "../torchbenchmark",
+ "../torchbench",
+ "../benchmark",
+ "../../torchbenchmark",
+ "../../torchbench",
+ "../../benchmark",
+):
+ if exists(torchbench_dir):
+ break
+
+assert exists(torchbench_dir), "../../torchbenchmark does not exist"
+original_dir = abspath(os.getcwd())
+torchbench_dir = abspath(torchbench_dir)
+
+os.chdir(torchbench_dir)
+sys.path.append(torchbench_dir)
+
+
+# Some models have large dataset that doesn't fit in memory. Lower the batch
+# size to test the accuracy.
+USE_SMALL_BATCH_SIZE = {
+ "demucs": 4,
+ "densenet121": 4,
+ "hf_Reformer": 4,
+ "timm_efficientdet": 1,
+}
+
+DETECTRON2_MODELS = {
+ "detectron2_fasterrcnn_r_101_c4",
+ "detectron2_fasterrcnn_r_101_dc5",
+ "detectron2_fasterrcnn_r_101_fpn",
+ "detectron2_fasterrcnn_r_50_c4",
+ "detectron2_fasterrcnn_r_50_dc5",
+ "detectron2_fasterrcnn_r_50_fpn",
+ "detectron2_maskrcnn_r_101_c4",
+ "detectron2_maskrcnn_r_101_fpn",
+ "detectron2_maskrcnn_r_50_fpn",
+}
+
+SKIP = {
+ # https://github.com/pytorch/torchdynamo/issues/101
+ "detectron2_maskrcnn",
+ # https://github.com/pytorch/torchdynamo/issues/145
+ "fambench_xlmr",
+}
+
+# Additional models that are skipped in training
+SKIP_TRAIN = {
+ # not designed for training
+ "pyhpc_equation_of_state",
+ "pyhpc_isoneutral_mixing",
+ "pyhpc_turbulent_kinetic_energy",
+ # Unusual training setup
+ "opacus_cifar10",
+ "maml",
+}
+SKIP_TRAIN.update(DETECTRON2_MODELS)
+
+# These models support only train mode. So accuracy checking can't be done in
+# eval mode.
+ONLY_TRAINING_MODE = {
+ "tts_angular",
+ "tacotron2",
+ "demucs",
+ "hf_Reformer",
+ "pytorch_struct",
+ "yolov3",
+}
+ONLY_TRAINING_MODE.update(DETECTRON2_MODELS)
+
+# Need lower tolerance on GPU. GPU kernels have non deterministic kernels for these models.
+REQUIRE_HIGHER_TOLERANCE = {
+ "alexnet",
+ "attention_is_all_you_need_pytorch",
+ "densenet121",
+ "hf_Albert",
+ "vgg16",
+ "mobilenet_v3_large",
+ "nvidia_deeprecommender",
+ "timm_efficientdet",
+ "vision_maskrcnn",
+}
+
+# These models need >1e-3 tolerance
+REQUIRE_EVEN_HIGHER_TOLERANCE = {
+ "soft_actor_critic",
+ "tacotron2",
+}
+
+REQUIRE_COSINE_TOLERACE = {
+ # https://github.com/pytorch/torchdynamo/issues/556
+ "resnet50_quantized_qat",
+}
+
+# non-deterministic output / cant check correctness
+NONDETERMINISTIC = set()
+
+# These benchmarks took >600s on an i9-11900K CPU
+VERY_SLOW_BENCHMARKS = {
+ "hf_BigBird", # 3339s
+ "hf_Longformer", # 3062s
+ "hf_T5", # 930s
+}
+
+# These benchmarks took >60s on an i9-11900K CPU
+SLOW_BENCHMARKS = {
+ *VERY_SLOW_BENCHMARKS,
+ "BERT_pytorch", # 137s
+ "demucs", # 116s
+ "fastNLP_Bert", # 242s
+ "hf_Albert", # 221s
+ "hf_Bart", # 400s
+ "hf_Bert", # 334s
+ "hf_DistilBert", # 187s
+ "hf_GPT2", # 470s
+ "hf_Reformer", # 141s
+ "speech_transformer", # 317s
+ "vision_maskrcnn", # 99s
+}
+
+TRT_NOT_YET_WORKING = {
+ "alexnet",
+ "resnet18",
+ "resnet50",
+ "mobilenet_v2",
+ "mnasnet1_0",
+ "squeezenet1_1",
+ "shufflenetv2_x1_0",
+ "vgg16",
+ "resnext50_32x4d",
+}
+
+DYNAMIC_SHAPES_NOT_YET_WORKING = {
+ "demucs",
+ "timm_nfnet",
+}
+
+DONT_CHANGE_BATCH_SIZE = {
+ "demucs",
+ "pytorch_struct",
+ "pyhpc_turbulent_kinetic_energy",
+}
+
+
+SKIP_ACCURACY_CHECK_MODELS = {
+ # Models too large to have eager, dynamo and fp64_numbers simultaneosuly
+ # even for 40 GB machine. We have tested accuracy for smaller version of
+ # these models
+ "hf_GPT2_large",
+ "hf_T5_large",
+ "timm_vision_transformer_large",
+}
+
+
+class TorchBenchmarkRunner(BenchmarkRunner):
+ def __init__(self):
+ super(TorchBenchmarkRunner, self).__init__()
+ self.suite_name = "torchbench"
+
+ @property
+ def skip_models(self):
+ return SKIP
+
+ @property
+ def slow_models(self):
+ return SLOW_BENCHMARKS
+
+ @property
+ def very_slow_models(self):
+ return VERY_SLOW_BENCHMARKS
+
+ @property
+ def non_deterministic_models(self):
+ return NONDETERMINISTIC
+
+ @property
+ def skip_not_suitable_for_training_models(self):
+ return SKIP_TRAIN
+
+ @property
+ def failing_fx2trt_models(self):
+ return TRT_NOT_YET_WORKING
+
+ @property
+ def failing_dynamic_shape_models(self):
+ return DYNAMIC_SHAPES_NOT_YET_WORKING
+
+ @property
+ def skip_accuracy_checks_large_models_dashboard(self):
+ if self.args.dashboard:
+ return SKIP_ACCURACY_CHECK_MODELS
+ return set()
+
+ def load_model(
+ self,
+ device,
+ model_name,
+ batch_size=None,
+ ):
+
+ is_training = self.args.training
+ use_eval_mode = self.args.use_eval_mode
+ dynamic_shapes = self.args.dynamic_shapes
+ module = importlib.import_module(f"torchbenchmark.models.{model_name}")
+ benchmark_cls = getattr(module, "Model", None)
+ if not hasattr(benchmark_cls, "name"):
+ benchmark_cls.name = model_name
+
+ cant_change_batch_size = (
+ not getattr(benchmark_cls, "ALLOW_CUSTOMIZE_BSIZE", True)
+ or model_name in DONT_CHANGE_BATCH_SIZE
+ )
+ if cant_change_batch_size:
+ batch_size = None
+ if batch_size is None and is_training and model_name in USE_SMALL_BATCH_SIZE:
+ batch_size = USE_SMALL_BATCH_SIZE[model_name]
+
+ if is_training:
+ benchmark = benchmark_cls(
+ test="train", device=device, jit=False, batch_size=batch_size
+ )
+ else:
+ benchmark = benchmark_cls(
+ test="eval", device=device, jit=False, batch_size=batch_size
+ )
+ if dynamic_shapes:
+ if not hasattr(benchmark, "get_dynamic_shapes_module"):
+ raise NotImplementedError("Dynamic Shapes not supported")
+ model, example_inputs = benchmark.get_dynamic_shapes_module()
+ else:
+ model, example_inputs = benchmark.get_module()
+
+ # Models that must be in train mode while training
+ if is_training and (not use_eval_mode or model_name in ONLY_TRAINING_MODE):
+ model.train()
+ else:
+ model.eval()
+ gc.collect()
+ batch_size = benchmark.batch_size
+
+ self.init_optimizer(device, model.parameters())
+
+ # Torchbench has quite different setup for yolov3, so directly passing
+ # the right example_inputs
+ if model_name == "yolov3":
+ example_inputs = (torch.rand(batch_size, 3, 384, 512).to(device),)
+ # global current_name, current_device
+ # current_device = device
+ # current_name = benchmark.name
+ self.validate_model(model, example_inputs)
+ return device, benchmark.name, model, example_inputs, batch_size
+
+ def iter_model_names(self, args):
+ from torchbenchmark import _list_model_paths
+
+ models = _list_model_paths()
+ start, end = self.get_benchmark_indices(len(models))
+ for index, model_path in enumerate(models):
+ if index < start or index >= end:
+ continue
+
+ model_name = os.path.basename(model_path)
+ if (
+ not re.search("|".join(args.filter), model_name, re.I)
+ or re.search("|".join(args.exclude), model_name, re.I)
+ or model_name in SKIP
+ ):
+ continue
+
+ yield model_name
+
+ def pick_grad(self, name, is_training):
+ if is_training or name in ("maml",):
+ return torch.enable_grad()
+ else:
+ return torch.no_grad()
+
+ def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
+ tolerance = 1e-4
+ cosine = self.args.cosine
+ # Increase the tolerance for torch allclose
+ if self.args.float16:
+ return 1e-3, cosine
+ if is_training and current_device == "cuda":
+ if name in REQUIRE_COSINE_TOLERACE:
+ cosine = True
+ elif name in REQUIRE_HIGHER_TOLERANCE:
+ tolerance = 1e-3
+ elif name in REQUIRE_EVEN_HIGHER_TOLERANCE:
+ tolerance = 8 * 1e-2
+ return tolerance, cosine
+
+ def compute_loss(self, pred):
+ return reduce_to_scalar_loss(pred)
+
+ def forward_pass(self, mod, inputs, collect_outputs=True):
+ return mod(*inputs)
+
+ def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
+ cloned_inputs = clone_inputs(inputs)
+ mod.zero_grad(True)
+ with self.autocast():
+ pred = mod(*cloned_inputs)
+ loss = self.compute_loss(pred)
+ self.grad_scaler.scale(loss).backward()
+ self.optimizer_step()
+ if collect_outputs:
+ return collect_results(mod, pred, loss, cloned_inputs)
+ return None
+
+
+if __name__ == "__main__":
+
+ logging.basicConfig(level=logging.WARNING)
+ warnings.filterwarnings("ignore")
+ main(TorchBenchmarkRunner(), original_dir)
diff --git a/benchmarks/dynamo/torchbench_models_list.txt b/benchmarks/dynamo/torchbench_models_list.txt
new file mode 100644
index 0000000000000..04947c4a6a301
--- /dev/null
+++ b/benchmarks/dynamo/torchbench_models_list.txt
@@ -0,0 +1,28 @@
+BERT_pytorch,128
+Background_Matting, 16
+LearningToPaint,1024
+alexnet,1024
+dcgan,1024
+densenet121,64
+hf_Albert,32
+hf_Bart,16
+hf_Bert,16
+hf_GPT2,16
+hf_T5,4
+mnasnet1_0,256
+mobilenet_v2,128
+mobilenet_v3_large,256
+nvidia_deeprecommender,1024
+pytorch_unet,8
+resnet18,512
+resnet50,128
+resnext50_32x4d,128
+shufflenet_v2_x1_0,512
+squeezenet1_1,512
+timm_nfnet,256
+timm_efficientnet,128
+timm_regnet,128
+timm_resnest,256
+timm_vision_transformer,256
+timm_vovnet,128
+vgg16,128
diff --git a/benchmarks/dynamo/training_loss.py b/benchmarks/dynamo/training_loss.py
new file mode 100644
index 0000000000000..2ec7945403348
--- /dev/null
+++ b/benchmarks/dynamo/training_loss.py
@@ -0,0 +1,205 @@
+import argparse
+import inspect
+import os
+import sys
+import time
+from datetime import timedelta
+
+import torch
+
+import torch._dynamo
+from datasets import load_dataset, load_metric
+from torch.utils.data import DataLoader
+from transformers import AutoModelForSequenceClassification, AutoTokenizer
+
+torch.backends.cuda.matmul.allow_tf32 = True
+
+# You will download around 84G dataset if you run this end to end training/evaluation example.
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+
+def data_processing(num_samples, batch_size):
+ dataset = load_dataset("yelp_review_full")
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
+
+ def tokenize_function(examples):
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
+
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
+
+ tokenized_datasets = tokenized_datasets.remove_columns(["text"])
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
+ tokenized_datasets.set_format("torch")
+
+ small_train_dataset = tokenized_datasets["train"].select(range(num_samples))
+ small_eval_dataset = tokenized_datasets["test"].select(range(num_samples))
+
+ train_dataloader = DataLoader(small_train_dataset, batch_size=batch_size)
+ eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size)
+
+ return train_dataloader, eval_dataloader
+
+
+def training_iter_fn(batch, model, optimizer):
+ outputs = model(**batch)
+ loss = outputs.loss
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ return loss
+
+
+def model_training_evaluation(
+ backend, train_dataloader, eval_dataloader, model, optimizer, num_epochs, evaluation
+):
+ model.to(device)
+ model.train()
+ loss_history = []
+ if not backend:
+ # Run with native Pytorch
+ opt_training_iter_fn = training_iter_fn
+ else:
+ # Support backends: eager, aot_eager, aot_nvfuser and inductor
+ opt_training_iter_fn = torch._dynamo.optimize(backend)(training_iter_fn)
+ for epoch in range(num_epochs):
+ running_loss = 0.0
+ for i, batch in enumerate(train_dataloader, 0):
+ batch = {k: v.to(device) for k, v in batch.items()}
+ loss = opt_training_iter_fn(batch, model, optimizer)
+ running_loss += loss.item()
+ if i % 100 == 99:
+ loss_history.append(running_loss / 100)
+ running_loss = 0.0
+
+ if evaluation:
+ metric = load_metric("accuracy")
+ model.eval()
+ if not backend:
+ opt_model = model
+ else:
+ opt_model = torch._dynamo.optimize(backend)(model)
+ for batch in eval_dataloader:
+ batch = {k: v.to(device) for k, v in batch.items()}
+ with torch.no_grad():
+ outputs = opt_model(**batch)
+
+ logits = outputs.logits
+ predictions = torch.argmax(logits, dim=-1)
+ metric.add_batch(predictions=predictions, references=batch["labels"])
+
+ return loss_history, metric.compute()
+ else:
+ return loss_history, None
+
+
+def check_loss(ref_loss, res_loss):
+ assert len(ref_loss) == len(res_loss)
+ length = len(ref_loss)
+ x = min(length, 10)
+ if sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 1e-1:
+ return True
+ else:
+ return False
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="TorchDynamo end to end training/evaluation benchmark"
+ )
+ parser.add_argument(
+ "--epochs", type=int, default=10, help="number of epochs to train (default: 10)"
+ )
+ parser.add_argument(
+ "--num-samples",
+ type=int,
+ default=1000,
+ help="number of samples to train/eval (default: 1000)",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=8,
+ help="input batch size for training (default: 8)",
+ )
+ parser.add_argument(
+ "--lr", type=float, default=5e-5, help="learning rate (default: 5e-5)"
+ )
+ parser.add_argument(
+ "--backend",
+ choices=torch._dynamo.list_backends(),
+ default="inductor",
+ help="train/evaluate model with a given backend (default: inductor)",
+ )
+ parser.add_argument(
+ "--optimizer",
+ default="Adam",
+ help="train model using a given optimizer (default: Adam)",
+ )
+ parser.add_argument(
+ "--evaluation",
+ action="store_true",
+ help="running evaluation after model training",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+ train_dataloader, eval_dataloader = data_processing(
+ args.num_samples, args.batch_size
+ )
+ model = AutoModelForSequenceClassification.from_pretrained(
+ "bert-base-cased", num_labels=5
+ )
+ optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer)
+ if "capturable" in inspect.signature(optimizer_cls).parameters.keys():
+ optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True)
+ else:
+ optimizer = optimizer_cls(model.parameters(), lr=args.lr)
+ native_start = time.time()
+ ref_loss, accuracy = model_training_evaluation(
+ None,
+ train_dataloader,
+ eval_dataloader,
+ model,
+ optimizer,
+ args.epochs,
+ args.evaluation,
+ )
+ native_end = time.time()
+ res_loss, accuracy = model_training_evaluation(
+ args.backend,
+ train_dataloader,
+ eval_dataloader,
+ model,
+ optimizer,
+ args.epochs,
+ args.evaluation,
+ )
+ dynamo_end = time.time()
+ if check_loss(ref_loss, res_loss):
+ print(
+ "[PASSED] TorchDynamo end to end training loss is less than or equal to native PyTorch"
+ )
+ else:
+ print(
+ "[FAILED] TorchDynamo end to end training loss is greater than native Pytorch"
+ )
+ if args.evaluation:
+ print(f"Model accuracy: {accuracy}")
+ native_elapsed = native_end - native_start
+ dynamo_elapsed = dynamo_end - native_end
+ print(
+ f"Train model on {args.epochs} epochs with backend {args.backend} and optimizer {args.optimizer}:"
+ )
+ print(f"PyTorch spent {timedelta(seconds=native_elapsed/args.epochs)} per epoch")
+ print(
+ f"TorchDynamo spent {timedelta(seconds=dynamo_elapsed/args.epochs)} per epoch"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
index 64808a00d60f4..573b7a08a568b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,3 +12,6 @@ six
types-dataclasses
typing_extensions
sympy
+filelock
+networkx
+jinja2
diff --git a/setup.py b/setup.py
index 4f3f86d8cb9ac..e464a43255960 100644
--- a/setup.py
+++ b/setup.py
@@ -968,6 +968,8 @@ def main():
# the list of runtime dependencies required by this built package
install_requires = [
'typing_extensions',
+ 'sympy',
+ 'networkx',
]
extras_require = {
diff --git a/test/dynamo/__init__.py b/test/dynamo/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/test/dynamo/mock_modules/__init__.py b/test/dynamo/mock_modules/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/test/dynamo/mock_modules/mock_module1.py b/test/dynamo/mock_modules/mock_module1.py
new file mode 100644
index 0000000000000..c4bd2bf4f9deb
--- /dev/null
+++ b/test/dynamo/mock_modules/mock_module1.py
@@ -0,0 +1,2 @@
+def method1(a, b):
+ return a + b
diff --git a/test/dynamo/mock_modules/mock_module2.py b/test/dynamo/mock_modules/mock_module2.py
new file mode 100644
index 0000000000000..7fe8979709c35
--- /dev/null
+++ b/test/dynamo/mock_modules/mock_module2.py
@@ -0,0 +1,19 @@
+# from . import mock_module3
+import torch
+
+from . import mock_module3
+
+
+class Class1:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def method2(self, x):
+ return mock_module3.method1([], x)
+
+
+def method1(x, y):
+ torch.ones(1, 1)
+ x.append(y)
+ return x
diff --git a/test/dynamo/mock_modules/mock_module3.py b/test/dynamo/mock_modules/mock_module3.py
new file mode 100644
index 0000000000000..8af77a237a89b
--- /dev/null
+++ b/test/dynamo/mock_modules/mock_module3.py
@@ -0,0 +1,7 @@
+import torch
+
+
+def method1(x, y):
+ torch.ones(1, 1)
+ x.append(y)
+ return x
diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py
new file mode 100644
index 0000000000000..b185313f8b142
--- /dev/null
+++ b/test/dynamo/test_aot_autograd.py
@@ -0,0 +1,79 @@
+# Owner(s): ["module: dynamo"]
+import functools
+
+import torch
+
+import torch._dynamo
+from torch._dynamo.optimizations.training import is_aot_autograd_safe_to_run
+from torch._dynamo.testing import rand_strided
+
+
+def compiler_safe_fn(gm, example_inputs, is_safe):
+ is_safe[0] = is_aot_autograd_safe_to_run(gm, example_inputs)
+ return gm.forward
+
+
+class AotAutogradFallbackTests(torch._dynamo.testing.TestCase):
+ def test_LSTM(self):
+ # https://github.com/pytorch/torchdynamo/issues/1147
+ class Repro(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.self_mod_model_lstm_lstm = torch.nn.LSTM(
+ 64, 64, num_layers=2, bidirectional=True
+ )
+
+ def forward(self, permute: torch.Tensor):
+ self_mod_model_lstm_lstm = self.self_mod_model_lstm_lstm(permute)
+ return (self_mod_model_lstm_lstm,)
+
+ is_safe = [True]
+ mod = Repro()
+ compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
+ aot_mod = torch._dynamo.optimize(compiler_fn)(mod)
+
+ args = [((92, 4, 64), (1, 5888, 92), torch.float32, "cpu", False)]
+ args = [
+ rand_strided(sh, st, dt, dev).requires_grad_(rg)
+ for (sh, st, dt, dev, rg) in args
+ ]
+
+ aot_mod(*args)
+ self.assertTrue(not is_safe[0])
+
+ def test_mutation(self):
+ # https://github.com/pytorch/torchdynamo/issues/1301
+ def fn(param, y):
+ prev_grad = torch.is_grad_enabled()
+ try:
+ torch.set_grad_enabled(False)
+ param.add_(y)
+ finally:
+ torch.set_grad_enabled(prev_grad)
+ return y
+
+ y = torch.randn(4)
+ x = torch.nn.Parameter(torch.randn(4))
+ is_safe = [True]
+ compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
+ aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
+ aot_fn(x, y)
+ self.assertTrue(not is_safe[0])
+
+ def test_negative_testing(self):
+ def fn(x, y):
+ return torch.sin(x).add_(y)
+
+ y = torch.randn(4)
+ x = torch.randn(4)
+ is_safe = [True]
+ compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
+ aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
+ aot_fn(x, y)
+ self.assertTrue(is_safe[0])
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py
new file mode 100644
index 0000000000000..37eeb6af3b305
--- /dev/null
+++ b/test/dynamo/test_aot_cudagraphs.py
@@ -0,0 +1,206 @@
+# Owner(s): ["module: cuda graphs"]
+
+import functools
+import unittest
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo
+import torch._dynamo.testing
+from torch._dynamo.testing import same
+
+
+def composed(*decs):
+ def deco(f):
+ for dec in reversed(decs):
+ f = dec(f)
+ return f
+
+ return deco
+
+
+def assert_aot_autograd_counter(ok=True):
+ def deco(f):
+ @functools.wraps(f)
+ def wrap(self, *args, **kwargs):
+ torch._dynamo.utils.counters.clear()
+ r = f(self, *args, **kwargs)
+ c_ok = torch._dynamo.utils.counters["aot_autograd"]["ok"]
+ c_not_ok = torch._dynamo.utils.counters["aot_autograd"]["not_ok"]
+ if ok:
+ self.assertGreater(c_ok, 0)
+ self.assertEqual(c_not_ok, 0)
+ else:
+ self.assertEqual(c_ok, 0)
+ self.assertGreater(c_not_ok, 0)
+ return r
+
+ return wrap
+
+ return deco
+
+
+def patch_all(ok=True):
+ return composed(
+ patch("torch._dynamo.config.verify_correctness", True),
+ assert_aot_autograd_counter(ok),
+ )
+
+
+N_ITERS = 5
+
+
+@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda")
+class TestAotCudagraphs(torch._dynamo.testing.TestCase):
+ @patch_all()
+ def test_basic(self):
+ def model(x, y):
+ return (x + y) * y
+
+ @torch._dynamo.optimize("aot_cudagraphs")
+ def fn(x, y):
+ for i in range(N_ITERS):
+ loss = model(x, y).sum()
+ loss.backward()
+
+ x = torch.randn(3, device="cuda", requires_grad=True)
+ y = torch.randn(3, device="cuda")
+ fn(x, y)
+
+ @patch_all()
+ def test_dtoh(self):
+ def model(x, y):
+ a = x + y
+ b = a.cpu() * 3
+ return b
+
+ @torch._dynamo.optimize("aot_cudagraphs")
+ def fn(x, y):
+ for i in range(N_ITERS):
+ loss = model(x, y).sum()
+ loss.backward()
+
+ x = torch.randn(3, device="cuda", requires_grad=True)
+ y = torch.randn(3, device="cuda")
+ fn(x, y)
+
+ @patch_all()
+ def test_htod(self):
+ def model(x, y):
+ a = x + y
+ return a * 3
+
+ @torch._dynamo.optimize("aot_cudagraphs")
+ def fn(x, y):
+ for i in range(N_ITERS):
+ loss = model(x, y).sum()
+ loss.backward()
+
+ x = torch.randn(3, device="cuda", requires_grad=True)
+ y = torch.randn((), device="cpu")
+ fn(x, y)
+
+ @patch("functorch._src.config.use_functionalize", True)
+ @patch_all(ok=False) # input mutation not supported yet
+ def test_mutate_input(self):
+ def model(x, y):
+ y.add_(3)
+ return x * y
+
+ @torch._dynamo.optimize("aot_cudagraphs")
+ def fn(x, y):
+ for i in range(N_ITERS):
+ with self.subTest(i):
+ y_orig = y.clone()
+ loss = model(x, y).sum()
+ self.assertTrue(same(y, y_orig + 3))
+ loss.backward()
+
+ x = torch.randn(3, device="cuda", requires_grad=True)
+ y = torch.randn(3, device="cuda")
+ fn(x, y)
+
+ @patch_all()
+ def test_mutate_constant(self):
+ def model(x, y):
+ c = torch.tensor(1)
+ c.add_(2)
+ return x * y * 0 + c
+
+ @torch._dynamo.optimize("aot_cudagraphs")
+ def fn(x, y):
+ for i in range(N_ITERS):
+ with self.subTest(i):
+ loss = model(x, y).sum()
+ self.assertTrue(same(loss, torch.tensor(3.0, device="cuda")))
+ loss.backward()
+
+ x = torch.randn(1, device="cuda", requires_grad=True)
+ y = torch.randn(1, device="cuda")
+ fn(x, y)
+
+ @patch_all()
+ def test_factory(self):
+ def model(y):
+ x = torch.zeros(3, device="cuda:0")
+ x.add_(3)
+ return x * y
+
+ @torch._dynamo.optimize("aot_cudagraphs")
+ def fn(y):
+ for i in range(N_ITERS):
+ with self.subTest(i):
+ loss = model(y).sum()
+ loss.backward()
+
+ y = torch.randn(3, device="cuda:0", requires_grad=True)
+ fn(y)
+
+ @patch("functorch._src.config.use_functionalize", True)
+ @patch_all()
+ def test_mutated_metadata(self):
+ # more tortured example at
+ # https://github.com/pytorch/pytorch/issues/81385
+ def model(x):
+ x = x.clone()
+ x.resize_(20)
+ x.fill_(2)
+ return x
+
+ @torch._dynamo.optimize("aot_cudagraphs")
+ def fn(x):
+ for i in range(N_ITERS):
+ with self.subTest(i):
+ rx = model(x)
+ self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
+
+ x = torch.empty(0, device="cuda:0")
+ fn(x)
+
+ @patch("functorch._src.config.use_functionalize", True)
+ @patch_all()
+ def test_dead_fill(self):
+ def model(x):
+ x = x.clone()
+ y = x[0:0]
+ x.fill_(2)
+ y.fill_(3)
+ return x, y
+
+ @torch._dynamo.optimize("aot_cudagraphs")
+ def fn(x):
+ for i in range(N_ITERS):
+ with self.subTest(i):
+ rx, ry = model(x)
+ self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
+ self.assertTrue(same(ry, torch.empty(0, device="cuda:0")))
+
+ x = torch.empty(20, device="cuda:0")
+ fn(x)
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py
new file mode 100644
index 0000000000000..c1684a013d713
--- /dev/null
+++ b/test/dynamo/test_distributed.py
@@ -0,0 +1,229 @@
+# Owner(s): ["module: dynamo"]
+import os
+import unittest
+from unittest.mock import patch
+
+import pytest
+import torch
+
+import torch._dynamo
+import torch.distributed as dist
+from torch import nn
+from torch._dynamo import config
+from torch._dynamo.testing import same
+
+
+class ToyModel(nn.Module):
+ def __init__(self, in_feat=10, hidden_feat=5000, num_hidden=2, out_feat=5):
+ super().__init__()
+ self.net = nn.Sequential(
+ *[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ + [nn.Linear(5000, 5000), nn.ReLU()] * num_hidden
+ + [nn.Linear(5000, 5), nn.ReLU()]
+ )
+
+ def forward(self, inputs):
+ return self.net(inputs)
+
+
+class CheckSplitsCompiler:
+ def __init__(self):
+ self.compiler_called = 0
+
+ def compile_fn(self, gm, example_inputs):
+ self.compiler_called += 1
+ return gm
+
+
+def skip_if_no_active_ddp():
+ from torch.nn.parallel import DistributedDataParallel as DDP
+
+ if not hasattr(DDP, "_get_active_ddp_module"):
+ raise unittest.SkipTest("requires pytorch landing in parallel")
+
+
+@pytest.mark.skip("Module hangs in PyTorch CI")
+class TestDistributed(torch._dynamo.testing.TestCase):
+ """
+ Test harness initializes dist process group
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ # _exit_stack is set up in TestCase
+ cls._exit_stack.enter_context(
+ patch.dict(
+ os.environ,
+ {
+ "MASTER_ADDR": "localhost",
+ "MASTER_PORT": "12355",
+ },
+ )
+ )
+ cls.rank = 0
+ cls.device = f"cpu:{cls.rank}"
+ cls.device_ids = None if "cpu" in cls.device else [cls.rank]
+ dist.init_process_group("gloo", rank=cls.rank, world_size=1)
+
+ @classmethod
+ def tearDownClass(cls):
+ dist.destroy_process_group()
+ super().tearDownClass()
+
+ def get_model(self):
+ m = ToyModel().to(self.device)
+ inputs = torch.randn(20, 10).to(self.device)
+ outputs = m(inputs)
+ return m, inputs, outputs
+
+ @patch.object(config, "optimize_ddp", False)
+ def test_ddp_baseline_aot_eager(self):
+ from torch.nn.parallel import DistributedDataParallel as DDP
+
+ m, inputs, correct_outputs = self.get_model()
+ ddp_m = DDP(m, device_ids=self.device_ids)
+ ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m)
+ outputs = ddp_m(inputs)
+ self.assertTrue(same(correct_outputs, outputs))
+
+ @patch.object(config, "optimize_ddp", False)
+ def test_ddp_baseline_inductor(self):
+ from torch.nn.parallel import DistributedDataParallel as DDP
+
+ m, inputs, correct_outputs = self.get_model()
+ ddp_m = DDP(m, device_ids=self.device_ids)
+ ddp_m = torch._dynamo.optimize("inductor")(ddp_m)
+ outputs = ddp_m(inputs)
+ self.assertTrue(same(correct_outputs, outputs))
+
+ # can't run with gloo (no support for _allgather_base) and nccl not available in CI
+ @pytest.mark.xfail
+ @patch.object(config, "optimize_ddp", False)
+ def test_fsdp_baseline_aot_eager(self):
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+ m, inputs, correct_outputs = self.get_model()
+ fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None)
+ fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
+ outputs = fsdp_m(inputs)
+ self.assertTrue(same(correct_outputs, outputs))
+
+ # hangs/crashes with inductor currently
+ @pytest.mark.skip
+ @patch.object(config, "optimize_ddp", False)
+ def test_fsdp_baseline_inductor(self):
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+ m, inputs, correct_outputs = self.get_model()
+ fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None)
+ fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
+ outputs = fsdp_m(inputs)
+ self.assertTrue(same(correct_outputs, outputs))
+
+ @patch.object(config, "optimize_ddp", True)
+ def test_graph_split(self):
+ """
+ Just ensures that the appropriate number of splits happen (based on
+ bucket size and model parameters) - verifies the number of times
+ the user-provided compiler is called by the DDPOptimizer which is
+ doing the graph splitting
+ """
+ from torch.nn.parallel import DistributedDataParallel as DDP
+
+ skip_if_no_active_ddp()
+
+ m, inputs, correct_outputs = self.get_model()
+ ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
+
+ check_splits_compiler = CheckSplitsCompiler()
+
+ @torch._dynamo.optimize(check_splits_compiler.compile_fn)
+ def opt_fn(inputs):
+ return ddp_m(inputs)
+
+ opt_outputs = opt_fn(inputs)
+ self.assertTrue(same(correct_outputs, opt_outputs))
+ self.assertEqual(check_splits_compiler.compiler_called, 3)
+
+ # hangs/crashes with inductor currently
+ @pytest.mark.skip
+ @patch.object(config, "optimize_ddp", True)
+ def test_graph_split_inductor(self):
+ """
+ Same as above, but using inductor backend.
+ We observed issues with inductor/fx interface in the past.
+ """
+ from torch.nn.parallel import DistributedDataParallel as DDP
+
+ skip_if_no_active_ddp()
+ m, inputs, correct_outputs = self.get_model()
+ ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
+
+ @torch._dynamo.optimize("inductor")
+ def opt_fn(inputs):
+ return ddp_m(inputs)
+
+ opt_outputs = opt_fn(inputs)
+ self.assertTrue(same(correct_outputs, opt_outputs))
+
+ @patch.object(config, "optimize_ddp", True)
+ def test_no_split(self):
+ """
+ Ensures the DDPOptimizer returns a correct, compiled module without
+ introducing graph splits. (Based on model parmeters fitting in the bucket)
+ """
+ from torch.nn.parallel import DistributedDataParallel as DDP
+
+ skip_if_no_active_ddp()
+ m, inputs, correct_outputs = self.get_model()
+ ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250)
+
+ check_splits_compiler = CheckSplitsCompiler()
+
+ @torch._dynamo.optimize(check_splits_compiler.compile_fn)
+ def opt_fn(inputs):
+ return ddp_m(inputs)
+
+ opt_outputs = opt_fn(inputs)
+ self.assertTrue(same(correct_outputs, opt_outputs))
+ self.assertEqual(check_splits_compiler.compiler_called, 1)
+
+ @patch.object(config, "optimize_ddp", True)
+ def test_aot_autograd(self):
+ """
+ Explicitly check AotAutograd family of compilers work,
+ since they require example inputs propagated between graph splits.
+ """
+ from torch.nn.parallel import DistributedDataParallel as DDP
+
+ skip_if_no_active_ddp()
+ m, inputs, correct_outputs = self.get_model()
+ ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
+
+ @torch._dynamo.optimize("aot_eager")
+ def opt_fn(inputs):
+ return ddp_m(inputs)
+
+ opt_outputs = opt_fn(inputs)
+ opt_outputs.sum().backward()
+ self.assertTrue(same(correct_outputs, opt_outputs))
+
+ def test_empty_graph(self):
+ def fn():
+ get_world_size = torch.distributed.distributed_c10d.get_world_size()
+ return (get_world_size,)
+
+ opt_fn = torch._dynamo.optimize("inductor")(fn)
+ res = None
+ try:
+ res = opt_fn()[0]
+ except Exception:
+ pass
+ self.assertEqual(res, 1)
+
+
+# TODO(jansel): debug issues running this in CI
+# if __name__ == "__main__":
+# from torch._dynamo.testing import run_tests
+# run_tests()
diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py
new file mode 100644
index 0000000000000..2c9c90df19e05
--- /dev/null
+++ b/test/dynamo/test_dynamic_shapes.py
@@ -0,0 +1,30 @@
+# Owner(s): ["module: dynamo"]
+
+from torch._dynamo.testing import make_test_cls_with_patches
+
+try:
+ from . import test_functions, test_misc, test_modules, test_repros, test_unspec
+except ImportError:
+ import test_functions
+ import test_misc
+ import test_modules
+ import test_repros
+ import test_unspec
+
+
+def make_dynamic_cls(cls):
+ return make_test_cls_with_patches(
+ cls, "DynamicShapes", "_dynamic_shapes", ("dynamic_shapes", True)
+ )
+
+
+DynamicShapesFunctionTests = make_dynamic_cls(test_functions.FunctionTests)
+DynamicShapesMiscTests = make_dynamic_cls(test_misc.MiscTests)
+DynamicShapesReproTests = make_dynamic_cls(test_repros.ReproTests)
+DynamicShapesNNModuleTests = make_dynamic_cls(test_modules.NNModuleTests)
+DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests)
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
new file mode 100644
index 0000000000000..9365535c73bc3
--- /dev/null
+++ b/test/dynamo/test_export.py
@@ -0,0 +1,1428 @@
+# Owner(s): ["module: dynamo"]
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo.testing
+import torch.utils._pytree as pytree
+from torch.fx.experimental.proxy_tensor import make_fx
+
+
+class ExportTests(torch._dynamo.testing.TestCase):
+ # TODO(voz): Refactor to a shared test function.
+ # The tests in this file are a little redundant,
+ # They all take a func, run it with eager, then export it, then compare
+ def test_export(self):
+ def pre_attention_state_ops(input, mems, state):
+ lc_key = state[0]
+ lc_val = state[1]
+ bar = []
+ for i in range(0, 4):
+ bar2 = []
+ for j in range(0, 3):
+ bar2.append(
+ lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
+ )
+ bar.append(bar2)
+
+ return bar
+
+ def func():
+ mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
+ state = [
+ torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
+ torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
+ ]
+ i = torch.tensor(
+ [
+ [0.0313, -0.1487, -0.3846, -0.5321],
+ [-1.7073, 1.3331, -0.0890, -1.4935],
+ [-0.8314, -0.1862, -0.5935, 1.5232],
+ ]
+ )
+ return pre_attention_state_ops(i, mems, state)
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func()
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func)
+ out_graph = exported[0]
+
+ dynamo_result = out_graph()
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_mismatched_out(self):
+ def func(x):
+ y = x + 1
+ return ([x, x], (y, y))
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, torch.tensor([[[1.3737, 0.1]]]))
+ out_graph = exported[0]
+
+ dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_graph_bypass(self):
+ inp = [
+ torch.tensor([0.1, 0.1]),
+ torch.tensor([0.2, 0.2]),
+ torch.tensor([0.3, 0.3]),
+ ]
+
+ def func(x):
+ first = x[2]
+ second = x[2]
+ return first * second
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_list_unpack(self):
+ inp = [
+ torch.tensor([0.1, 0.1]),
+ torch.tensor([0.2, 0.2]),
+ torch.tensor([0.3, 0.3]),
+ ]
+
+ def func(x):
+ first = x[2]
+ second = x[2]
+ return x[0], first * second, x[1], x[2]
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_mismatched_out_2(self):
+ def func(x):
+ y = x + 1
+ return ([x, x], (y, y))
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, torch.tensor([[[1.3737, 0.1]]]))
+ out_graph = exported[0]
+
+ dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_graph_with_list(self):
+ inp = [
+ torch.tensor([0.1, 0.1]),
+ torch.tensor([0.2, 0.2]),
+ torch.tensor([0.3, 0.3]),
+ torch.tensor([0.4, 0.4]),
+ ]
+
+ def func(x):
+ first = x[2]
+ second = x[2]
+ return first * second, x
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_graph_with_complex_reorder(self):
+ inp = [
+ torch.tensor([0.1, 0.1]),
+ torch.tensor([0.2, 0.2]),
+ torch.tensor([0.3, 0.3]),
+ torch.tensor([0.4, 0.4]),
+ ]
+
+ def func(x):
+ first = x[0]
+ second = x[1]
+ third = x[2]
+ return third, first, second, first * second, first * third
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes(self):
+ inp = torch.tensor([0.1, 0.1])
+
+ def func(x):
+ y = x + 1
+ return y, y
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes_2(self):
+ inp = torch.tensor([0.1, 0.1])
+
+ def func(x):
+ y = x + 1
+ return y, y
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes_and_bypass(self):
+ inp = torch.tensor([0.1, 0.1])
+ inp2 = torch.tensor([0.4, 0.4])
+ inps = [inp, inp2]
+
+ def func(x, z):
+ y = x + 1
+ return y, y, z
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes_and_bypass_with_non_tensor_arg(self):
+ inp = torch.tensor([0.1, 0.1])
+ inp2 = torch.tensor([0.1, 0.1])
+ inp3 = 4
+ inps = [inp, inp2, inp3]
+
+ def func(x, z, k):
+ y = x + k
+ return y, y, z
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes_and_bypass_reorder_with_non_tensor_arg(self):
+ inp = torch.tensor([0.1, 0.1])
+ inp2 = torch.tensor([0.1, 0.1])
+ inp3 = 4
+ inps = [inp, inp2, inp3]
+
+ def func(x, z, k):
+ y = x + k
+ return z, y, y
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_dupes_and_bypass_with_non_tensor_output(self):
+ inp = torch.tensor([0.1, 0.1])
+ inp2 = torch.tensor([0.1, 0.1])
+ inp3 = 4
+ inps = [inp, inp2, inp3]
+
+ def func(x, z, k):
+ y = x + k
+ return y[0].item(), y, z
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_zeroes_in_and_out_different_shape_on_test(self):
+ inp = torch.zeros(10)
+ inp2 = torch.zeros(10)
+ inp3 = torch.zeros(10)
+ inps = [inp, inp2, inp3]
+
+ inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
+
+ def func(a, b, c):
+ return [[a], [b, c], [a + b], [[c + c]]]
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps_rand)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps_rand)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_zeroes_in_new_shape_scalar_out(self):
+ inp = torch.zeros(10)
+ inp2 = torch.zeros(10)
+ inp3 = torch.zeros(10)
+ inps = [inp, inp2, inp3]
+
+ inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
+
+ def func(a, b, c):
+ return a[0].item() + b[0].item() + c[0].item()
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps_rand)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps_rand)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_zeroes_in_new_shape_scalar_out_permute(self):
+ inp = torch.zeros(10)
+ inp2 = torch.zeros(10)
+ inp3 = torch.zeros(10)
+ inps = [inp, inp2, inp3]
+
+ inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
+
+ def func(a, b, c):
+ return b[0].item() + c[0].item() + a[0].item() + a[0].item()
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps_rand)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps_rand)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self):
+ inp = torch.zeros(10)
+ inp2 = torch.zeros(10)
+ inp3 = torch.zeros(10)
+ inps = [inp, inp2, inp3]
+
+ inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
+
+ def func(a, b, c):
+ return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps_rand)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps_rand)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_func_return(self):
+ inp = torch.zeros(10)
+ inp2 = torch.zeros(10)
+ inp3 = torch.zeros(10)
+ inps = [inp, inp2, inp3]
+
+ inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
+
+ def func(a, b, c):
+ x = a + b + c
+
+ def func2(y):
+ return x * y
+
+ return func2(x)
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps_rand)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps_rand)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dict_return(self):
+ inp = torch.zeros(10)
+ inp2 = torch.zeros(10)
+ inp3 = torch.zeros(10)
+ inps = [inp, inp2, inp3]
+
+ inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
+
+ def func(a, b, c):
+ x = a + b + c
+ return {"a": x}
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps_rand)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps_rand)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_with_aten_graph(self):
+ def pre_attention_state_ops(input, mems, state):
+ lc_key = state[0]
+ lc_val = state[1]
+ bar = []
+ for i in range(0, 4):
+ bar2 = []
+ for j in range(0, 3):
+ bar2.append(
+ lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
+ )
+ bar.append(bar2)
+
+ return bar
+
+ def func():
+ mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
+ state = [
+ torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
+ torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
+ ]
+ i = torch.tensor(
+ [
+ [0.0313, -0.1487, -0.3846, -0.5321],
+ [-1.7073, 1.3331, -0.0890, -1.4935],
+ [-0.8314, -0.1862, -0.5935, 1.5232],
+ ]
+ )
+ return pre_attention_state_ops(i, mems, state)
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func()
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, aten_graph=True)
+ out_graph = exported[0]
+
+ dynamo_result = out_graph()
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_mismatched_out_with_aten_graph(self):
+ def func(x):
+ y = x + 1
+ return ([x, x], (y, y))
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(
+ func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True
+ )
+ out_graph = exported[0]
+
+ dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_graph_bypass_with_aten_graph(self):
+ inp = [
+ torch.tensor([0.1, 0.1]),
+ torch.tensor([0.2, 0.2]),
+ torch.tensor([0.3, 0.3]),
+ ]
+
+ def func(x):
+ first = x[2]
+ second = x[2]
+ return first * second
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_list_unpack_with_aten_graph(self):
+ inp = [
+ torch.tensor([0.1, 0.1]),
+ torch.tensor([0.2, 0.2]),
+ torch.tensor([0.3, 0.3]),
+ ]
+
+ def func(x):
+ first = x[2]
+ second = x[2]
+ return x[0], first * second, x[1], x[2]
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_mismatched_out_2_with_aten_graph(self):
+ def func(x):
+ y = x + 1
+ return ([x, x], (y, y))
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(
+ func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True
+ )
+ out_graph = exported[0]
+
+ dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_graph_with_list_with_aten_graph(self):
+ inp = [
+ torch.tensor([0.1, 0.1]),
+ torch.tensor([0.2, 0.2]),
+ torch.tensor([0.3, 0.3]),
+ torch.tensor([0.4, 0.4]),
+ ]
+
+ def func(x):
+ first = x[2]
+ second = x[2]
+ return first * second, x
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_graph_with_complex_reorder_with_aten_graph(self):
+ inp = [
+ torch.tensor([0.1, 0.1]),
+ torch.tensor([0.2, 0.2]),
+ torch.tensor([0.3, 0.3]),
+ torch.tensor([0.4, 0.4]),
+ ]
+
+ def func(x):
+ first = x[0]
+ second = x[1]
+ third = x[2]
+ return third, first, second, first * second, first * third
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes_with_aten_graph(self):
+ inp = torch.tensor([0.1, 0.1])
+
+ def func(x):
+ y = x + 1
+ return y, y
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes_2_with_aten_graph(self):
+ inp = torch.tensor([0.1, 0.1])
+
+ def func(x):
+ y = x + 1
+ return y, y
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(inp)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inp)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes_and_bypass_with_aten_graph(self):
+ inp = torch.tensor([0.1, 0.1])
+ inp2 = torch.tensor([0.4, 0.4])
+ inps = [inp, inp2]
+
+ def func(x, z):
+ y = x + 1
+ return y, y, z
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self):
+ inp = torch.tensor([0.1, 0.1])
+ inp2 = torch.tensor([0.1, 0.1])
+ inp3 = 4
+ inps = [inp, inp2, inp3]
+
+ def func(x, z, k):
+ y = x + k
+ return y, y, z
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self):
+ inp = torch.tensor([0.1, 0.1])
+ inp2 = torch.tensor([0.1, 0.1])
+ inp3 = 4
+ inps = [inp, inp2, inp3]
+
+ def func(x, z, k):
+ y = x + k
+ return z, y, y
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self):
+ inp = torch.tensor([0.1, 0.1])
+ inp2 = torch.tensor([0.1, 0.1])
+ inp3 = 4
+ inps = [inp, inp2, inp3]
+
+ def func(x, z, k):
+ y = x + k
+ return y[0].item(), y, z
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self):
+ inp = torch.zeros(10)
+ inp2 = torch.zeros(10)
+ inp3 = torch.zeros(10)
+ inps = [inp, inp2, inp3]
+
+ inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
+
+ def func(a, b, c):
+ return [[a], [b, c], [a + b], [[c + c]]]
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps_rand)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps_rand)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_func_return_with_aten_graph(self):
+ inp = torch.zeros(10)
+ inp2 = torch.zeros(10)
+ inp3 = torch.zeros(10)
+ inps = [inp, inp2, inp3]
+
+ inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
+
+ def func(a, b, c):
+ x = a + b + c
+
+ def func2(y):
+ return x * y
+
+ return func2(x)
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps_rand)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps_rand)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_dict_return_with_aten_graph(self):
+ inp = torch.zeros(10)
+ inp2 = torch.zeros(10)
+ inp3 = torch.zeros(10)
+ inps = [inp, inp2, inp3]
+
+ inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
+
+ def func(a, b, c):
+ x = a + b + c
+ return {"a": x}
+
+ opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
+ real_result = opt_func(*inps_rand)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, *inps, aten_graph=True)
+ out_graph = exported[0]
+ flat_input, _ = pytree.tree_flatten(inps_rand)
+
+ dynamo_result = out_graph(*flat_input)
+
+ self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
+
+ def test_export_with_stack_trace(self):
+ inp = torch.tensor([0.1, 0.1])
+ linear = torch.nn.Linear(2, 2)
+
+ def func(x):
+ x = x + 1
+ y = x.t()
+ y = y.relu()
+ y = linear(y)
+ return y
+
+ exported = torch._dynamo.export(func, inp, aten_graph=False)
+ out_graph = exported[0]
+
+ for node in out_graph.graph.nodes:
+ if node.op not in {"placeholder", "output"}:
+ self.assertTrue(node.stack_trace is not None)
+
+ torch._dynamo.reset()
+
+ exported = torch._dynamo.export(func, inp, aten_graph=True)
+ out_graph = exported[0]
+ for node in out_graph.graph.nodes:
+ if node.op == "call_function":
+ self.assertTrue(node.stack_trace is not None)
+
+ def test_export_compare_optimize_with_make_fx(self):
+ inp = torch.tensor([0.1, 0.1])
+ linear = torch.nn.Linear(2, 2)
+
+ def func(x):
+ x = x + 1
+ y = x.t()
+ y = y.relu()
+ y = linear(y)
+ return y
+
+ exported = torch._dynamo.export(func, inp, aten_graph=True)
+ out_graph = exported[0]
+ export_result = out_graph(inp)
+
+ torch._dynamo.reset()
+
+ def compiler(gm, sample_inputs):
+ aten_gm = make_fx(gm)(*sample_inputs)
+
+ self.assertEqual(len(aten_gm.graph.nodes), len(out_graph.graph.nodes))
+ for node1, node2 in zip(aten_gm.graph.nodes, out_graph.graph.nodes):
+ self.assertEqual(node1.op, node2.op)
+ if node1.op == "call_function":
+ self.assertEqual(node1.target, node2.target)
+ self.assertEqual(len(node1.args), len(node2.args))
+ for arg1, arg2 in zip(node1.args, node2.args):
+ self.assertEqual(type(arg1), type(arg2))
+
+ return aten_gm.forward
+
+ opt_func = torch._dynamo.optimize(compiler, nopython=True)(func)
+ make_fx_result = opt_func(inp)
+
+ self.assertTrue(torch._dynamo.utils.same(make_fx_result, export_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_method_on_module(self):
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 2))
+ self.linear = torch.nn.Linear(2, 2)
+
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ return torch.nonzero(x)
+
+ def forward(self, x):
+ y = torch.sin(x)
+ x = self.linear(x)
+ y = self.helper_fn(x)
+ return y
+
+ module = MyModule()
+ real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
+ module = MyModule()
+ graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
+ result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+ result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_method_on_module_invoke_twice(self):
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 2))
+ self.linear = torch.nn.Linear(2, 2)
+
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ return torch.nonzero(x)
+
+ def forward(self, x):
+ y = torch.sin(x)
+ x = self.linear(x)
+ y = self.helper_fn(x) + self.helper_fn(x)
+ return y
+
+ module = MyModule()
+ real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
+ module = MyModule()
+ graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
+ result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+ result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_free_function(self):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(x):
+ return torch.nonzero(x)
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 2))
+ self.linear = torch.nn.Linear(2, 2)
+
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ return torch.nonzero(x)
+
+ def forward(self, x):
+ y = torch.sin(x)
+ x = self.linear(x)
+ y = helper_fn(x) + self.helper_fn(x)
+ return y
+
+ module = MyModule()
+ real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
+ module = MyModule()
+ graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
+ result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+ result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_free_function_and_class_method(self):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(x):
+ return torch.nonzero(x)
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 2))
+ self.linear = torch.nn.Linear(2, 2)
+
+ def forward(self, x):
+ y = torch.sin(x)
+ x = self.linear(x)
+ y = helper_fn(x)
+ return y
+
+ module = MyModule()
+ real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
+ module = MyModule()
+ graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
+ result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+ result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_free_function_and_class_method_multiarg(self):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(x):
+ return torch.nonzero(x)
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 2))
+ self.linear = torch.nn.Linear(2, 2)
+
+ def forward(self, x, z):
+ y = torch.sin(x)
+ x = self.linear(x)
+ y = helper_fn(x) + helper_fn(z)
+ return y
+
+ module = MyModule()
+ real_result = module(
+ torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
+ )
+ module = MyModule()
+ graph, _ = torch._dynamo.export(
+ module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
+ )
+ result = graph(
+ torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]])
+ )
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+ result = graph(
+ torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]])
+ )
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_free_function_and_class_method_multiarg_diff(self):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(x):
+ return torch.nonzero(x)
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, z):
+ y = helper_fn(x) + helper_fn(z)
+ return y
+
+ module = MyModule()
+ real_result = module(
+ torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
+ )
+ module = MyModule()
+ graph, _ = torch._dynamo.export(
+ module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]])
+ )
+ result = graph(
+ torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]])
+ )
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+ result = graph(
+ torch.tensor([[1, 0], [0.25, 0.25]]),
+ torch.tensor([[0.33, 0.33], [0.25, 0.25]]),
+ )
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_tuple_nonzero(self):
+ class MyModule(torch.nn.Module):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ return (torch.nonzero(x), torch.nonzero(x))
+
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ elements = self.helper_fn(x)
+ all_y = []
+ for element in elements:
+ for item in element:
+ all_y.append(y * item)
+ return all_y
+
+ module = MyModule()
+ real_result = module(torch.tensor([1.0, 1.0]))
+ graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0]))
+
+ # Tensor input can be almost anything here, and the result will capture what we
+ # made constant at compile time.
+ result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_list_nonzero(self):
+ class MyModule(torch.nn.Module):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ return [torch.nonzero(x), torch.nonzero(x)]
+
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ elements = self.helper_fn(x)
+ all_y = []
+ for element in elements:
+ for item in element:
+ all_y.append(y * item)
+ return all_y
+
+ module = MyModule()
+ real_result = module(torch.tensor([1.0, 1.0]))
+ graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0]))
+
+ # Tensor input can be almost anything here, and the result will capture what we
+ # made constant at compile time.
+ result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_list_nonzero_free_function(self):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(x):
+ return [torch.nonzero(x), torch.nonzero(x)]
+
+ class MyModule(torch.nn.Module):
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ elements = helper_fn(x)
+ all_y = []
+ for element in elements:
+ for item in element:
+ all_y.append(y * item)
+ return all_y
+
+ module = MyModule()
+ real_result = module(torch.tensor([1.0, 1.0]))
+ graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0]))
+
+ # Tensor input can be almost anything here, and the result will capture what we
+ # made constant at compile time.
+ result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_dict_values(self):
+ class MyModule(torch.nn.Module):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ return {"x": x, "x^2": x * x}
+
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ elements = self.helper_fn(x)
+ y = y * elements["x"]
+ y = y * elements["x^2"]
+ return y
+
+ module = MyModule()
+ real_result = module(torch.tensor([2.0, 2.0]))
+ graph, guards = torch._dynamo.export(module, torch.tensor([2.0, 2.0]))
+
+ # Tensor input can be almost anything here, and the result will capture what we
+ # made constant at compile time.
+ result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_none_control_flow(self):
+ class MyModule(torch.nn.Module):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ if x.item() < 0:
+ return None
+ else:
+ return x
+
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ x = self.helper_fn(x)
+ if x is None:
+ return y
+ return y * x
+
+ module = MyModule()
+ real_result = module(torch.tensor([-1]))
+
+ # X is negative, so .item() < 0, which means we return y
+ self.assertEqual(real_result, torch.tensor([0.5]))
+
+ graph, guards = torch._dynamo.export(module, torch.tensor([-1]))
+ result = graph(torch.tensor([2]))
+ # X is positive, but we compiled helper_fn to return None, so it will still return y
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_not_none_control_flow(self):
+ class MyModule(torch.nn.Module):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ if x.item() < 0:
+ return None
+ else:
+ return x
+
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ x = self.helper_fn(x)
+ if x is None:
+ return y
+ return y * x
+
+ module = MyModule()
+ real_result = module(torch.tensor([2]))
+
+ # X is positive, so .item() > 0, which means we return y * x
+ self.assertEqual(real_result, torch.tensor([1.0]))
+
+ graph, guards = torch._dynamo.export(module, torch.tensor([2]))
+ result = graph(torch.tensor([-0.5]))
+ # X is negative, but we compiled helper_fn to return x, so it will still return y * x
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_none_control_flow_free_func(self):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(x):
+ if x.item() < 0:
+ return None
+ else:
+ return x
+
+ class MyModule(torch.nn.Module):
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ x = helper_fn(x)
+ if x is None:
+ return y
+ return y * x
+
+ module = MyModule()
+ real_result = module(torch.tensor([-1]))
+
+ # X is negative, so .item() < 0, which means we return y
+ self.assertEqual(real_result, torch.tensor([0.5]))
+
+ graph, guards = torch._dynamo.export(module, torch.tensor([-1]))
+ result = graph(torch.tensor([2]))
+ # X is positive, but we compiled helper_fn to return None, so it will still return y
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_not_none_control_flow_pos(self):
+ class MyModule(torch.nn.Module):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ if x.item() < 0:
+ return None
+ else:
+ return x
+
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ x = self.helper_fn(x)
+ if x is None:
+ return y
+ return y * x
+
+ module = MyModule()
+ real_result = module(torch.tensor([2]))
+
+ # X is positive, so .item() > 0, which means we return y * x
+ self.assertEqual(real_result, torch.tensor([1.0]))
+
+ graph, guards = torch._dynamo.export(module, torch.tensor([2]))
+ result = graph(torch.tensor([-0.5]))
+ # X is negative, but we compiled helper_fn to return x, so it will still return y * x
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_not_none_control_flow_free_func(self):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(x):
+ if x.item() < 0:
+ return None
+ else:
+ return x
+
+ class MyModule(torch.nn.Module):
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ x = helper_fn(x)
+ if x is None:
+ return y
+ return y * x
+
+ module = MyModule()
+ real_result = module(torch.tensor([2]))
+
+ # X is positive, so .item() > 0, which means we return y * x
+ self.assertEqual(real_result, torch.tensor([1.0]))
+
+ graph, guards = torch._dynamo.export(module, torch.tensor([2]))
+ result = graph(torch.tensor([-0.5]))
+ # X is negative, but we compiled helper_fn to return x, so it will still return y * x
+ self.assertTrue(torch._dynamo.utils.same(result, real_result))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_export_with_constant_not_return_const(self):
+ class MyModule(torch.nn.Module):
+ @torch._dynamo.assume_constant_result
+ def helper_fn(self, x):
+ return self.val
+
+ def forward(self, x):
+ y = torch.tensor([0.5])
+ x = self.helper_fn(x)
+ if x == "A":
+ return y
+ return -1
+
+ module = MyModule()
+ module.val = "A"
+ resA = module(torch.tensor([2]))
+ graph, guards = torch._dynamo.export(module, torch.tensor([2]))
+ module.val = "B"
+ resB = graph(torch.tensor([2]))
+ self.assertTrue(torch._dynamo.utils.same(resA, resB))
+
+ def test_export_decomp(self):
+ def f(x):
+ return x.t() + x.t()
+
+ def nop(x):
+ return x.cos()
+
+ graph, _ = torch._dynamo.export(
+ f,
+ (torch.randn(5)),
+ aten_graph=True,
+ decomposition_table={torch.ops.aten.t.default: nop},
+ )
+ self.assertEqual(
+ len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
+ 0,
+ )
+
+ graph, _ = torch._dynamo.export(
+ f, (torch.randn(5)), aten_graph=True, decomposition_table=None
+ )
+ self.assertEqual(
+ len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
+ 2,
+ )
+
+ def test_export_decomp_asserts_bad_args(self):
+ def f(x):
+ return x.t() + x.t()
+
+ def nop(x):
+ return x.cos()
+
+ with self.assertRaises(AssertionError):
+ graph, _ = torch._dynamo.export(
+ f,
+ (torch.randn(5)),
+ aten_graph=False,
+ decomposition_table={torch.ops.aten.t.default: nop},
+ )
+
+ def test_export_decomp_asserts_bad_args_mode(self):
+ def f(x):
+ return x.t() + x.t()
+
+ def nop(x):
+ return x.cos()
+
+ with self.assertRaises(AssertionError):
+ graph, _ = torch._dynamo.export(
+ f, (torch.randn(5)), aten_graph=False, tracing_mode="symbolic"
+ )
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
new file mode 100644
index 0000000000000..e2004430f4186
--- /dev/null
+++ b/test/dynamo/test_functions.py
@@ -0,0 +1,675 @@
+# Owner(s): ["module: dynamo"]
+# flake8: noqa
+import collections
+import functools
+import inspect
+import itertools
+import operator
+from typing import Any
+
+import torch
+
+import torch._dynamo.testing
+from torch import sub
+from torch._dynamo.testing import requires_static_shapes
+from torch.nn import functional as F
+
+tensor_for_import_testing = torch.ones(10, 10)
+d = torch.ones(10, 10)
+e = torch.nn.Linear(10, 10)
+flag = True
+
+
+def constant3(a, b):
+ return a - b + (1.0 + 2)
+
+
+def func_with_default(a, b, some_default_arg=True):
+ if some_default_arg:
+ return a - b
+
+
+def make_test(fn):
+ nargs = len(inspect.signature(fn).parameters)
+
+ def test_fn(self):
+ return torch._dynamo.testing.standard_test(self, fn=fn, nargs=nargs)
+
+ return test_fn
+
+
+@torch.jit.script_if_tracing
+def inline_script_if_tracing(x):
+ return x + 1.2
+
+
+@torch.jit.ignore
+def inline_ignore(x):
+ return x + 3.4
+
+
+@torch.jit.unused
+def inline_unused(x):
+ return x + 5.6
+
+
+class FunctionTests(torch._dynamo.testing.TestCase):
+ @make_test
+ def test_inline_jit_annotations(x):
+ x = inline_script_if_tracing(x)
+ x = inline_ignore(x)
+ x = inline_unused(x)
+ return
+
+ @make_test
+ def test_add(a, b):
+ return a + b
+
+ @make_test
+ def test_is_not_null(a, b):
+ if a is not None and b is not None:
+ return a + b
+
+ @make_test
+ def test_constant1(a, b, c):
+ return a - b * c + 1.0
+
+ @make_test
+ def test_constant2(a, b, c):
+ return a - b * c + 1
+
+ @make_test
+ def test_constant3(a):
+ b = 1
+ c = 2
+ d = 3
+ return b + c - d + a
+
+ @make_test
+ def test_constant4(a, b):
+ c = 2
+ d = 3
+ if c > d:
+ return a - b
+ return b - a
+
+ @make_test
+ def test_finfo(a, b):
+ if torch.iinfo(torch.int32).bits == 32:
+ return torch.finfo(a.dtype).min * b
+
+ @make_test
+ def test_globalfn(a, b):
+ return sub(a, b)
+
+ @make_test
+ def test_viatorch(a, b):
+ return torch.sub(a, b)
+
+ @make_test
+ def test_viamethod(a, b):
+ return a.sub(b)
+
+ @make_test
+ def test_indirect1(a, b):
+ t = a.sub
+ return t(b)
+
+ @make_test
+ def test_indirect2(a, b):
+ t = a.sub
+ args = (b,)
+ return t(*args)
+
+ @make_test
+ def test_indirect3(a, b):
+ t = a.sub
+ args = (b,)
+ kwargs = {}
+ return t(*args, **kwargs)
+
+ @make_test
+ def test_methodcall1(a, b, c):
+ return constant3(a, b) * c
+
+ @make_test
+ def test_methodcall2(a, b):
+ return constant3(a=b, b=a) + 1
+
+ @make_test
+ def test_methodcall3(a, b):
+ return constant3(a, b=1.0) + b
+
+ @make_test
+ def test_device_constant(a):
+ return a + torch.ones(1, device=torch.device("cpu"))
+
+ @make_test
+ def test_tuple1(a, b):
+ args = (a, b)
+ return sub(*args)
+
+ @make_test
+ def test_tuple2(a, b):
+ args = [a, b]
+ return sub(*args)
+
+ @make_test
+ def test_is_in_onnx_export(x, y):
+ if torch.onnx.is_in_onnx_export():
+ return x - 1
+ else:
+ return y + 1
+
+ @make_test
+ def test_is_fx_tracing(x, y):
+ if torch.fx._symbolic_trace.is_fx_tracing():
+ return x - 1
+ else:
+ return y + 1
+
+ @make_test
+ def test_listarg1(a, b):
+ return torch.cat([a, b])
+
+ @make_test
+ def test_listarg2(a, b):
+ return torch.cat((a, b), dim=0)
+
+ @make_test
+ def test_listarg3(a, b):
+ kwargs = {"tensors": (a, b), "dim": 0}
+ return torch.cat(**kwargs)
+
+ @make_test
+ def test_listarg4(a, b):
+ return torch.cat(tensors=[a, b], dim=0)
+
+ @make_test
+ def test_listarg5(a, b):
+ args = [(a, b)]
+ kwargs = {"dim": 0}
+ return torch.cat(*args, **kwargs)
+
+ @make_test
+ def test_slice1(a):
+ return a[5]
+
+ @make_test
+ def test_slice2(a):
+ return a[:5]
+
+ @make_test
+ def test_slice3(a):
+ return a[5:]
+
+ @make_test
+ def test_slice4(a):
+ return a[2:5]
+
+ @make_test
+ def test_slice5(a):
+ return a[::2]
+
+ @make_test
+ def test_slice6(a):
+ return torch.unsqueeze(a, 0)[:, 2:]
+
+ @make_test
+ def test_unpack1(a):
+ a, b = a[:5], a[5:]
+ return a - b
+
+ @make_test
+ def test_unpack2(a):
+ packed = [a[:5], a[5:]]
+ a, b = packed
+ return a - b
+
+ @make_test
+ def test_unpack3(a):
+ packed = (a[:5], a[5:])
+ a, b = packed
+ return a - b
+
+ @make_test
+ def test_fn_with_self_set(a, b):
+ # avg_pool2d is an odd one with __self__ set
+ return F.avg_pool2d(
+ torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1
+ )
+
+ @make_test
+ def test_return_tuple1(a, b):
+ return (a - b, b - a, a, b)
+
+ @make_test
+ def test_globalvar(a, b):
+ return a - b + d
+
+ @make_test
+ def test_globalmodule(x):
+ return e(x)
+
+ @make_test
+ def test_inline_with_default(a, b, c):
+ return func_with_default(a, b) * c
+
+ @make_test
+ def test_inner_function(x):
+ def fn(x):
+ return torch.add(x, x)
+
+ return fn(x)
+
+ @make_test
+ def test_transpose_for_scores(x):
+ new_x_shape = x.size()[:-1] + (2, 5)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1)
+
+ @make_test
+ def test_return_tuple2(x):
+ return (torch.add(x, x), x)
+
+ @make_test
+ def test_load_global_bool(x):
+ if flag:
+ return torch.add(x, x)
+ else:
+ return x
+
+ @make_test
+ def test_len_tensor(x):
+ z = len(x)
+ return torch.add(x, z)
+
+ @make_test
+ def test_len_constant_list(x):
+ z = len([1, 2, 3])
+ return torch.add(x, z)
+
+ @make_test
+ def test_len_constant_dict(x):
+ z = len({"foo": "bar"})
+ return torch.add(x, z)
+
+ @make_test
+ def test_dict_copy(x):
+ z = dict({"foo": x + 1})
+ return z
+
+ @make_test
+ def test_len_constant_misc_iterables(x):
+ a = len((1, 2, 3))
+ b = len("test str")
+ c = a + b
+ return torch.add(x, c)
+
+ @make_test
+ def test_float(x):
+ y = float(1.2)
+ y += float("1.2")
+ return torch.add(x, y)
+
+ @make_test
+ def test_dtype(x):
+ if x.dtype == torch.float32:
+ return x + 1
+
+ @make_test
+ def test_device(x):
+ if not x.is_cuda:
+ return x + 1
+
+ @make_test
+ def test_ndim(x):
+ if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2:
+ return x + 1
+
+ @make_test
+ def test_is_sparse(x):
+ if not x.is_sparse:
+ return x + 1
+
+ @requires_static_shapes
+ @make_test
+ def test_shape1(x):
+ if x.shape[0] == 10:
+ return x + 1
+
+ @requires_static_shapes
+ @make_test
+ def test_shape2(x):
+ if x.size(1) == 10:
+ return x + 1
+
+ @make_test
+ def test_del(a, b):
+ c = a + 1
+ d = c + 2
+ del c, a
+ return b + d
+
+ @requires_static_shapes
+ @make_test
+ def test_chunks1(x):
+ chunk_size = 5
+ assert x.shape[0] % chunk_size == 0
+ assert x.shape[0] // chunk_size == 2
+ return x[:chunk_size] - x[chunk_size:]
+
+ @make_test
+ def test_import1(x, y):
+ import torch
+ from torch import sub
+
+ return sub(torch.add(x, y), y)
+
+ @make_test
+ def test_return_dict(x, y):
+ z = [x + y, y, False]
+ return {"x": x, "z": z, "a": x, "b": z, "c": x}
+
+ @make_test
+ def test_return_dict2(x, y):
+ tmp = {"x": x}
+ tmp["z"] = [x + y, y]
+ tmp["y"] = y
+ tmp["z"].append(False)
+ return tmp
+
+ @make_test
+ def test_funcdef_closure(x, y):
+ x = x + y + 1.0
+
+ def inner(z):
+ nonlocal x, y
+ y = x + z + 20.0
+ x = y + z + 10.0
+
+ inner(2.0)
+ inner(3.0)
+
+ return x, y
+
+ @make_test
+ def test_module_constant(x, y):
+ r = x + y
+ for i in range(torch._dynamo.testing.three):
+ r = r / y
+ return r
+
+ @make_test
+ def test_inline_softmax(x, y):
+ # This is common in sme huggingface models
+ return torch.nn.Softmax(dim=-1)(x + y * 2)
+
+ @make_test
+ def test_dtype_compare(a, b):
+ if a.dtype == torch.float16:
+ return a + 10
+ if a.dtype == torch.float32:
+ return a - b * 32
+
+ @make_test
+ def test_build_list_unpack(a, b):
+ it1 = (x + 1 for x in (a, b))
+ it2 = (x - 1 for x in (a, b))
+ return torch.cat([*it1, *it2], dim=-1)
+
+ @make_test
+ def test_tensor_len(a, b):
+ return a + b + len(a) + b.__len__()
+
+ @make_test
+ def test_pop(a, b):
+ ll = [a, b]
+ ll.append(a + 1)
+ ll.extend(
+ [
+ b + 2,
+ a + b,
+ ]
+ )
+ ll.pop(-1)
+ ll.pop(0)
+ ll.pop()
+ v1, v2 = ll
+ return v1 - v2
+
+ @make_test
+ def test_list_convert(a, b):
+ ll = [a + 2, b]
+ ll = tuple(ll)
+ tmp = b + 3
+ ll = list(ll)
+ v1, v2 = ll
+ return v1 - v2 + tmp
+
+ @make_test
+ def test_list_add(a, b):
+ l1 = (a, b)
+ l2 = () # being a LOAD_CONST in the bytecode
+ l3 = l1 + l2
+ return l3[0] + l3[1]
+
+ @make_test
+ def test_startswith(a, b):
+ x = a + b
+ if "foobar".startswith("foo") and "test" in constant3.__module__:
+ x = x + 1
+ return x
+
+ @make_test
+ def test_dict_ops(a, b):
+ tmp = {"a": a + 1, "b": b + 2}
+ v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4)
+ tmp.update({"d": 3})
+ tmp["c"] = v + tmp["d"]
+ if "c" in tmp and "missing" not in tmp:
+ return tmp["c"] - tmp["a"] + len(tmp)
+
+ def test_dict_param_keys(self):
+ a_param = torch.nn.Parameter(torch.ones([4, 4]))
+
+ def fn(a):
+ tmp = {"a": a, a_param: 3}
+ return tmp["a"] + tmp[a_param]
+
+ test = make_test(fn)
+ test(self)
+
+ def test_default_dict(self):
+ dd = collections.defaultdict(dict)
+ param = torch.nn.Parameter(torch.ones([2, 2]))
+
+ def fn(x):
+ dd["a"] = x + 1
+ dd[param] = 123
+ dd["c"] = x * 2
+ return dd["b"], dd
+
+ test = make_test(fn)
+ test(self)
+
+ @make_test
+ def test_min_max(a, b):
+ c = a + b
+ a = a.sum()
+ b = b.sum()
+ a = min(max(a, 0), 1)
+ b = max(0, min(1, b))
+ return max(a, b) - min(a, b) + c
+
+ @make_test
+ def test_map_sum(a, b, c, d):
+ return sum(map(lambda x: x + 1, [a, b, c, d]))
+
+ @make_test
+ def test_reduce(a, b, c, d):
+ return functools.reduce(operator.add, [a, b, c, d])
+
+ @make_test
+ def test_tuple_contains(a, b):
+ v1 = "a"
+ v2 = "b"
+ v3 = "c"
+ vals1 = (v1, v2, v3)
+ vals2 = ("d", "e", "f")
+ if "a" in vals1 and "b" not in vals2:
+ return a + b
+ return a - b
+
+ @make_test
+ def test_tuple_iadd(a, b):
+ output = (a, b)
+ output += (a + b, a - b)
+ return output
+
+ @make_test
+ def test_unpack_ex1(x):
+ output = (x, x + 1, x + 2, x + 3)
+ a, b, *cd = output
+ return a - b / cd[0]
+
+ @make_test
+ def test_unpack_ex2(x):
+ output = (x, x + 1, x + 2, x + 3)
+ *ab, c, d = output
+ return c - d / ab[0]
+
+ @make_test
+ def test_unpack_ex3(x):
+ output = (x, x + 1, x + 2, x + 3)
+ a, *bc, d = output
+ return a - d / bc[0]
+
+ @make_test
+ def test_const_tuple_add1(x):
+ output = (x, x + 1, x + 2, x + 3)
+ output = () + output + ()
+ return output[2] + output[3]
+
+ @make_test
+ def test_const_tuple_add2(x):
+ output = (x, x + 1, x + 2, x + 3)
+ output = (None,) + output + (None,)
+ return output[2] + output[3]
+
+ @make_test
+ def test_list_truth(a, b):
+ tmp = [1, 2, 3]
+ if tmp:
+ return a + b
+ else:
+ return a - b
+
+ @make_test
+ def test_list_reversed(a, b):
+ tmp = [a + 1, a + 2, a + 3]
+ return a + b + next(iter(reversed(tmp)))
+
+ @make_test
+ def test_list_clear(a, b):
+ tmp = [a + 1, a + 2]
+ tmp.clear()
+ tmp.append(a + b)
+ return tmp
+
+ @make_test
+ def test_islice_chain(a, b):
+ tmp1 = [a + 1, a + 2]
+ tmp2 = [a + 3, a + 4]
+ a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3))
+ c = next(itertools.islice(tmp1, 1, None))
+ return a - b / c
+
+ @make_test
+ def test_is_quantized(a, b):
+ if not a.is_quantized:
+ return a + b
+
+ @make_test
+ def test_fstrings1(a, b):
+ x = 1.229
+ tmp = f"{x:.2f} bar"
+ if tmp.startswith("1.23"):
+ return a + b
+
+ @requires_static_shapes
+ @make_test
+ def test_fstrings2(x):
+ tmp = f"{x.shape[0]} bar"
+ if tmp.startswith("10"):
+ return x + 1
+
+ @make_test
+ def test_fstrings3(x):
+ tmp = f"{x.__class__.__name__} foo"
+ if tmp.startswith("Tensor"):
+ return x + 1
+
+ @requires_static_shapes
+ @make_test
+ def test_tensor_new_with_size(x):
+ y = torch.rand(5, 8)
+ z = x.new(y.size())
+ assert z.size() == y.size()
+
+ @requires_static_shapes
+ @make_test
+ def test_tensor_new_with_shape(x):
+ y = torch.rand(5, 8)
+ z = x.new(y.shape)
+ assert z.size() == y.size()
+
+ @make_test
+ def test_jit_annotate(x):
+ y = torch.jit.annotate(Any, x + 1)
+ return y + 2
+
+ @requires_static_shapes
+ @make_test
+ def test_is_contiguous_memory_format(tensor):
+ if torch.jit.is_scripting():
+ return None
+ elif tensor.is_contiguous(memory_format=torch.contiguous_format):
+ return tensor + 1
+
+ @make_test
+ def test_list_slice_assignment(x):
+ m = [1, 2, 3, 4]
+ m[1:] = [6] * (len(m) - 1)
+ return x + 1
+
+ # # This is to test the new syntax for pattern matching
+ # # ("match ... case ...") added on python 3.10.
+ # # Uncomment these test cases if you run on 3.10+
+ # @make_test
+ # def test_match_sequence(a):
+ # point = (5, 8)
+ # match point:
+ # case (0, 0):
+ # return a
+ # case (0, y):
+ # return a - y
+ # case (x, 0):
+ # return a + x
+ # case (x, y):
+ # return a + x - y
+
+ # @make_test
+ # def test_match_mapping_and_match_keys(x):
+ # param = {"a": 0.5}
+ # match param:
+ # case {"a": param}:
+ # return x * param
+ # case {"b": param}:
+ # return x / param
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py
new file mode 100644
index 0000000000000..5e3d975d7bc87
--- /dev/null
+++ b/test/dynamo/test_global.py
@@ -0,0 +1,232 @@
+# Owner(s): ["module: dynamo"]
+import torch
+
+import torch._dynamo.testing
+from torch._dynamo.testing import same
+
+try:
+ from . import test_global_declaration
+except ImportError:
+ import test_global_declaration
+
+
+class Pair(object): # noqa: B903
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+
+def Foo():
+ return Pair(1, 1)
+
+
+g_counter = 1
+g_list = [0, 1, 2]
+g_dict = {"a": 0, "b": 1}
+g_object = Foo()
+g_tensor = torch.zeros(10)
+
+
+_name: int = 0
+
+
+def fresh_name() -> str:
+ """create a new unique name for a variable: v0, v1, v2"""
+ global _name
+ r = f"v{_name}"
+ _name += 1
+ return r
+
+
+def reset_name():
+ global _name
+ _name = 0
+
+
+class TestGlobals(torch._dynamo.testing.TestCase):
+ def test_store_global_1(self):
+ def fn(x):
+ global g_counter
+ val = x + g_counter
+ g_counter += 1
+ return val
+
+ x = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res1 = opt_fn(x)
+ res2 = fn(x)
+ self.assertTrue(same(res2 - res1, torch.ones(10)))
+
+ def test_store_global_2(self):
+ def fn(x):
+ global g_counter
+ val = x + g_counter
+ g_counter += 1
+ g_counter += 1
+ return val
+
+ x = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res1 = opt_fn(x)
+ """Wrap the second call with torch._dynamo as well"""
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res2 = opt_fn(x)
+ self.assertTrue(same(res2 - res1, 2 * torch.ones(10)))
+
+ def test_store_global_new(self):
+ def fn(x):
+ # Test create a new global
+ global g_counter_new
+ g_counter_new = x + 1
+ return x + g_counter_new
+
+ x = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res1 = opt_fn(x)
+ self.assertTrue(same(res1, x + x + 1))
+
+ def test_store_global_list(self):
+ def fn(x):
+ global g_list
+ val = x + g_list[1]
+ """
+ Strictly speaking, we are not testing STORE_GLOBAL
+ here, since STORE_SUBSCR is actually used to store.
+ """
+ g_list[1] += 1
+ return val
+
+ x = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res1 = opt_fn(x)
+ res2 = fn(x)
+ self.assertTrue(same(res2 - res1, torch.ones(10)))
+
+ def test_store_global_list_2(self):
+ def fn(x):
+ global g_list
+ val = x + g_list[1]
+ g_list = [x + 1 for x in g_list]
+ return val
+
+ x = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res1 = opt_fn(x)
+ res2 = fn(x)
+ self.assertTrue(same(res2 - res1, torch.ones(10)))
+
+ def test_store_global_dict(self):
+ def fn(x):
+ global g_dict
+ val = x + g_dict["b"]
+ """
+ Strictly speaking, we are not testing STORE_GLOBAL
+ here, since STORE_SUBSCR is actually used to store.
+ """
+ g_dict["b"] += 1
+ return val
+
+ x = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res1 = opt_fn(x)
+ res2 = fn(x)
+ self.assertTrue(same(res2 - res1, torch.ones(10)))
+
+ def test_store_global_dict_2(self):
+ def fn(x):
+ global g_dict
+ g_dict = {key: value + 1 for key, value in g_dict.items()}
+ val = x + g_dict["b"]
+ return val
+
+ x = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res1 = opt_fn(x)
+ res2 = fn(x)
+ self.assertTrue(same(res2 - res1, torch.ones(10)))
+
+ def test_store_global_object(self):
+ def fn(x):
+ global g_object
+ val = x + g_object.y
+ g_object.y += 1
+ return val
+
+ x = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res1 = opt_fn(x)
+ res2 = fn(x)
+ self.assertTrue(same(res2 - res1, torch.ones(10)))
+
+ def test_store_global_cross_file(self):
+ def fn(x):
+ val = x + test_global_declaration.g_tensor_export
+ test_global_declaration.g_tensor_export = (
+ test_global_declaration.g_tensor_export + 1
+ )
+ return val
+
+ x = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res1 = opt_fn(x)
+ res2 = fn(x)
+ self.assertTrue(same(res2 - res1, torch.ones(10)))
+
+ def test_store_global_inline_1(self):
+ # Borrowed from test_python_autograd.py
+ class Variable:
+ def __init__(self, value: torch.Tensor, name: str = None):
+ self.value = value
+ self.name = name or fresh_name()
+
+ def fn(a, b):
+ a = Variable(a)
+ b = Variable(b)
+ return a.value + b.value, a.name + b.name
+
+ a = torch.randn(10)
+ b = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ v0, s0 = opt_fn(a, b)
+ self.assertEqual(s0, "v0v1")
+ reset_name()
+
+ def test_store_global_inline_2(self):
+ # Borrowed from test_python_autograd.py
+ class Variable:
+ def __init__(self, value: torch.Tensor, name: str = None):
+ self.value = value
+ self.name = name or fresh_name()
+
+ @staticmethod
+ def constant(value: torch.Tensor, name: str = None):
+ return Variable(value, name)
+
+ def fn(a, b):
+ a = Variable.constant(a)
+ b = Variable.constant(b)
+ return a.value + b.value, a.name + b.name
+
+ a = torch.randn(10)
+ b = torch.randn(10)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ v0, s0 = opt_fn(a, b)
+ self.assertEqual(s0, "v0v1")
+ reset_name()
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_global_declaration.py b/test/dynamo/test_global_declaration.py
new file mode 100644
index 0000000000000..95995ca80a22f
--- /dev/null
+++ b/test/dynamo/test_global_declaration.py
@@ -0,0 +1,4 @@
+# Owner(s): ["module: dynamo"]
+import torch
+
+g_tensor_export = torch.ones(10)
diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py
new file mode 100644
index 0000000000000..030b9f73ecf30
--- /dev/null
+++ b/test/dynamo/test_minifier.py
@@ -0,0 +1,97 @@
+# Owner(s): ["module: dynamo"]
+import os
+import shutil
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo
+import torch._dynamo.testing
+from torch._dynamo.optimizations.backends import create_backend
+
+
+class MockModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ for _ in range(10):
+ x = torch.sin(x)
+ x = torch._foobar(x)
+ for _ in range(10):
+ x = torch.cos(x)
+ return x
+
+
+class MinfierTests(torch._dynamo.testing.TestCase):
+ def test_after_dynamo(self):
+ @create_backend
+ def bad_dynamo_backend(subgraph):
+ import sys
+
+ def f(*args):
+ # Shifted the forced exception to runtime as this is more common
+ # in JIT compilers.
+ for node in subgraph.model.graph.nodes:
+ if node.op == "call_function" and node.target is torch._foobar:
+ sys.stdout.write("Dynamo compiled failed\n")
+ raise NotImplementedError("foobar is not implemented")
+ return subgraph.model(*args)
+
+ return f
+
+ mod = MockModule()
+ opt_mod = torch._dynamo.optimize("bad_dynamo_backend")(mod)
+ repro_dir = "/tmp/test_minifier"
+ repro_file = os.path.join(repro_dir, "minifier_launcher.py")
+ shutil.rmtree(repro_dir, ignore_errors=True)
+
+ @patch.object(torch._dynamo.config, "repro_after", "dynamo")
+ @patch.object(torch._dynamo.config, "repro_dir", repro_dir)
+ def inner():
+ x = torch.randn(4)
+ try:
+ opt_mod(x)
+ except Exception:
+ pass
+
+ inner()
+ self.assertTrue(os.path.exists(repro_file))
+
+ # If error_at_aot is True, an error will be produced when AOTAutograd
+ # attempts to generate the backward graph.
+ # If error_after_aot is False, an error will be produced in inductor.
+ def _test_around_aot(self, error_at_aot):
+ mod = MockModule()
+ opt_mod = torch._dynamo.optimize("inductor")(mod)
+ repro_dir = "/tmp/test_minifier"
+ repro_file = os.path.join(repro_dir, "minifier_launcher.py")
+ shutil.rmtree(repro_dir, ignore_errors=True)
+
+ repro_after = "dynamo" if error_at_aot else "aot"
+
+ @patch.object(torch._dynamo.config, "repro_after", repro_after)
+ @patch.object(torch._dynamo.config, "repro_dir", repro_dir)
+ def inner():
+ x = torch.randn(4)
+ x.requires_grad = error_at_aot
+ try:
+ opt_mod(x)
+ except Exception:
+ pass
+
+ inner()
+
+ self.assertTrue(os.path.exists(repro_file))
+
+ def test_at_aot(self):
+ self._test_around_aot(True)
+
+ def test_after_aot(self):
+ self._test_around_aot(False)
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
new file mode 100644
index 0000000000000..e3e05059230fd
--- /dev/null
+++ b/test/dynamo/test_misc.py
@@ -0,0 +1,2724 @@
+# Owner(s): ["module: dynamo"]
+import collections
+import copy
+import dataclasses
+import dis
+import enum
+import logging
+import math
+import os
+import sys
+import typing
+import unittest
+import weakref
+from unittest.mock import patch
+
+import numpy as np
+import torch
+
+import torch._dynamo.testing
+import torch.onnx.operators
+from torch._dynamo import bytecode_transformation
+from torch._dynamo.testing import (
+ CompileCounter,
+ requires_static_shapes,
+ same,
+ unsupported,
+)
+from torch.testing._internal.jit_utils import JitTestCase
+
+mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"])
+
+
+def my_custom_function(x):
+ return x + 1
+
+
+class MiscTests(torch._dynamo.testing.TestCase):
+ def test_boolarg(self):
+ def boolarg(aa, bb, flag):
+ if flag:
+ return aa - bb
+ else:
+ return bb - aa
+
+ a = torch.randn(10, 10)
+ b = torch.randn(10, 10)
+ correct1 = boolarg(a, b, True)
+ correct2 = boolarg(a, b, False)
+ correct3 = boolarg(a, b, None)
+ counter = CompileCounter()
+ opt_boolarg = torch._dynamo.optimize_assert(counter)(boolarg)
+ val1 = opt_boolarg(a, b, True)
+ val2 = opt_boolarg(a, b, False)
+ val3 = opt_boolarg(a, b, None)
+ val4 = opt_boolarg(a, b, True)
+ self.assertTrue(same(val1, correct1))
+ self.assertTrue(same(val2, correct2))
+ self.assertTrue(same(val3, correct3))
+ self.assertTrue(same(val4, correct1))
+ self.assertEqual(counter.frame_count, 3)
+
+ def test_callpacked(self):
+ def call_packed(args):
+ a, b, c = args
+ return a - b * c
+
+ counter = CompileCounter()
+ a = torch.randn(10, 10)
+ b = torch.randn(10, 10)
+ c = torch.randn(10, 10)
+ correct = call_packed([a, b, c])
+ opt_call_packed = torch._dynamo.optimize_assert(counter)(call_packed)
+ val1 = opt_call_packed([a, b, c])
+ val2 = opt_call_packed((a, b, c))
+ val3 = opt_call_packed([a, b, c])
+ val4 = opt_call_packed((a, b, c))
+ self.assertTrue(same(val1, correct))
+ self.assertTrue(same(val2, correct))
+ self.assertTrue(same(val3, correct))
+ self.assertTrue(same(val4, correct))
+ self.assertEqual(counter.frame_count, 2)
+
+ def test_raises(self):
+ def fn(a, b, c, cls):
+ x = a + b - c * 10
+ raise cls(str(x))
+
+ counter = CompileCounter()
+ a = torch.randn(10, 10)
+ b = torch.randn(10, 10)
+ c = torch.randn(10, 10)
+ opt_fn = torch._dynamo.optimize(counter)(fn)
+ self.assertRaises(AssertionError, lambda: opt_fn(a, b, c, AssertionError))
+ self.assertEqual(counter.frame_count, 1)
+ self.assertEqual(counter.op_count, 3)
+
+ def test_inplace(self):
+ def inplace1(a, b):
+ o = torch.empty((10, 10))
+ o.copy_(a)
+ o -= b
+ return o
+
+ torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3)
+
+ def test_unpack4(self):
+ def unpack4(a, b):
+ a = a[:5, :]
+ b = b[:5, :]
+ x, y = a.size()
+ o = torch.empty((x, y))
+ o.copy_(a / b)
+ return o
+
+ torch._dynamo.testing.standard_test(
+ self, unpack4, 2, expected_ops=5, expected_ops_dynamic=8
+ )
+
+ def test_unpack5(self):
+ def unpack5(a, b):
+ a = a[:5, :]
+ b = b[:5, :]
+ x, y = a.shape
+ o = torch.empty((x, y))
+ o.copy_(a / b)
+ return o
+
+ torch._dynamo.testing.standard_test(
+ self, unpack5, 2, expected_ops=5, expected_ops_dynamic=8
+ )
+
+ def test_matmul1(self):
+ def matmul_op1(a, b):
+ return a @ b
+
+ # TODO(jansel): FX doesn't support this, should add upstream support
+ torch._dynamo.testing.standard_test(self, matmul_op1, 2, expected_ops=1)
+
+ def test_builtin_isinstance(self):
+ def fn(x):
+ t = torch.arange(1, 3)
+ a = isinstance(x, torch.Tensor)
+ b = isinstance(t, torch.Tensor)
+ c = isinstance(x, int)
+ d = isinstance(3, int)
+ e = isinstance([1, 2, 3], list)
+ f = isinstance({"foo": 1, "bar": 2}, dict)
+ res = [a, b, c, d, e, f]
+ # Can't run yet due to other unimplemented instructions
+ # res += [isinstance(torch.nn.LazyLinear(2, 3), torch.nn.Linear)]
+ return res
+
+ torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
+
+ def test_fold(self):
+ def fn(a):
+ return a + math.sqrt(63)
+
+ torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
+
+ def test_shape_unpack(self):
+ def fn(x):
+ a, b = x.size()
+ return x * b
+
+ i = torch.randn(5, 10)
+ r1 = fn(i)
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ r2 = opt_fn(i)
+ self.assertTrue(same(r1, r2))
+
+ def test_empty_list(self):
+ def fn(x, ll):
+ if len(ll) == 0 and not ll and ll is not None:
+ return x + 1
+
+ i = torch.randn(5, 10)
+ r1 = fn(i, [])
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ r2 = opt_fn(i, [])
+ r3 = opt_fn(i, tuple())
+ self.assertTrue(same(r1, r2))
+ self.assertTrue(same(r1, r3))
+
+ def test_config_obj(self):
+ class Cfg:
+ def __init__(self):
+ self.val = 0.5
+ self.count = 3
+
+ def fn(x, cfg):
+ for i in range(cfg.count):
+ x = x + cfg.val
+ return x
+
+ cfg1 = Cfg()
+ cfg1.val = 1.0
+ cfg2 = Cfg()
+ v = torch.zeros(1)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ v = opt_fn(v, cfg1) # 3
+ v = opt_fn(v, cfg2) # 4.5
+ cfg2.count = 1
+ v = opt_fn(v, cfg2) # 5
+ cfg2.val = 2.0
+ v = opt_fn(v, cfg2) # 7
+ self.assertEqual(v[0], 7)
+ self.assertEqual(cnts.op_count, 8)
+
+ def test_config_getattr_default(self):
+ class Cfg:
+ def __init__(self):
+ self.val = 0.5
+ self.count = 10
+
+ def fn(x, cfg):
+ if getattr(cfg, "just_add_7", False):
+ return x + 7
+ for i in range(cfg.count):
+ x = x + cfg.val
+ return x
+
+ cfg1 = Cfg()
+ v = torch.zeros(1)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertEqual(opt_fn(v, cfg1)[0], 5)
+ self.assertEqual(opt_fn(v, cfg1)[0], 5)
+ cfg1.just_add_7 = True
+ self.assertEqual(opt_fn(v, cfg1)[0], 7)
+ self.assertEqual(opt_fn(v, cfg1)[0], 7)
+ cfg1.just_add_7 = False
+ self.assertEqual(opt_fn(v, cfg1)[0], 5)
+ self.assertEqual(opt_fn(v, cfg1)[0], 5)
+ self.assertEqual(cnts.frame_count, 3)
+
+ def test_size_input(self):
+ def fn(x, s):
+ a, b = s
+ return x + (a - b)
+
+ v = torch.zeros(10, 20)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertEqual(opt_fn(v, v.size())[0, 0], -10)
+ self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10)
+ self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_cell_output1(self):
+ out = None
+
+ def fn(a, b):
+ nonlocal out
+ out = a + b * 10
+
+ v = torch.Tensor([100])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertIsNone(opt_fn(v, v))
+ self.assertEqual(out[0], 1100)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_cell_output2(self):
+ out = None
+
+ def fn(a, b):
+ nonlocal out
+ c = unsupported(a, b)
+ out = a + b * 10 + c
+
+ v = torch.Tensor([100])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertIsNone(opt_fn(v, v))
+ self.assertEqual(out[0], 1200)
+ self.assertEqual(cnts.op_count, 3)
+
+ def test_return_nested_function(self):
+ out = None
+
+ def fn(a, b):
+ nonlocal out
+ c = a + b
+ d = a + 1.0
+
+ def fn2(f: int = 7, g: float = 9.0):
+ nonlocal out
+ out = a + b * 10
+ return c * f - d * g
+
+ return fn2
+
+ v1 = torch.Tensor([100])
+ v2 = torch.Tensor([200])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ opt_fn_ret = torch._dynamo.optimize(cnts)(opt_fn(v1, v2))
+ self.assertEqual(opt_fn_ret(1.5)[0], -459)
+ self.assertEqual(out[0], 2100)
+ self.assertEqual(cnts.frame_count, 2)
+ self.assertEqual(cnts.op_count, 7)
+
+ def test_tensor_dict1(self):
+ def fn(inputs):
+ return inputs["a"] - inputs["b"] * 1.5
+
+ v1 = torch.Tensor([100])
+ v2 = torch.Tensor([200])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertEqual(opt_fn({"a": v1, "b": v2})[0], -200)
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_tensor_dict2(self):
+ def fn1(inputs):
+ total = torch.zeros(1)
+ for k, v in inputs.items():
+ total += v
+ return total
+
+ def fn2(inputs):
+ total = torch.zeros(1)
+ for v in inputs.values():
+ total += v
+ return total
+
+ def fn3(inputs):
+ total = torch.zeros(1)
+ for k in inputs.keys():
+ total += inputs[k]
+ return total
+
+ v1 = torch.Tensor([100])
+ v2 = torch.Tensor([200])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn1 = torch._dynamo.optimize(cnts)(fn1)
+ opt_fn2 = torch._dynamo.optimize(cnts)(fn2)
+ opt_fn3 = torch._dynamo.optimize(cnts)(fn3)
+ self.assertEqual(opt_fn1({"a": v1, "b": v2})[0], 300)
+ self.assertEqual(opt_fn2({"a": v1, "b": v2})[0], 300)
+ self.assertEqual(opt_fn3({"a": v1, "b": v2})[0], 300)
+ self.assertEqual(cnts.frame_count, 3)
+ self.assertEqual(cnts.op_count, 9)
+
+ def test_dictcomp(self):
+ def fn1(inputs):
+ return {k: v + 1 for k, v in inputs.items()}
+
+ v1 = torch.Tensor([100])
+ v2 = torch.Tensor([200])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn1 = torch._dynamo.optimize(cnts)(fn1)
+ self.assertEqual(opt_fn1({"a": v1, "b": v2})["a"], 101)
+ self.assertEqual(opt_fn1({"a": v1, "b": v2})["b"], 201)
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_listcomp(self):
+ def fn2(inputs):
+ return torch.sum(torch.cat([v + 1 for k, v in inputs.items()], 0))
+
+ v1 = torch.Tensor([100])
+ v2 = torch.Tensor([200])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn2 = torch._dynamo.optimize(cnts)(fn2)
+ self.assertEqual(opt_fn2({"a": v1, "b": v2}), 302)
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 4)
+
+ def test_is_floating_point(self):
+ def fn(a, b):
+ x = a + 1.0
+ if torch.is_floating_point(b):
+ x = x + b
+ return x + 2.0
+
+ return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
+
+ def test_is_floating_point2(self):
+ def fn(a, b):
+ x = a + 1.0
+ if b.is_floating_point():
+ x = x + b
+ return x + 2.0
+
+ return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
+
+ def test_is_tensor(self):
+ def fn(a, b):
+ x = a + 1.0
+ if torch.is_tensor(b):
+ x = x + b
+ return x + 2.0
+
+ return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
+
+ def test_numel(self):
+ def fn(a):
+ return a + a.numel() + torch.numel(a)
+
+ return torch._dynamo.testing.standard_test(
+ self, fn=fn, nargs=1, expected_ops=2, expected_ops_dynamic=4
+ )
+
+ def test_pair(self):
+ def fn(a):
+ return (
+ torch.zeros(torch.nn.modules.utils._pair(a.size()))
+ + a
+ + torch.ones(torch.nn.modules.utils._ntuple(3)(3)).sum()
+ )
+
+ return torch._dynamo.testing.standard_test(
+ self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8
+ )
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_tensor_item_capture(self):
+ def fn(a, b):
+ return (a + b).sum().item()
+
+ v1 = torch.randn((10, 10))
+ v2 = torch.randn((10, 10))
+ correct = fn(v1, v2)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize((cnts))(fn)
+ self.assertEqual(opt_fn(v1, v2), correct)
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 3)
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
+ def test_tensor_item_no_capture(self):
+ def fn(a, b):
+ return (a + b).sum().item()
+
+ v1 = torch.randn((10, 10))
+ v2 = torch.randn((10, 10))
+ correct = fn(v1, v2)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize((cnts))(fn)
+ self.assertEqual(opt_fn(v1, v2), correct)
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_namedtuple1(self):
+ def fn(a, b):
+ tmp = mytuple(a, b, a + b)
+ return mytuple(tmp.a, tmp[1], tmp.ab + b)
+
+ v1 = torch.Tensor([10])
+ v2 = torch.Tensor([20])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertEqual(opt_fn(v1, v2).ab, 50)
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_namedtuple2(self):
+ def fn(packed):
+ a, b, c = packed
+ if hasattr(packed, "b"):
+ b = packed.b + 1
+ c = packed[2]
+ return a + b + c
+
+ v1 = torch.Tensor([1])
+ v2 = torch.Tensor([2])
+ v3 = torch.Tensor([3])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7)
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 3)
+
+ def test_range_input(self):
+ def fn(a, rng):
+ x = a
+ for i in rng:
+ x = x + i
+ return x
+
+ def fn1(a):
+ return fn(a, rng=range(3))
+
+ return torch._dynamo.testing.standard_test(
+ self, fn=fn1, nargs=1, expected_ops=3
+ )
+
+ def test_no_grad(self):
+ def fn1(a, b):
+ x = a + 1
+ # redundant no_grad should get ignored
+ with torch.no_grad():
+ x = x + b
+ x = x + 2
+ return x
+
+ def fn2(a, b):
+ x = a + 1
+ with torch.set_grad_enabled(False):
+ x = x + b
+ x = x + 2
+ return x
+
+ def fn3(a, b):
+ x = a + 1
+ with torch.enable_grad():
+ x = x + b
+ x = x + 2
+ return x
+
+ def fn4(a, b):
+ x = a + 1
+ with torch.set_grad_enabled(True):
+ if torch.is_grad_enabled():
+ x = x + b
+ x = x + 2
+ return x
+
+ with torch.no_grad():
+ torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
+ torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
+ torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
+ torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
+ with torch.enable_grad():
+ torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
+ torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
+ torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
+ torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
+
+ def test_grad_mode_guard(self):
+ def fn(a, b):
+ prev_grad = torch.is_grad_enabled()
+ torch.set_grad_enabled(False)
+ a = a + 1
+ a.tolist() # graph break
+ ret = a + b
+ torch.set_grad_enabled(prev_grad)
+ return ret
+
+ a = torch.randn([3, 4])
+ b = torch.randn([3, 4])
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ for _ in range(10):
+ opt_fn(a, b)
+ self.assertEqual(cnts.frame_count, 2)
+
+ def test_build_tuple_unpack(self):
+ def fn1(a, b, c):
+ return a - b / c
+
+ def fn2(a, b, c):
+ tmp1 = (a,)
+ tmp2 = (b, c)
+ args = (*tmp1, *tmp2)
+ return fn1(*args)
+
+ def fn3(a, *args):
+ return fn1(a, *args)
+
+ torch._dynamo.testing.standard_test(self, fn=fn2, nargs=3, expected_ops=2)
+ torch._dynamo.testing.standard_test(self, fn=fn3, nargs=3, expected_ops=2)
+
+ def test_list_mul(self):
+ def fn(count):
+ head_mask = count * [None] * count
+ return head_mask
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertEqual(opt_fn(2), [None] * 4)
+ self.assertEqual(cnts.frame_count, 0)
+ self.assertEqual(cnts.op_count, 0)
+
+ def test_user_getattr1(self):
+ class MyConfig(dict):
+ def __getattr__(self, name):
+ return self[name]
+
+ def fn(cfg, x, y):
+ return x + y + cfg.offset
+
+ x = torch.randn(10)
+ cfg = MyConfig(offset=5)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_user_getattr2(self):
+ class MyConfig:
+ defined_on_class = 1
+
+ def __init__(self):
+ self.defined_on_object = 2
+
+ def __getattr__(self, name):
+ return 3
+
+ def fn(cfg, x):
+ return x + cfg.defined_on_class - cfg.defined_on_object + cfg.not_defined
+
+ x = torch.randn(10)
+ cfg = MyConfig()
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertTrue(same(opt_fn(cfg, x), x + 1 - 2 + 3))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 3)
+
+ def test_user_property(self):
+ class MyConfig:
+ @property
+ def prop5(self):
+ return 5
+
+ def fn(cfg, x, y):
+ return x + y + cfg.prop5
+
+ x = torch.randn(10)
+ cfg = MyConfig()
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_dataclass_fields(self):
+ @dataclasses.dataclass
+ class MyDataClass:
+ a: torch.Tensor
+ b: torch.Tensor = None
+ c: torch.Tensor = None
+ d: torch.Tensor = None
+ e: torch.Tensor = None
+
+ def fn(obj):
+ class_fields = dataclasses.fields(obj)
+ assert len(class_fields)
+ assert all(field.default is None for field in class_fields[1:])
+ other_fields_are_none = all(
+ getattr(obj, field.name) is None for field in class_fields[1:]
+ )
+ assert not other_fields_are_none
+
+ total = getattr(obj, class_fields[0].name)
+ for field in class_fields[1:]:
+ v = getattr(obj, field.name)
+ if v is not None:
+ total += v
+
+ return total
+
+ obj1 = MyDataClass(torch.randn(10), torch.randn(10), torch.randn(10))
+ obj2 = MyDataClass(torch.randn(10), e=torch.randn(10))
+ correct1 = fn(obj1)
+ correct2 = fn(obj2)
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertTrue(same(opt_fn(obj1), correct1))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ torch._dynamo.reset()
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertTrue(same(opt_fn(obj2), correct2))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 1)
+
+ @requires_static_shapes
+ def test_tensor_build_list_unpack(self):
+ def fn(x):
+ # seen in fastNLP_Bert
+ return torch.cat([*x], dim=-1)
+
+ val = torch.randn([1, 1, 473, 768])
+ correct = fn(val)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertTrue(same(opt_fn(val), correct))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_numpy_int_constant(self):
+ def fn(x, a, b):
+ return x + (a % b)
+
+ args = [torch.randn(10), 4096, np.int64(8)]
+ correct = fn(*args)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertTrue(same(opt_fn(*args), correct))
+ self.assertTrue(same(opt_fn(*args), correct))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_dict_mutation_side_effect(self):
+ def fn(d):
+ d["c"] = d["a"] + d.pop("b")
+ return d
+
+ args1 = {"a": torch.randn(10), "b": torch.randn(10)}
+ args2 = dict(args1)
+ assert fn(args1) is args1
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertIs(opt_fn(args2), args2)
+ self.assertTrue(same(args1, args2))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 1)
+
+ def test_module_deepcopy(self):
+ m1 = torch.nn.Sequential(
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ )
+ m2 = torch.nn.Sequential(
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ )
+
+ def fn(m, x):
+ m_copy = copy.deepcopy(m)
+ return m_copy(x)
+
+ v = torch.randn(10)
+ correct1 = fn(m1, v)
+ correct2 = fn(m2, v)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ for _ in range(10):
+ self.assertTrue(same(opt_fn(m1, v), correct1))
+ for _ in range(10):
+ self.assertTrue(same(opt_fn(m2, v), correct2))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 4)
+
+ def test_type_copy(self):
+ def fn(seq):
+ a, b = seq
+ return type(seq)([a + 1, b + 2, a + b])
+
+ args1 = [torch.randn(10), torch.randn(10)]
+ args2 = (torch.randn(10), torch.randn(10))
+ correct1 = fn(args1)
+ correct2 = fn(args2)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertTrue(same(opt_fn(args1), correct1))
+ self.assertTrue(same(opt_fn(args2), correct2))
+ self.assertIsInstance(opt_fn(args1), list)
+ self.assertIsInstance(opt_fn(args2), tuple)
+ self.assertEqual(cnts.frame_count, 2)
+ self.assertEqual(cnts.op_count, 6)
+
+ def test_setattr_mutation1(self):
+ class MyObj: # noqa: B903
+ def __init__(self, a, b):
+ self.a = a
+ self.b = b
+
+ def fn(obj):
+ obj.c = obj.a * obj.b + 1
+ obj.b = obj.a * obj.c + 2
+ obj.a = obj.b * obj.c + 3
+ obj.c = obj.a * obj.b + 4
+ obj.b = obj.a * obj.c + 5
+ obj.a = obj.b * obj.c + 6
+ return obj
+
+ x1 = torch.randn(10)
+ x2 = torch.randn(10)
+ obj1 = MyObj(x1, x2)
+ obj2 = MyObj(x1, x2)
+ fn(obj2)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertIs(opt_fn(obj1), obj1)
+ self.assertTrue(same(obj1.a, obj2.a))
+ self.assertTrue(same(obj1.b, obj2.b))
+ self.assertTrue(same(obj1.c, obj2.c))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 12)
+
+ def test_setattr_mutation2(self):
+ class MyObj:
+ def __init__(self, x):
+ self.a = x + 1
+ self.b = x + 2
+
+ def fn(x):
+ x = x / 3.0
+ obj = MyObj(x)
+ obj.c = obj.a * obj.b + 1
+ obj.b = obj.a * obj.c + 2
+ obj.a = obj.b * obj.c + 3
+ return obj
+
+ x1 = torch.randn(10)
+ obj2 = fn(x1)
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ obj1 = opt_fn(x1)
+ self.assertTrue(same(obj1.a, obj2.a))
+ self.assertTrue(same(obj1.b, obj2.b))
+ self.assertTrue(same(obj1.c, obj2.c))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 9)
+
+ def test_setattr_mutation3(self):
+ # TODO(jansel): dead code eliminate the object creation
+ class MyObj:
+ def __init__(self, x):
+ super().__init__()
+ self.a = x + 1
+ self.b = x + 2
+
+ def fn(x):
+ x = x / 3.0
+ obj = MyObj(x)
+ obj.c = obj.a * obj.b + 1
+ obj.b = obj.a * obj.c + 2
+ obj.a = obj.b * obj.c + 3
+ return obj.a, obj.b, obj.c
+
+ x1 = torch.randn(10)
+ obj2 = fn(x1)
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ obj1 = opt_fn(x1)
+ self.assertTrue(same(obj1, obj2))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 9)
+
+ def test_user_defined_class_name(self):
+ class MyClassFoo:
+ pass
+
+ def fn1(a, b, c):
+ tmp = MyClassFoo()
+ if tmp.__class__.__name__ == "MyClassFoo":
+ return a - b / c
+
+ torch._dynamo.testing.standard_test(self, fn=fn1, nargs=3)
+
+ def test_manual_seed(self):
+ def fn(a, b):
+ x = a + b
+ torch.manual_seed(9000)
+ return x + 1
+
+ torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
+
+ def test_usr_cls_staticmethod(self):
+ class Foo:
+ @staticmethod
+ def bar(a, b):
+ return a + b
+
+ def fn(a, b):
+ return Foo.bar(a, b) - 1
+
+ torch._dynamo.testing.standard_test(self, fn=fn, nargs=2)
+
+ def test_usr_cls_classmethod(self):
+ class Foo:
+ @classmethod
+ def bar(cls, a, b):
+ return a + b
+
+ def fn(a, b):
+ return Foo.bar(a, b) - 1
+
+ torch._dynamo.testing.standard_test(self, fn=fn, nargs=2)
+
+ def test_dunder_methods(self):
+ class Foo:
+ def __init__(self, val):
+ super().__init__()
+ self.val = val
+
+ def __add__(self, other):
+ return Foo(self.val + other.val)
+
+ def __mul__(self, other):
+ return Foo(self.val * other.val)
+
+ def __truediv__(self, other):
+ return Foo(self.val / other.val)
+
+ def __sub__(self, other):
+ return Foo(self.val - other.val)
+
+ def fn(a, b, c):
+ return Foo(a) + Foo(b) * Foo(c) / Foo(a) - Foo(b)
+
+ torch._dynamo.testing.standard_test(self, fn=fn, nargs=3, expected_ops=4)
+
+ def test_function_annotation(self):
+ class Variable:
+ pass
+
+ def fn(x):
+ x = x / 3.0
+
+ def inner(y: typing.List[Variable]):
+ return x + 1
+
+ return inner
+
+ x1 = torch.randn(10)
+ obj2 = fn(x1)([])
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
+ opt_fn_inner = torch._dynamo.optimize_assert(cnts)(opt_fn(x1))
+ obj1 = opt_fn_inner([])
+ self.assertTrue(same(obj1, obj2))
+ self.assertEqual(cnts.frame_count, 2)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_nested_closure(self):
+ v0 = torch.randn(10)
+
+ def fn1():
+ v1 = torch.randn(10)
+
+ def fn2(*args, **kwargs):
+ assert len(args) == 1
+ assert len(kwargs) == 1
+ v2 = torch.randn(10) + args[0] + kwargs["b"]
+
+ def fn3(v3=torch.randn(10)):
+ def fn4():
+ return v0 + v1 + v2 + v3 + 1
+
+ return fn4
+
+ return fn3
+
+ return fn2(1, b=2)()
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1)
+ tmp1 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
+ tmp2 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
+ self.assertTrue(tmp1().shape, (10,))
+ self.assertTrue(same(tmp1(), tmp1()))
+ self.assertFalse(same(tmp1(), tmp2()))
+ self.assertEqual(cnts.frame_count, 2)
+ self.assertEqual(cnts.op_count, 9)
+
+ def test_nested_closure_mutation(self):
+ def fn1():
+ v1 = torch.randn(10)
+
+ def fn2():
+ v2 = torch.randn(10)
+
+ def fn3():
+ nonlocal v1, v2
+ v1 += 1
+ v2 += 2
+ return v1 + v2
+
+ return fn3
+
+ rv = fn2()
+ rv()
+ rv()
+ return rv
+
+ torch.manual_seed(9000)
+ counter1 = fn1()
+ result1 = [counter1(), counter1(), counter1()]
+
+ torch.manual_seed(9000)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1)
+ counter2 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
+ result2 = [counter2(), counter2(), counter2()]
+ result1.append(counter1())
+ result2.append(counter2())
+
+ self.assertTrue(same(result1, result2))
+ self.assertEqual(cnts.frame_count, 2)
+ self.assertEqual(cnts.op_count, 11)
+
+ def test_write_to_closures_in_inlining(self):
+ out = []
+ for use_dynamo in [False, True]:
+
+ def make_counter():
+ x = torch.randn(10)
+
+ def counter():
+ nonlocal x
+ x = x + 1
+ return x
+
+ return counter
+
+ torch.manual_seed(0)
+ counter = make_counter()
+ if not use_dynamo:
+ out.append(counter() + counter())
+ else:
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ @torch._dynamo.optimize(cnts, nopython=True)
+ def fn(counter):
+ return counter() + counter()
+
+ out.append(fn(counter))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 3)
+ self.assertFalse(same(counter() + counter(), out[-1]))
+
+ self.assertTrue(same(out[0], out[1]))
+
+ def test_top_package_import(self):
+ def fn(x):
+ import torch.fx
+
+ assert not isinstance(x, torch.fx.Proxy)
+ return torch.sin(x)
+
+ x = torch.randn(4, 5)
+ ref = fn(x)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+
+ def test_optimize_on_module(self):
+ class MockModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.relu = torch.nn.ReLU()
+
+ def custom_member(self):
+ # Just for checking that Dynamo returned mod object can redirect
+ # to this method
+ pass
+
+ def forward(self, x):
+ return self.relu(x)
+
+ cnts1 = torch._dynamo.testing.CompileCounter()
+ mod = MockModule()
+ optimized_mod = torch._dynamo.optimize(cnts1, nopython=True)(mod)
+
+ a = torch.randn(10)
+ ref = mod(a)
+ res = optimized_mod(a)
+
+ optimized_mod.custom_member()
+
+ self.assertTrue(same(ref, res))
+
+ def test_nested_optimize_decorator(self):
+ cnts2 = torch._dynamo.testing.CompileCounter()
+ cnts3 = torch._dynamo.testing.CompileCounter()
+
+ @torch._dynamo.run()
+ def fn1(x):
+ return torch.sin(x) * 10
+
+ @torch._dynamo.optimize(cnts2, nopython=True)
+ def fn2(x):
+ return fn1(x) + 1
+
+ @torch._dynamo.optimize(cnts3, nopython=True)
+ def fn3(x):
+ return torch.relu(fn2(x))
+
+ fn3(torch.randn(4, 5))
+ self.assertEqual(cnts2.frame_count, 0)
+ self.assertEqual(cnts3.frame_count, 1)
+ self.assertEqual(cnts3.op_count, 4)
+
+ def test_nested_optimize_run(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ @torch._dynamo.optimize(cnts, nopython=True)
+ def fn(x):
+ return torch.relu(torch.cos(x) + torch.sin(x))
+
+ fn(torch.randn(4))
+ self.assertEqual(cnts.frame_count, 1)
+
+ fn(torch.randn(4, 4))
+ self.assertEqual(cnts.frame_count, 2)
+
+ # Test that run works on a decorated fn
+ fn = torch._dynamo.run(fn)
+ fn(torch.randn(4, 4, 4))
+ self.assertEqual(cnts.frame_count, 2)
+
+ def test_nested_optimize(self):
+ cnts1 = torch._dynamo.testing.CompileCounter()
+ cnts2 = torch._dynamo.testing.CompileCounter()
+
+ def fn(x):
+ return torch.relu(torch.cos(x) + torch.sin(x))
+
+ fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn)
+ fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1)
+
+ # The first optimize in the nesting should be ignored
+ fn2(torch.randn(4))
+ self.assertEqual(cnts2.frame_count, 1)
+ self.assertEqual(cnts1.frame_count, 0)
+
+ # Since the fn code object is already compiled, calling fn1 should
+ # directly call the compiled_fn callable.
+ torch._dynamo.run()(fn1)(torch.randn(4))
+ self.assertEqual(cnts1.frame_count, 0)
+
+ # Test same behavior by reversing the calls
+ torch._dynamo.reset()
+ cnts1 = torch._dynamo.testing.CompileCounter()
+ cnts2 = torch._dynamo.testing.CompileCounter()
+ fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn)
+ fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1)
+ fn1(torch.randn(4))
+ self.assertEqual(cnts1.frame_count, 1)
+ torch._dynamo.run()(fn2)(torch.randn(4))
+ self.assertEqual(cnts2.frame_count, 0)
+
+ def test_nested_disable_decorator(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ @torch._dynamo.disable()
+ def fn1(x):
+ return torch.sin(x) * 10
+
+ @torch._dynamo.optimize(cnts)
+ def fn2(x):
+ x = x + 1
+ x = x + 1
+ x = fn1(x) # graph break
+ x = x + 1
+ x = x + 1
+ return x
+
+ @torch._dynamo.optimize(cnts, nopython=True)
+ def fn3(x):
+ return fn2(x)
+
+ fn2(torch.randn(4, 5))
+ self.assertEqual(cnts.frame_count, 2)
+ self.assertEqual(cnts.op_count, 4)
+
+ try:
+ fn3(torch.randn(4, 5))
+ self.assertFalse(True)
+ except torch._dynamo.exc.Unsupported as e:
+ self.assertIn("call torch._dynamo.disable() wrapped function", str(e))
+
+ def test_torch_size(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ def fn(x):
+ output_size = torch.Size([10, 10])
+ x = x.view(*output_size)
+ return (x,)
+
+ x = torch.randn(100, requires_grad=True)
+ x_clone = x.clone()
+ ref = fn(x)
+
+ opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
+ res = opt_fn(x_clone)
+
+ self.assertTrue(same(ref, res))
+
+ def test_torch_seed(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ def fn(x):
+ attention_seed = int(torch.seed() % sys.maxsize)
+ torch.manual_seed(attention_seed)
+ return (x,)
+
+ x = torch.randn(100, requires_grad=True)
+ ref = fn(x)
+
+ opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
+ res = opt_fn(x)
+
+ self.assertTrue(same(ref, res))
+
+ def test_is_tensor_like(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ def f(x):
+ if torch.overrides.is_tensor_like(x):
+ return (x * 2,)
+ return (torch.ones(10) + x,)
+
+ x = torch.randn(10)
+ ref0 = f(x)
+ ref1 = f(4)
+ opt_f = torch._dynamo.optimize(cnts, nopython=True)(f)
+ res0 = opt_f(x)
+ res1 = opt_f(4)
+ self.assertTrue(same(ref0, res0))
+ self.assertTrue(same(ref1, res1))
+
+ def test_version_ci(self):
+ # temporary test to check that the ci torch version is set correctly
+ self.assertTrue(hasattr(torch, "_subclasses"))
+
+ @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
+ def test_rand(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+ device = "cuda"
+
+ def fn():
+ return torch.randn(10, device=device)
+
+ torch.manual_seed(10)
+ ref_run1 = fn()
+
+ torch.manual_seed(10)
+ ref_run2 = fn()
+ self.assertTrue(same(ref_run1, ref_run2))
+
+ torch.manual_seed(10)
+ opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
+ res = opt_fn()
+
+ self.assertTrue(same(res, ref_run1))
+
+ def test_slice_input(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ def getitem(a, idx):
+ if isinstance(idx, slice):
+ return (
+ torch.zeros(1),
+ a[idx]
+ + [
+ 100,
+ ],
+ )
+ else:
+ return (torch.zeros(1), a[idx])
+
+ layers = list(range(10))
+ ref0 = getitem(layers, slice(0, 2, 1))
+ ref1 = getitem(layers, 2)
+ ref2 = getitem(layers, slice(3, 8, 2))
+ opt_getitem = torch._dynamo.optimize(cnts, nopython=True)(getitem)
+ res0 = opt_getitem(layers, slice(0, 2, 1))
+ res1 = opt_getitem(layers, 2)
+ res2 = opt_getitem(layers, slice(3, 8, 2))
+
+ self.assertTrue(ref0 == res0)
+ self.assertTrue(ref1 == res1)
+ self.assertTrue(ref2 == res2)
+
+ def test_grad(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ def fn(a, b):
+ out = a * b
+ out.sum().backward()
+ real_out = torch.sigmoid(a.grad + b)
+ return real_out
+
+ inps = [torch.randn(4, requires_grad=True) for _ in range(2)]
+ for inp in inps:
+ inp.grad = None
+ ref = fn(*inps)
+
+ for inp in inps:
+ inp.grad = None
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res = opt_fn(*inps)
+
+ self.assertTrue(same(ref, res))
+
+ @unittest.skipIf(sys.version_info < (3, 10), "use linetable when python >= 3.10")
+ def test_linetable_writer(self):
+ def fn():
+ a = 10
+ b = 20
+ c = a + b
+ f = "linetable_writer"
+ return f"Test if {f} generates correct co_linetable: {c}"
+
+ inst = dis.get_instructions(fn)
+ result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
+ self.assertTrue(result[1] == fn.__code__.co_linetable)
+
+ @unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10")
+ def test_lnotab_writer(self):
+ def fn():
+ a = 10
+ b = 20
+ c = a + b
+ f = "lnotab_writer"
+ return f"Test if {f} generates correct co_lnotab: {c}"
+
+ inst = dis.get_instructions(fn)
+ result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
+ self.assertTrue(result[1] == fn.__code__.co_lnotab)
+
+ def test_torch_profiler(self):
+ # wrap torch.profiler.* as ProfilerContextWrapperVariable and do nothing
+ def fn(x):
+ y = x**2
+ with torch.profiler.profile():
+ y = y + 2
+ with torch.profiler.record_function("my_function"):
+ z = y**3
+ z.tolist() # graph break
+ z = z + 1
+ return z
+
+ x = torch.randn((2, 2), requires_grad=True)
+ ref = fn(x)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+ self.assertEqual(cnts.frame_count, 2)
+
+ def test_autograd_profiler(self):
+ # wrap torch.autograd.profiler.* as ProfilerContextWrapperVariable and do nothing
+ def fn(x):
+ y = x**2
+ with torch.autograd.profiler.profile():
+ y = y + 2
+ with torch.autograd.profiler.record_function("my_function"):
+ z = y**3
+ z.tolist() # graph break
+ z = z + 1
+ return z
+
+ x = torch.randn((2, 2), requires_grad=True)
+ ref = fn(x)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+ self.assertEqual(cnts.frame_count, 2)
+
+ def test_python_slice(self):
+ def f1(input):
+ y = 0
+ for i, x in enumerate(input[2:], 1):
+ y = y + x
+ return y
+
+ def f2(input):
+ y = 0
+ for i, x in enumerate(input.shape[2:], 1):
+ y = y + x
+ return y
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_f1 = torch._dynamo.optimize(cnts)(f1)
+ opt_f2 = torch._dynamo.optimize(cnts)(f2)
+ res1 = opt_f1([1, 2, 3, 5])
+ res2 = opt_f2(torch.rand([2, 3, 4, 5]))
+
+ self.assertEqual(res1, 8)
+ self.assertEqual(res2, 9)
+
+ def test_const_dict_variable_python_type(self):
+ from torch._dynamo.variables import ConstDictVariable
+
+ d1 = {"a": 10, "b": 20}
+ d2 = collections.OrderedDict([("x", 12), ("y", 22)])
+ self.assertEqual(ConstDictVariable(d1, dict).python_type(), dict)
+ self.assertEqual(
+ ConstDictVariable(d2, collections.OrderedDict).python_type(),
+ collections.OrderedDict,
+ )
+
+ def test_builtin_subclasses_as_method_on_class_type(self):
+ class Foo:
+ def __init__(self, name):
+ self.ame_ = name
+
+ def get_name(self):
+ return "Foo " + self.name_
+
+ class Bar(Foo):
+ def __init__(self, name):
+ self.name_ = name
+
+ def get_name(self):
+ return "Bar " + self.name_
+
+ class Baz(Foo):
+ def __init__(self, name): # noqa: B903
+ self.name_ = name
+
+ def get_name(self):
+ return "Baz " + self.name_
+
+ subs_of_foo_reg = Foo.__subclasses__()
+
+ counter = CompileCounter()
+
+ @torch._dynamo.optimize_assert(counter)
+ def fn():
+ return Foo.__subclasses__()
+
+ subs_of_foo_optim = fn()
+
+ self.assertEqual(len(subs_of_foo_reg), 2)
+ self.assertEqual(subs_of_foo_reg, subs_of_foo_optim)
+
+ def test_builtin_subclasses_as_method_on_var(self):
+ class Foo:
+ def __init__(self, name):
+ self.name_ = name
+
+ def get_name(self):
+ return "Foo " + self.name_
+
+ class Bar(Foo):
+ def __init__(self, name):
+ self.name_ = name
+
+ def get_name(self):
+ return "Bar " + self.name_
+
+ class Baz(Bar):
+ def __init__(self, name):
+ self.name_ = name
+
+ def get_name(self):
+ return "Baz " + self.name_
+
+ subs_of_foo_reg = Foo.__subclasses__()
+ sub_of_foo_subclass_var_reg = subs_of_foo_reg[0].__subclasses__()
+
+ sub_of_foo_subclass_var_optim = list()
+ counter = CompileCounter()
+
+ @torch._dynamo.optimize_assert(counter)
+ def fn():
+ return Foo.__subclasses__()
+
+ @torch._dynamo.optimize_assert(counter)
+ def fn_single(subs_of_foo_optim):
+ return subs_of_foo_optim[0].__subclasses__()
+
+ subs_of_foo_optim = fn()
+ sub_of_foo_subclass_var_optim = fn_single(subs_of_foo_optim)
+
+ self.assertEqual(len(sub_of_foo_subclass_var_optim), 1)
+ self.assertEqual(sub_of_foo_subclass_var_optim, sub_of_foo_subclass_var_reg)
+
+ def test_enum_no_graphbreaks(self):
+ class Foo(enum.Enum):
+ FOO = 0
+ BAR = 1
+
+ def fn(x, foo):
+ if foo is Foo.FOO:
+ x = torch.add(x, 1.0)
+ x = torch.mul(x, 1.0)
+ return x
+
+ x = torch.randn(1)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
+ opt_fn(x, Foo.FOO)
+ self.assertEqual(cnts.op_count, 2)
+
+ torch._dynamo.reset()
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
+ opt_fn(x, Foo.BAR)
+ self.assertEqual(cnts.op_count, 1)
+
+ def test_id_of_nn_module(self):
+ class M(torch.nn.Module):
+ def forward(self, x, ref_id):
+ self_id = id(self)
+ if self_id == ref_id:
+ x = torch.mul(x, 1.0)
+ x = torch.add(x, 1.0)
+ return x
+
+ m = M().eval()
+ data = torch.randn(1)
+ cnts = torch._dynamo.testing.CompileCounter()
+ correct_ref_id = id(m)
+ opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
+ opt_m(data, correct_ref_id)
+ self.assertEqual(cnts.op_count, 2)
+
+ torch._dynamo.reset()
+ cnts = torch._dynamo.testing.CompileCounter()
+ incorrect_ref_id = id(m) + 1
+ opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
+ opt_m(data, incorrect_ref_id)
+ self.assertEqual(cnts.op_count, 1)
+
+ def test_inline_func_jump_on_tensor_condition(self):
+ def f1(input):
+ if input == 0:
+ return input + 1
+ else:
+ return input + 2
+
+ def f2(input):
+ return f1(input)
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_f2 = torch._dynamo.optimize(cnts)(f2)
+ res1 = opt_f2(torch.tensor([1.0]))
+ res2 = opt_f2(torch.tensor([0.0]))
+
+ self.assertEqual(res1, 3)
+ self.assertEqual(res2, 1)
+
+ def test_frozenset_torch_func_contains(self):
+ funcs = frozenset([torch.add])
+
+ def fn(x, func):
+ if func in funcs:
+ x = torch.add(x, 1.0)
+ x = torch.mul(x, 1.0)
+ return x
+
+ x = torch.randn(1)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
+ opt_fn(x, torch.add)
+ self.assertEqual(cnts.op_count, 2)
+
+ torch._dynamo.reset()
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
+ opt_fn(x, torch.mul)
+ self.assertEqual(cnts.op_count, 1)
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", True)
+ def test_unsupported_fake_tensor(self):
+ def f(x):
+ return torch.quantize_per_tensor(x, 0.1, 10, torch.quint8)
+
+ x = torch.randn(2, 2)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_f = torch._dynamo.optimize(cnts)(f)
+ opt_f(x)
+ self.assertEqual(cnts.op_count, 0)
+
+ torch._dynamo.reset()
+ with patch.object(torch._dynamo.config, "fake_tensor_propagation", False):
+ opt_f = torch._dynamo.optimize_assert(
+ torch._dynamo.testing.CompileCounter()
+ )(f)
+ opt_f(x)
+
+ def test_inline_list_mutation(self):
+ def f1(x):
+ x.append(torch.ones(8))
+ return x
+
+ def f2():
+ x = [torch.ones(6)]
+ f1(x)
+ return x
+
+ res1 = f2()
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_f2 = torch._dynamo.optimize(cnts)(f2)
+ res2 = opt_f2()
+ self.assertTrue(same(res1, res2))
+
+ def test_inline_dict_mutation(self):
+ def f1(d):
+ d["c"] = d["a"] + d.pop("b")
+ return d
+
+ def f2():
+ d = {"a": torch.ones(5), "b": torch.ones(5)}
+ f1(d)
+ return d
+
+ res1 = f2()
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_f2 = torch._dynamo.optimize(cnts)(f2)
+ res2 = opt_f2()
+ self.assertTrue(same(res1, res2))
+
+ def test_recursive_inline_list_mutation(self):
+ def f1(x, y):
+ x.append(torch.tensor([1.1]))
+ y.append(torch.tensor([1.2]))
+ return x, y
+
+ def f2(x, y):
+ x.append(torch.tensor([2.1]))
+ y.append(torch.tensor([2.2]))
+ f1(x, y)
+ return x, y
+
+ def f3(x):
+ x.append(torch.tensor([3.1]))
+ y = [torch.tensor([3.2])]
+ f2(x, y)
+ return x, y
+
+ def f4():
+ x = [torch.tensor([4.1])]
+ return f3(x)
+
+ res1 = f4()
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_f4 = torch._dynamo.optimize(cnts)(f4)
+ res2 = opt_f4()
+ self.assertTrue(same(res1, res2))
+
+ def test_disallow_in_graph(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ @torch._dynamo.optimize(cnts)
+ def fn(a):
+ x = torch.add(a, 1)
+ x = torch.add(x, 1)
+ x = torch.sub(x, 1)
+ x = torch.add(x, 1)
+ x = torch.add(x, 1)
+ return x
+
+ torch._dynamo.disallow_in_graph(torch.sub)
+ fn(torch.randn(10))
+ torch._dynamo.allow_in_graph(torch.sub)
+
+ # check for graph break on sub
+ self.assertEqual(cnts.frame_count, 2)
+ self.assertEqual(cnts.op_count, 4)
+
+ def test_allow_in_graph(self):
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ @torch._dynamo.optimize(cnts)
+ def fn(a):
+ x = torch.add(a, 1)
+ x = torch.add(x, 1)
+ x = my_custom_function(x)
+ x = torch.add(x, 1)
+ x = torch.add(x, 1)
+ return x
+
+ torch._dynamo.allow_in_graph(my_custom_function)
+ fn(torch.randn(10))
+ torch._dynamo.disallow_in_graph(my_custom_function)
+
+ # check for no graph break
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 5)
+
+ def test_sample_input(self):
+ from torch.testing._internal.common_methods_invocations import SampleInput
+
+ def fn(sample):
+ if isinstance(sample.input, torch.Tensor):
+ return sample.input * 2
+ return torch.zeros(())
+
+ sample = SampleInput(torch.ones(2))
+ ref = fn(sample)
+
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ res = opt_fn(sample)
+
+ self.assertTrue(same(ref, res))
+
+ def test_release_input_memory(self):
+ x = torch.rand([4])
+ x_ref = weakref.ref(x)
+
+ cnts = torch._dynamo.testing.CompileCounter()
+
+ @torch._dynamo.optimize(cnts)
+ def foo(x):
+ return x + x
+
+ out = foo(x)
+ self.assertTrue(same(out, x + x))
+ del x
+ self.assertIs(x_ref(), None)
+
+ def test_release_module_memory(self):
+
+ mod = torch.nn.Linear(10, 10)
+ x = torch.rand([10, 10])
+ mod_weight_ref = weakref.ref(mod.weight)
+ mod_ref = weakref.ref(mod)
+
+ # Modules that are passed into torch._dynamo optimized functions
+ # will normally be held onto through the generated GraphModule,
+ # which contains the modules. remove the reference in this backend
+ # and test that no additional references are being held.
+ class NoLeakBackend:
+ def __call__(self, gm: torch.fx.GraphModule, example_inputs):
+ gm.mod = None
+
+ def foo(*args, **kwargs):
+ return (1,)
+
+ return foo
+
+ no_leak_backend = NoLeakBackend()
+
+ @torch._dynamo.optimize(no_leak_backend)
+ def foo(mod, x):
+ return mod(x)
+
+ foo(mod, x)
+ del mod
+ del x
+ self.assertIsNone(mod_ref(), None)
+ self.assertIsNone(mod_weight_ref(), None)
+
+ def test_update_locals_and_stack_uses_shared_cache(self):
+ def fn(x):
+ perm = [0, 3, 5]
+ perm = list(range(min(perm))) + perm
+ perm.extend(i for i in range(x.dim()) if i not in perm)
+ return perm
+
+ x = torch.rand([2, 2, 2, 2, 2, 2])
+ res1 = fn(x)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res2 = opt_fn(x)
+ self.assertTrue(same(res1, res2))
+
+ def test_dict_reconstruct_keeps_original_order(self):
+ def fn():
+ modules = collections.OrderedDict([("act", torch.nn.ReLU())])
+ module_dict = torch.nn.ModuleDict(modules)
+
+ next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()}
+ modules.update(next_modules.items())
+ module_dict.update(next_modules)
+ return modules, module_dict
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ modules, module_dict = opt_fn()
+
+ self.assertEqual(len(module_dict), len(modules))
+ for k1, m2 in zip(modules, module_dict.children()):
+ self.assertTrue(modules[k1] is m2)
+
+ def test_side_effects_codegen_update_mutated(self):
+ # codegen to update mutated variables with side effect
+ # should after stack value's codegen
+ def f1(x):
+ alist = [x]
+ alist.append(x + 1)
+ alist[0].sum().item() # graph break
+ res = alist.pop()
+ res.sum().item() # graph break
+ return res
+
+ def f2(a, b):
+ d = {"a": a + 1, "b": b + 2}
+ x = d.pop("b")
+ x.sum().item() # graph break
+ y = d["a"] + x
+ y.sum().item() # graph break
+ d["c"] = y
+ return d
+
+ x = torch.rand([2, 3])
+ a = torch.rand([5, 6])
+ b = torch.rand([5, 6])
+ res11 = f1(x)
+ res21 = f2(a, b)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_f1 = torch._dynamo.optimize(cnts)(f1)
+ opt_f2 = torch._dynamo.optimize(cnts)(f2)
+ res12 = opt_f1(x)
+ res22 = opt_f2(a, b)
+ self.assertTrue(same(res11, res12))
+ self.assertTrue(same(res21, res22))
+
+ def test_list_append_return_none(self):
+ def fn(x):
+ alist = []
+ blist = alist.append(x + 1)
+ return alist, blist
+
+ x = torch.tensor([2.3])
+ res = fn(x)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res2 = opt_fn(x)
+ self.assertEqual(res, res2)
+
+ def test_tensor_types(self):
+ def fn(dtype, tensor_type):
+ x = torch.empty(4, dtype=dtype)
+ assert isinstance(x, tensor_type)
+
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ opt_fn(torch.float32, torch.FloatTensor)
+ opt_fn(torch.float64, torch.DoubleTensor)
+ opt_fn(torch.float16, torch.HalfTensor)
+ opt_fn(torch.bfloat16, torch.BFloat16Tensor)
+ opt_fn(torch.uint8, torch.ByteTensor)
+ opt_fn(torch.int8, torch.CharTensor)
+ opt_fn(torch.int64, torch.LongTensor)
+ opt_fn(torch.int, torch.IntTensor)
+ opt_fn(torch.int16, torch.ShortTensor)
+ opt_fn(torch.bool, torch.BoolTensor)
+
+ def test_nan(self):
+ def f(x, n):
+ return x * 2 + n
+
+ x = torch.randn(4)
+ n = float("nan")
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_f = torch._dynamo.optimize(cnts)(f)
+ opt_f(x, n)
+ opt_f(x, n)
+ self.assertEqual(cnts.frame_count, 1)
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_item(self):
+ class MyMod(torch.nn.Module):
+ def forward(self, x):
+ z = torch.max(x)
+ return z.int().item()
+
+ x = torch.tensor([[10.6763, 11.7445, -2.2369]])
+ model = MyMod()
+ y = torch._dynamo.optimize("eager", nopython=True)(model)(x)
+
+ self.assertEqual(y, 11)
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_item_changes(self):
+ class MyMod(torch.nn.Module):
+ def forward(self, x):
+ z = torch.max(x)
+ return z.int().item()
+
+ x = torch.tensor([[10.6763, 11.7445, -2.2369]])
+ model = MyMod()
+ opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
+ y = opt_model(x)
+ z = opt_model(torch.tensor([[y - 5, y + 10, y + 50]]))
+
+ self.assertEqual(y, 11)
+ self.assertEqual(z, 61)
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_item_changes_new_shape(self):
+ class MyMod(torch.nn.Module):
+ def forward(self, x):
+ z = torch.max(x)
+ return z.int().item()
+
+ x = torch.tensor([[10.6763, 11.7445, -2.2369]])
+ model = MyMod()
+ opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
+ y = opt_model(x)
+ z = opt_model(torch.tensor([[y - 5, y + 50], [y + 5, y - 50]]))
+
+ self.assertEqual(y, 11)
+ self.assertEqual(z, 61)
+
+ def test_cross_entropy_loss_fancy_ctor(self):
+ output = None
+ rand_5 = torch.randn(5)
+ rand_3_5 = torch.randn(3, 5)
+ target = torch.empty(3, dtype=torch.long).random_(5)
+
+ loss = torch.nn.CrossEntropyLoss(
+ weight=rand_5, reduce=False, label_smoothing=0.5
+ )
+ opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
+ input = rand_3_5
+ dynamo_output = opt_loss(input, target)
+
+ loss = torch.nn.CrossEntropyLoss(
+ weight=rand_5, reduce=False, label_smoothing=0.5
+ )
+ input = rand_3_5
+ output = loss(input, target)
+
+ self.assertTrue(torch.allclose(dynamo_output, output))
+
+ def test_cross_entropy_loss_simple_ctor(self):
+ output = None
+ rand_3_5 = torch.randn(3, 5)
+ target = torch.empty(3, dtype=torch.long).random_(5)
+
+ loss = torch.nn.CrossEntropyLoss()
+ opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
+ input = rand_3_5
+ dynamo_output = opt_loss(input, target)
+
+ loss = torch.nn.CrossEntropyLoss()
+ input = rand_3_5
+ output = loss(input, target)
+
+ self.assertTrue(torch.allclose(dynamo_output, output))
+
+ def test_large_reduction_list(self):
+ dtype = torch.float32
+ device = "cpu"
+
+ def check_sum_all(tensor: torch.Tensor) -> None:
+ pylist = tensor.reshape(-1).tolist()
+ self.assertTrue(same(tensor.sum(), torch.tensor(sum(pylist))))
+
+ check_sum_all(torch.randn(200000, dtype=dtype, device=device))
+
+ @patch.object(torch._dynamo.config, "raise_on_backend_error", True)
+ def test_raise_on_backend_error(self):
+ def my_compiler(gm, _):
+ raise RuntimeError("duck!")
+
+ @torch._dynamo.optimize(my_compiler)
+ def fn(a, b):
+ return a + b / (a - b)
+
+ self.assertRaises(
+ torch._dynamo.exc.BackendCompilerFailed,
+ lambda: fn(torch.randn(10), torch.randn(10)),
+ )
+
+ def test_named_parameters(self):
+ n_embd = 768
+ block_size = 128
+ vocab_size = 65
+ embd_pdrop = 0.1
+
+ class MyModel2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
+ self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
+ self.drop = torch.nn.Dropout(embd_pdrop)
+
+ def forward(self, x):
+ return x
+
+ class MyModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
+ self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
+ self.drop = torch.nn.Dropout(embd_pdrop)
+ self.submod2 = MyModel2()
+
+ def forward(self, x):
+ return x
+
+ # Regular
+ params = []
+ mod = MyModel()
+ actual_params = list(mod.named_parameters())
+
+ @torch._dynamo.optimize("eager", nopython=True)
+ def fn():
+ return list(mod.named_parameters())
+
+ params = fn()
+
+ self.assertEqual(len(actual_params), len(params))
+ for idx in range(len(params)):
+ k_a, v_a = actual_params[idx]
+ k, v = params[idx]
+ self.assertEqual(k_a, k)
+ self.assertTrue(torch.allclose(v_a, v))
+
+ # Prefix
+ params = []
+ mod = MyModel()
+ actual_params = list(mod.named_parameters(prefix="foo"))
+
+ @torch._dynamo.optimize("eager", nopython=True)
+ def fn1():
+ return list(mod.named_parameters(prefix="foo"))
+
+ params = fn1()
+
+ self.assertEqual(len(actual_params), len(params))
+ for idx in range(len(params)):
+ k_a, v_a = actual_params[idx]
+ k, v = params[idx]
+ self.assertEqual(k_a, k)
+ self.assertTrue(torch.allclose(v_a, v))
+
+ def test_module_complex_iter(self):
+ n_embd = 768
+ block_size = 128
+ vocab_size = 65
+ embd_pdrop = 0.1
+
+ class FakeGPT(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
+ self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
+ self.drop = torch.nn.Dropout(embd_pdrop)
+ self.ln_f = torch.nn.LayerNorm(n_embd)
+ self.head = torch.nn.Linear(n_embd, vocab_size, bias=False)
+
+ self.block_size = block_size
+ self.names = []
+
+ def forward(self, idx, targets=None):
+ from torch.nn import functional as F
+
+ b, t = idx.size()
+ assert (
+ t <= self.block_size
+ ), "Cannot forward, model block size is exhausted."
+
+ # forward the GPT model
+ token_embeddings = self.tok_emb(
+ idx
+ ) # each index maps to a (learnable) vector
+ position_embeddings = self.pos_emb[
+ :, :t, :
+ ] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.ln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(
+ logits.view(-1, logits.size(-1)), targets.view(-1)
+ )
+
+ return logits, loss
+
+ def foo(self, memo=None, prefix="", remove_duplicate=False):
+ for mn, m in self.named_modules(
+ memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
+ ):
+ for pn, p in self.named_parameters():
+ fpn = "%s.%s" % (mn, pn) if mn else pn
+ self.names.append(fpn)
+
+ # Test plain recurse
+ model_a = FakeGPT()
+ model_a.foo()
+ a_names = model_a.names
+
+ model_b = FakeGPT()
+ opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b)
+ opt_model_b.foo()
+
+ self.assertEqual(a_names, model_b.names)
+
+ # Test with prefix
+ model_a = FakeGPT()
+ model_a.foo(prefix="abc")
+ a_names = model_a.names
+
+ model_b = FakeGPT()
+ opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b)
+ opt_model_b.foo(prefix="abc")
+
+ self.assertEqual(a_names, model_b.names)
+
+ def test_numpy_variable_isinstance(self):
+ def fn(x, m):
+ if isinstance(m, np.ndarray):
+ return x + 1
+ else:
+ return x - 1
+
+ x = torch.tensor([2.3])
+ m = np.array([1, 2, 3])
+ ref = fn(x, m)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res = opt_fn(x, m)
+ self.assertEqual(ref, res)
+
+ def test_tensor_dot_grad_no_graph_break(self):
+ def fn(a, b):
+ y = 3 * a**3 - b**2
+ y.backward(gradient=torch.tensor([1.0, 1.0]))
+ b.grad.zero_()
+ return a.grad, b.grad
+
+ a = torch.tensor([2.0, 3.0], requires_grad=True)
+ b = torch.tensor([6.0, 4.0], requires_grad=True)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ _, b_grad = opt_fn(a, b)
+ self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0])))
+ self.assertEqual(cnts.frame_count, 2)
+
+ def test_torch_nn_parameter_isinstance(self):
+ def fn(x):
+ a = torch.nn.Parameter(torch.rand(2, 3))
+ if isinstance(a, torch.Tensor):
+ return x + 1
+ else:
+ return x - 1
+
+ x = torch.tensor([2.5])
+ ref = fn(x)
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ res = opt_fn(x)
+ self.assertEqual(ref, res)
+
+ def test_change_backends(self):
+ @torch._dynamo.optimize("eager", nopython=True)
+ def fn1():
+ return x + 1
+
+ @torch._dynamo.optimize("ts")
+ def fn2():
+ return x + 2
+
+ @torch._dynamo.optimize("eager", nopython=False)
+ def fn3():
+ return x + 1
+
+ x = torch.tensor([3, 5])
+
+ fn1()
+ fn1()
+ fn3()
+ self.assertRaises(torch._dynamo.exc.ResetRequired, fn2)
+ fn1()
+ torch._dynamo.reset()
+ fn2()
+ fn2()
+ self.assertRaises(torch._dynamo.exc.ResetRequired, fn1)
+ self.assertRaises(torch._dynamo.exc.ResetRequired, fn3)
+ fn2()
+
+ def test_dynamo_min_operator_with_shape(self):
+ @torch._dynamo.optimize("eager", nopython=True)
+ def f(x, a):
+ return min(x.shape[0], a)
+
+ result = f(torch.ones(6), 3)
+ self.assertEqual(result, 3)
+
+ @patch.object(torch._dynamo.config, "dynamic_shapes", True)
+ def test_onnx_shape_as_tensor(self):
+ @torch._dynamo.optimize("eager", nopython=True)
+ def f(x):
+ return 1 + torch._shape_as_tensor(x)[0]
+
+ gm, _ = torch._dynamo.export(f, torch.ones(6))
+
+ input_one_dim = torch.ones(6)
+ input_two_dims = torch.ones(7, 4)
+ self.assertEqual(f(input_one_dim), 7)
+ self.assertEqual(f(input_two_dims), 8)
+ self.assertEqual(f(input_two_dims), 8)
+
+ @torch._dynamo.optimize("eager", nopython=True)
+ def f_onnx(x):
+ return 1 + torch.onnx.operators.shape_as_tensor(x)[0]
+
+ self.assertEqual(f_onnx(input_one_dim), 7)
+ self.assertEqual(f_onnx(input_two_dims), 8)
+ self.assertEqual(f_onnx(input_two_dims), 8)
+
+ def test_cond(self):
+ from functorch.experimental.cond import cond
+
+ def true_fn(x):
+ return x.sin()
+
+ def false_fn(x):
+ return x.cos()
+
+ def f(pred, x):
+ return cond(pred, true_fn, false_fn, [x])
+
+ opt_fn = torch._dynamo.optimize("eager")(f)
+ a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25]))
+ self.assertTrue(same(torch.cos(torch.tensor([0.25, 0.25])), a))
+ b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25]))
+ self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b))
+
+ def test_cond_nested(self):
+ from functorch.experimental.cond import cond
+
+ def true_fn_nested(x):
+ return x * 10
+
+ def false_fn_nested(x):
+ return x * -1
+
+ def true_fn(pred2, x):
+ return x.sin()
+
+ def false_fn(pred2, x):
+ return x + cond(pred2, true_fn_nested, false_fn_nested, [x])
+
+ def f(pred, pred2, x):
+ return cond(pred, true_fn, false_fn, [pred2, x])
+
+ cc = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cc)(f)
+ true_true_sin = opt_fn(
+ torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin))
+
+ true_false_sin = opt_fn(
+ torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin))
+
+ false_true_sum_mult = opt_fn(
+ torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(
+ same(torch.tensor([2.75, 2.75]), false_true_sum_mult)
+ ) # * 10 then add x
+
+ false_false_sum_neg = opt_fn(
+ torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(
+ same(torch.tensor([0.0, 0.0]), false_false_sum_neg)
+ ) # * -1 then add x
+ self.assertTrue(cc.frame_count, 2)
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_cond_nested_fake_tensor_off(self):
+ from functorch.experimental.cond import cond
+
+ def true_fn_nested(x):
+ return x * 10
+
+ def false_fn_nested(x):
+ return x * -1
+
+ def true_fn(pred2, x):
+ return x.sin()
+
+ def false_fn(pred2, x):
+ return x + cond(pred2, true_fn_nested, false_fn_nested, [x])
+
+ def f(pred, pred2, x):
+ return cond(pred, true_fn, false_fn, [pred2, x])
+
+ cc = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cc)(f)
+ true_true_sin = opt_fn(
+ torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin))
+
+ true_false_sin = opt_fn(
+ torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin))
+
+ false_true_sum_mult = opt_fn(
+ torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(
+ same(torch.tensor([2.75, 2.75]), false_true_sum_mult)
+ ) # * 10 then add x
+
+ false_false_sum_neg = opt_fn(
+ torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(
+ same(torch.tensor([0.0, 0.0]), false_false_sum_neg)
+ ) # * -1 then add x
+ self.assertTrue(cc.frame_count, 1)
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_cond_export(self):
+ from functorch.experimental.cond import cond
+
+ def true_fn_nested(x):
+ return x * 10
+
+ def false_fn_nested(x):
+ return x * -1
+
+ def true_fn(pred2, x):
+ return x.sin()
+
+ def false_fn(pred2, x):
+ return x + cond(pred2, true_fn_nested, false_fn_nested, [x])
+
+ def f(pred, pred2, x):
+ return cond(pred, true_fn, false_fn, [pred2, x])
+
+ graph, guard = torch._dynamo.export(
+ f, torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
+ )
+ true_true_sin = graph(
+ torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin))
+
+ true_false_sin = graph(
+ torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin))
+
+ false_true_sum_mult = graph(
+ torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(
+ same(torch.tensor([2.75, 2.75]), false_true_sum_mult)
+ ) # * 10 then add x
+
+ false_false_sum_neg = graph(
+ torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25])
+ )
+ self.assertTrue(
+ same(torch.tensor([0.0, 0.0]), false_false_sum_neg)
+ ) # * -1 then add x
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_cond_export_single_arg(self):
+ from functorch.experimental.cond import cond
+
+ def true_fn(x):
+ return x
+
+ def false_fn(x):
+ return x.sin()
+
+ def f(pred, x):
+ return cond(pred, true_fn, false_fn, [x])
+
+ graph, guard = torch._dynamo.export(
+ f, torch.tensor(False), torch.tensor([0.25, 0.25])
+ )
+ true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25]))
+ self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror))
+ true_mirror_2 = graph(torch.tensor(True), torch.tensor([0.33, 0.33, 0.33]))
+ self.assertTrue(same(torch.tensor([0.33, 0.33, 0.33]), true_mirror_2))
+
+ false_sin = graph(torch.tensor(False), torch.tensor([0.5, 0.5]))
+ self.assertTrue(same(torch.sin(torch.tensor([0.5, 0.5])), false_sin))
+
+ def test_disable_optimize(self):
+ cnt = torch._dynamo.testing.CompileCounter()
+
+ @torch._dynamo.optimize(cnt, disable=True)
+ def f1(x):
+ return x + 1
+
+ f1(torch.ones(6))
+ self.assertEqual(cnt.frame_count, 0)
+
+ @torch._dynamo.optimize(cnt, disable=True)
+ def f2(x):
+ return x + 1
+
+ f2(torch.ones(6))
+ self.assertEqual(cnt.frame_count, 0)
+
+ with patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"}):
+
+ @torch._dynamo.optimize(cnt)
+ def f3(x):
+ return x + 1
+
+ f3(torch.ones(6))
+ self.assertEqual(cnt.frame_count, 0)
+
+ def test_config_log_level(self):
+ @torch._dynamo.optimize("eager")
+ def fn(a, b):
+ return a + b
+
+ with self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log:
+ torch._dynamo.config.log_level = logging.DEBUG
+ fn(torch.randn(10), torch.randn(10))
+ cur_len = len(log)
+ self.assertGreater(cur_len, 0)
+
+ torch._dynamo.config.log_level = logging.WARNING
+ fn(torch.randn(10), torch.randn(10))
+ self.assertEqual(cur_len, len(log))
+
+ def test_duplicate_graph_break_warning(self):
+ @torch._dynamo.optimize("eager")
+ def f1(a, b):
+ f2(a, b)
+
+ def f2(a, b):
+ c = a + b
+ print("break")
+ return a + b + c
+
+ @torch._dynamo.optimize("eager")
+ def g1(a, b):
+ g2(a, b)
+
+ def g2(a, b):
+ c = a + b
+ print("break")
+ return a + b + c
+
+ def count_graph_break_msgs(msgs):
+ return sum(msg.find("Graph break") != -1 for msg in msgs)
+
+ with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log:
+ torch._dynamo.config.verbose = True
+ f1(torch.randn(10), torch.randn(10))
+ self.assertGreater(count_graph_break_msgs(log.output), 1)
+
+ with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log:
+ torch._dynamo.config.verbose = False
+ g1(torch.randn(10), torch.randn(10))
+ self.assertEqual(count_graph_break_msgs(log.output), 1)
+
+ def test_inplace_param_update(self):
+ def fn(param, y):
+ prev_grad = torch.is_grad_enabled()
+ try:
+ torch.set_grad_enabled(False)
+ torch.set_grad_enabled(True)
+ torch.set_grad_enabled(False)
+ param.add_(y)
+ finally:
+ torch.set_grad_enabled(prev_grad)
+
+ y = torch.randn(4)
+ x = torch.nn.Parameter(torch.randn(4))
+ fn(x, y)
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
+ opt_fn(x, y)
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 5)
+
+ @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
+ def test_autocast(self):
+ if not torch.cuda.is_bf16_supported():
+ raise unittest.SkipTest("requires bf16")
+
+ class MyModule(torch.nn.Module):
+ def forward(self, x):
+ a_float32 = torch.rand((8, 8), device="cuda")
+ b_float32 = torch.rand((8, 8), device="cuda")
+ d_float32 = torch.rand((8, 8), device="cuda")
+
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ e_float16 = torch.mm(a_float32, b_float32)
+ f_float16 = torch.mm(d_float32, e_float16)
+ return f_float16
+
+ module = MyModule()
+ real = module(torch.tensor([0.5]))
+ real_device = real.device
+ real_dtype = real.dtype
+
+ graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
+ exported = graph(torch.tensor([0.5]))
+ self.assertEqual(exported.device, real_device)
+ self.assertEqual(exported.dtype, real_dtype)
+
+ self.assertEqual(exported.device.type, "cuda")
+ self.assertEqual(exported.device.index, 0)
+ self.assertEqual(exported.dtype, torch.bfloat16)
+
+ def test_autocast_cpu(self):
+ class MyModule(torch.nn.Module):
+ def forward(self, x):
+ a_float32 = torch.rand((8, 8), device="cpu")
+ b_float32 = torch.rand((8, 8), device="cpu")
+ d_float32 = torch.rand((8, 8), device="cpu")
+
+ with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+ e_float16 = torch.mm(a_float32, b_float32)
+ f_float16 = torch.mm(d_float32, e_float16)
+ return f_float16
+
+ module = MyModule()
+ real = module(torch.tensor([0.5]))
+ real_device = real.device
+ real_dtype = real.dtype
+
+ graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
+ exported = graph(torch.tensor([0.5]))
+ self.assertEqual(exported.device, real_device)
+ self.assertEqual(exported.dtype, real_dtype)
+
+ self.assertEqual(exported.device.type, "cpu")
+ self.assertEqual(exported.dtype, torch.bfloat16)
+
+ @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
+ def test_autocast_float64(self):
+ class MyModule(torch.nn.Module):
+ def forward(self, x):
+ a_float32 = torch.rand((8, 8), device="cuda")
+ b_float32 = torch.rand((8, 8), device="cuda")
+ d_float32 = torch.rand((8, 8), device="cuda")
+
+ with torch.autocast(device_type="cuda", dtype=torch.float64):
+ e_float64 = torch.mm(a_float32, b_float32)
+ f_float64 = torch.mm(d_float32, e_float64)
+ return f_float64
+
+ module = MyModule()
+ real = module(torch.tensor([0.5]))
+ real_device = real.device
+ real_dtype = real.dtype
+
+ graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
+ exported = graph(torch.tensor([0.5]))
+ self.assertEqual(exported.device, real_device)
+ self.assertEqual(exported.dtype, real_dtype)
+
+ self.assertEqual(exported.device.index, 0)
+ self.assertEqual(exported.dtype, torch.float64)
+
+ @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
+ def test_autocast_device(self):
+ class MyModule(torch.nn.Module):
+ def forward(self, x):
+ a_float32 = torch.rand((8, 8), device="cuda")
+ b_float32 = torch.rand((8, 8), device="cuda")
+ d_float32 = torch.rand((8, 8), device="cuda")
+
+ with torch.autocast(device_type="cuda"):
+ e_float64 = torch.mm(a_float32, b_float32)
+ f_float64 = torch.mm(d_float32, e_float64)
+ return f_float64
+
+ module = MyModule()
+ real = module(torch.tensor([0.5]))
+ real_device = real.device
+ real_dtype = real.dtype
+
+ graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
+ exported = graph(torch.tensor([0.5]))
+ self.assertEqual(exported.device, real_device)
+ self.assertEqual(exported.dtype, real_dtype)
+
+ self.assertEqual(exported.device.index, 0)
+ self.assertEqual(exported.dtype, torch.torch.float16)
+
+ def test_generate_tensor_from_list_of_numpy_primitive_type(self):
+ # Test sth like torch.LongTensor(list(np.int64, np.int64, ...))
+ def fn():
+ x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64)
+ y = [x[0], x[2], x[4]]
+ z = torch.LongTensor(y)
+ return z
+
+ ref = fn()
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ res = opt_fn()
+ self.assertTrue(same(ref, res))
+
+ def test_autograd_function_equivalence(self):
+ m1 = Module1()
+
+ @torch._dynamo.optimize("eager", nopython=True)
+ def f1():
+ return m1(torch.ones(2, 3))
+
+ self.assertTrue(torch.allclose(f1(), torch.tensor([2.0])))
+
+ m2 = Module2()
+
+ @torch._dynamo.optimize("eager", nopython=True)
+ def f2():
+ return m2(torch.ones(2, 3))
+
+ self.assertTrue(torch.allclose(f2(), torch.tensor([2.0])))
+
+ def test_object_classmethod(self):
+ class C:
+ @classmethod
+ def fn(cls, x):
+ return x + x
+
+ @torch._dynamo.optimize("eager", nopython=True)
+ def f():
+ return C().fn(torch.ones(2, 3))
+
+ self.assertTrue(torch.allclose(f(), torch.tensor([2.0])))
+
+ def test_object_staticmethod(self):
+ class C:
+ @staticmethod
+ def fn(x):
+ return x + x
+
+ @torch._dynamo.optimize("eager", nopython=True)
+ def f():
+ return C().fn(torch.ones(2, 3))
+
+ self.assertTrue(torch.allclose(f(), torch.tensor([2.0])))
+
+ def test_user_function_variable_supports_enum_argument(self):
+ class Foo(enum.Enum):
+ FOO = 0
+ BAR = 1
+
+ def gn(x, y=Foo.FOO):
+ if y is Foo.FOO:
+ return x
+ else:
+ return x + 1
+
+ def fn(x):
+ return gn(x)
+
+ x = torch.randn(2, 3)
+ ref = fn(x)
+ opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
+ res = opt_fn(x)
+ self.assertTrue(torch.allclose(ref, res))
+
+ def test_repro_graph_breaks_in__get_item_by_idx(self):
+ class Mod(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.mod = torch.nn.Sequential(
+ torch.nn.Linear(3, 3), torch.nn.Linear(3, 3)
+ )
+
+ def forward(self, x):
+ return self.mod[0](x)
+
+ m = Mod()
+ graph, _ = torch._dynamo.export(m, torch.randn(3, 3))
+
+
+class CustomFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, foo):
+ return foo + foo
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output
+
+
+class Module1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, foo):
+ return CustomFunc().apply(foo)
+
+
+class Module2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.fn = CustomFunc.apply
+
+ def forward(self, foo):
+ return self.fn(foo)
+
+
+class TestTracer(JitTestCase):
+ def test_jit_save(self):
+ def fn():
+ class Foo(torch.nn.Module):
+ def __init__(self):
+ super(Foo, self).__init__()
+ self.a = 3
+
+ @torch.jit.export
+ def __getstate__(self):
+ return (3, self.training)
+
+ @torch.jit.export
+ def __setstate__(self, state):
+ self.a = state[0]
+ self.training = state[1]
+
+ def forward(self, x):
+ return x + self.a
+
+ f = Foo()
+
+ return torch.jit.trace(f, (torch.rand(3, 4),))
+
+ fn()
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ opt_fn()
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py
new file mode 100644
index 0000000000000..28fdbbb8e5963
--- /dev/null
+++ b/test/dynamo/test_model_output.py
@@ -0,0 +1,165 @@
+# Owner(s): ["module: dynamo"]
+import dataclasses
+import unittest.mock
+
+import torch
+
+import torch._dynamo.testing
+from torch._dynamo.testing import same
+
+try:
+ from transformers import modeling_outputs
+ from transformers.configuration_utils import PretrainedConfig
+ from transformers.file_utils import ModelOutput
+ from transformers.modeling_outputs import BaseModelOutput
+except ImportError:
+ modeling_outputs = None
+
+
+def maybe_skip(fn):
+ if modeling_outputs is None:
+ return unittest.skip("requires HuggingFace")(fn)
+ return fn
+
+
+class TestHFPretrained(torch._dynamo.testing.TestCase):
+ @maybe_skip
+ def test_pretrained(self):
+ def fn(a, tmp):
+ if tmp.return_dict:
+ return a + torch.ones(2) * tmp.max_length
+ return a
+
+ x = torch.randn(2)
+ tmp = PretrainedConfig(return_dict=True, max_length=20)
+ ref = fn(x, tmp)
+ opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
+ res = opt_fn(x, tmp)
+ self.assertTrue(same(ref, res))
+
+
+class TestModelOutput(torch._dynamo.testing.TestCase):
+ @maybe_skip
+ def test_mo_create(self):
+ def fn(a, b):
+ tmp = BaseModelOutput(a + 1, attentions=b + 3)
+ return tmp
+
+ torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2)
+
+ @maybe_skip
+ def test_mo_assign(self):
+ def fn(a, b):
+ tmp = BaseModelOutput(last_hidden_state=b + 3)
+ tmp.hidden_states = a + 7
+ tmp["attentions"] = a + b + 6
+ return tmp
+
+ args = [torch.randn(10), torch.randn(10)]
+ obj1 = fn(*args)
+
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
+ obj2 = opt_fn(*args)
+ self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state))
+ self.assertTrue(same(obj1.hidden_states, obj2.hidden_states))
+ self.assertTrue(same(obj1.attentions, obj2.attentions))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 4)
+
+ def _common(self, fn, op_count):
+ args = [
+ BaseModelOutput(
+ last_hidden_state=torch.randn(10), attentions=torch.randn(10)
+ )
+ ]
+ obj1 = fn(*args)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
+ obj2 = opt_fn(*args)
+ self.assertTrue(same(obj1, obj2))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, op_count)
+
+ @maybe_skip
+ def test_mo_getattr(self):
+ def fn(obj: BaseModelOutput):
+ x = obj.last_hidden_state * 10
+ if obj.hidden_states is not None:
+ x += obj.hidden_states
+ if obj.attentions is not None:
+ x += obj.attentions
+ return x
+
+ self._common(fn, 2)
+
+ @maybe_skip
+ def test_mo_getitem(self):
+ def fn(obj: BaseModelOutput):
+ x = obj["last_hidden_state"] * 10
+ if "hidden_stats" in obj:
+ x += obj["hidden_states"]
+ if "attentions" in obj:
+ x += obj["attentions"]
+ return x
+
+ self._common(fn, 2)
+
+ @maybe_skip
+ def test_mo_tuple(self):
+ def fn(obj: BaseModelOutput):
+ a, b = obj.to_tuple()
+ return a + b * 10
+
+ self._common(fn, 2)
+
+ @maybe_skip
+ def test_mo_index(self):
+ def fn(obj: BaseModelOutput):
+ return obj[0] * 10 + obj[1]
+
+ self._common(fn, 2)
+
+ @maybe_skip
+ def test_mo_init(self):
+ @dataclasses.dataclass
+ class MyDataClass(ModelOutput):
+ a: torch.Tensor
+ b: torch.Tensor = None
+ c: torch.Tensor = None
+ d: torch.Tensor = None
+ e: torch.Tensor = None
+
+ def fn(obj):
+ class_fields = dataclasses.fields(obj)
+ assert len(class_fields)
+ assert all(field.default is None for field in class_fields[1:])
+ other_fields_are_none = all(
+ getattr(obj, field.name) is None for field in class_fields[1:]
+ )
+ assert not other_fields_are_none
+
+ total = getattr(obj, class_fields[0].name)
+ for field in class_fields[1:]:
+ v = getattr(obj, field.name)
+ if v is not None:
+ total += v
+
+ return total
+
+ tensors = [torch.randn(10), torch.randn(10), torch.randn(10)]
+ obj1 = MyDataClass(*tensors)
+ correct1 = fn(obj1)
+
+ obj2 = MyDataClass(*tensors)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ self.assertTrue(same(opt_fn(obj2), correct1))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py
new file mode 100644
index 0000000000000..6d05026499a7d
--- /dev/null
+++ b/test/dynamo/test_modules.py
@@ -0,0 +1,889 @@
+# Owner(s): ["module: dynamo"]
+
+from copy import deepcopy
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo.testing
+from torch._dynamo.eval_frame import unsupported
+from torch._dynamo.mutation_guard import GenerationTracker
+from torch._dynamo.testing import same
+from torch.nn import functional as F
+from torch.nn.modules.lazy import LazyModuleMixin
+from torch.nn.parameter import Parameter, UninitializedParameter
+
+try:
+ from . import test_functions
+except ImportError:
+ import test_functions
+
+
+class BasicModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(10, 10)
+ self.scale = torch.randn(1, 10)
+
+ def forward(self, x):
+ return F.relu(self.linear1(x)) * self.scale
+
+
+class FnMember(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(10, 10)
+ self.activation = F.relu
+
+ def forward(self, x):
+ x = self.linear1(x)
+ if self.activation:
+ x = self.activation(x)
+ return x
+
+
+class FnMemberCmp(torch.nn.Module):
+ def __init__(self, activation):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(10, 10)
+ self.activation = activation
+
+ def forward(self, x):
+ x = self.linear1(x)
+ if self.activation is not None:
+ x = self.activation(x)
+ if self.activation is None:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SubmoduleExample(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = BasicModule()
+ self.layer2 = BasicModule()
+ self.scale = torch.randn(1, 10)
+
+ def forward(self, x):
+ x = self.layer1(x)
+ x = self.layer2(x)
+ return x * self.scale
+
+
+class IsTrainingCheck(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(10, 10)
+ self.linear2 = torch.nn.Linear(10, 10)
+ self.train(True)
+
+ def forward(self, x):
+ if self.training:
+ mod = self.linear1
+ else:
+ mod = self.linear2
+ return F.relu(mod(x))
+
+
+class IsEvalCheck(IsTrainingCheck):
+ def __init__(self):
+ super().__init__()
+ self.train(False)
+
+
+class ModuleMethodCall(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = BasicModule()
+ self.layer2 = BasicModule()
+ self.scale = torch.randn(1, 10)
+
+ def call_and_scale(self, mod, x):
+ x = mod(x)
+ return x * self.scale
+
+ def forward(self, x):
+ x1 = self.call_and_scale(self.layer1, x)
+ x2 = self.call_and_scale(self.layer2, x)
+ return x1 + x2
+
+
+class UnsupportedMethodCall(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = BasicModule()
+ self.scale = torch.randn(1, 10)
+
+ def call_and_scale(self, mod, x):
+ x = mod(x)
+ x = x * self.scale
+ return unsupported(x, x)
+
+ def forward(self, x):
+ x1 = self.call_and_scale(self.layer1, x)
+ return x + x1
+
+
+class UnsupportedModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = BasicModule()
+ self.scale = torch.randn(1, 10)
+
+ def forward(self, x):
+ x = self.layer1(x) * self.scale
+ return unsupported(x, x)
+
+
+class UnsupportedModuleCall(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.mod = UnsupportedModule()
+
+ def forward(self, x):
+ return 1 + self.mod(x * 1.5)
+
+
+class ModuleStaticMethodCall(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = BasicModule()
+ self.layer2 = BasicModule()
+ self.scale = torch.randn(1, 10)
+
+ @staticmethod
+ def call_and_scale(scale, mod, x):
+ x = mod(x)
+ return x * scale
+
+ def forward(self, x):
+ x1 = self.call_and_scale(self.scale, self.layer1, x)
+ x2 = self.call_and_scale(self.scale, self.layer2, x)
+ return x1 + x2
+
+
+class ModuleClassMethodCall(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = BasicModule()
+ self.layer2 = BasicModule()
+ self.scale = torch.randn(1, 10)
+
+ @classmethod
+ def call_and_scale(cls, scale, mod, x):
+ x = mod(x)
+ return x * scale
+
+ def forward(self, x):
+ x1 = self.call_and_scale(self.scale, self.layer1, x)
+ x2 = self.call_and_scale(self.scale, self.layer2, x)
+ return x1 + x2
+
+
+class ModuleProperty(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.scale = torch.randn(1, 10)
+
+ @property
+ def scale_alias(self):
+ return self.scale
+
+ def forward(self, x):
+ return x * self.scale_alias
+
+
+class ConstLoop(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(10, 10)
+ self.count = 3
+
+ def forward(self, x):
+ for i in range(self.count):
+ x = torch.sigmoid(self.linear1(x))
+ return x
+
+
+class ViaModuleCall(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(10, 10)
+
+ def forward(self, x):
+ return test_functions.constant3(torch.sigmoid(self.linear1(x)), x)
+
+
+class IsNoneLayer(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = torch.nn.Linear(10, 10)
+ self.layer2 = None
+ self.train(True)
+
+ def forward(self, x):
+ if self.layer1 is not None:
+ x = self.layer1(x)
+ if self.layer2 is not None:
+ x = self.layer2(x)
+ return x
+
+
+class LayerList(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = [
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ torch.nn.Linear(10, 10),
+ ]
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = layer(x)
+ return x
+
+
+class ModuleList(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ ]
+ )
+
+ def forward(self, x):
+ for i in range(len(self.layers)):
+ x = self.layers[i](x)
+
+ for layer in self.layers:
+ x = layer(x)
+
+ for layer, val in zip(self.layers, (x, x, x, x)):
+ x = layer(x) + val
+
+ for layer, val in zip(self.layers, (1, 2, 3, 4)):
+ x = layer(x) + val
+
+ for idx, layer in enumerate(self.layers):
+ x = layer(x) * idx
+
+ for idx, layer in enumerate(self.layers[::-1]):
+ x = layer(x) * idx
+
+ return x
+
+
+class ModuleDict(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = torch.nn.ModuleDict(
+ {
+ "0": torch.nn.Linear(10, 10),
+ }
+ )
+
+ def forward(self, x):
+ # TODO(future PR): handle more logic
+ x = self.layers["0"](x)
+ return x
+
+
+class TensorList(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = (
+ torch.randn((1, 10)),
+ torch.randn((10, 1)),
+ torch.randn((1, 10)),
+ torch.randn((10, 1)),
+ )
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = x * layer
+ return x
+
+
+class Children(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.l1 = torch.nn.Linear(10, 10)
+ self.l2 = torch.nn.ReLU()
+ self.l3 = torch.nn.Linear(10, 10)
+ self.l4 = torch.nn.ReLU()
+
+ def forward(self, x):
+ for block in self.children():
+ x = block(x)
+ return x
+
+
+class IntArg(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = torch.nn.Linear(10, 10)
+
+ def forward(self, x, offset=1):
+ x = F.relu(self.layer1(x)) + offset
+ return x
+
+
+class Seq(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class Cfg:
+ def __init__(self):
+ self.val = 0.5
+ self.count = 3
+
+
+class CfgModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.cfg = Cfg()
+ self.layer = torch.nn.Linear(10, 10)
+
+ def forward(self, x):
+ for i in range(self.cfg.count):
+ x = self.layer(x + self.cfg.val)
+ return x
+
+
+class StringMember(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(10, 10)
+ self.mode = "some_string"
+
+ def forward(self, x):
+ if self.mode == "some_string":
+ return F.relu(self.linear1(x))
+
+
+class _Block(torch.nn.Module):
+ def forward(self, x):
+ return 1.5 * torch.cat(x, 1)
+
+
+class _DenseBlock(torch.nn.ModuleDict):
+ _version = 2
+
+ def __init__(
+ self,
+ num_layers: int = 3,
+ ) -> None:
+ super().__init__()
+ for i in range(num_layers):
+ self.add_module("denselayer%d" % (i + 1), _Block())
+
+ def forward(self, init_features):
+ features = [init_features]
+ for name, layer in self.items():
+ new_features = layer(features)
+ features.append(new_features)
+ return torch.cat(features, 1)
+
+
+class DenseNetBlocks(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = _DenseBlock()
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class MaterializedModule(torch.nn.Module):
+ """Once the below lazy module is initialized with its first input,
+ it is transformed into this module."""
+
+ param: Parameter
+
+ def __init__(self):
+ super().__init__()
+ self.register_parameter("param", None)
+
+ def forward(self, x):
+ return x
+
+
+class LazyModule(LazyModuleMixin, MaterializedModule):
+ param: UninitializedParameter
+ cls_to_become = MaterializedModule
+
+ def __init__(self):
+ super().__init__()
+ self.param = UninitializedParameter()
+
+ def initialize_parameters(self, x):
+ self.param.materialize(x.shape)
+
+
+def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool:
+ requires_grad = any([p.requires_grad for p in module.parameters(recurse)])
+ return requires_grad
+
+
+def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool:
+ requires_grad = any(p.requires_grad for p in module.parameters(recurse))
+ return requires_grad
+
+
+class ParametersModule1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(10, 10)
+ self.scale = torch.nn.Parameter(torch.randn(1, 10))
+
+ def forward(self, x):
+ if not requires_grad1(self):
+ return F.relu(self.linear1(x)) * self.scale
+ else:
+ return x + 1
+
+
+class ParametersModule2(ParametersModule1):
+ def forward(self, x):
+ if not requires_grad2(self):
+ return F.relu(self.linear1(x)) * self.scale
+ else:
+ return x + 1
+
+
+class ParametersModule3(ParametersModule1):
+ def forward(self, x):
+ ones = torch.ones(10, dtype=next(self.parameters()).dtype)
+ return F.relu(self.linear1(x)) * self.scale + ones
+
+
+class SuperModule(BasicModule):
+ def forward(self, x):
+ x = super().forward(x)
+ return x + 10.0
+
+
+class ComplicatedSuperParent(torch.nn.Module):
+ @classmethod
+ def custom_add(cls, x):
+ x = x + x
+ return x
+
+
+class SuperChildCallsClassMethod(ComplicatedSuperParent):
+ @classmethod
+ def child_func(cls, x):
+ x = super().custom_add(x)
+ return x
+
+ def forward(self, x):
+ x = self.child_func(x)
+ return x
+
+
+class HasAttrModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.scale = torch.nn.Parameter(torch.randn(1, 10))
+
+ def forward(self, x):
+ x = F.relu(x)
+ if hasattr(self, "scale"):
+ x *= self.scale
+ if hasattr(self, "scale2"):
+ x *= self.scale2
+ return x
+
+
+class EnumValues(torch.nn.ModuleDict):
+ def __init__(
+ self,
+ num_layers: int = 3,
+ ) -> None:
+ super().__init__()
+ for i in range(num_layers):
+ self.add_module("denselayer%d" % (i + 1), _Block())
+
+ def forward(self, init_features):
+ features = [init_features]
+ for idx, layer in enumerate(self.values()):
+ new_features = layer(features)
+ features.append(new_features)
+ return torch.cat(features, 1)
+
+
+class CallForwardDirectly(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer1 = BasicModule()
+ self.layer2 = torch.nn.Linear(10, 10)
+
+ def forward(self, x):
+ x = self.layer1.forward(x)
+ x = self.layer2.forward(x)
+ return x
+
+
+class ModuleNameString(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = torch.nn.Linear(10, 10)
+
+ def forward(self, x):
+ if self.__class__.__name__ == "ABC":
+ return 10
+ if self.linear1.__class__.__name__ == "Linear":
+ return F.relu(self.linear1(x) + 10)
+ return 11
+
+
+class SelfMutatingModule(torch.nn.Module):
+ def __init__(self, layer):
+ super().__init__()
+ self.layer = layer
+ self.counter = 0
+
+ def forward(self, x):
+ result = self.layer(x) + self.counter
+ self.counter += 1
+ return F.relu(result)
+
+
+class ModuleAttributePrecedenceBase(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def linear(self, x):
+ return x * 2.0
+
+
+class ModuleAttributePrecedence(ModuleAttributePrecedenceBase):
+ def __init__(self):
+ super().__init__()
+ self.activation = torch.nn.ReLU()
+ self.linear = torch.nn.Linear(10, 10)
+ self.initializer = torch.ones([10, 10])
+ self.scale = 0.5
+
+ def activation(self, x):
+ return x * 1.2
+
+ def initializer(self):
+ return torch.zeros([10, 10])
+
+ def scale(self):
+ return 2.0
+
+ def forward(self, x):
+ # object attribute takes precedence unless it's a nn.Module
+ return self.activation(self.linear(self.initializer + x)) * self.scale
+
+
+def make_test(fn, expected_ops=None):
+ def test_fn(self):
+ return torch._dynamo.testing.standard_test(
+ self, fn=fn, nargs=1, expected_ops=expected_ops
+ )
+
+ fn.eval()
+ return test_fn
+
+
+class NNModuleTests(torch._dynamo.testing.TestCase):
+ test_seq = make_test(Seq())
+ test_basicmodule1 = make_test(BasicModule())
+ test_basicmodule2 = make_test(BasicModule())
+ test_submodules1 = make_test(SubmoduleExample())
+ test_submodules2 = make_test(SubmoduleExample())
+ test_modulemethod1 = make_test(ModuleMethodCall())
+ test_modulemethod2 = make_test(ModuleMethodCall())
+ test_module_static_method = make_test(ModuleStaticMethodCall())
+ test_fnmember = make_test(FnMember())
+ test_fnmembercmp1 = make_test(FnMemberCmp(F.relu))
+ test_fnmembercmp2 = make_test(FnMemberCmp(None))
+ test_constloop = make_test(ConstLoop())
+ test_istraining1 = make_test(IsTrainingCheck())
+ test_istraining2 = make_test(IsTrainingCheck())
+ test_iseval1 = make_test(IsEvalCheck())
+ test_iseval2 = make_test(IsEvalCheck())
+ test_viamodulecall = make_test(ViaModuleCall())
+ test_isnonelayer = make_test(IsNoneLayer())
+ test_layerlist = make_test(LayerList())
+ test_tensorlist = make_test(TensorList())
+ test_intarg = make_test(IntArg())
+ test_cfgmod = make_test(CfgModule())
+ test_stringmember = make_test(StringMember())
+ test_modulelist = make_test(ModuleList())
+ test_moduledict = make_test(ModuleDict())
+ test_super1 = make_test(SuperModule())
+ test_super_class_method = make_test(SuperChildCallsClassMethod())
+ test_children = make_test(Children())
+ test_densenet = make_test(DenseNetBlocks())
+ test_parameters1 = make_test(ParametersModule1())
+ test_parameters2 = make_test(ParametersModule2())
+ test_parameters3 = make_test(ParametersModule3(), expected_ops=5)
+ test_hasattr = make_test(HasAttrModule())
+ test_enumvalues = make_test(EnumValues())
+ test_module_class_method = make_test(ModuleClassMethodCall())
+ test_module_property = make_test(ModuleProperty())
+ test_forward_directly = make_test(CallForwardDirectly())
+ test_module_name_string = make_test(ModuleNameString())
+ test_module_attribute_precedence = make_test(ModuleAttributePrecedence())
+
+ def test_unsupportedmethod(self):
+ m = UnsupportedMethodCall()
+ i = torch.randn(10)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_m = torch._dynamo.optimize(cnt)(m)
+ r = opt_m(i)
+ self.assertTrue(torch._dynamo.testing.same(r, m(i)))
+ self.assertEqual(cnt.op_count, 5)
+
+ def test_unsupportedmodule(self):
+ m = UnsupportedModuleCall()
+ i = torch.randn(10)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_m = torch._dynamo.optimize(cnt)(m)
+ r = opt_m(i)
+ self.assertTrue(torch._dynamo.testing.same(r, m(i)))
+ self.assertEqual(cnt.op_count, 6)
+
+ def test_self_mutating1(self):
+ m1 = torch.nn.Linear(10, 10)
+ m2 = SelfMutatingModule(m1)
+ m3 = SelfMutatingModule(m1)
+ m4 = SelfMutatingModule(m1)
+ i = torch.randn(10)
+ out2 = [m2(i), m2(i), m2(i)]
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_m3 = torch._dynamo.optimize_assert(cnt)(m3)
+ opt_m4 = torch._dynamo.optimize_assert(cnt)(m4)
+ out3 = [opt_m3(i), opt_m3(i), opt_m3(i)]
+ out4 = [opt_m4(i), opt_m4(i), opt_m4(i)]
+ self.assertTrue(torch._dynamo.testing.same(out2, out3))
+ self.assertTrue(torch._dynamo.testing.same(out2, out4))
+ self.assertEqual(cnt.frame_count, 3)
+
+ @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
+ def test_generation_tag(self):
+ cnt = torch._dynamo.testing.CompileCounter()
+
+ # guarantee that we have installed
+ # the generation tagging function
+ with torch._dynamo.optimize_assert(cnt):
+ pass
+
+ m1 = torch.nn.Linear(10, 10)
+ prev_generation = GenerationTracker.get_generation_value(m1)
+ cur_generation = prev_generation + 1
+
+ with torch._dynamo.optimize_assert(cnt):
+ m2 = torch.nn.Linear(10, 10)
+
+ self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation)
+ self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation)
+ # check that newly constructed instances
+ # also have the same generation (even if copied from an old instance)
+ m3 = deepcopy(m1)
+ self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation)
+
+ def test_simple_torch_function(self):
+ def foo(x):
+ # function call, twice to test wrapping
+ x = F.sigmoid(x)
+ x = F.sigmoid(x)
+ # method call, twice to test wrapping
+ x = x.sigmoid()
+ x = x.sigmoid()
+ return x
+
+ class TensorProxy(torch.Tensor):
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ return super().__torch_function__(func, types, args, kwargs)
+
+ torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
+
+ x = torch.randn(1).as_subclass(TensorProxy)
+ cnt = torch._dynamo.testing.CompileCounter()
+ out1 = foo(x)
+ opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
+ out2 = opt_foo(x)
+
+ self.assertEqual(cnt.op_count, 4)
+ self.assertTrue(torch._dynamo.testing.same(out1, out2))
+
+ torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
+
+ def test_torch_function_with_closure(self):
+ def run():
+
+ counter = 0
+
+ def foo(x):
+ # function call, twice to test wrapping
+ x = F.sigmoid(x)
+ x = F.sigmoid(x)
+ # method call, twice to test wrapping
+ x = x.sigmoid()
+ x = x.sigmoid()
+ return x
+
+ class TensorProxy(torch.Tensor):
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ nonlocal counter
+ # for now, only support reads from closure cells
+ # TODO(future PR): support writes as well
+ counter + 1
+ return super().__torch_function__(func, types, args, kwargs)
+
+ torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
+
+ x = torch.randn(1).as_subclass(TensorProxy)
+ x = torch.randn(1)
+ cnt = torch._dynamo.testing.CompileCounter()
+ out1 = foo(x)
+ opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
+ out2 = opt_foo(x)
+
+ self.assertEqual(cnt.op_count, 4)
+ self.assertTrue(torch._dynamo.testing.same(out1, out2))
+
+ torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
+
+ run()
+
+ @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
+ def test_nn_moduledict_contains(self):
+ class M(torch.nn.Module):
+ def __init__(self, module_dict):
+ super().__init__()
+ self.module_dict = module_dict
+
+ def forward(self, x):
+ if "foo" in self.module_dict:
+ x = torch.mul(x, 1.0)
+ x = torch.add(x, 1.0)
+ return x
+
+ module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)})
+ m = M(module_dict)
+ data = torch.randn(1)
+ out1 = m(data)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
+ out2 = opt_m(data)
+ self.assertEqual(cnt.op_count, 2)
+ self.assertTrue(torch._dynamo.testing.same(out1, out2))
+
+ module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)})
+ m = M(module_dict)
+ data = torch.randn(1)
+ out1 = m(data)
+ cnt = torch._dynamo.testing.CompileCounter()
+ torch._dynamo.reset()
+ opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
+ out2 = opt_m(data)
+
+ self.assertEqual(cnt.op_count, 1)
+ self.assertTrue(torch._dynamo.testing.same(out1, out2))
+
+ module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
+ pre = m(data)
+ cnt.clear()
+
+ with torch._dynamo.optimize(cnt, nopython=False):
+ opt_pre = m(data)
+ m = M(module_dict)
+ data = torch.randn(1)
+ out1 = m(data)
+
+ out_post = m(data)
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 1)
+ self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
+ self.assertTrue(torch._dynamo.testing.same(out1, out_post))
+
+ def test_lazy_module(self):
+ input_shape = (16, 3, 6, 7, 8)
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ module = LazyModule()
+
+ def test_static_module():
+ input = torch.ones(*input_shape)
+ module(input)
+
+ opt_test_static_module = torch._dynamo.optimize(cnt)(test_static_module)
+ opt_test_static_module()
+
+ self.assertTrue(
+ isinstance(module, MaterializedModule),
+ "Module should be transformed to an instance of MaterializedModule.",
+ )
+ self.assertEqual(module.param.shape, input_shape)
+
+ # test when mapped to UnspecializedNNModule
+ module = LazyModule()
+
+ def test_unspecialized():
+ nonlocal module
+ module = LazyModule()
+ input = torch.ones(*input_shape)
+ module(input)
+
+ opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized)
+ opt_test_unspecialized()
+
+ self.assertTrue(
+ isinstance(module, MaterializedModule),
+ "Module should be transformed to an instance of MaterializedModule.",
+ )
+ self.assertEqual(module.param.shape, input_shape)
+
+ # test with a static module in torch.*
+ module = torch.nn.modules.LazyBatchNorm3d(
+ affine=False, track_running_stats=False
+ )
+
+ cnt = torch._dynamo.testing.CompileCounter()
+
+ torch._dynamo.reset()
+
+ def test_torch_static():
+ input = torch.ones(*input_shape)
+ return module(input) # fully materialized
+
+ opt_test_torch_static = torch._dynamo.optimize(cnt)(test_torch_static)
+ opt_test_torch_static()
+ out = opt_test_torch_static()
+
+ self.assertTrue(same(out, module(torch.ones(*input_shape))))
+
+ self.assertTrue(
+ isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d),
+ "Module should be transformed to an instance of BatchNorm3d.",
+ )
+ self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.")
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py
new file mode 100644
index 0000000000000..6b2faec3d1d54
--- /dev/null
+++ b/test/dynamo/test_no_fake_tensors.py
@@ -0,0 +1,29 @@
+# Owner(s): ["module: dynamo"]
+from torch._dynamo.testing import make_test_cls_with_patches
+
+try:
+ from . import test_functions, test_misc, test_modules, test_repros, test_unspec
+except ImportError:
+ import test_functions
+ import test_misc
+ import test_modules
+ import test_repros
+ import test_unspec
+
+
+def make_no_fake_cls(cls):
+ return make_test_cls_with_patches(
+ cls, "NoFakeTensors", "_no_fake_tensors", ("fake_tensor_propagation", False)
+ )
+
+
+NoFakeTensorsFunctionTests = make_no_fake_cls(test_functions.FunctionTests)
+NoFakeTensorsMiscTests = make_no_fake_cls(test_misc.MiscTests)
+NoFakeTensorsReproTests = make_no_fake_cls(test_repros.ReproTests)
+NoFakeTensorsNNModuleTests = make_no_fake_cls(test_modules.NNModuleTests)
+NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests)
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_nops.py b/test/dynamo/test_nops.py
new file mode 100644
index 0000000000000..de52315e12efd
--- /dev/null
+++ b/test/dynamo/test_nops.py
@@ -0,0 +1,71 @@
+# Owner(s): ["module: dynamo"]
+import torch
+
+import torch._dynamo.testing
+from torch._dynamo import eval_frame
+
+c = 10
+
+
+def fn1(a, b):
+ return a + b - c
+
+
+def fn2(a, b):
+ x = 0
+ y = 1
+
+ def modify():
+ nonlocal x
+ x += a + b + c
+
+ for _ in range(2):
+ modify()
+
+ return x + y
+
+
+def fn3():
+ yield 1
+ yield 2
+
+
+with_debug_nops = eval_frame._optimize_catch_errors(
+ torch._dynamo.testing.debug_insert_nops
+)
+
+
+class NopTests(torch._dynamo.testing.TestCase):
+ @with_debug_nops
+ def test1(self):
+ self.assertEqual(fn1(1, 2), -7)
+ self.assertEqual(fn1(1, 2), -7)
+
+ @with_debug_nops
+ def test2(self):
+ self.assertEqual(fn2(1, 2), 27)
+ self.assertEqual(fn2(1, 2), 27)
+
+ @with_debug_nops
+ def test3(self):
+ t = fn3()
+ self.assertEqual(next(t), 1)
+ self.assertEqual(next(t), 2)
+ self.assertRaises(StopIteration, lambda: next(t))
+
+ def test_extended_args(self):
+ too_many_adds = "+".join(["a", "b"] * 256)
+ source = (
+ f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)"
+ )
+ fn = eval(source)
+ a = torch.ones(1)
+ b = torch.ones(1)
+ fn = with_debug_nops(fn)
+ self.assertEqual(fn(a, b).sum(), 513)
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_optimizations.py b/test/dynamo/test_optimizations.py
new file mode 100644
index 0000000000000..b58d7a44e5990
--- /dev/null
+++ b/test/dynamo/test_optimizations.py
@@ -0,0 +1,207 @@
+# Owner(s): ["module: dynamo"]
+import importlib
+import json
+import os
+import unittest
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo
+from torch._dynamo.optimizations import backends
+from torch._dynamo.optimizations.analysis import has_mutation
+from torch._dynamo.optimizations.log_args import conv_args_analysis
+from torch._dynamo.optimizations.normalize import Inplacifier, normalize
+from torch._dynamo.testing import same
+
+
+def has_onnxruntime():
+ try:
+ importlib.import_module("onnxruntime")
+ return True
+ except ImportError:
+ return False
+
+
+def has_ipex():
+ try:
+ importlib.import_module("intel_extension_for_pytorch")
+ return True
+ except ImportError:
+ return False
+
+
+def has_functorch():
+ try:
+ importlib.import_module("functorch")
+ return True
+ except ImportError:
+ return False
+
+
+class Seq(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ torch.nn.Linear(10, 10),
+ torch.nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class Conv_Bn_Relu(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(Conv_Bn_Relu, self).__init__()
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+ self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ return self.relu(self.bn(self.conv(x)))
+
+
+class TestOptimizations(torch._dynamo.testing.TestCase):
+ def test_inplacifier(self):
+ gm = torch.fx.symbolic_trace(Seq())
+ normalize(gm)
+ Inplacifier(gm).inplacify()
+ gm.recompile()
+ code = gm.code.replace(" ", "")
+ self.assertIn("inplace=True", code)
+ self.assertIn("out=linear_1", code)
+
+ def test_has_mutation(self):
+ gm = torch.fx.symbolic_trace(Seq())
+ self.assertFalse(has_mutation(gm, torch.rand([10, 10])))
+
+ class Mutating(torch.nn.Module):
+ def __init__(self):
+ super(Mutating, self).__init__()
+
+ def forward(self, arg):
+ return arg.add_(1)
+
+ gm = torch.fx.symbolic_trace(Mutating())
+ self.assertTrue(has_mutation(gm, torch.rand([10, 1, 1, 1])))
+
+ def test_has_mutation_factory(self):
+ def fn():
+ x = torch.empty(2)
+ x.fill_(2)
+ return x
+
+ def compiler_fn(graph, example_inputs):
+ self.assertTrue(has_mutation(graph, example_inputs))
+ return graph
+
+ opt_fn = torch._dynamo.optimize(compiler_fn)(fn)
+ opt_fn()
+
+ def test_example_inputs(self):
+ def fn(a, bc, d):
+ b, c = bc
+ return a / d - b / c
+
+ def compiler_fn(graph, example_inputs):
+ nonlocal r1
+ r1 = graph(*example_inputs)[0]
+ return graph.forward
+
+ a = torch.empty(2).fill_(1)
+ b = torch.empty(2).fill_(2)
+ c = torch.empty(2).fill_(3)
+ d = 4
+ r1 = None
+ r2 = fn(a, (b, c), d)
+ opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
+ r3 = opt_fn(a, (b, c), d)
+
+ self.assertIsNotNone(r1)
+ self.assertTrue(same(r1, r2))
+ self.assertTrue(same(r1, r3))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ @unittest.skipIf(not has_functorch(), "requires functorch")
+ def test_log_conv_args(self):
+ model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
+ model = model.to(memory_format=torch.channels_last)
+ model = model.eval()
+ input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
+ r1 = model(input)
+ # check tmp/conv_args.json exists and has keys as arg names
+ filename = "tmp/conv_args.json"
+ if os.path.exists(filename):
+ os.remove(filename)
+ opt_model = torch._dynamo.optimize(conv_args_analysis)(model)
+ with torch.no_grad():
+ r2 = opt_model(input)
+ self.assertTrue(same(r1, r2.float(), tol=0.1))
+ self.assertTrue(os.path.exists(filename))
+ with open(filename) as f:
+ args_dict = json.load(f)
+ self.assertIn("convolution", args_dict.keys())
+ conv_args_dict = args_dict["convolution"]
+ self.assertIn("input", conv_args_dict.keys())
+ self.assertIn("weight", conv_args_dict.keys())
+ self.assertIn("bias", conv_args_dict.keys())
+ self.assertIn("stride", conv_args_dict.keys())
+ self.assertIn("padding", conv_args_dict.keys())
+ self.assertIn("dilation", conv_args_dict.keys())
+ self.assertIn("transposed", conv_args_dict.keys())
+ self.assertIn("output_padding", conv_args_dict.keys())
+ self.assertIn("groups", conv_args_dict.keys())
+ os.remove(filename)
+
+ @unittest.skipIf(not has_ipex(), "requires ipex")
+ def test_ipex_fp32(self):
+ model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
+ model = model.to(memory_format=torch.channels_last)
+ model = model.eval()
+ input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
+ r1 = model(input)
+ opt_model = torch._dynamo.optimize(backends.ipex_fp32)(model)
+ with torch.no_grad():
+ r2 = opt_model(input)
+ self.assertTrue(same(r1, r2))
+ self.assertEqual(r2.dtype, torch.float32)
+
+ @unittest.skipIf(not has_ipex(), "requires ipex")
+ def test_ipex_bf16(self):
+ model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
+ model = model.to(memory_format=torch.channels_last)
+ model = model.eval()
+ input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
+ r1 = model(input)
+ opt_model = torch._dynamo.optimize(backends.ipex_bf16)(model)
+ with torch.no_grad(), torch.cpu.amp.autocast():
+ r2 = opt_model(input)
+ self.assertTrue(same(r1, r2.float(), tol=0.1))
+ self.assertEqual(r2.dtype, torch.bfloat16)
+
+
+class NormalizeIRTests(torch._dynamo.testing.TestCase):
+ @unittest.skipIf(not has_functorch(), "requires functorch")
+ def test_inplace_normalize(self):
+ def fn(a, b):
+ x = torch.cos(a)
+ x += b
+ return torch.sin(x)
+
+ a = torch.randn(10)
+ b = torch.randn(10).to(torch.float64)
+
+ ref = fn(a, b)
+
+ optimized_fn = torch._dynamo.optimize("aot_eager")(fn)
+ res = optimized_fn(a, b)
+ self.assertTrue(same(ref, res))
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py
new file mode 100644
index 0000000000000..122c5c06b069f
--- /dev/null
+++ b/test/dynamo/test_optimizers.py
@@ -0,0 +1,102 @@
+# Owner(s): ["module: dynamo"]
+
+import inspect
+import unittest
+
+import torch
+
+import torch._dynamo
+import torch._dynamo.testing
+
+input = torch.ones([10, 10])
+model = torch.nn.Sequential(*[torch.nn.Linear(10, 10) for _ in range(2)])
+model(input).sum().backward()
+
+
+def make_test(optim_cls, exp_frame_cnt=1, closure=None, **kwargs):
+ opt = optim_cls(model.parameters(), **kwargs)
+
+ def test_fn(self):
+ nonlocal opt
+
+ counter = torch._dynamo.testing.CompileCounter()
+
+ if closure is not None:
+
+ def fn():
+ opt.step(closure)
+
+ else:
+ fn = opt.step
+
+ opt_fn = torch._dynamo.optimize(counter)(fn)
+ opt_fn()
+
+ self.assertEqual(counter.frame_count, exp_frame_cnt)
+
+ return test_fn
+
+
+class OptimizerTests(torch._dynamo.testing.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ # needed until pytorch assertion is changed to enable Adam
+ # to be called with capturable=True
+ cls._exit_stack.enter_context(
+ unittest.mock.patch.object(
+ torch._dynamo.config, "capture_scalar_outputs", True
+ )
+ )
+ cls._exit_stack.enter_context(
+ unittest.mock.patch.object(
+ torch._dynamo.config, "fake_tensor_propagation", False
+ )
+ )
+ cls._exit_stack.enter_context(
+ unittest.mock.patch.object(
+ torch._dynamo.config, "raise_on_assertion_error", True
+ )
+ )
+
+ test_sgd = make_test(torch.optim.SGD, lr=0.01)
+ # lgbfs has data-dependent control and internally iterates
+ # calling the closure
+ # TODO mlazos: re-enable once we have latest pytorch with FakeTensor fix #497
+ # test_lbfgs = make_test(
+ # torch.optim.LBFGS, exp_frame_cnt=3, closure=lambda: model(input).sum()
+ # )
+ # RAdam has data-dependent control which breaks the graph
+ test_radam = make_test(torch.optim.RAdam, exp_frame_cnt=1)
+
+ # ASGD has a small optimization that avoids averaging
+ # This will fully capture the graph once that optimization is removed
+ # NB: in python versions < 3.8, we don't capture graphs when breaks
+ # occur in a loop
+
+ # Fails without fake tensor:
+ # TypeError: clamp() received an invalid combination of arguments - got (float, min=int)
+ # test_asgd = make_test(
+ # torch.optim.ASGD, exp_frame_cnt=(0 if sys.version_info < (3, 8) else 6)
+ # )
+
+
+# exclude SparseAdam because other areas of the stack don't support it yet
+# the others are handled specially above
+exclude = set(["SGD", "Optimizer", "SparseAdam", "LBFGS", "RAdam", "ASGD"])
+optimizers = [
+ opt
+ for opt in torch.optim.__dict__.values()
+ if inspect.isclass(opt)
+ and issubclass(opt, torch.optim.Optimizer)
+ and opt.__name__ not in exclude
+]
+
+
+for opt in optimizers:
+ setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt))
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py
new file mode 100644
index 0000000000000..fe2f2819f20d5
--- /dev/null
+++ b/test/dynamo/test_python_autograd.py
@@ -0,0 +1,288 @@
+# Owner(s): ["module: dynamo"]
+from typing import Callable, Dict, List, NamedTuple, Optional
+
+import torch
+
+import torch._dynamo
+from torch._dynamo.testing import CompileCounter, same, TestCase
+
+"""
+This is an example of a pure-python version of autograd implemented by
+@zdevito. It represents a rather challenging test case for TorchDynamo
+to push the limits of what it can do.
+"""
+
+
+_name: int = 0
+
+
+def fresh_name() -> str:
+ """create a new unique name for a variable: v0, v1, v2"""
+ global _name
+ r = f"v{_name}"
+ _name += 1
+ return r
+
+
+class Variable:
+ def __init__(self, value: torch.Tensor, name: str = None):
+ self.value = value
+ self.name = name or fresh_name()
+
+ # We need to start with some tensors whose values were not computed
+ # inside the autograd. This function constructs leaf nodes.
+ @staticmethod
+ def constant(value: torch.Tensor, name: str = None):
+ return Variable(value, name)
+
+ def __repr__(self):
+ return repr(self.value)
+
+ # This performs a pointwise multiplication of a Variable, tracking gradients
+ def __mul__(self, rhs: "Variable") -> "Variable":
+ # defined later in the notebook
+ return operator_mul(self, rhs)
+
+ def __add__(self, rhs: "Variable") -> "Variable":
+ return operator_add(self, rhs)
+
+ def sum(self, name: Optional[str] = None) -> "Variable":
+ return operator_sum(self, name)
+
+ def expand(self, sizes: List[int]) -> "Variable":
+ return operator_expand(self, sizes)
+
+
+class TapeEntry(NamedTuple):
+ # names of the inputs to the original computation
+ inputs: List[str]
+ # names of the outputs of the original computation
+ outputs: List[str]
+ # apply chain rule
+ propagate: "Callable[List[Variable], List[Variable]]"
+
+
+gradient_tape: List[TapeEntry] = []
+
+
+def reset_tape():
+ gradient_tape.clear()
+ global _name
+ _name = 0
+
+
+def grad(L, desired_results: List[Variable]) -> List[Variable]:
+ # this map holds dL/dX for all values X
+ dL_d: Dict[str, Variable] = {}
+ # It starts by initializing the 'seed' dL/dL, which is 1
+ dL_d[L.name] = Variable(torch.ones(()))
+ # print(f'd{L.name} ------------------------')
+
+ # look up dL_dentries. If a variable is never used to compute the loss,
+ # we consider its gradient None, see the note below about zeros for more information.
+ def gather_grad(entries: List[str]):
+ return [dL_d[entry] if entry in dL_d else None for entry in entries]
+
+ # propagate the gradient information backward
+ for entry in reversed(gradient_tape):
+ dL_doutputs = gather_grad(entry.outputs)
+ if all(dL_doutput is None for dL_doutput in dL_doutputs):
+ # optimize for the case where some gradient pathways are zero. See
+ # The note below for more details.
+ continue
+
+ # perform chain rule propagation specific to each compute
+ dL_dinputs = entry.propagate(dL_doutputs)
+
+ # Accululate the gradient produced for each input.
+ # Each use of a variable produces some gradient dL_dinput for that
+ # use. The multivariate chain rule tells us it is safe to sum
+ # all the contributions together.
+ for input, dL_dinput in zip(entry.inputs, dL_dinputs):
+ if input not in dL_d:
+ dL_d[input] = dL_dinput
+ else:
+ dL_d[input].value += dL_dinput.value
+
+ # print some information to understand the values of each intermediate
+ # for name, value in dL_d.items():
+ # print(f'd{L.name}_d{name} = {value.name}')
+ # print(f'------------------------')
+
+ return gather_grad(desired.name for desired in desired_results)
+
+
+def operator_mul(self: Variable, rhs: Variable) -> Variable:
+ if isinstance(rhs, float) and rhs == 1.0:
+ # peephole optimization
+ return self
+
+ # define forward
+ r = Variable(self.value * rhs.value)
+ # print(f'{r.name} = {self.name} * {rhs.name}')
+
+ # record what the inputs and outputs of the op were
+ inputs = [self.name, rhs.name]
+ outputs = [r.name]
+
+ # define backprop
+ def propagate(dL_doutputs: List[Variable]):
+ (dL_dr,) = dL_doutputs
+
+ dr_dself = rhs # partial derivative of r = self*rhs
+ dr_drhs = self # partial derivative of r = self*rhs
+
+ # chain rule propagation from outputs to inputs of multiply
+ dL_dself = dL_dr * dr_dself
+ dL_drhs = dL_dr * dr_drhs
+ dL_dinputs = [dL_dself, dL_drhs]
+ return dL_dinputs
+
+ # finally, we record the compute we did on the tape
+ gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate))
+ return r
+
+
+def operator_add(self: Variable, rhs: Variable) -> Variable:
+ # Add follows a similar pattern to Mul, but it doesn't end up
+ # capturing any variables.
+ r = Variable(self.value + rhs.value)
+ # print(f'{r.name} = {self.name} + {rhs.name}')
+
+ def propagate(dL_doutputs: List[Variable]):
+ (dL_dr,) = dL_doutputs
+ dr_dself = 1.0
+ dr_drhs = 1.0
+ dL_dself = dL_dr * dr_dself
+ dL_drhs = dL_dr * dr_drhs
+ return [dL_dself, dL_drhs]
+
+ gradient_tape.append(
+ TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate)
+ )
+ return r
+
+
+def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
+ r = Variable(torch.sum(self.value), name=name)
+ # print(f'{r.name} = {self.name}.sum()')
+
+ def propagate(dL_doutputs: List[Variable]):
+ (dL_dr,) = dL_doutputs
+ size = self.value.size()
+ return [dL_dr.expand(*size)]
+
+ gradient_tape.append(
+ TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
+ )
+ return r
+
+
+def operator_expand(self: Variable, sizes: List[int]) -> "Variable":
+ assert self.value.dim() == 0 # only works for scalars
+ r = Variable(self.value.expand(sizes))
+ # print(f'{r.name} = {self.name}.expand({sizes})')
+
+ def propagate(dL_doutputs: List[Variable]):
+ (dL_dr,) = dL_doutputs
+ return [dL_dr.sum()]
+
+ gradient_tape.append(
+ TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
+ )
+ return r
+
+
+def simple(a, b):
+ t = a + b
+ return t * b
+
+
+class TestPythonAutograd(TestCase):
+ def _common(self, fn, expected_ops):
+ args1 = [torch.randn(10), torch.randn(10)]
+ args2 = [torch.randn(10), torch.randn(10)]
+ cnt = CompileCounter()
+ fn_dynamo = torch._dynamo.optimize_assert(cnt)(fn)
+ reset_tape()
+ res1 = fn_dynamo(*args1)
+ reset_tape()
+ res2 = fn_dynamo(*args2)
+ reset_tape()
+ self.assertTrue(same(res1, fn(*args1)))
+ reset_tape()
+ self.assertTrue(same(res2, fn(*args2)))
+ reset_tape()
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, expected_ops)
+
+ def test_forwards1(self):
+ def fn(a, b):
+ a = Variable.constant(a, name="a")
+ b = Variable.constant(b, name="b")
+ loss = simple(a, b).sum()
+ return loss
+
+ self._common(fn, 3)
+
+ def test_forwards2(self):
+ def fn(a, b):
+ reset_tape()
+ a = Variable.constant(a, name="a")
+ b = Variable.constant(b, name="b")
+ loss = simple(a, b).sum()
+ reset_tape()
+ return loss
+
+ self._common(fn, 3)
+
+ def test_backwards1(self):
+ def fn(a, b):
+ a = Variable.constant(a, name="a")
+ b = Variable.constant(b, name="b")
+ loss = simple(a, b).sum()
+ return grad(loss, [a, b])
+
+ self._common(fn, 8)
+
+ def test_backwards2(self):
+ def fn(a, b):
+ reset_tape()
+ a = Variable.constant(a, name="a")
+ b = Variable.constant(b, name="b")
+ loss = simple(a, b).sum()
+ res = grad(loss, [a, b])
+ reset_tape()
+ return res
+
+ self._common(fn, 8)
+
+ def test_split(self):
+ v1 = Variable.constant(torch.randn(10), name="a")
+ v2 = Variable.constant(torch.randn(10), name="b")
+ cnt = CompileCounter()
+
+ def forward(a, b):
+ return simple(a, b).sum()
+
+ reset_tape()
+ loss1 = forward(v1, v2)
+ grad1 = grad(loss1, [v1, v2])
+
+ reset_tape()
+ opt_forward = torch._dynamo.optimize_assert(cnt)(forward)
+ opt_grad = torch._dynamo.optimize_assert(cnt)(grad)
+ loss2 = opt_forward(v1, v2)
+ # force two frames
+ grad2 = opt_grad(loss2, [v1, v2])
+
+ self.assertTrue(same(loss1, loss2))
+ self.assertTrue(same(grad1, grad2))
+ self.assertEqual(cnt.frame_count, 2)
+ self.assertEqual(cnt.op_count, 8)
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py
new file mode 100644
index 0000000000000..00e99ab3f2024
--- /dev/null
+++ b/test/dynamo/test_recompile_ux.py
@@ -0,0 +1,204 @@
+# Owner(s): ["module: dynamo"]
+import unittest
+import weakref
+
+import torch
+
+import torch._dynamo
+import torch._dynamo.config
+import torch._dynamo.testing
+
+
+class RecompileUxTests(torch._dynamo.testing.TestCase):
+ # TODO(whc) dynamo actualy recompiles one more time than the cache limit
+ cache_limit = 1
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._exit_stack.enter_context(
+ unittest.mock.patch.object(
+ torch._dynamo.config, "cache_size_limit", cls.cache_limit
+ )
+ )
+
+ def test_drop_cache_on_skip(self):
+ def model(x, i):
+ return x + i
+
+ attached = False
+ triggered = False
+
+ def trigger():
+ nonlocal triggered
+ triggered = True
+
+ def compiler(gm, input):
+ nonlocal attached
+ f = gm.forward
+ assert not attached
+ # NB: making this a weakref.ref causes the cycle to no
+ # longer be promptly GC'ed
+ weakref.finalize(f, trigger)
+ attached = True
+ return f
+
+ x = torch.randn(2)
+ for i in range(2):
+ opt_model = torch._dynamo.optimize(compiler)(model)
+ opt_model(x, i)
+
+ self.assertTrue(triggered)
+
+ def test_loop_torture(self):
+ def loop_torture(input, iters):
+ out = input
+ # randint itself causes one graph break
+ for _ in range(iters):
+ out += input
+ return out
+
+ compile_counter = torch._dynamo.testing.CompileCounter()
+ for _ in range(10):
+ x = torch.randn(3)
+ iters = torch.randint(low=0, high=1000, size=())
+ opt_loop_torture = torch._dynamo.optimize(compile_counter)(loop_torture)
+ opt_loop_torture(x, iters)
+
+ # Currently, we recompile each time,
+ # We'd probably like to bail out quickly and warn
+ # TODO(whc) these checks fail on py37. Why?
+ # self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit)
+ # self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit)
+
+ # compile_counter only sees frames that were fed to the backend compiler,
+ # which is a subset of counters["frames"]["ok"] -- probably becuase
+ # counters["frames"]["ok"] includes frames not containing torch ops?
+ self.assertEqual(compile_counter.frame_count, self.cache_limit)
+
+ def test_dynamic_input(self):
+ def model(input):
+ return input + input
+
+ expected_recompiles = 2
+ compile_counter = torch._dynamo.testing.CompileCounter()
+ with unittest.mock.patch.object(
+ torch._dynamo.config, "cache_size_limit", expected_recompiles
+ ):
+ with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
+ for _ in range(10):
+ bsz = torch.randint(low=0, high=1000, size=())
+ x = torch.randn((bsz, 3, 4))
+ opt_model = torch._dynamo.optimize(compile_counter)(model)
+ opt_model(x)
+
+ self.assertEqual(compile_counter.frame_count, expected_recompiles)
+ self.assertEqual(len(logs.records), 1)
+ print(logs.records[0])
+ self.assertTrue(
+ logs.records[0]
+ .getMessage()
+ .startswith("torch._dynamo hit config.cache_size_limit")
+ )
+
+ @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
+ def test_nvfuser_guards(self):
+ # we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards
+ # such that we ensure dynamo is in charge of all the recompilations at the top level,
+ # and we could thus simplfy the underlying torchscript executor
+ def func(a, b, c):
+ return a + b * c
+
+ a = torch.rand(3, 4, 5, device="cuda")
+ b = torch.rand(3, 4, 5, device="cuda")
+ b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5)
+ b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1)
+ c = torch.rand(3, 4, 5, device="cuda")
+ compile_counter = torch._dynamo.testing.CompileCounter()
+
+ with unittest.mock.patch.object(torch._dynamo.config, "cache_size_limit", 2):
+ opt_func = torch._dynamo.optimize(compile_counter)(func)
+ opt_func(a, b, c) # warmup
+ self.assertEqual(compile_counter.frame_count, 1)
+
+ opt_func(a, b, c) # no guard fail or recompile
+ self.assertEqual(compile_counter.frame_count, 1)
+
+ opt_func(a, b_v, c) # a view should not cause nvfuser recompile
+ self.assertEqual(compile_counter.frame_count, 1)
+
+ opt_func(a, b_p, c) # a permutation should cause recompile
+ self.assertEqual(compile_counter.frame_count, 2)
+
+ def assert_single_log_contains(self, logs, contains_str):
+ self.assertEqual(len(logs.records), 1)
+ self.assertTrue(
+ logs.records[0].getMessage().find(contains_str) > 0,
+ msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"',
+ )
+
+ def test_verbose_tensor_check(self):
+ def func(a):
+ # Warning: choose a function here whose meta implementation lives
+ # entirely in C++. If you do a Python one, Dynamo will dive into
+ # torch._refs which is OK but it will muddy up the warnings
+ return torch.add(a, 4)
+
+ def cache_fail_test(cached_input, missed_input, expected_failure):
+ # TODO(whc) maybe its hacky to have a 'test within a test' but this seemed convenient
+ torch._dynamo.reset()
+ torch._dynamo.utils.counters.clear()
+ opt_func = torch._dynamo.optimize("eager")(func)
+ # warmup
+ opt_func(cached_input)
+
+ with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
+ opt_func = torch._dynamo.optimize("eager")(func)
+ opt_func(missed_input)
+ self.assert_single_log_contains(logs, expected_failure)
+
+ a = torch.rand(3, 4, 5)
+ cache_fail_test(
+ a, a[0:2, :, :], "tensor 'a' size mismatch at index 0. expected 3, actual 2"
+ )
+ cache_fail_test(
+ a,
+ a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
+ "tensor 'a' strides mismatch at index 0. expected 20, actual 1",
+ )
+ cache_fail_test(a, a[0, :, :], "tensor 'a' rank mismatch. expected 3, actual 2")
+ cache_fail_test(a, a.to("meta"), "tensor 'a' dispatch key set mismatch.")
+ cache_fail_test(
+ a,
+ a.to(torch.float16),
+ "tensor 'a' dtype mismatch. expected Float, actual Half",
+ )
+ a_grad = a.clone()
+ a_grad.requires_grad = True
+ cache_fail_test(
+ a, a_grad, "tensor 'a' requires_grad mismatch. expected requires_grad=0"
+ )
+
+ def test_mismatched_type(self):
+ a = torch.rand(3, 4, 5)
+ b = torch.rand(3, 4, 5)
+
+ def func(a, b):
+ return a + b
+
+ opt_func = torch._dynamo.optimize("eager")(func)
+ # warmup
+ opt_func(a, b)
+
+ with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
+ opt_func = torch._dynamo.optimize("eager")(func)
+ opt_func(a, 1)
+ self.assert_single_log_contains(
+ logs, "expected type of 'b' to be a tensor type, ' but found "
+ )
+
+
+# TODO(jansel): these pass with pytest, but not with pytorch CI
+# if __name__ == "__main__":
+# from torch._dynamo.testing import run_tests
+# run_tests()
diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py
new file mode 100644
index 0000000000000..f2586b7db37ef
--- /dev/null
+++ b/test/dynamo/test_replay_record.py
@@ -0,0 +1,186 @@
+# Owner(s): ["module: dynamo"]
+import logging
+import re
+import shutil
+import unittest
+
+import torch
+
+import torch._dynamo.testing
+
+try:
+ import dill
+except ImportError:
+ dill = None
+
+requires_dill = unittest.skipIf(dill is None, "requires dill")
+
+
+class ReplayRecordTests(torch._dynamo.testing.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._exit_stack.enter_context(
+ unittest.mock.patch.object(
+ torch._dynamo.config, "replay_record_enabled", True
+ )
+ )
+ cls._exit_stack.enter_context(
+ unittest.mock.patch.object(
+ torch._dynamo.config,
+ "replay_record_dir_name",
+ "/tmp/torch._dynamo_error_records/",
+ )
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(torch._dynamo.config.replay_record_dir_name, ignore_errors=True)
+ cls._exit_stack.close()
+
+ def check_replay(self, fn, *args, exp_exc_name=None):
+ fn_opt = torch._dynamo.optimize("eager")(fn)
+ with self.assertLogs(logger="torch._dynamo", level=logging.ERROR) as log_orig:
+ try:
+ fn_opt(*args)
+ except Exception:
+ pass # we'll check the logs for the raised exception
+
+ with self.assertLogs(
+ logger="torch._dynamo", level=logging.ERROR
+ ) as log_replayed:
+ file_name_match = re.search(
+ r"torch._dynamo\.replay\('(.*)'\)", log_orig.output[-1]
+ )
+ self.assertTrue(
+ file_name_match is not None,
+ "No record file name found in generated logs.",
+ )
+
+ torch._dynamo.replay(file_name_match.groups()[0])
+
+ def get_error_name(log):
+ error_name = re.search(r"\w+Error", log.output[-1])
+ self.assertIsNotNone(error_name, "No error name found in logs.")
+ return error_name[0]
+
+ orig_error = get_error_name(log_orig)
+ replayed_error = get_error_name(log_replayed)
+ if exp_exc_name is not None:
+ self.assertEqual(orig_error, exp_exc_name)
+
+ self.assertEqual(
+ orig_error,
+ replayed_error,
+ "Error logs for recorded execution and replayed execution should match.",
+ )
+
+ @requires_dill
+ def test_unsuccessful_inline(self):
+ def level2():
+ z = torch.ones(2, 2)
+ a = {z: 10} # Error here, tensor as key to dict
+ return a[z] * torch.ones(1)
+
+ def level1():
+ y = torch.ones(1, 1)
+ return level2() + y
+
+ def level0():
+ x = torch.ones(1, 1)
+ return level1() + x
+
+ self.check_replay(level0, exp_exc_name="AssertionError")
+
+ @requires_dill
+ def test_successful_inline(self):
+ def test_fn():
+ x = torch.ones(2, 2)
+
+ def level1(a):
+ return a + torch.ones(2, 2)
+
+ y = level1(x)
+
+ return y + torch.ones(3, 3) # dimension mismatch
+
+ self.check_replay(test_fn, exp_exc_name="RuntimeError")
+
+ @requires_dill
+ def test_nonlocal_fn_call(self):
+ def nonlocal_fn(x):
+ return x + torch.ones(2, 2)
+
+ def test_fn():
+ z = torch.ones(2, 2)
+ x = nonlocal_fn(z)
+ return x + torch.ones(3, 3)
+
+ self.check_replay(test_fn, exp_exc_name="RuntimeError")
+
+ @requires_dill
+ def test_nonlocal_module_fn_call(self):
+ # replay when we use a module
+ # not defined in the replay env
+ try:
+ from . import mock_modules
+ except ImportError:
+ import mock_modules
+
+ def test_fn():
+ z = mock_modules.mock_module2.method1([], 2)
+ z = torch.ones(2, 2) + z[0]
+ return z + torch.zeros(3, 3)
+
+ self.check_replay(test_fn, exp_exc_name="RuntimeError")
+
+ @requires_dill
+ def test_nonlocal_module_class(self):
+ try:
+ from .mock_modules import mock_module2
+ except ImportError:
+ from mock_modules import mock_module2
+
+ def test_fn():
+ z = mock_module2.Class1(1, 2)
+ y = z.method2(torch.ones(3, 3))
+ return y + torch.zeros(3, 5)
+
+ self.check_replay(test_fn, exp_exc_name="TypeError")
+
+ @requires_dill
+ def test_local_module(self):
+ try:
+ from .mock_modules import mock_module3 as _ # noqa: F401
+
+ def test_fn(x):
+ from .mock_modules import mock_module3
+
+ z = mock_module3.method1([], torch.ones(5, 1))
+ return torch.ones(2, 2) + x + z[0]
+
+ except ImportError:
+
+ def test_fn(x):
+ from mock_modules import mock_module3
+
+ z = mock_module3.method1([], torch.ones(5, 1))
+ return torch.ones(2, 2) + x + z[0]
+
+ self.check_replay(test_fn, torch.ones(1, 1), exp_exc_name="RuntimeError")
+
+ # Verfiy that we replay when we have tensor arguments to the frame being replayed
+ @requires_dill
+ def test_fn_call_args(self):
+ def test_fn(x, y):
+ return x + y + torch.zeros(2, 2)
+
+ self.check_replay(
+ test_fn, torch.ones(3, 3), torch.ones(2, 2), exp_exc_name="RuntimeError"
+ )
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
new file mode 100644
index 0000000000000..db44b20cfd315
--- /dev/null
+++ b/test/dynamo/test_repros.py
@@ -0,0 +1,1717 @@
+# Owner(s): ["module: dynamo"]
+import collections
+import copy
+import inspect
+import itertools
+import random
+import unittest
+from abc import ABC
+from collections import namedtuple
+from copy import deepcopy
+from typing import List
+from unittest.mock import patch
+
+import numpy as np
+import torch
+
+import torch._dynamo.testing
+import torch._dynamo.utils
+from torch import nn
+from torch._dynamo.debug_utils import same_two_models
+from torch._dynamo.testing import rand_strided, requires_static_shapes, same
+from torch.nn import functional as F
+
+try:
+ import torch._refs
+
+ HAS_REFS = True
+except ImportError:
+ HAS_REFS = False
+
+
+def ifdyn(count1, count2):
+ if torch._dynamo.config.dynamic_shapes:
+ return count1
+ else:
+ return count2
+
+
+def has_detectron2():
+ try:
+ from detectron2.layers.mask_ops import _paste_masks_tensor_shape
+
+ return _paste_masks_tensor_shape is not None
+ except ImportError:
+ return False
+
+
+def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
+ # from detectron2 mask_ops.py
+
+ device = masks.device
+
+ if skip_empty and not torch.jit.is_scripting():
+ x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
+ dtype=torch.int32
+ )
+ x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(
+ dtype=torch.int32
+ )
+ y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(
+ dtype=torch.int32
+ )
+ else:
+ x0_int, y0_int = 0, 0
+ x1_int, y1_int = img_w, img_h
+ x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
+
+ N = masks.shape[0]
+
+ img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
+ img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
+ img_y = (img_y - y0) / (y1 - y0) * 2 - 1
+ img_x = (img_x - x0) / (x1 - x0) * 2 - 1
+ # img_x, img_y have shapes (N, w), (N, h)
+
+ gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
+ gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
+ grid = torch.stack([gx, gy], dim=3)
+
+ if not torch.jit.is_scripting():
+ if not masks.dtype.is_floating_point:
+ masks = masks.float()
+ img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
+
+ if skip_empty and not torch.jit.is_scripting():
+ return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
+ else:
+ return img_masks[:, 0], ()
+
+
+def cat(tensors, dim=0):
+ # from detectron2 wrappers.py
+ assert isinstance(tensors, (list, tuple))
+ if len(tensors) == 1:
+ return tensors[0]
+ return torch.cat(tensors, dim)
+
+
+def shapes_to_tensor(x, device=None):
+ # from detectron2 wrappers.py
+ if torch.jit.is_scripting():
+ return torch.as_tensor(x, device=device)
+ if torch.jit.is_tracing():
+ assert all(
+ [isinstance(t, torch.Tensor) for t in x]
+ ), "Shape should be tensor during tracing!"
+ # as_tensor should not be used in tracing because it records a constant
+ ret = torch.stack(x)
+ if ret.device != device: # avoid recording a hard-coded device if not necessary
+ ret = ret.to(device=device)
+ return ret
+ return torch.as_tensor(x, device=device)
+
+
+class Boxes:
+ # from detectron2 poolers.py
+ def __init__(self, tensor: torch.Tensor):
+ """
+ Args:
+ tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
+ """
+ device = (
+ tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
+ )
+ tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
+ if tensor.numel() == 0:
+ # Use reshape, so we don't end up creating a new tensor that does not depend on
+ # the inputs (and consequently confuses jit)
+ tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32, device=device)
+ assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
+ self.tensor = tensor
+
+ def __len__(self) -> int:
+ return self.tensor.shape[0]
+
+ @property
+ def device(self):
+ return self.tensor.device
+
+
+def convert_boxes_to_pooler_format(box_lists):
+ # from detectron2 structures.py
+ boxes = torch.cat([x.tensor for x in box_lists], dim=0)
+ # __len__ returns Tensor in tracing.
+ sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device)
+ indices = torch.repeat_interleave(
+ torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes
+ )
+ return cat([indices[:, None], boxes], dim=1)
+
+
+ReformerBackwardOutput = namedtuple(
+ "ReformerBackwardOutput",
+ ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"],
+)
+ReformerEncoderOutput = namedtuple(
+ "ReformerEncoderOutput",
+ ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
+)
+
+
+class _ReversibleFunction(torch.autograd.Function):
+ # taken from modeling_reformer.py in huggingface
+ @staticmethod
+ def forward(
+ ctx,
+ hidden_states,
+ layers,
+ attention_mask,
+ head_mask,
+ num_hashes,
+ all_hidden_states,
+ all_attentions,
+ past_buckets_states,
+ use_cache,
+ orig_sequence_length,
+ output_hidden_states,
+ output_attentions,
+ ):
+ all_buckets = ()
+
+ # split duplicated tensor
+ hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)
+
+ for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)):
+ if output_hidden_states is True:
+ all_hidden_states.append(hidden_states)
+
+ attn_output = layer(attn_output)
+
+ # Add last layer
+ if output_hidden_states is True:
+ all_hidden_states.append(hidden_states)
+
+ # attach params to ctx for backward
+ ctx.save_for_backward(attn_output.detach(), hidden_states.detach())
+ ctx.layers = layers
+ ctx.all_buckets = all_buckets
+ ctx.head_mask = head_mask
+ ctx.attention_mask = attention_mask
+
+ # Concatenate 2 RevNet outputs
+ return torch.cat([attn_output, hidden_states], dim=-1)
+
+ @staticmethod
+ def backward(ctx, grad_hidden_states):
+ grad_attn_output, grad_hidden_states = torch.chunk(
+ grad_hidden_states, 2, dim=-1
+ )
+
+ # retrieve params from ctx for backward
+ attn_output, hidden_states = ctx.saved_tensors
+
+ # create tuple
+ output = ReformerBackwardOutput(
+ attn_output=attn_output,
+ hidden_states=hidden_states,
+ grad_attn_output=grad_attn_output,
+ grad_hidden_states=grad_hidden_states,
+ )
+
+ # free memory
+ del grad_attn_output, grad_hidden_states, attn_output, hidden_states
+
+ layers = ctx.layers
+ all_buckets = ctx.all_buckets
+ head_mask = ctx.head_mask
+ attention_mask = ctx.attention_mask
+
+ for idx, layer in enumerate(layers[::-1]):
+ # pop last buckets from stack
+ buckets = all_buckets[-1]
+ all_buckets = all_buckets[:-1]
+
+ # backprop
+ output = layer.backward_pass(
+ next_attn_output=output.attn_output,
+ hidden_states=output.hidden_states,
+ grad_attn_output=output.grad_attn_output,
+ grad_hidden_states=output.grad_hidden_states,
+ head_mask=head_mask[len(layers) - idx - 1],
+ attention_mask=attention_mask,
+ buckets=buckets,
+ )
+
+ assert all_buckets == (), "buckets have to be empty after backpropagation"
+ grad_hidden_states = torch.cat(
+ [output.grad_attn_output, output.grad_hidden_states], dim=-1
+ )
+
+ # num of return vars has to match num of forward() args
+ # return gradient for hidden_states arg and None for other args
+ return (
+ grad_hidden_states,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class ReformerEncoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.dropout = 0.5
+ self.layer_norm = torch.nn.LayerNorm(512, eps=1.0e-12)
+ self.layers = [torch.nn.Linear(256, 256)]
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=[None] * 6,
+ num_hashes=None,
+ use_cache=False,
+ orig_sequence_length=64,
+ output_hidden_states=False,
+ output_attentions=False,
+ ):
+ # hidden_states and attention lists to be filled if wished
+ all_hidden_states = []
+ all_attentions = []
+ past_buckets_states = [((None), (None)) for i in range(len(self.layers))]
+
+ # concat same tensor for reversible ResNet
+ hidden_states = torch.cat([hidden_states, hidden_states], dim=-1)
+ hidden_states = _ReversibleFunction.apply(
+ hidden_states,
+ self.layers,
+ attention_mask,
+ head_mask,
+ num_hashes,
+ all_hidden_states,
+ all_attentions,
+ past_buckets_states,
+ use_cache,
+ orig_sequence_length,
+ output_hidden_states,
+ output_attentions,
+ )
+
+ # Apply layer norm to concatenated hidden states
+ hidden_states = self.layer_norm(hidden_states)
+
+ # Apply dropout
+ hidden_states = torch.nn.functional.dropout(
+ hidden_states, p=self.dropout, training=self.training
+ )
+
+ return ReformerEncoderOutput(
+ hidden_states=hidden_states,
+ all_hidden_states=all_hidden_states,
+ all_attentions=all_attentions,
+ past_buckets_states=past_buckets_states,
+ )
+
+
+def longformer_chunk(hidden_states, window_overlap=256):
+ """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
+
+ # non-overlapping chunks of size = 2w
+ hidden_states = hidden_states.view(
+ hidden_states.size(0),
+ hidden_states.size(1) // (window_overlap * 2),
+ window_overlap * 2,
+ hidden_states.size(2),
+ )
+
+ # use `as_strided` to make the chunks overlap with an overlap size = window_overlap
+ chunk_size = list(hidden_states.size())
+ chunk_size[1] = chunk_size[1] * 2 - 1
+
+ chunk_stride = list(hidden_states.stride())
+ chunk_stride[1] = chunk_stride[1] // 2
+ return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
+
+
+class PartialT5(torch.nn.Module):
+ # Highly simplified T5Attention prefix
+ def __init__(self):
+ super(PartialT5, self).__init__()
+ self.q = torch.nn.Linear(512, 512)
+ self.k = torch.nn.Linear(512, 512)
+ self.v = torch.nn.Linear(512, 512)
+
+ def forward(
+ self,
+ hidden_states,
+ key_value_states=None,
+ past_key_value=None,
+ query_length=None,
+ ):
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ real_seq_length = seq_length
+
+ if past_key_value is not None:
+ assert (
+ len(past_key_value) == 2
+ ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+ real_seq_length += (
+ past_key_value[0].shape[2] if query_length is None else query_length
+ )
+
+ def shape(states):
+ """projection"""
+ return states.view(batch_size, -1, 8, 64).transpose(1, 2)
+
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
+ """projects hidden states correctly to key/query states"""
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(hidden_states))
+ elif past_key_value is None:
+ # cross-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(key_value_states))
+
+ if past_key_value is not None:
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, key_length, dim_per_head)
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
+ else:
+ # cross-attn
+ hidden_states = past_key_value
+ return hidden_states
+
+ # get query states
+ query_states = shape(
+ self.q(hidden_states)
+ ) # (batch_size, n_heads, seq_length, dim_per_head)
+
+ # get key/value states
+ key_states = project(
+ hidden_states,
+ self.k,
+ key_value_states,
+ past_key_value[0] if past_key_value is not None else None,
+ )
+ value_states = project(
+ hidden_states,
+ self.v,
+ key_value_states,
+ past_key_value[1] if past_key_value is not None else None,
+ )
+
+ # compute scores
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
+
+ # (truncated here )
+ return scores, value_states
+
+
+class ChunkReformerFeedForward(torch.nn.Module):
+ # simplified from HF modeling_reformer.py
+ def __init__(self):
+ super().__init__()
+ self.layer_norm = torch.nn.LayerNorm(256, eps=1e-12)
+ self.dense = torch.nn.Linear(256, 256)
+ self.output = torch.nn.Linear(256, 256)
+
+ def forward(self, attention_output):
+ return apply_chunking_to_forward(
+ self.forward_chunk,
+ attention_output + 1,
+ )
+
+ def forward_chunk(self, hidden_states):
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.dense(hidden_states)
+ return self.output(hidden_states)
+
+
+def apply_chunking_to_forward(forward_fn, *input_tensors):
+ # simplified from HF model_utils.py
+ assert len(input_tensors) > 0
+ tensor_shape = input_tensors[0].shape[1]
+ assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors)
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
+ if num_args_in_forward_chunk_fn != len(input_tensors):
+ raise ValueError()
+
+ return forward_fn(*input_tensors)
+
+
+class FakeMamlInner(torch.nn.Module):
+ def __init__(self):
+ super(FakeMamlInner, self).__init__()
+ self.linear = torch.nn.Linear(784, 5)
+
+ def forward(self, x, ignored=None, bn_training=False):
+ return self.linear(x.view(x.shape[0], -1))
+
+
+class PartialMaml(torch.nn.Module):
+ # Highly simplified version of maml.meta.Meta.finetuning
+ def __init__(self):
+ super(PartialMaml, self).__init__()
+ self.net = FakeMamlInner()
+ self.update_step_test = 10
+ self.update_lr = 0.4
+
+ def forward(self, x_spt, y_spt, x_qry, y_qry):
+ querysz = x_qry.size(0)
+
+ corrects = [0 for _ in range(self.update_step_test + 1)]
+
+ # in order to not ruin the state of running_mean/variance and bn_weight/bias
+ # we finetunning on the copied model instead of self.net
+ net = deepcopy(self.net)
+
+ # 1. run the i-th task and compute loss for k=0
+ logits = net(x_spt)
+ loss = F.cross_entropy(logits, y_spt)
+ grad = torch.autograd.grad(loss, net.parameters())
+ fast_weights = list(
+ map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))
+ )
+
+ # this is the loss and accuracy before first update
+ with torch.no_grad():
+ # [setsz, nway]
+ logits_q = net(x_qry, net.parameters(), bn_training=True)
+ # [setsz]
+ pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
+ # scalar
+ correct = torch.eq(pred_q, y_qry).sum().item()
+ corrects[0] = corrects[0] + correct
+
+ # this is the loss and accuracy after the first update
+ with torch.no_grad():
+ # [setsz, nway]
+ logits_q = net(x_qry, fast_weights, bn_training=True)
+ # [setsz]
+ pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
+ # scalar
+ correct = torch.eq(pred_q, y_qry).sum().item()
+ corrects[1] = corrects[1] + correct
+
+ del net
+
+ accs = torch.tensor(corrects) / querysz
+
+ return accs
+
+
+class ModelOutput(collections.OrderedDict):
+ """based on file_utils.py in HuggingFace"""
+
+ def __getitem__(self, k):
+ if isinstance(k, str):
+ inner_dict = {k: v for (k, v) in self.items()}
+ return inner_dict[k]
+ else:
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def to_tuple(self):
+ return tuple(self[k] for k in self.keys())
+
+
+def create_rand_mask_from_inputs(
+ from_blocked_mask,
+ to_blocked_mask,
+ rand_attn,
+ num_attention_heads,
+ num_rand_blocks,
+ batch_size,
+ from_seq_length,
+ from_block_size,
+):
+ """taken from HF modeling_big_bird.py"""
+ num_windows = from_seq_length // from_block_size - 2
+ rand_mask = torch.stack(
+ [p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]
+ )
+ rand_mask = rand_mask.view(
+ batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size
+ )
+ rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask)
+ return rand_mask
+
+
+class SequentialAppendList(torch.nn.Sequential):
+ """from timm/models/vovnet.py"""
+
+ def __init__(self, *args):
+ super(SequentialAppendList, self).__init__(*args)
+
+ def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
+ for i, module in enumerate(self):
+ if i == 0:
+ concat_list.append(module(x))
+ else:
+ concat_list.append(module(concat_list[-1]))
+ x = torch.cat(concat_list, dim=1)
+ return x, concat_list
+
+
+class BatchNormAct2d(torch.nn.BatchNorm2d):
+ """Taken from timm"""
+
+ def __init__(
+ self,
+ num_features,
+ eps=1e-5,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=True,
+ act_layer=torch.nn.ReLU,
+ inplace=True,
+ ):
+ super(BatchNormAct2d, self).__init__(
+ num_features,
+ eps=eps,
+ momentum=momentum,
+ affine=affine,
+ track_running_stats=track_running_stats,
+ )
+ self.act = act_layer(inplace=inplace)
+
+ @torch.jit.ignore
+ def _forward_python(self, x):
+ return super().forward(x)
+
+ def forward(self, x):
+ if torch.jit.is_scripting():
+ x = self._forward_jit(x)
+ else:
+ x = self._forward_python(x)
+ x = self.act(x)
+ return x
+
+
+def get_parameter_dtype(parameter):
+ """from huggingface model_utils.py"""
+ try:
+ return next(parameter.parameters()).dtype
+ except StopIteration:
+ # For nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module):
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+
+class DummyConfig:
+ attn_layers = ["local", "lsh", "local", "lsh", "local", "lsh"]
+ lsh_attn_chunk_length = 64
+ local_attn_chunk_length = 64
+
+
+def _get_min_chunk_len(config):
+ """from hf_Reformer"""
+ attn_types = config.attn_layers
+ attn_types_set = set(attn_types)
+ if len(attn_types_set) == 1 and attn_types[0] == "lsh":
+ return config.lsh_attn_chunk_length
+ elif len(attn_types_set) == 1 and attn_types[0] == "local":
+ return config.local_attn_chunk_length
+ elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]):
+ return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
+ else:
+ raise NotImplementedError(
+ f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select "
+ "attn layer types from ['lsh', 'local'] only."
+ )
+
+
+def _stable_argsort(vector, dim):
+ """from hf_Reformer"""
+ # this function scales the vector so that torch.argsort is stable.
+ # torch.argsort is not stable on its own
+ scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1)
+ scale_offset = scale_offset.expand(vector.shape)
+ scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim])
+ return torch.argsort(scaled_vector, dim=dim)
+
+
+def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(buckets):
+ """from hf_Reformer"""
+ # no gradients are needed
+ with torch.no_grad():
+ # hash-based sort
+ sorted_bucket_idx = _stable_argsort(buckets, dim=-1)
+
+ # create simple indices to scatter to, to have undo sort
+ indices = (
+ torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device)
+ .view(1, 1, -1)
+ .expand(sorted_bucket_idx.shape)
+ )
+
+ # get undo sort
+ undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size())
+ undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices)
+
+ return sorted_bucket_idx, undo_sorted_bucket_idx
+
+
+class FeedForwardLayer(nn.Module):
+ def __init__(self, d_model, dim_feedforward, activation, dropout) -> None:
+ super(FeedForwardLayer, self).__init__()
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.activation = activation
+ self.dropout1 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+ self.dropout2 = nn.Dropout(dropout)
+
+ def forward(self, x):
+ return self.dropout2(
+ self.linear2(self.dropout1(self.activation(self.linear1(x))))
+ )
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation=nn.ReLU(),
+ layer_norm_eps=1e-5,
+ ):
+ super(TransformerEncoderLayer, self).__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.dropout = nn.Dropout(dropout)
+ self.ff_block = FeedForwardLayer(d_model, dim_feedforward, activation, dropout)
+
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
+ x = src
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
+ x = self.norm2(x + self._ff_block(x))
+ return x
+
+ # self-attention block
+ def _sa_block(self, x, attn_mask, key_padding_mask):
+ x = self.self_attn(
+ x,
+ x,
+ x,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout(x)
+
+ # feed forward block
+ def _ff_block(self, x):
+ return self.ff_block(x)
+
+
+class TestModule(torch.nn.Module):
+ def inner_fn(self, left, right):
+ return tuple(left) == tuple(right)
+
+ def fn(self, tensor):
+ if type(tensor) is int:
+ return False
+
+ torch.add(tensor, tensor)
+ return self.inner_fn(tensor.shape, (1, 2, 3))
+
+
+class ReproTests(torch._dynamo.testing.TestCase):
+ def test_do_paste_mask(self):
+ torch._dynamo.utils.counters.clear()
+ opt__do_paste_mask = torch._dynamo.optimize(
+ torch._dynamo.testing.CompileCounter()
+ )(_do_paste_mask)
+ opt__do_paste_mask(
+ torch.randn(1, 1, 28, 28),
+ torch.tensor([[0.0, 1, 2, 4]]) * 1,
+ 427,
+ 640,
+ True,
+ )
+ opt__do_paste_mask(
+ torch.randn(1, 1, 28, 28),
+ torch.tensor([[0.0, 1, 2, 4]]) * 2,
+ 427,
+ 640,
+ True,
+ )
+ opt__do_paste_mask(
+ torch.randn(1, 1, 28, 28),
+ torch.tensor([[0.0, 1, 2, 4]]) * 3,
+ 612,
+ 612,
+ True,
+ )
+ opt__do_paste_mask(
+ torch.randn(1, 1, 28, 28),
+ torch.tensor([[0.0, 1, 2, 4]]) * 4,
+ 612,
+ 612,
+ True,
+ )
+ opt__do_paste_mask(
+ torch.randn(1, 1, 28, 28),
+ torch.tensor([[0.0, 1, 2, 4]]) * 2,
+ 427,
+ 640,
+ False,
+ )
+
+ self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3)
+ # Graph break because of dynamic slicing
+ self.assertEqual(
+ torch._dynamo.utils.counters["frames"]["total"],
+ torch._dynamo.utils.counters["frames"]["ok"] + 1,
+ )
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", True)
+ def test_convert_boxes_to_pooler_format(self):
+ boxes1 = [
+ Boxes(torch.arange(0, 8).reshape((2, 4))),
+ Boxes(torch.arange(8, 16).reshape((2, 4))),
+ ]
+ boxes2 = [
+ Boxes(torch.arange(16, 20).reshape((1, 4))),
+ Boxes(torch.arange(20, 24).reshape((1, 4))),
+ ]
+ correct1 = convert_boxes_to_pooler_format(boxes1)
+ correct2 = convert_boxes_to_pooler_format(boxes2)
+ fn = convert_boxes_to_pooler_format
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt)(fn)
+ self.assertTrue(same(opt_fn(boxes1), correct1))
+ self.assertTrue(same(opt_fn(boxes2), correct2))
+
+ # repeat_interleave is a dynamic shape operator we do not execute/
+ # In the future, we could reduce the frame_count down to 1
+ # by guarding on the exact values of `Tensor repeats` arg
+ self.assertEqual(cnt.frame_count, ifdyn(2, 4))
+ self.assertEqual(cnt.op_count, ifdyn(9, 10))
+
+ def test_boxes_len(self):
+ def fn(boxes):
+ return len(boxes) + boxes.__len__() + boxes.tensor
+
+ boxes1 = Boxes(torch.arange(0, 8).reshape((2, 4)))
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
+ self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
+
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, ifdyn(6, 1))
+
+ def _reformer(self, nopython):
+ input = torch.randn([1, 64, 256])
+ model = ReformerEncoder()
+ torch.manual_seed(1337)
+ correct = copy.deepcopy(model)(input)
+ cnt = torch._dynamo.testing.CompileCounter()
+ torch.manual_seed(1337)
+ opt_model = torch._dynamo.optimize(cnt, nopython=nopython)(model)
+ self.assertTrue(same(opt_model(input), correct))
+ return cnt
+
+ def test_reformer_eval(self):
+ with torch.no_grad():
+ cnt = self._reformer(nopython=True)
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 10)
+
+ def test_reformer_train(self):
+ with torch.enable_grad():
+ cnt = self._reformer(nopython=False)
+ # cant inline torch.autograd.Function means graph break
+ self.assertEqual(cnt.frame_count, 4)
+ self.assertEqual(cnt.op_count, 10)
+
+ def test_longformer_chunk(self):
+ input1 = torch.randn([1, 4096, 1])
+ input2 = torch.randn([12, 4096, 64])
+ correct1 = longformer_chunk(input1)
+ correct2 = longformer_chunk(input2)
+ fn = longformer_chunk
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
+ self.assertTrue(same(opt_fn(input1), correct1))
+ self.assertTrue(same(opt_fn(input2), correct2))
+ self.assertTrue(same(opt_fn(input1), correct1))
+ self.assertTrue(same(opt_fn(input2), correct2))
+
+ self.assertEqual(cnt.frame_count, ifdyn(1, 2))
+ self.assertEqual(cnt.op_count, ifdyn(19, 4))
+
+ def test_hf_t5_forward(self):
+ input = torch.randn([1, 2048, 512])
+ model = PartialT5()
+ correct = model(input)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_model = torch._dynamo.optimize_assert(cnt)(model)
+ self.assertTrue(same(opt_model(input), correct))
+
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, ifdyn(13, 11))
+
+ def test_slicing_dynamic_shape(self):
+ def fn(y):
+ x = torch.ones(8)
+ idx = y[0]
+ out = x[idx:]
+ return (out + 3) * 5
+
+ counter = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(counter)(fn)
+ out = opt_fn(torch.ones(10, dtype=torch.long))
+ # idx should be 1 -> slicing off [1:] of 8 elem tensor
+ self.assertEqual(list(out.shape), [7])
+
+ expected_ops = ifdyn(5, 4)
+ expected_frame = ifdyn(1, 2)
+
+ self.assertEqual(expected_ops, expected_ops)
+ self.assertEqual(expected_frame, expected_frame)
+
+ self.assertEqual(list(opt_fn(torch.tensor([4])).shape), [4])
+
+ def test_slicing_dynamic_shape_setitem(self):
+ def fn(input_lengths: torch.Tensor, new_ones_1):
+ getitem_13 = input_lengths[3]
+ new_ones_1[(3, slice(getitem_13, None, None))] = 0
+ setitem_13 = new_ones_1
+ return (setitem_13,)
+
+ x = torch.randn(10).to(dtype=torch.int64)
+ y = torch.randn(10, 204)
+ ref = fn(x, y)
+ opt_fn = torch._dynamo.optimize("aot_eager")(fn)
+ res = opt_fn(x, y)
+ self.assertTrue(same(ref, res))
+
+ @requires_static_shapes
+ def test_chunk_reformer_ff(self):
+ input = torch.randn([1, 4096, 256])
+ model = ChunkReformerFeedForward()
+ correct = model(input)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_model = torch._dynamo.optimize_assert(cnt)(model)
+ self.assertTrue(same(opt_model(input), correct))
+
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 4)
+
+ # see: https://github.com/pytorch/pytorch/issues/80067
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_maml_item_capture(self):
+ a = torch.randn(5, 1, 28, 28)
+ b = torch.zeros(5, dtype=torch.int64)
+ c = torch.randn(75, 1, 28, 28)
+ d = torch.zeros(75, dtype=torch.int64)
+ model = PartialMaml()
+ correct = model(a, b, c, d)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_model = torch._dynamo.optimize(cnt)(model)
+ for _ in range(10):
+ self.assertTrue(same(opt_model(a, b, c, d), correct))
+
+ self.assertEqual(cnt.frame_count, ifdyn(3, 2))
+ # TODO(jansel): figure out why op count depends on imports
+ self.assertIn(cnt.op_count, (36, 35, 29, 28))
+
+ # see: https://github.com/pytorch/pytorch/issues/80067
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
+ def test_maml_no_item_capture(self):
+ a = torch.randn(5, 1, 28, 28)
+ b = torch.zeros(5, dtype=torch.int64)
+ c = torch.randn(75, 1, 28, 28)
+ d = torch.zeros(75, dtype=torch.int64)
+ model = PartialMaml()
+ correct = model(a, b, c, d)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_model = torch._dynamo.optimize(cnt)(model)
+ for _ in range(10):
+ self.assertTrue(same(opt_model(a, b, c, d), correct))
+
+ self.assertEqual(cnt.frame_count, ifdyn(5, 4))
+ # TODO(jansel): figure out why op count depends on imports
+ self.assertIn(cnt.op_count, (31, 36, 35, 29, 28))
+
+ def test_hf_model_output(self):
+ ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10))
+
+ def fn1(x):
+ return x["a"] + 1
+
+ def fn2(x):
+ return x.a + 1
+
+ def fn3(x):
+ return x.to_tuple()[0] + 1
+
+ def fn4(x):
+ return x[0] + 1
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ for fn in (fn1, fn2, fn3, fn4):
+ cnt.clear()
+ opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
+ self.assertTrue(same(opt_fn(ex), ex.a + 1))
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 1)
+
+ @requires_static_shapes
+ def test_create_rand_mask_from_inputs(self):
+ args = [
+ torch.randn([1, 64, 64]),
+ torch.randn([1, 64, 64]),
+ torch.zeros([1, 12, 62, 3], dtype=torch.int64),
+ 12,
+ 3,
+ 1,
+ 4096,
+ 64,
+ ]
+ correct = create_rand_mask_from_inputs(*args)
+ fn = create_rand_mask_from_inputs
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
+ self.assertTrue(same(opt_fn(*args), correct))
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 8)
+
+ def test_rng_state(self):
+ def fn():
+ state = torch.get_rng_state()
+ before = torch.rand(1000)
+ torch.set_rng_state(state)
+ after = torch.rand(1000)
+ return before, after
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt)(fn)
+
+ before, after = opt_fn()
+ self.assertTrue(same(before, after))
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 4) # rand, rand
+ graph, _ = torch._dynamo.export(fn)
+
+ def test_seq_append_list(self):
+ x = torch.randn(4, 10)
+ model = SequentialAppendList(
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ )
+ # this one is tricky because it mutates the list provided as an input
+ l1 = [x]
+ l2 = [x]
+ correct, _ = model(x, l1)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_model = torch._dynamo.optimize_assert(cnt)(model)
+ result, l3 = opt_model(x, l2)
+ self.assertTrue(same(result, correct))
+ self.assertTrue(same(l1, l2))
+ self.assertIs(l2, l3)
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 5)
+
+ def test_batch_norm_act(self):
+ a = torch.randn(5, 1, 28, 28)
+ model = BatchNormAct2d(1).eval()
+ correct = model(a)
+ cnt = torch._dynamo.testing.CompileCounter()
+ if not torch._dynamo.config.specialize_int_float:
+ # _local_scalar_dense causes graph break w 0-dim tensor
+ opt_model = torch._dynamo.optimize(cnt)(model)
+ self.assertTrue(same(opt_model(a), correct))
+ return
+
+ opt_model = torch._dynamo.optimize_assert(cnt)(model)
+ self.assertTrue(same(opt_model(a), correct))
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 2)
+
+ def test_get_parameter_dtype(self):
+ model = SequentialAppendList(
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ )
+
+ def fn(model, x):
+ return x + torch.randn(10, dtype=get_parameter_dtype(model))
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
+ self.assertEqual(opt_fn(model, torch.randn(10)).dtype, torch.float32)
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 2)
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", True)
+ def test_nn_parameter(self):
+ def test_fn():
+ a = torch.nn.Parameter(torch.randn(5, 5))
+ # Checks that TensorVariable stores the type information correctly
+ self.assertTrue(isinstance(a, torch.nn.Parameter))
+ return a
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
+ out = opt_test_fn()
+ self.assertTrue(isinstance(out, torch.nn.Parameter))
+
+ def test_Size(self):
+ def test_fn():
+ a = torch.randn(4)
+ x = torch.Size([1, 2, 3])
+ # Checks that SizeVariable return torch.Size object
+ assert isinstance(x, torch.Size)
+ # Causes graph breaks and checks reconstruction of SizeVariable
+ # object
+ self.assertIsInstance(x, torch.Size)
+ return a
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
+ opt_test_fn()
+
+ def test_indexing_with_list(self):
+ def test_fn():
+ def run_test(tensor, *idx):
+ npt = tensor.numpy()
+ assert npt[idx].shape == tensor[idx].shape
+
+ x = torch.arange(0, 10)
+ cases = [
+ [None, None],
+ [1, None],
+ ]
+
+ for case in cases:
+ run_test(x, *case)
+
+ return torch.randn(4)
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
+ opt_test_fn()
+
+ def test_reformer_min_chunk_len(self):
+ def fn(cfg):
+ t = torch.empty(10)
+ t.fill_(_get_min_chunk_len(cfg))
+ return t[0]
+
+ cfg = DummyConfig()
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
+ self.assertEqual(opt_fn(cfg), 64)
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 3)
+
+ def test_reformer_sorting(self):
+ x = torch.zeros([1, 12, 4096], dtype=torch.int64)
+ correct = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(x)
+ fn = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
+ self.assertTrue(same(opt_fn(x), correct))
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, ifdyn(28, 14))
+
+ def test_recursive_map(self):
+ # https://github.com/pytorch/torchdynamo/issues/132
+ def _recursive_map(struct, batch_dim=0):
+ for k, v in struct.items():
+ if v is not None:
+ if isinstance(v, dict):
+ _recursive_map(v)
+ else:
+ struct[k] = v
+
+ def toy_example(a, b, v):
+ x = a / (torch.abs(a) + 1)
+ if v is not None:
+ _recursive_map(v)
+ return x * b
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_toy_example = torch._dynamo.optimize(cnt)(toy_example)
+ opt_toy_example(
+ torch.randn(10),
+ torch.randn(10),
+ {"layer0": {"memory_keys": torch.randn(10)}},
+ )
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 4)
+
+ def test_issue175(self):
+ n_heads = 2
+ d_model = 64
+ model = TransformerEncoderLayer(d_model, n_heads)
+ inp = torch.randn(1, d_model)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_model = torch._dynamo.optimize(cnt, nopython=True)(model)
+ opt_model(inp)
+ opt_model(inp)
+ self.assertEqual(cnt.frame_count, 1)
+ self.assertEqual(cnt.op_count, 12)
+
+ def test_exec_import(self):
+ def fn1():
+ exec("import math")
+
+ def fn2():
+ try:
+ math.sqrt(4)
+ return False
+ except NameError:
+ return True
+
+ def fn3():
+ fn1()
+ return fn2()
+
+ self.assertTrue(fn3())
+ opt_fn3 = torch._dynamo.optimize("eager")(fn3)
+ self.assertTrue(opt_fn3())
+
+ def test_exec_wildcard_import(self):
+ # Test that globals are not carried over from frame to frame
+ def fn1():
+ exec("from torch import *")
+
+ def fn2():
+ x = torch.zeros(4)
+ for i in range(5):
+ x = x + i
+ return x
+
+ def fn3():
+ fn1()
+ return fn2()
+
+ ref = fn3()
+ opt_fn3 = torch._dynamo.optimize("eager")(fn3)
+ res = opt_fn3()
+ self.assertTrue(same(ref, res))
+
+ def test_with_on_graph_break_inst(self):
+ def reversible(x):
+ print("Hello world") # Cause graph break so inline fails
+ return torch.sin(torch.cos(x))
+
+ def fn(x):
+ with torch.enable_grad():
+ a = torch.sin(x)
+ b = reversible(a)
+ c = torch.sigmoid(b)
+ c.sum().backward()
+ return x.grad
+
+ x = torch.randn(3, requires_grad=True)
+ x.grad = None
+ with torch.no_grad():
+ ref = fn(x)
+
+ x.grad = None
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ with torch.no_grad():
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+
+ def test_abc_setattr(self):
+ # tests that we correctly bail out of __setattr__ calls
+
+ # TODO: does not ensure ABC classes are correctly inferred as ClassVariables
+ # (doesn't test the fix for 'super()')
+
+ class BaseModule(torch.nn.Module, ABC):
+ def blah(self, x):
+ return x + 1
+
+ class Derived(BaseModule):
+ def __setattr__(self, name, value) -> None:
+ super().__setattr__(name, value)
+
+ def forward(self, x):
+ # expect a graph break on __setattr__
+ self.foo = 0
+ return self.blah(x)
+
+ def blah(self, x):
+ return super().blah(x)
+
+ x = torch.randn(3, requires_grad=True)
+ mod = Derived()
+ opt_mod = torch._dynamo.optimize("eager")(mod)
+ opt_mod(x)
+
+ self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3)
+ self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 3)
+
+ def test_guard_fail_tensor_bool(self):
+ @torch._dynamo.skip
+ def fn():
+ condition_shape = (5, 5)
+ dtypes = (torch.bool,)
+ shapes = (
+ (),
+ (5,),
+ (1, 5),
+ )
+
+ tensors = list(
+ [
+ torch.empty(shape, dtype=dtype).fill_(17)
+ for shape, dtype in itertools.product(shapes, dtypes)
+ ]
+ )
+
+ x_vals = (5.0, *tensors)
+ y_vals = (6.0, *tensors)
+
+ @torch._dynamo.disable
+ def get_expected(condition, x, y):
+ x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
+ y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
+ return torch.from_numpy(
+ np.where(condition.cpu().numpy(), x_np, y_np)
+ ).to(common_dtype)
+
+ for x, y in zip(x_vals, y_vals):
+ condition = torch.empty(*condition_shape, dtype=torch.bool).bernoulli_()
+ common_dtype = torch.result_type(x, y)
+
+ def check_equal(condition, x, y):
+ # NumPy aggressively promotes to double, hence cast to output to correct dtype
+ expected = get_expected(condition, x, y)
+ result = torch.where(condition, x, y)
+ assert torch.allclose(expected, result)
+
+ check_equal(condition, x, y)
+ check_equal(condition, y, x)
+
+ fn()
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ opt_fn()
+
+ def test_guard_fail_nested_tuple(self):
+ def fn(args):
+ return torch.ones(()), args[0] * 2
+
+ # This adds a tensor check on args[1][0] and args[1][1]
+ args1 = (torch.ones(1), (torch.ones(1), torch.ones(1)))
+ args2 = (torch.ones(1), torch.ones(1))
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ ref = opt_fn(args1)
+ res = opt_fn(args2)
+
+ self.assertTrue(same(ref, res))
+
+ def test_numpy_list(self):
+ @torch._dynamo.disable
+ def rand_gen():
+ return list(np.array([random.randint(5, 10) for _ in range(10)]))
+
+ def fn(x):
+ random_list = rand_gen()
+ z = torch.LongTensor(random_list)
+ return x * z
+
+ x = torch.ones(10) * 2
+
+ random.seed(0)
+ ref0 = fn(x)
+ ref1 = fn(x)
+
+ random.seed(0)
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ res0 = opt_fn(x)
+ res1 = opt_fn(x)
+
+ self.assertTrue(same(ref0, res0))
+ self.assertTrue(same(ref1, res1))
+
+ @unittest.skipIf(not HAS_REFS, "requires recent PT version")
+ @unittest.expectedFailure
+ def test_primtorch(self):
+ @torch._dynamo.optimize("eager", nopython=True)
+ def fn(x):
+ torch._refs.abs(x)
+
+ fn(torch.randn(3))
+
+ @unittest.skipIf(
+ not isinstance(torch.ops.aten.abs, torch._ops.OpOverloadPacket),
+ "old pt doesn't work",
+ )
+ def test_torch_ops_aten(self):
+ # Picked an op that doesn't show up in the default list
+ @torch._dynamo.optimize("eager", nopython=True)
+ def fn(x):
+ return torch.ops.aten.absolute(x)
+
+ fn(torch.randn(3))
+
+ def test_guard_ordering_shape_fail(self):
+ # If a function which takes a tensor has an inner function which
+ # is compiled and generates a guard on its shape,
+ # they are evaluated in the wrong order. So if on a subsequent call
+ # an int is passed instead of a tensor, guard evaluation will crash
+ # with a "no attribute: shape" error
+ m = TestModule()
+ opt_m = torch._dynamo.optimize("eager")(m)
+ opt_m.fn(torch.ones((5, 5)))
+ opt_m.fn(-3)
+
+ def test_tensor_isinstance_tuple(self):
+ @torch._dynamo.optimize("eager")
+ def fn():
+ t = torch.ones(5, 5)
+ if not isinstance(t, (int, torch.Tensor)):
+ msg = str.format(
+ "{0} is not an instance of {1}",
+ type(t),
+ (int, torch.Tensor),
+ )
+ raise ValueError(msg)
+ return True
+
+ fn()
+
+ def test_isinstance_dtype(self):
+ @torch._dynamo.optimize("eager", nopython=True)
+ def fn(x):
+ isinstance(torch.bfloat16, torch.dtype)
+ return x
+
+ fn(torch.randn(3))
+
+ def test_isinstance_storage(self):
+ @torch._dynamo.optimize("eager")
+ def fn(x):
+ f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
+ bools = torch.BoolStorage.from_buffer(f, "big")
+ self.assertTrue(isinstance(bools, torch.BoolStorage))
+ return x
+
+ fn(torch.randn(3))
+
+ def test_dict_list_values(self):
+ def inner_fn(args):
+ return [x[1].shape for x in args]
+
+ @torch._dynamo.optimize("eager")
+ def fn(tensors):
+ return inner_fn(zip(itertools.count(), tensors["args"]))
+
+ fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]})
+ fn({"args": [torch.ones(5, 5)]})
+
+ def test_dict_iter(self):
+ class MyMod(torch.nn.Module):
+ def forward(self, x):
+ z = {"my": 1, "const": 2, "dict": 3, "variable": 4}
+ tot = 0
+ for key in z:
+ tot += z[key]
+
+ return tot
+
+ x = torch.tensor([0])
+ model = MyMod()
+ opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
+ y = opt_model(x)
+
+ self.assertEqual(y, 10)
+
+ def test_sort_out(self):
+
+ dtype = torch.float32
+ device = "cpu"
+
+ def fn():
+ tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0]
+ values1 = torch.tensor(0, dtype=dtype, device=device)
+ indices1 = torch.tensor(0, dtype=torch.long, device=device)
+ torch.sort(tensor, out=(values1, indices1))
+ self.assertEqual(values1.stride(), (1,))
+ self.assertEqual(indices1.stride(), (1,))
+
+ fn()
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ opt_fn()
+
+ def test_sigmoid_out(self):
+
+ dtype = torch.float32
+ device = "cpu"
+
+ def fn():
+ inp = torch.randn((3, 5), dtype=dtype, device=device)
+ out1 = torch.tensor(0, dtype=dtype, device=device)
+ torch.sigmoid(inp, out=out1)
+ self.assertEqual(out1.numel(), 15)
+
+ fn()
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ opt_fn()
+
+ def test_slice_into_list_mutable(self):
+ class Mod(torch.nn.Module):
+ def forward(self, listy):
+ x = listy[3:5]
+ for i in range(10):
+ z = torch.abs(torch.randn(10)) + 1
+ x[0] = z
+ return x
+
+ m = Mod()
+ listy = [torch.randn(10)] * 10
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
+ opt_m.forward(listy)
+
+ self.assertEqual(cnt.frame_count, 1)
+
+ def test_vdd_duplicate_error(self):
+ def fn(a, dt):
+ keys = list(dt._jt_dict.keys())
+ p = torch.cos(dt._jt_dict[keys[0]]._value)
+ q = torch.sin(a)
+ r = torch.sigmoid(dt._jt_dict[keys[0]]._value)
+ return p + q + r
+
+ class Value:
+ def __init__(self):
+ self._value = torch.randn(4)
+
+ class Sample:
+ def __init__(self):
+ self._jt_dict = {}
+ self._jt_dict["POSITION_ID"] = Value()
+
+ a = torch.randn(4)
+ sample = Sample()
+
+ ref = fn(a, sample)
+
+ optimized_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
+ res = optimized_fn(a, sample)
+
+ self.assertTrue(same(ref, res))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_specialized_stride(self):
+ def f():
+ e = torch.empty(4)
+ x = e[::2]
+ return x.stride()
+
+ self.assertEqual(f(), torch._dynamo.optimize("eager")(f)())
+
+ @unittest.skipIf(not has_detectron2(), "requires detectron2")
+ def test_multi_import(self):
+ @torch._dynamo.optimize("eager", nopython=True)
+ def to_bitmasks(boxes):
+ from detectron2.layers.mask_ops import (
+ _paste_masks_tensor_shape,
+ paste_masks_in_image,
+ )
+
+ if (
+ paste_masks_in_image is not None
+ and _paste_masks_tensor_shape is not None
+ ):
+ return boxes + 1
+
+ self.assertTrue((to_bitmasks(torch.zeros(10)) == torch.ones(10)).all())
+
+ def test_multi_dot_import(self):
+ def fn1(x):
+ return torch.sin(x)
+
+ def fn(x):
+ import torch.fx
+
+ _ = torch.fx.symbolic_trace(fn1)
+ return x * 2
+
+ x = torch.randn(10)
+ fn(x)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt)(fn)
+ opt_fn(x)
+ self.assertEqual(cnt.frame_count, 1)
+
+ def test_relative_import(self):
+ try:
+ from . import test_functions as _ # noqa: F401
+
+ def fn(x):
+ from .test_functions import tensor_for_import_testing
+
+ return x * 2 * tensor_for_import_testing
+
+ except ImportError:
+
+ def fn(x):
+ from test_functions import tensor_for_import_testing
+
+ return x * 2 * tensor_for_import_testing
+
+ x = torch.randn(10)
+ fn(x)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
+ opt_fn(x)
+ self.assertEqual(cnt.frame_count, 1)
+
+ def test_relative_import_no_modulename(self):
+ try:
+ from . import test_functions as _ # noqa: F401
+
+ def fn(x):
+ from . import test_functions
+
+ return x * 2 * test_functions.tensor_for_import_testing
+
+ except ImportError:
+
+ def fn(x):
+ import test_functions
+
+ return x * 2 * test_functions.tensor_for_import_testing
+
+ x = torch.randn(10)
+ fn(x)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
+ opt_fn(x)
+ self.assertEqual(cnt.frame_count, 1)
+
+ # This doesn't work without fake tensors but I don't care
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", True)
+ def test_issue1466_size_aot_autograd(self):
+ def fn(x):
+ # do a tensor op and a size compute
+ y = x * 2
+ x_size = x.size()
+ # trigger a graph break
+ print("arf")
+ # use the tensor op and size compute
+ z = y.view(x_size) + 1
+ return z
+
+ x = torch.randn(2, 3, requires_grad=True)
+ ref = fn(x)
+ opt_fn = torch._dynamo.optimize("aot_eager")(fn)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+
+ def test_ellipsis(self):
+ class Repro(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.lnorm = torch.nn.LayerNorm(
+ (256,), eps=1e-06, elementwise_affine=True
+ )
+ self.linear = torch.nn.Linear(
+ in_features=256, out_features=256, bias=True
+ )
+
+ def forward(self, cat_10):
+ lnorm = self.lnorm(cat_10)
+ getitem_64 = lnorm[
+ (slice(None, None, None), slice(0, 1, None), Ellipsis)
+ ]
+ linear = self.linear(getitem_64)
+ return (linear,)
+
+ args = [torch.randn(2, 197, 256)]
+
+ mod = Repro()
+ opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
+
+ self.assertTrue(same(mod(*args), opt_mod(*args)))
+
+ def test_reinplacing(self):
+ class MockModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.self_layoutlm_embeddings_x_position_embeddings = (
+ torch.nn.Embedding(1024, 768)
+ )
+ self.self_layoutlm_embeddings_y_position_embeddings = (
+ torch.nn.Embedding(1024, 768)
+ )
+
+ def forward(self, getitem_1, getitem_2, add):
+ self_layoutlm_embeddings_x_position_embeddings = (
+ self.self_layoutlm_embeddings_x_position_embeddings(getitem_1)
+ )
+ self_layoutlm_embeddings_y_position_embeddings = (
+ self.self_layoutlm_embeddings_y_position_embeddings(getitem_2)
+ )
+ add_1 = add + self_layoutlm_embeddings_x_position_embeddings
+ add_2 = add_1 + self_layoutlm_embeddings_y_position_embeddings
+ return (add_2,)
+
+ mod = MockModule()
+ opt_mod = torch._dynamo.optimize("aot_inductor_debug")(mod)
+
+ args = [
+ ((2, 512), (2048, 4), torch.int64, "cpu", False),
+ ((2, 512), (2048, 4), torch.int64, "cpu", False),
+ ((2, 512, 768), (393216, 768, 1), torch.float32, "cpu", True),
+ ]
+ args = [
+ rand_strided(sh, st, dt, dev).requires_grad_(rg)
+ for (sh, st, dt, dev, rg) in args
+ ]
+ self.assertTrue(same_two_models(mod, opt_mod, args))
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py
new file mode 100644
index 0000000000000..a2338c60af8bb
--- /dev/null
+++ b/test/dynamo/test_skip_non_tensor.py
@@ -0,0 +1,112 @@
+# Owner(s): ["module: dynamo"]
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo
+from torch._dynamo.testing import CompileCounter
+
+
+class SkipNonTensorTests(torch._dynamo.testing.TestCase):
+ def test_add_tensor1(self):
+ def fn(a, b):
+ return a + b
+
+ counter = CompileCounter()
+ x = torch.randn(4)
+ y = 5
+ opt_fn = torch._dynamo.optimize_assert(counter)(fn)
+ opt_fn(x, y)
+
+ assert counter.op_count == 1
+
+ def test_add_tensor2(self):
+ def fn(a, b):
+ return torch.add(a, b)
+
+ counter = CompileCounter()
+
+ x = torch.randn(4)
+ y = 5
+ opt_fn = torch._dynamo.optimize_assert(counter)(fn)
+ opt_fn(x, y)
+
+ assert counter.op_count == 1
+
+ def test_add_tensor_list(self):
+ def fn(lst):
+ return lst[0] + lst[1]
+
+ counter = CompileCounter()
+ x = torch.randn(4)
+ y = 5
+ opt_fn = torch._dynamo.optimize_assert(counter)(fn)
+ opt_fn([x, y])
+
+ assert counter.op_count == 1
+
+ def test_add_tensor_dict(self):
+ def fn(dt):
+ return dt["a"] + dt["b"]
+
+ counter = CompileCounter()
+ x = torch.randn(4)
+ y = 5
+ opt_fn = torch._dynamo.optimize_assert(counter)(fn)
+ opt_fn({"a": x, "b": y})
+
+ assert counter.op_count == 1
+
+ def test_add_skip(self):
+ def fn(a, b):
+ return a + b
+
+ counter = CompileCounter()
+ opt_fn = torch._dynamo.optimize_assert(counter)(fn)
+ x = 4
+ y = 5
+ opt_fn(x, y)
+
+ assert counter.op_count == 0
+
+ @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
+ def test_recursive_list(self):
+ def fn(x):
+ return x
+
+ counter = CompileCounter()
+
+ x = []
+ x.append(x)
+ with torch._dynamo.optimize_assert(counter):
+ fn(x)
+
+ assert counter.op_count == 0
+
+ @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
+ def test_custom_list(self):
+ def fn(x):
+ return x[0] + x[1]
+
+ counter = CompileCounter()
+
+ class Foo(list):
+ def __iter__(self):
+ raise Exception()
+
+ def __len__(self):
+ raise Exception()
+
+ x = Foo()
+ x.append(torch.randn(4))
+ x.append(torch.randn(4))
+ with torch._dynamo.optimize_assert(counter):
+ fn(x)
+
+ assert counter.op_count == 0
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py
new file mode 100644
index 0000000000000..f7d601c82b70f
--- /dev/null
+++ b/test/dynamo/test_subgraphs.py
@@ -0,0 +1,533 @@
+# Owner(s): ["module: dynamo"]
+import unittest
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo.testing
+from torch._dynamo import config
+from torch._dynamo.testing import unsupported
+from torch._dynamo.utils import disable_cache_limit
+
+globalmod = torch.nn.ReLU()
+
+
+def indirectly_unsupported(a, b):
+ c = a + b
+ return unsupported(a, c)
+
+
+class SubGraphTests(torch._dynamo.testing.TestCase):
+ def _common(self, fn, frame_count, op_count):
+ torch._dynamo.reset()
+ v1 = torch.ones(10)
+ v2 = torch.ones(10) * -2.0
+ correct1 = fn(v1, v2)
+ correct2 = fn(v2, v1)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt)(fn)
+ r1 = opt_fn(v1, v2)
+ r2 = opt_fn(v2, v1)
+ self.assertTrue(torch._dynamo.testing.same(r1, correct1))
+ self.assertTrue(torch._dynamo.testing.same(r2, correct2))
+ self.assertEqual(cnt.frame_count, frame_count)
+ self.assertEqual(cnt.op_count, op_count)
+
+ def test_control_flow1(self):
+ def fn(a, b):
+ c1 = a - b
+ c2 = b - a
+ if c1.sum() > c2.sum():
+ return c1
+ else:
+ return c2
+
+ self._common(fn, 1, 5)
+
+ def test_control_flow2(self):
+ def fn(a, b):
+ if a.sum() > b.sum():
+ return 1
+ else:
+ return 2
+
+ self._common(fn, 1, 3)
+
+ def test_control_flow3(self):
+ def fn(a, b):
+ c1 = a - b
+ c2 = b - a
+ m = globalmod
+ if c1.sum() > c2.sum():
+ return m(c1)
+ else:
+ return m(c2)
+
+ self._common(fn, 3, 7)
+
+ def test_control_flow4(self):
+ def fn(a, b):
+ tmp1 = a.sum() > b.sum() and a.sum() > 0
+ if tmp1:
+ return 1
+ else:
+ return 2
+
+ self._common(fn, 3, 5)
+
+ def test_control_flow5(self):
+ def fn(a, b):
+ tmp1 = a.sum() > b.sum() and a.sum() > 0
+ tmp2 = a.sum() < b.sum() or b.sum() > 0
+ if tmp1 and tmp2:
+ return 1, tmp1, tmp2
+ else:
+ return 2, tmp1, tmp2
+
+ self._common(fn, 6, 13)
+
+ def test_capi_call1(self):
+ def fn(a, b):
+ c1 = a - b
+ c2 = b - a
+ return unsupported(c1, c2)
+
+ self._common(fn, 1, 2)
+
+ def test_capi_call2(self):
+ def fn(a, b):
+ c1 = a - b
+ c2 = b - a
+ return a - (b - unsupported(c1, c2))
+
+ self._common(fn, 2, 4)
+
+ def test_capi_call3(self):
+ def fn(a, b):
+ c1 = a - b
+ c2 = b - a
+ return torch._dynamo.testing.unsupported(c1, c2)
+
+ self._common(fn, 1, 2)
+
+ def test_indirect_unsupported1(self):
+ def fn(a, b):
+ c1 = a - b
+ c2 = b - a
+ return indirectly_unsupported(c1, c2)
+
+ self._common(fn, 2, 3)
+
+ def test_indirect_unsupported2(self):
+ def fn(a, b):
+ local_const1 = 7
+ local_const2 = 22
+ c1 = a - b
+ c2 = b - a
+ return local_const1 / (local_const2 - indirectly_unsupported(c1, c2))
+
+ self._common(fn, 3, 5)
+
+ def test_indirect_unsupported3(self):
+ def fn(a, b):
+ args = [a - b, b - a]
+ return indirectly_unsupported(*args)
+
+ self._common(fn, 2, 3)
+
+ def test_stack_state1(self):
+ def fn(a, b):
+ t1 = 1.23 * a
+ t2 = 4.56 * a
+ c1 = a - b
+ c2 = b - a
+ return t1 / (t2 - unsupported(c1, c2))
+
+ self._common(fn, 2, 6)
+
+ def test_stack_state2(self):
+ def fn(a, b):
+ t1 = 1.23 * a
+ t2 = 4.56 * a
+ c1 = a - b
+ c2 = b - a
+ return t1 / (t2 - indirectly_unsupported(c1, c2))
+
+ self._common(fn, 3, 7)
+
+ def test_multigraph(self):
+ def fn(a, b):
+ x = a + b
+ x = x / 2.0
+ if x.sum() < 0:
+ return x * -1.0
+ return x
+
+ self._common(fn, 2, 5)
+
+ def test_extended_args(self):
+ too_many_adds = "+".join(["a", "b"] * 256)
+ source = (
+ f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)"
+ )
+ self._common(eval(source), 3, 1026)
+
+ def test_resume1(self):
+ def fn(a, b):
+ x = a + b
+ x = x / 2.0
+ x = x + 2.0
+ x = unsupported(x, a)
+ x = x + 2.0
+ x = x + 2.0
+ x = x + 2.0
+ return x
+
+ self._common(fn, 2, 6)
+
+ def test_resume2(self):
+ def fn(a, b):
+ x = a + b
+ x = x / 2.0
+ x = x + 2.0
+ x = indirectly_unsupported(x, a)
+ x = x + 2.0
+ x = x + 2.0
+ x = x + 2.0
+ return x
+
+ self._common(fn, 3, 7)
+
+ def test_resume3(self):
+ def fn(a, b):
+ x = a + b
+ x = x / 2.0
+ x = x + 2.0
+ x = indirectly_unsupported(x, b=a)
+ x = x + 2.0
+ x = x + 2.0
+ x = x + 2.0
+ return x
+
+ self._common(fn, 3, 7)
+
+ def test_resume4(self):
+ def fn(a, b):
+ x = a + b
+ x = x / 2.0
+ x = x + 2.0
+ x = indirectly_unsupported(a=x, b=a)
+ x = x + 2.0
+ x = x + 2.0
+ x = x + 2.0
+ return x
+
+ self._common(fn, 3, 7)
+
+ def test_resume5(self):
+ def fn(a, b):
+ x = a + b
+ x = x / 2.0
+ x = x + 2.0
+ print(x)
+ x = x + 2.0
+ x = x + 2.0
+ x = x + 2.0
+ return x
+
+ self._common(fn, 2, 6)
+
+ def test_start1(self):
+ def fn(a, b):
+ print(a)
+ x = a + b
+ x = x + 2.0
+ x = x + 2.0
+ return x
+
+ self._common(fn, 1, 3)
+
+ def test_start2(self):
+ def fn(a, b):
+ x = indirectly_unsupported(a, b)
+ x = x + 2.0
+ x = x + 2.0
+ x = x + 2.0
+ return x
+
+ self._common(fn, 2, 4)
+
+ def test_start3(self):
+ def fn(a, b):
+ x = unsupported(a, b)
+ x = x + 2.0
+ x = x + 2.0
+ x = x + 2.0
+ return x
+
+ self._common(fn, 1, 3)
+
+ def test_start4(self):
+ def fn(a, b, check):
+ if check:
+ return a + b + 10
+ else:
+ return a + b - 10
+
+ v1 = torch.randn(10)
+ v2 = torch.randn(10)
+ f = torch.zeros(1, dtype=torch.int32)
+ t = torch.ones(1, dtype=torch.int32)
+ correct1 = fn(v1, v2, t)
+ correct2 = fn(v1, v2, f)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt)(fn)
+ r1 = opt_fn(v1, v2, t)
+ r2 = opt_fn(v1, v2, f)
+ self.assertTrue(torch._dynamo.testing.same(r1, correct1))
+ self.assertTrue(torch._dynamo.testing.same(r2, correct2))
+ self.assertEqual(cnt.frame_count, 3)
+ self.assertEqual(cnt.op_count, 4)
+
+ def test_resume_freevars(self):
+ c1 = torch.randn(10)
+ c2 = torch.randn(10)
+
+ def fn(a, b):
+ x = a + b + (c1 - c2)
+ x = unsupported(x, x)
+ return x + (c1 - c2)
+
+ self._common(fn, 2, 5)
+
+ def test_restore_state(self):
+ def fn(a, b):
+ len_ = len
+ x = a + b
+ x = torch.add(unsupported(x, x), 1)
+ return a * x + len_(b)
+
+ if config.dynamic_shapes:
+ self._common(fn, 2, 5)
+ else:
+ self._common(fn, 2, 4)
+
+ def test_restore_range(self):
+ def fn(a, b):
+ x = a + b
+ rng = range(3, 8, 2)
+ x = unsupported(x, x)
+ for i in rng:
+ x = x + i
+ return x
+
+ self._common(fn, 2, 4)
+
+ def test_restore_range_iter(self):
+ def fn(a, b):
+ x = a + b
+ rng = iter(range(3, 8, 2))
+ x = unsupported(x, x)
+ x += next(rng)
+ return x, list(rng)
+
+ self._common(fn, 2, 2)
+
+ def test_pop_after_resume(self):
+ def fn(a, b):
+ tmp = [a + 1, b + 2, a + b]
+ x = a
+ x = unsupported(x, x)
+ for i in range(3):
+ x += tmp.pop(-1)
+ return x
+
+ self._common(fn, 2, 6)
+
+ @disable_cache_limit()
+ def test_dynamic_shapes(self):
+ def fn(a, b):
+ return a - b * 10
+
+ torch._dynamo.reset()
+ cnt_static = torch._dynamo.testing.CompileCounter()
+ with patch("torch._dynamo.config.dynamic_shapes", False):
+ opt_fn = torch._dynamo.optimize(cnt_static)(fn)
+ for i in range(10):
+ opt_fn(torch.randn(i), torch.randn(i))
+ self.assertEqual(cnt_static.frame_count, 10)
+
+ torch._dynamo.reset()
+ cnt_dynamic = torch._dynamo.testing.CompileCounter()
+ with patch("torch._dynamo.config.dynamic_shapes", True):
+ opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
+ for i in range(10):
+ opt_fn(torch.randn(i), torch.randn(i))
+ # just one graph now rather than 10
+ self.assertEqual(cnt_dynamic.frame_count, 1)
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
+ def test_no_graph_break_on_item(self):
+ def fn(a, b):
+ x = a + b - 1.5
+ x = x.sum()
+ x.item()
+ x = x / (a + b)
+ return x
+
+ self._common(fn, 1, 6)
+
+ @patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
+ def test_graph_break_on_item(self):
+ def fn(a, b):
+ x = a + b - 1.5
+ x = x.sum()
+ x.item()
+ x = x / (a + b)
+ return x
+
+ self._common(fn, 2, 5)
+
+ def test_resume_paths_join(self):
+ def fn(x, c1, c2, c3):
+ x = x + 1
+ if c1:
+ x = x + 2
+ x = x + 3
+ if c2:
+ x = x + 4
+ x = x + 5
+ if c3:
+ x = x + 6
+ return x + 7
+
+ v1 = torch.randn(10)
+ t = torch.Tensor([True])
+ f = torch.Tensor([False])
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt)(fn)
+ for a in (t, f):
+ for b in (t, f):
+ for c in (t, f):
+ opt_fn(v1, a, b, c)
+
+ # checking here we don't create 2^n graphs
+ self.assertEqual(cnt.frame_count, 7)
+ self.assertEqual(cnt.op_count, 10)
+
+ def test_resume_with_no_grad1(self):
+ def fn(a, b):
+ x = a + b
+ with torch.no_grad():
+ x = x + 1
+ x.sum().tolist() # graph break
+ x = x + 2
+ x = x + 3
+ return x
+
+ self._common(fn, 2, 9)
+ torch._dynamo.reset()
+ with torch.no_grad():
+ self._common(fn, 2, 9)
+
+ def test_resume_with_no_grad2(self):
+ def fn(a, b):
+ x = a + b
+ with torch.no_grad():
+ x = x + 1
+ x.sum().tolist() # graph break
+ x = x + 2
+ x.sum().tolist() # graph break
+ x = x + 3
+ x = x + 4
+ return x
+
+ self._common(fn, 3, 13)
+
+ def test_resume_with_no_grad3(self):
+ def fn(a, b):
+ x = a + b
+ with torch.no_grad():
+ with torch.no_grad():
+ x = x + 1
+ with torch.enable_grad():
+ x.sum().tolist() # graph break
+ x = x[0] + 2
+ x = x + 3
+ x = x + 4
+ return x
+
+ self._common(fn, 2, 19)
+
+ def test_resume_tuple_iterator(self):
+ def fn(a, b):
+ x = a + b
+ it = iter(tuple(range(10)))
+ x = x + next(it)
+ x = x + next(it)
+ x = x + next(it)
+ x = unsupported(x, x)
+ x = x + next(it)
+ x = x + next(it)
+ x = x + next(it)
+ x = x + next(it)
+ return x
+
+ self._common(fn, 2, 8)
+
+ def test_tuple_iterator_return(self):
+ def fn(x):
+ it = iter(tuple(range(10)))
+ x = x + next(it)
+ x = x + next(it)
+ x = unsupported(x, x)
+ x = x + next(it)
+ x = x + next(it)
+ x = unsupported(x, x)
+ x = x + next(it)
+ x = x + next(it)
+ return x, it
+
+ v1 = torch.randn(10)
+ v2, it2 = fn(v1)
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt)(fn)
+ v3, it3 = opt_fn(v1)
+ v4, it4 = opt_fn(v1)
+ self.assertEqual(v2.tolist(), v3.tolist())
+ self.assertEqual(v2.tolist(), v4.tolist())
+ self.assertEqual(list(it2), list(it3))
+ self.assertEqual(cnt.frame_count, 3)
+ self.assertEqual(cnt.op_count, 6)
+
+ @unittest.skip("not working yet")
+ def test_tuple_iterator_mutate(self):
+ def fn(x, it):
+ x = x + next(it)
+ x = x + next(it)
+ x = x + next(it)
+ x = x + next(it)
+ return x
+
+ v1 = torch.randn(10)
+ it1 = iter(tuple(range(10)))
+ cnt = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnt)(fn)
+ self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist())
+ self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9])
+
+ def test_enumerate_not_break_graph(self):
+ def fn(a, b):
+ for i, x in enumerate(a.shape):
+ b = b + x
+ for i, x in enumerate(b.shape, 8):
+ b = b + x * i
+ return b
+
+ self._common(fn, 1, 2)
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py
new file mode 100644
index 0000000000000..5f184834418db
--- /dev/null
+++ b/test/dynamo/test_unspec.py
@@ -0,0 +1,226 @@
+# Owner(s): ["module: dynamo"]
+import functools
+import random
+import unittest
+from unittest.mock import patch
+
+import numpy as np
+import torch
+
+import torch._dynamo.testing
+from torch._dynamo.testing import same
+
+try:
+ from . import test_modules, test_repros
+except ImportError:
+ import test_modules
+ import test_repros
+
+
+def make_unspec_fn(fn):
+ @functools.wraps(fn)
+ def _fn(*args, **kwargs):
+ with patch.object(torch._dynamo.config, "specialize_int_float", False):
+ return fn(*args, **kwargs)
+
+ return _fn
+
+
+def make_unspec_cls(cls):
+ class UnspecTest(cls):
+ pass
+
+ UnspecTest.__name__ = f"Unspec{cls.__name__}"
+
+ for name in dir(cls):
+ if name.startswith("test_"):
+ fn = getattr(cls, name)
+ if not callable(fn):
+ continue
+ new_name = f"{name}_unspec"
+ fn = make_unspec_fn(fn)
+ fn.__name__ = new_name
+ setattr(UnspecTest, name, None)
+ setattr(UnspecTest, new_name, fn)
+
+ return UnspecTest
+
+
+UnspecReproTests = make_unspec_cls(test_repros.ReproTests)
+UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests)
+
+
+@patch.object(torch._dynamo.config, "specialize_int_float", False)
+class UnspecTests(torch._dynamo.testing.TestCase):
+ def test_numpy_correctness(self):
+ def fn(x, y, z):
+ xy = [x + y, y, False]
+ np_x = x.numpy()
+ np_y = y.numpy()
+ return {
+ "x": x,
+ "z": z,
+ "a": np_y.sum(),
+ "b": xy,
+ "c": np_y[0][0] / 68,
+ "d": np_x.sum(),
+ }, x + np_y.sum() + z
+
+ x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
+ y = torch.ones([2, 2], dtype=torch.int64)
+ z = np.int64(12)
+ res1 = fn(x, y, z)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res2 = opt_fn(x, y, z)
+ self.assertTrue(same(res1, res2))
+
+ def test_no_recompilations(self):
+ # no recompilations if passing on different numpy int values
+ def fn(x, y):
+ return {"a": x + 1, "b": y / 2}
+
+ x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ for i in range(10):
+ opt_fn(x, np.int64(i))
+ self.assertEqual(cnts.frame_count, 1)
+ self.assertEqual(cnts.op_count, 2)
+
+ def test_builtin_max_min(self):
+ # test unspecialized primitive max/min
+ def fn(x, y, z):
+ return z + 1, max(x, y), min(x - 4, y)
+
+ x = np.int64(12)
+ y = 10
+ z = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
+ res1 = fn(x, y, z)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res2 = opt_fn(x, y, z)
+ self.assertTrue(same(res1, res2))
+
+ def test_feed_random_values_into_graph_only(self):
+ def fn(shape):
+ torch.manual_seed(123)
+ x = torch.randn(shape, device="cpu") * random.randint(30, 100)
+ return x
+
+ shape = [2, 3]
+ random.seed(1)
+ res1 = fn(shape)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ random.seed(1)
+ res2 = opt_fn(shape)
+
+ self.assertTrue(same(res1, res2))
+
+ def test_random_values_with_graph_break(self):
+ def fn(x):
+ r1 = random.random()
+ y = x + random.uniform(10, 20)
+ y.sum().item()
+ r2 = random.randint(2, 18) # no graph output in this frame
+ y.sum().item()
+ return y + r1, r2
+
+ x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
+ random.seed(1)
+ res1 = fn(x)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ random.seed(1)
+ res2 = opt_fn(x)
+ self.assertTrue(same(res1, res2))
+
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_multiple_consecutive_random_calls_before_graph(self):
+ def fn(x):
+ dim1 = random.randrange(start=0, stop=5)
+ dim2 = random.randrange(start=0, stop=5)
+ dim3 = random.randrange(start=0, stop=5)
+ y = torch.rand(dim1, dim2, dim3)
+ return x + 2, y
+
+ x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
+ random.seed(1)
+ res1 = fn(x)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ random.seed(1)
+ res2 = opt_fn(x)
+ self.assertTrue(same(res1, res2))
+
+ def test_random_call_with_while_loop(self):
+ def fn(x):
+ dim1 = random.randrange(start=0, stop=3)
+ dim2 = dim1
+ while dim1 == dim2:
+ dim2 = random.randrange(start=0, stop=3)
+ return x * 2
+
+ x = torch.randn(4)
+ random.seed(1)
+ res1 = fn(x)
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ random.seed(1)
+ res2 = opt_fn(x)
+ self.assertTrue(same(res1, res2))
+
+ def test_builtin_getitem(self):
+ # builtin getitem args[0] is python list and args[1] is unspec
+ def fn(x, idx):
+ return (torch.zeros(idx), x[idx], x[idx:])
+
+ x = list(range(50))
+ ref = fn(x, 48) # 48 is unspecialized
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res = opt_fn(x, 48)
+ self.assertTrue(same(ref, res))
+
+ @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
+ def test_builtin_functions_on_cuda(self):
+ def fn(x, scaler):
+ m = torch.nn.ReLU()
+ y = m(x) * scaler
+ return y
+
+ x = torch.randn([3, 6], device="cuda")
+ scaler = 0.23 # 0.23 is unspecialized
+ ref = fn(x, scaler)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res = opt_fn(x, scaler)
+ self.assertTrue(same(ref, res))
+ self.assertEqual(ref.device, res.device)
+
+ def test_unspec_float_precision(self):
+ def fn(image, scale_factor):
+ image = torch.nn.functional.interpolate(
+ image[None],
+ size=None,
+ scale_factor=scale_factor,
+ mode="bilinear",
+ recompute_scale_factor=True,
+ align_corners=False,
+ )[0]
+
+ return image.shape
+
+ x = torch.rand([3, 427, 640])
+ scale_factor = 1.873536229133606
+ ref = fn(x, scale_factor)
+ cnts = torch._dynamo.testing.CompileCounter()
+ opt_fn = torch._dynamo.optimize(cnts)(fn)
+ res = opt_fn(x, scale_factor)
+ self.assertTrue(same(ref, res))
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/dynamo/test_verify_correctness.py b/test/dynamo/test_verify_correctness.py
new file mode 100644
index 0000000000000..f9d820f44c299
--- /dev/null
+++ b/test/dynamo/test_verify_correctness.py
@@ -0,0 +1,174 @@
+# Owner(s): ["module: dynamo"]
+import importlib
+import operator
+import unittest
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo
+import torch._dynamo.config as config
+from torch._dynamo.optimizations import backends
+from torch._dynamo.testing import same
+
+
+def has_onnxruntime():
+ try:
+ importlib.import_module("onnxruntime")
+ return True
+ except ImportError:
+ return False
+
+
+def has_ipex():
+ try:
+ importlib.import_module("intel_extension_for_pytorch")
+ return True
+ except ImportError:
+ return False
+
+
+class Seq(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(10, 10),
+ torch.nn.ReLU(),
+ torch.nn.Linear(10, 10),
+ torch.nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class Conv_Bn_Relu(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(Conv_Bn_Relu, self).__init__()
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+ self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ return self.relu(self.bn(self.conv(x)))
+
+
+def toy_example(a, b):
+ x = a / (torch.abs(a) + 1)
+ if b.sum() < 0:
+ b = b * -1
+ return x * b
+
+
+def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+ for node in gm.graph.nodes:
+ # Checks if we're calling a function (i.e:
+ # operator.add)
+ if node.op == "call_function":
+ # The target attribute is the function
+ # that call_function calls.
+ if node.target == operator.mul:
+ node.target = operator.add
+
+ gm.graph.lint() # Does some checks to make sure the
+ # Graph is well-formed.
+
+ gm.recompile()
+ return gm
+
+
+class TestVerifyCorrectness(torch._dynamo.testing.TestCase):
+ @patch.object(config, "verify_correctness", True)
+ def test_example_inputs(self):
+ def fn(a, bc, d):
+ b, c = bc
+ return a / d - b / c
+
+ def compiler_fn(graph, example_inputs):
+ nonlocal r1
+ r1 = graph(*example_inputs)[0]
+ return graph.forward
+
+ a = torch.empty(2).fill_(1)
+ b = torch.empty(2).fill_(2)
+ c = torch.empty(2).fill_(3)
+ d = 4
+ r1 = None
+ r2 = fn(a, (b, c), d)
+ opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
+ r3 = opt_fn(a, (b, c), d)
+
+ self.assertIsNotNone(r1)
+ self.assertTrue(same(r1, r2))
+ self.assertTrue(same(r1, r3))
+
+ @patch.object(config, "verify_correctness", True)
+ def test_nnc(self):
+ s = Seq()
+ i = torch.randn(10)
+ r1 = s(i)
+ opt_s = torch._dynamo.optimize("nnc")(s)
+ r2 = opt_s(i)
+ self.assertTrue(same(r1, r2))
+
+ @patch.object(config, "verify_correctness", True)
+ def test_incorrect_verify_true(self):
+ """
+ If a bad optimization return a graph that
+ is not functionally equal to the original graph;
+ When config.verify_correctness=True, it will
+ check the correctness of outputs and raise an error
+ """
+ i1 = torch.randn(10)
+ i2 = torch.randn(10)
+
+ def incorrect_compile_fn(gm, example_inputs):
+ return transform(gm).forward
+
+ toy_example(i1, i2)
+ try:
+ opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
+ opt_toy_example(i1, i2)
+ except RuntimeError:
+ pass
+ else:
+ self.fail("expected failure")
+
+ @patch.object(config, "verify_correctness", False)
+ def test_incorrect_verify_false(self):
+ """
+ The bad optimization return a graph that
+ is not functionally equal to the original graph;
+ When config.verify_correctness=False, wrong outputs
+ will return
+ """
+ i1 = torch.randn(10)
+ i2 = torch.randn(10)
+
+ def incorrect_compile_fn(gm, example_inputs):
+ return transform(gm).forward
+
+ r1 = toy_example(i1, i2)
+ opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
+ r2 = opt_toy_example(i1, i2)
+ self.assertTrue(not same(r1, r2))
+
+ @unittest.skipIf(not has_ipex(), "requires ipex")
+ @patch.object(config, "verify_correctness", True)
+ def test_ipex_fp32(self):
+ model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
+ model = model.to(memory_format=torch.channels_last)
+ model = model.eval()
+ input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
+ r1 = model(input)
+ opt_model = torch._dynamo.optimize(backends.ipex_fp32)(model)
+ with torch.no_grad():
+ r2 = opt_model(input)
+ self.assertTrue(same(r1, r2))
+ self.assertEqual(r2.dtype, torch.float32)
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests()
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index bc7e82bda6cdf..8d1c0dba70131 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -7,7 +7,7 @@
# LICENSE file in the root directory of this source tree.
from unittest.mock import patch
-from torch.testing._internal.common_utils import TestCase, run_tests, IS_ARM64
+from torch.testing._internal.common_utils import TestCase, run_tests, IS_ARM64, IS_WINDOWS
import torch
import torch.nn as nn
import torch.utils._pytree as pytree
@@ -60,7 +60,8 @@
try:
import sympy # noqa: F401
- HAS_SYMPY = True
+ # TODO(jansel): these tests fail on windows
+ HAS_SYMPY = not IS_WINDOWS
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
diff --git a/test/inductor/__init__.py b/test/inductor/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/test/inductor/cpp/.gitignore b/test/inductor/cpp/.gitignore
new file mode 100644
index 0000000000000..37b0b62a96b87
--- /dev/null
+++ b/test/inductor/cpp/.gitignore
@@ -0,0 +1,13 @@
+CMakeLists.txt.user
+CMakeCache.txt
+CMakeFiles
+CMakeScripts
+Testing
+Makefile
+cmake_install.cmake
+install_manifest.txt
+compile_commands.json
+CTestTestfile.cmake
+_deps
+lib
+bin
diff --git a/test/inductor/cpp/CMakeLists.txt b/test/inductor/cpp/CMakeLists.txt
new file mode 100644
index 0000000000000..cc4954fc895ad
--- /dev/null
+++ b/test/inductor/cpp/CMakeLists.txt
@@ -0,0 +1,47 @@
+project(my-project LANGUAGES C CXX)
+
+# Build output setup
+set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/test/lib)
+set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/test/lib)
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/test/bin)
+
+# TODO(voz): Fix hack below
+# Start hack
+list(APPEND policies_new CMP0079)
+
+foreach(policy ${policies_new})
+ if(POLICY ${policy})
+ cmake_policy(SET ${policy} NEW)
+ endif()
+endforeach()
+# End hack
+
+################################
+# GTest
+################################
+project(googletest-git NONE)
+
+include(FetchContent)
+FetchContent_Declare(
+ googletest
+ GIT_REPOSITORY https://github.com/google/googletest.git
+ GIT_TAG release-1.12.1
+)
+
+set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
+set(BUILD_GMOCK OFF CACHE BOOL "" FORCE)
+set(BUILD_GTEST ON CACHE BOOL "" FORCE)
+
+FetchContent_MakeAvailable(googletest)
+
+
+
+################################
+# Tests
+################################
+
+# TODO(voz): This is a little assumptive of just this one test, rewrite with real dir includes
+include_directories(${ATEN_INCLUDE})
+add_executable(test_cpp_prefix test_cpp_prefix.cpp ../../torchinductor/codegen/cpp_prefix.h)
+target_link_libraries(test_cpp_prefix gtest gtest_main)
+add_test(NAME test_cpp_prefix COMMAND test_cpp_prefix)
diff --git a/test/inductor/cpp/test.sh b/test/inductor/cpp/test.sh
new file mode 100755
index 0000000000000..055b740cc1e3e
--- /dev/null
+++ b/test/inductor/cpp/test.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+set -euo pipefail
+IFS=$'\n\t'
+
+cmake . -DATEN_INCLUDE:PATH=$(python -c "import torch; from torch.utils import cpp_extension; print(cpp_extension.include_paths()[0])")
+make
+./test/bin/test_cpp_prefix
diff --git a/test/inductor/cpp/test_cpp_prefix.cpp b/test/inductor/cpp/test_cpp_prefix.cpp
new file mode 100644
index 0000000000000..08d379fe3a05b
--- /dev/null
+++ b/test/inductor/cpp/test_cpp_prefix.cpp
@@ -0,0 +1,21 @@
+#include "../../torchinductor/codegen/cpp_prefix.h"
+#include
+
+TEST(testCppPrefix, testAtomicAddInt) {
+ int x = 0;
+ atomic_add(&x, 100);
+ EXPECT_EQ(x, 100);
+}
+
+TEST(testCppPrefix, testAtomicAddFloat) {
+ float x = 0.0f;
+ atomic_add(&x, 100.0f);
+ EXPECT_EQ(x, 100.0f);
+}
+
+TEST(testCppPrefix, testAtomicAddI64) {
+ int64_t x = 0.0;
+ int64_t y = 100.0;
+ atomic_add(&x, y);
+ EXPECT_EQ(x, 100);
+}
diff --git a/test/inductor/opinfo_harness.py b/test/inductor/opinfo_harness.py
new file mode 100644
index 0000000000000..86077582134dc
--- /dev/null
+++ b/test/inductor/opinfo_harness.py
@@ -0,0 +1,25 @@
+import os
+import subprocess
+
+from torch.testing._internal.common_methods_invocations import op_db
+
+if __name__ == "__main__":
+ i = 0
+ while i < len(op_db):
+ start = i
+ end = i + 20
+ os.environ["PYTORCH_TEST_RANGE_START"] = f"{start}"
+ os.environ["PYTORCH_TEST_RANGE_END"] = f"{end}"
+ popen = subprocess.Popen(
+ ["pytest", "test/inductor/test_torchinductor_opinfo.py"],
+ stdout=subprocess.PIPE,
+ )
+ for line in popen.stdout:
+ print(line.decode(), end="")
+ popen.stdout.close()
+ return_code = popen.wait()
+ if return_code:
+ raise subprocess.CalledProcessError(
+ return_code, ["pytest", "test/inductor/test_torchinductor_opinfo.py"]
+ )
+ i = end + 1
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
new file mode 100644
index 0000000000000..47e7e4c417220
--- /dev/null
+++ b/test/inductor/test_torchinductor.py
@@ -0,0 +1,3957 @@
+# Owner(s): ["module: inductor"]
+import contextlib
+import dataclasses
+import functools
+import importlib
+import random
+import sys
+import unittest
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo
+from torch._dynamo.debug_utils import same_two_models
+from torch._dynamo.testing import rand_strided, same
+from torch.fx.experimental.proxy_tensor import make_fx
+from torch.nn import functional as F
+from torch.testing._internal.common_utils import (
+ TEST_WITH_ASAN,
+ TestCase as TorchTestCase,
+)
+from torch.utils._pytree import tree_flatten, tree_unflatten
+
+try:
+ import sympy
+
+ importlib.import_module("functorch")
+ importlib.import_module("filelock")
+
+ import torch._inductor.config
+ from functorch.compile import config as functorch_config
+ from torch._decomp import get_decompositions
+ from torch._inductor import config
+ from torch._inductor.compile_fx import compile_fx
+ from torch._inductor.ir import IndexingDiv, ModularIndexing
+ from torch._inductor.sizevars import SizeVarAllocator
+ from torch._inductor.utils import has_torchvision_roi_align, timed
+
+ # This will only pass on pytorch builds newer than roughly 5/15/2022
+ assert get_decompositions([torch.ops.aten.trace])
+ # Requires functorch
+ from torch._inductor.compile_fx import compile_fx_inner
+except (ImportError, AssertionError) as e:
+ sys.stderr.write(f"{type(e)}: {e}\n")
+ if __name__ == "__main__":
+ sys.exit(0)
+ raise unittest.SkipTest("requires sympy/functorch/filelock")
+
+
+HAS_CPU = False
+try:
+ from subprocess import CalledProcessError
+
+ from torch._inductor.codecache import CppCodeCache
+
+ CppCodeCache.load("")
+ HAS_CPU = True
+except (
+ CalledProcessError,
+ OSError,
+ torch._inductor.exc.InvalidCxxCompiler,
+ torch._inductor.exc.CppCompileError,
+):
+ pass
+
+aten = torch.ops.aten
+
+HAS_CUDA = False
+if torch.cuda.is_available():
+ try:
+ importlib.import_module("triton")
+ HAS_CUDA = True
+ except ImportError:
+ pass
+
+requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
+
+torch._inductor.config.triton.autotune = False # too slow
+
+
+def requires_decomp(fn):
+ """Decorator to disable test if a decomp is missing"""
+
+ def wrap_test(test):
+ @functools.wraps(test)
+ def maybe_test(*args, **kwargs):
+ if len(get_decompositions([fn])) == 0:
+ raise unittest.SkipTest(f"requires decomp for {fn.__name__}")
+ return test(*args, **kwargs)
+
+ return maybe_test
+
+ return wrap_test
+
+
+class TestCase(TorchTestCase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._stack = contextlib.ExitStack()
+ cls._stack.enter_context(patch.object(config, "debug", True))
+ cls._stack.enter_context(patch.object(config.cpp, "min_chunk_size", 1))
+
+ @classmethod
+ def tearDownClass(cls):
+ cls._stack.close()
+ super().tearDownClass()
+
+
+class ToTuple(torch.nn.Module):
+ def forward(self, x):
+ return (x,)
+
+
+@dataclasses.dataclass
+class InputGen:
+ n: int
+ device: str
+
+ def dense(self):
+ return torch.randn((self.n, self.n), device=self.device)
+
+ def transposed(self):
+ return self.dense().transpose(0, 1)
+
+ def strided(self):
+ return torch.randn((self.n * 2, self.n * 3), device=self.device)[
+ self.n :, self.n :: 2
+ ]
+
+ def broadcast1(self):
+ return torch.randn((self.n,), device=self.device)
+
+ def broadcast2(self):
+ return torch.randn((1, self.n, 1), device=self.device)
+
+ def broadcast3(self):
+ return torch.randn((1,), device=self.device)
+
+ def double(self):
+ return torch.randn((self.n, self.n), device=self.device, dtype=torch.double)
+
+ def int(self):
+ return torch.arange(self.n, device=self.device, dtype=torch.int32)
+
+
+def compute_grads(args, kwrags, results, grads):
+ def gather_leaf_tensors(args, kwargs):
+ args, _ = tree_flatten(args)
+ kwargs, _ = tree_flatten(kwargs)
+ args = args + kwargs
+ leaf_tensors = [
+ arg for arg in args if isinstance(arg, torch.Tensor) and arg.requires_grad
+ ]
+ return leaf_tensors
+
+ flat_results, _ = tree_flatten(results)
+ flat_diff_results = [r for r in flat_results if r.requires_grad]
+ assert len(flat_diff_results) > 0
+
+ leaf_tensors = gather_leaf_tensors(args, kwrags)
+ assert len(leaf_tensors) > 0
+ return torch.autograd.grad(
+ flat_diff_results,
+ leaf_tensors,
+ grads,
+ allow_unused=True,
+ retain_graph=True,
+ )
+
+
+@patch.object(torch._inductor.config.triton, "cudagraphs", False)
+@patch("torch._dynamo.config.raise_on_backend_error", True)
+def check_model(
+ self: TestCase,
+ model,
+ example_inputs,
+ kwargs=None,
+ *,
+ atol=None,
+ rtol=None,
+ check_lowp=True,
+ exact_dtype=True,
+ nopython=True,
+ copy_to_cuda=True,
+ reference_in_float=True,
+ assert_equal=True,
+ check_gradient=False,
+):
+ kwargs = kwargs or {}
+ torch._dynamo.reset()
+
+ ref_inputs = example_inputs
+ ref_kwargs = kwargs
+ has_lowp_args = False
+
+ if reference_in_float:
+ # check_lowp is ignored here, it's kept just to be able to call `common` with extra arg
+ def upcast_fn(x):
+ nonlocal has_lowp_args
+ if isinstance(x, torch.Tensor) and (
+ x.dtype == torch.float16 or x.dtype == torch.bfloat16
+ ):
+ has_lowp_args = True
+ return x.float()
+ else:
+ return x
+
+ ref_inputs = list(map(upcast_fn, example_inputs))
+ ref_kwargs = {k: upcast_fn(v) for k, v in kwargs.items()}
+ if has_lowp_args:
+ if hasattr(model, "to"):
+ model = model.to(torch.float)
+
+ torch.manual_seed(0)
+
+ correct = model(*ref_inputs, **ref_kwargs)
+ # downcast the model back if needed
+ if reference_in_float and has_lowp_args:
+ if hasattr(model, "to"):
+ model = model.to(torch.half)
+
+ torch._inductor.metrics.reset()
+
+ called = False
+
+ def compile_fx_wrapper(model_, example_inputs_):
+ nonlocal called
+ called = True
+ return compile_fx(model_, example_inputs_)
+
+ def run(*ex, **kwargs):
+ return model(*ex, **kwargs)
+
+ run = torch._dynamo.optimize(compile_fx_wrapper, nopython=nopython)(run)
+
+ torch.manual_seed(0)
+ actual = run(*example_inputs, **kwargs)
+ # if not called:
+ # exp = torch._dynamo.explain(run, *example_inputs)
+ # print("Explain:", exp[0])
+ # for graph in exp[2]:
+ # print("Graph", graph)
+ assert called, "Ran graph without calling compile_fx"
+
+ assert type(actual) == type(correct)
+
+ correct_flat, correct_spec = tree_flatten(correct)
+ actual_flat, _ = tree_flatten(actual)
+ if reference_in_float:
+ correct_flat = tuple(
+ y.to(x.dtype)
+ if isinstance(y, torch.Tensor) and y.dtype.is_floating_point
+ else y
+ for x, y in zip(actual_flat, correct_flat)
+ )
+ correct = tree_unflatten(correct_flat, correct_spec)
+
+ if assert_equal:
+ self.assertEqual(
+ actual,
+ correct,
+ atol=atol,
+ rtol=rtol,
+ equal_nan=True,
+ exact_dtype=exact_dtype,
+ )
+ else:
+ for correct_val, actual_val in zip(correct_flat, actual_flat):
+ if isinstance(correct_val, torch.Tensor):
+ assert correct_val.device == actual_val.device
+ assert correct_val.size() == actual_val.size()
+ assert correct_val.stride() == actual_val.stride()
+ assert correct_val.layout == actual_val.layout
+ if exact_dtype:
+ assert correct_val.dtype == actual_val.dtype
+
+ if check_gradient:
+
+ # generate random unit norm gradients
+ grads = [
+ torch.rand(r.shape, device=r.device, dtype=r.dtype)
+ for r in correct_flat
+ if r.requires_grad
+ ]
+ for g in grads:
+ g /= g.norm()
+
+ correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads)
+ actual_grad = compute_grads(example_inputs, kwargs, actual, grads)
+
+ self.assertEqual(
+ actual_grad,
+ correct_grad,
+ atol=atol,
+ rtol=rtol,
+ equal_nan=True,
+ exact_dtype=exact_dtype,
+ )
+
+ torch._dynamo.reset()
+
+
+@patch.object(torch._inductor.config.triton, "cudagraphs", False)
+def check_model_cuda(
+ self: TestCase,
+ model,
+ example_inputs,
+ kwargs=None,
+ *,
+ atol=None,
+ rtol=None,
+ check_lowp=True,
+ exact_dtype=True,
+ nopython=True,
+ copy_to_cuda=True,
+ reference_in_float=True,
+ assert_equal=True,
+ check_gradient=False,
+):
+ kwargs = kwargs or {}
+ if hasattr(model, "to"):
+ model = model.to("cuda")
+
+ def copy_fn(x):
+ # preserve strides of the input on the device
+ if not isinstance(x, torch.Tensor):
+ return x
+ return torch.empty_strided(
+ x.size(), x.stride(), device="cuda", dtype=x.dtype
+ ).copy_(x)
+
+ if copy_to_cuda:
+ example_inputs = tuple(copy_fn(x) for x in example_inputs)
+
+ check_model(
+ self,
+ model,
+ example_inputs,
+ kwargs,
+ atol=atol,
+ rtol=rtol,
+ exact_dtype=exact_dtype,
+ nopython=nopython,
+ reference_in_float=reference_in_float,
+ assert_equal=assert_equal,
+ check_gradient=check_gradient,
+ )
+
+ if check_lowp:
+
+ def downcast_fn(x):
+ if not isinstance(x, torch.Tensor) or not x.dtype == torch.float:
+ return x
+ return torch.empty_strided(
+ x.size(), x.stride(), device="cuda", dtype=torch.half
+ ).copy_(x)
+
+ example_inputs = list(map(downcast_fn, example_inputs))
+ if hasattr(model, "to"):
+ model = model.to(torch.half)
+ check_model(
+ self,
+ model,
+ example_inputs,
+ kwargs,
+ atol=atol,
+ rtol=rtol,
+ exact_dtype=exact_dtype,
+ nopython=nopython,
+ reference_in_float=reference_in_float,
+ assert_equal=assert_equal,
+ check_gradient=check_gradient,
+ )
+
+
+class SweepInputs2:
+ input_gen_types1 = [
+ "dense",
+ "transposed",
+ "strided",
+ "broadcast1",
+ "broadcast2",
+ "broadcast3",
+ "double",
+ "int",
+ ]
+ input_gen_types2 = input_gen_types1
+ gen = None
+
+ @staticmethod
+ def kernel(a, b):
+ return (a + b,)
+
+ @classmethod
+ def gen_template(cls, name1, name2):
+ def test(self):
+ check_model(
+ self,
+ cls.kernel,
+ (
+ getattr(cls.gen, name1)(),
+ getattr(cls.gen, name2)(),
+ ),
+ )
+
+ test.__name__ = f"test_{cls.gen.device}_{name1}_{name2}"
+ setattr(cls, test.__name__, test)
+
+ @classmethod
+ def populate(cls):
+ for name1 in cls.input_gen_types1:
+ for name2 in cls.input_gen_types2:
+ cls.gen_template(name1, name2)
+
+
+class SweepInputsCpuTest(SweepInputs2, TestCase):
+ gen = InputGen(10, "cpu")
+
+
+SweepInputsCpuTest.populate()
+
+
+class TestIndexingSimplification(TorchTestCase):
+ def test_indexing_simplification(self):
+ sizevars = SizeVarAllocator()
+ i0 = sympy.Symbol("i0")
+ i1 = sympy.Symbol("i1")
+ i2 = sympy.Symbol("i2")
+ r3 = sympy.Symbol("r3")
+
+ var_ranges = {i0: 3136, i1: 64, i2: 32, r3: 3}
+ expr = (
+ 128 * i2
+ + ModularIndexing(i1, 1, 64)
+ + 64 * ModularIndexing(i1 + 64 * r3, 64, 2)
+ )
+ # check that `i1//64` is removed when i1 is always less than 64,
+ # and the next simplificaton doesn't happen
+ self.assertEqual(
+ sizevars.simplify_with_ranges(expr, var_ranges),
+ i1 + 128 * i2 + 64 * ModularIndexing(r3, 1, 2),
+ )
+ # all the modular indexing should be removed when the body cant be larger than the modulus
+ var_ranges[r3] = 2
+ self.assertEqual(
+ sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3
+ )
+
+ # small terms should be kept if the rest is not guaranteed to be divisible
+ self.assertEqual(
+ sizevars.simplify_with_ranges(IndexingDiv(r3 + i2 + i1, 32), var_ranges),
+ IndexingDiv(r3 + i2 + i1, 32),
+ )
+
+ expr = ModularIndexing(2 * i2 + r3, 1, 64)
+ # modular indexing is removed if base is smaller than modulo
+ self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), 2 * i2 + r3)
+
+ # check the same thing but with symbolic divisor
+ self.assertEqual(IndexingDiv(r3 * i0, r3), i0)
+ self.assertEqual(ModularIndexing(r3 * i0, r3, 10), ModularIndexing(i0, 1, 10))
+
+ # (10*i) % 10 is always zero and should get optimized away
+ self.assertEqual(
+ ModularIndexing(i0 + i1 * 10, 1, 10), ModularIndexing(i0, 1, 10)
+ )
+
+ # ((20*i)//2) % 10 is always zero and should get optimized away
+ self.assertEqual(
+ ModularIndexing(i0 + i1 * 20, 2, 10), ModularIndexing(i0, 2, 10)
+ )
+
+ # the same things happens with symbolic divisor
+ self.assertEqual(
+ ModularIndexing(i0 + i1 * i2 * r3, i2, r3), ModularIndexing(i0, i2, r3)
+ )
+
+ # Constant fold from divisor into base
+ self.assertEqual(ModularIndexing(i0 * 4, 2, 10), ModularIndexing(i0 * 2, 1, 10))
+ self.assertEqual(IndexingDiv(i0 * 4, 2), i0 * 2)
+
+ # Nested modular indexing is correctly simplified
+ var_ranges = {"i1": 13, "i2": 121}
+ expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784), 1, 28)
+ self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
+ expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784) + 1, 1, 28)
+ self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
+ var_ranges = {"i2": 784}
+ expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4)
+ expected = IndexingDiv(ModularIndexing(i2, 1, 28), 7)
+ self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected)
+ expr = ModularIndexing(ModularIndexing(i2, 1, 28) + 1, 7, 4)
+ self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
+
+ def test_indexing_join(self):
+ sizevars = SizeVarAllocator()
+ i0 = sympy.Symbol("i0")
+ i1 = sympy.Symbol("i1")
+ i2 = sympy.Symbol("i2")
+
+ # join two ModularIndexing calls into one larger one when possible
+ expr1 = ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
+ self.assertEqual(
+ sizevars.simplify_with_ranges(expr1, {}), ModularIndexing(i0, 1, 128)
+ )
+
+ # it should also work with a scale
+ self.assertEqual(
+ sizevars.simplify_with_ranges(2 * expr1, {}),
+ 2 * ModularIndexing(i0, 1, 128),
+ )
+
+ # it should work when divisor is not 1
+ expr2 = ModularIndexing(i0, 3, 32) + 32 * ModularIndexing(i0, 32 * 3, 4)
+ simplified = sizevars.simplify_with_ranges(expr2, {})
+ self.assertEqual(simplified, ModularIndexing(i0, 3, 128))
+ self.assertEqual(expr2.subs({i0: 39485}), simplified.subs({i0: 39485}))
+
+ # it should not happen in this case as the modulus is wrong
+ expr3 = ModularIndexing(i0, 1, 30) + 32 * ModularIndexing(i0, 32, 4)
+ self.assertEqual(sizevars.simplify_with_ranges(expr3, {}), expr3)
+
+ # check that it also works with a modulus>1
+ expr4 = ModularIndexing(i0, 10, i1) + i1 * ModularIndexing(i0, i1 * 10, i2)
+ res0 = expr4.subs({i0: 24056, i1: 13, i2: 19})
+ simplified = sizevars.simplify_with_ranges(expr4, {})
+ res1 = simplified.subs({i0: 24056, i1: 13, i2: 19})
+ self.assertEqual(res0, res1)
+ self.assertEqual(simplified, ModularIndexing(i0, 10, i1 * i2))
+
+ # and also works with an offset
+ self.assertEqual(
+ sizevars.simplify_with_ranges(expr4 + 10, {}),
+ ModularIndexing(i0, 10, i1 * i2) + 10,
+ )
+
+ # works for ModularIndexing + IndexingDiv
+ expr5 = 197 * IndexingDiv(i0, 197) + ModularIndexing(i0, 1, 197)
+ simplified = sizevars.simplify_with_ranges(expr5, {})
+ self.assertEqual(simplified, i0)
+ self.assertEqual(expr5.subs({i0: 39485}), simplified.subs({i0: 39485}))
+
+ # works with a scale
+ self.assertEqual(
+ sizevars.simplify_with_ranges(2 * expr5, {}),
+ 2 * i0,
+ )
+
+ # divisor != 1
+ expr6 = 197 * IndexingDiv(i0, 197 * 3) + ModularIndexing(i0, 3, 197)
+ simplified = sizevars.simplify_with_ranges(expr6, {})
+ self.assertEqual(simplified, IndexingDiv(i0, 3))
+ self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485}))
+
+
+class CommonTemplate:
+ @classmethod
+ def install(my_cls, other_cls, suffix): # noqa: B902
+ for name, value in my_cls.__dict__.items():
+ if name.startswith("test_"):
+ setattr(other_cls, f"{name}_{suffix}", value)
+
+ def test_bool(self):
+ def fn(a, b):
+ return (
+ a + b,
+ a * b,
+ a & b,
+ a | b,
+ a ^ b,
+ torch.logical_and(a, b),
+ torch.logical_or(a, b),
+ torch.logical_not(a),
+ torch.sign(b),
+ )
+
+ self.common(
+ fn,
+ (
+ torch.tensor([True, False, True, False]),
+ torch.tensor([False, False, True, True]),
+ ),
+ )
+
+ def test_add_const_int(self):
+ def fn(a):
+ return (a + 1,)
+
+ self.common(fn, (torch.randn(32),))
+
+ def test_add_const_float(self):
+ def fn(a):
+ return (a + 1.5,)
+
+ self.common(fn, (torch.randn(32),))
+
+ def test_add_inplace_permuted(self):
+ def fn(x, y):
+ return x.add_(y)
+
+ x = torch.ones([2, 12, 13, 17]).transpose(1, 2)
+ y = torch.randn([2, 13, 1, 17])
+
+ self.common(fn, (x, y))
+
+ def test_abs(self):
+ def fn(a):
+ return (a / (torch.abs(a) + 1),)
+
+ self.common(fn, (torch.randn(17),))
+
+ def test_sgn(self):
+ def fn(a):
+ return torch.sgn(a), torch.sgn(a + 1) - 1
+
+ self.common(fn, [torch.linspace(-10, 10, 41)])
+
+ def test_max_min(self):
+ def fn(a, b):
+ return (torch.maximum(a, b), torch.minimum(a, b))
+
+ self.common(fn, (torch.randn(8), torch.randn(8)))
+
+ def test_horizonal_fusion1(self):
+ def fn(a, b, c):
+ return (a + b, a - c, b * c)
+
+ self.common(
+ fn, (torch.randn(8, 16, 16), torch.randn(8, 16, 16), torch.randn(1, 16, 1))
+ )
+
+ def test_horizonal_fusion2(self):
+ def fn(a, b, c):
+ return a + 1, b + 2, c + 3
+
+ self.common(fn, (torch.randn(8, 16, 8), torch.randn(8, 16), torch.randn(16, 8)))
+
+ def test_vertical_fusion1(self):
+ def fn(sa, ct, p):
+ # From torchbench.pyhpc_equation_of_state
+ v17 = -3.087032500374211e-7
+ v18 = -1.988366587925593e-8
+ v19 = -1.061519070296458e-11
+ v20 = 1.550932729220080e-10
+ t15 = v19 * ct
+ t19 = v17 + ct * (v18 + t15) + v20 * sa
+ t20 = 1.0 / t19
+ t128 = t19 * p
+ return t20 + t128
+
+ self.common(
+ fn,
+ (
+ torch.randn(204, 204, 26),
+ torch.randn(204, 204, 26),
+ torch.randn(26),
+ ),
+ )
+ self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
+
+ def test_sum1(self):
+ def fn(a, b):
+ return ((a + b).sum(-1),)
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_sum2(self):
+ def fn(a, b):
+ return ((a + b).sum([1, 2]), (a + b).sum(-1))
+
+ self.common(fn, (torch.randn(8, 9, 3, 21), torch.randn(8, 9, 3, 21)))
+
+ def test_sum3(self):
+ def fn(a, b):
+ r1 = a + b
+ r2 = r1.sum(-1)
+ r3 = torch.squeeze(b) + 10
+ return (r1, r2, r3)
+
+ # Mismatched elements: 2 / 10 (20.0%)
+ # Greatest absolute difference: 0.0029296875 at index (8,) (up to 1e-05 allowed)
+ # Greatest relative difference: 0.0017482517482517483 at index (6,) (up to 0.001 allowed)
+ self.common(fn, (torch.randn(10, 10), torch.randn(1, 10)), atol=1e-5, rtol=2e-3)
+
+ def test_sum4(self):
+ def fn(a):
+ b = a + 1
+ c = b.sum(-1)
+ d = c + 3
+ e = d.sum(-1)
+ f = e + 5
+ return (f, e, d, c, b)
+
+ self.common(fn, (torch.randn(1, 16, 8, 8),))
+
+ def test_sum5(self):
+ def fn(a):
+ b = a + 1
+ c = b.sum(-1)
+ d = c + 3
+ e = d.sum(-1)
+ f = e + 5
+ return (f,)
+
+ self.common(fn, (torch.randn(1, 17, 8, 9),))
+
+ def test_reduction1(self):
+ def fn(a):
+ return (a.sum(), a.max(), a.min(), a.argmax(), a.argmin())
+
+ self.common(fn, (torch.tensor([float("-inf"), 0.0, float("inf")]),))
+
+ def test_reduction2(self):
+ def fn(a):
+ # FIXME: a.argmax
+ return (a.sum(), a.max(), a.min(), a.argmin())
+
+ self.common(fn, (torch.full((4,), float("inf")),))
+
+ def test_reduction3(self):
+ def fn(a):
+ # FIXME: a.argmin
+ return (a.sum(), a.max(), a.min(), a.argmax())
+
+ self.common(fn, (torch.full((4,), float("-inf")),))
+
+ @patch.object(config, "dynamic_shapes", False)
+ def test_unroll_small_reduction(self):
+ def fn(x):
+ val1, index1 = x.min(-1)
+ val2, index2 = x.max(-1)
+ return (
+ val1,
+ index1,
+ val2,
+ index2,
+ x.sum(-1),
+ (x > 1).any(-1),
+ (x > 0).all(-1),
+ x.argmin(-1),
+ x.argmax(-1),
+ x.amin(-1),
+ x.amax(-1),
+ )
+
+ with patch.object(config, "unroll_reductions_threshold", 8):
+ # small sized reductions will get unrolled
+ self.common(fn, (torch.randn(8, 3),))
+ torch._dynamo.reset()
+ with patch.object(config, "unroll_reductions_threshold", 1):
+ # make sure things also work if they aren't unrolled
+ self.common(fn, (torch.randn(8, 3),))
+
+ def test_multilayer_low_prec(self):
+ # fp16 nyi for cpu
+ if self.device == "cpu":
+ raise unittest.SkipTest("requires CUDA")
+
+ def fn(a):
+ return torch.mean(a)
+
+ self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))
+
+ def test_expanded_reduction(self):
+ def fn(x, y):
+ z = x * y
+ return z.sum((0, 1))
+
+ self.common(fn, (torch.randn(2, 197, 256), torch.randn(2, 1, 256)))
+
+ def test_min_max_reduction(self):
+ def fn(a, b):
+ return ((a + b).max(), (a + b).min(), torch.amax(a + 1, keepdim=True))
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_sum_int(self):
+ def fn(x):
+ return 2 * x.sum(-1) + x.sum()
+
+ dtypes = torch.bool, torch.uint8, torch.int
+ inps = [torch.randint(2, (64,), dtype=dtype) for dtype in dtypes]
+ for i in inps:
+ self.common(fn, (i,), check_lowp=False)
+
+ def test_sum_dtype(self):
+ def fn(x):
+ return x * x.sum(-1, dtype=torch.double) + x.sum(dtype=torch.double)
+
+ self.common(fn, (torch.ones(32, 32) * 70,))
+
+ def test_clamp(self):
+ def fn(a, b):
+ return (a.clamp(-0.1, 0.1), b.clamp(0), torch.clamp(a + b, max=0))
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_arange1(self):
+ def fn(x):
+ rng1 = torch.arange(8 * 8, dtype=torch.float32, device=x.device).view(8, 8)
+ rng2 = torch.arange(10, 18, device=x.device)
+ tmp = x * rng1
+ return tmp, tmp + rng2
+
+ self.common(fn, (torch.randn(8, 8),))
+
+ def test_arange2(self):
+ def fn(x):
+ rng1 = torch.arange(8, device=x.device)
+ return (x + rng1,)
+
+ self.common(fn, (torch.randint(4, (8, 8)),), check_lowp=False)
+
+ def test_arange3(self):
+ def fn(x):
+ return x + torch.ops.aten.arange.start_step(
+ 0, 53, 4, dtype=torch.int64, device=x.device
+ )
+
+ self.common(fn, (torch.randn(14),))
+
+ def test_arange4(self):
+ def fn(x):
+ return x - torch.arange(512, -512, -1.0, device=x.device)
+
+ self.common(fn, (torch.randn(1024),))
+
+ def test_linspace(self):
+ def fn(x):
+ return torch.linspace(0.125, 0.875, 7, device=x.device) + x
+
+ self.common(fn, (torch.randn(1, 7),))
+
+ def test_tensor1(self):
+ def fn(x):
+ return torch.tensor([1], device=x.device) + x, torch.tensor(
+ 5, device=x.device
+ )
+
+ self.common(fn, (torch.randn(10),))
+
+ def test_tensor2(self):
+ def fn(x):
+ return torch.tensor(list(range(2, 40, 2)), device=x.device) + x
+
+ self.common(fn, (torch.randn(1),))
+
+ def test_tensor3(self):
+ def fn(x):
+ return (
+ torch.tensor([], device=x.device),
+ torch.tensor([1, 2], device=x.device) + 1,
+ torch.tensor([1, 2, 3], device=x.device) + 2,
+ torch.tensor([1, 2, 3, 4], device=x.device) + x,
+ )
+
+ self.common(fn, [torch.randn(4)])
+
+ def test_views1(self):
+ def fn1(x, y):
+ return (x.view(size2) + y,)
+
+ def fn2(x, y):
+ return ((x + 1).view(size2) + y,)
+
+ views = [
+ ([5 * 7], [5, 7]),
+ ([2 * 3 * 4 * 5 * 6 * 7], [2, 3, 4, 5, 6, 7]),
+ ([2 * 3, 4, 5, 6 * 7], [2, 3, 4, 5, 6, 7]),
+ ([10 * 5, 20], [10, 5, 20]),
+ ([1, 10, 1], [10]),
+ ([10, 1, 10, 1, 10], [10, 100]),
+ ([2, 2, 2, 2], [4, 4]),
+ ]
+ for size1, size2 in views:
+ self.common(fn1, (torch.randn(size1), torch.randn(size2)))
+ self.common(fn2, (torch.randn(size1), torch.randn(size2)))
+
+ for size2, size1 in views:
+ self.common(fn1, (torch.randn(size1), torch.randn(size2)))
+ self.common(fn2, (torch.randn(size1), torch.randn(size2)))
+
+ def test_views2(self):
+ def fn1(x):
+ return (x.view(size2) + 1,)
+
+ def fn2(x):
+ return ((x * 2).view(size2) + 1,)
+
+ for size1, size2 in [
+ ([2, 2, 2, 2], [4, -1]),
+ ([10, 1, 10, 1, 10], [-1, 100]),
+ ([10 * 5, 20], [10, -1, 20]),
+ ]:
+ self.common(fn1, (torch.randn(size1),))
+ self.common(fn2, (torch.randn(size1),))
+
+ def test_views3(self):
+ # example taken from hf_BigBird
+ def forward(arg1, arg2):
+ index = torch.ops.aten.index(arg1, [arg2])
+ view_1 = torch.ops.aten.view(index, [1, 2232, 64])
+ view_2 = torch.ops.aten.view(view_1, [1, 12, 62, 192])
+ return view_2
+
+ self.common(
+ forward,
+ (
+ rand_strided((64, 64), (64, 1), torch.float32),
+ rand_strided((2232,), (1,), torch.int64),
+ ),
+ )
+
+ def test_relu(self):
+ def fn(a, b):
+ return (torch.relu(a), torch.relu(a + b) / 10)
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_exp(self):
+ def fn(a, b):
+ return (torch.exp(a), torch.exp(a + b))
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_sigmoid(self):
+ def fn(a, b):
+ return (torch.sigmoid(a), torch.sigmoid(a + b))
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_round(self):
+ def fn(a, b):
+ return torch.round(a), torch.round(b + 1), torch.round(a, decimals=2)
+
+ # without manual_seed, there is some chance this test fails due to:
+ # https://github.com/openai/triton/issues/530
+ torch.manual_seed(0)
+
+ # with *100 we are always getting a number exactly at .5 which we don't do right in half
+ self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 10))
+
+ def test_round_correctness(self):
+ if self.device == "cuda":
+ raise unittest.SkipTest("need to debug tl.libdevice on A100/V100")
+
+ def fn(a):
+ return torch.round(a)
+
+ self.common(
+ fn,
+ [torch.arange(-10, 10, 0.1, dtype=torch.float64)],
+ check_lowp=False,
+ )
+
+ def test_silu(self):
+ def fn(a):
+ return (torch.nn.functional.silu(a),)
+
+ self.common(fn, (torch.randn(8, 8),))
+
+ # TODO(voz): Re-enable this test ASAP https://github.com/pytorch/pytorch/issues/82763
+ @unittest.skip("Skipping due to op bugs")
+ def test_nan_to_num(self):
+ def fn(a):
+ return (
+ torch.nan_to_num(a),
+ torch.nan_to_num(a, nan=3.0),
+ torch.nan_to_num(a, nan=None),
+ torch.nan_to_num(a, posinf=4.0),
+ torch.nan_to_num(a, neginf=5.0),
+ torch.nan_to_num(a, nan=3.0, posinf=4.0, neginf=5.0),
+ )
+
+ self.common(
+ fn,
+ (torch.tensor((float("nan"), float("inf"), float("-inf"), 1.0)),),
+ check_lowp=False, # a much more elaborate test is required to match finfo max's for float and half
+ )
+
+ def test_div1(self):
+ def fn(a, b):
+ return (
+ aten.div(a, b, rounding_mode=None),
+ aten.div(a, b, rounding_mode="floor"),
+ aten.div(a, b, rounding_mode="trunc"),
+ a / b,
+ a // b,
+ )
+
+ self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 100))
+
+ def test_div2(self):
+ def fn(a, b):
+ return (
+ aten.div(a, b, rounding_mode=None),
+ aten.div(a, b, rounding_mode="floor"),
+ aten.div(a, b, rounding_mode="trunc"),
+ a / b,
+ a // b,
+ )
+
+ self.common(fn, (torch.randint(-100, 100, [8, 8]), 100 * torch.randn(8, 8)))
+
+ def test_div3(self):
+ def fn(a, b):
+ return (
+ aten.div(a, b, rounding_mode=None),
+ aten.div(a, b, rounding_mode="floor"),
+ aten.div(a, b, rounding_mode="trunc"),
+ a / b,
+ a // b,
+ )
+
+ a = torch.randint(1, 100, [8, 8])
+ self.common(fn, (a * 2, a))
+
+ def test_div4(self):
+ def fn(a, b):
+ return (
+ aten.div(a, b, rounding_mode=None),
+ aten.div(a, b, rounding_mode="floor"),
+ aten.div(a, b, rounding_mode="trunc"),
+ a / b,
+ a // b,
+ )
+
+ self.common(
+ fn,
+ (torch.randint(-100, 0, [8, 8]), torch.randint(1, 10, [8, 8])),
+ )
+
+ def test_div5(self):
+ def fn(a, b):
+ return (
+ aten.div(a, b, rounding_mode=None),
+ aten.div(a, b, rounding_mode="floor"),
+ aten.div(a, b, rounding_mode="trunc"),
+ a / b,
+ a // b,
+ )
+
+ # divide a scalar
+ self.common(fn, (torch.randint(-100, 0, [8, 8]), 16))
+
+ def test_div6(self):
+ def fn(a, b):
+ return (
+ aten.div(a, b, rounding_mode=None),
+ aten.div(a, b, rounding_mode="floor"),
+ aten.div(a, b, rounding_mode="trunc"),
+ a / b,
+ a // b,
+ )
+
+ # treat boolean as integer
+ self.common(
+ fn,
+ (torch.ones([8, 8], dtype=torch.bool), torch.randint(-100, -1, [8, 8])),
+ )
+
+ def test_div7(self):
+ def fn(a, b):
+ return (
+ aten.div(a, b, rounding_mode=None),
+ aten.div(a, b, rounding_mode="floor"),
+ aten.div(a, b, rounding_mode="trunc"),
+ a / b,
+ a // b,
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randint(2**32, 2**40, [100, 100]),
+ torch.randint(-10, -1, [100, 100]),
+ ),
+ )
+
+ def test_div8(self):
+ def fn(a, b):
+ return (
+ aten.div(a, b, rounding_mode=None),
+ aten.div(a, b, rounding_mode="floor"),
+ aten.div(a, b, rounding_mode="trunc"),
+ a / b,
+ a // b,
+ )
+
+ self.common(fn, (1024, 100))
+
+ def test_both_scalars(self):
+ def fn(a, b):
+ return (
+ aten.add(a, b),
+ aten.add(b, a),
+ aten.sub(a, b),
+ aten.sub(b, a),
+ aten.mul(a, b),
+ aten.mul(b, a),
+ )
+
+ self.common(fn, (4, 3.3), reference_in_float=False)
+
+ def test_sum_keepdims(self):
+ def fn(a, b):
+ return (torch.sum(a + b, -1, keepdim=True),)
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_softmax(self):
+ def fn(a, b):
+ return (torch.softmax(a + b, -1), torch.softmax(a, 0), torch.softmax(b, 1))
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_log_softmax(self):
+ def fn(a, b):
+ return (F.log_softmax(a + b, -1), F.log_softmax(a, 0), F.log_softmax(b, 1))
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_transpose(self):
+ def fn(a, b):
+ return (
+ torch.t(a) + b,
+ torch.transpose(b * 2, 0, 1) + 10,
+ )
+
+ self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
+
+ def test_permute(self):
+ def fn(a):
+ return (
+ torch.permute(a + 1, [2, 1, 4, 0, 3]) + 2,
+ torch.permute(a, [2, 1, 4, 0, 3]) + 2,
+ )
+
+ self.common(fn, (torch.randn(2, 2, 2, 2, 2),))
+
+ def test_expand(self):
+ def fn(a):
+ return (
+ (a + 1).expand(3, 4, 2, 3, 2) + 2,
+ a.expand(2, 1, 2, 3, 2) + 2,
+ ), a.expand(2, -1, 5, -1)
+
+ self.common(fn, (torch.randn(2, 1, 2),))
+
+ def test_squeeze1(self):
+ def fn(a):
+ return ((a + 1).squeeze() + 2, a.squeeze() + 2)
+
+ self.common(fn, (torch.randn(1, 2, 1, 2, 2, 1, 1),))
+
+ def test_squeeze2(self):
+ def fn(a):
+ return ((a + 1).squeeze(-1).squeeze(2) + 2, a.squeeze(0) + 2)
+
+ self.common(fn, (torch.randn(1, 2, 1, 2, 2, 2, 1),))
+
+ def test_simplify_loops(self):
+ def fn(a, b):
+ return a + b
+
+ self.common(
+ fn,
+ (
+ torch.randn(2, 3, 4, 5, 6),
+ torch.randn(4, 2, 3, 5, 6).permute(1, 2, 0, 3, 4),
+ ),
+ )
+
+ def test_unsqueeze(self):
+ def fn(a):
+ return (
+ torch.unsqueeze(a + 1, -1) + 2,
+ torch.unsqueeze(a, 2) + 2,
+ torch.unsqueeze(a + 1, 0) + 2,
+ torch.unsqueeze(a, -2) + 2,
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randn(
+ 2,
+ 2,
+ 2,
+ 2,
+ ),
+ ),
+ )
+
+ def test_unsqueeze_inplace(self):
+ def fn(a):
+ tmp1 = a + 1
+ aten.unsqueeze_(tmp1, 2)
+ tmp2 = aten.unsqueeze_(a + 1, 0) + 2
+ return (tmp1, tmp2)
+
+ self.common(
+ fn,
+ (
+ torch.randn(
+ 2,
+ 2,
+ 2,
+ 2,
+ ),
+ ),
+ )
+
+ def test_addmm(self):
+ def fn(a, b, c):
+ return (torch.addmm(a + 1, b + 2, c + 3) + 4,)
+
+ self.common(
+ fn,
+ (
+ torch.randn(8, 8),
+ torch.randn(8, 8),
+ torch.randn(8, 8),
+ ),
+ )
+
+ def test_linear1(self):
+ mod = torch.nn.Sequential(
+ torch.nn.Linear(8, 16),
+ torch.nn.Sigmoid(),
+ ToTuple(),
+ )
+ self.common(mod, (torch.randn(2, 8),))
+
+ def test_linear2(self):
+ mod = torch.nn.Sequential(
+ torch.nn.Linear(8, 8),
+ torch.nn.ReLU(),
+ torch.nn.Linear(8, 8),
+ torch.nn.ReLU(),
+ torch.nn.Linear(8, 8),
+ torch.nn.ReLU(),
+ torch.nn.Linear(8, 8),
+ torch.nn.ReLU(),
+ )
+ self.common(mod, (torch.randn(2, 8),))
+
+ def test_bmm1(self):
+ def fn(a, b):
+ return (
+ torch.bmm(a, b),
+ torch.bmm(a + 1, b + 2) + 3,
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randn(2, 8, 8),
+ torch.randn(2, 8, 8),
+ ),
+ check_lowp=False,
+ )
+ self.common(
+ fn,
+ (
+ torch.randn(1, 16, 8),
+ torch.randn(1, 8, 10),
+ ),
+ check_lowp=False,
+ )
+
+ def test_bmm2(self):
+ def fn(a, b):
+ return torch.bmm(a.permute(0, 2, 1), b)
+
+ self.common(
+ fn,
+ (
+ torch.randn(1, 8, 8),
+ torch.randn(1, 8, 8),
+ ),
+ check_lowp=False,
+ )
+
+ def test_gather1(self):
+ def fn(a, b):
+ return (
+ torch.gather(a.expand([4, 5, 10, 6]), 3, b + 1),
+ torch.gather(a.expand([4, 5, 10, 6]), -1, b + 1),
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randn([1, 1, 10, 6]),
+ torch.randint(5, [4, 5, 10, 1], dtype=torch.int64),
+ ),
+ )
+
+ def test_gather2(self):
+ # 0d tensor
+ def fn(a, b):
+ return torch.gather(a, 0, b) + torch.gather(a, -1, b)
+
+ x = torch.tensor(123)
+ y = torch.tensor(0)
+ self.assertEqual(fn(x, y), x + x)
+
+ def test_slice1(self):
+ def fn(a):
+ return (
+ a[:, :10, 0] + a[:, 10:, 0],
+ (a + 1)[:, :10, 0] + (a + 1)[:, 10:, 0],
+ )
+
+ self.common(
+ fn,
+ (torch.randn([2, 20, 2]),),
+ )
+
+ def test_slice2(self):
+ def fn(a):
+ return (
+ a[:-1, ::2, -1] + a[-1:, 1::2, -2],
+ (a + 1)[:-1, ::2, -1] + (a + 2)[-1:, 1::2, -2],
+ )
+
+ self.common(
+ fn,
+ (torch.randn([2, 20, 2]),),
+ )
+
+ def test_split_with_sizes(self):
+ def fn(a, sizes):
+ return [t + 1.0 for t in torch.split(a * 2.0, sizes, -1)]
+
+ self.common(fn, (torch.randn(2, 2, 10), [3, 3, 4]))
+ self.common(fn, (torch.randn(2, 2, 10), [4, 3, 3]))
+ self.common(fn, (torch.randn(2, 2, 10), [1, 2, 3, 4]))
+
+ def test_split(self):
+ def fn(a):
+ t = torch.split(a, 3, -1)
+ return (t[0], t[1], t[2], t[3])
+
+ def fn2(a):
+ return fn(a + 1)
+
+ self.common(
+ fn,
+ (torch.randn([2, 2, 10]),),
+ )
+
+ self.common(
+ fn2,
+ (torch.randn([2, 2, 10]),),
+ )
+
+ def test_to_dtype(self):
+ def fn(a, b):
+ return (
+ aten._to_copy(a, dtype=6),
+ aten._to_copy(b + 1, dtype=6),
+ aten.to(b, torch.float64),
+ aten.to(b, torch.bool),
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randn([2, 2, 10]),
+ torch.randn([2, 2, 10], dtype=torch.float64),
+ ),
+ )
+
+ @requires_cuda()
+ def test_to_device(self):
+ def fn(a):
+ if a.device.type == "cpu":
+ return aten._to_copy(a, device=torch.device("cuda"), dtype=6, layout=0)
+ else:
+ return aten._to_copy(a, device=torch.device("cpu"), dtype=6, layout=0)
+
+ self.common(
+ fn,
+ (torch.randn([2, 2, 10]),),
+ )
+
+ @requires_cuda()
+ def test_to_device_constant(self):
+ def fn(a):
+ d1 = a.device.type
+ if d1 == "cpu":
+ d2 = "cuda"
+ else:
+ d2 = "cpu"
+
+ const1 = torch.as_tensor(list(range(64)), device=d2)
+ return (
+ torch.arange(10, device=d2).to(d1) + a,
+ const1.to(d1),
+ (const1 + 1).to(d1),
+ )
+
+ self.common(
+ fn,
+ (torch.randn([10]),),
+ )
+
+ @requires_cuda()
+ def test_multi_device(self):
+ def fn(x):
+ x = x + 1
+ x = x + 2
+ x = x.cuda()
+ x = x + 3
+ x = x + 4
+ x = x.cpu()
+ x = x + 5
+ x = x + 6
+ x = x.cuda()
+ x = x + 7
+ x = x + 8
+ x = x.cpu()
+ x = x + 9
+ x = x + 10
+ return x
+
+ self.common(
+ fn,
+ (torch.randn([2, 2, 10]),),
+ check_lowp=False, # cpu doesn't understand fp16, and there are explicit .cpu() calls
+ )
+
+ def test_unbind(self):
+ def fn(a):
+ return torch.unbind(a), torch.unbind(a, -1)
+
+ self.common(
+ fn,
+ (torch.randn([4, 4, 4]),),
+ )
+
+ def test_convolution1(self):
+ m = torch.nn.Sequential(
+ torch.nn.Conv2d(5, 6, [3, 3]),
+ torch.nn.ReLU(),
+ ToTuple(),
+ )
+
+ self.common(
+ m,
+ (torch.randn([2, 5, 16, 16]),),
+ # Mismatched elements: 10 / 2352 (0.4%)
+ # Greatest absolute difference: 5.7220458984375e-05 at index (0, 3, 12, 12) (up to 1e-05 allowed)
+ # Greatest relative difference: 0.06512477175897748 at index (0, 4, 11, 9) (up to 0.001 allowed)
+ atol=6e-5,
+ rtol=0.001,
+ )
+
+ def test_convolution2(self):
+ def fn(x, w, b):
+ # transposed conv
+ return (aten.convolution(x, w, b, [4], [0], [1], True, [0], 1),)
+
+ self.common(
+ fn,
+ (
+ torch.randn([2, 32, 90]),
+ torch.randn([32, 16, 8]),
+ torch.randn([16]),
+ ),
+ check_lowp=False,
+ )
+
+ @unittest.skipIf(HAS_CUDA, "only support cpu channels_last")
+ def test_conv2d_channels_last(self):
+ m = torch.nn.Sequential(
+ torch.nn.Conv2d(3, 3, 1, 1),
+ ToTuple(),
+ )
+ # only weight is channels_last
+ self.common(
+ m.to(memory_format=torch.channels_last),
+ (torch.randn([2, 3, 16, 16]),),
+ )
+ # only activation is channels_last
+ self.common(
+ m,
+ (torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),),
+ )
+ # activation and weight are all channels_last
+ self.common(
+ m.to(memory_format=torch.channels_last),
+ (torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),),
+ )
+
+ @unittest.skipIf(HAS_CUDA, "only support cpu channels_last")
+ def test_conv3d_channels_last(self):
+ m = torch.nn.Sequential(
+ torch.nn.Conv3d(3, 3, 1, 1),
+ ToTuple(),
+ )
+ # only weight is channels_last
+ self.common(
+ m.to(memory_format=torch.channels_last_3d),
+ (torch.randn([2, 3, 16, 16, 16]),),
+ )
+ # only activation is channels_last
+ self.common(
+ m,
+ (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),),
+ )
+ # activation and weight are all channels_last
+ self.common(
+ m.to(memory_format=torch.channels_last_3d),
+ (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),),
+ )
+
+ def test_adaptive_avg_pool2d1(self):
+ def fn(x):
+ return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d(
+ x + 1, (2, 5)
+ )
+
+ self.common(
+ fn,
+ (torch.randn(2, 4, 16, 16),),
+ )
+
+ # lowering to avg_pool2d case
+ self.common(
+ fn,
+ (torch.randn(2, 4, 3, 3),),
+ )
+
+ # no-op case
+ self.common(
+ fn,
+ (torch.randn(2, 4, 6, 6),),
+ )
+
+ def test_max_pool2d1(self):
+ def fn(x):
+ return aten.max_pool2d_with_indices(x, [3, 3], [2, 2])
+
+ self.common(
+ fn,
+ (torch.randn(2, 4, 16, 16),),
+ )
+
+ def test_max_pool2d2(self):
+ def fn(x):
+ return aten.max_pool2d_with_indices(x, [3, 3], [2, 2])
+
+ self.common(
+ fn,
+ (torch.randn([16, 64, 55, 55]),),
+ )
+
+ def test_max_pool2d3(self):
+ def fn(x):
+ # with padding
+ return aten.max_pool2d_with_indices(x, [3, 3], [2, 2], [1, 1])
+
+ self.common(
+ fn,
+ (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),),
+ )
+
+ def test_max_pool2d4(self):
+ def fn(x):
+ # with padding
+ return aten.max_pool2d_with_indices(x, [3, 3], [2, 2], [0, 0], [1, 1], True)
+
+ self.common(
+ fn,
+ (torch.randn([2, 8, 111, 111]),),
+ )
+
+ def test_max_pool2d5(self):
+ def fn(x):
+ return aten.max_pool2d_with_indices(x, [3, 3], [])
+
+ self.common(
+ fn,
+ (torch.randn([16, 64, 55, 55]),),
+ )
+
+ def test_avg_pool2d1(self):
+ def fn(x):
+ return aten.avg_pool2d(x, [3, 3], [2, 2])
+
+ self.common(
+ fn,
+ (torch.randn(2, 4, 16, 16),),
+ )
+
+ def test_avg_pool2d2(self):
+ def fn(x):
+ return aten.avg_pool2d(x, [3, 3], [2, 2])
+
+ self.common(
+ fn,
+ (torch.randn([16, 64, 55, 55]),),
+ )
+
+ def test_avg_pool2d3(self):
+ def fn(x):
+ return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1])
+
+ self.common(
+ fn,
+ (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),),
+ )
+
+ def test_avg_pool2d4(self):
+ def fn(x):
+ return aten.avg_pool2d(x, [3, 3], [2, 2], [0, 0], True)
+
+ self.common(
+ fn,
+ (torch.randn([2, 8, 111, 111]),),
+ )
+
+ def test_avg_pool2d5(self):
+ def fn(x):
+ return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1], count_include_pad=False)
+
+ self.common(
+ fn,
+ (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),),
+ )
+
+ def test_avg_pool2d6(self):
+ def fn(x):
+ return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1], divisor_override=3)
+
+ self.common(
+ fn,
+ (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),),
+ )
+
+ def test_alexnet_prefix(self):
+ def forward(arg6, arg7, arg16):
+ convolution = torch.ops.aten.convolution(
+ arg16, arg7, arg6, [4, 4], [2, 2], [1, 1], False, [0, 0], 1
+ )
+ relu = torch.ops.aten.relu(convolution)
+ max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices(
+ relu, [3, 3], [2, 2]
+ )
+ getitem = max_pool2d_with_indices[0]
+ return (getitem,)
+
+ self.common(
+ forward,
+ (
+ rand_strided((64,), (1,), torch.float32, "cpu"),
+ rand_strided((64, 3, 11, 11), (363, 121, 11, 1), torch.float32, "cpu"),
+ rand_strided(
+ (16, 3, 224, 224), (150528, 50176, 224, 1), torch.float32, "cpu"
+ ),
+ ),
+ # Mismatched elements: 127 / 746496 (0.0%)
+ # Greatest absolute difference: 0.0009765625 at index (1, 62, 7, 16) (up to 1e-05 allowed)
+ # Greatest relative difference: 0.05187467899332306 at index (14, 18, 11, 0) (up to 0.001 allowed)
+ atol=1e-3,
+ rtol=0.001,
+ )
+
+ def test_elu(self):
+ def fn(x):
+ return aten.elu(x, 1.6732632423543772, 1.0507009873554805) + 2, aten.elu(
+ x + 1, 2, 3, 4
+ )
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_tanh(self):
+ def fn(x):
+ return aten.tanh(x) + 2, aten.tanh(x + 1)
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_lgamma(self):
+ def fn(x):
+ return aten.lgamma(x) + 2, aten.cos(x + 1)
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_cos(self):
+ def fn(x):
+ return aten.cos(x) + 2, aten.cos(x + 1)
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_sin(self):
+ def fn(x):
+ return aten.sin(x) + 2, aten.sin(x + 1)
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_repeat(self):
+ def fn(x):
+ return (
+ x.repeat(2, 2, 3, 1),
+ x.repeat(8, 1, 1, 1),
+ x.repeat(2, 1, 1, 1, 1, 1),
+ )
+
+ self.common(
+ fn,
+ (torch.randn([1, 2, 4, 8]),),
+ )
+
+ def test_embedding(self):
+ m = torch.nn.Sequential(
+ torch.nn.Embedding(10, 4, padding_idx=0),
+ torch.nn.ReLU(),
+ ToTuple(),
+ )
+
+ self.common(
+ m,
+ (torch.randint(10, [2, 8]),),
+ )
+
+ def test_mean(self):
+ def fn(x):
+ return (
+ x.mean(),
+ x.mean(-1),
+ torch.mean(x, -2, keepdim=True),
+ x.mean([0, 1]),
+ )
+
+ self.common(
+ fn,
+ (torch.randn([1, 2, 4, 8]),),
+ )
+
+ def test_var_mean(self):
+ def fn(x):
+ return (
+ *torch.var_mean(x, -1),
+ *torch.var_mean(x, [1, 3]),
+ )
+
+ self.common(
+ fn,
+ (torch.randn([1, 2, 4, 8]),),
+ )
+
+ @patch.object(config, "pick_loop_orders", True)
+ def test_transposed_propagates(self):
+ @torch._dynamo.optimize("inductor", nopython=True)
+ def fn(x, y):
+ return x + y
+
+ a = torch.randn(1, 4, 4, 4, device=self.device).permute(0, 2, 3, 1)
+ b = torch.randn(4, 4, 4, device=self.device).permute(1, 2, 0)
+ c = fn(a, b)
+ self.assertEqual(a.stride(), c.stride())
+ self.assertEqual(c.stride()[2], 1)
+
+ @requires_cuda()
+ @patch.object(config.triton, "convolution", "triton")
+ @patch.object(config.triton, "dense_indexing", "True")
+ def test_triton_conv(self):
+ @torch._dynamo.optimize("inductor", nopython=True)
+ def triton_conv(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups,
+ ):
+ y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+ return y
+
+ stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
+ dtype = torch.float32
+ x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
+ w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
+ bias = torch.randn((32), dtype=dtype, device=self.device)
+
+ y = triton_conv(x, w, bias, stride, padding, dilation, groups)
+ y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+ self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))
+
+ @requires_cuda()
+ @patch.object(config.triton, "convolution", "autotune")
+ @patch.object(config.triton, "dense_indexing", "True")
+ def test_conv_autotune(self):
+ @torch._dynamo.optimize("inductor", nopython=True)
+ def triton_conv(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups,
+ ):
+ y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+ return y
+
+ stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
+ dtype = torch.float32
+ x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
+ w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
+ bias = torch.randn((32), dtype=dtype, device=self.device)
+
+ y = triton_conv(x, w, bias, stride, padding, dilation, groups)
+ y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
+ self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))
+
+ @patch.object(config.triton, "mm", "triton")
+ def test_triton_mm2(self):
+ @torch._dynamo.optimize("inductor", nopython=True)
+ def fn(x, y):
+ return torch.relu(torch.mm(x, y))
+
+ N = 1024
+ a = torch.randn([N, N], device=self.device, dtype=torch.float32)
+ b = torch.randn([N, N], device=self.device, dtype=torch.float32)
+ c1 = torch.relu(torch.mm(a, b))
+ torch._inductor.metrics.reset()
+ c = fn(a, b)
+ assert torch.allclose(c1, c, atol=1e-3, rtol=1e-3)
+ if self.device == "cuda":
+ assert torch._inductor.metrics.generated_kernel_count == 1
+
+ def test_std(self):
+ def fn(x):
+ return (
+ torch.var(x, True),
+ torch.var(x, False),
+ torch.var(x, -1, True),
+ torch.var(x, -1, False),
+ torch.std(x, False),
+ torch.std(x, [0, 1], True),
+ torch.std(x, [0, 1], False),
+ torch.std(x, -2, True, keepdim=True),
+ )
+
+ self.common(
+ fn,
+ (torch.randn([2, 4, 4, 8]),),
+ )
+
+ def test_embedding_bag(self):
+ def fn(w, i, o):
+ return aten._embedding_bag(w, i, o, False, 0, False, None)
+
+ self.common(
+ fn,
+ (torch.randn([10, 4]), torch.randint(10, [8]), torch.tensor([0, 2, 6])),
+ )
+
+ def test_batch_norm_2d(self):
+ m = torch.nn.Sequential(
+ torch.nn.BatchNorm2d(10),
+ torch.nn.ReLU(),
+ )
+ m.eval()
+ self.common(m, (torch.randn([2, 10, 8, 8]),), check_lowp=False)
+ self.common(
+ m,
+ (torch.randn([3, 10, 16, 16]),),
+ check_lowp=False, # too painful to match types of bn model
+ )
+
+ def test_layer_norm(self):
+ m = torch.nn.Sequential(
+ torch.nn.LayerNorm(32),
+ torch.nn.ReLU(),
+ )
+ m.eval()
+ self.common(m, (torch.randn([16, 32]),), check_lowp=False)
+ if self.device != "cpu":
+ self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
+
+ def test_move_arange(self):
+ def fn(x):
+ return torch.arange(len(x), device="cpu").to(x.device) + x
+
+ self.common(fn, (torch.randn([32]),), check_lowp=False)
+ # if we have a copy there will be more than 1 kernel
+ self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
+
+ def test_leaky_relu(self):
+ def fn(x):
+ return aten.leaky_relu(x, 0.2) + 2, aten.leaky_relu(x + 1)
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_gelu(self):
+ def fn(x):
+ return aten.gelu(x) + 2, aten.gelu(x + 1)
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_clone(self):
+ def fn(x):
+ return aten.clone(x) + 2, aten.clone(x + 1)
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_masked_fill(self):
+ def fn(mask, value):
+ return aten.masked_fill(value, mask, -10000.0) + 2, aten.masked_fill(
+ value / 2.0, torch.logical_not(mask), 667
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randint(0, 1, [1, 16], dtype=torch.bool),
+ torch.randn([16, 16]),
+ ),
+ )
+
+ def test_fill1(self):
+ def fn(x):
+ tmp = torch.ones_like(x)
+ return tmp, aten.fill.Scalar(tmp, 2)
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_fill2(self):
+ def fn(x):
+ tmp = torch.ones_like(x)
+ return tmp, aten.fill.Tensor(tmp, torch.tensor(3.0))
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_pow1(self):
+ def fn(x):
+ return [aten.pow(x, e) for e in range(-8, 9)]
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ )
+
+ def test_pow2(self):
+ def fn(x):
+ return aten.pow(1000, x), aten.pow(x, 1000)
+
+ self.common(
+ fn,
+ (torch.randn([16, 16]),),
+ # Mismatched elements: 9 / 256 (3.5%)
+ # Greatest absolute difference: 2.491354329061828e+28 at index (6, 6) (up to 1e-05 allowed)
+ # Greatest relative difference: 2.9793410720160818e-05 at index (4, 5) (up to 1.3e-06 allowed)
+ atol=1e-5,
+ rtol=3e-05,
+ )
+
+ def test_glu(self):
+ def fn(x):
+ return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2)
+
+ self.common(
+ fn,
+ (torch.randn([8, 16, 8, 8]),),
+ )
+
+ def test_cat(self):
+ def fn(a):
+ tmp = a * 2
+ return torch.cat((a, a[:, :4] + 1, a + 2), -1), torch.cat((tmp, tmp), 0)
+
+ self.common(
+ fn,
+ (torch.randn([8, 16]),),
+ )
+
+ def test_cat_extern_kernel(self):
+ def fn(x1, x2, x3, x4):
+ x = torch.mm(x2, x3)
+ s = torch.narrow(x, 1, 0, 100)
+ x = torch.mm(s, x4)
+ c = torch.cat((x, x1), 1)
+ return (c,)
+
+ self.common(
+ fn,
+ (
+ torch.randn(256, 256),
+ torch.randn(256, 1024),
+ torch.randn(1024, 1600),
+ torch.randn(100, 256),
+ ),
+ check_lowp=False, # accuracy issues with relatively large matmuls
+ )
+
+ def test_stack(self):
+ def fn(a, b):
+ return torch.stack(
+ [
+ a.expand(12, 16),
+ b.expand(12, 16),
+ ],
+ 2,
+ )
+
+ self.common(fn, (torch.randn([1, 16]), torch.randn([12, 1])))
+
+ def test_hardtanh(self):
+ def fn(x):
+ return F.hardtanh(x), F.hardtanh(x + 1), F.hardtanh(x - 1)
+
+ self.common(
+ fn,
+ (torch.randn([64]),),
+ )
+
+ def test_hardsigmoid(self):
+ def fn(x):
+ return F.hardsigmoid(x), F.hardsigmoid(x + 3), F.hardsigmoid(x - 3)
+
+ self.common(
+ fn,
+ (torch.randn([64]),),
+ )
+
+ def test_hardswish(self):
+ def fn(x):
+ return F.hardswish(x), F.hardswish(x + 3), F.hardswish(x - 3)
+
+ self.common(
+ fn,
+ (torch.randn([64]),),
+ )
+
+ def test_rsqrt(self):
+ def fn(x):
+ return torch.rsqrt(x), torch.rsqrt(x + 1) - 2
+
+ self.common(
+ fn,
+ (torch.randn([64]),),
+ )
+
+ def test_flip(self):
+ def fn(x):
+ return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2
+
+ self.common(
+ fn,
+ (torch.randn([1, 2, 6, 6]),),
+ )
+
+ def test_signbit(self):
+ def fn(x):
+ return torch.signbit(x), ~torch.signbit(-x) & 1
+
+ self.common(
+ fn,
+ (torch.randn([1, 2, 6, 6]),),
+ )
+
+ def test_fmod(self):
+ def fn(a, b):
+ return torch.fmod(a, b), torch.fmod(3.0 * a, b) - 2.0
+
+ shape = [1, 2, 6, 6]
+ self.common(fn, (torch.randn(shape), torch.randn(shape)))
+
+ def test_log2(self):
+ def fn(x):
+ return torch.log2(x), torch.log2(x + 1) - 2
+
+ self.common(
+ fn,
+ (torch.randn([64]) + 10,),
+ )
+
+ def test_logsumexp(self):
+ def fn(x):
+ return torch.logsumexp(x, -1), torch.logsumexp(x, 0) - 2
+
+ self.common(
+ fn,
+ (torch.randn([8, 8]) + 10,),
+ )
+
+ def test_log_fp64(self):
+ def fn(x):
+ return torch.log(x), torch.log2(x)
+
+ self.common(
+ fn,
+ (torch.randn([1024], dtype=torch.float64) + 10,),
+ )
+
+ def test_bitwise(self):
+ def fn(x, y):
+ return (
+ torch.bitwise_not(x),
+ torch.bitwise_or(x, y),
+ torch.bitwise_xor(x, y),
+ torch.bitwise_and(x, y),
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randint(0, 2**30, [64], dtype=torch.int32),
+ torch.randint(0, 2**30, [64], dtype=torch.int32),
+ ),
+ )
+
+ def test_bitwise2(self):
+ # again with bool types
+ def fn(x, y):
+ return (
+ torch.bitwise_not(x),
+ torch.bitwise_or(x, y),
+ torch.bitwise_xor(x, y),
+ torch.bitwise_and(x, y),
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randint(0, 2, (2, 20), dtype=torch.bool),
+ torch.randint(0, 2, (2, 20), dtype=torch.bool),
+ ),
+ )
+
+ def test_inf(self):
+ def fn(a):
+ return a + float("inf"), a + float("-inf"), a * -float("inf")
+
+ self.common(fn, (torch.randn(8),))
+
+ def test_remainder(self):
+ def fn(a, b):
+ return (
+ torch.remainder(a, b),
+ torch.remainder(a + 1, b - 1),
+ torch.remainder(a - 1, b + 1),
+ )
+
+ self.common(fn, (torch.randn(64), torch.randn(64)))
+
+ def test_zeros(self):
+ def fn(a):
+ return (
+ a + 1,
+ torch.zeros(
+ (1, 8, 64, 64),
+ dtype=torch.float32,
+ device=a.device,
+ ),
+ torch.zeros(
+ 1,
+ 8,
+ 64,
+ 64,
+ dtype=torch.float32,
+ device=a.device,
+ ),
+ torch.zeros(2, 3, names=None),
+ a + torch.ones(8, device=a.device),
+ torch.full((2, 3), 3.1416, device=a.device),
+ )
+
+ self.common(fn, (torch.randn(8),))
+
+ def test_new_ones(self):
+ def fn(a):
+ return (
+ aten.new_ones(
+ a, [], device=a.device, dtype=6, layout=0, pin_memory=False
+ ),
+ aten.new_zeros(
+ a, [], device=a.device, dtype=6, layout=0, pin_memory=False
+ ),
+ )
+
+ self.common(fn, (torch.randn(8),))
+
+ def test_full_like(self):
+ def fn(a):
+ return torch.full_like(a, 7.777) - 1
+
+ self.common(fn, (torch.randn(8),))
+
+ def test_index1(self):
+ def fn(a, b, c):
+ return aten.index(a, [b, c])
+
+ self.common(
+ fn,
+ (
+ torch.randn(8, 8, 12),
+ torch.tensor([0, 0, 2, 2], dtype=torch.int64),
+ torch.tensor([3, 4, 4, 3], dtype=torch.int64),
+ ),
+ )
+ self.common(
+ fn,
+ (
+ torch.randn(8, 8, 12),
+ torch.tensor([[0, 0, 2, 2]], dtype=torch.int64),
+ torch.tensor([[3], [4], [4], [3]], dtype=torch.int64),
+ ),
+ )
+
+ def test_index2(self):
+ def fn(a, b):
+ return (
+ aten.index(a, [b]),
+ aten.index(a, [None, b]),
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randn(8, 8, 8),
+ torch.tensor([[0, 0, 2, 2]], dtype=torch.int64),
+ ),
+ )
+
+ def test_index_select(self):
+ def fn(a, b):
+ return (
+ torch.index_select(a, 0, b),
+ torch.index_select(a, 1, b),
+ torch.index_select(torch.index_select(a, 2, b), 1, b),
+ )
+
+ for ind_dtype in (torch.int32, torch.int64):
+ self.common(
+ fn,
+ (
+ torch.randn(8, 8, 8),
+ torch.tensor([0, 0, 2, 1], dtype=ind_dtype),
+ ),
+ )
+
+ # https://github.com/pytorch/torchdynamo/issues/467
+ @patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
+ def test_cudnn_rnn(self):
+ if self.device == "cpu":
+ raise unittest.SkipTest("requires CUDA")
+
+ def fn(
+ a0,
+ b0,
+ b1,
+ b2,
+ b3,
+ b4,
+ b5,
+ b6,
+ b7,
+ b8,
+ b9,
+ b10,
+ b11,
+ b12,
+ b13,
+ b14,
+ b15,
+ a3,
+ a4,
+ a5,
+ ):
+ a1 = [
+ b0,
+ b1,
+ b2,
+ b3,
+ b4,
+ b5,
+ b6,
+ b7,
+ b8,
+ b9,
+ b10,
+ b11,
+ b12,
+ b13,
+ b14,
+ b15,
+ ]
+ return aten._cudnn_rnn(
+ a0,
+ a1,
+ 4,
+ a3,
+ a4,
+ a5,
+ 2,
+ 2048,
+ 0,
+ 2,
+ False,
+ 0.0,
+ False,
+ True,
+ [],
+ None,
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randn([92, 8, 2048]),
+ torch.randn([8192, 2048]),
+ torch.randn([8192, 2048]),
+ torch.randn([8192]),
+ torch.randn([8192]),
+ torch.randn([8192, 2048]),
+ torch.randn([8192, 2048]),
+ torch.randn([8192]),
+ torch.randn([8192]),
+ torch.randn([8192, 4096]),
+ torch.randn([8192, 2048]),
+ torch.randn([8192]),
+ torch.randn([8192]),
+ torch.randn([8192, 4096]),
+ torch.randn([8192, 2048]),
+ torch.randn([8192]),
+ torch.randn([8192]),
+ torch.randn([167837696]),
+ torch.randn([4, 8, 2048]),
+ torch.randn([4, 8, 2048]),
+ ),
+ check_lowp=False, # difference in rnn is too large between half and float inputs
+ )
+
+ def test_upsample_nearest2d(self):
+ def fn(a):
+ return (
+ aten.upsample_nearest2d(a, [74, 76], None),
+ aten.upsample_nearest2d(a, [70, 75], None),
+ aten.upsample_nearest2d(a, [45, 74], None),
+ aten.upsample_nearest2d(a, [36, 39], None),
+ aten.upsample_nearest2d(a, None, [2.0, 2.0]),
+ )
+
+ self.common(fn, (torch.randn([2, 4, 37, 38]),))
+
+ def test_upsample_nearest2d_backward(self):
+ func = torch.ops.aten.upsample_nearest2d_backward.vec
+
+ def fn(a):
+ return (
+ func(
+ a, output_size=[6, 12], input_size=[3, 3, 3, 6], scale_factors=None
+ ),
+ func(
+ a, output_size=[6, 12], input_size=[3, 3, 4, 5], scale_factors=None
+ ),
+ func(
+ a, output_size=[6, 12], input_size=[3, 3, 2, 8], scale_factors=None
+ ),
+ func(
+ a, output_size=[6, 12], input_size=[3, 3, 2, 8], scale_factors=None
+ ),
+ func(
+ a, output_size=[6, 12], input_size=[3, 3, 4, 7], scale_factors=None
+ ),
+ )
+
+ self.common(fn, (torch.randn([3, 3, 6, 12]),))
+
+ def test_upsample_bilinear2d_a(self):
+ def fn(a):
+ return (
+ aten.upsample_bilinear2d(a, [45, 45], False, None),
+ aten.upsample_bilinear2d(a, None, True, [2.0, 2.0]),
+ )
+
+ self.common(fn, (torch.randn([2, 4, 37, 38]),))
+
+ def test_upsample_bilinear2d_b(self):
+ def fn(a):
+ return aten.upsample_bilinear2d(a, None, True, [2.0, 2.0])
+
+ self.common(
+ fn,
+ [
+ torch.randn([1, 2, 40, 59]),
+ ],
+ )
+
+ def test_reflection_pad2d(self):
+ def fn(a):
+ return (
+ aten.reflection_pad2d(a, [1, 1, 1, 1]),
+ aten.reflection_pad2d(a, [1, 2, 3, 4]),
+ )
+
+ self.common(
+ fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
+ )
+
+ def test_reflection_pad2d_backward(self):
+ def template(size, padding):
+ def fn(grad_output, x):
+ return aten.reflection_pad2d_backward(grad_output, x, padding)
+
+ x = torch.randint(0, 999, size=size, dtype=torch.float32)
+ result = aten.reflection_pad2d(x, padding)
+ grad_output = torch.randn_like(result)
+
+ self.common(fn, (grad_output, x))
+
+ template([1, 1, 8, 8], [0, 0, 0, 0])
+ template([1, 1, 8, 8], [1, 1, 1, 1])
+ template([1, 1, 8, 8], [1, 2, 3, 4])
+
+ def test_grid_sampler_2d(self):
+ def fn(a, b):
+ return (
+ aten.grid_sampler_2d(a, b, 0, 0, True),
+ aten.grid_sampler_2d(a, b, 0, 1, False),
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randn([4, 3, 352, 352], dtype=torch.float32),
+ torch.rand([4, 352, 352, 2], dtype=torch.float32) * 2 - 1,
+ ),
+ check_lowp=False,
+ # Mismatched elements: 154697 / 1486848 (10.4%)
+ # Greatest absolute difference: 0.0001976490020751953 at index (0, 0, 101, 243) (up to 1e-05 allowed)
+ # Greatest relative difference: 7.332530120481928 at index (1, 1, 258, 301) (up to 1.3e-06 allowed)
+ atol=0.0002,
+ rtol=1.3e-06,
+ )
+
+ def test_upsample_bicubic2d(self):
+ def fn(a):
+ return (
+ aten.upsample_bicubic2d(a, (128, 128), True),
+ aten.upsample_bicubic2d(a, (128, 256), False),
+ )
+
+ # Mismatched elements: 10 / 196608 (0.0%)
+ # Greatest absolute difference: 1.3869255781173706e-05 at index (2, 1, 88, 65) (up to 1e-05 allowed)
+ # Greatest relative difference: 0.0033082996811011046 at index (3, 1, 88, 91) (up to 1.3e-06 allowed)
+ self.common(
+ fn,
+ (torch.randn([4, 3, 64, 32], dtype=torch.float32),),
+ atol=2e-5,
+ rtol=1e-3,
+ )
+
+ def test_sort(self):
+ def fn(a):
+ return torch.sort(a)
+
+ self.common(
+ fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
+ )
+
+ def test_topk(self):
+ def fn(a):
+ return torch.topk(a, 2, -1)
+
+ self.common(
+ fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
+ )
+
+ def test_long_tensor(self):
+ def fn(a):
+ return (
+ torch.LongTensor([294]).to(a.device) - a,
+ torch.as_tensor([295]).to(a.device) + a,
+ )
+
+ self.common(fn, (torch.randint(0, 999, size=[8, 8]),))
+
+ def test_constant_pad_1d(self):
+ def fn(a):
+ return (
+ aten.constant_pad_nd(a, [0, 1], 6.0),
+ aten.constant_pad_nd(a, [2, 3], 99.0),
+ )
+
+ self.common(fn, (torch.randint(0, 999, size=[2, 16, 31], dtype=torch.float32),))
+
+ def test_constant_pad_2d(self):
+ def fn(a):
+ return (
+ aten.constant_pad_nd(a, [1, 1, 1, 1], 6.0),
+ aten.constant_pad_nd(a, [1, 2, 3, 4], 99.0),
+ )
+
+ self.common(
+ fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
+ )
+
+ def test_constant_pad_3d(self):
+ def fn(a):
+ return (
+ aten.constant_pad_nd(a, [1, 2, 3, 4, 5, 6], 6.0),
+ aten.constant_pad_nd(a, [0, 0, 3, 4, 0, 0], 6.0),
+ )
+
+ self.common(
+ fn, (torch.randint(0, 999, size=[2, 4, 4, 4], dtype=torch.float32),)
+ )
+
+ def test_l1_loss(self):
+ def fn(a, b):
+ return torch.nn.functional.l1_loss(a, b), torch.nn.functional.mse_loss(a, b)
+
+ self.common(
+ fn,
+ (
+ torch.randn([2, 3, 16, 16]),
+ torch.randn([2, 3, 16, 16]),
+ ),
+ check_lowp=False,
+ )
+
+ def test_triu(self):
+ def fn(a):
+ return aten.triu(a, 1), aten.triu(a, 0), aten.triu(a, 2)
+
+ self.common(fn, (torch.randn([2, 10, 10]),))
+
+ def test_no_op_reduction(self):
+ def fn(a):
+ return a.sum(-1), torch.amax(a + 1, 1, keepdim=True)
+
+ self.common(fn, (torch.randn([8, 1, 1]),))
+
+ def test_inplace_add(self):
+ @torch._dynamo.optimize("inductor")
+ def fn(x, y):
+ return x.add_(y)
+
+ inputs = (
+ rand_strided((4, 4), (4, 1), device=self.device),
+ rand_strided((4, 4), (4, 1), device=self.device),
+ )
+ inp_clone = inputs[0].clone()
+ out = fn(*inputs)
+ self.assertTrue(same(out, inp_clone + inputs[1]))
+ self.assertTrue(out is inputs[0])
+
+ def test_inplace_mixed_dtype_ops(self):
+ @torch._dynamo.optimize("inductor")
+ def fn(x, y):
+ z = x + y.float()
+ w = z.add_(y)
+ return w.mul_(y)
+
+ inputs = (
+ rand_strided((4, 4), (4, 1), device=self.device, dtype=torch.float),
+ rand_strided((4, 4), (4, 1), device=self.device, dtype=torch.double),
+ )
+ out = fn(*inputs)
+ out_eager = (inputs[0] + inputs[1].float()).add_(inputs[1]).mul_(inputs[1])
+ self.assertTrue(same(out, out_eager))
+
+ @patch.object(config.triton, "cudagraphs", True)
+ def test_strided_inputs(self):
+ @torch._dynamo.optimize("inductor")
+ def fn(x, y):
+ return x + y
+
+ inputs = (
+ rand_strided((8, 16), (32, 2), device=self.device),
+ rand_strided((8, 16), (16, 1), device=self.device),
+ )
+ self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
+
+ @patch.object(config.triton, "cudagraphs", True)
+ @patch.object(functorch_config, "use_fake_tensor", True)
+ def test_input_mutation1(self):
+ def fn(a):
+ b = a + 1
+ a.copy_(b)
+ c = a + 2
+ return a * b / c
+
+ arg1 = torch.randn(64, device=self.device)
+ arg2 = arg1.clone()
+ arg3 = torch.randn(64, device=self.device)
+ arg4 = arg3.clone()
+ correct1 = fn(arg1)
+ correct2 = fn(arg3)
+ opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
+ actual1 = opt_fn(arg2)
+ actual2 = opt_fn(arg4)
+
+ self.assertTrue(same(actual1, correct1))
+ self.assertTrue(same(actual2, correct2))
+ self.assertTrue(same(arg1, arg2))
+ self.assertTrue(same(arg3, arg4))
+
+ @patch.object(functorch_config, "use_fake_tensor", True)
+ def test_input_mutation2(self):
+ def fn(a):
+ b = a + 1
+ a.view(64).copy_(torch.tensor([66.0], device=a.device))
+ c = a + 2
+ return b, c
+
+ arg1 = torch.randn([1, 64], device=self.device)
+ arg2 = arg1.clone()
+ correct1 = fn(arg1)
+ opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
+ actual1 = opt_fn(arg2)
+
+ self.assertTrue(same(actual1, correct1))
+ self.assertTrue(same(arg1, arg2))
+
+ @patch.object(functorch_config, "use_fake_tensor", True)
+ def test_input_mutation3(self):
+ def fn(a):
+ a += 1
+ a *= 2
+ aten.sigmoid_(a)
+ a = a.view(64)
+ a += 3
+ a *= 4
+ aten.relu_(a)
+ return a
+
+ arg1 = torch.randn([1, 64], device=self.device)
+ arg2 = arg1.clone()
+ correct1 = fn(arg1)
+ opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
+ actual1 = opt_fn(arg2)
+
+ self.assertTrue(same(actual1, correct1))
+ self.assertTrue(same(arg1, arg2))
+
+ def test_input_mutation4(self):
+ def fn(a):
+ torch.relu_(a)
+ return a
+
+ arg1 = torch.randn([1, 64], device=self.device)
+ arg2 = arg1.clone()
+ correct1 = fn(arg1)
+ opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
+ actual1 = opt_fn(arg2)
+
+ self.assertTrue(same(actual1, correct1))
+ self.assertTrue(same(arg1, arg2))
+
+ @patch.object(functorch_config, "use_fake_tensor", True)
+ def test_slice_mutation1(self):
+ def fn(a):
+ x = torch.zeros_like(a)
+ b = x + 1
+ x[:, 3] = 3.0
+ c = torch.clone(x)
+ x[4, :] = 4.0
+ d = x + 1
+ return x, b, c, d
+
+ self.common(fn, (torch.randn([8, 8]),))
+
+ @patch.object(functorch_config, "use_fake_tensor", True)
+ def test_slice_mutation2(self):
+ def fn(a):
+ a[:, 20:40] = a[:, 20:40] + 1
+ a[:, 2:11] = a[:, 1:10] + 2
+
+ arg1 = torch.randn([1, 64], device=self.device)
+ arg2 = arg1.clone()
+ fn(arg1)
+ opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
+ opt_fn(arg2)
+
+ self.assertTrue(same(arg1, arg2))
+
+ def test_indirect_load_broadcast(self):
+ def fn(in_ptr0, in_ptr1, in_ptr2):
+ return torch.gather(in_ptr1, 0, in_ptr2) + in_ptr0
+
+ arg190 = rand_strided((32, 21), (1, 32), device=self.device, dtype=torch.int64)
+ arg190.fill_(0)
+ arg111 = rand_strided(
+ (9521, 512), (512, 1), device=self.device, dtype=torch.float32
+ )
+ self.common(
+ fn,
+ (
+ torch.randn(32, 1),
+ arg111,
+ arg190,
+ ),
+ )
+
+ @unittest.skipIf(not has_torchvision_roi_align(), "requirs torchvision")
+ def test_roi_align(self):
+ def fn(a, b):
+ return torch.ops.torchvision.roi_align(a, b, 0.25, 7, 7, 2, False)
+
+ self.common(fn, (torch.zeros([4, 256, 296, 304]), torch.zeros([2292, 5])))
+
+ @requires_decomp(aten.nll_loss_forward)
+ def test_nll_loss_forward(self):
+ def fn(a, b):
+ return aten.nll_loss_forward(a, b, None, 1, -100)
+
+ self.common(
+ fn,
+ (
+ torch.randn([5, 5]),
+ torch.zeros([5], dtype=torch.int64),
+ ),
+ )
+
+ def test_isinf(self):
+ def fn(x):
+ return x.isinf(), x.isnan()
+
+ self.common(
+ fn, [torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")])]
+ )
+ self.common(
+ fn,
+ [
+ torch.tensor(
+ [1, float("inf"), 2, float("-inf"), float("nan")],
+ dtype=torch.float64,
+ )
+ ],
+ )
+
+ def test_any(self):
+ def fn(x):
+ return (
+ x.any(-1),
+ x.isinf().any(),
+ torch.all(x.isinf(), dim=0),
+ torch.all(torch.logical_not(x.isinf())),
+ )
+
+ self.common(fn, [-torch.rand(64)])
+ tmp = torch.randn(16, 8)
+ tmp[1, 1] = float("inf")
+ self.common(fn, [tmp])
+
+ def test_inplace_activations(self):
+ def fn(x):
+ a = aten.hardswish_(x + 1)
+ b = aten.hardtanh_(x + 1)
+ c = aten.leaky_relu_(x + 1)
+ d = aten.silu_(x + 1)
+ e = aten.log1p(x + 1)
+ f = aten.masked_fill_(x + 1, torch.zeros_like(x, dtype=torch.bool), 99.0)
+ h = aten.masked_fill_(x + 1, torch.ones_like(x, dtype=torch.bool), 99.0)
+ return (a, b, c, d, e, f, h)
+
+ self.common(fn, [torch.randn(64) * 10])
+
+ def test_baddbmm(self):
+ def fn(a, b, c):
+ return aten.baddbmm(a, b, c)
+
+ self.common(
+ fn,
+ [
+ torch.randn(6, 1, 100),
+ torch.randn(6, 128, 64),
+ torch.randn(6, 64, 100),
+ ],
+ # Mismatched elements: 1212 / 76800 (1.6%)
+ # Greatest absolute difference: 0.001953125 at index (0, 0, 93) (up to 1e-05 allowed)
+ # Greatest relative difference: 1.0 at index (3, 19, 4) (up to 0.001 allowed)
+ atol=0.002,
+ rtol=0.001,
+ )
+
+ @patch.object(config.triton, "max_tiles", 2)
+ def test_fuse_tiled(self):
+ def fn(a, b, c):
+ return a + b, c + 1
+
+ self.common(
+ fn, [torch.randn(128, 1), torch.randn(1, 128), torch.randn(128, 128)]
+ )
+
+ def test_expand_as(self):
+ def fn(a, b):
+ return aten.expand_as(a, b), aten.expand_as(a + 1, b + 1) + 1
+
+ self.common(
+ fn,
+ [
+ torch.randn(6, 1, 100),
+ torch.randn(6, 128, 100),
+ ],
+ )
+
+ def test_index_put1(self):
+ def fn(a, b, c):
+ return (
+ torch.index_put(a, [b], c),
+ torch.index_put_(a + 1, [b + 1], c + 1) + 1,
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn([800, 256, 7, 7]),
+ torch.randperm(601),
+ torch.randn([601, 256, 7, 7]),
+ ],
+ )
+ self.common(
+ fn, [torch.randn(1024, 4, 2), torch.arange(4), torch.randn(4, 1, 1)]
+ )
+
+ def test_index_put2(self):
+ def fn(a, b, c):
+ return torch.index_put(a, [b], c, True)
+
+ self.common(
+ fn,
+ [
+ torch.randn([100, 256, 7, 7]),
+ torch.randint(0, 100, size=[600], dtype=torch.int64),
+ torch.randn([600, 256, 7, 7]),
+ ],
+ # workaround for https://github.com/openai/triton/issues/558
+ check_lowp=False,
+ )
+
+ def test_index_put3(self):
+ def fn(a, b, c):
+ torch.ops.aten.index_put_(a, (None, b, None), c)
+ a1 = a + 1
+ torch.ops.aten.index_put_(a1, (None, b + 1, None), c + 1)
+ return (a, a1)
+
+ self.common(
+ fn,
+ [
+ torch.randn([1024, 4, 2]),
+ torch.arange(3),
+ torch.randn([1024, 1, 2]),
+ ],
+ )
+
+ def test_index_put_as_masked_fill(self):
+ def fn(a, b, c, d):
+ a = a.clone()
+ torch.ops.aten.index_put_(a, [b], c, d)
+ return a
+
+ self.common(
+ fn,
+ (
+ torch.randn([1024, 4, 2]),
+ torch.randn([1024, 4, 2]) > 0,
+ torch.randn([]),
+ False,
+ ),
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randn([1024, 4, 2]),
+ torch.randn([1024, 4, 2]) > 0,
+ torch.randn([]),
+ True,
+ ),
+ )
+
+ def test_index_put_fallback1(self):
+ def fn(a, b, c, d):
+ a = a.clone()
+ torch.ops.aten.index_put_(a, [b], c, d)
+ return a
+
+ self.common(
+ fn,
+ (
+ torch.randn([3]),
+ torch.as_tensor([True, True, False]),
+ torch.randn([2]),
+ False,
+ ),
+ )
+
+ self.common(
+ fn,
+ (
+ torch.randn([3]),
+ torch.as_tensor([True, True, False]),
+ torch.randn([2]),
+ True,
+ ),
+ )
+
+ def test_index_put_fallback2(self):
+ def fn(a, b, c, d, e):
+ a = a.clone()
+ torch.ops.aten.index_put_(a, [None, b, c], d, e)
+ return a
+
+ self.common(
+ fn,
+ (
+ torch.randn([1, 2, 3]),
+ torch.as_tensor([0, 1]),
+ torch.as_tensor([True, True, False]),
+ torch.randn([]),
+ False,
+ ),
+ )
+ self.common(
+ fn,
+ (
+ torch.randn([1, 2, 3]),
+ torch.as_tensor([0, 1]),
+ torch.as_tensor([True, True, False]),
+ torch.randn([]),
+ True,
+ ),
+ )
+
+ @patch.object(config, "fallback_random", True)
+ def test_bernoulli1(self):
+ def fn(a):
+ b = torch.empty_like(a)
+ return aten.bernoulli_(b), b
+
+ self.common(
+ fn,
+ [
+ torch.randn([100]),
+ ],
+ )
+
+ def test_bernoulli2(self):
+ def fn(a):
+ return aten.bernoulli(a)
+
+ self.common(
+ fn,
+ [torch.tensor([1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0])],
+ )
+
+ def test_narrow(self):
+ def fn(x):
+ return aten.narrow(x, 1, 10, 16), aten.narrow(x + 2, 0, 10, 16) + 1
+
+ self.common(fn, [torch.randn(64, 64)])
+
+ def test_as_strided(self):
+ def fn(x):
+ return (
+ aten.as_strided(x, (8, 8, 64), (8 * 64, 64, 1), 0),
+ aten.as_strided(x + 1, (8, 8, 64), (8 * 64, 64, 1), 0) + 2,
+ )
+
+ self.common(fn, [torch.randn(64, 64)])
+
+ def test_select_scatter(self):
+ def fn(x, a, b):
+ return (
+ aten.select_scatter(x, a, 1, 0),
+ aten.select_scatter(x, b, 0, 1),
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn(8, 197, 38),
+ torch.randn(8, 38),
+ torch.randn(197, 38),
+ ],
+ )
+
+ def test_slice_scatter(self):
+ def fn(x, a):
+ return (
+ aten.slice_scatter(x, a, 2, 10, -10),
+ aten.slice_scatter(x, a[:, :, :40], 2, 10, -10, 2),
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn(4, 8, 100),
+ torch.randn(4, 8, 80),
+ ],
+ )
+
+ def test_slice_scatter2(self):
+ def fn(a, b):
+ return aten.slice_scatter(a, b, 0, 0, 9223372036854775807)
+
+ self.common(
+ fn,
+ [
+ torch.randn([8, 197, 384]),
+ torch.randn([8, 197, 384]),
+ ],
+ )
+
+ def test_scatter1(self):
+ def fn(a, dim, index, b):
+ return aten.scatter(a, dim, index, b)
+
+ self.common(
+ fn,
+ [
+ torch.zeros(2, 3),
+ -1,
+ torch.tensor([[0]]),
+ torch.ones(2, 3),
+ ],
+ )
+
+ def test_scatter2(self):
+ def fn(a, dim, index, b):
+ return aten.scatter.reduce(a, dim, index, b, reduce="add")
+
+ self.common(
+ fn,
+ [
+ torch.zeros(64, 512),
+ 0,
+ torch.zeros((64, 512), dtype=torch.int64),
+ torch.ones(64, 512),
+ ],
+ )
+
+ def test_scatter3(self):
+ def fn(a, dim, index, b):
+ return aten.scatter(a, dim, index, b, reduce="add")
+
+ self.common(
+ fn,
+ [
+ torch.randn(5, 29, 13),
+ 2,
+ torch.tensor([[[3, 5, 7, 9]]]),
+ 0.8, # src can be a scalar
+ ],
+ # Mismatched elements: 1 / 1885 (0.1%)
+ # Greatest absolute difference: 0.00018310546875 at index (0, 0, 3) (up to 1e-05 allowed)
+ # Greatest relative difference: 0.0022371364653243847 at index (0, 0, 3) (up to 0.001 allowed)
+ atol=2e-4,
+ rtol=1e-3,
+ )
+
+ def test_scatter4(self):
+ def fn(x, ind, src):
+ return torch.scatter(x, 0, ind, src)
+
+ self.common(
+ fn,
+ (torch.randn(196, 992), torch.randint(196, (1, 992)), torch.randn(1, 992)),
+ )
+
+ @unittest.skip("Flaky test, needs debugging")
+ def test_scatter_add1(self):
+ def fn(a, dim, index, b):
+ return aten.scatter_add(a, dim, index, b)
+
+ self.common(
+ fn,
+ [
+ torch.randn(2, 3),
+ 0,
+ torch.tensor([[0]]),
+ torch.randn(2, 3),
+ ],
+ )
+
+ def test_scatter_add2(self):
+ def fn(a, dim, index, b):
+ return aten.scatter_add(a, dim, index, b)
+
+ self.common(
+ fn,
+ [
+ torch.randn(2, 3),
+ 0,
+ torch.tensor([[0, 0, 0], [1, 1, 1]]),
+ torch.randn(2, 3),
+ ],
+ )
+
+ def test_scatter_add3(self):
+ def fn(a, dim, index, b):
+ return aten.scatter_add(a, dim, index, b)
+
+ self.common(
+ fn,
+ [
+ torch.randn(5, 29, 13),
+ 2,
+ torch.tensor([[[3, 5, 7, 9]]]),
+ torch.randn(1, 1, 10),
+ ],
+ )
+
+ def test_scatter_reduce1(self):
+ def fn(a, dim, index, b):
+ return aten.scatter_reduce(a, dim, index, b, "sum")
+
+ self.common(
+ fn,
+ [
+ torch.randn(5, 29, 13),
+ 2,
+ torch.tensor([[[3, 5, 7, 9]]]),
+ torch.randn(1, 1, 10),
+ ],
+ )
+
+ def test_scatter_reduce2(self):
+ def fn(a, dim, index, b):
+ return aten.scatter_reduce(a, dim, index, b, "sum", include_self=False)
+
+ self.common(
+ fn,
+ [
+ torch.randn(2, 3),
+ 0,
+ torch.zeros((2, 3), dtype=torch.int64),
+ torch.randn(2, 3),
+ ],
+ )
+
+ def test_new_empty_strided(self):
+ def fn(a):
+ return aten.new_empty_strided(a, [1, 128, 128], [16384, 128, 1]).fill_(123)
+
+ self.common(fn, [torch.randn(55)])
+
+ @patch.object(torch._inductor.config.triton, "cudagraphs", True)
+ def test_dropout(self):
+ random.seed(1234)
+ torch.manual_seed(1234)
+
+ @torch._dynamo.optimize("inductor")
+ def fn(a):
+ return torch.nn.functional.dropout(a, 0.5, True)
+
+ x = torch.ones(1000, device=self.device, dtype=torch.float32)
+ result = fn(x)
+ self.assertTrue(400 < result.nonzero().shape[0] < 600)
+ self.assertTrue(0.9 < result.mean().item() < 1.1)
+
+ def test_dropout_deterministic(self):
+ @torch._dynamo.optimize("inductor")
+ def fn(a):
+ return torch.nn.functional.dropout(a, 0.55, True)
+
+ for cg in (False, True):
+ with patch.object(torch._inductor.config.triton, "cudagraphs", cg):
+ torch._dynamo.reset()
+
+ x = torch.ones(1024, device=self.device, dtype=torch.float32)
+
+ torch.manual_seed(1234)
+ a0 = fn(x).clone()
+ a1 = fn(x).clone()
+ a2 = fn(x).clone()
+
+ torch.manual_seed(1234)
+ b0 = fn(x).clone()
+ b1 = fn(x).clone()
+ b2 = fn(x).clone()
+
+ # same seed, same values
+ self.assertTrue(torch.allclose(a0, b0))
+ self.assertTrue(torch.allclose(a1, b1))
+ self.assertTrue(torch.allclose(a2, b2))
+
+ # different calls, different values
+ self.assertFalse(torch.allclose(a0, a1))
+ self.assertFalse(torch.allclose(a1, a2))
+
+ def test_rand_like_deterministic(self):
+ @torch._dynamo.optimize("inductor")
+ def fn(a):
+ return torch.rand_like(a), torch.rand_like(a)
+
+ x = torch.ones(1024, device=self.device, dtype=torch.float32)
+
+ torch.manual_seed(1234)
+ a0 = fn(x)[0].clone()
+ a1 = fn(x)[0].clone()
+ a2 = fn(x)[0].clone()
+
+ torch.manual_seed(1234)
+ b0 = fn(x)[0].clone()
+ b1 = fn(x)[0].clone()
+ b2 = fn(x)[0].clone()
+
+ # same seed, same values
+ self.assertTrue(torch.allclose(a0, b0))
+ self.assertTrue(torch.allclose(a1, b1))
+ self.assertTrue(torch.allclose(a2, b2))
+
+ # different calls, different values
+ self.assertFalse(torch.allclose(a0, a1))
+ self.assertFalse(torch.allclose(a1, a2))
+
+ c, d = fn(x)
+ self.assertFalse(torch.allclose(c, d))
+ self.assertTrue((c >= 0).all())
+ self.assertTrue((c < 1).all())
+ self.assertTrue((d >= 0).all())
+ self.assertTrue((d < 1).all())
+
+ def test_max_pool2d_with_indices_backward(self):
+ def fn(a, b, c):
+ return aten.max_pool2d_with_indices_backward(
+ a, b, [2, 2], [2, 2], [0, 0], [1, 1], False, c
+ )
+
+ x = torch.randn([2, 4, 18, 14])
+ result, indices = aten.max_pool2d_with_indices(
+ x,
+ [2, 2],
+ [2, 2],
+ [0, 0],
+ [1, 1],
+ False,
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn_like(result),
+ x,
+ indices,
+ ],
+ )
+
+ def test_max_pool2d_with_indices_backward2(self):
+ def fn(a, b, c):
+ return aten.max_pool2d_with_indices_backward(
+ a, b, [3, 3], [2, 2], [1, 1], [1, 1], True, c
+ )
+
+ x = torch.randn([2, 4, 40, 56])
+ result, indices = aten.max_pool2d_with_indices(
+ x,
+ [3, 3],
+ [2, 2],
+ [1, 1],
+ [1, 1],
+ True,
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn_like(result),
+ x,
+ indices,
+ ],
+ )
+
+ # From https://github.com/pytorch/torchdynamo/issues/1200
+ def test_max_pool2d_with_indices_backward3(self):
+ def fn(a, b, c):
+ return aten.max_pool2d_with_indices_backward(
+ a, b, [1, 1], [2, 2], [0, 0], [1, 1], False, c
+ )
+
+ x = torch.randn([32, 256, 37, 38])
+ result, indices = aten.max_pool2d_with_indices(
+ x,
+ [1, 1],
+ [2, 2],
+ 0,
+ 1,
+ False,
+ )
+ self.common(
+ fn,
+ [
+ torch.randn_like(result),
+ x,
+ indices,
+ ],
+ )
+
+ def test_avg_pool2d_backward(self):
+ def fn(a, b):
+ return aten.avg_pool2d_backward(
+ a,
+ b,
+ [2, 2],
+ [2, 2],
+ [0, 0],
+ True,
+ False,
+ None,
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn([2, 4, 7, 7]),
+ torch.randn([2, 4, 14, 14]),
+ ],
+ )
+
+ def test_avg_pool2d_backward2(self):
+ def fn(a, b):
+ return aten.avg_pool2d_backward(
+ a,
+ b,
+ [3, 3],
+ [1, 1],
+ [1, 1],
+ True,
+ False,
+ None,
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn([1, 1, 20, 15]),
+ torch.randn([1, 1, 20, 15]),
+ ],
+ )
+
+ def test_avg_pool2d_backward3(self):
+ def fn(a, b):
+ return aten.avg_pool2d_backward(
+ a,
+ b,
+ [1, 1],
+ [2, 2],
+ [0, 0],
+ False,
+ False,
+ None,
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn([1, 2016, 11, 11]),
+ torch.randn([1, 2016, 21, 21]),
+ ],
+ )
+
+ def test_mm_views(self):
+ def fn(a, b):
+ return torch.mm(a.view(32, 32), b.view(32, 32))
+
+ self.common(
+ fn,
+ (
+ torch.randn([32, 32]).transpose(0, 1),
+ torch.randn([1, 32, 32]).transpose(0, 1),
+ ),
+ check_lowp=False,
+ )
+ expected_kernel = 0
+ # codegen mm kernel from template
+ if config.triton.mm != "aten" and self.device == "cuda":
+ expected_kernel = 1
+ if config.triton.mm == "autotune":
+ self.assertLessEqual(
+ torch._inductor.metrics.generated_kernel_count, expected_kernel
+ )
+ self.assertEqual(
+ torch._inductor.metrics.generated_kernel_count, expected_kernel
+ )
+
+ @patch.object(config.triton, "cudagraphs", False)
+ def test_lowmem_dropout1(self):
+ n = 100000
+ weight = torch.ones(
+ n, device=self.device, dtype=torch.float32, requires_grad=True
+ )
+ ones = torch.ones(n, device=self.device, dtype=torch.float32)
+
+ @torch._dynamo.optimize_assert("inductor")
+ def run(x, train=True):
+ return F.dropout(x * weight, 0.33, train)
+
+ def check(r, g):
+ rmean = r.mean().item()
+ gmean = g.mean().item()
+ rcount = len(r.nonzero())
+ gcount = len(g.nonzero())
+
+ # dropped elements should match
+ self.assertTrue(same(r.nonzero(), g.nonzero()))
+ self.assertEqual(rcount, gcount)
+
+ # dropped should be close to 0.33
+ self.assertGreater(rcount, 0.64 * n)
+ self.assertGreater(0.68 * n, rcount)
+
+ self.assertAlmostEqual(rmean, gmean)
+ self.assertAlmostEqual(rmean, 1.0, places=2)
+
+ r1 = run(ones, train=False)
+ r1.sum().backward()
+ g1 = weight.grad.clone()
+ # eval mode should be all ones
+ self.assertTrue(same(r1, torch.ones_like(r1)))
+ self.assertTrue(same(g1, torch.ones_like(g1)))
+
+ torch.manual_seed(1234)
+ weight.grad.zero_()
+ r2 = run(ones)
+ r2.sum().backward()
+ g2 = weight.grad.clone()
+ check(r2, g2)
+
+ torch.manual_seed(1234)
+ weight.grad.zero_()
+ r3 = run(ones)
+ r3.sum().backward()
+ g3 = weight.grad.clone()
+ check(r3, g3)
+
+ # second run is same result as first
+ self.assertTrue(same(r2, r3))
+ self.assertTrue(same(g2, g3))
+
+ def test_lowmem_dropout2(self):
+ m = torch.nn.Sequential(
+ torch.nn.Linear(32, 32, bias=False),
+ torch.nn.Dropout(),
+ torch.nn.Linear(32, 32, bias=False),
+ torch.nn.Dropout(),
+ ).to(self.device)
+
+ @torch._dynamo.optimize_assert("inductor")
+ def run(x):
+ return m(x)
+
+ torch._inductor.metrics.generated_kernel_count = 0
+ result = run(torch.randn([8, 32], device=self.device))
+ result.sum().backward()
+
+ expected_kernel = 4
+ if config.triton.mm != "aten" and self.device == "cuda":
+ # fwd: 2 * (mm+dropout) kernels = 2 kernels
+ # bwd: dropout + (mm) + 2 * (mm+dropout) kernels = 4 kernels
+ # expect 2 + 4 = 6 kernels
+ expected_kernel = 6
+ if config.triton.mm == "autotune":
+ self.assertLessEqual(
+ torch._inductor.metrics.generated_kernel_count, expected_kernel
+ )
+ self.assertEqual(
+ torch._inductor.metrics.generated_kernel_count, expected_kernel
+ )
+
+ def test_roll(self):
+ def fn(a):
+ return (
+ aten.roll(a, [-3, 10], [1, 2]),
+ aten.roll(a, [5]),
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn([2, 56, 56, 16]),
+ ],
+ )
+
+ def test_argmax_argmin1(self):
+ def fn(x):
+ return (aten.argmax(x), aten.argmin(x))
+
+ self.common(
+ fn,
+ [
+ torch.randn([8, 256, 256]),
+ ],
+ )
+
+ def test_argmax_argmin2(self):
+ def fn(x):
+ return (
+ aten.argmax(x, 0),
+ aten.argmin(x, 0),
+ aten.argmax(x, 1),
+ aten.argmin(x, 1),
+ )
+
+ self.common(
+ fn,
+ [
+ torch.randn([144, 144]),
+ ],
+ # Mismatched elements: 1 / 144 (0.7%)
+ # Greatest absolute difference: 26 at index (71,)
+ # Greatest relative difference: 0.4126984179019928 at index (71,)
+ atol=1e-5,
+ rtol=0.5,
+ )
+
+ @unittest.skip(
+ """
+ FIXME: In the case of having equally max/min elements, our implementation returns
+ the last index instead of the first one
+ """
+ )
+ def test_argmax_argmin3(self):
+ def fn(x):
+ return (
+ aten.argmax(x, 0),
+ aten.argmin(x, 0),
+ aten.argmax(x, -1),
+ aten.argmin(x, -1),
+ )
+
+ self.common(
+ fn,
+ [torch.randint(0, 5, [10, 10])],
+ )
+
+ def test_vdd_clamp(self):
+ def fn(x):
+ return torch.clamp_min(x, 3)
+
+ self.common(
+ fn,
+ [
+ torch.randn([16], requires_grad=True) * 10,
+ ],
+ )
+
+ def test_tmp_not_defined_issue1(self):
+ def forward(
+ primals_3,
+ primals_4,
+ add_tensor,
+ convert_element_type_default,
+ div_default,
+ reciprocal_default,
+ ):
+ var_default = torch.ops.prims.var.default(
+ convert_element_type_default, [2], correction=0
+ )
+ sub_tensor = torch.ops.aten.sub.Tensor(add_tensor, div_default)
+ mul_tensor_1 = torch.ops.aten.mul.Tensor(sub_tensor, reciprocal_default)
+ mul_tensor_2 = torch.ops.aten.mul.Tensor(mul_tensor_1, primals_3)
+ add_tensor_2 = torch.ops.aten.add.Tensor(mul_tensor_2, primals_4)
+ convert_element_type_default_1 = (
+ torch.ops.prims.convert_element_type.default(
+ add_tensor_2, torch.float32
+ )
+ )
+ convert_element_type_default_2 = (
+ torch.ops.prims.convert_element_type.default(
+ convert_element_type_default_1, torch.float32
+ )
+ )
+ var_default_1 = torch.ops.prims.var.default(
+ convert_element_type_default_2, [2], correction=0
+ )
+ broadcast_in_dim_default_2 = torch.ops.prims.broadcast_in_dim.default(
+ var_default_1, [1, 512, 1], [0, 1]
+ )
+ sum_default_1 = torch.ops.prims.sum.default(
+ convert_element_type_default_2, [2]
+ )
+ add_tensor_3 = torch.ops.aten.add.Tensor(broadcast_in_dim_default_2, 1e-05)
+ return (var_default, sum_default_1, add_tensor_3)
+
+ inps = [
+ (torch.Size([1024]), torch.float32),
+ (torch.Size([1024]), torch.float32),
+ (torch.Size([1, 512, 1024]), torch.float32),
+ (torch.Size([1, 512, 1024]), torch.float32),
+ (torch.Size([1, 512, 1]), torch.float32),
+ (torch.Size([1, 512, 1]), torch.float32),
+ ]
+ inps = [torch.randn(shape, dtype=dtype) for (shape, dtype) in inps]
+ self.common(forward, inps, atol=1e-05, rtol=2e-05)
+
+ @unittest.skipIf(TEST_WITH_ASAN, "TODO: debug this with asan")
+ def test_tmp_not_defined_issue2(self):
+ def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4):
+ div_tensor_7 = torch.ops.aten.div.Tensor(getitem_17, arg81_1)
+ mul_tensor_24 = torch.ops.aten.mul.Tensor(div_tensor_7, arg38_1)
+ sum_default_7 = torch.ops.aten.sum.default(mul_tensor_24)
+ return (new_zeros_default_4, sum_default_7)
+
+ args = [
+ ((1, 88, 40, 40), (140800, 1600, 40, 1), torch.float32),
+ ((), (), torch.float32),
+ ((1, 88, 40, 40), (140800, 1600, 40, 1), torch.float32),
+ ((3,), (1,), torch.float32),
+ ]
+ args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
+ self.common(forward, args)
+
+ def test_misaligned_address_issue1(self):
+ def forward(sub_tensor_1, unsqueeze_default):
+ gather_default = torch.ops.aten.gather.default(
+ sub_tensor_1, 1, unsqueeze_default
+ )
+ return gather_default
+
+ args = [
+ ((1, 1000), (1000, 1), torch.float32),
+ ((1, 1), (1, 1), torch.int64),
+ ]
+ args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
+ self.common(forward, args)
+
+ def test_invalid_operand_issue1(self):
+ def forward(arg0_1, arg1_1, arg3_1, squeeze, view_1, slice_1):
+ slice_scatter = torch.ops.aten.slice_scatter.default(
+ slice_1, arg3_1, 1, 1, 9223372036854775807
+ )
+ slice_scatter_1 = torch.ops.aten.slice_scatter.default(
+ arg1_1, slice_scatter, 0, 0, 9223372036854775807
+ )
+ slice_2 = torch.ops.aten.slice.Tensor(
+ slice_scatter_1, 0, 0, 9223372036854775807
+ )
+ select_scatter = torch.ops.aten.select_scatter.default(
+ slice_2, squeeze, 1, 0
+ )
+ slice_scatter_2 = torch.ops.aten.slice_scatter.default(
+ slice_scatter_1, select_scatter, 0, 0, 9223372036854775807
+ )
+ view = torch.ops.aten.view.default(slice_scatter_2, [-1, 128])
+ embedding = torch.ops.aten.embedding.default(arg0_1, view, 1)
+ return [embedding, view_1]
+
+ args = [
+ ((50005, 768), (768, 1), torch.float32),
+ ((8, 128), (128, 1), torch.int64),
+ ((8, 127), (127, 1), torch.int64),
+ ((8,), (1,), torch.int64),
+ ((1024,), (1,), torch.int64),
+ ((8, 128), (128, 1), torch.int64),
+ ]
+ args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
+ self.common(forward, args)
+
+ def test_sizehint_issue1(self):
+ def forward(x):
+ return torch.nn.functional.unfold(
+ x, kernel_size=[4, 4], dilation=1, padding=0, stride=[4, 4]
+ )
+
+ args = [((2, 24, 56, 56), (75264, 3136, 56, 1), torch.float32, False)]
+ args = [
+ rand_strided(sh, st, dt).requires_grad_(rg) for (sh, st, dt, rg) in args
+ ]
+ self.common(forward, args)
+
+ @unittest.skip("https://github.com/pytorch/torchdynamo/issues/1297")
+ @patch.object(torch._inductor.config.triton, "cudagraphs", False)
+ def test_symbolic(self):
+ def f(x):
+ x = x.cos()
+ x = x.view(x.shape[0] * 2, -1)
+ return (x,)
+
+ traced = make_fx(f, tracing_mode="symbolic")(
+ torch.randn(8, 4, device=self.device)
+ )
+ compiled = compile_fx_inner(traced, [torch.randn(8, 4, device=self.device)])
+
+ out = compiled(torch.randn(8, 4, device=self.device))
+ self.assertEqual(out[0].shape, (16, 2))
+
+ out = compiled(torch.randn(12, 4, device=self.device))
+ self.assertEqual(out[0].shape, (24, 2))
+
+ @requires_cuda()
+ @patch.object(config.triton, "cudagraphs", False)
+ def test_unspec_inputs(self):
+ def fn(x, y):
+ return x + y
+
+ inputs = (
+ rand_strided((2, 3), (3, 1), device="cuda"),
+ rand_strided((), (), device="cpu"),
+ )
+ self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
+
+ @requires_cuda()
+ @patch.object(config.triton, "cudagraphs", True)
+ def test_unspec_inputs_cudagraphs(self):
+ def fn(x, y):
+ return x + y
+
+ inputs = (
+ rand_strided((2, 3), (3, 1), device="cuda"),
+ rand_strided((), (), device="cpu"),
+ )
+ self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
+
+
+if HAS_CPU:
+
+ class CpuTests(TestCase):
+ common = check_model
+ device = "cpu"
+
+ CommonTemplate.install(CpuTests, "cpu")
+
+ class CPUReproTests(TestCase):
+ def test_inplace_squeeze_needed(self):
+ mod = torch.nn.Sequential(
+ torch.nn.Linear(10, 10),
+ torch.nn.LayerNorm(10),
+ torch.nn.ReLU(),
+ ).eval()
+
+ @torch._dynamo.optimize("inductor")
+ def fn(x):
+ return mod(x)
+
+ v = torch.randn(10)
+ result = fn(v)
+ assert same(result, mod(v))
+
+ def test_inplace_add_alpha(self):
+ def fn(x, y):
+ aten.add_.Tensor(x, y, alpha=0.55)
+ return (x,)
+
+ x1 = torch.zeros(10)
+ x2 = torch.zeros(10)
+ x3 = torch.zeros(10)
+ y = torch.randn(10)
+ fn_fx = make_fx(fn)(x1, y)
+ fn_compiled = compile_fx_inner(fn_fx, [x1, y])
+ fn(x2, y)
+ fn_compiled(x3, y)
+ assert same(x2, x3)
+
+ def test_no_op_squeeze(self):
+ @torch._dynamo.optimize("inductor")
+ def forward(arg0_1):
+ return torch.ops.aten.squeeze.dim(arg0_1, 1)
+
+ x = torch.randn((10, 20))
+ assert same(x, forward(x))
+
+ def test_parallel_num_threads(self):
+ @torch._dynamo.optimize("inductor")
+ def fn(x1, x2):
+ return x1 + x2
+
+ @contextlib.contextmanager
+ def set_num_threads(num_threads):
+ orig_num_threads = torch.get_num_threads()
+ torch.set_num_threads(num_threads)
+ yield
+ torch.set_num_threads(orig_num_threads)
+
+ x1 = torch.randn((10, 20))
+ x2 = torch.randn((10, 20))
+ with set_num_threads(1):
+ assert same(x1 + x2, fn(x1, x2))
+ with set_num_threads(4):
+ assert same(x1 + x2, fn(x1, x2))
+
+ @patch("torch.cuda.is_available", lambda: False)
+ def test_timed_cpu_only(self):
+ timed(lambda: torch.randn(10), ())
+
+
+if HAS_CUDA:
+
+ class SweepInputsCudaTest(SweepInputs2, TestCase):
+ gen = InputGen(10, "cuda")
+
+ SweepInputsCudaTest.populate()
+
+ class CudaTests(TestCase):
+ common = check_model_cuda
+ device = "cuda"
+
+ def test_simplify_dims(self):
+ def fn(a):
+ return (a + 1,)
+
+ self.common(
+ fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],)
+ )
+
+ CommonTemplate.install(CudaTests, "cuda")
+
+ class CudaReproTests(TestCase):
+ def test_index_put_issue(self):
+ def forward(
+ self,
+ arg76_1,
+ expand_default,
+ full_like_default,
+ _to_copy_default_67,
+ zeros,
+ ):
+ sum_sym_int_19 = torch.ops.aten.sum(_to_copy_default_67, [0], True)
+ view_default_57 = torch.ops.aten.view.default(
+ sum_sym_int_19, [512, 768]
+ )
+ where_self = torch.ops.aten.where.self(
+ expand_default, view_default_57, full_like_default
+ )
+ clone_default_12 = torch.ops.aten.clone.default(zeros)
+ index_put__default = torch.ops.aten.index_put_.default(
+ clone_default_12, [arg76_1], where_self, True
+ )
+ return (index_put__default,)
+
+ inps = [
+ (torch.Size([512]), torch.int64),
+ (torch.Size([512, 768]), torch.bool),
+ (torch.Size([512, 768]), torch.float16),
+ (torch.Size([4, 512, 768]), torch.float16),
+ (torch.Size([512, 768]), torch.float16),
+ ]
+ inps = [torch.zeros(())] + [
+ torch.ones(shape, dtype=dtype, device="cuda") for (shape, dtype) in inps
+ ]
+ mod = make_fx(forward)(*inps)
+ compiled = compile_fx_inner(mod, inps)
+ compiled(*inps)
+
+ @patch.object(config, "fallback_random", True)
+ def test_dtype_factory_issue(self):
+ def forward():
+ randn = torch.ops.aten.randn.default(
+ [12, 64, 1, 64],
+ dtype=torch.float32,
+ device=torch.device(type="cuda", index=0),
+ pin_memory=False,
+ )
+ unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(randn, -1)
+ return (unsqueeze_default_2,)
+
+ mod = make_fx(forward)()
+ compiled = compile_fx_inner(mod, ())
+ assert compiled()[0].device.type == "cuda"
+
+ @patch.object(config.triton, "cudagraphs", True)
+ def test_expanded_inputs_cudagraphs(self):
+ @torch._dynamo.optimize("inductor")
+ def fn(x, y):
+ return x + y
+
+ inputs = (
+ rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
+ rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
+ )
+ self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
+
+ @patch.object(config, "size_asserts", False)
+ @patch.object(config.triton, "cudagraphs", True)
+ def test_expanded_inputs_cudagraphs_no_size_asserts(self):
+ @torch._dynamo.optimize("inductor")
+ def fn(x, y):
+ return x + y
+
+ inputs = (
+ rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
+ rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
+ )
+ self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
+
+ def test_accuracy_issue1(self):
+ class Repro(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(
+ in_features=768, out_features=2, bias=True
+ )
+
+ def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
+ linear = self.linear(x)
+ split = linear.split(1, dim=-1)
+ getitem = split[0]
+ squeeze = getitem.squeeze(-1)
+ clamp = start_positions.clamp(0, 128)
+ cross_entropy = torch.nn.functional.cross_entropy(
+ squeeze, clamp, None, None, 128, None, "mean", 0.0
+ )
+ return cross_entropy
+
+ mod = Repro().cuda()
+ opt_mod = torch._dynamo.optimize("inductor")(mod)
+ mod.eval()
+ opt_mod.eval()
+
+ args = [
+ ((1,), (1,), torch.int64, "cuda", False),
+ ((1, 128, 768), (98304, 768, 1), torch.float32, "cuda", True),
+ ]
+ args = [
+ rand_strided(sh, st, dt, dev).requires_grad_(rg)
+ for (sh, st, dt, dev, rg) in args
+ ]
+ with torch.cuda.amp.autocast(enabled=False):
+ assert same_two_models(mod, opt_mod, args), "Dynamo failed"
+
+
+if __name__ == "__main__":
+ from torch._dynamo.testing import run_tests
+
+ run_tests(needs="filelock")
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
new file mode 100644
index 0000000000000..2b8b166a35d51
--- /dev/null
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -0,0 +1,622 @@
+# Owner(s): ["module: inductor"]
+import atexit
+import os
+import sys
+import unittest
+from collections import defaultdict
+from enum import Enum
+from functools import partial
+from unittest.mock import patch
+
+import torch
+
+import torch._dynamo
+from torch.testing._internal.common_device_type import (
+ instantiate_device_type_tests,
+ onlyNativeDeviceTypes,
+ OpDTypes,
+ ops,
+)
+from torch.testing._internal.common_methods_invocations import op_db
+from torch.testing._internal.common_utils import (
+ dtype_abbrs,
+ run_tests,
+ skipCUDAMemoryLeakCheckIf,
+ suppress_warnings,
+ TestCase,
+)
+
+try:
+ from torch._inductor.utils import has_triton
+
+ try:
+ from .test_torchinductor import check_model, check_model_cuda
+ except ImportError:
+ from test_torchinductor import check_model, check_model_cuda
+except (unittest.SkipTest, ImportError) as e:
+ sys.stderr.write(f"{type(e)}: {e}\n")
+ if __name__ == "__main__":
+ sys.exit(0)
+ raise
+
+bf16 = torch.bfloat16 # not tested
+f64 = torch.float64
+f32 = torch.float32
+f16 = torch.float16
+i8 = torch.int8 # not tested
+i16 = torch.int16 # not tested
+i32 = torch.int32
+i64 = torch.int64
+b8 = torch.bool
+u8 = torch.uint8 # not tested
+
+_ops = partial(
+ ops, dtypes=OpDTypes.supported, allowed_dtypes=[f16, f32, f64, i32, i64, b8]
+)
+
+# Success forces pass; failure forces fail; skip unconditionally skips testing
+TestExpect = Enum("TestExpect", ("SUCCESS", "XFAILURE", "SKIP"))
+
+COLLECT_EXPECT = os.getenv("PYTORCH_COLLECT_EXPECT", "0") == "1"
+FAIL_ON_SUCCESS = os.getenv("PYTORCH_FAIL_ON_SUCCESS", "1") == "1"
+ALL_SAMPLES = os.getenv("PYTORCH_ALL_SAMPLES", "0") == "1"
+START = os.getenv("PYTORCH_TEST_RANGE_START", None)
+END = os.getenv("PYTORCH_TEST_RANGE_END", None)
+
+if START is not None or END is not None:
+ assert END is not None
+ assert START is not None
+ START = int(START)
+ END = int(END)
+ assert START < END
+else:
+ START = 0
+ END = len(op_db)
+
+seen_succeeded = defaultdict(dict)
+seen_failed = defaultdict(dict)
+failed_reasons = defaultdict(set)
+
+
+def print_seen():
+ expected_failures = defaultdict(list)
+
+ def fmt_dtypes(dtypes):
+ r = ", ".join(sorted(dtype_abbrs[d] for d in dtypes))
+ return "{" + r + "}"
+
+ def process(device_type):
+ for op, failed_dtypes in seen_failed[device_type].items():
+ succeeded_dtypes = seen_succeeded.get(op, set())
+ expected_failures_dtypes = failed_dtypes - succeeded_dtypes
+
+ reasons = ""
+ if failed_reasons[op]:
+ reasons = " # " + ", ".join(sorted(failed_reasons[op]))
+ if expected_failures_dtypes:
+ expected_failures[device_type].append(
+ f' "{op}": {fmt_dtypes(expected_failures_dtypes)},{reasons}'
+ )
+
+ expected_failures[device_type].sort()
+ nl = "\n"
+ print(
+ f"""
+inductor_expected_failures_single_sample[\"{device_type}\"] = {{
+{nl.join(expected_failures[device_type])}
+}}
+"""
+ )
+
+ process("cpu")
+ process("cuda")
+
+
+if COLLECT_EXPECT:
+ atexit.register(print_seen)
+
+inductor_skips = defaultdict(dict)
+
+inductor_skips["cpu"] = {
+ "linalg.ldl_solve": {b8, f16, f32, f64, i32, i64}, # segfault
+ "linalg.lu_solve": {b8, f16, f32, f64, i32, i64}, # segfault
+ "reciprocal": {b8, i32, i64}, # segfault
+ "lu_solve": {b8, f16, f32, f64, i32, i64}, # segfault
+ "lu_unpack": {b8, f16, f32, f64, i32, i64}, # segfault
+ "__rdiv__": {b8, f16, f32, f64, i32, i64}, # flaky
+}
+
+inductor_skips["cuda"] = {
+ # flaky
+ "__rdiv__": {b8, f16, f32, f64, i32, i64},
+ "masked.prod": {f16, f32, f64},
+ "linalg.vander": {f32, f64},
+ "sparse.sampled_addmm": {f32, f64},
+ "broadcast_tensors": {f16, f32, f64},
+ "dsplit": {f16, f32, f64},
+ # Call parameter type does not match function signature!
+ "masked.logsumexp": {f64},
+ "erf": {f64},
+ "logsumexp": {f64},
+ "lu_unpack": {f32, f64}, # RuntimeError: CUDA error
+ "nn.functional.binary_cross_entropy_with_logits": {f64},
+ "nn.functional.gelu": {f64},
+ "nn.functional.glu": {f64},
+ "nn.functional.poisson_nll_loss": {f64},
+ "nn.functional.tanhshrink": {f16, f64},
+ "nn.functional.conv_transpose3d": {f16, f64},
+ "nn.functional._scaled_dot_product_attention": {f64},
+ "nn.functional.triplet_margin_loss": {f16},
+ "special.ndtr": {f64},
+ # Jiterator kernel is not expected to work with inductor
+ "jiterator_2inputs_2outputs": {b8, f16, f32, f64, i32, i64},
+ "jiterator_4inputs_with_extra_args": {b8, f16, f32, f64, i32, i64},
+ "jiterator_binary": {b8, f16, f32, f64, i32, i64},
+ "jiterator_binary_return_by_ref": {b8, f16, f32, f64, i32, i64},
+ "jiterator_unary": {b8, f16, f32, f64, i32, i64},
+}
+
+inductor_expected_failures_single_sample = defaultdict(dict)
+
+inductor_expected_failures_single_sample["cpu"] = {
+ "T": {b8, f16, f32, f64, i32, i64},
+ "H": {b8, f16, f32, f64, i32, i64},
+ "mH": {b8, f16, f32, f64, i32, i64},
+ "mT": {b8, f16, f32, f64, i32, i64},
+ "__getitem__": {b8, f16, f32, f64, i32, i64},
+ "addr": {f16},
+ "allclose": {f16, f32, f64},
+ "angle": {f16, f32, f64},
+ "argwhere": {b8, f16, f32, f64, i32, i64},
+ "bernoulli": {f32, f64},
+ "bincount": {i32, i64},
+ "chalf": {b8, f16, f32, f64, i32, i64},
+ "cholesky": {f32, f64},
+ "combinations": {b8, f16, f32, f64, i32, i64},
+ "complex": {f16, f32, f64},
+ "constant_pad_nd": {f16, f32, f64},
+ "copysign": {f16},
+ "corrcoef": {f32, f64, i32, i64},
+ "cov": {f32, f64, i32, i64},
+ "equal": {b8, f16, f32, f64, i32, i64},
+ "erf": {b8, f64},
+ "fft.fft": {f32, f64},
+ "fft.fft2": {b8, f32, f64, i32, i64},
+ "fft.fftn": {b8, f32, f64, i32, i64},
+ "fft.hfft": {b8, f32, f64, i32, i64},
+ "fft.hfft2": {b8, f32, f64, i32, i64},
+ "fft.hfftn": {b8, f32, f64, i32, i64},
+ "fft.ifft": {b8, f16, f32, f64, i32, i64},
+ "fft.ifft2": {b8, f32, f64, i32, i64},
+ "fft.ifftn": {b8, f32, f64, i32, i64},
+ "fft.ihfft": {b8, f16, f32, f64, i32, i64},
+ "fft.ihfft2": {f32, f64},
+ "fft.ihfftn": {f32, f64},
+ "fft.irfft": {b8, f32, f64, i32, i64},
+ "fft.irfft2": {b8, f32, f64, i32, i64},
+ "fft.irfftn": {b8, f32, f64, i32, i64},
+ "fft.rfft": {f32, f64},
+ "fft.rfft2": {f32, f64},
+ "fft.rfftn": {f32, f64},
+ "index_add": {f16},
+ "index_put": {f16, f32, f64},
+ "index_reduce": {f16, f32, f64},
+ "istft": {f32, f64},
+ "linalg.cholesky": {f32, f64},
+ "linalg.cholesky_ex": {f32, f64},
+ "linalg.eig": {f32, f64},
+ "linalg.eigh": {f32, f64},
+ "linalg.eigvals": {f32, f64},
+ "linalg.eigvalsh": {f32, f64},
+ "linalg.ldl_factor": {f32, f64},
+ "linalg.lstsq": {f32, f64},
+ "linalg.lstsq.grad_oriented": {f32, f64},
+ "linalg.matrix_rank": {f32, f64},
+ "linalg.matrix_rank.hermitian": {f32, f64},
+ "linalg.svd": {f32, f64},
+ "logdet": {f32, f64},
+ "masked.norm": {f16},
+ "masked_fill": {f16},
+ "masked_scatter": {f16, f32, f64},
+ "masked_select": {b8, f16, f32, f64, i32, i64},
+ "max.reduction_no_dim": {f16},
+ "max.reduction_with_dim": {b8, f16},
+ "min.reduction_no_dim": {f16},
+ "min.reduction_with_dim": {b8, f16},
+ "multinomial": {f32, f64},
+ "nan_to_num": {f16},
+ "nanquantile": {f32, f64},
+ "nn.functional.avg_pool1d": {i64},
+ "nn.functional.avg_pool2d": {i64},
+ "nn.functional.adaptive_avg_pool2d": {f16},
+ "nn.functional.ctc_loss": {f32, f64},
+ "nn.functional.gaussian_nll_loss": {f32, f64},
+ "nn.functional.gelu": {f64},
+ "nn.functional.local_response_norm": {i64},
+ "nn.functional.one_hot": {i64},
+ "nn.functional.pairwise_distance": {f16},
+ "nn.functional.rrelu": {f32, f64},
+ "nn.functional.triplet_margin_with_distance_loss": {f32, f64, i32, i64},
+ "nonzero": {b8, f16, f32, f64, i32, i64},
+ "normal": {f16, f32, f64},
+ "normal.number_mean": {f16, f32, f64},
+ "pca_lowrank": {f32, f64},
+ "pinverse": {f32, f64},
+ "polar": {f32, f64},
+ "quantile": {f32, f64},
+ "rand_like": {f16, f32, f64},
+ "randint_like": {f16, f32, f64, i32, i64},
+ "randn_like": {f16, f32, f64},
+ "repeat_interleave": {b8, f16, f32, f64, i32, i64},
+ "scatter_add": {f16},
+ "scatter_reduce.sum": {f16},
+ "scatter_reduce.prod": {f16, f32, f64},
+ "segment_reduce.lengths": {f16, f32, f64},
+ "segment_reduce.offsets": {f16, f32, f64},
+ "sgn": {f16, f32, f64},
+ "sparse.sampled_addmm": {f32, f64},
+ "stft": {f32, f64},
+ "svd": {f32, f64},
+ "svd_lowrank": {f32, f64},
+ "tensor_split": {b8, f16, f32, f64, i32, i64},
+ "to": {b8, f16, f32, f64, i32, i64},
+ "to_sparse": {f32, f64},
+ "tril": {f16},
+ "triu": {f16},
+ "uniform": {f16, f32, f64},
+ "unique": {b8, f32, f64, i32, i64},
+ "unique_consecutive": {b8, f32, f64, i32, i64},
+ "var": {f16},
+ "var_mean": {f16},
+ "view_as_complex": {f16, f32, f64},
+}
+
+
+inductor_expected_failures_single_sample["cuda"] = {
+ "T": {b8, f16, f32, f64, i32, i64},
+ "H": {b8, f16, f32, f64, i32, i64},
+ "mH": {b8, f16, f32, f64, i32, i64},
+ "mT": {b8, f16, f32, f64, i32, i64},
+ "__getitem__": {b8, f16, f32, f64, i32, i64},
+ "allclose": {f16, f32, f64},
+ "angle": {f32, f64},
+ "argwhere": {b8, f16, f32, f64, i32, i64},
+ "baddbmm": {f16},
+ "bernoulli": {f16, f32, f64},
+ "bincount": {i32, i64},
+ "chalf": {b8, f16, f32, f64, i32, i64},
+ "cholesky": {f32, f64},
+ "combinations": {b8, f16, f32, f64, i32, i64},
+ "complex": {f16, f32, f64},
+ "corrcoef": {f16, f32, f64, i32, i64},
+ "cov": {f16, f32, f64, i32, i64},
+ "equal": {b8, f16, f32, f64, i32, i64},
+ "erf": {b8},
+ "fft.fft": {f16, f32, f64},
+ "fft.fft2": {b8, f16, f32, f64, i32, i64},
+ "fft.fftn": {b8, f16, f32, f64, i32, i64},
+ "fft.hfft": {b8, f16, f32, f64, i32, i64},
+ "fft.hfft2": {b8, f16, f32, f64, i32, i64},
+ "fft.hfftn": {b8, f16, f32, f64, i32, i64},
+ "fft.ifft": {b8, f16, f32, f64, i32, i64},
+ "fft.ifft2": {b8, f16, f32, f64, i32, i64},
+ "fft.ifftn": {b8, f16, f32, f64, i32, i64},
+ "fft.ihfft": {b8, f16, f32, f64, i32, i64},
+ "fft.ihfft2": {f16, f32, f64},
+ "fft.ihfftn": {f16, f32, f64},
+ "fft.irfft": {b8, f16, f32, f64, i32, i64},
+ "fft.irfft2": {b8, f16, f32, f64, i32, i64},
+ "fft.irfftn": {b8, f16, f32, f64, i32, i64},
+ "fft.rfft": {f16, f32, f64},
+ "fft.rfft2": {f16, f32, f64},
+ "fft.rfftn": {f16, f32, f64},
+ "index_put": {f16, f32, f64},
+ "index_reduce": {f16, f32, f64},
+ "istft": {f32, f64},
+ "linalg.cholesky": {f32, f64},
+ "linalg.cholesky_ex": {f32, f64},
+ "linalg.eig": {f32, f64},
+ "linalg.eigh": {f32, f64},
+ "linalg.eigvals": {f32, f64},
+ "linalg.eigvalsh": {f32, f64},
+ "linalg.ldl_factor": {f32, f64},
+ "linalg.lstsq": {f32, f64},
+ "linalg.lstsq.grad_oriented": {f32, f64},
+ "linalg.matrix_rank": {f32, f64},
+ "linalg.matrix_rank.hermitian": {f32, f64},
+ "linalg.pinv.hermitian": {f32, f64},
+ "linalg.svd": {f32, f64},
+ "masked.argmax": {f16, f32, f64, i32},
+ "masked.argmin": {f16, f32, f64, i32},
+ "masked_scatter": {f16, f32, f64},
+ "masked_select": {b8, f16, f32, f64, i32, i64},
+ "max.reduction_with_dim": {b8, i32, i64},
+ "min.reduction_with_dim": {b8, i32, i64},
+ "multinomial": {f16, f32, f64},
+ "nn.functional.adaptive_avg_pool2d": {f16},
+ "nn.functional._scaled_dot_product_attention": {f64},
+ "nn.functional.ctc_loss": {f32, f64},
+ "nn.functional.grid_sample": {f16},
+ "nn.functional.gaussian_nll_loss": {f16, f32, f64},
+ "nn.functional.one_hot": {i64},
+ "nn.functional.rrelu": {f16, f32, f64},
+ "nn.functional.triplet_margin_with_distance_loss": {f16, f32, f64, i32, i64},
+ "nonzero": {b8, f16, f32, f64, i32, i64},
+ "normal": {f16, f32, f64},
+ "normal.number_mean": {f16, f32, f64},
+ "pca_lowrank": {f32, f64},
+ "pinverse": {f32, f64},
+ "polar": {f32, f64},
+ "pow": {i32, i64},
+ "rand_like": {f16, f32, f64},
+ "randint_like": {f16, f32, f64, i32, i64},
+ "randn_like": {f16, f32, f64},
+ "repeat_interleave": {b8, f16, f32, f64, i32, i64},
+ "round.decimals_3": {f16},
+ "scatter_reduce.prod": {f16, f32, f64},
+ "segment_reduce.lengths": {f16, f32, f64},
+ "segment_reduce.offsets": {f16, f32, f64},
+ "sgn": {f16, f32, f64},
+ "stft": {f32, f64},
+ "svd": {f32, f64},
+ "svd_lowrank": {f32, f64},
+ "tensor_split": {b8, f16, f32, f64, i32, i64},
+ "to": {b8, f16, f32, f64, i32, i64},
+ "to_sparse": {f16, f32, f64},
+ "uniform": {f16, f32, f64},
+ "unique": {b8, f16, f32, f64, i32, i64},
+ "unique_consecutive": {b8, f16, f32, f64, i32, i64},
+ "view_as_complex": {f16, f32, f64},
+}
+
+inductor_gradient_expected_failures_single_sample = defaultdict(dict)
+
+inductor_gradient_expected_failures_single_sample["cuda"] = {
+ "amax": {f16, f32, f64},
+ "amin": {f16, f32, f64},
+ "asin": {f16},
+ "cumprod": {f16},
+ "linalg.vector_norm": {f64, f64},
+ "linalg.householder_product": {f32},
+ "linalg.lu": {f32, f64},
+ "kron": {f16},
+ "masked.amax": {f16, f32, f64},
+ "masked.amin": {f16, f32, f64},
+ "max.reduction_no_dim": {f16, f32, f64},
+ "median": {f16, f32, f64},
+ "min.reduction_no_dim": {f16, f32, f64},
+ "nan_to_num": {f16, f32, f64},
+ "nanmean": {f16, f32, f64},
+ "nanmedian": {f16, f32, f64},
+ "nanquantile": {f32, f64},
+ "nansum": {f16, f32, f64},
+ "native_batch_norm": {f16, f32, f64},
+ "native_layer_norm": {f16, f32, f64},
+ "nn.functional._scaled_dot_product_attention": {f16},
+ "nn.functional.avg_pool2d": {f16, f32, f64},
+ "nn.functional.batch_norm.without_cudnn": {f16},
+ "nn.functional.batch_norm": {f16},
+ "nn.functional.cosine_similarity": {f16},
+ "nn.functional.instance_norm": {f16},
+ "nn.functional.normalize": {f16},
+ "nn.functional.softsign": {f16},
+ "nn.functional.local_response_norm": {f16},
+ "norm.inf": {f64},
+ "outer": {f16},
+ "quantile": {f32, f64},
+ "scatter_reduce.amax": {f16, f32, f64},
+ "scatter_reduce.amin": {f16, f32, f64},
+ "tanh": {f16},
+}
+
+inductor_should_fail_with_exception = defaultdict(dict)
+
+inductor_should_fail_with_exception["cpu"] = {}
+
+
+inductor_should_fail_with_exception["cuda"] = {
+ "__rpow__": {
+ i32: "Pow input must be floating point.",
+ i64: "Pow input must be floating point.",
+ }
+}
+
+
+def wrapper_set_seed(op, *args, **kwargs):
+ """Wrapper to set seed manually for some functions like dropout
+ See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
+ """
+ torch.manual_seed(42)
+ return op(*args, **kwargs)
+
+
+torch.testing._internal.common_methods_invocations.wrapper_set_seed = wrapper_set_seed
+
+# This file does a global patch to `disable_global_flags()` - which we should not invoke in non testing cases.
+torch._dynamo.variables.torch.tensor_dunder_fns.append(
+ torch.testing._internal.common_utils.disable_functorch
+)
+
+# key can be either op_name, or (op_name, deivce_type), or (op_name, device_type, dtype)
+inductor_override_kwargs = {
+ # the return value of empty is undefined
+ "empty": {"assert_equal": False},
+ "empty_like": {"assert_equal": False},
+ "new_empty": {"assert_equal": False},
+ "new_empty_strided": {"assert_equal": False},
+ "randn": {"assert_equal": False},
+ ("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001},
+ "gradient": {"check_gradient": False}, # segfault on check_gradient
+ # Following tests failed, and causing subsequent tests failing with unrecoverable CUDA error
+ "linalg.solve_triangular": {"check_gradient": False},
+ "linalg.lu_factor": {"check_gradient": False},
+ "linalg.lu_factor_ex": {"check_gradient": False},
+}
+
+# Always test with all sample for following ops
+inductor_all_samples = {
+ "softmax.with_dtype",
+ "index_add",
+ "index_put",
+ "index_copy",
+ "scatter_reduce.sum",
+ "select_scatter",
+}
+
+
+class TestInductorOpInfo(TestCase):
+ check_model = check_model
+ check_model_cuda = check_model_cuda
+
+ @onlyNativeDeviceTypes
+ @suppress_warnings
+ @skipCUDAMemoryLeakCheckIf(
+ True
+ ) # inductor kernels failing this test intermittently
+ @_ops(op_db[START:END])
+ @patch("torch._dynamo.config.raise_on_unsafe_aot_autograd", True)
+ def test_comprehensive(self, device, dtype, op):
+ torch._dynamo.reset()
+ with torch.no_grad():
+ torch.cuda.empty_cache()
+ op_name = op.name
+ if op.variant_test_name:
+ op_name += f".{op.variant_test_name}"
+
+ device_type = torch.device(device).type
+
+ assert device_type in ("cuda", "cpu")
+
+ # with open("test_output.txt", "a") as f:
+ # print(f"CONSIDERING OP {op_name} on {device_type} with {dtype} |
+ # {inductor_skips[device_type].get(op_name, set())}", flush=True, file=f)
+ # print(f"CONSIDERING OP {op_name} on {device_type} with {dtype} |
+ # {inductor_skips[device_type].get(op_name, set())}", flush=True)
+ if dtype in inductor_skips[device_type].get(op_name, set()):
+ test_expect = TestExpect.SKIP
+ # with open("test_output.txt", "a") as f:
+ # print(f"SKIPPING OP {op_name} on {device_type}", flush=True, file=f)
+ # print(f"SKIPPING OP {op_name} on {device_type}", flush=True)
+ self.skipTest(f"{op_name} in {dtype} not supported")
+ elif dtype in inductor_expected_failures_single_sample[device_type].get(
+ op_name, set()
+ ) or dtype in inductor_gradient_expected_failures_single_sample[
+ device_type
+ ].get(
+ op_name, set()
+ ):
+ test_expect = TestExpect.XFAILURE
+ else:
+ test_expect = TestExpect.SUCCESS
+
+ overridden_kwargs = {}
+ if op_name in inductor_override_kwargs:
+ overridden_kwargs = inductor_override_kwargs[op_name]
+ elif (op_name, device_type) in inductor_override_kwargs:
+ overridden_kwargs = inductor_override_kwargs[(op_name, device_type)]
+ elif (op_name, device_type, dtype) in inductor_override_kwargs:
+ overridden_kwargs = inductor_override_kwargs[(op_name, device_type, dtype)]
+
+ func = op.get_op()
+
+ def fn(*args, **kwargs):
+ return func(*args, **kwargs)
+
+ requires_grad = (
+ op.supports_autograd
+ and dtype in op.supported_backward_dtypes(device_type)
+ # TODO: OpInfo really ought to error out for this case, but it's
+ # not exercised in test_ops_gradients atm. The problem is not
+ # complex32 per-se (which is supported by data movement only ops)
+ # but that when we do backwards we expect other ops like add to work
+ and not dtype == torch.complex32
+ )
+ samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
+
+ if op_name not in inductor_all_samples and not ALL_SAMPLES:
+ if isinstance(samples, (list, tuple)):
+ samples = [samples[0]]
+ else:
+ samples = [next(samples)]
+
+ try:
+ for sample_input in samples:
+ args = [sample_input.input] + list(sample_input.args)
+ kwargs = sample_input.kwargs
+ # UNCOMMENT TO DEBUG SEGFAULTS
+ # with open("test_output.txt", "a") as f:
+ # print(f"RUNNING OP {op_name} on {device_type} with {dtype}", flush=True, file=f)
+ # print(f"RUNNING OP {op_name} on {device_type} with {dtype}", flush=True)
+ if device_type == "cuda":
+ # opinfo test case have already place the input on the correct device
+ # so we don't need do additional copy by setting copy_to_cuda=False
+ adjusted_kwargs = {
+ "check_lowp": False,
+ "nopython": True,
+ "copy_to_cuda": False,
+ "reference_in_float": False,
+ "check_gradient": requires_grad,
+ }
+ adjusted_kwargs.update(overridden_kwargs)
+
+ self.check_model_cuda(
+ fn,
+ args,
+ kwargs,
+ **adjusted_kwargs,
+ )
+ elif device_type == "cpu":
+ adjusted_kwargs = {
+ "check_lowp": False,
+ "nopython": True,
+ # skip checking gradient on CPU for now
+ "check_gradient": False,
+ }
+ adjusted_kwargs.update(overridden_kwargs)
+
+ self.check_model(
+ fn,
+ args,
+ kwargs,
+ **adjusted_kwargs,
+ )
+
+ except Exception as e:
+
+ if test_expect is TestExpect.XFAILURE:
+ return
+
+ seen_failed[device_type].setdefault(op_name, set()).add(dtype)
+
+ if COLLECT_EXPECT:
+ return
+
+ known_failure = False
+ if dtype in inductor_should_fail_with_exception[device_type].get(
+ op_name, set()
+ ):
+ failure = inductor_should_fail_with_exception[device_type][op_name][
+ dtype
+ ]
+ if failure in str(e):
+ known_failure = True
+ if not known_failure:
+ raise e
+
+ # with open("test_output.txt", "a") as f:
+ # print(f"SUCCEEDED OP {op_name} on {device_type} with {dtype}", flush=True, file=f)
+ seen_succeeded[device_type].setdefault(op_name, set()).add(dtype)
+
+ if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT:
+ if FAIL_ON_SUCCESS:
+ raise RuntimeError(
+ f"unexpected success {op_name}, {dtype}, {device_type}"
+ )
+
+
+instantiate_device_type_tests(TestInductorOpInfo, globals())
+
+if __name__ == "__main__":
+ torch._dynamo.config.raise_on_assertion_error = True
+ if has_triton():
+ run_tests()
diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py
index be35db2b38942..0f12734ffd668 100644
--- a/test/test_dynamic_shapes.py
+++ b/test/test_dynamic_shapes.py
@@ -4,7 +4,7 @@
from torch._C import _disabled_torch_function_impl
import torch.fx
import torch.nn.functional as F
-from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo
+from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, IS_WINDOWS
import unittest
import torch
import operator
@@ -19,7 +19,8 @@
try:
import sympy
- HAS_SYMPY = True
+ # TODO(jansel): these tests fail on windows
+ HAS_SYMPY = not IS_WINDOWS
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 241c9f72154e0..d736a2c453aca 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1,6 +1,6 @@
# Owner(s): ["module: ProxyTensor"]
-from torch.testing._internal.common_utils import TestCase, run_tests
+from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
import torch
import unittest
import warnings
@@ -28,7 +28,8 @@
try:
import sympy # noqa: F401
- HAS_SYMPY = True
+ # TODO(jansel): these tests fail on windows
+ HAS_SYMPY = not IS_WINDOWS
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py
new file mode 100644
index 0000000000000..22a974e7afb9b
--- /dev/null
+++ b/torch/_dynamo/__init__.py
@@ -0,0 +1,106 @@
+from . import allowed_functions, convert_frame, eval_frame, resume_execution
+from .convert_frame import replay
+from .eval_frame import (
+ assume_constant_result,
+ disable,
+ explain,
+ export,
+ optimize,
+ optimize_assert,
+ reset_code,
+ run,
+ skip,
+)
+from .utils import compilation_metrics, guard_failures, orig_code_map
+
+__all__ = [
+ "assume_constant_result",
+ "optimize",
+ "optimize_assert",
+ "export",
+ "explain",
+ "run",
+ "replay",
+ "disable",
+ "reset",
+ "list_backends",
+ "skip",
+]
+
+
+def reset():
+ """Clear all compile caches and restore initial state"""
+ for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
+ code = weak_code()
+ if code:
+ reset_code(code)
+ convert_frame.input_codes.clear()
+ convert_frame.output_codes.clear()
+ orig_code_map.clear()
+ guard_failures.clear()
+ resume_execution.ContinueExecutionCache.cache.clear()
+ eval_frame.most_recent_backend = None
+ compilation_metrics.clear()
+
+
+def list_backends():
+ """
+ Return valid strings that can be passed to:
+ @torchdynamo.optimize()
+ def foo(...):
+ ....
+ """
+ from .optimizations import BACKENDS
+
+ return [*sorted([*BACKENDS.keys(), "inductor"])]
+
+
+def allow_in_graph(fn):
+ """
+ Customize which functions TorchDynamo will include in the generated
+ graph. Similar to torch.fx.wrap().
+
+ torchdynamo.allow_in_graph(my_custom_function)
+
+ @torchdynamo.optimize(...)
+ def fn(a):
+ x = torch.add(x, 1)
+ x = my_custom_function(x)
+ x = torch.add(x, 1)
+ return x
+
+ fn(...)
+
+ Will capture a single graph containing my_custom_function().
+ """
+ if isinstance(fn, (list, tuple)):
+ return [allow_in_graph(x) for x in fn]
+ assert callable(fn), "allow_in_graph expects a callable"
+ allowed_functions._allowed_function_ids.add(id(fn))
+ allowed_functions._disallowed_function_ids.remove(id(fn))
+
+
+def disallow_in_graph(fn):
+ """
+ Customize which functions TorchDynamo will exclude in the generated
+ graph and force a graph break on.
+
+ torchdynamo.disallow_in_graph(torch.sub)
+
+ @torchdynamo.optimize(...)
+ def fn(a):
+ x = torch.add(x, 1)
+ x = torch.sub(x, 1)
+ x = torch.add(x, 1)
+ return x
+
+ fn(...)
+
+ Will break the graph on torch.sub, and give two graphs each with a
+ single torch.add() op.
+ """
+ if isinstance(fn, (list, tuple)):
+ return [disallow_in_graph(x) for x in fn]
+ assert callable(fn), "disallow_in_graph expects a callable"
+ allowed_functions._allowed_function_ids.remove(id(fn))
+ allowed_functions._disallowed_function_ids.add(id(fn))
diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py
new file mode 100644
index 0000000000000..56740bcb3b6a1
--- /dev/null
+++ b/torch/_dynamo/allowed_functions.py
@@ -0,0 +1,255 @@
+import builtins
+import collections
+import copy
+import functools
+import inspect
+import itertools
+import math
+import operator
+import types
+import warnings
+from typing import Dict, Optional, Set
+
+import numpy
+
+import torch
+from torch.fx._symbolic_trace import is_fx_tracing
+
+from . import config
+from .utils import is_safe_constant
+
+
+def make_function_id_set(lazy_initializer):
+ """
+ Track a set of `id()`s of objects which are either allowed or not
+ allowed to go into the generated FX graph. Use to test for torch.*,
+ numpy.*, builtins.*, etc.
+
+ Support user modification to permit customization of what can be
+ added to the graph and what will cause a graph break.
+ """
+
+ class FunctionIdSet:
+ function_ids: Optional[Set[int]] = None
+ function_names: Optional[Dict[int, str]] = None
+
+ def __call__(self):
+ if self.function_ids is None:
+ value = lazy_initializer()
+ if isinstance(value, dict):
+ self.function_ids = set(value.keys())
+ self.function_names = value
+ else:
+ assert isinstance(value, set)
+ self.function_ids = value
+ return self.function_ids
+
+ def get_name(self, idx: int, default: str):
+ self() # lazy init
+ return self.function_names.get(idx, default)
+
+ def add(self, idx: int):
+ self() # lazy init
+ self.function_ids.add(idx)
+
+ def remove(self, idx: int):
+ if idx in self():
+ self.function_ids.remove(idx)
+
+ def __contains__(self, idx: int):
+ return idx in self()
+
+ return FunctionIdSet()
+
+
+@make_function_id_set
+def _disallowed_function_ids():
+ remove = [
+ True,
+ False,
+ None,
+ collections.OrderedDict,
+ copy.copy,
+ copy.deepcopy,
+ inspect.signature,
+ math.__package__,
+ torch.__builtins__,
+ torch.autocast_decrement_nesting,
+ torch.autocast_increment_nesting,
+ torch.autograd.grad,
+ torch.clear_autocast_cache,
+ torch.cuda.current_device,
+ torch.cuda.amp.autocast_mode.autocast,
+ torch.distributions.constraints.is_dependent,
+ torch.distributions.normal.Normal,
+ torch.get_rng_state,
+ torch.inference_mode,
+ torch.set_anomaly_enabled,
+ torch.set_autocast_cache_enabled,
+ torch.set_autocast_cpu_dtype,
+ torch.set_autocast_cpu_enabled,
+ torch.set_autocast_enabled,
+ torch.set_autocast_gpu_dtype,
+ torch.set_rng_state,
+ torch.autograd.profiler.profile,
+ warnings.warn,
+ torch._C._dynamo.eval_frame.unsupported,
+ ]
+ # extract all dtypes from torch
+ dtypes = [
+ obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))
+ ]
+ remove += dtypes
+ storage = [
+ obj
+ for obj in torch.__dict__.values()
+ if isinstance(obj, type(torch.FloatStorage))
+ ]
+ remove += storage
+ return {id(x) for x in remove}
+
+
+@make_function_id_set
+def _allowed_function_ids():
+ """
+ Walk torch.* and get the ids of all the stuff in it
+ """
+ warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
+ torch_object_ids = dict()
+
+ def _is_allowed_module_prefix(obj):
+ allowed_modules = ("torch", "math")
+ # torch.nn.modules.rnn is disallowed because these modules internally
+ # flatten their parameters. This flattening process will call
+ # Tensor.set_ with a Storage, and Storages cannot be traced with
+ # AOTAutograd; so we need to graph-break. To ensure this, we inline
+ # these functions, rather than keep them opaque-ly in the graph.
+ disallowed_modules = (
+ "torch.optim.",
+ "torch.nn.modules.rnn.",
+ "torch._dynamo.",
+ "torch._C._dynamo.",
+ "torch._inductor.",
+ "torch._C.inductor.",
+ "torch.fx.",
+ )
+ allowed_modules_dot = tuple([x + "." for x in allowed_modules])
+ module = inspect.getmodule(obj)
+ if module is None:
+ return False
+
+ mod_name = module.__name__
+
+ if any(mod_name.startswith(m) for m in disallowed_modules):
+ return False
+
+ return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
+
+ def _find_torch_objects(module):
+ if any(
+ module.__name__.startswith(mod_name)
+ for mod_name in config.allowed_functions_module_string_ignorelist
+ ):
+ return
+ torch_object_ids[id(module)] = module.__name__
+ for name, obj in list(module.__dict__.items()):
+ if id(obj) not in torch_object_ids:
+ if isinstance(obj, types.ModuleType):
+ if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
+ obj
+ ):
+ torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
+ _find_torch_objects(obj)
+ elif _is_allowed_module_prefix(obj):
+ torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
+ elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
+ torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
+
+ _find_torch_objects(torch)
+ _find_torch_objects(math)
+
+ for idx in _disallowed_function_ids():
+ if idx in torch_object_ids:
+ del torch_object_ids[idx]
+
+ for extra in (is_fx_tracing,):
+ torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"
+
+ return torch_object_ids
+
+
+@make_function_id_set
+def _builtin_function_ids():
+ rv = {
+ id(v): f"builtins.{k}"
+ for k, v in builtins.__dict__.items()
+ if not k.startswith("_") and callable(v)
+ }
+ rv.update(
+ {
+ id(v): f"operator.{k}"
+ for k, v in operator.__dict__.items()
+ if not k.startswith("_") and callable(v)
+ }
+ )
+ rv.update(
+ {id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
+ )
+ rv[id(functools.reduce)] = "functools.reduce"
+ return rv
+
+
+@make_function_id_set
+def _numpy_function_ids():
+ rv = dict()
+ for mod in (numpy, numpy.random):
+ rv.update(
+ {
+ id(v): f"{mod.__name__}.{k}"
+ for k, v in mod.__dict__.items()
+ if callable(v)
+ and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
+ }
+ )
+ return rv
+
+
+@make_function_id_set
+def _builtin_constant_ids():
+ """
+ Collects constant builtins by eliminating callable items.
+ """
+ rv = {
+ id(v): f"builtins.{k}"
+ for k, v in builtins.__dict__.items()
+ if not k.startswith("_") and not callable(v)
+ }
+ return rv
+
+
+def is_allowed(obj):
+ """Is this safe to trace like torch.add ?"""
+ # torch.ops is populated lazily so we don't necessarily have them in
+ # _allowed_function_ids. Figure it out by testing the type instead
+ # in those cases
+ return id(obj) in _allowed_function_ids or isinstance(
+ obj,
+ (torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),
+ )
+
+
+def torch_get_name(obj, default):
+ """Convert a torch.* funcion to a string"""
+ return _allowed_function_ids.get_name(id(obj), default)
+
+
+def is_builtin_callable(obj):
+ return id(obj) in _builtin_function_ids
+
+
+def is_builtin_constant(obj):
+ return id(obj) in _builtin_constant_ids
+
+
+def is_numpy(obj):
+ return isinstance(obj, numpy.ndarray) or id(obj) in _numpy_function_ids
diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py
new file mode 100644
index 0000000000000..541336ba483c8
--- /dev/null
+++ b/torch/_dynamo/bytecode_analysis.py
@@ -0,0 +1,164 @@
+import dataclasses
+import dis
+import sys
+from numbers import Real
+
+TERMINAL_OPCODES = {
+ dis.opmap["RETURN_VALUE"],
+ dis.opmap["JUMP_ABSOLUTE"],
+ dis.opmap["JUMP_FORWARD"],
+ dis.opmap["RAISE_VARARGS"],
+ # TODO(jansel): double check exception handling
+}
+if sys.version_info >= (3, 9):
+ TERMINAL_OPCODES.add(dis.opmap["RERAISE"])
+JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs)
+HASLOCAL = set(dis.haslocal)
+HASFREE = set(dis.hasfree)
+
+if sys.version_info < (3, 8):
+
+ def stack_effect(opcode, arg, jump=None):
+ # jump= was added in python 3.8, we just ingore it here
+ if dis.opname[opcode] in ("NOP", "EXTENDED_ARG"):
+ # for some reason NOP isn't supported in python 3.7
+ return 0
+ return dis.stack_effect(opcode, arg)
+
+else:
+ stack_effect = dis.stack_effect
+
+
+def remove_dead_code(instructions):
+ """Dead code elimination"""
+ indexof = {id(inst): i for i, inst in enumerate(instructions)}
+ live_code = set()
+
+ def find_live_code(start):
+ for i in range(start, len(instructions)):
+ if i in live_code:
+ return
+ live_code.add(i)
+ inst = instructions[i]
+ if inst.opcode in JUMP_OPCODES:
+ find_live_code(indexof[id(inst.target)])
+ if inst.opcode in TERMINAL_OPCODES:
+ return
+
+ find_live_code(0)
+ return [inst for i, inst in enumerate(instructions) if i in live_code]
+
+
+def remove_pointless_jumps(instructions):
+ """Eliminate jumps to the next instruction"""
+ pointless_jumps = {
+ id(a)
+ for a, b in zip(instructions, instructions[1:])
+ if a.opname == "JUMP_ABSOLUTE" and a.target is b
+ }
+ return [inst for inst in instructions if id(inst) not in pointless_jumps]
+
+
+@dataclasses.dataclass
+class ReadsWrites:
+ reads: set
+ writes: set
+ visited: set
+
+
+def livevars_analysis(instructions, instruction):
+ indexof = {id(inst): i for i, inst in enumerate(instructions)}
+ must = ReadsWrites(set(), set(), set())
+ may = ReadsWrites(set(), set(), set())
+
+ def walk(state, start):
+ if start in state.visited:
+ return
+ state.visited.add(start)
+
+ for i in range(start, len(instructions)):
+ inst = instructions[i]
+ if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
+ if "LOAD" in inst.opname or "DELETE" in inst.opname:
+ if inst.argval not in must.writes:
+ state.reads.add(inst.argval)
+ elif "STORE" in inst.opname:
+ state.writes.add(inst.argval)
+ else:
+ raise NotImplementedError(f"unhandled {inst.opname}")
+ if inst.opcode in JUMP_OPCODES:
+ walk(may, indexof[id(inst.target)])
+ state = may
+ if inst.opcode in TERMINAL_OPCODES:
+ return
+
+ walk(must, indexof[id(instruction)])
+ return must.reads | may.reads
+
+
+@dataclasses.dataclass
+class FixedPointBox:
+ value: bool = True
+
+
+@dataclasses.dataclass
+class StackSize:
+ low: Real
+ high: Real
+ fixed_point: FixedPointBox
+
+ def zero(self):
+ self.low = 0
+ self.high = 0
+ self.fixed_point.value = False
+
+ def offset_of(self, other, n):
+ prior = (self.low, self.high)
+ self.low = min(self.low, other.low + n)
+ self.high = max(self.high, other.high + n)
+ if (self.low, self.high) != prior:
+ self.fixed_point.value = False
+
+
+def stacksize_analysis(instructions):
+ assert instructions
+ fixed_point = FixedPointBox()
+ stack_sizes = {
+ inst: StackSize(float("inf"), float("-inf"), fixed_point)
+ for inst in instructions
+ }
+ stack_sizes[instructions[0]].zero()
+
+ for _ in range(100):
+ if fixed_point.value:
+ break
+ fixed_point.value = True
+
+ for inst, next_inst in zip(instructions, instructions[1:] + [None]):
+ stack_size = stack_sizes[inst]
+ if inst.opcode not in TERMINAL_OPCODES:
+ assert next_inst is not None, f"missing next inst: {inst}"
+ stack_sizes[next_inst].offset_of(
+ stack_size, stack_effect(inst.opcode, inst.arg, jump=False)
+ )
+ if inst.opcode in JUMP_OPCODES:
+ stack_sizes[inst.target].offset_of(
+ stack_size, stack_effect(inst.opcode, inst.arg, jump=True)
+ )
+
+ if False:
+ for inst in instructions:
+ stack_size = stack_sizes[inst]
+ print(stack_size.low, stack_size.high, inst)
+
+ low = min([x.low for x in stack_sizes.values()])
+ high = max([x.high for x in stack_sizes.values()])
+
+ if sys.version_info < (3, 8) and not fixed_point.value:
+ # This is a rare issue in python 3.7 that still needs debugging
+ # see test/test_nops.py::NopTests::test3
+ return low + 32
+
+ assert fixed_point.value, "failed to reach fixed point"
+ assert low >= 0
+ return high
diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py
new file mode 100644
index 0000000000000..75d30e0655196
--- /dev/null
+++ b/torch/_dynamo/bytecode_transformation.py
@@ -0,0 +1,382 @@
+import dataclasses
+import dis
+import itertools
+import sys
+import types
+from typing import Any, List, Optional
+
+from .bytecode_analysis import stacksize_analysis
+
+
+@dataclasses.dataclass
+class Instruction:
+ """A mutable version of dis.Instruction"""
+
+ opcode: int
+ opname: str
+ arg: int
+ argval: Any
+ offset: Optional[int] = None
+ starts_line: Optional[int] = None
+ is_jump_target: bool = False
+ # extra fields to make modification easier:
+ target: Optional["Instruction"] = None
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return id(self) == id(other)
+
+
+def convert_instruction(i: dis.Instruction):
+ return Instruction(
+ i.opcode,
+ i.opname,
+ i.arg,
+ i.argval,
+ i.offset,
+ i.starts_line,
+ i.is_jump_target,
+ )
+
+
+class _NotProvided:
+ pass
+
+
+def create_instruction(name, arg=None, argval=_NotProvided, target=None):
+ if argval is _NotProvided:
+ argval = arg
+ return Instruction(
+ opcode=dis.opmap[name], opname=name, arg=arg, argval=argval, target=target
+ )
+
+
+def lnotab_writer(lineno, byteno=0):
+ """
+ Used to create typing.CodeType.co_lnotab
+ See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
+ This is the internal format of the line number table if Python < 3.10
+ """
+ assert sys.version_info < (3, 10)
+ lnotab = []
+
+ def update(lineno_new, byteno_new):
+ nonlocal byteno, lineno
+ while byteno_new != byteno or lineno_new != lineno:
+ byte_offset = max(0, min(byteno_new - byteno, 255))
+ line_offset = max(-128, min(lineno_new - lineno, 127))
+ assert byte_offset != 0 or line_offset != 0
+ byteno += byte_offset
+ lineno += line_offset
+ lnotab.extend((byte_offset, line_offset & 0xFF))
+
+ return lnotab, update
+
+
+def linetable_writer(first_lineno):
+ """
+ Used to create typing.CodeType.co_linetable
+ See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
+ This is the internal format of the line number table if Python >= 3.10
+ """
+ assert sys.version_info >= (3, 10)
+ linetable = []
+ lineno = first_lineno
+ lineno_delta = 0
+ byteno = 0
+
+ def _update(byteno_delta, lineno_delta):
+ while byteno_delta != 0 or lineno_delta != 0:
+ byte_offset = max(0, min(byteno_delta, 254))
+ line_offset = max(-127, min(lineno_delta, 127))
+ assert byte_offset != 0 or line_offset != 0
+ byteno_delta -= byte_offset
+ lineno_delta -= line_offset
+ linetable.extend((byte_offset, line_offset & 0xFF))
+
+ def update(lineno_new, byteno_new):
+ nonlocal lineno, lineno_delta, byteno
+ byteno_delta = byteno_new - byteno
+ byteno = byteno_new
+ _update(byteno_delta, lineno_delta)
+ lineno_delta = lineno_new - lineno
+ lineno = lineno_new
+
+ def end(total_bytes):
+ _update(total_bytes - byteno, lineno_delta)
+
+ return linetable, update, end
+
+
+def assemble(instructions: List[dis.Instruction], firstlineno):
+ """Do the opposite of dis.get_instructions()"""
+ code = []
+ if sys.version_info < (3, 10):
+ lnotab, update_lineno = lnotab_writer(firstlineno)
+ else:
+ lnotab, update_lineno, end = linetable_writer(firstlineno)
+
+ for inst in instructions:
+ if inst.starts_line is not None:
+ update_lineno(inst.starts_line, len(code))
+ arg = inst.arg or 0
+ code.extend((inst.opcode, arg & 0xFF))
+
+ if sys.version_info >= (3, 10):
+ end(len(code))
+
+ return bytes(code), bytes(lnotab)
+
+
+def virtualize_jumps(instructions):
+ """Replace jump targets with pointers to make editing easier"""
+ jump_targets = {inst.offset: inst for inst in instructions}
+
+ for inst in instructions:
+ if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel:
+ for offset in (0, 2, 4, 6):
+ if jump_targets[inst.argval + offset].opcode != dis.EXTENDED_ARG:
+ inst.target = jump_targets[inst.argval + offset]
+ break
+
+
+def devirtualize_jumps(instructions):
+ """Fill in args for virtualized jump target after instructions may have moved"""
+ indexof = {id(inst): i for i, inst, in enumerate(instructions)}
+ jumps = set(dis.hasjabs).union(set(dis.hasjrel))
+
+ for inst in instructions:
+ if inst.opcode in jumps:
+ target = inst.target
+ target_index = indexof[id(target)]
+ for offset in (1, 2, 3):
+ if (
+ target_index >= offset
+ and instructions[target_index - offset].opcode == dis.EXTENDED_ARG
+ ):
+ target = instructions[target_index - offset]
+ else:
+ break
+
+ if inst.opcode in dis.hasjabs:
+ if sys.version_info < (3, 10):
+ inst.arg = target.offset
+ else:
+ # arg is offset of the instruction line rather than the bytecode
+ # for all jabs/jrel since python 3.10
+ inst.arg = int(target.offset / 2)
+ else: # relative jump
+ if sys.version_info < (3, 10):
+ inst.arg = target.offset - inst.offset - instruction_size(inst)
+ else:
+ inst.arg = int(
+ (target.offset - inst.offset - instruction_size(inst)) / 2
+ )
+ inst.argval = target.offset
+ inst.argrepr = f"to {target.offset}"
+
+
+def strip_extended_args(instructions: List[Instruction]):
+ instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG]
+
+
+def remove_load_call_method(instructions: List[Instruction]):
+ """LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
+ rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"}
+ for inst in instructions:
+ if inst.opname in rewrites:
+ inst.opname = rewrites[inst.opname]
+ inst.opcode = dis.opmap[inst.opname]
+ return instructions
+
+
+def explicit_super(code: types.CodeType, instructions: List[Instruction]):
+ """convert super() with no args into explict arg form"""
+ cell_and_free = (code.co_cellvars or tuple()) + (code.co_freevars or tuple())
+ output = []
+ for idx, inst in enumerate(instructions):
+ output.append(inst)
+ if inst.opname == "LOAD_GLOBAL" and inst.argval == "super":
+ nexti = instructions[idx + 1]
+ if nexti.opname == "CALL_FUNCTION" and nexti.arg == 0:
+ assert "__class__" in cell_and_free
+ output.append(
+ create_instruction(
+ "LOAD_DEREF", cell_and_free.index("__class__"), "__class__"
+ )
+ )
+ first_var = code.co_varnames[0]
+ if first_var in cell_and_free:
+ output.append(
+ create_instruction(
+ "LOAD_DEREF", cell_and_free.index(first_var), first_var
+ )
+ )
+ else:
+ output.append(create_instruction("LOAD_FAST", 0, first_var))
+ nexti.arg = 2
+ nexti.argval = 2
+
+ instructions[:] = output
+
+
+def fix_extended_args(instructions: List[Instruction]):
+ """Fill in correct argvals for EXTENDED_ARG ops"""
+ output = []
+
+ def maybe_pop_n(n):
+ for _ in range(n):
+ if output and output[-1].opcode == dis.EXTENDED_ARG:
+ output.pop()
+
+ for i, inst in enumerate(instructions):
+ if inst.opcode == dis.EXTENDED_ARG:
+ # Leave this instruction alone for now so we never shrink code
+ inst.arg = 0
+ elif inst.arg and inst.arg > 0xFFFFFF:
+ maybe_pop_n(3)
+ output.append(create_instruction("EXTENDED_ARG", inst.arg >> 24))
+ output.append(create_instruction("EXTENDED_ARG", inst.arg >> 16))
+ output.append(create_instruction("EXTENDED_ARG", inst.arg >> 8))
+ elif inst.arg and inst.arg > 0xFFFF:
+ maybe_pop_n(2)
+ output.append(create_instruction("EXTENDED_ARG", inst.arg >> 16))
+ output.append(create_instruction("EXTENDED_ARG", inst.arg >> 8))
+ elif inst.arg and inst.arg > 0xFF:
+ maybe_pop_n(1)
+ output.append(create_instruction("EXTENDED_ARG", inst.arg >> 8))
+ output.append(inst)
+
+ added = len(output) - len(instructions)
+ assert added >= 0
+ instructions[:] = output
+ return added
+
+
+def instruction_size(inst):
+ return 2
+
+
+def check_offsets(instructions):
+ offset = 0
+ for inst in instructions:
+ assert inst.offset == offset
+ offset += instruction_size(inst)
+
+
+def update_offsets(instructions):
+ offset = 0
+ for inst in instructions:
+ inst.offset = offset
+ offset += instruction_size(inst)
+
+
+def debug_bytes(*args):
+ index = range(max(map(len, args)))
+ result = []
+ for arg in (
+ [index] + list(args) + [[int(a != b) for a, b in zip(args[-1], args[-2])]]
+ ):
+ result.append(" ".join(f"{x:03}" for x in arg))
+
+ return "bytes mismatch\n" + "\n".join(result)
+
+
+def debug_checks(code):
+ """Make sure our assembler produces same bytes as we start with"""
+ dode = transform_code_object(code, lambda x, y: None, safe=True)
+ assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code)
+ assert code.co_lnotab == dode.co_lnotab, debug_bytes(code.co_lnotab, dode.co_lnotab)
+
+
+HAS_LOCAL = set(dis.haslocal)
+HAS_NAME = set(dis.hasname)
+
+
+def fix_vars(instructions: List[Instruction], code_options):
+ varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])}
+ names = {name: idx for idx, name in enumerate(code_options["co_names"])}
+ for i in range(len(instructions)):
+ if instructions[i].opcode in HAS_LOCAL:
+ instructions[i].arg = varnames[instructions[i].argval]
+ elif instructions[i].opcode in HAS_NAME:
+ instructions[i].arg = names[instructions[i].argval]
+
+
+def transform_code_object(code, transformations, safe=False):
+ keys = [
+ "co_argcount",
+ "co_posonlyargcount", # python 3.8+
+ "co_kwonlyargcount",
+ "co_nlocals",
+ "co_stacksize",
+ "co_flags",
+ "co_code",
+ "co_consts",
+ "co_names",
+ "co_varnames",
+ "co_filename",
+ "co_name",
+ "co_firstlineno",
+ "co_lnotab", # changed to "co_linetable" if python 3.10+
+ "co_freevars",
+ "co_cellvars",
+ ]
+ if sys.version_info < (3, 8):
+ keys.pop(1)
+ if sys.version_info >= (3, 10):
+ keys = list(map(lambda x: x.replace("co_lnotab", "co_linetable"), keys))
+ code_options = {k: getattr(code, k) for k in keys}
+ assert len(code_options["co_varnames"]) == code_options["co_nlocals"]
+
+ instructions = cleaned_instructions(code, safe)
+
+ transformations(instructions, code_options)
+
+ fix_vars(instructions, code_options)
+
+ dirty = True
+ while dirty:
+ update_offsets(instructions)
+ devirtualize_jumps(instructions)
+ # this pass might change offsets, if so we need to try again
+ dirty = fix_extended_args(instructions)
+
+ bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"])
+ if sys.version_info < (3, 10):
+ code_options["co_lnotab"] = lnotab
+ else:
+ code_options["co_linetable"] = lnotab
+
+ code_options["co_code"] = bytecode
+ code_options["co_nlocals"] = len(code_options["co_varnames"])
+ code_options["co_stacksize"] = stacksize_analysis(instructions)
+ assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - {
+ "co_posonlyargcount"
+ }
+ return types.CodeType(*[code_options[k] for k in keys])
+
+
+def cleaned_instructions(code, safe=False):
+ instructions = list(map(convert_instruction, dis.get_instructions(code)))
+ check_offsets(instructions)
+ virtualize_jumps(instructions)
+ strip_extended_args(instructions)
+ if not safe:
+ remove_load_call_method(instructions)
+ explicit_super(code, instructions)
+ return instructions
+
+
+_unique_id_counter = itertools.count()
+
+
+def unique_id(name):
+ return f"{name}_{next(_unique_id_counter)}"
+
+
+def is_generator(code: types.CodeType):
+ co_generator = 0x20
+ return (code.co_flags & co_generator) > 0
diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py
new file mode 100644
index 0000000000000..2ba29981c3668
--- /dev/null
+++ b/torch/_dynamo/codegen.py
@@ -0,0 +1,362 @@
+import collections
+import dataclasses
+import re
+import sys
+import types
+from typing import List
+
+import torch.nn
+
+from .bytecode_transformation import create_instruction, Instruction
+from .exc import unimplemented
+from .source import AttrSource, Source
+from .utils import is_safe_constant, istype, rot_n_helper
+from .variables.base import VariableTracker
+from .variables.nn_module import NNModuleVariable
+from .variables.tensor import (
+ TensorVariable,
+ TensorWithTFOverrideVariable,
+ UnspecializedNumpyVariable,
+ UnspecializedPythonVariable,
+)
+
+
+@dataclasses.dataclass
+class GraphOutputEntry:
+ index: int
+ variable: VariableTracker
+
+ def merge(self, other: VariableTracker):
+ # merge in any extra guards
+ self.variable = self.variable.add_options(other)
+
+
+class PyCodegen(object):
+ """
+ Helper class uses for constructing Python bytecode
+ """
+
+ def __init__(
+ self,
+ tx=None,
+ root: torch.nn.Module = None,
+ graph_output_var: str = None,
+ tempvars=None,
+ ):
+ self.root = root
+ self.top_of_stack = None
+ self.uses = collections.Counter()
+ self.graph_outputs = collections.OrderedDict()
+ self._output: List[Instruction] = []
+ self.tempvars = tempvars or {}
+ self.tx = tx
+ self.graph_output_var = graph_output_var
+ self.code_options = self.tx.output.code_options
+ self.cell_and_freevars = self.tx.cell_and_freevars
+ self.new_var = self.tx.output.new_var
+
+ def graph_output_vars(self):
+ return [x.variable for x in self.graph_outputs.values()]
+
+ def __call__(self, value, allow_cache=True):
+ """Generate code such that top-of-stack (TOS) is set to value"""
+ if isinstance(value, Source):
+ self._output.extend(value.reconstruct(self))
+ self.clear_tos()
+ return
+
+ self.tx.output.guards.update(value.guards)
+
+ assert isinstance(value, VariableTracker)
+ output = self._output
+ graph_outputs = self.graph_outputs
+
+ if self.top_of_stack is value:
+ output.append(create_instruction("DUP_TOP"))
+ return
+
+ if allow_cache:
+ if value.mutable_local and value.mutable_local in self.tempvars:
+ output.append(self.create_load(self.tempvars[value.mutable_local]))
+ self.top_of_stack = value
+ return
+ if self.tempvars.get(value) is not None:
+ output.append(self.create_load(self.tempvars[value]))
+ self.top_of_stack = value
+ return
+
+ if value.source is not None and allow_cache:
+ output.extend(value.source.reconstruct(self))
+ elif value.is_python_constant() and is_safe_constant(
+ value.as_python_constant()
+ ):
+ output.append(self.create_load_const(value.as_python_constant()))
+ elif isinstance(
+ value,
+ (
+ TensorVariable,
+ TensorWithTFOverrideVariable,
+ UnspecializedNumpyVariable,
+ UnspecializedPythonVariable,
+ ),
+ ):
+ if isinstance(value, TensorWithTFOverrideVariable):
+ # unwrap back to tensor
+ value = value.tensor_variable
+ graph_outputs_key = id(value.proxy)
+ if graph_outputs_key not in graph_outputs:
+ graph_outputs[graph_outputs_key] = GraphOutputEntry(
+ len(graph_outputs), value
+ )
+ else:
+ graph_outputs[graph_outputs_key].merge(value)
+
+ output.append(self.create_load(self.graph_output_var))
+ output.append(
+ self._create_load_const(graph_outputs[graph_outputs_key].index)
+ )
+ output.append(create_instruction("BINARY_SUBSCR"))
+
+ if isinstance(value, UnspecializedNumpyVariable):
+ unspec_var = self.tx.output.new_var("unspec")
+ raw_type = type(value.raw_value)
+ output.extend(
+ [
+ self.create_load_attr("item"),
+ create_instruction("CALL_FUNCTION", 0),
+ self.create_store(unspec_var),
+ self.create_load_const(raw_type),
+ self.create_load(unspec_var),
+ create_instruction("CALL_FUNCTION", 1),
+ ]
+ )
+ if isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
+ output.extend(
+ [
+ self.create_load_attr("item"),
+ create_instruction("CALL_FUNCTION", 0),
+ ]
+ )
+ elif isinstance(value, NNModuleVariable):
+ parts = value.module_key.split(".")
+ if parts[0] in self.code_options["co_varnames"]:
+ output.append(self.create_load(parts[0]))
+ parts = parts[1:]
+ else:
+ assert self.root is not None
+ output.append(self.create_load_output(self.root))
+ for part in parts:
+ output.append(self.create_load_attr(part))
+ else:
+ self.uses[value] += 1
+ try:
+ output.extend(value.reconstruct(self))
+ except NotImplementedError:
+ unimplemented(f"reconstruct: {value}")
+ if allow_cache and value in self.tempvars:
+ self._output.append(create_instruction("DUP_TOP"))
+ self.add_cache(value)
+
+ self.top_of_stack = value
+
+ def add_cache(self, value):
+ var = self.new_var()
+ self.tempvars[value] = var
+ if value.mutable_local:
+ self.tempvars[value.mutable_local] = var
+ self._output.append(self.create_store(var))
+
+ def foreach(self, items):
+ for i in items:
+ self(i)
+
+ def setup_globally_cached(self, name, value):
+ """Store value in a new global"""
+ name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
+ f_globals = self.tx.f_globals
+ if name in f_globals:
+ assert id(f_globals[name]) == id(value)
+ else:
+ f_globals[name] = value
+ return [self.create_load_global(name, add=True)]
+
+ def clear_tos(self):
+ self.top_of_stack = None
+
+ def append_output(self, inst):
+ assert isinstance(inst, Instruction)
+ self._output.append(inst)
+ self.clear_tos()
+
+ def extend_output(self, insts):
+ assert all(isinstance(x, Instruction) for x in insts)
+ self._output.extend(insts)
+ self.clear_tos()
+
+ def get_instructions(self):
+ return self._output
+
+ def create_load(self, name):
+ if name in self.cell_and_freevars():
+ return create_instruction(
+ "LOAD_DEREF", self.cell_and_freevars().index(name), name
+ )
+ assert name in self.code_options["co_varnames"], f"{name} missing"
+ return create_instruction(
+ "LOAD_FAST", self.code_options["co_varnames"].index(name), name
+ )
+
+ def create_load_closure(self, name):
+ assert name in self.cell_and_freevars()
+ return create_instruction(
+ "LOAD_CLOSURE", self.cell_and_freevars().index(name), name
+ )
+
+ def create_store(self, name):
+ if name in self.cell_and_freevars():
+ return create_instruction(
+ "STORE_DEREF", self.cell_and_freevars().index(name), name
+ )
+ assert name in self.code_options["co_varnames"]
+ return create_instruction(
+ "STORE_FAST", self.code_options["co_varnames"].index(name), name
+ )
+
+ def create_load_global(self, name, add=False):
+ if add:
+ self.tx.output.update_co_names(name)
+ assert name in self.code_options["co_names"], f"{name} not in co_names"
+ return create_instruction(
+ "LOAD_GLOBAL", self.code_options["co_names"].index(name), name
+ )
+
+ def create_load_const(self, value):
+ assert is_safe_constant(value), f"unsafe constant {value}"
+ return self._create_load_const(value)
+
+ @staticmethod
+ def get_const_index(code_options, value):
+ co_consts = code_options["co_consts"]
+ assert istype(co_consts, tuple)
+ index = None
+ for i, v in enumerate(co_consts):
+ if type(v) is type(value) and v == value:
+ index = i
+ break
+ if index is None:
+ index = len(co_consts)
+ co_consts = co_consts + (value,)
+ code_options["co_consts"] = co_consts
+ return index
+
+ def _create_load_const(self, value):
+ index = self.get_const_index(self.code_options, value)
+ return create_instruction("LOAD_CONST", index, value)
+
+ create_load_output = _create_load_const
+
+ def create_load_attr(self, name):
+ if name not in self.code_options["co_names"]:
+ self.code_options["co_names"] = self.code_options["co_names"] + (name,)
+ return create_instruction(
+ "LOAD_ATTR", self.code_options["co_names"].index(name), name
+ )
+
+ def create_load_attrs(self, names):
+ return [self.create_load_attr(name) for name in names.split(".")]
+
+ def load_function_name(self, fn_name, num_on_stack=0):
+ """Load the global fn_name on the stack num_on_stack down"""
+ return [self.create_load_global(fn_name, add=True)] + self.rot_n(
+ num_on_stack + 1
+ )
+
+ def rot_n(self, n):
+ if n == 0 or n == 1:
+ return []
+ elif n == 2:
+ return [create_instruction("ROT_TWO")]
+ elif n == 3:
+ return [create_instruction("ROT_THREE")]
+ elif n == 4 and sys.version_info >= (3, 8):
+ return [create_instruction("ROT_FOUR")]
+ elif sys.version_info >= (3, 10):
+ return [create_instruction("ROT_N", n)]
+ else:
+ return [
+ create_instruction("BUILD_TUPLE", n),
+ self._create_load_const(rot_n_helper(n)),
+ create_instruction("ROT_TWO"),
+ create_instruction("CALL_FUNCTION_EX", 0),
+ create_instruction("UNPACK_SEQUENCE", n),
+ ]
+
+ def make_function_with_closure(
+ self, fn_name: str, code: types.CodeType, num_on_stack=0
+ ):
+ freevars = code.co_freevars
+ assert freevars
+ output = self._output
+ for var in freevars:
+ assert var in self.cell_and_freevars()
+ output.append(
+ create_instruction(
+ "LOAD_CLOSURE", self.cell_and_freevars().index(var), var
+ )
+ )
+ output.append(create_instruction("BUILD_TUPLE", len(freevars)))
+ output.append(self.create_load_const(code))
+ output.append(self.create_load_const(fn_name))
+ output.append(create_instruction("MAKE_FUNCTION", 0x08))
+ output.extend(self.rot_n(num_on_stack + 1))
+ self.clear_tos()
+
+ def create_load_python_module(self, mod):
+ """
+ Generate a LOAD_GLOBAL instruction to fetch a given python module.
+ """
+ root_globals = self.tx.output.root_globals
+ name = re.sub(r"^.*[.]", "", mod.__name__)
+ if root_globals.get(name, None) is mod:
+ return self.create_load_global(name, add=True)
+ mangled_name = f"___module_{name}_{id(mod)}"
+ if mangled_name not in root_globals:
+ self.tx.output.install_global(mangled_name, mod)
+ return self.create_load_global(mangled_name, add=True)
+
+ def make_call_generated_code(self, fn_name: str) -> List[Instruction]:
+ """Call the generated code function stored in fn_name"""
+ self.extend_output(self.load_function_name(fn_name))
+
+ graphargs = self.tx.output.graphargs
+ for arg in graphargs:
+ if arg.is_unspecialized:
+ self.extend_output(
+ [
+ self.create_load_python_module(torch),
+ self.create_load_attr("tensor"),
+ ]
+ )
+ self.extend_output(arg.load(self))
+ self.extend_output(
+ [
+ create_instruction("CALL_FUNCTION", 1),
+ ]
+ )
+ else:
+ self.extend_output(arg.load(self))
+
+ self.append_output(create_instruction("CALL_FUNCTION", len(graphargs)))
+
+ def load_import_from(self, module_name, object_name):
+ self.extend_output(
+ AttrSource(self.tx.import_source(module_name), object_name).reconstruct(
+ self
+ )
+ )
+
+ def create_begin_finally(self):
+ if sys.version_info < (3, 8):
+ return self.create_load_const(None)
+ else:
+ return create_instruction("BEGIN_FINALLY")
diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py
new file mode 100644
index 0000000000000..40933d7f120d2
--- /dev/null
+++ b/torch/_dynamo/config.py
@@ -0,0 +1,169 @@
+import logging
+import os
+import sys
+from os.path import abspath, dirname
+from types import ModuleType
+
+import torch
+
+try:
+ import torch._prims
+ import torch._refs
+
+ HAS_REFS_PRIMS = True
+except ImportError:
+ HAS_REFS_PRIMS = False
+
+
+# log level (levels print what it says + all levels listed below it)
+# logging.DEBUG print full traces <-- lowest level + print tracing of every instruction
+# torchdynamo.logging.CODE print compiled functions + graphs
+# logging.INFO print the steps that dynamo is running
+# logging.WARN print warnings (including graph breaks)
+# logging.ERROR print exceptions (and what user code was being processed when it occurred)
+# NOTE: changing log_level will automatically update the levels of all torchdynamo loggers
+log_level = logging.WARNING
+
+# the name of a file to write the logs to
+log_file_name = None
+
+# Verbose will print full stack traces on warnings and errors
+verbose = False
+
+# verify the correctness of optimized backend
+verify_correctness = False
+
+# need this many ops to create an FX graph
+minimum_call_count = 1
+
+# turn on/off DCE pass
+dead_code_elimination = True
+
+# disable (for a function) when cache reaches this size
+cache_size_limit = 64
+
+# specializing int/float by default
+specialize_int_float = True
+
+# Assume these functions return constants
+constant_functions = {
+ torch.jit.is_scripting: False,
+ torch.jit.is_tracing: False,
+ torch._C._get_tracing_state: None,
+ torch.fx._symbolic_trace.is_fx_tracing: False,
+ torch.onnx.is_in_onnx_export: False,
+}
+
+# root folder of the project
+base_dir = dirname(dirname(dirname(abspath(__file__))))
+
+# don't specialize on shapes and strides and put shape ops in graph
+dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
+
+# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
+guard_nn_modules = False
+
+# Run the FX graph as it is created to get better type information
+dynamic_propagation = True
+
+# Run the FX graph with FakeTensors
+fake_tensor_propagation = True
+
+# run FX normalization passes in optimizer
+normalize_ir = False
+
+# If a tensor subclass type is in this set, torchdynamo will inline the
+# __torch_function__ logic of the subclass.
+traceable_tensor_subclasses = set()
+
+# Raise torchdynamo internal assertions
+raise_on_assertion_error = False
+
+# Propagate backend exceptions up to torchdynamo.optimize
+raise_on_backend_error = True
+
+# Record and write an execution record of the current frame to a file
+# if an exception is encountered
+replay_record_enabled = False
+replay_record_dir_name = "./torchdynamo_error_records"
+
+# If a PyTorch module is in this allowlist, torchdynamo will be allowed
+# to inline objects from it or its children.
+skipfiles_inline_module_allowlist = {
+ torch.nn,
+ torch.distributions,
+ torch.testing,
+}
+if HAS_REFS_PRIMS:
+ skipfiles_inline_module_allowlist |= {
+ torch._refs,
+ torch._prims,
+ torch._decomp,
+ }
+
+# If a string representing a PyTorch module is in this ignorelist,
+# the `allowed_functions.is_allowed` function will not consider it
+# when creating a list of PyTorch functions that will appear in
+# FX IR.
+allowed_functions_module_string_ignorelist = {
+ "torch.distributions",
+ "torch.testing",
+ "torch._refs",
+ "torch._prims",
+ "torch._decomp",
+}
+
+# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"}
+# None - Minifier is switched off
+# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
+# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
+repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
+# Compiler compilation debug info
+# 1: Dumps the original graph out to repro.py if compilation fails
+# 2: Dumps a minifier_launcher.py if compilation fails.
+# 3: Always dumps a minifier_laucher.py. Good for segfaults.
+# 4: Dumps a minifier_launcher.py if the accuracy fails.
+repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
+
+# Specify the directory where to save the repro artifacts
+repro_dir = os.environ.get("TORCHDYNAMO_REPRO_DIR", None)
+
+# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
+# When this flag is set to False, we introduce a graph break instead of capturing.
+capture_scalar_outputs = False
+
+# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
+# false_fn produces code with identical guards.
+enforce_cond_guards_match = True
+
+# Automatically split model graph into pieces to match DDP bucket sizes
+# to allow DDP comm/compute overlap
+optimize_ddp = False
+
+# If True, raises exception if TorchDynamo is called with a context manager
+raise_on_ctx_manager_usage = True
+
+# If True, raise when aot autograd is unsafe to use
+raise_on_unsafe_aot_autograd = False
+
+# How to import torchdynamo, either torchdynamo or torch.dynamo
+dynamo_import = __name__.replace(".config", "")
+
+# How to import torchinductor, either torchinductor or torch.inductor
+inductor_import = dynamo_import.replace("dynamo", "inductor")
+
+
+class _AccessLimitingConfig(ModuleType):
+ def __setattr__(self, name, value):
+ if name not in _allowed_config_names:
+ raise AttributeError(f"{__name__}.{name} does not exist")
+ # automatically set logger level whenever config.log_level is modified
+ if name == "log_level":
+ from .logging import set_loggers_level
+
+ set_loggers_level(value)
+ return object.__setattr__(self, name, value)
+
+
+_allowed_config_names = {*globals().keys()}
+sys.modules[__name__].__class__ = _AccessLimitingConfig
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
new file mode 100644
index 0000000000000..d4afed9f63e37
--- /dev/null
+++ b/torch/_dynamo/convert_frame.py
@@ -0,0 +1,496 @@
+import functools
+import itertools
+import logging
+import os
+import traceback
+import types
+import typing
+import weakref
+from typing import Callable
+
+import torch
+from torch.fx.graph_module import _forward_from_src as original_forward_from_src
+
+from . import config, exc, logging as torchdynamo_logging
+from .allowed_functions import is_allowed
+from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
+from .bytecode_transformation import is_generator, transform_code_object
+from .eval_frame import (
+ always_optimize_code_objects,
+ skip_code,
+ TorchPatcher,
+ WrapperBackend,
+)
+from .exc import (
+ BackendCompilerFailed,
+ InternalTorchDynamoError,
+ TorchRuntimeError,
+ unimplemented,
+ Unsupported,
+)
+from .guards import CheckFunctionManager, GuardedCode
+from .replay_record import ExecutionRecord
+from .symbolic_convert import InstructionTranslator
+from .utils import (
+ CleanupManager,
+ counters,
+ dynamo_timed,
+ filter_stack,
+ format_bytecode,
+ gen_record_file_name,
+ guard_failures,
+ init_logging,
+ is_namedtuple,
+ istype,
+ orig_code_map,
+ troubleshooting_url,
+ write_record_to_file,
+)
+
+log = logging.getLogger(__name__)
+
+
+class Tracker:
+ def __init__(self):
+ self.seen = []
+ self.seen_ids = set()
+
+ def add(self, strong_obj):
+ idx = id(strong_obj)
+ if idx not in self.seen_ids:
+ obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx))
+ self.seen.append(obj)
+ self.seen_ids.add(idx)
+
+ def __contains__(self, item):
+ return id(item) in self.seen_ids
+
+ def clear(self):
+ self.seen.clear()
+ self.seen_ids.clear()
+
+
+input_codes = Tracker()
+output_codes = Tracker()
+
+
+initial_grad_state = None
+
+
+@functools.wraps(original_forward_from_src)
+def fx_forward_from_src_skip_result(*args, **kwargs):
+ # we monkey patch FX to prevent infinite loop of trying to convert
+ # our generated code
+ result: types.FunctionType = original_forward_from_src(*args, **kwargs)
+ skip_code(result.__code__)
+ return result
+
+
+def wrap_compiler_fn(compiler_fn):
+ """WrapperBackend if config.verify_correctness is True"""
+ if config.verify_correctness:
+ # wrap backend if verify_correctness is True
+ wrapper_backend_compiler_fn = WrapperBackend(compiler_fn)
+
+ wrapper_backend_compiler_fn._torchdynamo_orig_callable = compiler_fn
+ return wrapper_backend_compiler_fn
+
+ return compiler_fn
+
+
+def wrap_convert_context(fn):
+ """
+ Context manager to:
+ 1) Save/restore torch random state
+ 2) Save/restore torch.is_grad_enabled() state
+ 3) Monkey patch torch.fx.graph_module._forward_from_src
+ """
+
+ @functools.wraps(fn)
+ def _fn(*args, **kwargs):
+ prior_grad_mode = torch.is_grad_enabled()
+ rng_state = torch.random.get_rng_state()
+ if torch.cuda.is_available():
+ cuda_rng_state = torch.cuda.get_rng_state()
+ prior_fwd_from_src = torch.fx.graph_module._forward_from_src
+ torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
+ try:
+ return fn(*args, **kwargs)
+ finally:
+ torch._C._set_grad_enabled(prior_grad_mode)
+ torch.random.set_rng_state(rng_state)
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state(cuda_rng_state)
+ torch.fx.graph_module._forward_from_src = prior_fwd_from_src
+
+ _fn._torchdynamo_orig_callable = fn
+ return _fn
+
+
+@TorchPatcher.suppress_torch_distributed_warnings
+def has_tensor_in_frame(frame):
+ """Check if the frame has torch.* related bits"""
+ # Check if the function was decorated using torchdynamo.optimize
+ if frame.f_code in always_optimize_code_objects:
+ return True
+
+ # Check if there is global import of torch.*
+ for co_name in frame.f_code.co_names:
+ if co_name in frame.f_globals:
+ if is_allowed(frame.f_globals[co_name]):
+ return True
+
+ seen_ids = dict()
+
+ def has_tensor(obj):
+ """Recursively check if the obj has a tensor"""
+ obj_id = id(obj)
+ if obj_id in seen_ids:
+ return seen_ids[obj_id]
+ seen_ids[obj_id] = False
+
+ if isinstance(obj, (torch.Tensor, torch.nn.Module)):
+ seen_ids[obj_id] = True
+ return seen_ids[obj_id]
+ elif istype(obj, (list, tuple)):
+ seen_ids[obj_id] = any([has_tensor(v) for v in obj])
+ return seen_ids[obj_id]
+ elif istype(obj, dict):
+ seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()])
+ return seen_ids[obj_id]
+ elif istype(obj, (str, int, float, type(None), bool)):
+ seen_ids[obj_id] = False
+ return seen_ids[obj_id]
+ elif is_namedtuple(obj):
+ seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
+ return seen_ids[obj_id]
+ elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__):
+ seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()])
+ return seen_ids[obj_id]
+ else:
+ # if config.debug:
+ # print(
+ # f"Assuming that object of type {type(obj)} does not have a tensor"
+ # )
+ return False
+
+ # Check if the passed arguments are of type Tensor
+ for value in frame.f_locals.values():
+ if has_tensor(value):
+ return True
+
+ log.debug(
+ f"skipping because no torch.* {frame.f_code.co_name} \
+ {frame.f_code.co_filename} {frame.f_code.co_firstlineno}"
+ )
+
+ return False
+
+
+def format_error_msg(exc, code, record_filename=None, frame=None):
+ msg = os.linesep * 2
+
+ def replay_record_msg():
+ if (
+ config.replay_record_enabled
+ and hasattr(exc, "exec_record")
+ and record_filename is not None
+ ):
+ return f"\nLast frame execution written to {record_filename}. To run only this frame while debugging, run\
+ {config.dynamo_import}.replay('{record_filename}').\n"
+ else:
+ return ""
+
+ if config.verbose:
+ msg = format_bytecode(
+ "WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code
+ )
+ msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
+ msg += traceback.format_exc()
+ if hasattr(exc, "real_stack"):
+ msg += (
+ "\n"
+ + "=" * 10
+ + " The above exception occurred while processing the following code "
+ + "=" * 10
+ + "\n\n"
+ )
+ stack_above_dynamo = []
+ if frame is not None:
+ stack_above_dynamo = filter_stack(traceback.extract_stack(frame))
+
+ msg += "".join(
+ traceback.format_list(
+ stack_above_dynamo + list(reversed(exc.real_stack))
+ )
+ )
+
+ msg += replay_record_msg()
+
+ else:
+ msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
+ line {code.co_firstlineno} \ndue to: \n{traceback.format_exc(limit=-1)}"
+
+ if hasattr(exc, "real_stack"):
+ msg += f"\nfrom user code:\n {''.join(traceback.format_list([exc.real_stack[-1]]))}"
+
+ msg += replay_record_msg()
+
+ msg += (
+ f"\nSet {config.dynamo_import}.config.verbose=True for more information\n"
+ )
+ msg += "=" * 10
+ return msg
+
+
+def exception_handler(e, code, frame=None):
+ record_filename = None
+ if hasattr(e, "exec_record"):
+ record_filename = gen_record_file_name(e, code)
+ write_record_to_file(record_filename, e.exec_record)
+
+ log.error(format_error_msg(e, code, record_filename, frame))
+
+
+def convert_frame_assert(
+ compiler_fn: Callable, guard_export_fn=None, one_graph=True, export=False
+):
+ """Fully convert a frame into an FX graph"""
+ init_logging()
+
+ compiler_fn = wrap_compiler_fn(compiler_fn)
+
+ @dynamo_timed
+ def _convert_frame_assert(frame: types.FrameType, cache_size: int):
+ code = frame.f_code
+ input_codes.add(code)
+ if code in output_codes:
+ return None
+ if (
+ os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION")
+ and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name
+ ):
+ return None
+ if code.co_name == "" and code.co_filename.endswith(
+ ("transformers/file_utils.py", "transformers/utils/generic.py")
+ ):
+ # not needed, but cleans up torchbench error stats
+ return None
+ if code.co_name == "__setattr__":
+ # setattr could be tricky to handle generally,
+ # but also not likely useful to compile- skip the whole frame
+ return None
+ # Check if the frame is generated by an exec builtin call
+ # TODO - Running exec generated frame seems propagates f_globals to the
+ # next frames.
+ if code.co_name == "" and code.co_filename == "":
+ return None
+
+ if (
+ code.co_name == ""
+ and code.co_filename == ""
+ and not bool(frame.f_builtins)
+ ):
+ # namedtuple subclass constructor. Empty builtins cause issue with
+ # len keyword in LIST_LEN guard.
+ return None
+
+ if is_generator(code):
+ unimplemented("generator")
+ if cache_size >= config.cache_size_limit:
+
+ def format_func_info(code):
+ return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})"
+
+ def format_guard_failures(code):
+ # For the common case, it's sufficient to see just the most recent failure.
+ # We could add a verbose mode if needed
+ return f"{str(guard_failures[code][-1])}"
+
+ assert code in guard_failures, "TODO(whc) any other recompile reasons?"
+ log.warning(
+ f"{config.dynamo_import} hit config.cache_size_limit ({config.cache_size_limit})\n"
+ + f" function: {format_func_info(code)}\n"
+ + f" reasons: {format_guard_failures(code)}\n"
+ + f"to diagnose recompilation issues, see {troubleshooting_url}."
+ )
+ unimplemented("cache_size_limit reached")
+
+ if not has_tensor_in_frame(frame):
+ return None
+
+ global initial_grad_state
+ initial_grad_state = torch.is_grad_enabled()
+
+ return _compile(
+ frame.f_code,
+ frame.f_globals,
+ frame.f_locals,
+ frame.f_builtins,
+ compiler_fn,
+ one_graph,
+ export,
+ guard_export_fn,
+ frame,
+ )
+
+ _convert_frame_assert._torchdynamo_orig_callable = compiler_fn
+ return wrap_convert_context(_convert_frame_assert)
+
+
+def _compile(
+ code,
+ globals,
+ locals,
+ builtins,
+ compiler_fn,
+ one_graph,
+ export,
+ guard_export_fn=None,
+ frame=None,
+):
+ output = None
+
+ # from .utils import print_once; print_once(code.co_filename)
+ def transform(instructions, code_options):
+ nonlocal output
+ tracer = InstructionTranslator(
+ instructions,
+ code,
+ locals,
+ globals,
+ builtins,
+ code_options,
+ compiler_fn,
+ one_graph,
+ export,
+ )
+ tracer.run()
+ output = tracer.output
+ assert output.output_instructions
+ instructions[:] = output.output_instructions
+ code_options.update(output.code_options)
+
+ if config.dead_code_elimination:
+ instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
+
+ try:
+ for attempt in itertools.count():
+ try:
+ out_code = transform_code_object(code, transform)
+ orig_code_map[out_code] = code
+ break
+ except exc.RestartAnalysis:
+ log.debug("Restarting analysis ...")
+ if attempt > 100:
+ unimplemented("100+ RestartAnalysis() calls")
+ except exc.SkipFrame:
+ log.debug(
+ f"Skipping frame {code.co_name} \
+ {code.co_filename} {code.co_firstlineno}"
+ )
+ if one_graph:
+ log.debug("No graph captured with one_graph=True")
+ return None
+ output_codes.add(out_code)
+
+ log.log(
+ torchdynamo_logging.CODE,
+ format_bytecode(
+ "ORIGINAL BYTECODE",
+ code.co_name,
+ code.co_filename,
+ code.co_firstlineno,
+ code,
+ ),
+ )
+ log.log(
+ torchdynamo_logging.CODE,
+ format_bytecode(
+ "MODIFIED BYTECODE",
+ code.co_name,
+ code.co_filename,
+ code.co_firstlineno,
+ out_code,
+ ),
+ )
+
+ assert output.guards is not None
+ CleanupManager.instance[out_code] = output.cleanups
+ check_fn = CheckFunctionManager(output.guards, locals, globals)
+
+ guarded_code = GuardedCode(out_code, check_fn.check_fn)
+ guard_str = "GUARDS:\n"
+ guard_str += "\n".join([f" - {str(guard)}" for guard in sorted(output.guards)])
+
+ log.log(torchdynamo_logging.CODE, guard_str)
+
+ if guard_export_fn is not None:
+ guard_export_fn(output.guards)
+
+ return guarded_code
+ except (
+ Unsupported,
+ TorchRuntimeError,
+ BackendCompilerFailed,
+ AssertionError,
+ ) as e:
+ exception_handler(e, code, frame)
+ raise
+ except Exception as e:
+ exception_handler(e, code, frame)
+ raise InternalTorchDynamoError()
+
+
+def convert_frame(compiler_fn: typing.Callable, guard_export_fn=None):
+ """Try to convert a frame into an FX graph, if error leave frame unmodified"""
+ inner_convert = convert_frame_assert(compiler_fn, guard_export_fn, one_graph=False)
+
+ def _convert_frame(frame: types.FrameType, cache_size: int):
+ counters["frames"]["total"] += 1
+ try:
+ result = inner_convert(frame, cache_size)
+ counters["frames"]["ok"] += 1
+ return result
+ except AssertionError:
+ if config.raise_on_assertion_error:
+ raise
+ except BackendCompilerFailed:
+ raise
+ except Exception:
+ pass
+ return None
+
+ _convert_frame._torchdynamo_orig_callable = compiler_fn
+ return _convert_frame
+
+
+# TODO mlazos: add support for same args, or record them
+def replay(filename):
+ from .optimizations.backends import eager
+
+ original_replay_val = config.replay_record_enabled
+ config.replay_record_enabled = False
+ init_logging()
+ with open(filename, "rb") as in_file:
+ record = ExecutionRecord.load(in_file)
+ record.globals = {
+ k: v for k, v in itertools.chain(record.globals.items(), globals().items())
+ }
+
+ try:
+ _compile(
+ record.code,
+ record.globals,
+ record.locals,
+ record.builtins,
+ eager,
+ False, # one_graph
+ None, # export_fn
+ None, # frame
+ False, # Export
+ )
+ except Exception:
+ pass
+ finally:
+ config.replay_record_enabled = original_replay_val
diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py
new file mode 100644
index 0000000000000..ac56c0e262046
--- /dev/null
+++ b/torch/_dynamo/debug_utils.py
@@ -0,0 +1,879 @@
+import copy
+import functools
+import getpass
+import logging
+import os
+import shutil
+import subprocess
+import textwrap
+import uuid
+from collections import Counter
+from importlib import import_module
+
+import torch
+import torch.fx as fx
+
+from . import config
+from .optimizations.backends import register_backend
+from .utils import clone_inputs
+
+log = logging.getLogger(__name__)
+
+
+def minifier_dir():
+ path = config.repro_dir
+ if path is None:
+ path = f"/tmp/minifier_{getpass.getuser()}"
+ if not os.path.exists(path):
+ os.makedirs(path, exist_ok=True)
+ return path
+
+
+class NNModuleToString:
+ safe_reprs = [
+ torch.nn.Linear,
+ torch.nn.Conv1d,
+ torch.nn.Conv2d,
+ torch.nn.Conv3d,
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.LayerNorm,
+ torch.nn.Dropout,
+ torch.nn.Softmax,
+ torch.nn.ReLU,
+ torch.nn.GELU,
+ torch.nn.Identity,
+ torch.nn.MaxPool2d,
+ torch.nn.Embedding,
+ torch.nn.Tanh,
+ torch.nn.ConvTranspose1d,
+ torch.nn.GLU,
+ torch.nn.LSTM,
+ torch.nn.Flatten,
+ torch.nn.AdaptiveAvgPool2d,
+ ]
+
+ @staticmethod
+ def can_convert_to_string(gm):
+ cant_convert = set()
+ for _, module in gm.named_children():
+ if type(module) not in NNModuleToString.safe_reprs:
+ cant_convert.add(module)
+
+ if len(cant_convert) > 0:
+ log.warning(f"We have not tested reprs of some modules - {cant_convert}")
+ # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
+ return True
+
+ @staticmethod
+ def convert(gm):
+ from torch.nn.modules.module import _addindent
+
+ tab = " " * 4
+
+ model_str = textwrap.dedent(
+ """
+ from torch.nn import *
+ class Repro(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ """
+ )
+
+ for module_name, module in gm.named_children():
+ module_str = f"{module.__repr__()}"
+ model_str += f"{tab*2}self.{module_name} = {module_str}\n"
+
+ for buffer_name, buffer in gm._buffers.items():
+ if buffer is None:
+ continue
+ if torch.is_floating_point(buffer):
+ tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
+ else:
+ tensor_str = (
+ f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
+ )
+ model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
+
+ for param_name, param in gm._parameters.items():
+ if param is None:
+ continue
+ tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))"
+ model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
+
+ # TODO - Keep this code for now. But, I don't think we will need this.
+ # attrs = dir(gm)
+ # for attr in attrs:
+ # if "_tensor_constant" in attr:
+ # val = getattr(gm, attr)
+ # model_str += f" {attr} = {val!r}\n"
+
+ model_str += f"{_addindent(gm.code, 4)}\n"
+ return model_str
+
+
+@functools.lru_cache(None) # subprocess is expensive
+def _cuda_system_info_comment():
+ if not torch.cuda.is_available():
+ return "# torch.cuda.is_available()==False, no GPU info collected\n"
+
+ model_str = "# CUDA Info: \n"
+ try:
+ cuda_version_out = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE)
+ cuda_version_lines = cuda_version_out.stdout.decode().split("\n")
+ cuda_version_out = "".join(
+ [f"# {s} \n" for s in cuda_version_lines if s not in [""]]
+ )
+ model_str += f"{cuda_version_out}\n"
+ except FileNotFoundError:
+ model_str += "nvcc not found\n"
+
+ gpu_names = subprocess.run(
+ ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"],
+ stdout=subprocess.PIPE,
+ )
+ gpu_names = gpu_names.stdout.decode().split("\n")
+ gpu_names = [name for name in gpu_names if name not in ("", "name")]
+ gpu_names = Counter(gpu_names)
+
+ model_str += "# GPU Hardware Info: \n"
+ for name, count in gpu_names.items():
+ model_str += f"# {name} : {count} \n"
+ model_str += "\n"
+ return model_str
+
+
+def generate_compiler_repro_string(gm, args):
+ model_str = textwrap.dedent(
+ f"""
+ import torch
+ from torch import tensor, device
+ import torch.fx as fx
+ from {config.dynamo_import}.testing import rand_strided
+ from math import inf
+ from torch.fx.experimental.proxy_tensor import make_fx
+
+ """
+ )
+ model_str += f"# torch version: {torch.version.__version__}\n"
+ if hasattr(torch.version, "cuda"):
+ model_str += f"# torch cuda version: {torch.version.cuda}\n"
+ if hasattr(torch.version, "git_version"):
+ model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
+ model_str += _cuda_system_info_comment()
+
+ model_str += NNModuleToString.convert(gm)
+
+ model_str += f"args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type) for a in args]!r}\n"
+ model_str += (
+ "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n"
+ )
+ model_str += 'mod = make_fx(Repro().to(device="cuda"))(*args)\n'
+ return model_str
+
+
+INDUCTOR_IMPORT = f"""
+from {config.inductor_import}.compile_fx import compile_fx_inner
+from {config.dynamo_import}.debug_utils import same_two_models
+"""
+
+NVFUSER_IMPORT = """
+from torch.fx.passes.backends.nvfuser import NvFuserBackend
+nvfuser = NvFuserBackend()
+"""
+
+COMPILER_REPRO_OPTIONS = {
+ "inductor": (INDUCTOR_IMPORT, "compile_fx_inner", "inductor_fails"),
+ "inductor_accuracy": (
+ INDUCTOR_IMPORT,
+ "compile_fx_inner",
+ "inductor_accuracy_fails",
+ ),
+ "nvfuser": (NVFUSER_IMPORT, "nvfuser", "nvfuser_fails"),
+}
+
+
+def dump_compiler_graph_state(gm, args, compiler_name):
+ subdir = os.path.join(minifier_dir(), "checkpoints")
+ if not os.path.exists(subdir):
+ os.makedirs(subdir, exist_ok=True)
+ file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py")
+ log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
+ with open(file_name, "w") as fd:
+ save_graph_repro(fd, gm, args, compiler_name)
+ repro_path = os.path.join(config.base_dir, "repro.py")
+ try:
+ shutil.copyfile(file_name, repro_path)
+ log.warning(f"Copying repro file for convenience to {repro_path}")
+ except OSError:
+ log.warning(f"No write permissions for {repro_path}")
+ pass
+
+
+def save_graph_repro(fd, gm, args, compiler_name):
+ fd.write(generate_compiler_repro_string(gm, args))
+ fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0])
+ if "_accuracy" in compiler_name:
+ fd.write(
+ textwrap.dedent(
+ f"""
+ compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args)
+ assert same_two_models(mod, compiled, args, only_fwd=True), "Accuracy failed"
+ """
+ )
+ )
+ else:
+ fd.write(
+ textwrap.dedent(
+ f"""
+ compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args)
+ compiled(*args)
+ """
+ )
+ )
+
+
+def isolate_fails(fx_g, args, compiler_name: str, env=None):
+ if env is None:
+ env = {}
+ subdir = f"{minifier_dir()}/isolate"
+ if not os.path.exists(subdir):
+ os.makedirs(subdir, exist_ok=True)
+ file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py")
+ with open(file_name, "w") as fd:
+ fd.write(generate_compiler_repro_string(fx_g, args))
+ fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2]
+ fd.write(
+ textwrap.dedent(
+ f"""
+ from {__name__} import {fail_fn}
+ """
+ )
+ )
+ fd.write(
+ textwrap.dedent(
+ f"""
+ if {fail_fn}(mod, args):
+ exit(1)
+ else:
+ exit(0)
+ """
+ )
+ )
+ new_env = os.environ.copy()
+ new_env = {**new_env, **env}
+ p = subprocess.Popen(
+ ["python", file_name],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ env=new_env,
+ )
+ out, err = p.communicate()
+ if p.returncode != 0:
+ print(textwrap.indent(out.decode("utf-8"), prefix=">> "))
+ print(textwrap.indent(err.decode("utf-8"), prefix=">> "))
+ return True
+ return False
+
+
+def inductor_fails(fx_g, args, check_str=None):
+ compile_fx_inner = import_module(
+ f"{config.inductor_import}.compile_fx"
+ ).compile_fx_inner
+
+ import_module(f"{config.inductor_import}.config").triton.autotune = False
+
+ try:
+ result = fx_g(*args)
+ assert isinstance(result, (tuple, list))
+ assert not any([isinstance(x, (tuple, list)) for x in result])
+ except Exception:
+ return False
+
+ try:
+ compile_mod = compile_fx_inner(fx_g, args)
+ compile_mod(*args)
+ except Exception as e:
+ if check_str is not None and check_str not in repr(e):
+ return False
+ print(repr(e))
+ return True
+ return False
+
+
+def nvfuser_fails(fx_g, args, check_str=None):
+ from torch.fx.passes.backends.nvfuser import NvFuserBackend
+
+ nvfuser = NvFuserBackend()
+
+ try:
+ compile_mod = nvfuser(fx_g, args)
+ compile_mod = compile_mod(*args)
+ except Exception as e:
+ if check_str is not None and check_str not in repr(e):
+ return False
+ print(repr(e))
+ return True
+ return False
+
+
+def inductor_accuracy_fails(fx_g, args, check_str=None):
+ from torchinductor.compile_fx import compile_fx_inner
+
+ return backend_aot_accuracy_fails(fx_g, args, compile_fx_inner)
+
+
+def helper_for_dump_minify(contents):
+ minified_repro_path = os.path.join(minifier_dir(), "minifier_launcher.py")
+ log.warning(f"Writing minified repro to {minified_repro_path}")
+ try:
+ with open(minified_repro_path, "w") as fd:
+ fd.write(contents)
+ except OSError as e:
+ log.exception(e)
+ raise NotImplementedError("Could not write to {minified_repro_path}")
+
+ local_path = os.path.join(config.base_dir, "minifier_launcher.py")
+ try:
+ shutil.copyfile(minified_repro_path, local_path)
+ log.warning(
+ f"Copying minified repro from {minified_repro_path} to {local_path} for convenience"
+ )
+ except OSError:
+ log.warning(f"Don't have write permissions for {local_path}")
+
+
+def dump_to_minify(gm, args, compiler_name: str):
+ favored_device = 1 if torch.cuda.device_count() >= 2 else 0
+
+ contents = textwrap.dedent(
+ f"""
+{generate_compiler_repro_string(gm, args)}
+
+from functools import partial
+from {__name__} import (
+ isolate_fails,
+ dump_compiler_graph_state,
+)
+from functorch.compile import minifier
+
+env_variables = {{"CUDA_VISIBLE_DEVICES": "{favored_device}"}}
+
+minifier(
+ mod,
+ args,
+ module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}"),
+ dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"),
+)
+ """
+ )
+ return helper_for_dump_minify(contents)
+
+
+def wrap_compiler_debug(compiler_fn, compiler_name: str):
+ """
+ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both
+ forward and backward call separately with the backend compiler_fn - like
+ inductor or nvfuser. Intercepting after Aot Autograd presents neat
+ abstration, where all the params are lifted as graph inputs, making it easy
+ to save the graph as a string.
+ """
+
+ @functools.wraps(compiler_fn)
+ def debug_wrapper(gm, example_inputs, **kwargs):
+ orig_graph = copy.deepcopy(gm.graph)
+ assert config.repro_after in ("dynamo", "aot", None)
+
+ def deferred_for_real_inputs(*real_inputs):
+ """
+ Aot Autograd fw_compiler and bw_compiler can have fake tensors. So,
+ example_inputs can be fake tensors. We can call compiler_fn (which is
+ inductor or nvfuser) with fake tensors but the actualy compiled_fn
+ should be called with real tensors. Therefore, the actual invocation
+ is deffered.
+ """
+ if config.repro_level == 3:
+ # Always dump the original module in case we have segfaults
+ dump_to_minify(
+ fx.GraphModule(gm, orig_graph), real_inputs, compiler_name
+ )
+
+ if config.repro_level == 4:
+ if compiler_name != "inductor":
+ raise NotImplementedError(
+ "Accuracy minification is supported for inductor only"
+ )
+ compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
+ if backend_aot_accuracy_fails(gm, real_inputs, compiler_fn):
+ log.warning("Accuracy failed for the AOT Autograd graph")
+ dump_compiler_graph_state(
+ fx.GraphModule(gm, orig_graph),
+ real_inputs,
+ f"{compiler_name}_accuracy",
+ )
+ dump_to_minify(
+ fx.GraphModule(gm, orig_graph),
+ real_inputs,
+ f"{compiler_name}_accuracy",
+ )
+ raise ValueError("Bad accuracy detected")
+ else:
+ # Call the compiled function with real inputs
+ return compiled_fn(*real_inputs)
+ else:
+ try:
+ # Call the compiler_fn - which is either aot_autograd or inductor
+ # with fake inputs
+ compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
+ # Call the compiled function with real inputs
+ return compiled_fn(*real_inputs)
+ except Exception as e:
+ if config.repro_level == 1:
+ dump_compiler_graph_state(
+ fx.GraphModule(gm, orig_graph), real_inputs, compiler_name
+ )
+ elif config.repro_level == 2:
+ dump_to_minify(
+ fx.GraphModule(gm, orig_graph), real_inputs, compiler_name
+ )
+ raise e
+
+ if config.repro_after == "aot":
+ compiled_fn = deferred_for_real_inputs
+ else:
+ compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
+
+ return compiled_fn
+
+ return debug_wrapper
+
+
+def run_fwd_maybe_bwd(gm, args, only_fwd=False):
+ """
+ Runs a forward and possibly backward iteration for a given mod and args.
+ """
+ from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
+
+ gm = copy.deepcopy(gm)
+ new_args = clone_inputs(args)
+ # Set the requires_grad field explicitly because clone_inputs only sets
+ # requires_grad for leaf tensors.
+ for narg, arg in zip(new_args, args):
+ narg.requires_grad_(arg.requires_grad)
+ args = new_args
+
+ if hasattr(gm, "zero_grad"):
+ gm.zero_grad(True)
+ out = gm(*args)
+ if only_fwd:
+ return out
+ if requires_bwd_pass(out):
+ loss = reduce_to_scalar_loss(out)
+ loss.backward()
+ return collect_results(gm, out, None, [])
+
+
+def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
+ """
+ Check two models have same accuracy.
+ """
+ from .utils import same
+
+ ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
+
+ try:
+ fp64_model, fp64_examples = cast_to_fp64(
+ copy.deepcopy(gm), clone_inputs(example_inputs)
+ )
+ fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
+ except Exception:
+ log.warning("Could not generate fp64 outputs")
+ fp64_ref = None
+
+ res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
+
+ passing = same(ref, res, fp64_ref, tol=0.001, equal_nan=True)
+ return passing
+
+
+def cast_to(dtype, model, inputs):
+ from torch.utils._pytree import tree_map
+
+ # cast model and inputs to fp16
+ model = model.to(dtype)
+
+ inputs = tree_map(
+ lambda x: x.to(dtype)
+ if isinstance(x, torch.Tensor) and x.is_floating_point()
+ else x,
+ inputs,
+ )
+ return model, inputs
+
+
+def cast_to_fp64(model, inputs):
+ return cast_to(torch.float64, model, inputs)
+
+
+def generate_dynamo_fx_repro_string(
+ model_str, args, compiler_name, check_accuracy=False
+):
+ """
+ Generate a repro string for backend-agnostic minified version.
+ """
+
+ run_code = textwrap.dedent(
+ f"""
+with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}):
+ ref = run_fwd_maybe_bwd(mod, args)
+ res = run_fwd_maybe_bwd(opt_mod, args)
+ """
+ )
+
+ if config.repro_level == 4 or check_accuracy:
+ run_code = textwrap.dedent(
+ f"""
+mod.eval()
+opt_mod.eval()
+with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}):
+ assert same_two_models(mod, mod, args), "Eager itself failed"
+ assert same_two_models(mod, opt_mod, args), "Dynamo failed"
+ """
+ )
+
+ return textwrap.dedent(
+ f"""
+from math import inf
+import torch
+from torch import tensor, device
+import torch.fx as fx
+import {config.dynamo_import}
+from {config.dynamo_import}.testing import rand_strided
+from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd
+from {config.dynamo_import}.debug_utils import same_two_models
+
+args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]}
+args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
+
+{model_str}
+
+mod = Repro().cuda()
+opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod)
+
+{run_code}
+ """
+ )
+
+
+def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False):
+ """
+ Saves the repro to a repro.py file
+ """
+ subdir = os.path.join(minifier_dir())
+ if not os.path.exists(subdir):
+ os.makedirs(subdir, exist_ok=True)
+ file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py")
+ log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
+
+ model_str = NNModuleToString.convert(gm)
+ with open(file_name, "w") as fd:
+ fd.write(
+ generate_dynamo_fx_repro_string(
+ model_str, args, compiler_name, check_accuracy
+ )
+ )
+ latest_repro = os.path.join(subdir, "repro.py")
+ log.warning(f"Copying {file_name} to {latest_repro} for convenience")
+ shutil.copyfile(file_name, latest_repro)
+
+ local_path = os.path.join(config.base_dir, "repro.py")
+ try:
+ shutil.copyfile(file_name, local_path)
+ log.warning(
+ f"Copying minified repro from {file_name} to {local_path} for convenience"
+ )
+ except OSError:
+ log.warning("No write permissions for {local_path}")
+
+
+# TODO - Commented because we are assuming that nn.Modules can be safely repr'd
+# If that does not work, we might have to bring this code back. So, keeping it
+# as it is for now.
+# def dump_backend_repro_as_tarfile(gm, args, compiler_name):
+# """
+# Saves the repro in repro.tar.gz, as opposed to a file. This is used for
+# cases, where we can't convert a Fx GraphModule to a string, and therefore
+# fallback to to_folder for serialization. We accompany this with a repro.py
+# script that imports the saved module, sets it up and runs the model to repro
+# the error.
+# """
+# import tarfile
+
+# subdir = os.path.join(minifier_dir(), "checkpoints")
+# if not os.path.exists(subdir):
+# os.makedirs(subdir, exist_ok=True)
+
+# tmp_dir = os.path.join(subdir, f"{len(gm.graph.nodes)}")
+# if os.path.exists(tmp_dir):
+# shutil.rmtree(tmp_dir)
+# os.makedirs(tmp_dir, exist_ok=True)
+
+# file_name = os.path.join(tmp_dir, "repro.py")
+# gm_dir = os.path.join(tmp_dir, "module")
+# if not os.path.exists(gm_dir):
+# os.makedirs(gm_dir, exist_ok=True)
+# for node in gm.graph.nodes:
+# new_kwargs = {}
+# for k, v in node.kwargs.items():
+# if isinstance(v, torch.device):
+# v = v.type
+# new_kwargs[k] = v
+# node.kwargs = new_kwargs
+# gm.recompile()
+
+# print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
+# with open(file_name, "w") as fd:
+# # TODO - Add the readable version of to_folder when available
+# gm.to_folder(gm_dir, "Repro")
+# fd.write(
+# generate_dynamo_fx_repro_string(
+# "from module import Repro", args, compiler_name
+# )
+# )
+
+# local_dir = os.path.join(config.base_dir, "repro")
+# if os.path.exists(local_dir):
+# shutil.rmtree(local_dir)
+# shutil.copytree(tmp_dir, local_dir)
+# local_tar_file = os.path.join(config.base_dir, "repro.tar.gz")
+# print(f"Writing checkpoint with {len(gm.graph.nodes)} locally to {local_tar_file}")
+# with tarfile.open(local_tar_file, "w:gz") as tar:
+# tar.add(local_dir, arcname=os.path.basename(local_dir))
+
+
+def dump_backend_state(gm, args, compiler_name, check_accuracy=False):
+ """
+ Dumps the dynamo graph to repro the issue.
+ 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a
+ repro.py file.
+ 2) If we can't convert Fx GraphModule to a string, we use to_folder to save
+ the module and save a tar file.
+ """
+ assert NNModuleToString.can_convert_to_string(gm)
+ return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy)
+ # return dump_backend_repro_as_tarfile(gm, args, compiler_name)
+
+
+def backend_accuracy_fails(gm, example_inputs, compiler_fn, only_fwd=False):
+ compiled_gm = compiler_fn(copy.deepcopy(gm), clone_inputs(example_inputs))
+ return not same_two_models(gm, compiled_gm, example_inputs, only_fwd)
+
+
+backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True)
+
+
+def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
+ """
+ Minifier uses this function to identify if the minified graph module fails
+ with the same error.
+
+ One caveat is that minifier can potentially go into a wrong direction when
+ the resulting graph module fails for a different reason. To avoid this, we
+ save the string for the original exception and check similarity between new
+ and old exception. They can be somewhat different in some cases, when the
+ exception string depends on the failing node information. So, we have a
+ loose similarity metric to guide the minifier path.
+ """
+ from difflib import SequenceMatcher
+
+ try:
+ compiled_gm = compiler_fn(gm, example_inputs)
+ run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs))
+ return False
+ except Exception as e:
+ new_failure = str(e)
+ if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5:
+ return True
+ return False
+
+
+def dump_to_minify_after_dynamo(gm, args, compiler_name):
+ model_str = NNModuleToString.convert(gm)
+
+ minifier_backend = "dynamo_minifier_backend"
+ if config.repro_level == 4:
+ minifier_backend = "dynamo_accuracy_minifier_backend"
+
+ contents = textwrap.dedent(
+ f"""
+import os
+from math import inf
+import torch
+from torch import tensor, device
+import torch.fx as fx
+import functools
+import {config.dynamo_import}
+from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd
+from {config.dynamo_import}.optimizations.backends import BACKENDS
+from {config.dynamo_import}.testing import rand_strided
+
+{config.dynamo_import}.config.repro_dir = \"{minifier_dir()}\"
+
+args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]}
+args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
+
+{model_str}
+mod = Repro().cuda()
+
+# Setup debug minifier compiler
+compiler_fn = BACKENDS["{minifier_backend}"]
+dynamo_minifier_backend = functools.partial(
+ compiler_fn,
+ compiler_name="{compiler_name}",
+)
+opt_mod = {config.dynamo_import}.optimize(dynamo_minifier_backend)(mod)
+
+with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}):
+ opt_mod(*args)
+ """
+ )
+ helper_for_dump_minify(contents)
+
+
+def wrap_backend_debug(compiler_fn, compiler_name: str):
+ """
+ A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
+ As opposed to wrap_compiler_debug, this wrapper intercepts at the
+ TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
+ level, e.g., it is useful for minifying issues related to Aot Autograd
+ tracing. If an error is found, we minify and save the minified repro in
+ repro.tar.gz.
+ """
+
+ @functools.wraps(compiler_fn)
+ def debug_wrapper(gm, example_inputs, **kwargs):
+ assert config.repro_after in ("dynamo", "aot", None)
+ if config.repro_after == "dynamo":
+ # Ensure that we fail when backend fails
+ config.raise_on_backend_error = True
+ if config.repro_level == 3:
+ dump_to_minify_after_dynamo(gm, example_inputs, compiler_name)
+
+ # Check for either accuracy (level 4) or other type of failures.
+ if config.repro_level == 4:
+ # Check Accuracy
+ compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
+ if backend_accuracy_fails(gm, example_inputs, compiler_fn):
+ log.warning("Accuracy failed for the TorchDyanmo produced graph")
+ dump_to_minify_after_dynamo(
+ fx.GraphModule(gm, copy.deepcopy(gm.graph)),
+ example_inputs,
+ compiler_name,
+ )
+ raise ValueError("Bad accuracy detected")
+ else:
+ try:
+ compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
+ run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs))
+ except Exception as exc:
+ log.warning(
+ "Compiled Fx GraphModule failed with following error. Setting up minifier."
+ )
+ log.exception(exc)
+ if config.repro_level == 1:
+ dump_state_fn = functools.partial(
+ dump_backend_state, compiler_name=compiler_name
+ )
+ dump_state_fn(
+ fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
+ )
+ elif config.repro_level == 2:
+ dump_to_minify_after_dynamo(
+ fx.GraphModule(gm, copy.deepcopy(gm.graph)),
+ example_inputs,
+ compiler_name,
+ )
+ raise ValueError("Issue deteced. Repro at minifier_launcher.py.")
+ else:
+ compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
+
+ return compiled_gm
+
+ debug_wrapper._torchdynamo_orig_callable = compiler_fn
+
+ return debug_wrapper
+
+
+@register_backend
+def dynamo_minifier_backend(gm, example_inputs, compiler_name):
+ from functorch.compile import minifier
+
+ from .eval_frame import lookup_backend
+
+ compiler_fn = lookup_backend(compiler_name)
+
+ try:
+ compiled_gm = compiler_fn(gm, example_inputs)
+ run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs))
+ raise ValueError("No issue was detected")
+ except Exception as exc:
+ orig_failure = str(exc)
+ log.warning(
+ "Compiled Fx GraphModule failed with following error. Starting minifier."
+ )
+ log.exception(exc)
+ dump_state_fn = functools.partial(
+ dump_backend_state, compiler_name=compiler_name
+ )
+ dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs)
+ fails_fn = functools.partial(
+ backend_fails,
+ compiler_fn=compiler_fn,
+ orig_failure=orig_failure,
+ )
+ minifier(
+ gm,
+ example_inputs,
+ module_fails=fails_fn,
+ dump_state=dump_state_fn,
+ )
+ return gm
+
+
+@register_backend
+def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name):
+ from functorch.compile import minifier
+
+ from torchdynamo.optimizations.backends import BACKENDS
+
+ if compiler_name == "inductor":
+ from torchinductor.compile_fx import compile_fx
+
+ compiler_fn = compile_fx
+ else:
+ compiler_fn = BACKENDS[compiler_name]
+
+ # Set the eval mode to remove randomness.
+ gm.eval()
+
+ # Check Accuracy
+ if backend_accuracy_fails(gm, example_inputs, compiler_fn):
+ log.warning("Accuracy failed for the TorchDyanmo produced graph")
+ dump_state_fn = functools.partial(
+ dump_backend_state, compiler_name=compiler_name, check_accuracy=True
+ )
+ fails_fn = functools.partial(
+ backend_accuracy_fails,
+ compiler_fn=compiler_fn,
+ )
+ dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs)
+ minifier(
+ gm,
+ example_inputs,
+ module_fails=fails_fn,
+ dump_state=dump_state_fn,
+ )
+ else:
+ log.error("Input graph does not fail accuracy testing")
+ return gm
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
new file mode 100644
index 0000000000000..e015699c0ead0
--- /dev/null
+++ b/torch/_dynamo/eval_frame.py
@@ -0,0 +1,704 @@
+import contextlib
+import copy
+import functools
+import inspect
+import logging
+import os
+import sys
+import threading
+import traceback
+import types
+import warnings
+from importlib import import_module
+from unittest.mock import patch
+
+import torch
+import torch.utils._pytree as pytree
+from torch.fx.experimental.proxy_tensor import make_fx
+from torch.nn.parallel.distributed import DistributedDataParallel
+
+from . import config, convert_frame, logging as torchdynamo_logging, skipfiles, utils
+from .exc import ResetRequired
+from .mutation_guard import install_generation_tagging_init
+from .optimizations.distributed import DDPOptimizer
+from .utils import checkpoint_params, clone_inputs, compile_times, same
+
+log = logging.getLogger(__name__)
+
+try:
+ from torch.fx.experimental import proxy_tensor
+except ImportError:
+ proxy_tensor = None
+
+_eval_frame = torch._C._dynamo.eval_frame
+set_eval_frame = _eval_frame.set_eval_frame
+reset_code = _eval_frame.reset_code
+unsupported = _eval_frame.unsupported
+skip_code = _eval_frame.skip_code
+set_guard_fail_hook = _eval_frame.set_guard_fail_hook
+set_guard_error_hook = _eval_frame.set_guard_error_hook
+always_optimize_code_objects = utils.ExactWeakKeyDictionary()
+null_context = contextlib.nullcontext
+unset = object()
+compile_lock = threading.RLock()
+most_recent_backend = None
+
+
+def remove_from_cache(f):
+ """
+ Make sure f.__code__ is not cached to force a recompile
+ """
+ if isinstance(f, types.CodeType):
+ reset_code(f)
+ elif hasattr(f, "__code__"):
+ reset_code(f.__code__)
+ elif hasattr(getattr(f, "forward", None), "__code__"):
+ reset_code(f.forward.__code__)
+ else:
+ from . import reset
+
+ reset()
+ log.warning("could not determine __code__ for %s", f)
+
+
+def nothing():
+ pass
+
+
+def innermost_fn(fn):
+ """
+ In case of nesting of _TorchDynamoContext calls, find the innermost
+ function. TorchDynamo caches on fn.__code__ object, so its necessary to find
+ the innermost function to pass on the optimize, run, disable etc.
+ """
+ unaltered_fn = fn
+ while hasattr(unaltered_fn, "_torchdynamo_orig_callable"):
+ unaltered_fn = unaltered_fn._torchdynamo_orig_callable
+ assert callable(unaltered_fn)
+ return unaltered_fn
+
+
+@functools.lru_cache(None)
+def _step_logger():
+ return torchdynamo_logging.get_step_logger(log)
+
+
+class _TorchDynamoContext:
+ def __init__(
+ self,
+ callback,
+ on_enter=nothing,
+ backend_ctx_ctor=null_context,
+ patch_fn=nothing,
+ first_ctx=False,
+ ):
+ super().__init__()
+ assert callable(callback) or callback is False or callback is None
+ self.callback = callback
+ self.prior = unset
+ self.on_enter = on_enter
+ self.extra_ctx_ctor = backend_ctx_ctor
+ self.first_ctx = first_ctx
+ patch_fn()
+
+ def __enter__(self):
+ if config.raise_on_ctx_manager_usage:
+ raise RuntimeError(
+ "torchdynamo.optimize(...) is used with a context manager. "
+ "Please refer to https://github.com/pytorch/torchdynamo#usage-example "
+ "to use torchdynamo.optimize(...) as an annotation/decorator. "
+ )
+ self.on_enter()
+ self.prior = set_eval_frame(self.callback)
+ self.backend_ctx = self.extra_ctx_ctor()
+ self.backend_ctx.__enter__()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ set_eval_frame(self.prior)
+ self.prior = unset
+ self.backend_ctx.__exit__(exc_type, exc_val, exc_tb)
+
+ def __call__(self, fn):
+ fn = innermost_fn(fn)
+ # Optimize the forward method of torch.nn.Module object
+ if isinstance(fn, torch.nn.Module):
+ mod = fn
+ optimized_forward = self(mod.forward)
+
+ class TorchDynamoNNModuleWrapper:
+ """
+ A wrapper that redirects the forward call to the optimized
+ forward, while for rest it redirects the calls to the original
+ module.
+ """
+
+ def __getattr__(self, name):
+ return getattr(mod, name)
+
+ def forward(self, *args, **kwargs):
+ return optimized_forward(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ new_mod = TorchDynamoNNModuleWrapper()
+ # Save the function pointer to find the original callable while nesting
+ # of decorators.
+ new_mod._torchdynamo_orig_callable = mod
+ return new_mod
+
+ assert callable(fn)
+ callback = self.callback
+ on_enter = self.on_enter
+ backend_ctx_ctor = self.extra_ctx_ctor
+
+ @functools.wraps(fn)
+ def _fn(*args, **kwargs):
+ if self.first_ctx:
+ _step_logger()(logging.INFO, "torchdynamo begin tracing")
+
+ on_enter()
+ prior = set_eval_frame(callback)
+ backend_ctx = backend_ctx_ctor()
+ backend_ctx.__enter__()
+ try:
+ return fn(*args, **kwargs)
+ finally:
+ set_eval_frame(prior)
+ backend_ctx.__exit__(None, None, None)
+ if self.first_ctx:
+ _step_logger()(logging.INFO, "torchdynamo done tracing")
+
+ # hooks to properly handle inlining
+ if isinstance(self, DisableContext):
+ _fn._torchdynamo_disable = True
+ else:
+ _fn._torchdynamo_inline = fn
+
+ # Save the function pointer to find the original callable while nesting
+ # of decorators.
+ _fn._torchdynamo_orig_callable = fn
+
+ # If the function is called using torchdynamo.optimize decorator, we
+ # should prevent any type of skipping.
+ if callback not in (None, False):
+ always_optimize_code_objects[fn.__code__] = True
+
+ return _fn
+
+
+class OptimizeContext(_TorchDynamoContext):
+ def __init__(self, callback, backend_ctx_ctor, first_ctx=False):
+ def on_enter():
+ global most_recent_backend
+ if (
+ most_recent_backend is not None
+ and most_recent_backend is not compiler_fn
+ ):
+ raise ResetRequired()
+ most_recent_backend = compiler_fn
+ install_generation_tagging_init()
+
+ compiler_fn = innermost_fn(callback)
+ super().__init__(
+ callback=callback,
+ on_enter=on_enter,
+ backend_ctx_ctor=backend_ctx_ctor,
+ patch_fn=TorchPatcher.patch,
+ first_ctx=first_ctx,
+ )
+
+
+class RunOnlyContext(_TorchDynamoContext):
+ def __init__(self):
+ super().__init__(callback=False)
+
+
+class DisableContext(_TorchDynamoContext):
+ def __init__(self):
+ super().__init__(callback=None)
+
+
+def catch_errors_wrapper(callback):
+ @functools.wraps(callback)
+ def catch_errors(frame, cache_size):
+ try:
+ if frame.f_lasti >= 0 or skipfiles.check(frame.f_code.co_filename):
+ log.debug(f"skipping {frame.f_code.co_name} {frame.f_code.co_filename}")
+ return None
+ if (
+ frame.f_code.co_filename == ""
+ and frame.f_code.co_name == "__new__"
+ ):
+ # nametuple constructor
+ return None
+ if config.optimize_ddp:
+ ddp_module = DistributedDataParallel._get_active_ddp_module()
+ if ddp_module and frame.f_code.co_name == "forward":
+ with compile_lock:
+ ddp_optimizer = DDPOptimizer(
+ bucket_bytes_cap=ddp_module.bucket_bytes_cap,
+ parameters_to_ignore=ddp_module.parameters_to_ignore,
+ backend_compile_fn=callback._torchdynamo_orig_callable,
+ )
+ hijacked_callback = convert_frame.convert_frame(
+ ddp_optimizer.compile_fn, guard_export_fn=None
+ )
+ return hijacked_callback(frame, cache_size)
+
+ with compile_lock:
+ return callback(frame, cache_size)
+ except Exception:
+ log.exception("Error while processing frame")
+ raise
+
+ catch_errors._torchdynamo_orig_callable = callback
+ return catch_errors
+
+
+def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context):
+ return OptimizeContext(
+ catch_errors_wrapper(compile_fn),
+ backend_ctx_ctor=backend_ctx_ctor,
+ first_ctx=True,
+ )
+
+
+class WrapperBackend:
+ def __init__(self, backend=None):
+ self.backend = backend
+
+ @property
+ def example_inputs(self):
+ return clone_inputs(self.original_example_inputs)
+
+ def __call__(self, gm: torch.fx.GraphModule, example_inputs):
+
+ self.restore = checkpoint_params(gm)
+ self.original_example_inputs = clone_inputs(example_inputs)
+ self.gm = gm
+ copy_gm = copy.deepcopy(self.gm)
+ self.candidate = self.backend(copy_gm, self.original_example_inputs)
+
+ if self.candidate is None or self.candidate is self.gm.forward:
+ return self.gm.forward
+
+ if not config.verify_correctness:
+ return self.candidate
+
+ # if verify_correctness=True
+ try:
+ correct = self.gm.forward(*self.example_inputs)
+ result = self.candidate(*self.example_inputs)
+
+ # TODO: replace `same` function with the one in testing
+ if same(correct, result):
+ return self.candidate
+
+ raise RuntimeError(f"incorrect results of backend {self}")
+ return self.gm.forward
+
+ except Exception:
+ log.exception("error in verify_correctness")
+ raise
+ finally:
+ self.restore()
+
+
+def get_compiler_fn(compiler_fn):
+ from .debug_utils import wrap_backend_debug
+
+ compiler_str = compiler_fn if isinstance(compiler_fn, str) else None
+ compiler_fn = lookup_backend(compiler_fn)
+ return wrap_backend_debug(compiler_fn, compiler_str)
+
+
+@functools.lru_cache(1)
+def lookup_backend(compiler_fn):
+ """Expand backend strings to functions"""
+ if compiler_fn == "inductor":
+ compiler_fn = import_module(f"{config.inductor_import}.compile_fx").compile_fx
+ elif isinstance(compiler_fn, str):
+ from .optimizations import BACKENDS
+
+ compiler_fn = BACKENDS[compiler_fn]
+ return compiler_fn
+
+
+class _NullDecorator(contextlib.nullcontext):
+ def __call__(self, fn):
+ assert callable(fn)
+ return fn
+
+
+def optimize(
+ backend="inductor", *, nopython=False, guard_export_fn=None, disable=False
+):
+ """
+ The main entrypoint of TorchDynamo. Do graph capture and call
+ backend() to optimize extracted graphs.
+
+ Args:
+ backend: One of the two things:
+ - Either, a function/callable taking a torch.fx.GraphModule and
+ example_inputs and returning a python callable that runs the
+ graph faster.
+ One can also provide additional context for the backend, like
+ torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
+ See AOTAutogradMemoryEfficientFusionWithContext for the usage.
+ - Or, a string backend name in `torchdynamo.list_backends()`
+ nopython: If True, graph breaks will be errors and there will
+ be a single whole-program graph.
+ disable: If True, turn this decorator into a no-op
+
+ Example Usage:
+
+ @torchdynamo.optimize()
+ def toy_example(a, b):
+ ...
+ """
+ if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1":
+ return _NullDecorator()
+ if sys.platform == "win32":
+ warnings.warn(
+ "Windows is not currently supported, "
+ + f"{config.dynamo_import}.optimize() will do nothing"
+ )
+ return _NullDecorator()
+ if sys.version_info >= (3, 11):
+ warnings.warn(
+ "Python 3.11+ not yet supported, "
+ f"{config.dynamo_import}.optimize() will do nothing"
+ )
+ return _NullDecorator()
+
+ backend = get_compiler_fn(backend)
+
+ # Find if backend has any extra context manager
+ backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
+
+ if nopython:
+ return optimize_assert(backend, guard_export_fn=guard_export_fn)
+ return _optimize_catch_errors(
+ convert_frame.convert_frame(backend, guard_export_fn=guard_export_fn),
+ backend_ctx_ctor,
+ )
+
+
+@patch("torchdynamo.symbolic_convert.explain", True)
+def explain(f, *args, **kwargs):
+ # TODO(voz): Do we want a decorator for this?
+ from . import reset
+
+ reset()
+
+ out_guards = []
+ graphs = []
+ ops_per_graph = []
+ op_count = 0
+ break_reasons = []
+
+ def dynamo_graph_accumulating_compiler(gm: torch.fx.GraphModule, example_inputs):
+ nonlocal graphs
+ nonlocal op_count
+ nonlocal ops_per_graph
+
+ graphs.append(gm)
+ ops = []
+ for node in gm.graph.nodes:
+ if node.op == "call_function":
+ ops.append(node.target)
+
+ op_count += len(ops)
+ ops_per_graph.append(ops)
+ if gm.compile_subgraph_reason is not None:
+ break_reasons.append(gm.compile_subgraph_reason)
+ return gm.forward
+
+ def guard_export_print(guards):
+ nonlocal out_guards
+ out_guards.append(guards)
+
+ with patch(f"{__name__}.most_recent_backend", None):
+ opt_f = optimize(
+ dynamo_graph_accumulating_compiler,
+ nopython=False,
+ guard_export_fn=guard_export_print,
+ )(f)
+ # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.
+ opt_f(*args, **kwargs)
+
+ graph_count = len(graphs)
+
+ # For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it.
+ deduped_reasons = {}
+ for reason in break_reasons:
+ innermost_frame = reason.user_stack[-1]
+ # __repr__ uniquely identifies a FrameSummary so we can use it for deduping
+ deduped_reasons[repr(innermost_frame)] = reason
+
+ formatted_list = ""
+ for idx, break_reason in enumerate(deduped_reasons.values()):
+ formatted_stack = "".join(traceback.format_list(break_reason.user_stack))
+ msg = f"{break_reason.reason}\n{formatted_stack}"
+ formatted_list += f"{idx + 1}. {msg} \n"
+
+ explanation = f"Dynamo produced {graph_count} graphs"
+ explanation += f"with {graph_count - 1} graph break and {op_count} ops"
+ explanation += f"\n Break reasons: \n\n{formatted_list}"
+
+ explanation += compile_times()
+
+ # TODO(voz): Do we want a decorator for this?
+ reset()
+ return explanation, out_guards, graphs, ops_per_graph, break_reasons
+
+
+def export(
+ f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs
+):
+ if decomposition_table is not None or tracing_mode != "real":
+ assert (
+ aten_graph
+ ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
+ f = innermost_fn(f)
+
+ graph = None
+ out_guards = None
+ graph_captured_input = None
+ graph_captured_result = None
+
+ def produce_matching(source_args, candidate_args):
+ matched_elements_positions = []
+ dict_of_source_args = dict()
+ for i in range(0, len(source_args)):
+ element_id = id(source_args[i])
+ dict_of_source_args[element_id] = i
+
+ for i in range(0, len(candidate_args)):
+ arg = candidate_args[i]
+ # 1-element tensor arg can be unspec int/float
+ if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1:
+ if id(arg) in dict_of_source_args:
+ matched_elements_positions.append(dict_of_source_args[id(arg)])
+ elif id(arg.item()) in dict_of_source_args:
+ matched_elements_positions.append(
+ dict_of_source_args[id(arg.item())]
+ )
+ else:
+ raise AssertionError(
+ "Dynamo input/output is not consistent with traced input/output"
+ )
+ else:
+ assert (
+ id(arg) in dict_of_source_args
+ ), "Dynamo input and output is a strict subset of traced input/output"
+ matched_elements_positions.append(dict_of_source_args[id(arg)])
+
+ return matched_elements_positions
+
+ def guard_export_print(guards):
+ nonlocal out_guards
+ assert out_guards is None, "whole graph export entails exactly one guard export"
+ out_guards = guards
+
+ def dynamo_normalization_capturing_compiler(
+ gm: torch.fx.GraphModule, example_inputs
+ ):
+ nonlocal graph
+
+ assert graph is None, "whole graph export entails exactly one graph"
+ graph = gm
+
+ def result_capturing_wrapper(*graph_inputs):
+ nonlocal graph_captured_result
+ nonlocal graph_captured_input
+
+ graph_captured_input = graph_inputs
+ graph_captured_result = graph(*graph_inputs)
+ return graph_captured_result
+
+ return result_capturing_wrapper
+
+ # TODO(voz): Handle kwargs properly?
+ flat_args, in_spec = pytree.tree_flatten(args)
+
+ remove_from_cache(f)
+ with patch(f"{__name__}.most_recent_backend", None):
+ opt_f = optimize_assert(
+ dynamo_normalization_capturing_compiler,
+ guard_export_fn=guard_export_print,
+ export=True,
+ )(f)
+ # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.
+ result_traced = opt_f(*args, **kwargs)
+ remove_from_cache(f)
+
+ assert graph is not None, "whole graph export entails exactly one call"
+ assert out_guards is not None, "whole graph export entails exactly one guard export"
+
+ matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)
+
+ flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
+
+ flat_both = list(graph_captured_result) + flat_args
+ matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
+
+ class ChangeInputOutputSignature(torch.fx.interpreter.Transformer):
+ def __init__(
+ self,
+ m,
+ ):
+ super().__init__(m)
+ arg_len = len(flat_args)
+ self.new_args = [
+ super(ChangeInputOutputSignature, self).placeholder(f"arg{i}", (), {})
+ for i in range(0, arg_len)
+ ]
+ self.old_args_gen = (
+ self.new_args[i] for i in matched_input_elements_positions
+ )
+
+ def placeholder(self, target, args, kwargs):
+ return next(self.old_args_gen)
+
+ def output(self, target, args, kwargs):
+ dynamo_result_flat = args[0]
+ lookup = [*dynamo_result_flat, *self.new_args]
+ new_result_flat = [lookup[i] for i in matched_output_elements_positions]
+ new_result = pytree.tree_unflatten(new_result_flat, out_spec_traced)
+
+ return super().output(target, (new_result,), {})
+
+ if aten_graph:
+ # Running graph with interpreter is needed for propagating the stack_trace
+ def graph_with_interpreter(*args):
+ with torch.fx.traceback.override_stack_trace():
+ return torch.fx.Interpreter(graph).run(*args)
+
+ graph = make_fx(
+ graph_with_interpreter,
+ decomposition_table=decomposition_table,
+ tracing_mode=tracing_mode,
+ )(*graph_captured_input)
+
+ new_graph = ChangeInputOutputSignature(
+ graph,
+ ).transform()
+
+ return (new_graph, out_guards)
+
+
+def assume_constant_result(fn):
+ fn._dynamo_marked_constant = True
+ assert (
+ not config.fake_tensor_propagation
+ ), "Constant result capture is not supported with fake tensors."
+ return fn
+
+
+def optimize_assert(backend, *, guard_export_fn=None, export=False):
+ """
+ The same as `torchdynamo.optimize(backend, nopython=True)`
+ """
+ backend = get_compiler_fn(backend)
+
+ # Find if backend has any extra context manager
+ backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
+
+ return _optimize_catch_errors(
+ convert_frame.convert_frame_assert(backend, guard_export_fn, export=export),
+ backend_ctx_ctor,
+ )
+
+
+def run(fn=None):
+ """Don't do any dynamic compiles, just use prior optimizations"""
+ if fn is not None:
+ fn = innermost_fn(fn)
+ assert callable(fn)
+ return RunOnlyContext()(fn)
+ return RunOnlyContext()
+
+
+def disable(fn=None):
+ """Decorator and context manager to disable TorchDynamo"""
+ if fn is not None:
+ fn = innermost_fn(fn)
+ assert callable(fn)
+ return DisableContext()(fn)
+ return DisableContext()
+
+
+def skip(fn=None):
+ """
+ Skip frames associated with the function code, but still process recursively
+ invoked frames
+ """
+ if fn is None:
+ return skip
+ fn = innermost_fn(fn)
+ assert callable(fn)
+ skip_code(fn.__code__)
+ fn._torchdynamo_disable = True
+ return fn
+
+
+class TorchPatcher:
+ @staticmethod
+ @functools.lru_cache(None)
+ def patch():
+ # Disable TorchDynamo on some torch.* compilers generated frames
+ torch.jit.trace = disable(torch.jit.trace)
+ torch.jit.trace_module = disable(torch.jit.trace_module)
+ torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
+
+ # symbolic_trace creates new frames. We disable Dynamo on such frames
+ torch.fx._symbolic_trace.Tracer.trace = disable(
+ torch.fx._symbolic_trace.Tracer.trace
+ )
+
+ torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string)
+ torch.distributions.Distribution.set_default_validate_args(False)
+
+ if proxy_tensor is not None:
+ proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
+
+ optimizers = [
+ opt
+ for opt in torch.optim.__dict__.values()
+ if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
+ ]
+
+ # disable dynamo for the wrapper that helps give dynamo hints about entering DDP
+ if hasattr(DistributedDataParallel, "_inside_ddp_forward"):
+ DistributedDataParallel._inside_ddp_forward = skip(
+ DistributedDataParallel._inside_ddp_forward
+ )
+
+ # disable profile hook
+ for opt in optimizers:
+ opt._cuda_graph_capture_health_check = disable(
+ opt._cuda_graph_capture_health_check
+ )
+ # disable any currently set hooks
+ # Note: we only want to disable the profiling hook
+ # which is the *last* hook applied, we want to keep the no_grad hook
+ hooked = getattr(opt.step, "hooked", False)
+ if hooked:
+ unwrapped_step = getattr(opt.step, "__wrapped__", None)
+ if unwrapped_step:
+ opt.step = unwrapped_step
+
+ # disable future hooking
+ opt.step.hooked = True
+
+ @staticmethod
+ def suppress_torch_distributed_warnings(fn):
+ def inner_fn(*args, **kwargs):
+ warnings.filterwarnings(
+ "ignore", category=UserWarning, module="torch.distributed"
+ )
+ return fn(*args, **kwargs)
+
+ return inner_fn
diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py
new file mode 100644
index 0000000000000..3001c8c823924
--- /dev/null
+++ b/torch/_dynamo/exc.py
@@ -0,0 +1,76 @@
+import os
+import textwrap
+
+from .utils import counters
+
+
+class TorchDynamoException(RuntimeError):
+ pass
+
+
+class InternalTorchDynamoError(TorchDynamoException):
+ pass
+
+
+class RestartAnalysis(TorchDynamoException):
+ pass
+
+
+class SkipFrame(TorchDynamoException):
+ pass
+
+
+class TorchRuntimeError(TorchDynamoException):
+ pass
+
+
+class ResetRequired(TorchDynamoException):
+ def __init__(self):
+ super(ResetRequired, self).__init__(
+ textwrap.dedent(
+ """
+ Must call `torchdynamo.reset()` before changing backends. Detected two calls to
+ `torchdynamo.optimize(...)` with a different backend compiler arguments.
+ """
+ )
+ )
+
+
+class BackendCompilerFailed(TorchDynamoException):
+ def __init__(self, backend_fn, inner_exception):
+ self.backend_name = getattr(backend_fn, "__name__", "?")
+ self.inner_exception = inner_exception
+ super().__init__(
+ f"{self.backend_name} raised {type(inner_exception).__name__}: {inner_exception}"
+ "\n\n"
+ "You can suppress this exception and fall back to eager by setting:\n"
+ " torchdynamo.config.raise_on_backend_error = False"
+ )
+
+
+class Unsupported(TorchDynamoException):
+ def __init__(self, msg):
+ super(Unsupported, self).__init__(msg)
+ self.real_stack = []
+ self.msg = msg
+ self.category = None
+ self.add_to_stats()
+
+ def remove_from_stats(self):
+ counters[self.category][self.msg] -= 1
+ if counters[self.category][self.msg] <= 0:
+ del counters[self.category][self.msg]
+
+ def add_to_stats(self, category="unimplemented"):
+ self.category = category
+ counters[category][self.msg] += 1
+
+
+def unimplemented(msg: str):
+ assert msg != os.environ.get("BREAK", False)
+ raise Unsupported(msg)
+
+
+def warning(msg: str):
+ counters["warnings"][msg] += 1
+ assert msg != os.environ.get("BREAK", False)
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
new file mode 100644
index 0000000000000..0076f5e10b4b9
--- /dev/null
+++ b/torch/_dynamo/guards.py
@@ -0,0 +1,638 @@
+import collections
+import dataclasses
+import enum
+import logging
+import math
+import os
+import re
+import textwrap
+import types
+import weakref
+from inspect import currentframe, getframeinfo
+from typing import Any, Callable, Dict, List, Optional, Set
+
+import numpy as np
+
+import torch
+
+from . import config, convert_frame, mutation_guard
+from .eval_frame import set_guard_error_hook, set_guard_fail_hook
+from .exc import unimplemented
+from .utils import (
+ dict_const_keys,
+ dict_param_key_ids,
+ guard_failures,
+ istype,
+ orig_code_map,
+ rename_implicit,
+ tuple_iterator_getitem,
+ tuple_iterator_len,
+)
+
+log = logging.getLogger(__name__)
+TensorGuards = torch._C._dynamo.guards.TensorGuards
+check_obj_id = torch._C._dynamo.guards.check_obj_id
+check_type_id = torch._C._dynamo.guards.check_type_id
+
+
+CLOSURE_VARS = collections.OrderedDict(
+ [
+ ("___check_type_id", check_type_id),
+ ("___check_obj_id", check_obj_id),
+ ("___is_grad_enabled", torch.is_grad_enabled),
+ ("___odict_getitem", collections.OrderedDict.__getitem__),
+ ("___dict_param_key_ids", dict_param_key_ids),
+ ("___dict_const_keys", dict_const_keys),
+ ("___tuple_iterator_len", tuple_iterator_len),
+ ("___tuple_iterator_getitem", tuple_iterator_getitem),
+ ("__math_isnan", math.isnan),
+ ("inf", float("inf")),
+ ]
+)
+
+
+class GuardSource(enum.Enum):
+ LOCAL = 0
+ GLOBAL = 1
+ LOCAL_NN_MODULE = 2
+ GLOBAL_NN_MODULE = 3
+ CONSTANT = 4
+
+ def select(self, locals_, globals_):
+ if self in (GuardSource.LOCAL, GuardSource.LOCAL_NN_MODULE):
+ return locals_
+ if self in (GuardSource.GLOBAL, GuardSource.GLOBAL_NN_MODULE):
+ return globals_
+ raise NotImplementedError()
+
+ def is_nn_module(self):
+ return self in (GuardSource.GLOBAL_NN_MODULE, GuardSource.LOCAL_NN_MODULE)
+
+ def is_local(self):
+ return self in (GuardSource.LOCAL, GuardSource.LOCAL_NN_MODULE)
+
+
+@dataclasses.dataclass
+class Guard:
+ name: str
+ source: GuardSource
+ create_fn: Callable
+ is_volatile: bool = False
+
+ # Export only. These values are written to at time of guard check_fn creation.
+ guard_types: Optional[List[str]] = None
+ code_list: Optional[List[str]] = None
+ obj_weakref: Optional[Any] = None
+ guarded_class_weakref: Optional[type] = None
+
+ def __hash__(self):
+ return hash((self.name, self.source, id(self.create_fn)))
+
+ def sort_key(self):
+ return (
+ self.source.value,
+ len(self.name),
+ self.name,
+ self.create_fn.__code__.co_firstlineno,
+ )
+
+ def __lt__(self, other):
+ return self.sort_key() < other.sort_key()
+
+ def __str__(self):
+ s = f"""
+ {self.source.name.lower()} {repr(self.name)} {self.create_fn.__name__}
+ {{
+ 'guard_types': {self.guard_types},
+ 'code': {self.code_list},
+ 'obj_weakref': {self.obj_weakref}
+ 'guarded_class': {self.guarded_class_weakref}
+ }}
+ """
+ return s
+
+ def create(self, local_builder: "GuardBuilder", global_builder: "GuardBuilder"):
+ return self.create_fn(self.source.select(local_builder, global_builder), self)
+
+ def is_nn_module(self):
+ return self.source.is_nn_module()
+
+ def is_local(self):
+ return self.source.is_local()
+
+ def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
+ if not self.guard_types:
+ self.guard_types = list()
+
+ self.guard_types.append(guard_type)
+
+ assert self.guarded_class_weakref in (
+ guarded_class,
+ None,
+ ), "Guarded class id must be identical, or None"
+ self.guarded_class_weakref = guarded_class
+
+ if not self.code_list:
+ self.code_list = code_list
+ else:
+ self.code_list.extend(code_list)
+
+ assert self.obj_weakref in (
+ obj_weakref,
+ None,
+ ), "Guarded object must be identical, or None"
+ self.obj_weakref = obj_weakref
+
+
+def strip_function_call(name):
+ """
+ "___odict_getitem(a, 1)" => "a"
+ """
+ m = re.search(r"([a-z0-9_]+)\(([^(),]+)[^()]*\)", name)
+ if m and m.group(1) != "slice":
+ return strip_function_call(m.group(2))
+ return strip_getattr_getitem(name)
+
+
+def strip_getattr_getitem(name):
+ """
+ "a[1]" => "a"
+ "a.foo" => "a"
+ """
+ return re.split(r"[.\[]", name)[0]
+
+
+class GuardBuilder:
+ def __init__(
+ self, id_ref: Callable, scope: Dict[str, Any], guarded_code, renames=True
+ ):
+ self.id_ref = id_ref
+ if scope:
+ if renames:
+ scope = {rename_implicit(k): v for k, v in scope.items()}
+ else:
+ scope = dict()
+ self.scope = scope
+ self.argnames: List[str] = []
+ # Code is python expression strings generated for each guard
+ self.code: List[str] = []
+ self.tensor_check_names = []
+ self.tensor_check_examples = []
+ self.guarded_code = guarded_code
+
+ def get(self, name: str):
+ return eval(name, self.scope, CLOSURE_VARS)
+
+ def arg_ref(self, guard: Guard):
+ if isinstance(guard, str):
+ name = guard
+ else:
+ name = guard.name
+ base = strip_getattr_getitem(strip_function_call(name))
+ if base not in self.argnames:
+ if re.match(r"^\d+$", base):
+ log.warning(f"invalid var name: {guard}")
+ self.argnames.append(base)
+
+ return name
+
+ def TYPE_MATCH(self, guard: Guard):
+ # ___check_type_id is same as `id(type(x)) == y`
+ t = type(self.get(guard.name))
+ obj_id = self.id_ref(t)
+ code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})"
+ self._produce_guard_code(guard, [code])
+
+ def ID_MATCH(self, guard: Guard):
+ # ___check_obj_id is same as `id(x) == y`
+ m = re.match(r"^type\((.+)\)$", guard.name)
+ if m:
+ # optional optimization to produce cleaner/faster guard code
+ return self.TYPE_MATCH(Guard(m.group(1), guard.source, None))
+
+ code = f"___check_obj_id({self.arg_ref(guard)}, {self.id_ref(self.get(guard.name))})"
+ self._produce_guard_code(guard, [code])
+
+ def NAME_MATCH(self, guard: Guard):
+ obj = self.get(guard.name)
+ code = f"{self.arg_ref(guard)}.__name__ == {obj.__name__})"
+ self._produce_guard_code(guard, [code])
+
+ def HASATTR(self, guard: Guard):
+ m = re.match(r"^(.*)[.]([a-zA-Z0-9_]+)$", guard.name)
+ assert m, f"invalid hasattr check {guard.name}"
+ base, attr = m.group(1, 2)
+ ref = self.arg_ref(base)
+ val = hasattr(self.get(base), attr)
+ code = None
+ if val:
+ code = f"hasattr({ref}, {attr!r})"
+ else:
+ code = f"not hasattr({ref}, {attr!r})"
+
+ self._produce_guard_code(guard, [code], provided_guarded_object=self.get(base))
+
+ def EQUALS_MATCH(self, guard: Guard):
+ ref = self.arg_ref(guard)
+ val = self.get(guard.name)
+ t = type(val)
+ assert istype(
+ val,
+ (
+ int,
+ float,
+ bool,
+ type(None),
+ str,
+ type,
+ list,
+ tuple,
+ set,
+ slice,
+ frozenset,
+ range,
+ torch.Size,
+ torch.device,
+ torch.dtype,
+ np.int8,
+ np.int16,
+ np.int32,
+ np.int64,
+ np.uint8,
+ np.uint16,
+ np.uint32,
+ np.uint64,
+ ),
+ ), t.__name__
+ if istype(val, (torch.device, torch.dtype)):
+ # TODO(jansel): is this slow? perhaps optimize it
+ code = f"str({ref}) == {str(val)!r}"
+ self._produce_guard_code(guard, [code])
+ return
+
+ # Special case for nan because float("nan") == float("nan") evaluates to False
+ if istype(val, float) and math.isnan(val):
+ code = list()
+ code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
+ code.append(f"__math_isnan({ref})")
+ self._produce_guard_code(guard, code)
+ return
+
+ # Add type check to prevent equality check between tensor and non-tensor.
+ code = list()
+ if istype(val, (list, tuple)):
+ self.LIST_LENGTH(guard)
+
+ for idx, elem in enumerate(val):
+ code.append(
+ f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})"
+ )
+
+ elif not istype(val, torch.Size):
+ code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
+
+ if istype(val, torch.Size):
+ val = tuple(val)
+
+ code.append(f"{ref} == {val!r}")
+ self._produce_guard_code(guard, code)
+
+ def CONSTANT_MATCH(self, guard: Guard):
+ val = self.get(guard.name)
+ if istype(val, (bool, type(None))):
+ self.ID_MATCH(guard)
+ else:
+ self.EQUALS_MATCH(guard)
+
+ def NN_MODULE(self, guard: Guard):
+ self.ID_MATCH(guard)
+ ref = self.arg_ref(guard)
+ val = self.get(guard.name)
+
+ def setup_guard():
+ assert istype(val.training, bool)
+ self.code.append(f"{ref}.training == {val.training}")
+
+ if hasattr(val, "training"):
+ # There are cases where a monkeypatched object has a guard made between __new__ and __init__
+ setup_guard()
+ else:
+ unimplemented(f"Guard setup for uninitialized class {type(val)}")
+
+ def FUNCTION_MATCH(self, guard: Guard):
+ """things like torch.add and user defined functions"""
+ if guard.is_local():
+ return self.ID_MATCH(guard)
+
+ def BUILTIN_MATCH(self, guard: Guard):
+ return self.FUNCTION_MATCH(guard)
+
+ def PYMODULE_MATCH(self, guard: Guard):
+ return self.FUNCTION_MATCH(guard)
+
+ def LIST_LENGTH(self, guard):
+ ref = self.arg_ref(guard)
+ value = self.get(guard.name)
+ t = type(value)
+
+ code = list()
+ code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
+ code.append(f"len({ref}) == {len(value)}")
+
+ self._produce_guard_code(guard, code)
+
+ def TUPLE_ITERATOR_LEN(self, guard):
+ ref = self.arg_ref(guard)
+ value = self.get(guard.name)
+ t = type(value)
+
+ code = list()
+ code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
+ code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}")
+
+ self._produce_guard_code(guard, code)
+
+ def DICT_KEYS(self, guard):
+ ref = self.arg_ref(guard)
+ value = self.get(guard.name)
+ t = type(value)
+
+ code = list()
+ code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
+ param_key_ids = set(dict_param_key_ids(value))
+ const_keys = set(dict_const_keys(value))
+ if param_key_ids:
+ code.append(f"___dict_param_key_ids({ref}) == {param_key_ids!r}")
+ code.append(f"___dict_const_keys({ref}) == {const_keys!r}")
+ else:
+ code.append(f"set({ref}.keys()) == {const_keys!r}")
+
+ self._produce_guard_code(guard, code)
+
+ def WEAKREF_ALIVE(self, guard):
+ self._produce_guard_code(guard, [f"{self.arg_ref(guard)} is not None"])
+
+ def NN_MODULE_PARAM_NAMES(self, guard):
+ ref = self.arg_ref(guard)
+ value = self.get(guard.name)
+ t = type(value)
+ keys = {k for k, v in value.named_parameters()}
+
+ code = list()
+ code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
+ code.append(f"{{k for k, v in {ref}.named_parameters()}} == {keys!r}")
+
+ self._produce_guard_code(guard, code)
+
+ def ODICT_KEYS(self, guard):
+ """OrderedDict keys match"""
+ ref = self.arg_ref(guard)
+ value = self.get(guard.name)
+ t = type(value)
+
+ code = list()
+ code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
+ code.append(f"str({ref}.keys()) == {str(value.keys())!r}")
+
+ self._produce_guard_code(guard, code)
+
+ def OBJECT_MUTATION(self, guard: Guard):
+ mutation_guard.watch(self.get(guard.name), self.guarded_code)
+
+ def GRAD_MODE(self, guard: Guard):
+ """Guard on the initial grad state"""
+ assert guard.name == ""
+ assert guard.source is GuardSource.GLOBAL
+ code = None
+ if convert_frame.initial_grad_state:
+ code = "___is_grad_enabled()"
+ else:
+ code = "not ___is_grad_enabled()"
+ self._produce_guard_code(guard, [code])
+
+ def TENSOR_MATCH(self, guard: Guard):
+ if guard.is_nn_module():
+ self.ID_MATCH(guard)
+ else:
+ value = self.get(guard.name)
+ self.tensor_check_names.append(self.arg_ref(guard))
+ self.tensor_check_examples.append(value)
+
+ # Note: Guard code produced for tensor_match is a little different.
+ # We accumulate tensor names, then do a single install of `___check_tensors`.
+ # See _guards.cpp and TensorGuard for more information.
+ # TODO(voz): Add tensor matching code to export
+ # Note: this is a bit of a special case, and so does not use _produce_guard_code
+ guard.set_export_info(
+ "TENSOR_MATCH",
+ weakref.ref(type(value)),
+ None,
+ weakref.ref(value),
+ )
+
+ # A util that appends guarded code, or, in the case of export, adds data onto guards
+ def _produce_guard_code(self, guard, code_list, provided_guarded_object=None):
+ caller = currentframe().f_back
+ func_name = getframeinfo(caller)[2]
+ # We use func_name for export, so might as well get a nice defensive check out of it
+ assert func_name in dir(
+ self.__class__
+ ), f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}"
+
+ self.code.extend(code_list)
+
+ # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD)
+ if provided_guarded_object is None:
+ name_valid = guard.name is not None and guard.name != ""
+
+ guarded_object = self.get(guard.name) if name_valid else None
+ else:
+ guarded_object = provided_guarded_object
+
+ guarded_object_type = (
+ weakref.ref(type(guarded_object)) if guarded_object is not None else None
+ )
+ obj_ref = None
+ if hasattr(guarded_object.__class__, "__weakref__"):
+ obj_ref = weakref.ref(guarded_object)
+
+ guard.set_export_info(
+ func_name,
+ guarded_object_type,
+ code_list,
+ obj_ref,
+ )
+
+
+@dataclasses.dataclass
+class GuardedCode:
+ code: types.CodeType
+ check_fn: Callable
+
+
+# NB: Naively, you'd expect this to only be a function that produces
+# the callable that consistutes the guard. However, there is some
+# delicate handling for invalidating this check function when the
+# locals/globals get invalidated, so there's some extra state
+# we have to hold in this manager class.
+#
+# TODO: this object has reference cycle with itself, via check_fn which
+# references back to CheckFunction via ___guarded_code in closure_vars.
+# Ideally, there shouldn't be any ref cycle so that guards are
+# promptly disposed of.
+class CheckFunctionManager:
+ def __init__(
+ self,
+ guards: Optional[Set[Guard]] = None,
+ f_locals: Optional[Dict] = None,
+ f_globals: Optional[Dict] = None,
+ ):
+ self.valid = True
+ self._weakrefs = []
+ self._seen_ids = set()
+
+ # Note: right overrides left
+ def combine_scopes(left, right):
+ if left is None:
+ return right
+
+ if right is None:
+ return left
+
+ return {**left, **right}
+
+ local_builder = GuardBuilder(
+ self.id_ref, combine_scopes(f_globals, f_locals), self, renames=True
+ )
+ global_builder = GuardBuilder(self.id_ref, f_globals, self, renames=False)
+ for guard in sorted(guards or [], key=Guard.sort_key):
+ if not config.guard_nn_modules and guard.is_nn_module():
+ continue
+ guard.create(local_builder, global_builder)
+ self.check_fn = self.compile_check_fn(local_builder, global_builder)
+ self._seen_ids.clear()
+
+ def compile_check_fn(self, local_builder, global_builder):
+ assert not (set(local_builder.argnames) & set(global_builder.argnames))
+ # see parallel handling of ".0" / "___implicit0" in _eval_frame.c
+ args = [a for a in local_builder.scope.keys() if a == "___implicit0"]
+ args += [a for a in local_builder.argnames if a != "___implicit0"]
+ args += ["**___kwargs_ignored"]
+ args = ",".join(args)
+
+ code_parts = (
+ ["___guarded_code.valid"] + local_builder.code + global_builder.code
+ )
+ # TODO(whc) maybe only the 'check_tensors' one is ambiguous? if so we can be less general..
+ verbose_code_parts = (
+ ["___guarded_code.valid"] + local_builder.code + global_builder.code
+ )
+
+ tensor_check_names = (
+ local_builder.tensor_check_names + global_builder.tensor_check_names
+ )
+ check_tensors_fn = None
+ check_tensors_verbose_fn = None
+ if tensor_check_names:
+ tensor_check_examples = (
+ local_builder.tensor_check_examples
+ + global_builder.tensor_check_examples
+ )
+ tensor_guards = TensorGuards(
+ *tensor_check_examples, dynamic_shapes=config.dynamic_shapes
+ )
+ check_tensors_fn = tensor_guards.check
+ check_tensors_verbose_fn = tensor_guards.check_verbose
+ code_parts.append(f"___check_tensors({', '.join(tensor_check_names)})")
+ verbose_args = ", ".join(
+ tensor_check_names + ["tensor_check_names=tensor_check_names"]
+ )
+ verbose_code_parts.append(f"___check_tensors_verbose({verbose_args})")
+
+ code = " and ".join(unique(code_parts))
+
+ closure_vars = collections.OrderedDict(
+ [
+ ("___guarded_code", self),
+ ("___check_tensors", check_tensors_fn),
+ ("___check_tensors_verbose", check_tensors_verbose_fn),
+ ("tensor_check_names", tensor_check_names),
+ ]
+ )
+ closure_vars.update(CLOSURE_VARS)
+ py_code = textwrap.dedent(
+ f"""
+ def ___make_guard_fn({','.join(closure_vars.keys())}):
+ return lambda {args}: {code}
+ """
+ )
+ if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1":
+ print("GUARDS", code)
+ set_guard_fail_hook(guard_fail_hook)
+ out = dict()
+ exec(py_code, global_builder.scope, out)
+ guard_fn = out["___make_guard_fn"](*closure_vars.values())
+ guard_fn.closure_vars = closure_vars
+ # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both
+ guard_fn.code_parts = code_parts
+ guard_fn.verbose_code_parts = verbose_code_parts
+ guard_fn.global_scope = global_builder.scope
+ return guard_fn
+
+ def invalidate(self, ref):
+ # A weakref is no longer valid, self.check_fn should return false
+ self.valid = False
+
+ def id_ref(self, obj):
+ """add a weakref, return the id"""
+ try:
+ if id(obj) not in self._seen_ids:
+ self._weakrefs.append(weakref.ref(obj, self.invalidate))
+ self._seen_ids.add(id(obj))
+ except TypeError:
+ pass # cannot weakref bool object
+ return id(obj)
+
+
+def guard_fail_hook(
+ guard_fn: Callable, code: types.CodeType, f_locals: Dict[str, Any], last: bool
+):
+ """
+ called whenever a guard fails.
+ """
+ if not last:
+ return
+ scope = {rename_implicit(k): v for k, v in f_locals.items()}
+ scope.update(guard_fn.closure_vars)
+ reasons = []
+ for part in guard_fn.verbose_code_parts:
+ fail_reason = eval(part, guard_fn.global_scope, scope)
+ # TODO(whc) hacky for now as not every 'part' in guard_fn.verbose_code_parts
+ # is updated to return a string explaining the failure.
+ if isinstance(fail_reason, str):
+ reasons.append(fail_reason)
+ break
+ elif isinstance(fail_reason, bool) and not fail_reason:
+ reasons.append(part)
+ break
+ guard_failures[orig_code_map[code]].append(reasons)
+
+
+def guard_error_hook(
+ guard_fn: Callable, code: types.CodeType, f_locals: Dict[str, Any], last: bool
+):
+ print(
+ f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}"
+ )
+ print(" ", " and\n ".join(guard_fn.code_parts))
+
+
+set_guard_error_hook(guard_error_hook)
+
+
+def unique(seq):
+ seen = set()
+ for x in seq:
+ if x not in seen:
+ yield x
+ seen.add(x)
diff --git a/torch/_dynamo/logging.py b/torch/_dynamo/logging.py
new file mode 100644
index 0000000000000..750bb7f2f3f7d
--- /dev/null
+++ b/torch/_dynamo/logging.py
@@ -0,0 +1,87 @@
+import itertools
+import logging
+import os
+
+# logging level for dynamo generated graphs/bytecode/guards
+CODE = 15
+
+
+# Return all loggers that torchdynamo/torchinductor is responsible for
+def get_loggers():
+ return [
+ logging.getLogger("torchdynamo"),
+ logging.getLogger("torchinductor"),
+ ]
+
+
+# Set the level of all loggers that torchdynamo is responsible for
+def set_loggers_level(level):
+ for logger in get_loggers():
+ logger.setLevel(level)
+
+
+LOGGING_CONFIG = {
+ "version": 1,
+ "formatters": {
+ "torchdynamo_format": {
+ "format": "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s"
+ },
+ },
+ "handlers": {
+ "torchdynamo_console": {
+ "class": "logging.StreamHandler",
+ "level": "DEBUG",
+ "formatter": "torchdynamo_format",
+ "stream": "ext://sys.stderr",
+ },
+ },
+ "loggers": {
+ "torchdynamo": {
+ "level": "DEBUG",
+ "handlers": ["torchdynamo_console"],
+ "propagate": False,
+ },
+ "torchinductor": {
+ "level": "DEBUG",
+ "handlers": ["torchdynamo_console"],
+ "propagate": False,
+ },
+ },
+ "disable_existing_loggers": False,
+}
+
+
+# initialize torchdynamo loggers
+def init_logging(log_level, log_file_name=None):
+ if "PYTEST_CURRENT_TEST" not in os.environ:
+ logging.config.dictConfig(LOGGING_CONFIG)
+ if log_file_name is not None:
+ log_file = logging.FileHandler(log_file_name)
+ log_file.setLevel(log_level)
+ for logger in get_loggers():
+ logger.addHandler(log_file)
+
+ set_loggers_level(log_level)
+
+
+# Creates a logging function that logs a message with a step # prepended.
+# get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
+# so that step numbers are initialized properly. e.g.:
+
+# @functools.lru_cache(None)
+# def _step_logger():
+# return get_step_logger(logging.getLogger(...))
+
+# def fn():
+# _step_logger()(logging.INFO, "msg")
+
+_step_counter = itertools.count(1)
+
+
+def get_step_logger(logger):
+ step = next(_step_counter)
+
+ def log(level, msg):
+ logger.log(level, f"Step {step}: {msg}")
+
+ return log
diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py
new file mode 100644
index 0000000000000..8d1122a7ab60c
--- /dev/null
+++ b/torch/_dynamo/mutation_guard.py
@@ -0,0 +1,119 @@
+import functools
+import weakref
+
+import torch.nn
+from torch.nn import Module
+
+from .utils import ExactWeakKeyDictionary
+
+
+class MutationTracker:
+ db = ExactWeakKeyDictionary()
+
+ def __init__(self):
+ self.mutation_count = 0
+ self.watchers = []
+
+ def on_mutation(self, name):
+ self.mutation_count += 1
+ tmp = self.watchers
+ self.watchers = []
+ for ref in tmp:
+ guarded = ref()
+ if guarded is not None:
+ guarded.invalidate(ref)
+
+ def track(self, guarded_code):
+ self.watchers.append(weakref.ref(guarded_code))
+
+
+def watch(obj, guarded_code):
+ """invalidate guarded_code when obj is mutated"""
+ ensure_patched(type(obj))
+
+ if obj not in MutationTracker.db:
+ MutationTracker.db[obj] = MutationTracker()
+ tracker = MutationTracker.db[obj]
+ tracker.track(guarded_code)
+
+
+def ensure_patched(cls):
+ if getattr(cls, "___needs_mutation_patch", True):
+ cls.___needs_mutation_patch = False
+ original_setattr = cls.__setattr__
+
+ @functools.wraps(original_setattr)
+ def custom_setattr(self, key, value):
+ try:
+ MutationTracker.db[self].on_mutation(key)
+ except KeyError:
+ pass
+ return original_setattr(self, key, value)
+
+ cls.__setattr__ = custom_setattr
+
+
+class GenerationTracker:
+ generation = 0
+ dynamic_classes = ExactWeakKeyDictionary()
+ generation_values = ExactWeakKeyDictionary()
+
+ @classmethod
+ def tag(cls, obj):
+ cls.generation_values[obj] = cls.generation
+
+ @staticmethod
+ def mark_class_dynamic(cls):
+ assert issubclass(cls, torch.nn.Module)
+ GenerationTracker.dynamic_classes[cls] = True
+
+ @classmethod
+ def get_generation_value(cls, obj):
+ if obj not in cls.generation_values:
+ return -1
+ return cls.generation_values[obj]
+
+ @classmethod
+ def check(cls, obj):
+ return (
+ obj in cls.generation_values
+ and cls.generation_values[obj] == cls.generation
+ )
+
+
+def is_dynamic_nn_module(obj):
+ """Check for nn.Modules() created dynamically or mutated"""
+ if hasattr(obj, "torchdynamo_force_dynamic"):
+ return obj.torchdynamo_force_dynamic
+ dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
+ obj
+ )
+ return dyn
+
+
+def install_generation_tagging_init():
+ """
+ Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__
+ so we can detect nn.Module instances created dynamically inside forward methods.
+ """
+
+ if getattr(Module, "___needs_generation_tag_patch", True):
+ init = Module.__init__
+
+ def patched_init(self, *args, **kwargs):
+ init(self, *args, **kwargs)
+ GenerationTracker.tag(self)
+
+ Module.__init__ = patched_init
+
+ setstate = Module.__setstate__
+
+ def patched_setstate(self, state):
+ setstate(self, state)
+ GenerationTracker.tag(self)
+
+ Module.__setstate__ = patched_setstate
+
+ Module.___needs_generation_tag_patch = False
+
+ GenerationTracker.generation += 1
diff --git a/torch/_dynamo/optimizations/__init__.py b/torch/_dynamo/optimizations/__init__.py
new file mode 100644
index 0000000000000..9117517b8bf41
--- /dev/null
+++ b/torch/_dynamo/optimizations/__init__.py
@@ -0,0 +1,6 @@
+from .backends import BACKENDS
+from .training import create_aot_backends
+
+create_aot_backends()
+
+__all__ = ["BACKENDS"]
diff --git a/torch/_dynamo/optimizations/analysis.py b/torch/_dynamo/optimizations/analysis.py
new file mode 100644
index 0000000000000..ccd175bfdae32
--- /dev/null
+++ b/torch/_dynamo/optimizations/analysis.py
@@ -0,0 +1,136 @@
+import copy
+import functools
+import itertools
+import operator
+
+import torch
+from torch.fx.node import map_aggregate
+from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
+from torch.multiprocessing.reductions import StorageWeakRef
+from torch.utils._pytree import tree_map
+
+from .. import config
+from ..utils import fake_tensors_available
+
+if fake_tensors_available:
+ from torch._subclasses import FakeTensorMode # noqa: F401
+
+ from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor
+
+
+class ShapeAliasingAndMutationProp(ShapeProp):
+ def __init__(self, *args, **kwargs):
+ super(ShapeAliasingAndMutationProp, self).__init__(*args, **kwargs)
+ self.input_alias_groups = set()
+ self.storage_to_alias_group = dict()
+ self.make_alias_group = itertools.count(1)
+
+ def tensor_alias_group(self, value: torch.Tensor):
+ """Assign a unique identifier to the storage of a given tensor"""
+ storage = StorageWeakRef(value.storage())
+ alias_group = self.storage_to_alias_group.get(storage)
+ if alias_group is None:
+ alias_group = next(self.make_alias_group)
+ self.storage_to_alias_group[storage] = alias_group
+ return alias_group
+
+ def placeholder(self, target, args, kwargs):
+ value = super().placeholder(target, args, kwargs)
+ assert isinstance(value, torch.Tensor)
+ self.input_alias_groups.add(self.tensor_alias_group(value))
+ return value
+
+ def run_node(self, n: torch.fx.Node):
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
+ tensor_args = self.extract_tensors((args, kwargs))
+
+ input_versions1 = [obj._version for obj in tensor_args]
+ result = getattr(self, n.op)(n.target, args, kwargs)
+ input_versions2 = [obj._version for obj in tensor_args]
+
+ n.meta["type"] = type(result)
+ n.meta["alias_groups"] = {
+ self.tensor_alias_group(obj) for obj in self.extract_tensors(result)
+ }
+ n.meta["mutates_alias_groups"] = {
+ self.tensor_alias_group(tensor)
+ for tensor, v1, v2 in zip(tensor_args, input_versions1, input_versions2)
+ if v1 != v2
+ }
+ # Partial mutation refers to the mutation caused by getitem that can
+ # potentially result in changing only a slice of the original tensor
+ n.meta["partial_mutation"] = False
+
+ def visit_arg(arg: torch.fx.Node):
+ if (
+ arg.op == "call_function" and arg.target == operator.getitem
+ ) or arg.meta["partial_mutation"]:
+ if bool(n.meta["mutates_alias_groups"] & arg.meta["alias_groups"]):
+ n.meta["partial_mutation"] = True
+
+ torch.fx.map_arg((n.args, n.kwargs), visit_arg)
+ n.meta["is_input_alias"] = bool(
+ self.input_alias_groups & n.meta["alias_groups"]
+ )
+ n.meta["is_input_mutation"] = bool(
+ self.input_alias_groups & n.meta["mutates_alias_groups"]
+ )
+ n.meta["is_mutation"] = bool(n.meta["mutates_alias_groups"])
+ n.meta["tensor_metas"] = [
+ _extract_tensor_metadata(obj) for obj in self.extract_tensors(result)
+ ]
+ tensors = self.extract_tensors(result)
+ if tensors:
+ n.meta["device"] = tensors[0].device
+ n.meta["dtype"] = tensors[0].dtype
+
+ return result
+
+ @staticmethod
+ def extract_tensors(result):
+ """Return a flat list of tensors found in some nested data structure"""
+ seen = set()
+ tensors = []
+
+ def visit(obj):
+ if isinstance(obj, torch.Tensor) and id(obj) not in seen:
+ seen.add(id(obj))
+ tensors.append(obj)
+
+ map_aggregate(result, visit)
+ return tensors
+
+ def run(self, *args):
+ try:
+ super().run(*args)
+ finally:
+ # cleanup
+ self.env.clear()
+
+
+def has_mutation(gm, example_inputs, inputs_only=False):
+ """Check if the graph module has any form of mutation. If inputs_only is
+ true, we only check for mutation of inputs"""
+ # TODO - moco gives bad accuracy with Aliasing. gm is getting mutated in a bad way.
+
+ if fake_tensors_available and config.fake_tensor_propagation:
+ with FakeTensorMode() as fake_mode:
+ pass
+ fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=fake_mode)
+ example_inputs = tree_map(fake_wrapper, example_inputs)
+ new_gm = deepcopy_to_fake_tensor(gm, fake_mode)
+ with fake_mode.restore() if hasattr(fake_mode, "restore") else fake_mode:
+ ShapeAliasingAndMutationProp(new_gm).run(*example_inputs)
+ else:
+ new_gm = copy.deepcopy(gm)
+ example_inputs = copy.deepcopy(example_inputs)
+ ShapeAliasingAndMutationProp(new_gm).run(*example_inputs)
+
+ for node in new_gm.graph.nodes:
+ if node.meta["is_mutation"] or node.meta["is_input_mutation"]:
+ if inputs_only:
+ if node.meta["is_input_alias"]:
+ return True
+ else:
+ return True
+ return False
diff --git a/torch/_dynamo/optimizations/backends.py b/torch/_dynamo/optimizations/backends.py
new file mode 100644
index 0000000000000..1ec5c774de11e
--- /dev/null
+++ b/torch/_dynamo/optimizations/backends.py
@@ -0,0 +1,820 @@
+import copy
+import functools
+import io
+import logging
+import os
+import subprocess
+import tempfile
+
+import numpy as np
+
+import torch
+
+from ..utils import identity
+from .subgraph import SubGraph
+
+log = logging.getLogger(__name__)
+BACKENDS = dict()
+_NP_DTYPE = {
+ torch.float16: np.float16,
+ torch.float32: np.float32,
+ torch.float64: np.float64,
+ torch.uint8: np.uint8,
+ torch.int8: np.int8,
+ torch.int16: np.int16,
+ torch.int32: np.int32,
+ torch.int64: np.longlong,
+ torch.bool: np.bool_,
+}
+
+
+def register_backend(fn):
+ @functools.wraps(fn)
+ def inner(gm, example_inputs, **kwargs):
+ return fn(gm, example_inputs, **kwargs)
+
+ BACKENDS[fn.__name__] = inner
+ return inner
+
+
+def create_backend(fn):
+ @functools.wraps(fn)
+ def inner(model, example_inputs=None, **kwargs):
+ if model is None:
+ return None
+
+ if not isinstance(model, SubGraph):
+ with tempfile.TemporaryDirectory() as tmp:
+ return inner(SubGraph(model, example_inputs, tmp), **kwargs)
+ else:
+ assert example_inputs is None
+
+ try:
+ return fn(model, **kwargs)
+ except KeyboardInterrupt:
+ raise
+ except Exception:
+ log.exception(f"{fn.__name__} error")
+ return None
+
+ BACKENDS[fn.__name__] = inner
+ return inner
+
+
+@create_backend
+def eager(subgraph):
+ return subgraph.model
+
+
+@create_backend
+def ts(subgraph):
+ return subgraph.scripted
+
+
+def reload_jit_model(subgraph, opt_fn=identity):
+ tmp = io.BytesIO()
+ torch.jit.save(subgraph.scripted, tmp)
+ tmp.seek(0)
+ model = torch.jit.load(tmp)
+ model = opt_fn(model)
+ # populate cache
+ for _ in range(3):
+ model(*subgraph.example_inputs)
+ return model
+
+
+def reload_jit_model_ofi(subgraph):
+ return reload_jit_model(subgraph, torch.jit.optimize_for_inference)
+
+
+@create_backend
+def nnc(subgraph):
+ with torch.jit.fuser("fuser1"):
+ return reload_jit_model(subgraph)
+
+
+@create_backend
+def nnc_ofi(subgraph):
+ with torch.jit.fuser("fuser1"):
+ return reload_jit_model_ofi(subgraph)
+
+
+@create_backend
+def nvfuser(subgraph):
+ with torch.jit.fuser("fuser2"):
+ return reload_jit_model(subgraph)
+
+
+@create_backend
+def nvfuser_ofi(subgraph):
+ with torch.jit.fuser("fuser2"):
+ return reload_jit_model_ofi(subgraph)
+
+
+@create_backend
+def onednn(subgraph):
+ with torch.jit.fuser("fuser3"):
+ return reload_jit_model(subgraph)
+
+
+@create_backend
+def ofi(subgraph):
+ return torch.jit.optimize_for_inference(subgraph.scripted)
+
+
+@create_backend
+def static_runtime(subgraph):
+ scripted = subgraph.scripted
+ if hasattr(scripted, "_c"):
+ static_module = torch._C._jit_to_static_module(scripted._c)
+ else:
+ static_module = torch._C._jit_to_static_module(scripted.graph)
+ return subgraph.wrap_returns(static_module)
+
+
+def onnxrt_common(subgraph, provider, onnx_filename=None):
+ import onnxruntime
+
+ assert provider in onnxruntime.get_available_providers()
+ session = onnxruntime.InferenceSession(
+ onnx_filename or subgraph.onnx_filename, providers=[provider]
+ )
+ input_names = subgraph.input_names
+ output_names = subgraph.output_names
+ create_outputs = subgraph.empty_outputs_factory()
+ is_cpu = subgraph.is_cpu
+
+ def _call(*args):
+ binding = session.io_binding()
+ args = [a.contiguous() for a in args]
+ for name, value in zip(input_names, args):
+ dev = value.device
+ binding.bind_input(
+ name,
+ dev.type,
+ dev.index or 0,
+ _NP_DTYPE[value.dtype],
+ value.size(),
+ value.data_ptr(),
+ )
+ outputs = create_outputs()
+ for name, value in zip(output_names, outputs):
+ dev = value.device
+ binding.bind_output(
+ name,
+ dev.type,
+ dev.index or 0,
+ _NP_DTYPE[value.dtype],
+ value.size(),
+ value.data_ptr(),
+ )
+ session.run_with_iobinding(binding)
+ if is_cpu:
+ binding.copy_outputs_to_cpu()
+ return outputs
+
+ return subgraph.wrap_returns(_call)
+
+
+@create_backend
+def onnxrt_cpu(subgraph):
+ return onnxrt_common(subgraph, provider="CPUExecutionProvider")
+
+
+@create_backend
+def onnxrt_cuda(subgraph):
+ return onnxrt_common(subgraph, provider="CUDAExecutionProvider")
+
+
+@create_backend
+def onnx2tensorrt(subgraph):
+ if subgraph.will_tensorrt_barf():
+ # TensorRT fails violently with an abort() on this
+ return None
+
+ return onnxrt_common(subgraph, provider="TensorrtExecutionProvider")
+
+
+@create_backend
+def onnxrt_cpu_numpy(subgraph, provider="CPUExecutionProvider"):
+ """Alternate version that integrates via numpy"""
+ import onnxruntime
+
+ assert provider in onnxruntime.get_available_providers()
+ ort_session = onnxruntime.InferenceSession(
+ subgraph.onnx_filename, providers=[provider]
+ )
+
+ def to_numpy(x):
+ try:
+ return x.numpy()
+ except RuntimeError:
+ return x.detach().numpy()
+
+ def _call(*args):
+ res = ort_session.run(
+ None, {f"i{i}": to_numpy(arg) for i, arg in enumerate(args)}
+ )
+ res = [torch.from_numpy(x) for x in res]
+ return res
+
+ return subgraph.wrap_returns(_call)
+
+
+@create_backend
+def onnxrt(subgraph):
+ if subgraph.is_cuda:
+ return onnxrt_cuda(subgraph)
+ else:
+ return onnxrt_cpu(subgraph)
+
+
+@functools.lru_cache(None)
+def _init_tensorflow():
+ import tensorflow as tf
+
+ # prevent tensorflow from eating all the GPU memory
+ gpus = tf.config.list_physical_devices("GPU")
+ for gpu in gpus:
+ tf.config.experimental.set_memory_growth(gpu, True)
+ return tf
+
+
+@create_backend
+def onnx2tf(subgraph):
+ import onnx
+ from onnx_tf.backend import prepare
+
+ tf = _init_tensorflow()
+ filename = subgraph.filename("tensorflow")
+ input_names = subgraph.input_names
+ output_names = subgraph.output_names
+ device = "/CPU:0" if subgraph.is_cpu else f"/GPU:{subgraph.device_index}"
+ with tf.device(device):
+ if not os.path.exists(filename):
+ prepare(onnx.load(subgraph.onnx_filename)).export_graph(filename)
+ tf_module = tf.saved_model.load(filename)
+ tf_module = tf.function(tf_module, jit_compile=True)
+
+ def run(*args):
+ args = [a.contiguous() for a in args]
+ with tf.device(device):
+ outs = tf_module(
+ **{
+ name: tf.experimental.dlpack.from_dlpack(
+ torch.utils.dlpack.to_dlpack(args[idx])
+ )
+ for idx, name in enumerate(input_names)
+ }
+ )
+ return [
+ torch.utils.dlpack.from_dlpack(
+ tf.experimental.dlpack.to_dlpack(outs[name])
+ )
+ for name in output_names
+ ]
+
+ return subgraph.wrap_returns(run)
+
+
+@create_backend
+def taso(subgraph):
+ taso_filename = subgraph.filename("taso")
+ subprocess.check_call(
+ [
+ os.path.expanduser("~/conda/envs/taso/bin/python"),
+ "-c",
+ "import taso,onnx; onnx.save(taso.export_onnx(taso.optimize("
+ f"taso.load_onnx('{subgraph.onnx_filename}'))), '{taso_filename}')",
+ ]
+ )
+ return onnxrt_common(
+ subgraph, provider="CUDAExecutionProvider", onnx_filename=taso_filename
+ )
+
+
+@create_backend
+def ipex(subgraph, **kwargs):
+ import intel_extension_for_pytorch as ipex
+
+ inputs = subgraph.example_inputs
+ model = subgraph.model
+ with torch.no_grad():
+ model.eval()
+ if kwargs["datatype"] == "bf16":
+ model = ipex.optimize(model, dtype=torch.bfloat16)
+ else:
+ model = ipex.optimize(model, dtype=torch.float32)
+ try:
+ traced_model = torch.jit.trace(model, inputs).eval()
+ traced_model = torch.jit.freeze(traced_model)
+ return traced_model
+ except Exception:
+ log.warning("JIT trace failed during the 'ipex' optimize process.")
+ return model
+
+
+def _raise_timeout(signum, frame):
+ raise TimeoutError()
+
+
+@create_backend
+def fx2trt(subgraph, **kwargs):
+ if subgraph.will_tensorrt_barf():
+ # TensorRT fails violently with an abort() on this
+ return None
+
+ from torch_tensorrt.fx.fx2trt import InputTensorSpec, TRTInterpreter
+ from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
+ from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
+ from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
+ from torch_tensorrt.fx.trt_module import TRTModule
+ from torch_tensorrt.fx.utils import LowerPrecision
+
+ from .normalize import normalize_ir
+
+ try:
+ model = subgraph.model
+ inputs = subgraph.example_inputs
+ # normalize
+ model = normalize_ir(model, inputs)
+ # pass rewrite
+ model = transform_setitem(model, inputs)
+ acc_model = acc_tracer.trace(model, inputs)
+ # Split out unsupported ops
+ splitter_setting = TRTSplitterSetting()
+ splitter_setting.use_implicit_batch_dim = False
+ splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
+ splitter.node_support_preview()
+ split_mod = splitter()
+ num_piece = 0
+ for name, _ in split_mod.named_children():
+ print(f"graph is split into {name}")
+ num_piece += 1
+
+ # if the graph module is split into pieces larger than 8, we consider its perf
+ # is not good and fall back to non-TRT
+ if num_piece > 8:
+ print(
+ f"The graph module is split into {num_piece} which is large than the \
+ threshold=8. Fall back to non-TRT module."
+ )
+ return None
+
+ if "fp16_mode" in kwargs and kwargs["fp16_mode"]:
+ precision = LowerPrecision.FP16
+ else:
+ precision = LowerPrecision.FP32
+
+ def get_submod_inputs(mod, submod, inputs):
+ acc_inputs = None
+
+ def get_input(self, inputs):
+ nonlocal acc_inputs
+ acc_inputs = inputs
+
+ handle = submod.register_forward_pre_hook(get_input)
+ mod(*inputs)
+ handle.remove()
+ return acc_inputs
+
+ for name, _ in split_mod.named_children():
+ if "_run_on_acc" in name:
+ submod = getattr(split_mod, name)
+ # print("acc=",submod.code)
+ # Get submodule inputs for fx2trt
+ acc_inputs = get_submod_inputs(split_mod, submod, inputs)
+
+ # fx2trt replacement
+ interp = TRTInterpreter(
+ submod,
+ InputTensorSpec.from_tensors(acc_inputs),
+ explicit_batch_dimension=True,
+ )
+ r = interp.run(
+ max_workspace_size=20 << 30,
+ lower_precision=precision,
+ # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
+ )
+ # For profile
+ # from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
+ # profile_trt_module("", trt_mod, acc_inputs)
+ trt_mod = TRTModule(*r)
+
+ setattr(split_mod, name, trt_mod)
+ else:
+ submod = getattr(split_mod, name)
+ # print("gpu=",submod.code)
+ return subgraph.wrap_returns(split_mod)
+ except Exception:
+ log.exception("FX2TRT conversion error")
+ return None
+
+
+@create_backend
+def torch2trt(subgraph):
+ if subgraph.will_tensorrt_barf():
+ # TensorRT fails violently with an abort() on this
+ return None
+
+ from torch2trt import torch2trt
+
+ inputs = subgraph.example_inputs
+ trt_mod = torch2trt(
+ subgraph.model,
+ inputs,
+ max_batch_size=len(inputs[0]),
+ strict_type_constraints=True,
+ )
+ return subgraph.wrap_returns(trt_mod)
+
+
+@create_backend
+def tensorrt(subgraph):
+ if subgraph.will_tensorrt_barf():
+ # TensorRT fails violently with an abort() on this
+ return None
+
+ model = onnx2tensorrt(subgraph)
+ if model is None:
+ model = torch2trt(subgraph)
+ return model
+
+
+@create_backend
+def onnx2tensorrt_alt(subgraph):
+ if subgraph.will_tensorrt_barf():
+ # TensorRT fails violently with an abort() on this
+ return None
+
+ import tensorrt as trt
+
+ from torch.fx.experimental.fx2trt.trt_module import TRTModule
+
+ inputs = subgraph.example_inputs
+
+ logger = trt.Logger(trt.Logger.ERROR)
+ builder = trt.Builder(logger)
+ config = builder.create_builder_config()
+ assert isinstance(inputs, (list, tuple))
+ inputs = tuple(inputs)
+ input_names = subgraph.input_names
+ output_names = subgraph.output_names
+ network = builder.create_network(
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+ )
+ parser = trt.OnnxParser(network, logger)
+ success = parser.parse(open(subgraph.onnx_filename, "rb").read())
+ for idx in range(parser.num_errors):
+ print(parser.get_error(idx))
+ assert success
+
+ config.max_workspace_size = 1 << 25
+ config.set_flag(trt.BuilderFlag.STRICT_TYPES)
+ builder.max_batch_size = len(inputs[0])
+
+ engine = builder.build_engine(network, config)
+ assert engine
+
+ trt_mod = TRTModule(engine, input_names, output_names)
+ return subgraph.wrap_returns(trt_mod)
+
+
+@create_backend
+def cudagraphs(subgraph):
+ model = subgraph.model
+ inputs = subgraph.example_inputs
+ assert subgraph.is_cuda
+ return subgraph.wrap_returns(cudagraphs_inner(model, inputs))
+
+
+@create_backend
+def cudagraphs_ts(subgraph):
+ assert subgraph.is_cuda
+ model = subgraph.scripted
+ inputs = subgraph.example_inputs
+
+ # warmup
+ for _ in range(3):
+ model(*inputs)
+
+ return subgraph.wrap_returns(cudagraphs_inner(model, inputs))
+
+
+@create_backend
+def cudagraphs_ts_ofi(subgraph):
+ assert subgraph.is_cuda
+ model = torch.jit.optimize_for_inference(torch.jit.freeze(subgraph.scripted))
+ inputs = subgraph.example_inputs
+
+ # warmup
+ for _ in range(3):
+ model(*inputs)
+
+ return subgraph.wrap_returns(cudagraphs_inner(model, inputs))
+
+
+def cudagraphs_inner(model, inputs, copy_outputs=True):
+ assert isinstance(inputs, (list, tuple))
+ static_inputs = [torch.zeros_like(x) for x in inputs]
+
+ # warmup
+ torch.cuda.synchronize()
+ stream = torch.cuda.Stream()
+ stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(stream):
+ model(*inputs)
+ stream.synchronize()
+ torch.cuda.current_stream().wait_stream(stream)
+ torch.cuda.synchronize()
+
+ # record
+ graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(graph, stream=stream):
+ static_outputs = model(*static_inputs)
+ if not isinstance(static_outputs, (list, tuple)):
+ static_outputs = (static_outputs,)
+
+ def run(*new_inputs):
+ assert len(static_inputs) == len(new_inputs)
+ for dst, src in zip(static_inputs, new_inputs):
+ dst.copy_(src)
+ graph.replay()
+ if copy_outputs:
+ return [x.clone() for x in static_outputs]
+ else:
+ return static_outputs
+
+ return run
+
+
+@create_backend
+def aot_autograd(subgraph, **kwargs):
+ def _wrapped_bw_compiler(*args, **kwargs):
+ # stop TorchDynamo from trying to compile our generated backwards pass
+ return disable(bw_compiler(*args, **kwargs))
+
+ bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
+ kwargs["bw_compiler"] = _wrapped_bw_compiler
+
+ from functorch.compile import aot_module_simplified
+
+ from .. import disable
+
+ return aot_module_simplified(subgraph.model, **kwargs)
+
+
+def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs):
+ if jit_mod is None:
+ return None
+ try:
+ return tvm_compile_inner(jit_mod, example_inputs, None, log_file, **kwargs)
+ except Exception as e:
+ if log_file and os.path.exists(log_file):
+ os.unlink(log_file)
+ if isinstance(e, KeyboardInterrupt):
+ raise
+ log.exception("tvm error")
+ return None
+
+
+@create_backend
+def tvm(subgraph):
+ return subgraph.wrap_returns(
+ tvm_compile_inner(
+ subgraph.scripted,
+ subgraph.example_inputs,
+ tuning_option=None,
+ cuda=subgraph.is_cuda,
+ )
+ )
+
+
+@create_backend
+def ansor(subgraph):
+ """
+ WARNING: this backend takes hours or days to train and
+ often produces a slower result than the default schedule.
+ """
+ return subgraph.wrap_returns(
+ tvm_compile_inner(
+ subgraph.scripted,
+ subgraph.example_inputs,
+ tuning_option="auto_scheduler",
+ log_file=subgraph.filename("ansor"),
+ cuda=subgraph.is_cuda,
+ )
+ )
+
+
+@create_backend
+def tvm_meta_schedule(subgraph):
+ return subgraph.wrap_returns(
+ tvm_compile_inner(
+ subgraph.scripted,
+ subgraph.example_inputs,
+ tuning_option="meta_schedule",
+ trials=20000,
+ cuda=subgraph.is_cuda,
+ )
+ )
+
+
+@functools.lru_cache(None)
+def llvm_target():
+ if "avx512" in open("/proc/cpuinfo").read():
+ return "llvm -mcpu=skylake-avx512"
+ return "llvm -mcpu=core-avx2"
+
+
+def tvm_compile_inner(
+ jit_mod, example_inputs, tuning_option=None, log_file=None, trials=20000, cuda=False
+):
+ try:
+ import tvm
+ from tvm import relay
+ from tvm.contrib import graph_executor
+
+ shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
+ mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
+ if cuda:
+ dev = tvm.cuda(0)
+ target = tvm.target.cuda()
+ else:
+ dev = tvm.cpu(0)
+ target = tvm.target.Target(llvm_target())
+
+ if tuning_option == "auto_scheduler":
+ from tvm import auto_scheduler
+
+ if log_file is None:
+ log_file = tempfile.NamedTemporaryFile()
+ if not os.path.exists(log_file):
+ tasks, task_weights = auto_scheduler.extract_tasks(
+ mod["main"], params, target
+ )
+ for task in tasks:
+ print(task.compute_dag)
+ else:
+ print("No tasks")
+ if len(tasks) != 0:
+ tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
+ if not os.path.exists(log_file):
+ assert trials > 0
+ tune_option = auto_scheduler.TuningOptions(
+ num_measure_trials=trials,
+ measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+ early_stopping=2000,
+ )
+ try:
+ tuner.tune(tune_option)
+ except Exception:
+ if os.path.exists(log_file):
+ os.unlink(log_file)
+ raise
+
+ with auto_scheduler.ApplyHistoryBest(log_file):
+ with tvm.transform.PassContext(
+ opt_level=3, config={"relay.backend.use_auto_scheduler": True}
+ ):
+ lib = relay.build(mod, target=target, params=params)
+ elif tuning_option == "meta_schedule":
+ from os import path as osp
+
+ from tvm.meta_schedule import TuneConfig
+ from tvm.meta_schedule.database import JSONDatabase
+ from tvm.meta_schedule.tune import tune_relay
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ if log_file is not None:
+ assert osp.isdir(
+ log_file
+ ), "TVM's meta_schedule requires a directory for storing log files."
+ work_dir = log_file
+ lib: tvm.runtime.Module = tune_relay(
+ mod=mod,
+ params=params,
+ target=target,
+ config=TuneConfig(
+ strategy="evolutionary",
+ num_trials_per_iter=64,
+ max_trials_per_task=trials,
+ max_trials_global=trials,
+ ),
+ work_dir=work_dir,
+ database=JSONDatabase(
+ osp.join(work_dir, "workload.json"),
+ osp.join(work_dir, "records.json"),
+ ),
+ )
+ elif tuning_option is None:
+ # no autotuning (for debugging)
+ with tvm.transform.PassContext(opt_level=10):
+ lib = relay.build(mod, target=target, params=params)
+ else:
+ raise NotImplementedError(
+ "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
+ "There are three available options including None, auto_scheduler and meta_schedule."
+ )
+
+ m = graph_executor.GraphModule(lib["default"](dev))
+
+ def to_torch_tensor(nd_tensor):
+ """A helper function to transfer a NDArray to torch.tensor."""
+ if nd_tensor.dtype == "bool":
+ # DLPack does not support boolean so it can't be handled by
+ # torch.utils.dlpack.from_pack. Workaround by going through
+ # numpy, although this brings additional data copy overhead.
+ return torch.from_numpy(nd_tensor.numpy())
+ return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
+
+ def exec_tvm(*args):
+ args = [a.contiguous() for a in args]
+ for idx, arg in enumerate(args, 0):
+ if arg.dim() != 0:
+ if arg.requires_grad:
+ arg = arg.detach()
+ m.set_input(
+ f"inp_{idx}",
+ tvm.nd.array(arg.numpy(), dev),
+ )
+ m.run()
+ return [
+ to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())
+ ]
+
+ return exec_tvm
+ except Exception:
+ log.exception("tvm error")
+ return jit_mod # explicit fall back to eager
+
+
+@functools.lru_cache(None)
+def _init_ltc():
+ try:
+ import torch._lazy.extract_compiled_graph
+ from torch._lazy.ts_backend import init as init_ts_backend
+
+ # hopefully changing this line to sth like _ltc_init_xla_backend in future
+ # will enable XLA
+ init_ts_backend()
+
+ return torch._lazy
+ except ModuleNotFoundError as e:
+ print(f"ltc backend fails. Can not import {e.name}")
+ raise
+
+
+def ltc_reuse_graph(gm: torch.fx.GraphModule, example_inputs):
+ ltc = _init_ltc()
+ return ltc.extract_compiled_graph.extract_compiled_graph(gm, example_inputs)
+
+
+def ltc_trivial(gm: torch.fx.GraphModule, example_inputs):
+ ltc = _init_ltc()
+ lazy_model = copy.deepcopy(gm).to(device="lazy")
+ ltc.extract_compiled_graph.force_lazy_device(lazy_model)
+
+ def ltc_model(*inputs):
+ orig_device = inputs[0].device if len(inputs) > 0 else "cuda"
+ lazy_inputs = tuple(inp.to(device="lazy") for inp in inputs)
+
+ lazy_out = lazy_model(*lazy_inputs)
+ out = tuple(out.to(device=orig_device) for out in lazy_out)
+ return out
+
+ return ltc_model
+
+
+def ipex_fp32(gm: torch.fx.GraphModule, example_inputs):
+ kwargs_ipex = {"datatype": "fp32"}
+ return BACKENDS["ipex"](gm, example_inputs, **kwargs_ipex)
+
+
+def ipex_bf16(gm: torch.fx.GraphModule, example_inputs):
+ kwargs_ipex = {"datatype": "bf16"}
+ return BACKENDS["ipex"](gm, example_inputs, **kwargs_ipex)
+
+
+def fx2trt_compiler_fp16(gm: torch.fx.GraphModule, example_inputs):
+ kwargs_fx2trt = {"fp16_mode": True}
+ trt_compiled = BACKENDS["fx2trt"](gm, example_inputs, **kwargs_fx2trt)
+ if trt_compiled is not None:
+ return trt_compiled
+ else:
+ print(
+ "FX2TRT conversion failed on the subgraph. Return GraphModule forward instead"
+ )
+ return gm.forward
+
+
+def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs):
+ kwargs_fx2trt = {"fp16_mode": False}
+ trt_compiled = BACKENDS["fx2trt"](gm, example_inputs, **kwargs_fx2trt)
+ if trt_compiled is not None:
+ return trt_compiled
+ else:
+ print(
+ "FX2TRT conversion failed on the subgraph. Return GraphModule forward instead"
+ )
+ return gm.forward
diff --git a/torch/_dynamo/optimizations/distributed.py b/torch/_dynamo/optimizations/distributed.py
new file mode 100644
index 0000000000000..5948f9f03b796
--- /dev/null
+++ b/torch/_dynamo/optimizations/distributed.py
@@ -0,0 +1,183 @@
+from typing import Any, List
+
+import torch
+import torch.fx.traceback as fx_traceback
+from torch import fx
+from torch.fx.node import Node
+
+
+def args_str(args):
+ # a debug helper
+ if torch.is_tensor(args):
+ return f"T[{args.shape}]"
+ elif isinstance(args, tuple):
+ return f"tuple({', '.join([args_str(x) for x in args])})"
+ elif isinstance(args, list):
+ return f"list({', '.join([args_str(x) for x in args])})"
+ else:
+ return str(args)
+
+
+class DDPOptimizer:
+ def __init__(
+ self,
+ bucket_bytes_cap: int,
+ parameters_to_ignore: List[str],
+ backend_compile_fn,
+ debug=False,
+ ):
+ self.bucket_bytes_cap = bucket_bytes_cap
+ self.parameters_to_ignore = parameters_to_ignore
+ self.backend_compile_fn = backend_compile_fn
+ self.debug = debug
+
+ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
+ """
+ TODO:
+ - handle params_and_buffers_to_ignore
+ - handle kwargs
+ """
+
+ # 1: compute the partition map according to DDP bucket logic
+ bucket_bytes = 0
+ bucket_actual_sizes = []
+ node_splits = [[]]
+ for node in reversed(gm.graph.nodes):
+ if bucket_bytes >= self.bucket_bytes_cap:
+ bucket_actual_sizes.insert(0, bucket_bytes)
+ bucket_bytes = 0
+ node_splits.insert(0, [])
+
+ if node.op == "output" or node.op == "placeholder":
+ continue
+
+ elif node.op == "call_module":
+ target = gm.get_submodule(node.target)
+ params_size_b = sum(
+ [
+ p.storage().nbytes()
+ for p in target.parameters()
+ if p.requires_grad
+ ]
+ )
+ bucket_bytes += params_size_b
+ # print(f"accumulated {params_size_b} b from {node}")
+ else:
+ # TODO(whc) confirm this:
+ # (e.g. call_method, call_function aren't expected to 'have' parameters)
+ pass
+
+ node_splits[0].append(node)
+
+ if len(node_splits) == 1:
+ if self.debug:
+ print(
+ "DDPOptimizer did not split graphs."
+ f" Accumulated {bucket_bytes} bytes, and bucket cap is {self.bucket_bytes_cap}"
+ )
+ return self.backend_compile_fn(gm, example_inputs)
+
+ if len(bucket_actual_sizes) < len(node_splits):
+ bucket_actual_sizes.insert(0, bucket_bytes)
+
+ if self.debug:
+ print(
+ f"DDPOptimizer used bucket cap {self.bucket_bytes_cap}"
+ f" and split graphs into parameter sizes {', '.join([str(b) for b in bucket_actual_sizes])}"
+ )
+
+ # 2: partition the graphmodule according to bucket capacity
+ partition_map = {}
+ for p, nodes in enumerate(node_splits):
+ for node in nodes:
+ partition_map[node] = p
+
+ split_gm = fx.passes.split_module.split_module(
+ gm, None, lambda node: partition_map[node]
+ )
+ if self.debug:
+ with open("debug_ddp_optimizer.log", "w") as dump_file:
+ dump_file.write("---orig graph---")
+ dump_file.write(str(gm.graph))
+ dump_file.write("\n---split graph---")
+ dump_file.write(str(split_gm.graph))
+
+ # 3: compile each of the partitioned submodules using the user-provided compiler
+ class SubmodCompiler(torch.fx.interpreter.Interpreter):
+ def __init__(self, module, compiler, debug=False):
+ super().__init__(module)
+ self.compiler = compiler
+ self.debug = debug
+
+ def compile_submod(self, submod, args, kwargs):
+ """
+ Compile the submodule,
+ using a wrapper to make sure its output is always a tuple,
+ which is required by AotAutograd based compilers
+ """
+ assert len(kwargs) == 0, "We assume only args for these modules"
+
+ class WrapperModule(torch.nn.Module):
+ def __init__(self, compiled_submod, unwrap_singleton_tuple):
+ super().__init__()
+ self.compiled_submod = compiled_submod
+ self.unwrap_singleton_tuple = unwrap_singleton_tuple
+
+ def forward(self, *args):
+ x = self.compiled_submod(*args)
+ # TODO(whc)
+ # for some reason the isinstance check is necessary if I split one node per submod
+ # - even though I supposedly wrapped the output in a tuple in those cases, the real
+ # compiled module was still returning a tensor
+ if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
+ return x[0]
+ return x
+
+ unwrap_singleton_tuple = False
+ for sn in submod.graph.nodes:
+ if sn.op == "output":
+ if not isinstance(sn.args[0], tuple):
+ unwrap_singleton_tuple = True
+ sn.args = (sn.args,)
+ submod.recompile()
+
+ wrapper = WrapperModule(
+ self.compiler(submod, args),
+ unwrap_singleton_tuple,
+ )
+ return wrapper
+
+ def run_node(self, n: Node) -> Any:
+ with fx_traceback.append_stack_trace(n.stack_trace):
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
+ if self.debug:
+ print(f"run_node {n.op}, {n.target} got args {args_str(args)}")
+ assert isinstance(args, tuple)
+ assert isinstance(kwargs, dict)
+
+ # modify the currently running FX graph
+ # maybe this isn't sound in general, but only changing the target of a node might be ok?
+ if n.op == "call_module":
+ submod = self.fetch_attr(n.target)
+ if self.debug:
+ with open("debug_ddp_optimizer.log", "a") as dump_file:
+ dump_file.write(f"\n---{n.target} graph---")
+ dump_file.write(str(submod.graph))
+ compiled_submod = self.compile_submod(submod, args, kwargs)
+ self.module.delete_submodule(n.target)
+ n.target = "compiled_" + n.target
+ self.module.add_submodule(n.target, compiled_submod)
+
+ # then we execute the modified node using the usual logic
+ return getattr(self, n.op)(n.target, args, kwargs)
+
+ submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, self.debug)
+ submod_compiler.run(*example_inputs)
+ split_gm.recompile()
+
+ if self.debug:
+ with open("debug_ddp_optimizer.log", "a") as dump_file:
+ dump_file.write("\n---final graph---")
+ dump_file.write(str(split_gm.graph))
+
+ return split_gm
diff --git a/torch/_dynamo/optimizations/inference.py b/torch/_dynamo/optimizations/inference.py
new file mode 100644
index 0000000000000..0ecf454025490
--- /dev/null
+++ b/torch/_dynamo/optimizations/inference.py
@@ -0,0 +1,197 @@
+import base64
+import hashlib
+import io
+import itertools
+import json
+import logging
+import os
+import time
+from collections import defaultdict
+
+import torch
+
+from .. import config
+from ..utils import (
+ check_is_cuda,
+ checkpoint_params,
+ clone_inputs,
+ count_calls,
+ counters,
+)
+from .normalize import long_name, normalize_ir
+
+log = logging.getLogger(__name__)
+
+
+def string_key(gm: torch.fx.GraphModule, example_inputs):
+ out = io.StringIO()
+ node_to_id = defaultdict(iter(itertools.count()).__next__)
+
+ def argkey(n: torch.fx.Node):
+ return f"#{node_to_id[n]}"
+
+ def tensorkey(t):
+ if isinstance(t, torch.Tensor):
+ requires_grad = t.requires_grad and torch.torch.is_grad_enabled()
+ return (
+ f"{t.__class__.__name__}({t.dtype}, {t.device}, "
+ f"{tuple(t.size())}, {tuple(t.stride())}, {requires_grad})"
+ )
+ return type(t).__name__
+
+ inputs_iter = iter(example_inputs)
+
+ for node in gm.graph.nodes:
+ key = argkey(node)
+ name = "."
+ if node.op == "placeholder":
+ name = tensorkey(next(inputs_iter))
+ elif node.op == "get_attr":
+ val = eval(f"self.{node.target}", {"self": gm})
+ name = tensorkey(val)
+ elif node.op in ("call_function", "call_method", "call_module"):
+ name = long_name(gm, node)
+ out.write(
+ f"{key} {node.op} {name} "
+ f"{torch.fx.map_arg(node.args, argkey)!r} "
+ f"{torch.fx.map_arg(node.kwargs, argkey)!r}\n"
+ )
+ return out.getvalue()
+
+
+def graph_hash(gm: torch.fx.GraphModule, example_inputs):
+ return "g" + base64.urlsafe_b64encode(
+ hashlib.sha256(string_key(gm, example_inputs).encode("utf-8")).digest()
+ )[:39].decode("utf-8")
+
+
+def folder_name(gm: torch.fx.GraphModule, example_inputs):
+ base = os.path.join(config.base_dir, "subgraphs")
+ if not os.path.exists(base):
+ os.mkdir(base)
+ open(os.path.join(base, "__init__.py"), "w").close()
+ return os.path.join(base, graph_hash(gm, example_inputs))
+
+
+def record_graph_stats(gm):
+ for node in gm.graph.nodes:
+ if node.op in ("call_function", "call_method", "call_module"):
+ counters[node.op][long_name(gm, node)] += 1
+ elif node.op in ("placeholder", "output", "get_attr"):
+ pass
+ else:
+ raise AssertionError(node.op)
+
+
+def check_requires_grad(gm, example_inputs):
+ if torch.is_grad_enabled():
+ if any(
+ getattr(x, "requires_grad", False)
+ for x in itertools.chain(example_inputs, gm.parameters(True))
+ ):
+ return True
+ return False
+
+
+def jit_trace(gm, example_inputs):
+ """Wrapper around jit.trace to handle hooks"""
+ restore_backward_hooks = []
+
+ def visit(mod):
+ if mod._backward_hooks:
+ restore_backward_hooks.append((mod, mod._backward_hooks))
+ mod._backward_hooks = []
+
+ if not check_requires_grad(gm, example_inputs):
+ # in inference mode it is safe to ignore backwards hooks to allow tracing
+ gm.apply(visit)
+
+ try:
+ return torch.jit.trace(gm.forward, example_inputs)
+ finally:
+ for mod, hooks in restore_backward_hooks:
+ mod._backward_hooks = hooks
+
+
+def same(left, right):
+ return len(left) == len(right) and all(
+ torch.allclose(a, b, atol=1e-4, rtol=1e-4) for a, b in zip(left, right)
+ )
+
+
+class TorchScriptStrategy(object):
+ """Common base for backend strategies that use TorchScript"""
+
+ @classmethod
+ def compile_fn(cls, gm: torch.fx.GraphModule, example_inputs):
+ if count_calls(gm.graph) < 2:
+ return gm.forward # no point for tiny graphs
+ return cls(gm, example_inputs).verified_candidate()
+
+ def __init__(self, gm: torch.fx.GraphModule, example_inputs):
+ super(TorchScriptStrategy, self).__init__()
+ self.restore = checkpoint_params(gm)
+ self.original_example_inputs = example_inputs
+ self.correct = gm.forward(*self.example_inputs)
+ self.gm = normalize_ir(gm, self.original_example_inputs)
+ self.scripted = jit_trace(self.gm, self.example_inputs)
+
+ @property
+ def example_inputs(self):
+ return clone_inputs(self.original_example_inputs)
+
+ def verified_candidate(self):
+ try:
+ candidate = self.candidate()
+ if candidate is None or candidate is self.gm.forward:
+ return self.gm.forward
+
+ self.restore()
+ result = candidate(*self.example_inputs)
+
+ if same(result, self.correct):
+ return candidate
+
+ print(f"incorrect candidate {self}")
+
+ return self.gm.forward
+ except Exception:
+ log.exception("error in verified_candidate()")
+ return self.gm.forward
+ finally:
+ self.restore()
+
+ def candidate(self):
+ raise NotImplementedError()
+
+
+def save_pt(path, name, data):
+ with open(os.path.join(path, name), "wb") as fd:
+ torch.save(data, fd)
+
+
+def save_metadata(path, gm, example_inputs):
+ with open(os.path.join(path, "metadata.json"), "w") as fd:
+ json.dump(
+ {
+ "is_cuda": check_is_cuda(gm, example_inputs),
+ },
+ fd,
+ )
+
+
+def touch_timestamp(path):
+ open(os.path.join(path, "timestamp"), "w").write(str(time.time()))
+
+
+def argmin(perf):
+ best = "eager"
+ best_sec = float("inf")
+ for name, sec in perf.items():
+ if sec < best_sec:
+ best = name
+ best_sec = float(sec)
+ if name == "eager":
+ # small bias torwards using eager since it is more robust
+ best_sec *= 0.99
+ return best
diff --git a/torch/_dynamo/optimizations/log_args.py b/torch/_dynamo/optimizations/log_args.py
new file mode 100644
index 0000000000000..caa0a9a83ce66
--- /dev/null
+++ b/torch/_dynamo/optimizations/log_args.py
@@ -0,0 +1,74 @@
+import json
+import os
+
+import torch
+from torch.fx.experimental.proxy_tensor import make_fx
+
+aten = torch.ops.aten
+
+
+class ConvArgsAnalysis(torch.fx.Interpreter):
+ """
+ Log arguments like input shape (input, bias, weights shape)
+ and options(padding/stride/kernel size/dilation/etc) for
+ aten.convolution
+ """
+
+ def __init__(self, gm: torch.fx.GraphModule):
+ super().__init__(gm)
+
+ self.nodes_conv_args = {}
+ self.conv_arg_names = [
+ arg.name for arg in aten.convolution.default._schema.arguments
+ ]
+
+ def run(self, *args):
+ run_result = super().run(*args)
+ if self.nodes_conv_args:
+ filename = "tmp/conv_args.json"
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ with open(filename, "a") as fd:
+ json.dump(self.nodes_conv_args, fd)
+ fd.write("\n")
+ return run_result
+
+ def run_node(self, n: torch.fx.Node):
+ result = super().run_node(n)
+
+ if n.op == "call_function":
+ if n.target == aten.convolution.default:
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
+ assert len(args) == len(
+ self.conv_arg_names
+ ), f"aten.convolution should have {len(self.conv_arg_names)} args"
+ conv_args = {}
+ # collect tensor's shape, stride (channel first or last), dtype
+ for i in range(3):
+ arg_name = self.conv_arg_names[i]
+ if args[i] is None:
+ conv_args[arg_name] = {
+ "shape": None,
+ "stride": None,
+ "dtype": None,
+ }
+ else:
+ conv_args[arg_name] = {
+ "shape": args[i].shape,
+ "stride": args[i].stride(),
+ "dtype": str(args[i].dtype),
+ }
+ # collect stride/padding/dilation/transposed/output_padding/groups
+ for i in range(3, len(args)):
+ arg_name = self.conv_arg_names[i]
+ conv_args[arg_name] = args[i]
+
+ self.nodes_conv_args[n.name.replace("_default", "")] = conv_args
+ return result
+
+
+def conv_args_analysis(gm: torch.fx.GraphModule, example_inputs):
+ # lowering graph
+ gm = make_fx(gm)(*example_inputs)
+ # use Interpreter to logs the args of conv
+ ConvArgsAnalysis(gm).run(*example_inputs)
+ return gm
diff --git a/torch/_dynamo/optimizations/normalize.py b/torch/_dynamo/optimizations/normalize.py
new file mode 100644
index 0000000000000..47b2c5703a4d9
--- /dev/null
+++ b/torch/_dynamo/optimizations/normalize.py
@@ -0,0 +1,441 @@
+import builtins
+import dataclasses
+import functools
+import itertools
+import logging
+import math
+import operator
+
+import torch
+from torch.fx import Transformer
+from torch.fx.experimental.normalize import NormalizeOperators
+from torch.fx.operator_schemas import get_signature_for_torch_op
+
+from .. import config
+from ..allowed_functions import torch_get_name
+from ..utils import clone_inputs, counters
+from .analysis import ShapeAliasingAndMutationProp
+
+log = logging.getLogger(__name__)
+
+VIEW_OPS = {
+ # list taken from https://pytorch.org/docs/stable/tensor_view.html
+ "getitem",
+ "as_strided",
+ "detach",
+ "diagonal",
+ "expand",
+ "expand_as",
+ "movedim",
+ "narrow",
+ "permute",
+ "select",
+ "squeeze",
+ "transpose",
+ "t",
+ "T",
+ "real",
+ "imag",
+ "view_as_real",
+ "view_as_imag",
+ "unflatten",
+ "unfold",
+ "unsqueeze",
+ "view",
+ "view_as",
+ "unbind",
+ "split",
+ "split_with_sizes",
+ "swapaxes",
+ "swapdims",
+ "chunk",
+ "indices",
+ "values",
+}
+MAYBE_VIEW_OPS = {"contiguous", "reshape"}
+
+# convert x.foo(...) to torch.foo(x, ...)
+NORMALIZE_METHODS = {
+ # These ones aren't normalized:
+ # ('view', 342)
+ # ('reshape', 285)
+ # ('expand', 87)
+ # ('permute', 78)
+ # ('to', 66)
+ # ('contiguous', 62)
+ # ('reshape_as', 57)
+ # ('masked_fill', 30)
+ # ('float', 22) -- could rewrite
+ # ('expand_as', 14) -- could rewrite
+ # ('detach', 4)
+ # ('repeat', 2)
+ # TODO(jansel): debug why this causes issues in detectron2_maskrcnn
+ # "div": torch.div,
+ "add_": operator.iadd,
+ "all": torch.all,
+ "any": torch.any,
+ "ceil": torch.ceil,
+ "chunk": torch.chunk,
+ "clamp": torch.clamp,
+ "clone": torch.clone,
+ "exp": torch.exp,
+ "flatten": torch.flatten,
+ "flip": torch.flip,
+ "floor": torch.floor,
+ "index_select": torch.index_select,
+ "log2": torch.log2,
+ "log_softmax": torch.nn.functional.log_softmax,
+ "max": torch.max,
+ "mean": torch.mean,
+ "min": torch.min,
+ "mul_": operator.imul,
+ "narrow": torch.narrow,
+ "ne": torch.ne,
+ "nonzero": torch.nonzero,
+ "numel": torch.numel,
+ "pow": torch.pow,
+ "round": torch.round,
+ "rsqrt": torch.rsqrt,
+ "sigmoid": torch.sigmoid,
+ "softmax": torch.nn.functional.softmax,
+ "sort": torch.sort,
+ "split": torch.split,
+ "squeeze": torch.squeeze,
+ "std": torch.std,
+ "sum": torch.sum,
+ "topk": torch.topk,
+ "transpose": torch.transpose,
+ "tril": torch.tril,
+ "t": torch.t,
+ "unbind": torch.unbind,
+ "unsqueeze": torch.unsqueeze,
+}
+DONT_EXPAND_MODULES = {
+ # These have internal control flow
+ "ConvTranspose1d",
+ "ConvTranspose2d",
+ "Conv2d",
+ "ConvReLU2d",
+ "ConvBn2d",
+ "ConvBnReLU2d",
+ "EmbeddingBag",
+ "InstanceNorm2d",
+ "LSTM",
+}
+
+F = torch.nn.functional
+INPLACE_KEYWORD_OPS = {
+ F.mish,
+ F.silu,
+ F.hardsigmoid,
+ F.rrelu,
+ F.leaky_relu,
+ F.celu,
+ F.selu,
+ F.elu,
+ F.relu6,
+ F.hardswish,
+ F.hardtanh,
+ F.relu,
+ F.threshold,
+}
+IOPERATOR_REPLACEMENTS = {
+ "masked_fill_": "masked_fill",
+ "scatter_": "scatter",
+ "unsqueeze_": "unsqueeze",
+ torch.relu_: torch.relu,
+ torch.sigmoid_: torch.sigmoid,
+ operator.iadd: torch.add,
+ operator.iand: torch.bitwise_and,
+ operator.ifloordiv: functools.partial(torch.div, rounding_mode="floor"),
+ operator.itruediv: torch.div,
+ operator.imul: torch.mul,
+ operator.imatmul: torch.matmul,
+ operator.ior: torch.bitwise_or,
+ operator.ipow: torch.pow,
+ operator.isub: torch.sub,
+ operator.ixor: torch.bitwise_xor,
+}
+OPERATOR_REPLACEMENTS = {
+ operator.lt: torch.lt,
+ operator.le: torch.le,
+ operator.eq: torch.eq,
+ operator.ne: torch.ne,
+ operator.ge: torch.ge,
+ operator.gt: torch.gt,
+ operator.abs: torch.abs,
+ operator.add: torch.add,
+ operator.and_: torch.bitwise_and,
+ operator.floordiv: functools.partial(torch.div, rounding_mode="floor"),
+ # operator.truediv: torch.div, # TODO(jansel): debug issue in vision_maskrcnn
+ operator.inv: torch.bitwise_not,
+ operator.invert: torch.bitwise_not,
+ operator.mod: torch.remainder,
+ operator.mul: torch.mul,
+ operator.matmul: torch.matmul,
+ operator.neg: torch.neg,
+ operator.or_: torch.bitwise_or,
+ operator.pos: torch.positive,
+ operator.pow: torch.pow,
+ operator.sub: torch.sub,
+ operator.xor: torch.bitwise_xor,
+ torch.nn.functional.sigmoid: torch.sigmoid,
+ torch.nn.functional.tanh: torch.tanh,
+ torch.nn.functional.relu: torch.relu,
+}
+
+SKIP_INPLACE = {
+ v
+ for v in itertools.chain(
+ math.__dict__.values(), builtins.__dict__.values(), operator.__dict__.values()
+ )
+ if callable(v)
+}
+
+
+def always_true(*args, **kwargs):
+ return True
+
+
+class InliningTracer(torch.fx.Tracer):
+ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
+ return False
+
+
+def expand_module_call(prefix, graph: torch.fx.Graph, module, args, kwargs):
+ # this patch is needed to make BatchNorm2D FX trace
+ module.__dict__["_check_input_dim"] = always_true
+ try:
+ assert not kwargs
+ arg_index = itertools.count()
+ vars = dict()
+ for node in InliningTracer().trace(module).nodes:
+ if node.op == "placeholder":
+ vars[node] = args[next(arg_index)]
+ elif node.op == "output":
+ assert len(node.args) == 1
+ return vars[node.args[0]]
+ elif node.op == "get_attr":
+ vars[node] = graph.get_attr(f"{prefix}{node.target}")
+ else:
+ vars[node] = graph.node_copy(node, vars.__getitem__)
+ raise AssertionError("unreachable")
+ except Exception:
+ print(f"Error while expanding {module.__class__.__name__}")
+ raise
+ finally:
+ del module.__dict__["_check_input_dim"]
+
+
+@dataclasses.dataclass
+class NodeCounts:
+ usages: int = 0
+
+
+def short_name(gm, node: torch.fx.Node):
+ if node.op == "call_function":
+ return node.target.__name__
+ elif node.op == "call_method":
+ return node.target
+ elif node.op == "call_module":
+ return gm.get_submodule(node.target).__class__.__name__
+ elif node.op == "get_attr":
+ return node.target
+ elif node.op == "output":
+ return "output"
+ raise AssertionError(node.op)
+
+
+def long_name(gm, node: torch.fx.Node):
+ name = short_name(gm, node)
+ target = node.target
+ if node.op == "call_function":
+ return torch_get_name(
+ node.target, f"{getattr(target, '__module__', '')}.{name}"
+ )
+ elif node.op == "call_method":
+ return name
+ elif node.op == "call_module":
+ target = gm.get_submodule(target).__class__
+ return f"{getattr(target, '__module__', '')}.{getattr(target, '__name__', '')}"
+ elif node.op == "get_attr":
+ return name
+ elif node.op == "output":
+ return "output"
+ raise AssertionError("unreachable")
+
+
+class Inplacifier:
+ def __init__(self, gm: torch.fx.GraphModule):
+ self.gm = gm
+
+ def can_be_view(self, node):
+ name = short_name(self.gm, node)
+ return name in VIEW_OPS or name in MAYBE_VIEW_OPS
+
+ def inplacify(self):
+ counts = dict()
+
+ def record_usage(node):
+ counts[node].usages += 1
+ return node
+
+ for node in self.gm.graph.nodes:
+ if node.op in ("call_function", "call_method", "call_module"):
+ if self.can_be_view(node):
+ # Aliasing
+ counts[node] = counts[node.args[0]]
+ elif "out" in node.kwargs:
+ counts[node] = counts[node.kwargs["out"]]
+ else:
+ counts[node] = NodeCounts(0)
+ else:
+ counts[node] = NodeCounts(float("inf"))
+
+ for node in reversed(list(self.gm.graph.nodes)):
+ kwargs = dict(node.kwargs)
+ if "inplace" in kwargs:
+ kwargs.pop("inplace")
+ if node.op == "call_function" and len(node.args) + len(kwargs) == 1:
+ arg = node.args[0] if node.args else next(kwargs.values())
+ if isinstance(arg, torch.fx.Node) and counts[arg].usages == 0:
+ if node.target in SKIP_INPLACE:
+ continue
+ elif node.target in INPLACE_KEYWORD_OPS:
+ kwargs["inplace"] = True
+ counters["optimizations"]["inplace"] += 1
+ elif " out: torch.Tensor" in repr(
+ get_signature_for_torch_op(node.target)
+ ):
+ kwargs["out"] = arg
+ counters["optimizations"]["out"] += 1
+ else:
+ continue
+ with self.gm.graph.inserting_before(node):
+ node.replace_all_uses_with(
+ self.gm.graph.call_function(node.target, node.args, kwargs)
+ )
+ self.gm.graph.erase_node(node)
+
+ torch.fx.map_arg((node.args, node.kwargs), record_usage)
+
+
+class Functionalization(Transformer):
+ """
+ Remove most cases of mutation from a given fx Graph.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(Functionalization, self).__init__(*args, **kwargs)
+ self.tracer.tensor_attrs = dict() # TODO(jansel): upstream this fix
+
+ def run_node(self, n: torch.fx.Node):
+
+ patches = []
+ target = n.target
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
+ kwargs = dict(kwargs)
+
+ if (
+ not n.meta["is_input_mutation"]
+ and not n.meta["partial_mutation"]
+ and issubclass(n.meta["type"], torch.Tensor)
+ ):
+ if "inplace" in n.kwargs:
+ if kwargs["inplace"]:
+ patches.append(n.args[0])
+ kwargs.pop("inplace")
+ elif "out" in n.kwargs:
+ kwargs.pop("out")
+ patches.append(n.kwargs["out"])
+ elif n.target in IOPERATOR_REPLACEMENTS:
+ target = IOPERATOR_REPLACEMENTS[n.target]
+ patches.append(n.args[0])
+ elif n.meta["is_mutation"]:
+ counters["mutation"][long_name(self.module, n)] += 1
+
+ if target in OPERATOR_REPLACEMENTS and not kwargs:
+ target = OPERATOR_REPLACEMENTS[target]
+
+ if target is builtins.getattr:
+ if args[1] == "dtype":
+ return n.args[0].meta["dtype"]
+ elif args[1] == "device":
+ return n.args[0].meta["device"]
+ else:
+ counters["getattr"][args[1]] += 1
+
+ if isinstance(target, functools.partial):
+ assert not target.args
+ kwargs.update(target.keywords)
+ target = target.func
+
+ if not issubclass(n.meta["type"], torch.Tensor):
+ counters["nontensor"][long_name(self.module, n)] += 1
+
+ with self._set_current_node(n):
+ result = getattr(self, n.op)(target, args, kwargs)
+
+ # For inplace operators, the output dtype should be equal to the
+ # dtype of tensor being inplace modified.
+ if n.target in IOPERATOR_REPLACEMENTS:
+ result = self.call_method("to", (result, n.args[0].meta["dtype"]), {})
+
+ for patch in patches:
+ assert isinstance(
+ patch, torch.fx.Node
+ ), f"{patch} {n.target} {n.args} {n.kwargs}"
+ if patch in self.env:
+ self.env[patch] = result
+
+ return result
+
+
+def swap_node(graph, old_node, new_node):
+ old_node.replace_all_uses_with(new_node)
+ graph.erase_node(old_node)
+ new_node.meta = old_node.meta
+
+
+def normalize(gm: torch.fx.GraphModule):
+ # gm.graph.print_tabular()
+ graph: torch.fx.Graph = gm.graph
+
+ for node in list(graph.nodes):
+ with graph.inserting_before(node):
+ if node.op == "call_method" and node.target in NORMALIZE_METHODS:
+ swap_node(
+ graph,
+ node,
+ graph.call_function(
+ NORMALIZE_METHODS[node.target], node.args, node.kwargs
+ ),
+ )
+ elif node.op == "call_module":
+ submod = gm.get_submodule(node.target)
+ if submod.__class__.__name__ not in DONT_EXPAND_MODULES:
+ swap_node(
+ graph,
+ node,
+ expand_module_call(
+ f"{node.target}.", graph, submod, node.args, node.kwargs
+ ),
+ )
+
+ # gm.graph.print_tabular()
+
+
+def normalize_ir(gm, example_inputs):
+ if config.normalize_ir:
+ example_inputs = clone_inputs(example_inputs)
+ normalize(gm)
+ try:
+ gm = NormalizeOperators(gm).transform()
+ except AttributeError:
+ # log.exception("NormalizeOperators() failed")
+ pass
+ ShapeAliasingAndMutationProp(gm).run(*example_inputs)
+ gm = Functionalization(gm).transform()
+ gm.recompile()
+ # record_graph_stats(gm)
+ return gm
diff --git a/torch/_dynamo/optimizations/subgraph.py b/torch/_dynamo/optimizations/subgraph.py
new file mode 100644
index 0000000000000..55b7736755667
--- /dev/null
+++ b/torch/_dynamo/optimizations/subgraph.py
@@ -0,0 +1,236 @@
+import functools
+import importlib
+import itertools
+import json
+import logging
+import math
+import operator
+import os
+
+import torch
+
+from .. import config
+from ..utils import check_is_cuda, checkpoint_params, is_jit_model, torchscript
+
+log = logging.getLogger(__name__)
+
+
+def cached(fn):
+ cached_name = f"_{fn.__name__}"
+
+ @functools.wraps(fn)
+ def inner(self):
+ if hasattr(self, cached_name):
+ return getattr(self, cached_name)
+ result = fn(self)
+ setattr(self, cached_name, result)
+ return result
+
+ return inner
+
+
+def load_module_fx(name):
+ pymod = importlib.import_module(f"subgraphs.{name}")
+ # TODO(jansel): upstream these fixes to to_folder()
+ pymod.module._operator_iadd = operator.iadd
+ pymod.module._operator_imul = operator.imul
+ pymod.module._operator_itruediv = operator.itruediv
+ pymod.module._operator_setitem = operator.setitem
+ pymod.module.math_sqrt = math.sqrt
+ pymod.module.device = torch.device
+ pymod.module.inf = float("inf")
+ return pymod.FxModule()
+
+
+def load_module_jit(name):
+ filename = os.path.join(config.base_dir, "subgraphs", name, "model.ts")
+ if not os.path.exists(filename):
+ return None
+ model = torch.jit.load(filename)
+ assert is_jit_model(model)
+ return model
+
+
+class SubGraph(object):
+ @classmethod
+ def load(cls, name):
+ model_dir = os.path.join(config.base_dir, "subgraphs", name)
+ example_inputs = torch.load(os.path.join(model_dir, "example_inputs.pt"))
+ example_outputs = torch.load(os.path.join(model_dir, "example_outputs.pt"))
+ metadata = json.loads(open(os.path.join(model_dir, "metadata.json")).read())
+ model_fx = load_module_fx(name)
+ model_jit = load_module_jit(name)
+ is_cuda = metadata["is_cuda"]
+
+ assert model_jit is not None
+
+ torch.set_rng_state(torch.load(os.path.join(model_dir, "rng_state.pt")))
+ if is_cuda:
+ model_jit = model_jit.cuda()
+ restore_jit = checkpoint_params(model_jit)
+ if model_fx is not None:
+ if is_cuda:
+ model_fx = model_fx.cuda()
+ restore_fx = checkpoint_params(model_fx)
+ else:
+ model_fx = model_jit
+ restore_fx = restore_jit
+
+ def restore():
+ restore_fx()
+ restore_jit()
+
+ subgraph = cls(model_fx, example_inputs, model_dir)
+ subgraph._scripted = model_jit
+ subgraph._example_outputs = example_outputs
+ subgraph._is_cuda = is_cuda
+ subgraph.restore = restore
+ return subgraph
+
+ def __init__(self, model, example_inputs, model_dir):
+ super(SubGraph, self).__init__()
+ self.model = model
+ self.example_inputs = example_inputs
+ self.model_dir = model_dir
+
+ def filename(self, name):
+ return os.path.join(self.model_dir, name)
+
+ @property
+ @cached
+ def scripted(self):
+ return torchscript(self.model, self.example_inputs)
+
+ @property
+ @cached
+ def example_outputs(self):
+ filename = self.filename("example_outputs.pt")
+ if os.path.exists(filename):
+ return torch.load(filename)
+ result = self.model(*self.example_inputs)
+ torch.save(result, filename)
+ return result
+
+ @property
+ def example_outputs_list(self):
+ if self.is_tensor_output:
+ return [self.example_outputs]
+ return self.example_outputs
+
+ @property
+ def input_names(self):
+ return [f"i{i}" for i in range(len(self.example_inputs))]
+
+ @property
+ def is_tensor_output(self):
+ return not isinstance(self.example_outputs, (list, tuple))
+
+ @property
+ def output_names(self):
+ return [f"o{x}" for x in range(len(self.example_outputs_list))]
+
+ @property
+ def device_index(self):
+ return 0
+
+ @property
+ @cached
+ def onnx_filename(self):
+ filename = self.filename("onnx")
+ if os.path.exists(filename):
+ return filename
+
+ try:
+ torch.onnx.export(
+ self.scripted,
+ self.example_inputs,
+ filename,
+ input_names=self.input_names,
+ output_names=self.output_names,
+ do_constant_folding=True,
+ opset_version=14,
+ )
+ except IndexError:
+ # work around bug in constant folding pass
+ torch.onnx.export(
+ self.scripted,
+ self.example_inputs,
+ filename,
+ input_names=self.input_names,
+ output_names=self.output_names,
+ do_constant_folding=False,
+ opset_version=14,
+ )
+ return filename
+
+ @property
+ def is_cpu(self):
+ return not self.is_cuda
+
+ @property
+ @cached
+ def is_cuda(self):
+ return check_is_cuda(self.model, self.example_inputs)
+
+ @property
+ def output_specs(self):
+ return [
+ (o.shape, o.dtype, o.layout, o.device, o.requires_grad)
+ for o in self.example_outputs_list
+ ]
+
+ def empty_outputs_factory(self):
+ specs = self.output_specs
+
+ def create():
+ return [
+ torch.empty(
+ shape,
+ dtype=dtype,
+ layout=layout,
+ device=device,
+ requires_grad=requires_grad,
+ )
+ for shape, dtype, layout, device, requires_grad in specs
+ ]
+
+ return create
+
+ def wrap_returns(self, fn):
+ """Fix [Tensor()] vs Tensor() return type issues"""
+ expected = self.example_outputs
+ actual = fn(*self.example_inputs)
+ if isinstance(expected, (list, tuple)) and not isinstance(
+ actual, (list, tuple)
+ ):
+ assert len(expected) == 1
+ if isinstance(expected, tuple):
+ return lambda *args: (fn(*args),)
+ else:
+ return lambda *args: [fn(*args)]
+ elif not isinstance(expected, (list, tuple)) and isinstance(
+ actual, (list, tuple)
+ ):
+ assert len(actual) == 1
+ return lambda *args: fn(*args)[0]
+ elif isinstance(expected, (list, tuple)) and isinstance(actual, (list, tuple)):
+ assert len(actual) == len(expected)
+ return fn
+ else:
+ return fn
+
+ def has_dtype(self, dtype):
+ for x in itertools.chain(
+ self.example_inputs, self.scripted.parameters(), self.scripted.buffers()
+ ):
+ if x.dtype == dtype:
+ return True
+ return False
+
+ def will_tensorrt_barf(self):
+ return False
+ # code = torch.jit.freeze(self.scripted).code
+ # TODO(jansel): submit a bug report for this one, issue is in opacus_cifar10
+ # if "group_norm" in code or "einsum" in code:
+ # return True
+ # return self.has_dtype(torch.int64)
diff --git a/torch/_dynamo/optimizations/training.py b/torch/_dynamo/optimizations/training.py
new file mode 100644
index 0000000000000..bec450bd37430
--- /dev/null
+++ b/torch/_dynamo/optimizations/training.py
@@ -0,0 +1,556 @@
+import functools
+import logging
+import operator
+from collections import defaultdict
+from functools import partial
+from importlib import import_module
+from typing import Set
+
+import torch
+from torch.fx import GraphModule
+from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
+from torch.multiprocessing.reductions import StorageWeakRef
+from torch.nn import Module
+from torch.utils._pytree import tree_map
+
+from .. import config
+from ..debug_utils import wrap_compiler_debug
+from ..utils import clone_inputs, count_calls, counters
+from .analysis import has_mutation
+from .backends import BACKENDS
+from .normalize import normalize_ir
+
+log = logging.getLogger(__name__)
+
+
+def is_aot_autograd_safe_to_run(gm, example_inputs):
+ """
+ There are some known issues with Aot Autograd. This is a workaround to catch
+ such cases, and fallback to eager. We should fix these quickly.
+
+ Issues
+ 1) LSTM - https://github.com/pytorch/torchdynamo/issues/1147
+ 2) LSTM - https://github.com/pytorch/functorch/issues/586
+ 3) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301
+ """
+
+ def raise_or_warn(reason):
+ msg = f"Unable to use Aot Autograd because of presence of {reason}"
+ if config.raise_on_unsafe_aot_autograd:
+ raise NotImplementedError(msg)
+ else:
+ log.warning(msg)
+ return False
+
+ import functorch.compile
+
+ # 1) LSTM module (tts_angular) - https://github.com/pytorch/functorch/issues/586
+ for submod in gm.modules():
+ if submod.__class__.__name__ == "LSTM":
+ return raise_or_warn("LSTM")
+
+ # 2) Mutation in the graph
+ mutated = False
+ try:
+ if functorch.compile.config.use_functionalize:
+ # There are two problematic classes we still exclude for now with
+ # functionalization:
+ # - data mutation of inputs (fixed when we stop recording the
+ # copy_ directly into the graph)
+ # - metadata mutation of inputs (fixed if we do an extra partition
+ # to avoid AotAutograd on the mutated inputs, or if we some how
+ # get custom autograd function to reflect metadata changes to the
+ # original tensor)
+ mutated = has_mutation(gm, example_inputs, inputs_only=True)
+ else:
+ mutated = has_mutation(gm, example_inputs)
+ except NotImplementedError as e:
+ if "SparseTensorImpl" not in str(e):
+ # TODO - TorchDynamo mutation analysis cannot handle sparse tensors.
+ # So, there is a chance that we could call Aot Autograd when it is
+ # unsafe.
+ # The exception is fairly guarded with string check, so any other
+ # mutation analysis bugs will raise exceptions and will be caught.
+ raise e
+ pass
+
+ if mutated:
+ return raise_or_warn("mutation")
+
+ return True
+
+
+class AotAutogradStrategy(object):
+ """Base class for backend strategies that use AOT Autograd"""
+
+ @classmethod
+ def compile_fn(cls, gm: torch.fx.GraphModule, example_inputs):
+ if count_calls(gm.graph) < 2:
+ return gm # no point for tiny graphs
+ return cls(gm, example_inputs).verified_candidate()
+
+ def __init__(self, gm: torch.fx.GraphModule, example_inputs):
+ import functorch.compile
+
+ functorch.compile.config.use_functionalize = True
+ functorch.compile.config.use_fake_tensor = True
+
+ super(AotAutogradStrategy, self).__init__()
+ counters["aot_autograd"]["total"] += 1
+ self.use_fallback = False
+ self.original_example_inputs = example_inputs
+ self.gm = gm
+
+ if not functorch.compile.config.use_functionalize and config.normalize_ir:
+ try:
+ self.gm = normalize_ir(gm, self.example_inputs)
+ except Exception:
+ log.debug("TorchDynamo unable to remove mutation")
+ self.use_fallback = True
+ pass
+
+ if not is_aot_autograd_safe_to_run(gm, example_inputs):
+ self.use_fallback = True
+
+ @property
+ def example_inputs(self):
+ return clone_inputs(self.original_example_inputs)
+
+ def verified_candidate(self):
+ if self.use_fallback:
+ log.debug("Unable to use AOT Autograd because graph has mutation")
+ counters["aot_autograd"]["not_ok"] += 1
+ return self.gm
+ cg = self.candidate()
+ if cg is None:
+ counters["aot_autograd"]["not_ok"] += 1
+ raise RuntimeError("AOT Autograd failed to compile")
+ counters["aot_autograd"]["ok"] += 1
+ return cg
+
+ def candidate(self):
+ raise NotImplementedError()
+
+
+class AotNop(AotAutogradStrategy):
+ """Useful for debugging purpose"""
+
+ def candidate(self):
+ from functorch.compile import nop
+
+ return BACKENDS["aot_autograd"](self.gm, self.example_inputs, fw_compiler=nop)
+
+
+aot_eager = AotNop.compile_fn
+
+
+class AotTorchscript(AotAutogradStrategy):
+ """
+ AOT Autograd with torchscript backend. Default partitioner.
+ """
+
+ def candidate(self):
+ from functorch.compile import ts_compile
+
+ return BACKENDS["aot_autograd"](
+ self.gm, self.example_inputs, fw_compiler=ts_compile
+ )
+
+
+aot_ts = AotTorchscript.compile_fn
+
+# Global counter to differentiate between different graphs.
+graph_idx = 0
+
+
+class AotPrint(AotNop):
+ """Saves all the gm models so that we can run them separately"""
+
+ def candidate(self):
+ global graph_idx
+ module_idx = "module_" + str(graph_idx)
+ self.gm.to_folder(module_idx, "Bar")
+ for idx, x in enumerate(self.example_inputs):
+ torch.save(x, module_idx + "_tensor" + str(idx) + ".pt")
+ graph_idx += 1
+ return super(AotPrint, self).candidate()
+
+
+aot_print = AotPrint.compile_fn
+
+
+def mem_efficient_fusion_kwargs(use_decomps):
+ from functorch.compile import (
+ default_decompositions,
+ min_cut_rematerialization_partition,
+ ts_compile,
+ )
+
+ kwargs = {
+ # these are taken from memory_efficient_fusion()
+ "fw_compiler": ts_compile,
+ "bw_compiler": ts_compile,
+ "partition_fn": min_cut_rematerialization_partition,
+ }
+
+ if use_decomps:
+ kwargs["decompositions"] = default_decompositions
+
+ return kwargs
+
+
+class AotMemEfficientFusion(AotAutogradStrategy):
+ """Use Min cut rematerilization and NVFuser with AOT Autograd"""
+
+ def candidate(self):
+ kwargs = mem_efficient_fusion_kwargs(use_decomps=True)
+ return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs)
+
+
+class AotMemEfficientFusionNoDecomps(AotAutogradStrategy):
+ """Use Min cut rematerilization and NVFuser with AOT Autograd"""
+
+ def candidate(self):
+ kwargs = mem_efficient_fusion_kwargs(use_decomps=False)
+ return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs)
+
+
+class AotInductorDebug(AotAutogradStrategy):
+ """
+ Uses TorchInductor Aot Autograd decopms and partitioner to isolate aot vs
+ inductor problems.
+ """
+
+ def candidate(self):
+ from functorch.compile import min_cut_rematerialization_partition, nop
+
+ decompositions = import_module(
+ f"{config.inductor_import}.compile_fx"
+ ).select_decomp_table()
+
+ kwargs = {
+ # these are taken from memory_efficient_fusion()
+ "fw_compiler": nop,
+ "bw_compiler": nop,
+ "decompositions": decompositions,
+ "partition_fn": functools.partial(
+ min_cut_rematerialization_partition, compiler="inductor"
+ ),
+ }
+ return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs)
+
+
+aot_inductor_debug = AotInductorDebug.compile_fn
+
+
+class AOTMemEfficientFusionWithContext:
+ """Pass nvfuser context to TorchDynamo"""
+
+ def __init__(self, use_decomps=True):
+ self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
+ self.use_decomps = use_decomps
+
+ def __call__(self, gm: torch.fx.GraphModule, example_inputs):
+ if self.use_decomps:
+ return AotMemEfficientFusion.compile_fn(gm, example_inputs)
+ else:
+ return AotMemEfficientFusionNoDecomps.compile_fn(gm, example_inputs)
+
+
+aot_mem_efficient_fusion = AOTMemEfficientFusionWithContext(True)
+aot_mem_efficient_fusion_no_decomp = AOTMemEfficientFusionWithContext(False)
+
+
+class AotPrimsNvfuser(AotAutogradStrategy):
+ """
+ Use FX graph partitioner + Aten2Prims ref + trace executor + nvFuser
+ """
+
+ def __init__(self, gm: torch.fx.GraphModule, example_inputs):
+ super(AotPrimsNvfuser, self).__init__(gm, example_inputs)
+
+ from functorch.compile import min_cut_rematerialization_partition
+
+ from torch.fx.passes.backends.nvfuser import NvFuserBackend
+
+ self.nvfuser = NvFuserBackend()
+ self.min_cut_rematerialization_partition = min_cut_rematerialization_partition
+ self.populate_aten2aten_decomps()
+
+ def populate_aten2aten_decomps(self):
+ from torch._decomp import get_decompositions
+
+ aten = torch.ops.aten
+ default_decompositions = {
+ aten.detach,
+ aten.gelu_backward,
+ aten.leaky_relu_backward,
+ aten.sigmoid_backward,
+ aten.threshold_backward,
+ aten.hardtanh_backward,
+ aten.hardsigmoid_backward,
+ aten.hardswish_backward,
+ aten.tanh_backward,
+ aten.silu_backward,
+ aten.elu_backward,
+ aten.cudnn_batch_norm,
+ aten.cudnn_batch_norm_backward,
+ aten.masked_fill.Scalar,
+ aten.masked_fill.Tensor,
+ aten.elu,
+ aten.leaky_relu,
+ aten.hardtanh,
+ aten.hardswish,
+ aten.hardsigmoid,
+ aten.rsub,
+ aten.native_batch_norm_backward,
+ }
+
+ self.aten2aten_decompositions = get_decompositions(default_decompositions)
+
+ def candidate(self):
+ return BACKENDS["aot_autograd"](
+ self.gm,
+ self.example_inputs,
+ fw_compiler=wrap_compiler_debug(self.nvfuser, "nvfuser"),
+ partition_fn=self.min_cut_rematerialization_partition,
+ decompositions=self.aten2aten_decompositions,
+ )
+
+
+aot_prims_nvfuser = AotPrimsNvfuser.compile_fn
+
+
+def prims_executor(gm, inputs, *, executor):
+ # This function is called once per forward/backward pass of a graph in AOT
+ # Autograd. We use it to set up the nvFuser-specific FX graph and return
+ # execute function.
+ from torch._prims.context import TorchRefsNvfuserCapabilityMode
+ from torch._prims.executor import execute
+ from torch.fx.experimental.proxy_tensor import make_fx
+
+ # First we trace the graph conditionally decomposing nodes
+ # that can be sent to the nvfuser executor
+ with TorchRefsNvfuserCapabilityMode():
+ prim_gm = make_fx(gm)(*inputs)
+
+ # Then we return a callable that executes the "prim_gm" graph
+ return partial(execute, prim_gm, executor=executor)
+
+
+def create_nvprims_backend(*, executor):
+ class NvPrims(AotAutogradStrategy):
+ def __init__(self, gm: torch.fx.GraphModule, example_inputs):
+ super(NvPrims, self).__init__(gm, example_inputs)
+ self.executor = executor
+
+ def candidate(self):
+ return BACKENDS["aot_autograd"](
+ self.gm,
+ self.example_inputs,
+ fw_compiler=partial(prims_executor, executor=self.executor),
+ bw_compiler=partial(prims_executor, executor=self.executor),
+ )
+
+ return NvPrims
+
+
+aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser").compile_fn
+aot_nvprims_aten = create_nvprims_backend(executor="aten").compile_fn
+
+
+def cloner(t):
+ if isinstance(t, torch.Tensor):
+ return t.clone()
+ else:
+ return t
+
+
+class CudaGraphModule(Module):
+ gm: GraphModule
+ mutated_inputs: Set[int]
+
+ def __init__(self, gm, mutated_inputs):
+ super().__init__()
+ self.gm = gm
+ self.mutated_inputs = mutated_inputs
+
+ warmed_up = False
+
+ # these are all None or all filled
+ graph = None
+ static_inputs = None
+ static_outputs = None
+
+ # NB: we override __call__ as we don't need any nn.Module machinery
+ # and to reduce overhead
+ def __call__(self, *args):
+ # TODO: once we've recorded here, we'd like to replace the __call__
+ # implementation with compiled bytecode that copies into static, replays
+ # the cuda graph, then copies out. First condition is the hotpath,
+ # needs optimizing
+ if self.graph is not None:
+ assert len(args) == len(self.static_inputs)
+ for dst, src in zip(self.static_inputs, args):
+ dst.copy_(src)
+ self.graph.replay()
+ for i in self.mutated_inputs:
+ args[i].copy_(self.static_inputs[i])
+ return tree_map(cloner, self.static_outputs)
+
+ elif self.warmed_up:
+ # record
+ self.static_inputs = [x.clone() for x in args]
+ self.graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(self.graph):
+ self.static_outputs = self.gm(*self.static_inputs)
+ # NB: recording doesn't actually run the operations, so
+ # now we immediately replay the graph to serve up the result
+ self.graph.replay()
+ for i in self.mutated_inputs:
+ args[i].copy_(self.static_inputs[i])
+ return tree_map(cloner, self.static_outputs)
+
+ else:
+ # warmup
+ stream = torch.cuda.Stream()
+ stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(stream):
+ r = self.gm(*args)
+ torch.cuda.current_stream().wait_stream(stream)
+ self.warmed_up = True
+ return r
+
+
+# Interpreter versions of these passes can be found at
+# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
+
+
+def find_input_mutations(g):
+ def meta_fk(meta):
+ return meta["val"] if "val" in meta else meta["fake_result"]
+
+ inputs = defaultdict(set)
+ input_idx = 0
+ mutated_inputs = set()
+ for n in g.nodes:
+ if n.op == "placeholder":
+ inputs[StorageWeakRef(meta_fk(n.meta).storage())].add(input_idx)
+ input_idx += 1
+ elif n.op == "call_function":
+ if n.target is operator.getitem:
+ continue
+ schema = n.target._schema
+ for i, arg in enumerate(schema.arguments):
+ if i < len(n.args):
+ argument = n.args[i]
+ else:
+ if arg.name not in n.kwargs:
+ continue
+ argument = n.kwargs[arg.name]
+ mut_arg = False
+ if arg.alias_info:
+ if arg.alias_info.is_write:
+ mut_arg = True
+ if mut_arg:
+ # TODO: not correct for args that contain tensors in a struct
+ # like list
+ mutated_inputs |= inputs[
+ StorageWeakRef(meta_fk(argument.meta).storage())
+ ]
+ # TODO: error on unrecognized nodes
+ return mutated_inputs
+
+
+# Mutates input graph
+def apply_cuda_graphs(gm):
+ for n in gm.graph.nodes:
+ if n.op == "call_module":
+ assert not n.kwargs
+ submod = gm.get_submodule(n.target)
+ gm.delete_submodule(n.target)
+ mutated_inputs = find_input_mutations(submod.graph)
+ gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
+ # NB: we didn't actually change the graph, no need for recompile
+
+
+def cudagraphs(model, inputs):
+ model = partition_cudagraphs(model, inputs)
+ apply_cuda_graphs(model)
+ return model
+
+
+def raw_aot_autograd_cudagraphs(model, inputs):
+ kwargs = {
+ # these are taken from memory_efficient_fusion()
+ "fw_compiler": cudagraphs,
+ "bw_compiler": cudagraphs,
+ }
+
+ def _wrapped_bw_compiler(*args, **kwargs):
+ # stop TorchDynamo from trying to compile our generated backwards pass
+ return disable(bw_compiler(*args, **kwargs)) # type: ignore[operator]
+
+ bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
+ kwargs["bw_compiler"] = _wrapped_bw_compiler
+
+ from functorch.compile import aot_module_simplified # type: ignore[import]
+
+ from .. import disable
+
+ return aot_module_simplified(model, **kwargs)
+
+
+class AotAutogradCudaGraphs(AotAutogradStrategy):
+ def candidate(self):
+ return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)
+
+
+aot_cudagraphs = AotAutogradCudaGraphs.compile_fn
+
+
+def create_aot_backends():
+ """
+ Register aliases for the AOT backends
+ """
+ # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
+ BACKENDS["aot_eager"] = aot_eager
+
+ # aot_eager uses AOT Autograd backend with print compiler. It prints the
+ # graphs and also saves the graph modules that are sent to AOT Autograd.
+ # This is helpful for debugging.
+ BACKENDS["aot_print"] = aot_print
+
+ # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
+ # by using the relevant fuser with torch.jit.fuser(...)
+ BACKENDS["aot_ts"] = aot_ts
+
+ # prims_nvfuser uses the prims and AOT-Autograd to get FX-aten IR. And then
+ # directly lowers to NVFuser without relying no Torchscript.
+ BACKENDS["prims_nvfuser"] = aot_prims_nvfuser
+
+ # "nvprims" is a subset of PrimTorch primitives that are guaranteed to be
+ # supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch.
+ BACKENDS["nvprims_nvfuser"] = aot_nvprims_nvfuser
+ # This is useful for debugging. Can be removed later.
+ BACKENDS["nvprims_aten"] = aot_nvprims_aten
+
+ # aot_nvfuser uses the memory efficient fusion algorithm from AOT Autograd.
+ # It uses min cut rematerialization algorithm, and uses nvfuser as the
+ # compiler backend. This is the most optimized setting with nvfuser for
+ # training.
+ BACKENDS["aot_nvfuser"] = aot_mem_efficient_fusion
+
+ # Similar to aot_nvfuser, but disables the decompositions. Decompositions
+ # can cause accuracy deviations. This setting allows us to compare accuracy
+ # without worrying about the impact of decomposisitons. More details at
+ # https://github.com/pytorch/torchdynamo/issues/611
+ BACKENDS["aot_nvfuser_nodecomps"] = aot_mem_efficient_fusion_no_decomp
+
+ # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
+ # for debugging and can serve as a perf baseline.
+ BACKENDS["aot_cudagraphs"] = aot_cudagraphs
+
+ # aot_inductor_debug just replaces the inductor compiler with nop to help
+ # isolate inductor vs aot_eager errors
+ BACKENDS["aot_inductor_debug"] = aot_inductor_debug
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
new file mode 100644
index 0000000000000..7a739b7414657
--- /dev/null
+++ b/torch/_dynamo/output_graph.py
@@ -0,0 +1,523 @@
+import collections
+import functools
+import itertools
+import logging
+import operator
+import re
+import traceback
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional
+
+import torch.nn
+from torch import fx
+
+from . import config, logging as torchdynamo_logging, variables
+from .bytecode_transformation import create_instruction, Instruction, unique_id
+from .codegen import PyCodegen
+from .exc import BackendCompilerFailed, unimplemented
+from .guards import GuardBuilder
+from .mutation_guard import is_dynamic_nn_module
+from .side_effects import SideEffects
+from .source import ConstantSource, LocalSource, Source
+from .utils import (
+ CleanupHook,
+ count_calls,
+ counters,
+ fake_tensors_available,
+ format_graph_tabular,
+)
+from .variables.builder import VariableBuilder
+from .variables.nn_module import NNModuleVariable
+from .variables.tensor import (
+ TensorVariable,
+ UnspecializedNumpyVariable,
+ UnspecializedPythonVariable,
+)
+
+log = logging.getLogger(__name__)
+
+
+@dataclass
+class GraphCompileReason:
+ """Stores why a given output graph was compiled; i.e. what caused the graph break."""
+
+ reason: str
+ user_stack: List[traceback.FrameSummary]
+
+
+def _get_gen_rand_values_fn(random_calls):
+ def _gen_rand_values():
+ return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
+
+ return _gen_rand_values
+
+
+class FakeRootModule(torch.nn.Module):
+ """Trick the constructor of fx.GraphModule"""
+
+ def __init__(self, nn_modules: dict):
+ super(FakeRootModule, self).__init__()
+ for k, v in nn_modules.items():
+ setattr(self, k, v)
+
+ def __repr__(self):
+ return "FakeRootModule(...)"
+
+
+@functools.lru_cache(None)
+def _step_logger():
+ return torchdynamo_logging.get_step_logger(log)
+
+
+class OutputGraph(fx.Tracer):
+ """
+ Wrapper class to hold outputs of InstructionTranslator. Mainly the
+ generated fx.Graph.
+ """
+
+ def __init__(
+ self,
+ f_globals: Dict[str, Any],
+ code_options: Dict[str, Any],
+ compiler_fn: Callable,
+ root_tx,
+ ):
+ super(OutputGraph, self).__init__()
+
+ # Mutable state checkpointed by copy_graphstate()
+ self.graph = torch.fx.Graph()
+ self.graphargs = []
+ self.guards = set()
+ self.nn_modules = dict()
+ self.side_effects = SideEffects()
+ self.code_options = dict(code_options)
+ self.output_instructions = []
+
+ # Not checkpointed
+ self.compiler_fn = compiler_fn
+ self.root_globals = f_globals
+ self.root_tx = root_tx
+ self.cleanups = []
+ self.should_exit = False
+ self.random_values_var = None
+ self.initial_random_state = ()
+ self.unspec_variable_map = {}
+
+ @property
+ def output(self):
+ return self
+
+ @property
+ def fake_mode(self):
+ return self.root_tx.fake_mode
+
+ def copy_graphstate(self):
+ """Create a checkpoint of the current state by copying everything"""
+ graph_nodes = set(self.graph.nodes)
+ return (
+ graph_nodes,
+ list(self.graphargs),
+ set(self.guards),
+ dict(self.nn_modules),
+ self.side_effects.clone(),
+ )
+
+ def restore_graphstate(self, state):
+ """Restore a checkpoint created by self.copy_graphstate()"""
+ (
+ graph_nodes,
+ self.graphargs,
+ self.guards,
+ self.nn_modules,
+ self.side_effects,
+ ) = state
+ # FX deepcopy doesn't work for a partially created graph, so just remove new nodes
+ for node in reversed(list(self.graph.nodes)):
+ if node not in graph_nodes:
+ # Erasing node alone does not remove the meta information
+ # So, remove the help tensor explicitly
+ if "example_value" in node.meta:
+ del node.meta["example_value"]
+ self.graph.erase_node(node)
+
+ def count_calls(self):
+ return count_calls(self.graph)
+
+ def get_submodule(self, keys):
+ assert keys
+ obj = self.nn_modules
+ for k in keys.split("."):
+ if isinstance(obj, dict):
+ obj = obj[k]
+ else:
+ obj = getattr(obj, k)
+ return obj
+
+ def create_graph_input(self, name, type_expr=None):
+ placeholders = [n for n in self.graph.nodes if n.op == "placeholder"]
+
+ # unique
+ used_names = {n.target for n in placeholders}
+ if name in used_names:
+ for i in itertools.count():
+ if f"{name}_{i}" not in used_names:
+ name = f"{name}_{i}"
+ break
+
+ if placeholders:
+ ctx = self.graph.inserting_after(placeholders[-1])
+ else:
+ ctx = self.graph.inserting_before(None)
+ with ctx:
+ return self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
+
+ def new_var(self, name="tmp"):
+ existing = set(self.code_options["co_varnames"])
+ for i in itertools.count():
+ var = f"___{name}_{i}"
+ if var not in existing:
+ self.code_options["co_varnames"] = self.code_options["co_varnames"] + (
+ var,
+ )
+ return var
+
+ def update_co_names(self, name):
+ """Ensure self.code_options.co_names contains name"""
+ if name not in self.code_options["co_names"]:
+ self.code_options["co_names"] = tuple(self.code_options["co_names"]) + (
+ name,
+ )
+
+ def register_attr_or_module(self, mod: torch.nn.Module, *names, **options):
+ if is_dynamic_nn_module(mod):
+ return variables.UnspecializedNNModuleVariable(mod, **options)
+
+ options = dict(options)
+ options["guards"] = set(options.get("guards", []))
+ source: Source = options.get("source", None)
+ if isinstance(mod, torch.Tensor):
+ if source:
+ options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
+
+ def wrap_name(module_key):
+ return TensorVariable.create(
+ self,
+ self.create_proxy("get_attr", module_key, tuple(), {}),
+ example_value=mod,
+ **options,
+ )
+
+ elif isinstance(mod, torch.nn.Module):
+ assert isinstance(mod, torch.nn.Module)
+ options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
+
+ def wrap_name(module_key):
+ return NNModuleVariable(type(mod), module_key, **options)
+
+ else:
+
+ def wrap_name(module_key):
+ self.output.update_co_names(module_key)
+ self.root_globals[module_key] = mod
+ return VariableBuilder(self, ConstantSource(source_name=module_key))(
+ mod
+ )
+
+ for k, v in self.nn_modules.items():
+ if v is mod:
+ # it already exists
+ return wrap_name(k)
+
+ # create a new unique name
+ name = re.sub(r"[^a-zA-Z0-9]", "_", "_".join(map(str, names)))
+ if not name or not name[0].isalpha():
+ name = "sub" + name
+ base = name
+ for i in itertools.count():
+ if name not in self.nn_modules:
+ self.nn_modules[name] = mod
+ return wrap_name(name)
+ name = f"{base}_{i}"
+
+ raise AssertionError("unreachable")
+
+ def compile_subgraph(
+ self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
+ ):
+ """
+ Generate a subgraph to continue execution on user code.
+ Automatically restore live variables.
+ """
+ from .eval_frame import disable
+
+ self.partial_convert = partial_convert
+ self.compile_subgraph_reason = reason
+
+ if not all(block.can_restore() for block in tx.block_stack):
+ unimplemented("compile_subgraph with block_depth != 0")
+
+ for block in reversed(tx.block_stack):
+ block.exit(tx)
+
+ tx.prune_dead_locals()
+ stack_values = list(tx.stack)
+ root = FakeRootModule(self.nn_modules)
+
+ # Add all the local vars to the "stack" so restore at the end
+ restore_vars = []
+ val_to_names = collections.OrderedDict()
+ if stack_values:
+ val_to_names[stack_values[-1]] = list()
+ for k, v in tx.symbolic_locals.items():
+ if isinstance(v.source, LocalSource) and v.source.name() == k:
+ continue # no need to restore initial state
+ if v not in val_to_names:
+ val_to_names[v] = list()
+ val_to_names[v].append(k)
+ for v in val_to_names.keys():
+ restore_vars.extend(val_to_names[v])
+ stack_values.extend([v] * len(val_to_names[v]))
+
+ # to handle random calls
+ if len(tx.random_calls) > 0:
+ random_calls_instructions = []
+ self.random_values_var = self.new_var("random_values")
+ rand_fn_name = unique_id("__gen_rand_values")
+ rand_fn = disable(_get_gen_rand_values_fn(tx.random_calls))
+ self.install_global(rand_fn_name, rand_fn)
+ codegen = PyCodegen(tx, root)
+ random_calls_instructions.extend(
+ [
+ codegen.create_load_global("random", add=True),
+ codegen.create_load_attr("setstate"),
+ codegen.create_load_const(tx.output.initial_random_state),
+ create_instruction("CALL_FUNCTION", 1),
+ ]
+ )
+ random_calls_instructions.extend(codegen.load_function_name(rand_fn_name))
+ random_calls_instructions.extend(
+ [
+ create_instruction("CALL_FUNCTION", 0),
+ codegen.create_store(tx.output.random_values_var),
+ ]
+ )
+ self.add_output_instructions(random_calls_instructions)
+
+ if (
+ stack_values
+ and all(
+ not isinstance(
+ v, (UnspecializedNumpyVariable, UnspecializedPythonVariable)
+ )
+ for v in stack_values
+ )
+ and all(isinstance(x, TensorVariable) for x in stack_values)
+ and len(set(stack_values)) == len(stack_values)
+ and self.side_effects.is_empty()
+ ):
+ # optimization to generate better code in a common case
+ self.add_output_instructions(
+ self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
+ + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
+ )
+ else:
+ graph_output_var = self.new_var("graph_out")
+ pass1 = PyCodegen(tx, root, graph_output_var)
+ self.side_effects.codegen_save_tempvars(pass1)
+ pass1.foreach(stack_values)
+ self.side_effects.codegen_update_mutated(pass1)
+
+ # one more time now that we have established tempvars
+ pass2 = PyCodegen(
+ tx,
+ root,
+ graph_output_var,
+ tempvars={val: None for val, count in pass1.uses.items() if count > 1},
+ )
+ self.side_effects.codegen_save_tempvars(pass2)
+ pass2.foreach(stack_values)
+ self.side_effects.codegen_update_mutated(pass2)
+
+ output = []
+ if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
+ output.extend(
+ self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
+ )
+
+ if len(pass2.graph_outputs) != 0:
+ output.append(pass2.create_store(graph_output_var))
+ else:
+ output.append(create_instruction("POP_TOP"))
+ self.add_output_instructions(output + pass2.get_instructions())
+
+ # restore all the live local vars
+ self.add_output_instructions(
+ [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
+ )
+
+ def compile_and_call_fx_graph(self, tx, rv, root):
+ """
+ Generate code from self.graph and return the Instruction()s to
+ call that generated code.
+ """
+ from .eval_frame import disable
+
+ assert isinstance(rv, list)
+ assert isinstance(root, FakeRootModule)
+ for output in rv:
+ self.guards.update(output.guards)
+
+ self.create_node(
+ "output", "output", (self.create_arg(tuple(x.as_proxy() for x in rv)),), {}
+ )
+ self.remove_unused_graphargs()
+ ncalls = count_calls(self.graph)
+ counters["stats"]["calls_captured"] += ncalls
+ counters["stats"]["fusions_possible"] += ncalls - 1
+
+ if config.dynamic_propagation:
+ # free a bit of memory
+ for node in self.graph.nodes:
+ if "example_value" in node.meta:
+ del node.meta["example_value"]
+
+ gm = fx.GraphModule(root, self.graph)
+ gm.recompile()
+ gm.compile_subgraph_reason = self.compile_subgraph_reason
+ name = unique_id("__compiled_fn")
+ compiled_fn = self.call_user_compiler(gm)
+ compiled_fn = disable(compiled_fn)
+ counters["stats"]["unique_graphs"] += 1
+ self.install_global(name, compiled_fn)
+
+ try:
+ # the call to tabulate can cause a lot of memory to be allocated
+ if config.log_level <= logging.INFO:
+ log.log(
+ torchdynamo_logging.CODE,
+ f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {format_graph_tabular(gm.graph)}\n",
+ )
+ except ImportError:
+ log.warning(
+ "Unable to print graph: `format_graph_tabular` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
+
+ cg = PyCodegen(tx)
+ cg.make_call_generated_code(name)
+ return cg.get_instructions()
+
+ def call_user_compiler(self, gm):
+ try:
+ _step_logger()(logging.INFO, "calling compiler function")
+ compiled_fn = self.compiler_fn(gm, self.example_inputs())
+ _step_logger()(logging.INFO, "done compiler function")
+ assert callable(compiled_fn), "compiler_fn did not return callable"
+ except Exception as e:
+ log.warning("-" * 40 + "\n")
+ log.warning("TORCHDYNAMO: backend compiler failed\n")
+ log.warning(e, exc_info=True)
+ log.warning("-" * 40 + "\n")
+ compiled_fn = gm.forward
+ if config.raise_on_backend_error:
+ raise BackendCompilerFailed(self.compiler_fn, e) from e
+ return compiled_fn
+
+ def example_inputs(self):
+ result = []
+ for arg in self.graphargs:
+ result.extend(arg.get_examples())
+ return result
+
+ def remove_unused_graphargs(self):
+ for node in reversed(list(self.graph.nodes)):
+ if len(list(node.users)) == 0:
+ if node.op == "get_attr":
+ self.graph.erase_node(node)
+ elif node.op == "call_function" and node.target is operator.getitem:
+ self.graph.erase_node(node)
+
+ expanded_graphargs = []
+ for arg in self.graphargs:
+ expanded_graphargs.extend([arg] * len(arg))
+ arg.uses = 0
+
+ for node, arg in zip(self.graph.nodes, expanded_graphargs):
+ assert node.op == "placeholder"
+ arg.uses += len(node.users)
+
+ for node, arg in list(zip(self.graph.nodes, expanded_graphargs)):
+ if arg.uses == 0:
+ if "example_value" in node.meta:
+ del node.meta["example_value"]
+ self.graph.erase_node(node)
+
+ self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
+
+ def add_output_instructions(self, prefix: List[Instruction]):
+ """
+ We call this on the creation of a new compiled subgraph that is inserted
+ before user code.
+ """
+ self.output_instructions.extend(prefix)
+ self.should_exit = True
+
+ def install_global(self, name, value):
+ self.cleanups.append(CleanupHook.create(self.root_globals, name, value))
+
+ def cleanup(self):
+ # There is a reference cycle between tracer and OutputGraph, causing
+ # some of the tensor objects to be held alive for longer than necessary.
+
+ # Clear cache for conversion of real -> fake tensors
+ if fake_tensors_available:
+ self.root_tx.fake_mode.fake_tensor_converter = None
+ self.root_tx = None
+
+ # Note: generated fx graph will hold a reference to the nn_module,
+ # So depending on the backend they may not be released
+ self.nn_modules = None
+
+ # Cleanup graphargs
+ for graph_arg in self.graphargs:
+ graph_arg.erase()
+
+ for node in self.graph.nodes:
+ if "example_value" in node.meta:
+ del node.meta["example_value"]
+
+ def create_proxy(
+ self,
+ kind,
+ target,
+ args,
+ kwargs,
+ name=None,
+ type_expr=None,
+ proxy_factory_fn=None,
+ current_tx=None,
+ ):
+ rv = super().create_proxy(
+ kind, target, args, kwargs, name, type_expr, proxy_factory_fn
+ )
+
+ # append stack trace to fx node
+ tx = current_tx if current_tx else self.root_tx
+
+ nn_module_stack = tx.nn_module_stack
+ if nn_module_stack:
+ rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
+
+ frame_summaries: List[traceback.FrameSummary] = []
+ while tx:
+ frame_summaries.append(tx.frame_summary())
+ tx = getattr(tx, "parent", None)
+
+ msgs = traceback.StackSummary.from_list(frame_summaries).format()
+
+ # Carry module_stack along with node.stack_trace for reusing stacktrace propagation infra
+ nn_module_stack_str = f"Module stack: {nn_module_stack}\n"
+ rv.node.stack_trace = nn_module_stack_str + " | ".join(msgs)
+
+ return rv
diff --git a/torch/_dynamo/profiler.py b/torch/_dynamo/profiler.py
new file mode 100644
index 0000000000000..b5a667070a8cd
--- /dev/null
+++ b/torch/_dynamo/profiler.py
@@ -0,0 +1,177 @@
+import dataclasses
+import os
+from typing import Any, List
+
+import torch
+
+from . import config
+from .utils import print_once
+
+
+@dataclasses.dataclass
+class ProfileMetrics:
+ microseconds: float = 0.0
+ operators: int = 0
+ fusions: int = 0
+ graphs: int = 0
+
+ def __iadd__(self, other: "ProfileMetrics"):
+ self.microseconds += other.microseconds
+ self.operators += other.operators
+ self.fusions += other.fusions
+ return self
+
+ def __add__(self, other: "ProfileMetrics"):
+ assert isinstance(other, ProfileMetrics)
+ return ProfileMetrics(
+ self.microseconds + other.microseconds,
+ self.operators + other.operators,
+ self.fusions + other.fusions,
+ )
+
+ def __truediv__(self, other):
+ if isinstance(other, int):
+ other = ProfileMetrics(other, other, other)
+ return ProfileMetrics(
+ self.microseconds / max(1, other.microseconds),
+ self.operators / max(1, other.operators),
+ self.fusions / max(1, other.fusions),
+ )
+
+ def __str__(self):
+ return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
+
+ def tocsv(self):
+ return [self.operators, self.microseconds]
+
+
+class ProfileResult:
+ def __init__(self, captured, total, unique_graphs):
+ self.captured: ProfileMetrics = captured or ProfileMetrics()
+ self.total: ProfileMetrics = total or ProfileMetrics()
+ self.unique_graphs: int = unique_graphs
+
+ def __iadd__(self, other: ProfileMetrics):
+ self.captured += other.captured
+ self.total += other.total
+ self.unique_graphs += other.unique_graphs
+ return self
+
+ def percent(self):
+ return self.captured / self.total
+
+ def __str__(self):
+ return (
+ f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
+ f"{self.captured.operators:4}/{self.total.operators:4} = "
+ + str(self.percent())
+ )
+
+ def tocsv(self):
+ return [
+ self.unique_graphs,
+ self.captured.graphs,
+ self.captured.operators,
+ self.total.operators,
+ ] + self.percent().tocsv()
+
+
+def should_print_missing():
+ return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1"
+
+
+def print_missing(stack):
+ if any("/torch/autograd/profiler.py" in x for x in stack):
+ return
+ stack = [
+ x for x in stack if ("> ".join(stack[-3:]))
+
+
+class Profiler:
+ unique_graphs = 0
+
+ def __init__(self):
+ self.prof = torch.profiler.profile(
+ activities=[torch.profiler.ProfilerActivity.CPU],
+ with_stack=should_print_missing(),
+ )
+
+ def results(self):
+ captured_regions = 0
+ captured_ops = 0
+ captured_microseconds = 0
+ total_ops = 0
+ total_microseconds = 0
+
+ last_op_end_time = -1
+ captured_region_end_time = -1
+ events = list(sorted(self.prof.events(), key=lambda x: x.time_range.start))
+ for e in events:
+ if e.name == "TORCHDYNAMO":
+ captured_region_end_time = e.time_range.end
+ captured_regions += 1
+ # ignore `handle = torch.zeros(1)` in record_function.__init__()
+ total_ops -= 1
+ elif e.time_range.start >= last_op_end_time:
+ last_op_end_time = e.time_range.end
+ if e.time_range.end <= captured_region_end_time:
+ captured_ops += 1
+ captured_microseconds += e.time_range.elapsed_us()
+ elif should_print_missing():
+ print_missing(e.stack)
+ total_ops += 1
+ total_microseconds += e.time_range.elapsed_us()
+ else:
+ pass # ops recursively called from other ops (ignored)
+
+ unique_graphs = Profiler.unique_graphs
+ Profiler.unique_graphs = 0
+
+ return ProfileResult(
+ captured=ProfileMetrics(
+ microseconds=captured_microseconds,
+ operators=captured_ops,
+ fusions=captured_ops - captured_regions,
+ graphs=captured_regions,
+ ),
+ total=ProfileMetrics(
+ microseconds=total_microseconds,
+ operators=total_ops,
+ fusions=total_ops - 1,
+ ),
+ unique_graphs=unique_graphs,
+ )
+
+
+def shapes_of(it):
+ if it:
+ return [tuple(getattr(x, "shape", [])) for x in it]
+
+
+def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: List[Any]):
+ input_shapes = shapes_of(example_inputs)
+ output_shapes = None
+
+ def debug_print(extra):
+ gm.graph.print_tabular()
+ return f"shape mismatch in={input_shapes} out={output_shapes} got={extra}"
+
+ def _wrapped(*args):
+ nonlocal output_shapes
+ with torch.profiler.record_function("TORCHDYNAMO"):
+ assert (
+ shapes_of(args) == input_shapes or config.dynamic_shapes
+ ), debug_print(shapes_of(args))
+ result = gm.forward(*args)
+ if output_shapes is None:
+ output_shapes = shapes_of(result)
+ else:
+ assert (
+ shapes_of(result) == output_shapes or config.dynamic_shapes
+ ), debug_print(shapes_of(result))
+ return result
+
+ Profiler.unique_graphs += 1
+ return _wrapped
diff --git a/torch/_dynamo/replay_record.py b/torch/_dynamo/replay_record.py
new file mode 100644
index 0000000000000..f09d9bf9c8783
--- /dev/null
+++ b/torch/_dynamo/replay_record.py
@@ -0,0 +1,118 @@
+import dataclasses
+from dataclasses import field
+from types import CodeType, ModuleType
+from typing import Any, Dict
+
+try:
+ import dill
+except ImportError:
+ dill = None
+
+
+@dataclasses.dataclass
+class ModuleRecord:
+ module: ModuleType
+ accessed_attrs: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclasses.dataclass
+class DummyModule:
+ name: str
+
+
+@dataclasses.dataclass
+class ExecutionRecord:
+ code: CodeType
+ globals: Dict[str, Any] = field(default_factory=dict)
+ locals: Dict[str, Any] = field(default_factory=dict)
+ builtins: Dict[str, Any] = field(default_factory=dict)
+ code_options: Dict[str, Any] = field(default_factory=dict)
+
+ def dump(self, f):
+ assert dill is not None, "replay_record requires `pip install dill`"
+ dill.dump(self, f)
+
+ @classmethod
+ def load(cls, f):
+ assert dill is not None, "replay_record requires `pip install dill`"
+ return dill.load(f)
+
+
+@dataclasses.dataclass
+class ExecutionRecorder:
+ MOD_EXCLUDES = ["torch"]
+ LOCAL_MOD_PREFIX = "___local_mod_"
+
+ code: CodeType
+ globals: Dict[str, Any] = field(default_factory=dict)
+ locals: Dict[str, Any] = field(default_factory=dict)
+ builtins: Dict[str, Any] = field(default_factory=dict)
+ code_options: Dict[str, Any] = field(default_factory=dict)
+ name_to_modrec: Dict[str, Any] = field(default_factory=dict)
+
+ def add_local_var(self, name, var):
+ if isinstance(var, ModuleType):
+ if self._is_excl(var):
+ return
+ self.locals[name] = self._add_mod(var)
+ else:
+ self.locals[name] = var
+
+ def add_global_var(self, name, var):
+ if isinstance(var, ModuleType):
+ if self._is_excl(var):
+ return
+ self.globals[name] = self._add_mod(var)
+ else:
+ self.globals[name] = var
+
+ def add_local_mod(self, name, mod):
+ assert isinstance(mod, ModuleType)
+ if self._is_excl(mod):
+ return
+
+ self.add_global_var(name, mod)
+
+ def record_module_access(self, mod, name, val):
+ if self._is_excl(mod):
+ return
+ if isinstance(val, ModuleType):
+ self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val)
+ return
+
+ self.name_to_modrec[mod.__name__].accessed_attrs[name] = val
+
+ def get_record(self):
+ return ExecutionRecord(
+ self.code,
+ ExecutionRecorder._resolve_modules(self.globals),
+ ExecutionRecorder._resolve_modules(self.locals),
+ self.builtins.copy(),
+ self.code_options.copy(),
+ )
+
+ def _add_mod(self, mod):
+ if mod.__name__ not in self.name_to_modrec:
+ self.name_to_modrec[mod.__name__] = ModuleRecord(mod)
+
+ return self.name_to_modrec[mod.__name__]
+
+ @classmethod
+ def _is_excl(cls, mod):
+ return any([mod.__name__ == excl for excl in cls.MOD_EXCLUDES])
+
+ # Convert ModuleRecords -> DummyModule tree
+ @classmethod
+ def _resolve_modules(cls, vars):
+ def resolve_module(var):
+ if not isinstance(var, ModuleRecord):
+ return var
+
+ dummy_mod = DummyModule(var.module.__name__)
+ for attr_name, attr_value in var.accessed_attrs.items():
+ attr_value = resolve_module(attr_value)
+ dummy_mod.__setattr__(attr_name, attr_value)
+
+ return dummy_mod
+
+ return {k: resolve_module(v) for k, v in vars.items()}
diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py
new file mode 100644
index 0000000000000..c05f610d67124
--- /dev/null
+++ b/torch/_dynamo/resume_execution.py
@@ -0,0 +1,304 @@
+import copy
+import dataclasses
+import sys
+import types
+from typing import Any, Dict, List
+
+from .bytecode_transformation import (
+ create_instruction,
+ Instruction,
+ transform_code_object,
+)
+from .codegen import PyCodegen
+from .utils import ExactWeakKeyDictionary
+
+# taken from code.h in cpython
+CO_OPTIMIZED = 0x0001
+CO_NEWLOCALS = 0x0002
+CO_VARARGS = 0x0004
+CO_VARKEYWORDS = 0x0008
+CO_NESTED = 0x0010
+CO_GENERATOR = 0x0020
+CO_NOFREE = 0x0040
+CO_COROUTINE = 0x0080
+CO_ITERABLE_COROUTINE = 0x0100
+CO_ASYNC_GENERATOR = 0x0200
+
+
+@dataclasses.dataclass(frozen=True)
+class ReenterWith:
+ stack_index: int = None
+
+ def __call__(self, code_options, cleanup):
+ if sys.version_info < (3, 9):
+ with_cleanup_start = create_instruction("WITH_CLEANUP_START")
+ if sys.version_info < (3, 8):
+ begin_finally = create_instruction(
+ "LOAD_CONST", PyCodegen.get_const_index(code_options, None), None
+ )
+ else:
+ begin_finally = create_instruction("BEGIN_FINALLY")
+ cleanup[:] = [
+ create_instruction("POP_BLOCK"),
+ begin_finally,
+ with_cleanup_start,
+ create_instruction("WITH_CLEANUP_FINISH"),
+ create_instruction("END_FINALLY"),
+ ] + cleanup
+
+ return [
+ create_instruction("CALL_FUNCTION", 0),
+ create_instruction("SETUP_WITH", target=with_cleanup_start),
+ create_instruction("POP_TOP"),
+ ]
+ else:
+
+ with_except_start = create_instruction("WITH_EXCEPT_START")
+ pop_top_after_with_except_start = create_instruction("POP_TOP")
+
+ cleanup_complete_jump_target = create_instruction("NOP")
+
+ cleanup[:] = [
+ create_instruction("POP_BLOCK"),
+ create_instruction(
+ "LOAD_CONST", PyCodegen.get_const_index(code_options, None), None
+ ),
+ create_instruction("DUP_TOP"),
+ create_instruction("DUP_TOP"),
+ create_instruction("CALL_FUNCTION", 3),
+ create_instruction("POP_TOP"),
+ create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
+ with_except_start,
+ create_instruction(
+ "POP_JUMP_IF_TRUE", target=pop_top_after_with_except_start
+ ),
+ create_instruction("RERAISE"),
+ pop_top_after_with_except_start,
+ create_instruction("POP_TOP"),
+ create_instruction("POP_TOP"),
+ create_instruction("POP_EXCEPT"),
+ create_instruction("POP_TOP"),
+ cleanup_complete_jump_target,
+ ] + cleanup
+
+ return [
+ create_instruction("CALL_FUNCTION", 0),
+ create_instruction("SETUP_WITH", target=with_except_start),
+ create_instruction("POP_TOP"),
+ ]
+
+
+@dataclasses.dataclass
+class ResumeFunctionMetadata:
+ code: types.CodeType
+ instructions: List[Instruction] = None
+
+
+class ContinueExecutionCache:
+ cache = ExactWeakKeyDictionary()
+ generated_code_metadata = ExactWeakKeyDictionary()
+
+ @classmethod
+ def lookup(cls, code, lineno, *key):
+ if code not in cls.cache:
+ cls.cache[code] = dict()
+ key = tuple(key)
+ if key not in cls.cache[code]:
+ cls.cache[code][key] = cls.generate(code, lineno, *key)
+ return cls.cache[code][key]
+
+ @classmethod
+ def generate(
+ cls,
+ code,
+ lineno,
+ offset: int,
+ nstack: int,
+ argnames: List[str],
+ setup_fns: List[ReenterWith],
+ ):
+ assert offset is not None
+ assert not (
+ code.co_flags
+ & (CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR)
+ )
+ assert code.co_flags & CO_OPTIMIZED
+ if code in ContinueExecutionCache.generated_code_metadata:
+ return cls.generate_based_on_original_code_object(
+ code, lineno, offset, nstack, argnames, setup_fns
+ )
+
+ meta = ResumeFunctionMetadata(code)
+
+ def update(instructions: List[Instruction], code_options: Dict[str, Any]):
+ meta.instructions = copy.deepcopy(instructions)
+
+ args = [f"___stack{i}" for i in range(nstack)]
+ args.extend(v for v in argnames if v not in args)
+ freevars = tuple(code_options["co_cellvars"] or []) + tuple(
+ code_options["co_freevars"] or []
+ )
+ code_options["co_name"] = f""
+ code_options["co_firstlineno"] = lineno
+ code_options["co_cellvars"] = tuple()
+ code_options["co_freevars"] = freevars
+ code_options["co_argcount"] = len(args)
+ code_options["co_posonlyargcount"] = 0
+ code_options["co_kwonlyargcount"] = 0
+ code_options["co_varnames"] = tuple(
+ args + [v for v in code_options["co_varnames"] if v not in args]
+ )
+ code_options["co_flags"] = code_options["co_flags"] & ~(
+ CO_VARARGS | CO_VARKEYWORDS
+ )
+ (target,) = [i for i in instructions if i.offset == offset]
+
+ prefix = []
+ cleanup = []
+ hooks = {fn.stack_index: fn for fn in setup_fns}
+ for i in range(nstack):
+ prefix.append(create_instruction("LOAD_FAST", f"___stack{i}"))
+ if i in hooks:
+ prefix.extend(hooks.pop(i)(code_options, cleanup))
+ assert not hooks
+
+ prefix.append(create_instruction("JUMP_ABSOLUTE", target=target))
+
+ # because the line number table monotonically increases from co_firstlineno
+ # remove starts_line for any instructions before the graph break instruction
+ # this will ensure the instructions after the break have the correct line numbers
+ target_ind = int(target.offset / 2)
+ for inst in instructions[0:target_ind]:
+ inst.starts_line = None
+
+ if cleanup:
+ prefix.extend(cleanup)
+ prefix.extend(cls.unreachable_codes(code_options))
+
+ # TODO(jansel): add dead code elimination here
+ instructions[:] = prefix + instructions
+
+ new_code = transform_code_object(code, update)
+ ContinueExecutionCache.generated_code_metadata[new_code] = meta
+ return new_code
+
+ @staticmethod
+ def unreachable_codes(code_options):
+ """Codegen a `raise None` to make analysis work for unreachable code"""
+ if None not in code_options["co_consts"]:
+ code_options["co_consts"] = tuple(code_options["co_consts"]) + (None,)
+ return [
+ create_instruction(
+ "LOAD_CONST",
+ argval=None,
+ arg=code_options["co_consts"].index(None),
+ ),
+ create_instruction("RAISE_VARARGS", 1),
+ ]
+
+ @classmethod
+ def generate_based_on_original_code_object(cls, code, lineno, offset: int, *args):
+ """
+ This handles the case of generating a resume into code generated
+ to resume something else. We want to always generate starting
+ from the original code object so that if control flow paths
+ converge we only generated 1 resume function (rather than 2^n
+ resume functions).
+ """
+
+ meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[
+ code
+ ]
+ new_offset = None
+
+ def find_new_offset(
+ instructions: List[Instruction], code_options: Dict[str, Any]
+ ):
+ nonlocal new_offset
+ (target,) = [i for i in instructions if i.offset == offset]
+ # match the functions starting at the last instruction as we have added a prefix
+ (new_target,) = [
+ i2
+ for i1, i2 in zip(reversed(instructions), reversed(meta.instructions))
+ if i1 is target
+ ]
+ assert target.opcode == new_target.opcode
+ new_offset = new_target.offset
+
+ transform_code_object(code, find_new_offset)
+ return ContinueExecutionCache.lookup(meta.code, lineno, new_offset, *args)
+
+
+"""
+# partially finished support for with statements
+
+def convert_locals_to_cells(
+ instructions: List[Instruction],
+ code_options: Dict[str, Any]):
+
+ code_options["co_cellvars"] = tuple(
+ var
+ for var in code_options["co_varnames"]
+ if var not in code_options["co_freevars"]
+ and not var.startswith("___stack")
+ )
+ cell_and_free = code_options["co_cellvars"] + code_options["co_freevars"]
+ for inst in instructions:
+ if str(inst.argval).startswith("___stack"):
+ continue
+ elif inst.opname == "LOAD_FAST":
+ inst.opname = "LOAD_DEREF"
+ elif inst.opname == "STORE_FAST":
+ inst.opname = "STORE_DEREF"
+ elif inst.opname == "DELETE_FAST":
+ inst.opname = "DELETE_DEREF"
+ else:
+ continue
+ inst.opcode = dis.opmap[inst.opname]
+ assert inst.argval in cell_and_free, inst.argval
+ inst.arg = cell_and_free.index(inst.argval)
+
+def patch_setup_with(
+ instructions: List[Instruction],
+ code_options: Dict[str, Any]
+):
+ nonlocal need_skip
+ need_skip = True
+ target_index = [
+ idx for idx, i in enumerate(instructions) if i.offset == offset
+ ][0]
+ assert instructions[target_index].opname == "SETUP_WITH"
+ convert_locals_to_cells(instructions, code_options)
+
+ stack_depth_before = nstack + stack_effect(instructions[target_index].opcode,
+ instructions[target_index].arg)
+
+ inside_with = []
+ inside_with_resume_at = None
+ stack_depth = stack_depth_before
+ idx = target_index + 1
+ for idx in range(idx, len(instructions)):
+ inst = instructions[idx]
+ if inst.opname == "BEGIN_FINALLY":
+ inside_with_resume_at = inst
+ break
+ elif inst.target is not None:
+ unimplemented("jump from with not supported")
+ elif inst.opname in ("BEGIN_FINALLY", "WITH_CLEANUP_START", "WITH_CLEANUP_FINISH", "END_FINALLY",
+ "POP_FINALLY", "POP_EXCEPT",
+ "POP_BLOCK", "END_ASYNC_FOR"):
+ unimplemented("block ops not supported")
+ inside_with.append(inst)
+ stack_depth += stack_effect(inst.opcode, inst.arg)
+ assert inside_with_resume_at
+
+ instructions = [
+ create_instruction("LOAD_FAST", f"___stack{i}") for i in range(nstack)
+ ] + [
+ create_instruction("SETUP_WITH", target=instructions[target_index].target)
+ ... call the function ...
+ unpack_tuple
+ ] + [
+ create_instruction("JUMP_ABSOLUTE", target=inside_with_resume_at)
+ ]
+"""
diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py
new file mode 100644
index 0000000000000..1f8675ae1c9e3
--- /dev/null
+++ b/torch/_dynamo/side_effects.py
@@ -0,0 +1,336 @@
+import collections
+import dataclasses
+import inspect
+from typing import Any
+
+import torch.nn
+
+from . import utils, variables
+from .bytecode_transformation import create_instruction
+from .codegen import PyCodegen
+from .source import LocalSource, Source
+from .utils import object_new
+from .variables.base import VariableTracker
+
+
+@dataclasses.dataclass
+class MutableSideEffects:
+ """
+ VariableTracker.mutable_local marker to indicate a list passed as
+ an input that if we mutate we need to re-apply those mutations after
+ the graph runs.
+ """
+
+ source: Source
+ is_modified: bool = False
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return self is other
+
+
+@dataclasses.dataclass
+class AttributeMutation:
+ """
+ VariableTracker.mutable_local marker to track changes to attributes
+ """
+
+ source: Source
+
+
+class AttributeMutationExisting(AttributeMutation):
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return self is other
+
+
+@dataclasses.dataclass
+class AttributeMutationNew(AttributeMutation):
+ cls_source: Source
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return self is other
+
+
+class SideEffects(object):
+ """
+ Track side effects (list mutation, setattr, etc) that need to be
+ applied after an FX graph is run.
+ """
+
+ def __init__(self, id_to_variable=None, store_attr_mutations=None, keepalive=None):
+ super(SideEffects, self).__init__()
+ self.id_to_variable = id_to_variable or collections.OrderedDict()
+ self.store_attr_mutations = store_attr_mutations or collections.OrderedDict()
+ self.keepalive = keepalive or []
+
+ def clone(self):
+ """Create a shallow copy"""
+ return self.__class__(
+ id_to_variable=collections.OrderedDict(self.id_to_variable),
+ store_attr_mutations=collections.OrderedDict(
+ (k, collections.OrderedDict(v))
+ for k, v in self.store_attr_mutations.items()
+ ),
+ keepalive=list(self.keepalive),
+ )
+
+ def apply(self, fn, cache=None):
+ if cache is None:
+ cache = dict()
+
+ self.id_to_variable = collections.OrderedDict(
+ (k, VariableTracker.apply(fn, v, cache))
+ for k, v in self.id_to_variable.items()
+ )
+ self.store_attr_mutations = collections.OrderedDict(
+ (k, VariableTracker.apply(fn, v, cache))
+ for k, v in self.store_attr_mutations.items()
+ )
+
+ def __contains__(self, item):
+ return id(item) in self.id_to_variable
+
+ def __getitem__(self, item):
+ return self.id_to_variable[id(item)]
+
+ def store_attr(self, item: VariableTracker, name: str, value: VariableTracker):
+ assert self.is_attribute_mutation(item)
+ if item.mutable_local not in self.store_attr_mutations:
+ self.store_attr_mutations[item.mutable_local] = collections.OrderedDict()
+ self.store_attr_mutations[item.mutable_local][name] = value
+
+ def load_attr(self, item, name):
+ assert self.is_attribute_mutation(item)
+ return self.store_attr_mutations[item.mutable_local][name]
+
+ def store_cell(self, cellvar, value):
+ assert isinstance(cellvar, variables.NewCellVariable)
+ assert isinstance(value, variables.VariableTracker)
+ self.store_attr(cellvar, "cell_contents", value)
+
+ def load_cell(self, cellvar):
+ assert isinstance(cellvar, variables.NewCellVariable)
+ return self.load_attr(cellvar, "cell_contents")
+
+ def load_global(self, gvar: VariableTracker, name: str):
+ assert isinstance(gvar, variables.VariableTracker)
+ return self.load_attr(gvar, name)
+
+ def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker):
+ assert isinstance(gvar, variables.VariableTracker)
+ assert isinstance(value, variables.VariableTracker)
+ self.store_attr(gvar, name, value)
+
+ @staticmethod
+ def cls_supports_mutation_side_effects(cls):
+ return inspect.getattr_static(cls, "__setattr__", None) in (
+ object.__setattr__,
+ torch.nn.Module.__setattr__,
+ )
+
+ def is_attribute_mutation(self, item):
+ return isinstance(item.mutable_local, AttributeMutation)
+
+ def is_modified(self, item):
+ if isinstance(item.mutable_local, AttributeMutationNew):
+ return True
+ if self.is_attribute_mutation(item):
+ return item.mutable_local in self.store_attr_mutations
+ return item.mutable_local.is_modified
+
+ def _track_obj(
+ self,
+ source: Source,
+ item: Any,
+ variable: VariableTracker,
+ mutable_cls=MutableSideEffects,
+ ):
+ """Start tracking a new variable for mutation"""
+ variable = variable.clone(mutable_local=mutable_cls(source), source=source)
+ self.id_to_variable[id(item)] = variable
+ self.keepalive.append(item)
+ return variable
+
+ track_list = _track_obj
+ track_dict = _track_obj
+
+ def track_object_existing(
+ self,
+ source: Source,
+ item: Any,
+ variable: VariableTracker,
+ ):
+ return self._track_obj(
+ source, item, variable, mutable_cls=AttributeMutationExisting
+ )
+
+ def track_object_new(
+ self,
+ cls_source: Source,
+ user_cls: Any,
+ variable_cls: Any,
+ options,
+ ):
+ obj = object_new(user_cls)
+ variable = variable_cls(
+ obj, mutable_local=AttributeMutationNew(None, cls_source), **options
+ )
+ self.id_to_variable[id(obj)] = variable
+ self.keepalive.append(obj)
+ return variable
+
+ def track_cell_new(
+ self,
+ ):
+ obj = object()
+ variable = variables.NewCellVariable(
+ mutable_local=AttributeMutationNew(None, None),
+ )
+ self.id_to_variable[id(obj)] = variable
+ self.keepalive.append(obj)
+ return variable
+
+ def track_cell_existing(self, source: Source, item: Any):
+ variable = variables.NewCellVariable(
+ mutable_local=AttributeMutationExisting(source),
+ )
+ self.id_to_variable[id(item)] = variable
+ self.keepalive.append(item)
+ return variable
+
+ def track_global_existing(self, source: Source, item: Any):
+ variable = variables.NewGlobalVariable(
+ mutable_local=AttributeMutationExisting(source),
+ )
+ self.id_to_variable[id(item)] = variable
+ self.keepalive.append(item)
+ return variable
+
+ def prune_dead_object_new(self, tx):
+ live_new_objects = set()
+ skip_obj = None
+
+ def visit(var: VariableTracker):
+ if (
+ isinstance(var.mutable_local, AttributeMutationNew)
+ and var.mutable_local is not skip_obj
+ ):
+ live_new_objects.add(var.mutable_local)
+ return var
+
+ def is_live(var: VariableTracker):
+ if isinstance(var, AttributeMutationNew):
+ return var in live_new_objects
+ if isinstance(var, VariableTracker):
+ return is_live(var.mutable_local)
+ return True
+
+ VariableTracker.apply(visit, (tx.stack, tx.symbolic_locals))
+ for var in self.id_to_variable.values():
+ if not isinstance(var.mutable_local, AttributeMutationNew):
+ VariableTracker.apply(visit, var)
+
+ for skip_obj, setattrs in self.store_attr_mutations.items():
+ VariableTracker.apply(visit, setattrs)
+
+ self.id_to_variable = collections.OrderedDict(
+ (k, v) for k, v in self.id_to_variable.items() if is_live(v)
+ )
+ self.store_attr_mutations = collections.OrderedDict(
+ (k, v) for k, v in self.store_attr_mutations.items() if is_live(k)
+ )
+
+ def mutation(self, oldvar, newvar):
+ return newvar.clone(
+ mutable_local=MutableSideEffects(oldvar.mutable_local.source, True)
+ )
+
+ def _get_modified_vars(self):
+ return [var for var in self.id_to_variable.values() if self.is_modified(var)]
+
+ def codegen_save_tempvars(self, cg: PyCodegen):
+ for var in self._get_modified_vars():
+ if isinstance(
+ var.mutable_local, (AttributeMutationExisting, AttributeMutationNew)
+ ) and isinstance(var, variables.NewCellVariable):
+ cg.load_import_from(utils.__name__, "make_cell")
+ cg.extend_output([create_instruction("CALL_FUNCTION", 0)])
+ cg.add_cache(var)
+ if isinstance(var.mutable_local, AttributeMutationNew):
+ var.mutable_local.source = LocalSource(cg.tempvars[var])
+ elif isinstance(var.mutable_local, AttributeMutationNew):
+ cg.load_import_from(utils.__name__, "object_new")
+ cg(var.mutable_local.cls_source)
+ cg.extend_output([create_instruction("CALL_FUNCTION", 1)])
+ cg.add_cache(var)
+ var.mutable_local.source = LocalSource(cg.tempvars[var])
+ elif var in cg.tempvars:
+ assert cg.tempvars.get(var) is None
+ # subsequent usage should point to the original variable
+ cg(var.mutable_local.source)
+ cg.add_cache(var)
+
+ def codegen_update_mutated(self, cg: PyCodegen):
+ suffixes = []
+ for var in self._get_modified_vars():
+ if isinstance(var, variables.ListVariable):
+ # old[:] = new
+ cg(var, allow_cache=False)
+ cg(var.mutable_local.source)
+ cg.extend_output(
+ [
+ cg.create_load_const(None),
+ cg.create_load_const(None),
+ create_instruction("BUILD_SLICE", 2),
+ ]
+ )
+ suffixes.append([create_instruction("STORE_SUBSCR")])
+ elif isinstance(var, variables.ConstDictVariable):
+ cg.tx.output.update_co_names("clear")
+ cg.tx.output.update_co_names("update")
+
+ cg(var.mutable_local.source)
+ cg.extend_output([create_instruction("LOAD_METHOD", "update")])
+ cg(var, allow_cache=False)
+
+ cg(var.mutable_local.source)
+ cg.extend_output([create_instruction("LOAD_METHOD", "clear")])
+
+ suffixes.append(
+ [
+ create_instruction("CALL_METHOD", 0), # clear
+ create_instruction("POP_TOP"),
+ create_instruction("CALL_METHOD", 1), # update
+ create_instruction("POP_TOP"),
+ ]
+ )
+ elif self.is_attribute_mutation(var):
+ for name, value in self.store_attr_mutations.get(
+ var.mutable_local, {}
+ ).items():
+ if isinstance(var, variables.NewGlobalVariable):
+ cg.tx.output.update_co_names(name)
+ cg(value)
+ suffixes.append([create_instruction("STORE_GLOBAL", name)])
+ else:
+ cg.tx.output.update_co_names(name)
+ cg(value)
+ cg(var.mutable_local.source)
+ suffixes.append([create_instruction("STORE_ATTR", name)])
+ else:
+ raise AssertionError(type(var))
+
+ # do all the actual mutations at the very end to handle dependencies
+ for suffix in reversed(suffixes):
+ cg.extend_output(suffix)
+
+ def is_empty(self):
+ return not any(map(self.is_modified, self.id_to_variable.values()))
diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py
new file mode 100644
index 0000000000000..2b6fbb3959c8d
--- /dev/null
+++ b/torch/_dynamo/skipfiles.py
@@ -0,0 +1,208 @@
+import _collections_abc
+import _weakrefset
+import abc
+import collections
+import contextlib
+import copy
+import copyreg
+import dataclasses
+import enum
+import functools
+import importlib
+import inspect
+import linecache
+import logging
+import multiprocessing
+import operator
+import os
+import posixpath
+import random
+import re
+import selectors
+import signal
+import tempfile
+import threading
+import tokenize
+import traceback
+import types
+import typing
+import unittest
+import weakref
+
+import torch
+
+try:
+ import torch._prims
+
+ # isort: split
+ # TODO: Hack to unblock simultaneous landing changes. Fix after https://github.com/pytorch/pytorch/pull/81088 lands
+ import torch._prims.utils
+ import torch._prims.wrappers
+ import torch._refs
+ import torch._refs.nn
+ import torch._refs.nn.functional
+ import torch._refs.special
+
+ HAS_PRIMS_REFS = True
+except ImportError:
+ HAS_PRIMS_REFS = False
+
+from . import config
+
+
+def _strip_init_py(s):
+ return re.sub(r"__init__.py$", "", s)
+
+
+def _module_dir(m: types.ModuleType):
+ return _strip_init_py(m.__file__)
+
+
+SKIP_DIRS = [
+ # torch.*
+ _module_dir(torch),
+ # torchdynamo.*
+ os.path.dirname(__file__) + "/",
+ "",
+] + [
+ # skip some standard libs
+ _module_dir(m)
+ for m in (
+ abc,
+ collections,
+ contextlib,
+ copy,
+ copyreg,
+ dataclasses,
+ enum,
+ functools,
+ importlib,
+ inspect,
+ linecache,
+ logging,
+ multiprocessing,
+ operator,
+ os,
+ posixpath,
+ random,
+ re,
+ selectors,
+ signal,
+ tempfile,
+ threading,
+ tokenize,
+ traceback,
+ types,
+ typing,
+ unittest,
+ weakref,
+ _collections_abc,
+ _weakrefset,
+ )
+]
+FILENAME_ALLOWLIST = {
+ torch.nn.Sequential.__init__.__code__.co_filename,
+ torch.set_rng_state.__code__.co_filename,
+}
+
+# Include optimizer code for tracing
+FILENAME_ALLOWLIST |= set(
+ [
+ inspect.getfile(obj)
+ for obj in torch.optim.__dict__.values()
+ if inspect.isclass(obj)
+ ]
+)
+
+FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}
+
+if HAS_PRIMS_REFS:
+ FILENAME_ALLOWLIST |= {
+ torch._prims.__file__,
+ torch._prims.utils.__file__,
+ torch._prims.wrappers.__file__,
+ torch._refs.__file__,
+ torch._refs.special.__file__,
+ torch._refs.nn.functional.__file__,
+ }
+
+FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}
+
+SKIP_DIRS_RE = None
+
+
+def _recompile_re():
+ global SKIP_DIRS_RE
+ SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})")
+
+
+def add(import_name: str):
+ if isinstance(import_name, types.ModuleType):
+ return add(import_name.__name__)
+ assert isinstance(import_name, str)
+ module_spec = importlib.util.find_spec(import_name)
+ if not module_spec:
+ return
+ origin = module_spec.origin
+ if origin is None:
+ return
+ global SKIP_DIRS_RE
+ SKIP_DIRS.append(_strip_init_py(origin))
+ _recompile_re()
+
+
+def check(filename, allow_torch=False):
+ """Should skip this file?"""
+ if filename is None:
+ return True
+ if filename in FILENAME_ALLOWLIST:
+ return False
+ if allow_torch and is_torch(filename):
+ return False
+ return bool(SKIP_DIRS_RE.match(filename))
+
+
+# skip common third party libs
+for _name in (
+ "functorch",
+ "intel_extension_for_pytorch",
+ "networkx",
+ "numpy",
+ "omegaconf",
+ "onnx",
+ "onnxruntime",
+ "onnx_tf",
+ "pandas",
+ "sklearn",
+ "tabulate",
+ "tensorflow",
+ "tensorrt",
+ "torch2trt",
+ "tqdm",
+ "tree",
+ "tvm",
+ "fx2trt_oss",
+ "xarray",
+):
+ add(_name)
+
+_recompile_re()
+
+
+def is_torch_inline_allowed(filename):
+ return any(
+ filename.startswith(_module_dir(mod))
+ for mod in config.skipfiles_inline_module_allowlist
+ )
+
+
+@functools.lru_cache(None)
+def dynamo_dir():
+ return _module_dir(importlib.import_module(config.dynamo_import))
+
+
+def is_torch(filename):
+ if filename.startswith(dynamo_dir()):
+ return False
+ return filename.startswith(_module_dir(torch))
diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py
new file mode 100644
index 0000000000000..6b5d63ab850e1
--- /dev/null
+++ b/torch/_dynamo/source.py
@@ -0,0 +1,256 @@
+import collections
+import dataclasses
+from typing import Any
+
+from . import utils
+from .bytecode_transformation import create_instruction
+from .guards import Guard, GuardSource
+from .utils import rename_implicit
+
+_GUARD_SOURCE_NN_MODULE = {
+ GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
+ GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
+ GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
+ GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
+}
+
+_GUARD_SOURCE_NOT_NN_MODULE = {
+ GuardSource.LOCAL: GuardSource.LOCAL,
+ GuardSource.GLOBAL: GuardSource.GLOBAL,
+ GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
+ GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
+}
+
+
+def is_constant_source(source):
+ if isinstance(source, ConstantSource):
+ return True
+ try:
+ if source.guard_source() == GuardSource.CONSTANT:
+ return True
+ except NotImplementedError:
+ pass
+
+ return False
+
+
+@dataclasses.dataclass
+class Source:
+ def reconstruct(self, codegen):
+ raise NotImplementedError()
+
+ def guard_source(self):
+ raise NotImplementedError()
+
+ def name(self):
+ raise NotImplementedError()
+
+ def make_guard(self, fn, is_volatile=False):
+ if self.guard_source() is GuardSource.CONSTANT:
+ raise NotImplementedError()
+ return Guard(self.name(), self.guard_source(), fn, is_volatile)
+
+ def is_nn_module(self):
+ return self.guard_source() in (
+ GuardSource.LOCAL_NN_MODULE,
+ GuardSource.GLOBAL_NN_MODULE,
+ )
+
+
+@dataclasses.dataclass
+class LocalSource(Source):
+ local_name: str
+
+ def reconstruct(self, codegen):
+ return [codegen.create_load(self.local_name)]
+
+ def guard_source(self):
+ return GuardSource.LOCAL
+
+ def name(self):
+ return rename_implicit(self.local_name)
+
+
+@dataclasses.dataclass
+class RandomValueSource(Source):
+ random_call_index: int
+
+ def reconstruct(self, codegen):
+ return [
+ codegen.create_load(codegen.tx.output.random_values_var),
+ codegen.create_load_const(self.random_call_index),
+ create_instruction("BINARY_SUBSCR"),
+ ]
+
+ def name(self):
+ return rename_implicit(f"random_value_{self.random_call_index}")
+
+
+@dataclasses.dataclass
+class GlobalSource(Source):
+ global_name: str
+
+ def reconstruct(self, codegen):
+ return [codegen.create_load_global(self.global_name, add=True)]
+
+ def guard_source(self):
+ return GuardSource.GLOBAL
+
+ def name(self):
+ return self.global_name
+
+
+@dataclasses.dataclass
+class GlobalWeakRefSource(Source):
+ global_name: str
+
+ def reconstruct(self, codegen):
+ return [
+ codegen.create_load_global(self.global_name, add=True),
+ create_instruction("CALL_FUNCTION", 0),
+ ]
+
+ def guard_source(self):
+ return GuardSource.GLOBAL
+
+ def name(self):
+ return f"{self.global_name}()"
+
+
+@dataclasses.dataclass
+class AttrSource(Source):
+ base: Source
+ member: str
+
+ def __init__(self, base, member):
+ super().__init__()
+ if "." in member:
+ member_parts = member.split(".")
+ self.base = AttrSource(base, ".".join(member_parts[:-1]))
+ self.member = member_parts[-1]
+ else:
+ self.base = base
+ self.member = member
+
+ def reconstruct(self, codegen):
+ return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member)
+
+ def guard_source(self):
+ return self.base.guard_source()
+
+ def name(self):
+ if self.member.isnumeric():
+ return f"getattr({self.base.name()}, {self.member!r})"
+ return f"{self.base.name()}.{self.member}"
+
+
+@dataclasses.dataclass
+class GetItemSource(Source):
+ base: Source
+ index: Any
+
+ def reconstruct(self, codegen):
+ instrs = self.base.reconstruct(codegen)
+
+ if isinstance(self.index, Source):
+ instrs.extend(self.index.reconstruct(codegen))
+ else:
+ instrs.append(codegen.create_load_const(self.index))
+ instrs.append(create_instruction("BINARY_SUBSCR"))
+
+ return instrs
+
+ def guard_source(self):
+ return self.base.guard_source()
+
+ def name(self):
+ if isinstance(self.index, Source):
+ return f"{self.base.name()}[{self.index.name()}]"
+ else:
+ return f"{self.base.name()}[{self.index!r}]"
+
+
+@dataclasses.dataclass
+class TupleIteratorGetItemSource(GetItemSource):
+ def reconstruct(self, codegen):
+ codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
+ return self.base.reconstruct(codegen) + [
+ codegen.create_load_const(self.index),
+ create_instruction("CALL_FUNCTION", 2),
+ ]
+
+ def name(self):
+ return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
+
+
+@dataclasses.dataclass
+class TypeSource(Source):
+ base: Source
+
+ def reconstruct(self, codegen):
+ codegen.load_import_from("builtins", "type")
+ return self.base.reconstruct(codegen) + [create_instruction("CALL_FUNCTION", 1)]
+
+ def guard_source(self):
+ return self.base.guard_source()
+
+ def name(self):
+ return f"type({self.base.name()})"
+
+
+@dataclasses.dataclass
+class ODictGetItemSource(Source):
+ base: Source
+ index: Any
+
+ def reconstruct(self, codegen):
+ return (
+ [codegen._create_load_const(collections.OrderedDict.__getitem__)]
+ + self.base.reconstruct(codegen)
+ + [
+ codegen.create_load_const(self.index),
+ create_instruction("CALL_FUNCTION", 2),
+ ]
+ )
+
+ def guard_source(self):
+ return self.base.guard_source()
+
+ def name(self):
+ return f"___odict_getitem({self.base.name()}, {self.index!r})"
+
+
+@dataclasses.dataclass
+class NNModuleSource(Source):
+ inner: Source
+
+ def reconstruct(self, codegen):
+ return self.inner.reconstruct(codegen)
+
+ def guard_source(self):
+ return _GUARD_SOURCE_NN_MODULE[self.inner.guard_source()]
+
+ def name(self):
+ return self.inner.name()
+
+
+class NotNNModuleSource(NNModuleSource):
+ def guard_source(self):
+ return _GUARD_SOURCE_NOT_NN_MODULE[self.inner.guard_source()]
+
+
+@dataclasses.dataclass
+class ConstantSource(Source):
+ source_name: str
+
+ def reconstruct(self, codegen):
+ return [codegen.create_load_global(self.source_name, add=False)]
+
+ def guard_source(self):
+ return GuardSource.CONSTANT
+
+ def name(self):
+ return self.source_name
+
+ def make_guard(self, fn, is_volatile=False):
+ raise NotImplementedError()
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
new file mode 100644
index 0000000000000..dba37ee0f214c
--- /dev/null
+++ b/torch/_dynamo/symbolic_convert.py
@@ -0,0 +1,1663 @@
+import collections
+import dataclasses
+import dis
+import functools
+import importlib
+import inspect
+import itertools
+import logging
+import operator
+import sys
+import traceback
+import types
+import typing
+import weakref
+from typing import Any, Dict, Iterable, List
+from unittest.mock import patch
+
+import torch
+
+from . import config, exc, side_effects, skipfiles, variables
+from .allowed_functions import is_allowed, is_builtin_callable, is_builtin_constant
+from .bytecode_analysis import livevars_analysis
+from .bytecode_transformation import (
+ cleaned_instructions,
+ create_instruction,
+ Instruction,
+ is_generator,
+ unique_id,
+)
+from .codegen import PyCodegen
+from .exc import unimplemented, Unsupported
+from .guards import GuardBuilder
+from .output_graph import GraphCompileReason, OutputGraph
+from .replay_record import DummyModule, ExecutionRecorder
+from .resume_execution import ContinueExecutionCache, ReenterWith
+from .source import (
+ AttrSource,
+ GetItemSource,
+ GlobalSource,
+ GlobalWeakRefSource,
+ LocalSource,
+)
+from .utils import (
+ counters,
+ fake_tensors_available,
+ graph_break_dup_warning_checker,
+ istype,
+)
+from .variables.base import MutableLocal, typestr, VariableTracker
+from .variables.builder import VariableBuilder
+from .variables.builtin import BuiltinVariable
+from .variables.constant import ConstantVariable
+from .variables.dicts import ConstDictVariable
+from .variables.functions import (
+ BaseUserFunctionVariable,
+ NestedUserFunctionVariable,
+ UserFunctionVariable,
+)
+from .variables.lists import (
+ BaseListVariable,
+ ListIteratorVariable,
+ ListVariable,
+ SliceVariable,
+ TupleVariable,
+)
+from .variables.misc import (
+ ClosureVariable,
+ ContextWrappingVariable,
+ GetAttrVariable,
+ GradModeVariable,
+ PythonModuleVariable,
+ UnknownVariable,
+ WithExitFunctionVariable,
+)
+from .variables.nn_module import NNModuleVariable
+from .variables.tensor import TensorVariable
+from .variables.torch import TorchVariable
+from .variables.user_defined import UserDefinedVariable
+
+log = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class BlockStackEntry:
+ target: Instruction
+ stack_index: int = None
+ with_context: ContextWrappingVariable = None
+
+ def can_restore(self):
+ return self.with_context is not None
+
+ def resume_fn(self):
+ assert self.stack_index is not None
+ return ReenterWith(self.stack_index)
+
+ def exit(self, tx):
+ return self.with_context.exit(tx)
+
+
+def stack_op(fn: typing.Callable):
+ nargs = len(inspect.signature(fn).parameters)
+ fn_var = BuiltinVariable(fn)
+
+ @functools.wraps(fn)
+ def impl(self: "InstructionTranslatorBase", inst: Instruction):
+ self.push(fn_var.call_function(self, self.popn(nargs), {}))
+
+ return impl
+
+
+def generic_jump(truth_fn: typing.Callable, push: bool):
+ def inner(self: "InstructionTranslatorBase", inst: Instruction):
+ value: VariableTracker = self.pop()
+ self.output.guards.update(value.guards)
+ if value.is_python_constant():
+ if truth_fn(value.as_python_constant()):
+ push and self.push(value)
+ self.jump(inst)
+ elif isinstance(value, TensorVariable) and self.should_compile_partial_graph():
+ # compile a partial subgraph prefix then jump into user code
+ self.push(value)
+ self.output.compile_subgraph(
+ self,
+ reason=GraphCompileReason(
+ f"generic_jump {typestr(value)}", [self.frame_summary()]
+ ),
+ )
+ self.pop()
+
+ if_next = self.create_call_resume_at(self.next_instruction)
+ push and self.push(value)
+ if_jump = self.create_call_resume_at(inst.target)
+
+ self.output.add_output_instructions(
+ [(create_instruction(inst.opname, target=if_jump[0]))]
+ + if_next
+ + if_jump
+ )
+ elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
+ self
+ ):
+ if truth_fn(len(value.unpack_var_sequence(self))):
+ push and self.push(value)
+ self.jump(inst)
+ else:
+ unimplemented(f"generic_jump {typestr(value)}")
+
+ return inner
+
+
+explain = False
+
+
+def break_graph_if_unsupported(*, push):
+ def decorator(inner_fn):
+ @functools.wraps(inner_fn)
+ def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
+ state = self.copy_graphstate()
+ reason = None
+ try:
+ return inner_fn(self, inst)
+ except Unsupported as exc:
+ if not self.should_compile_partial_graph():
+ raise
+ user_stack = [self.frame_summary()] + list(reversed(exc.real_stack))
+ user_stack_formatted = "".join(traceback.format_list(user_stack))
+ frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
+ # torchdynamo.explain() formats this a little nicer, and presents a slightly
+ # more actionable user code pointer
+ if not explain and graph_break_dup_warning_checker.add(frame_loc):
+ log.warning(
+ f"Graph break: {exc} from user code at {user_stack_formatted}"
+ )
+
+ exc.remove_from_stats()
+ exc.add_to_stats("graph_break")
+ reason = GraphCompileReason(exc.msg, user_stack)
+ self.restore_graphstate(state)
+ self.output.compile_subgraph(self, reason=reason)
+ self.popn(push - dis.stack_effect(inst.opcode, inst.arg))
+
+ for _ in range(push):
+ self.push(UnknownVariable())
+
+ resume_call_insts = self.create_call_resume_at(self.next_instruction)
+ # Check if there is a block stack entry with GradModeVariable. And
+ # wrap the instruction causing the graph break inside a try..finally
+ # block. See more details at
+ # https://github.com/pytorch/torchdynamo/issues/207
+ cleanup = []
+ if len(self.block_stack) == 1 and isinstance(
+ self.block_stack[0].with_context, GradModeVariable
+ ):
+ ctx_variable = self.block_stack[0].with_context
+
+ cg = PyCodegen(self)
+ setup_finally, cleanup = ctx_variable.reconstruct(
+ cg, resume_call_insts[0]
+ )
+ self.output.add_output_instructions(setup_finally)
+
+ self.output.add_output_instructions([inst])
+
+ # Add the cleanup instructions from try..finally block
+ self.output.add_output_instructions(cleanup)
+ self.output.add_output_instructions(
+ resume_call_insts,
+ )
+
+ return wrapper
+
+ return decorator
+
+
+class InstructionTranslatorBase(object):
+ def cell_and_freevars(self):
+ if not hasattr(self, "_cell_and_freevars"):
+ self._cell_and_freevars = tuple(
+ self.code_options["co_cellvars"] or []
+ ) + tuple(self.code_options["co_freevars"] or [])
+ return self._cell_and_freevars
+
+ def prune_dead_locals(self):
+ reads = livevars_analysis(self.instructions, self.current_instruction)
+ # implicit use by super()
+ # reads = reads | {"__class__"}
+ # output variables?
+ reads = reads | set(self.cell_and_freevars())
+ self.symbolic_locals = collections.OrderedDict(
+ [(k, v) for k, v in self.symbolic_locals.items() if k in reads]
+ )
+ self.output.side_effects.prune_dead_object_new(self)
+
+ def call_function(
+ self,
+ fn: VariableTracker,
+ args: List[VariableTracker],
+ kwargs: Dict[str, VariableTracker],
+ ):
+ assert isinstance(fn, VariableTracker)
+ assert isinstance(args, list)
+ assert isinstance(kwargs, dict)
+ assert all(
+ isinstance(x, VariableTracker)
+ for x in itertools.chain(args, kwargs.values())
+ )
+ self.push(fn.call_function(self, args, kwargs))
+
+ def update_locals_and_stack(self, oldvar: VariableTracker, newvar: VariableTracker):
+ def repl(v: VariableTracker):
+ if v.mutable_local is oldvar.mutable_local:
+ return newvar
+ return v
+
+ cache = dict()
+ self.output.side_effects.apply(repl, cache)
+ self.stack = [VariableTracker.apply(repl, x, cache) for x in self.stack]
+ for k, x in self.symbolic_locals.items():
+ self.symbolic_locals[k] = VariableTracker.apply(repl, x, cache)
+
+ def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
+ if isinstance(oldvar.mutable_local, side_effects.MutableSideEffects):
+ newvar = self.output.side_effects.mutation(oldvar, newvar)
+ else:
+ assert isinstance(oldvar.mutable_local, variables.base.MutableLocal)
+ newvar = newvar.clone(mutable_local=variables.base.MutableLocal())
+ self.update_locals_and_stack(oldvar, newvar)
+ return newvar
+
+ def inline_user_function_return(self, fn, args, kwargs):
+ """
+ A call to some user defined function by inlining it.
+ """
+ state = self.copy_graphstate()
+ try:
+ result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
+ self.output.guards.update(fn.guards)
+ return result
+ except Exception:
+ self.restore_graphstate(state)
+ raise
+
+ def step(self):
+ """Process exactly one instruction, return False we should exit"""
+ inst = self.instructions[self.instruction_pointer]
+ self.current_instruction = inst
+ self.instruction_pointer += 1
+ if self.instruction_pointer < len(self.instructions):
+ self.next_instruction = self.instructions[self.instruction_pointer]
+ else:
+ self.instruction_pointer = None
+ self.next_instruction = None
+ if inst.starts_line:
+ self.lineno = inst.starts_line
+ log.debug(f"TRACE starts_line {self.f_code.co_filename}:{self.lineno}")
+
+ if len(self.stack) == 0 and self.should_compile_partial_graph():
+ self.checkpoint = inst, self.copy_graphstate()
+
+ log.debug(f"TRACE {inst.opname} {inst.argval} {self.stack}")
+
+ try:
+ if not hasattr(self, inst.opname):
+ unimplemented(f"missing: {inst.opname}")
+ getattr(self, inst.opname)(inst)
+ return inst.opname != "RETURN_VALUE"
+ except Unsupported as exc:
+ exc.real_stack.append(self.frame_summary())
+ if self.empty_checkpoint():
+ raise
+ except Exception as exc:
+ real_stack = getattr(exc, "real_stack", [])
+ real_stack.append(self.frame_summary())
+ exc.real_stack = real_stack
+ raise
+
+ # generate code from checkpoint
+ assert not self.output.output_instructions
+ continue_inst, state = self.checkpoint
+ self.restore_graphstate(state)
+ self.output.compile_subgraph(self, partial_convert=True)
+ self.output.add_output_instructions(
+ [create_instruction("JUMP_ABSOLUTE", target=continue_inst)]
+ + self.instructions
+ )
+
+ def run(self):
+ try:
+ while (
+ self.instruction_pointer is not None
+ and not self.output.should_exit
+ and self.step()
+ ):
+ pass
+ except Exception as e:
+ if config.replay_record_enabled:
+ e.exec_record = self.exec_recorder.get_record()
+
+ raise
+ finally:
+ # Cleanup the outputGraph to delete the held tensors. We perform the
+ # cleanup only for InstructionTranslator and not
+ # InliningInstructionTranslator. The InliningInstructionTranslator
+ # mutates the output object and is restored to original state if
+ # there was an exception.
+ if isinstance(self, InstructionTranslator):
+ self.output.cleanup()
+
+ def push(self, val):
+ assert val is None or isinstance(
+ val, VariableTracker
+ ), f"push expects VariableTracker, got {typestr(val)}"
+ self.stack.append(val)
+
+ def push_many(self, vals: List[TensorVariable]):
+ for val in vals:
+ self.push(val)
+
+ def pop(self) -> TensorVariable:
+ return self.stack.pop()
+
+ def popn(self, n: int) -> List[TensorVariable]:
+ assert n >= 0
+ return list(reversed([self.pop() for _ in range(n)]))
+
+ def LOAD_FAST(self, inst):
+ name = inst.argval
+
+ if name in self.f_locals and config.replay_record_enabled:
+ self.exec_recorder.add_local_var(name, self.f_locals[name])
+
+ if name.startswith(".") and name not in self.symbolic_locals:
+ # This happens in dict/list comprehensions
+ name = name.replace(".", "implicit")
+ assert name not in self.cell_and_freevars()
+ if name not in self.symbolic_locals:
+ unimplemented("undefined LOAD_FAST")
+ self.push(self.symbolic_locals[name])
+ if name.startswith("___stack"):
+ self.symbolic_locals.pop(name)
+
+ def LOAD_DEREF(self, inst):
+ assert inst.argval in self.cell_and_freevars()
+
+ if inst.argval in self.f_locals and config.replay_record_enabled:
+ self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval])
+
+ if inst.argval not in self.symbolic_locals:
+ unimplemented(f"undefined LOAD_DEREF {inst.argval}")
+ self.push(self.symbolic_locals[inst.argval])
+
+ def STORE_FAST(self, inst):
+ self.symbolic_locals[inst.argval] = self.pop()
+
+ def DELETE_FAST(self, inst):
+ del self.symbolic_locals[inst.argval]
+
+ STORE_DEREF = STORE_FAST
+
+ def LOAD_CLOSURE(self, inst):
+ self.push(ClosureVariable(name=inst.argval))
+
+ def LOAD_CONST(self, inst):
+ self.push(ConstantVariable(value=inst.argval))
+
+ def get_global_source(self, name):
+ if self.output.root_globals is self.f_globals:
+ source = GlobalSource(name)
+ else:
+ if "__name__" in self.f_globals:
+ source = AttrSource(
+ self.import_source(self.f_globals["__name__"]), name
+ )
+ else:
+ mangled_name = f"___unnamed_scope_{id(self.f_globals)}"
+ if mangled_name not in self.output.root_globals:
+ self.output.install_global(mangled_name, self.f_globals)
+ source = GetItemSource(GlobalSource(mangled_name), name)
+ return source
+
+ def LOAD_GLOBAL(self, inst):
+ name = inst.argval
+
+ if config.replay_record_enabled:
+ if name in self.f_globals:
+ self.exec_recorder.add_global_var(name, self.f_globals[name])
+ else:
+ assert name in self.f_builtins
+ self.exec_recorder.builtins[name] = self.f_builtins[name]
+
+ if name in self.symbolic_globals:
+ variable = self.output.side_effects[self.symbolic_globals[name]]
+ self.push(self.output.side_effects.load_global(variable, name))
+ return
+
+ try:
+ value = self.f_globals[name]
+ except KeyError:
+ return self.load_builtin(inst)
+
+ source = self.get_global_source(name)
+ self.push(VariableBuilder(self, source)(value))
+
+ def STORE_GLOBAL(self, inst):
+ value = self.pop()
+ name = inst.argval
+ source = self.get_global_source(name)
+ if name not in self.symbolic_globals:
+ self.symbolic_globals[name] = object() # sentinel object
+ variable = self.output.side_effects.track_global_existing(
+ source, self.symbolic_globals[name]
+ )
+ self.output.side_effects.store_global(variable, name, value)
+
+ def import_source(self, module_name):
+ """Create an alias to a module for use in guards"""
+ value = importlib.import_module(module_name)
+ alias = f"__import_{module_name.replace('.', '_dot_')}"
+ f_globals = self.output.root_globals
+ assert alias not in f_globals or f_globals[alias] is value
+ f_globals[alias] = value
+ self.output.update_co_names(alias)
+ return GlobalSource(alias)
+
+ def resolve_name(self, name, package, level):
+ """
+ Copied from the Cpython implementation of __import__
+ Resolve a relative module name to an absolute one.
+ https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902
+ """
+ bits = package.rsplit(".", level - 1)
+ if len(bits) < level:
+ raise ImportError("attempted relative import beyond top-level package")
+ base = bits[0]
+ return "{}.{}".format(base, name) if name else base
+
+ def calc_package(self):
+ """
+ Copied from the Cpython implementation of __import__
+ https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090
+ """
+ package = self.f_globals.get("__package__")
+ spec = self.f_globals.get("__spec__")
+ if package is not None:
+ if spec is not None and package != spec.parent:
+ log.warning(
+ "__package__ != __spec__.parent "
+ f"({package!r} != {spec.parent!r})",
+ ImportWarning,
+ stacklevel=3,
+ )
+ return package
+ elif spec is not None:
+ return spec.parent
+ else:
+ log.warning(
+ "can't resolve package from __spec__ or __package__, "
+ "falling back on __name__ and __path__",
+ ImportWarning,
+ stacklevel=3,
+ )
+ package = self.f_globals["__name__"]
+ if "__path__" not in self.f_globals:
+ package = package.rpartition(".")[0]
+ return package
+
+ def IMPORT_NAME(self, inst):
+ level, fromlist = self.popn(2)
+ level = level.as_python_constant()
+ fromlist = fromlist.as_python_constant()
+ module_name = inst.argval
+
+ # Are we replaying? if so, load recorded module
+ recorded_name = (
+ f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}"
+ )
+ if recorded_name in self.f_globals:
+ value = self.f_globals[recorded_name]
+ source = GlobalSource(recorded_name)
+ else:
+ value = __import__(
+ module_name,
+ fromlist=fromlist,
+ level=level,
+ globals=self.f_globals,
+ )
+
+ if level != 0:
+ pkg = self.calc_package()
+ module_name = self.resolve_name(module_name, pkg, level)
+
+ # For __import__, when the name variable is of the form package.module,
+ # normally, the top-level package (the name up till the first dot) is
+ # returned, not the module named by module_name. However, when a
+ # non-empty fromlist argument is given, the module named by name is
+ # returned. Therefore, we set the source correctly here.
+ if not fromlist:
+ top_level_module_name = module_name.partition(".")[0]
+ source = self.import_source(top_level_module_name)
+ else:
+ source = self.import_source(module_name)
+
+ if config.replay_record_enabled:
+ self.exec_recorder.add_local_mod(recorded_name, value)
+
+ if is_allowed(value):
+ self.push(TorchVariable(value, source=source))
+ elif istype(value, (types.ModuleType, DummyModule)):
+ self.push(PythonModuleVariable(value, source=source))
+ else:
+ unimplemented(f"IMPORT_NAME {typestr(value)}")
+
+ def IMPORT_FROM(self, inst):
+ self.DUP_TOP(inst)
+ self.LOAD_ATTR(inst)
+
+ def load_builtin(self, inst):
+ assert inst.argval in self.f_builtins
+ val = self.f_builtins[inst.argval]
+
+ if callable(val):
+ assert is_builtin_callable(val)
+ self.push(VariableBuilder(self, GlobalSource(inst.argval))(val))
+ else:
+ assert is_builtin_constant(val)
+ self.push(ConstantVariable(value=val))
+
+ def jump(self, inst):
+ self.instruction_pointer = self.indexof[id(inst.target)]
+
+ JUMP_FORWARD = jump
+ JUMP_ABSOLUTE = jump
+
+ POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
+ POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
+ JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
+ JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
+
+ def SETUP_LOOP(self, inst):
+ # only exists in python<=3.7
+ self.block_stack.append(BlockStackEntry(inst.target))
+
+ def SETUP_EXCEPT(self, inst):
+ # only exists in python<=3.7
+ self.block_stack.append(BlockStackEntry(inst.target))
+
+ def POP_BLOCK(self, inst):
+ self.block_stack.pop()
+
+ def SETUP_WITH(self, inst):
+ ctx = self.pop()
+ if not isinstance(ctx, ContextWrappingVariable):
+ unimplemented(f"SETUP_WITH {ctx}")
+ self.output.guards.update(ctx.guards)
+
+ if isinstance(self, InstructionTranslator):
+ self.block_stack.append(BlockStackEntry(inst.target, len(self.stack), ctx))
+ else:
+ # can't restore this while inlining
+ self.block_stack.append(BlockStackEntry(inst.target))
+ self.push(
+ WithExitFunctionVariable(
+ ctx,
+ inst.target,
+ **VariableTracker.propagate(ctx),
+ )
+ )
+ self.push(ctx.enter(self))
+
+ def SETUP_FINALLY(self, inst):
+ self.block_stack.append(BlockStackEntry(inst.target))
+
+ def BEGIN_FINALLY(self, inst):
+ self.push(None)
+
+ def WITH_CLEANUP_START(self, inst):
+ exit, exc = self.popn(2)
+ if sys.version_info < (3, 8):
+ assert exc.is_python_constant()
+ assert exc.as_python_constant() is None
+ else:
+ assert exc is None
+ self.push(exc)
+ self.push(exit.call_function(self, [ConstantVariable(None)] * 3, {}))
+
+ def WITH_CLEANUP_FINISH(self, inst):
+ self.popn(2)
+ self.push(None)
+
+ def END_FINALLY(self, inst):
+ tos = self.pop()
+ if sys.version_info < (3, 8):
+ # python3.7 and 3.8 can have END_FINALLY without BEGIN_FINALLY
+ assert tos is None or (
+ tos.is_python_constant() and tos.as_python_constant() is None
+ )
+ else:
+ assert tos is None
+
+ def FOR_ITER(self, inst):
+ it = self.pop()
+ if isinstance(it, ListIteratorVariable):
+ self.output.guards.update(it.guards)
+ try:
+ val, next_iter = it.next_variables()
+ self.replace_all(it, next_iter)
+ self.push(next_iter)
+ self.push(val)
+ except StopIteration:
+ self.jump(inst)
+ else:
+ unimplemented(f"FOR_ITER {typestr(it)}")
+
+ def COMPARE_OP(self, inst):
+ left, right = self.popn(2)
+ left = left.as_specialized(self)
+ right = right.as_specialized(self)
+ options = VariableTracker.propagate([left, right])
+ op = inst.argval
+ supported_is_const = {
+ "is": operator.is_,
+ "is not": operator.is_not,
+ "==": operator.eq,
+ "!=": operator.ne,
+ }
+ supported_tensors = {
+ ">": operator.gt,
+ "<": operator.lt,
+ ">=": operator.ge,
+ "<=": operator.le,
+ "==": operator.eq,
+ "!=": operator.ne,
+ }
+ supported_any = dict(
+ itertools.chain(supported_tensors.items(), supported_is_const.items())
+ )
+ if (
+ isinstance(
+ left,
+ (
+ TensorVariable,
+ NNModuleVariable,
+ BaseListVariable,
+ UserDefinedVariable,
+ BaseUserFunctionVariable,
+ ConstDictVariable,
+ ),
+ )
+ and isinstance(right, ConstantVariable)
+ and right.value is None
+ and op in supported_is_const
+ ):
+ # is None
+ self.push(
+ ConstantVariable(
+ supported_is_const[op](object(), right.value), **options
+ )
+ )
+ elif (
+ isinstance(left, TensorVariable) or isinstance(right, TensorVariable)
+ ) and op in supported_tensors:
+ self.push(
+ TensorVariable.create(
+ self,
+ supported_tensors[op](left.as_proxy(), right.as_proxy()),
+ **options,
+ )
+ )
+ elif (
+ left.is_python_constant()
+ and right.is_python_constant()
+ and op in supported_any
+ ):
+ # constant fold
+ self.push(
+ ConstantVariable(
+ supported_any[op](
+ left.as_python_constant(), right.as_python_constant()
+ ),
+ **options,
+ )
+ )
+ elif op in ("in", "not in"):
+ self.push(right.call_method(self, "__contains__", [left], {}))
+ if op == "not in":
+ self.UNARY_NOT(inst)
+ else:
+ unimplemented(f"COMPARE_OP {typestr(left)} {op} {typestr(right)}")
+
+ def GET_ITER(self, inst):
+ self.call_function(BuiltinVariable(iter), [self.pop()], {})
+
+ @break_graph_if_unsupported(push=1)
+ def CALL_FUNCTION(self, inst):
+ args = self.popn(inst.argval)
+ fn = self.pop()
+ self.call_function(fn, args, {})
+
+ @break_graph_if_unsupported(push=1)
+ def CALL_FUNCTION_EX(self, inst):
+ if inst.argval == 0:
+ kwargsvars = ConstDictVariable({}, dict)
+ argsvars = self.pop()
+ elif inst.argval == 1:
+ kwargsvars = self.pop()
+ argsvars = self.pop()
+ else:
+ unimplemented("CALL_FUNCTION_EX")
+ fn = self.pop()
+ self.output.guards.update(argsvars.guards)
+ self.output.guards.update(kwargsvars.guards)
+
+ if (
+ isinstance(fn, GetAttrVariable)
+ and isinstance(fn.obj, TensorVariable)
+ and fn.name == "view"
+ and isinstance(argsvars, (ConstantVariable, TensorVariable))
+ ):
+ # Hack to handle special case in some bert models. Converts
+ # x.view(*shape) into x.view(shape), which is correct for view()
+ # but not generally. See test_transpose_for_scores().
+ argsvars = TupleVariable([argsvars])
+
+ if not isinstance(
+ argsvars, BaseListVariable
+ ) and argsvars.has_unpack_var_sequence(self):
+ argsvars = TupleVariable(argsvars.unpack_var_sequence(self))
+
+ if not isinstance(argsvars, BaseListVariable) or not isinstance(
+ kwargsvars, ConstDictVariable
+ ):
+ unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}")
+
+ self.call_function(fn, argsvars.items, kwargsvars.items)
+
+ @break_graph_if_unsupported(push=1)
+ def CALL_FUNCTION_KW(self, inst):
+ argnames = self.pop()
+ args = self.popn(inst.argval)
+ fn = self.pop()
+ assert isinstance(argnames, ConstantVariable)
+ argnames = argnames.value
+ args, kwargs = args[: -len(argnames)], args[-len(argnames) :]
+ kwargs = dict(zip(argnames, kwargs))
+ assert len(kwargs) == len(argnames)
+ self.call_function(fn, args, kwargs)
+
+ def LOAD_METHOD(self, inst):
+ self.LOAD_ATTR(inst)
+ self.push(self.pop())
+ self.push(None)
+
+ def CALL_METHOD(self, inst):
+ args = self.popn(inst.argval)
+ dummy = self.pop()
+ assert dummy is None
+ fn = self.pop()
+ self.call_function(fn, args, {})
+
+ def LOAD_ATTR(self, inst):
+ obj = self.pop()
+ result = BuiltinVariable(getattr).call_function(
+ self, [obj, ConstantVariable(inst.argval)], {}
+ )
+ self.push(result)
+
+ def STORE_ATTR(self, inst):
+ prior = self.copy_graphstate()
+ val, obj = self.popn(2)
+ try:
+ self.output.guards.update(
+ BuiltinVariable(setattr)
+ .call_function(self, [obj, ConstantVariable(inst.argval), val], {})
+ .guards
+ )
+ return
+ except Unsupported as e:
+ if not self.should_compile_partial_graph():
+ raise
+ e.remove_from_stats()
+ e.add_to_stats("graph_break")
+ self.restore_graphstate(prior)
+
+ # break the graph
+ self.output.compile_subgraph(
+ self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
+ )
+ self.output.add_output_instructions([inst])
+ self.popn(2)
+ self.output.add_output_instructions(
+ self.create_call_resume_at(self.next_instruction)
+ )
+
+ @break_graph_if_unsupported(push=0)
+ def STORE_SUBSCR(self, inst):
+ val, obj, key = self.popn(3)
+ result = obj.call_method(self, "__setitem__", [key, val], {})
+ # no result is pushed, so need to lift the guards to global
+ self.output.guards.update(result.guards)
+
+ def BUILD_TUPLE(self, inst):
+ items = self.popn(inst.argval)
+ options = VariableTracker.propagate(items)
+ self.push(TupleVariable(items, **options))
+
+ def BUILD_SLICE(self, inst):
+ items = self.popn(inst.argval)
+ options = VariableTracker.propagate(items)
+ self.push(
+ SliceVariable(
+ [x.as_specialized(self) for x in items],
+ **options,
+ )
+ )
+
+ def BUILD_LIST(self, inst):
+ items = self.popn(inst.argval)
+ options = VariableTracker.propagate(items)
+ self.push(ListVariable(items, mutable_local=MutableLocal(), **options))
+
+ def BUILD_LIST_UNPACK(self, inst, cls=ListVariable):
+ seqs = self.popn(inst.argval)
+ options = VariableTracker.propagate(seqs)
+ items = list()
+ for seq in seqs:
+ try:
+ items.extend(seq.unpack_var_sequence(self))
+ except NotImplementedError:
+ unimplemented(f"BUILD_LIST_UNPACK {seq}")
+ self.push(cls(items, mutable_local=MutableLocal(), **options))
+
+ def BUILD_TUPLE_UNPACK(self, inst):
+ self.BUILD_LIST_UNPACK(inst, cls=TupleVariable)
+
+ BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK
+
+ def BUILD_MAP(self, inst):
+ items = self.popn(inst.argval * 2)
+ options = VariableTracker.propagate(items)
+ result = dict()
+ for k, v in zip(items[::2], items[1::2]):
+ assert isinstance(k, ConstantVariable) or (
+ isinstance(k, TensorVariable) and k.parameter_value is not None
+ )
+
+ result[ConstDictVariable.get_key(k)] = v
+ assert len(result) == len(items) / 2
+ self.push(
+ ConstDictVariable(result, dict, mutable_local=MutableLocal(), **options)
+ )
+
+ def BUILD_CONST_KEY_MAP(self, inst):
+ keys = self.pop()
+ values = self.popn(inst.argval)
+ options = VariableTracker.propagate([keys] + values)
+ assert isinstance(keys, ConstantVariable)
+ keys = keys.value
+ assert istype(keys, tuple)
+ assert len(keys) == len(values)
+ self.push(
+ ConstDictVariable(
+ dict(zip(keys, values)),
+ dict,
+ mutable_local=MutableLocal(),
+ **options,
+ )
+ )
+
+ def MAP_ADD(self, inst):
+ if sys.version_info < (3, 8):
+ v, k = self.popn(2)
+ else:
+ k, v = self.popn(2)
+
+ assert inst.argval > 0
+ obj = self.stack[-inst.arg]
+ assert isinstance(obj, ConstDictVariable)
+ assert obj.mutable_local
+ items = dict(obj.items)
+ items[k.as_python_constant()] = v
+ self.replace_all(
+ obj,
+ ConstDictVariable(
+ items,
+ obj.user_cls,
+ **VariableTracker.propagate([obj, k, v]),
+ ),
+ )
+
+ def LIST_APPEND(self, inst):
+ v = self.pop()
+ assert inst.argval > 0
+ obj = self.stack[-inst.arg]
+ assert isinstance(obj, ListVariable)
+ assert obj.mutable_local
+ self.replace_all(
+ obj,
+ ListVariable(
+ obj.items + [v],
+ **VariableTracker.propagate([obj, v]),
+ ),
+ )
+
+ def MAKE_FUNCTION(self, inst):
+ flags = inst.arg
+ old_stack = list(self.stack)
+ fn_name = self.pop()
+ code = self.pop()
+ defaults = None
+ closure = None
+ annotations = None
+ kwdefaults = None
+
+ if flags & 0x08:
+ closure = self.pop()
+ if flags & 0x04:
+ annotations = self.pop()
+ if flags & 0x02:
+ kwdefaults = self.pop()
+ if flags & 0x01:
+ defaults = self.pop()
+
+ options = VariableTracker.propagate(old_stack[len(self.stack) :])
+ self.push(
+ NestedUserFunctionVariable(
+ fn_name,
+ code,
+ self.f_globals,
+ defaults,
+ kwdefaults,
+ annotations,
+ closure,
+ closure_scope=self,
+ **options,
+ )
+ )
+
+ def UNPACK_SEQUENCE(self, inst):
+ # TODO(jansel): rewrite this using unpack_var_sequence
+ seq = self.pop()
+ options = VariableTracker.propagate([seq])
+ if isinstance(seq, BaseListVariable):
+ assert len(seq.items) == inst.argval
+ self.output.guards.update(seq.guards)
+ for i in reversed(seq.items):
+ self.push(i)
+ elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
+ val = seq.as_python_constant()
+ assert len(val) == inst.argval
+ for i in reversed(val):
+ self.push(ConstantVariable(i, **options))
+ elif isinstance(seq, TensorVariable):
+ proxy = seq.as_proxy()
+ for i in reversed(range(inst.argval)):
+ self.push(TensorVariable.create(self, proxy[i], **options))
+ elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable):
+ # x, y = a.shape
+ proxy = getattr(seq.obj.as_proxy(), seq.name)
+ for i in reversed(range(inst.argval)):
+ self.push(TensorVariable.create(self, proxy[i], **options))
+ else:
+ unimplemented(f"UNPACK_SEQUENCE {seq}")
+
+ def UNPACK_EX(self, inst):
+ assert 0 <= inst.argval <= 0xFFFF
+ prefix = inst.argval & 0xFF # low byte
+ suffix = inst.argval >> 8 # high byte
+ seq = self.pop()
+ options = VariableTracker.propagate(seq)
+ if seq.has_unpack_var_sequence(self):
+ vals = list(seq.unpack_var_sequence(self))
+ assert len(vals) >= prefix + suffix
+ vals_prefix = vals[:prefix]
+ vals_list = vals[prefix : len(vals) - suffix]
+ vals_suffix = vals[len(vals) - suffix :]
+ for item in reversed(vals_suffix):
+ self.push(item.add_options(options))
+ self.push(TupleVariable(vals_list, **options))
+ for item in reversed(vals_prefix):
+ self.push(item.add_options(options))
+ else:
+ unimplemented(f"UNPACK_EX {seq}")
+
+ def NOP(self, inst):
+ pass
+
+ def POP_TOP(self, inst):
+ self.pop()
+
+ def ROT_TWO(self, inst):
+ a = self.pop()
+ b = self.pop()
+ self.push(a)
+ self.push(b)
+
+ def ROT_THREE(self, inst):
+ a = self.pop()
+ b = self.pop()
+ c = self.pop()
+ self.push(a)
+ self.push(c)
+ self.push(b)
+
+ def ROT_FOUR(self, inst):
+ a = self.pop()
+ b = self.pop()
+ c = self.pop()
+ d = self.pop()
+ self.push(a)
+ self.push(d)
+ self.push(c)
+ self.push(b)
+
+ def DUP_TOP(self, inst):
+ a = self.pop()
+ self.push(a)
+ self.push(a)
+
+ def DUP_TOP_TWO(self, inst):
+ a = self.pop()
+ b = self.pop()
+ self.push(b)
+ self.push(a)
+ self.push(b)
+ self.push(a)
+
+ def FORMAT_VALUE(self, inst):
+ flags = inst.arg
+ if (flags & 0x04) == 0x04:
+ fmt_spec = self.pop()
+ else:
+ fmt_spec = ConstantVariable("")
+
+ value = self.pop()
+
+ if (flags & 0x03) == 0x01:
+ value = BuiltinVariable(str).call_function(self, [value], {})
+ elif (flags & 0x03) == 0x02:
+ value = BuiltinVariable(repr).call_function(self, [value], {})
+ elif (flags & 0x03) == 0x03:
+ value = BuiltinVariable(ascii).call_function(self, [value], {})
+
+ fmt_var = ConstantVariable(
+ "{:" + fmt_spec.as_python_constant() + "}"
+ ).add_options(fmt_spec)
+
+ self.call_function(BuiltinVariable(str.format), [fmt_var, value], {})
+
+ def BUILD_STRING(self, inst):
+ result = ""
+ for _ in range(inst.arg):
+ str_var = self.pop()
+ assert isinstance(str_var, ConstantVariable)
+ result = str_var.value + result
+ self.push(ConstantVariable(value=result))
+
+ def IS_OP(self, inst):
+ assert inst.argval == 0 or inst.argval == 1
+ if inst.argval == 0:
+ new_argval = "is"
+ else:
+ new_argval = "is not"
+ new_inst = create_instruction("COMPARE_OP", argval=new_argval)
+ self.COMPARE_OP(new_inst)
+
+ def CONTAINS_OP(self, inst):
+ assert inst.argval == 0 or inst.argval == 1
+ left, right = self.popn(2)
+ op = inst.argval
+ self.push(right.call_method(self, "__contains__", [left], {}))
+ if op == 1:
+ self.UNARY_NOT(inst)
+
+ def LIST_EXTEND(self, inst):
+ v = self.pop()
+ assert inst.argval > 0
+ obj = self.stack[-inst.arg]
+ assert isinstance(obj, ListVariable)
+ assert obj.mutable_local
+ obj.call_method(self, "extend", [v], {})
+
+ def LIST_TO_TUPLE(self, inst):
+ self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {}))
+
+ def DICT_MERGE(self, inst):
+ v = self.pop()
+ assert inst.argval > 0
+ obj = self.stack[-inst.arg]
+ assert isinstance(obj, ConstDictVariable)
+ assert obj.mutable_local
+ obj.call_method(self, "update", [v], {})
+
+ def GEN_START(self, inst):
+ self.pop()
+
+ def GET_LEN(self, inst):
+ tos = self.stack[-1]
+ if tos.is_python_constant():
+ self.push(ConstantVariable(len(tos.as_python_constant())))
+ else:
+ self.push(tos.call_method(self, "__len__", [], {}))
+
+ def MATCH_MAPPING(self, inst):
+ tos = self.stack[-1]
+ assert isinstance(tos, ConstDictVariable)
+ if isinstance(tos.items, collections.abc.Mapping):
+ self.push(ConstantVariable(True))
+ else:
+ self.push(ConstantVariable(False))
+
+ def MATCH_SEQUENCE(self, inst):
+ tos = self.stack[-1]
+ assert tos.is_python_constant()
+ tos_value = tos.as_python_constant()
+ if isinstance(tos_value, collections.abc.Sequence) and not isinstance(
+ tos_value, (str, bytes, bytearray)
+ ):
+ self.push(ConstantVariable(True))
+ else:
+ self.push(ConstantVariable(False))
+
+ def MATCH_KEYS(self, inst):
+ tos = self.stack[-1]
+ assert tos.is_python_constant()
+ keys = tos.as_python_constant()
+ tos1 = self.stack[-2]
+ assert isinstance(tos1, ConstDictVariable)
+ match_obj = tos1.items
+ if all(key in match_obj for key in keys):
+ self.push(TupleVariable(list(match_obj[key] for key in keys)))
+ self.push(ConstantVariable(True))
+ else:
+ self.push(ConstantVariable(None))
+ self.push(ConstantVariable(False))
+
+ UNARY_POSITIVE = stack_op(operator.pos)
+ UNARY_NEGATIVE = stack_op(operator.neg)
+ UNARY_NOT = stack_op(operator.not_)
+ UNARY_INVERT = stack_op(operator.invert)
+
+ BINARY_POWER = stack_op(operator.pow)
+ BINARY_MULTIPLY = stack_op(operator.mul)
+ BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul)
+ BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv)
+ BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
+ BINARY_MODULO = stack_op(operator.mod)
+ BINARY_ADD = stack_op(operator.add)
+ BINARY_SUBTRACT = stack_op(operator.sub)
+ BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem))
+ BINARY_LSHIFT = stack_op(operator.lshift)
+ BINARY_RSHIFT = stack_op(operator.rshift)
+ BINARY_AND = stack_op(operator.and_)
+ BINARY_OR = stack_op(operator.or_)
+ BINARY_XOR = stack_op(operator.xor)
+
+ INPLACE_POWER = stack_op(operator.ipow)
+ INPLACE_MULTIPLY = stack_op(operator.imul)
+ INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul)
+ INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv)
+ INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv)
+ INPLACE_MODULO = stack_op(operator.imod)
+ INPLACE_ADD = stack_op(operator.iadd)
+ INPLACE_SUBTRACT = stack_op(operator.isub)
+ INPLACE_LSHIFT = stack_op(operator.ilshift)
+ INPLACE_RSHIFT = stack_op(operator.irshift)
+ INPLACE_AND = stack_op(operator.iand)
+ INPLACE_XOR = stack_op(operator.ixor)
+ INPLACE_OR = stack_op(operator.ior)
+
+ def copy_graphstate(self):
+ """Create a checkpoint of the current state by copying everything"""
+ return (
+ self.output.copy_graphstate(),
+ collections.OrderedDict(self.symbolic_locals),
+ list(self.stack),
+ list(self.block_stack),
+ self.instruction_pointer,
+ self.current_instruction,
+ self.next_instruction,
+ self.lineno,
+ )
+
+ def restore_graphstate(self, state):
+ """Restore a checkpoint created by self.copy_graphstate()"""
+ (
+ output_state,
+ self.symbolic_locals,
+ self.stack,
+ self.block_stack,
+ self.instruction_pointer,
+ self.current_instruction,
+ self.next_instruction,
+ self.lineno,
+ ) = state
+ self.output.restore_graphstate(output_state)
+
+ def empty_checkpoint(self):
+ if self.checkpoint is None:
+ return True
+ output_graphstate = self.checkpoint[1][0]
+ graphstate = self.checkpoint[1][1:]
+ state = (*output_graphstate, *graphstate)
+ for obj in state:
+ if isinstance(obj, Iterable):
+ if len(obj) != 0:
+ return False
+ return True
+
+ def format_frame_summary(self, additional_stack_frames=None):
+ if additional_stack_frames is None:
+ additional_stack_frames = []
+ return "".join(
+ traceback.format_list(
+ ([self.frame_summary()] + list(reversed(additional_stack_frames)))
+ )
+ )
+
+ def frame_summary(self):
+ return traceback.FrameSummary(
+ getattr(self.f_code, "co_filename", ""),
+ self.lineno,
+ getattr(self.f_code, "co_name", ""),
+ lookup_line=False,
+ )
+
+ def store_dict_key(self, name, value):
+ self.output.guards.add(
+ GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE)
+ )
+ if name not in self.output.root_globals:
+ self.output.install_global(name, weakref.ref(value))
+
+ @property
+ def fake_mode(self):
+ return self._fake_mode
+
+ def find_symbolic_locals_name(self, tensor_variable):
+ for key, value in self.symbolic_locals.items():
+ if value is tensor_variable:
+ return key
+ return None
+
+ def __init__(
+ self,
+ output: OutputGraph,
+ instructions: List[Instruction],
+ f_locals: Dict[str, Any],
+ f_globals: Dict[str, Any],
+ f_builtins: Dict[str, Any],
+ code_options: Dict[str, Any],
+ symbolic_locals: Dict[str, VariableTracker],
+ symbolic_globals: Dict[str, VariableTracker],
+ f_code: types.CodeType,
+ ):
+ super(InstructionTranslatorBase, self).__init__()
+
+ # Mutable state checkpointed by copy_graphstate()
+ self.output: OutputGraph = output
+ self.symbolic_locals: Dict[str, VariableTracker] = symbolic_locals
+ self.symbolic_globals: Dict[str, VariableTracker] = symbolic_globals
+ self.stack: List[VariableTracker] = []
+ self.instruction_pointer: int = 0
+ self.current_instruction: Instruction = create_instruction("NOP")
+ self.next_instruction: typing.Optional[Instruction] = None
+ self.block_stack: List[BlockStackEntry] = []
+ self.lineno: int = code_options.get("co_firstlineno")
+
+ # Properties of the input/output code
+ self.instructions: List[Instruction] = instructions
+ self.indexof: Dict[int, int] = {id(i): n for n, i in enumerate(instructions)}
+ self.f_locals: Dict[
+ str, Any
+ ] = f_locals # needed for recording accessed locals for replay
+ self.f_globals: Dict[str, Any] = f_globals
+ self.f_builtins: Dict[str, Any] = f_builtins
+ self.code_options: Dict[str, Any] = code_options
+ self.f_code: types.CodeType = f_code
+
+ # Execution record for replaying errors
+ self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options)
+ # Stack of module being parsed, current nn.module is at the end of ordered dict
+ self.nn_module_stack: Dict[str, str] = {}
+
+ if fake_tensors_available:
+ with torch._subclasses.FakeTensorMode(
+ throw_on_data_dependent_ops=True
+ ) as fake_mode:
+ pass
+ self._fake_mode = fake_mode
+
+ self.checkpoint = None
+ self.random_calls: List[tuple] = []
+
+ if sys.version_info >= (3, 10):
+ from .resume_execution import (
+ CO_ASYNC_GENERATOR,
+ CO_COROUTINE,
+ CO_GENERATOR,
+ CO_ITERABLE_COROUTINE,
+ )
+
+ if f_code.co_flags & (
+ CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
+ ):
+ self.push(BuiltinVariable(None))
+
+
+class InstructionTranslator(InstructionTranslatorBase):
+ def __init__(
+ self,
+ instructions: List[Instruction],
+ f_code,
+ f_locals,
+ f_globals,
+ f_builtins,
+ code_options,
+ compiler_fn,
+ one_graph,
+ export,
+ ):
+ super(InstructionTranslator, self).__init__(
+ output=OutputGraph(f_globals, code_options, compiler_fn, self),
+ instructions=instructions,
+ f_locals=f_locals,
+ f_globals=f_globals,
+ f_builtins=f_builtins,
+ code_options=code_options,
+ symbolic_locals=collections.OrderedDict(), # set below
+ # A global var is inserted only after a STORE_GLOBAL happens to it
+ symbolic_globals=collections.OrderedDict(),
+ f_code=f_code,
+ )
+ self.one_graph: bool = one_graph
+ self.export = export
+ if self.export:
+ assert (
+ self.one_graph
+ ), "Export without one graph - something has gone wrong."
+
+ vars = list(code_options["co_varnames"])
+ vars.extend(x for x in self.cell_and_freevars() if x not in vars)
+ self.symbolic_locals = collections.OrderedDict(
+ (k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
+ for k in vars
+ if k in f_locals
+ )
+
+ # symbolic_locals contains the mapping from original f_locals to the
+ # Variable objects. During the Variable building phase, each object also
+ # has its associated guards. At the end, we will accumulate these
+ # guards.
+ #
+ # One way of handling these guards is to just accumulate all of them
+ # right now. However, many f_locals might not be used in the frame and
+ # thus can unnecessarily increase guard execution overhead. Therefore,
+ # we selectively update output.guards as we run the Python Bytecode
+ # instruction by instruction.
+ #
+ # An exception here is list/dict variables. Guards related to these
+ # variables have indexed access, like Tensor_match on args[0], and if
+ # args is not used in this frame, we will miss a LIST_LENGTH check like
+ # len(args) == 2. Missing the LIST_LENGTH check causes problem for the
+ # next invocation when args is not a list, and args[0] is a runtime
+ # error. Therefore, we recursively add guards for list/dict variable here.
+ for val in self.symbolic_locals.values():
+ if isinstance(
+ val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
+ ):
+ local_guards = VariableTracker.propagate(val)["guards"]
+ index_guards = [
+ guard
+ for guard in local_guards
+ if guard.create_fn
+ in (
+ GuardBuilder.LIST_LENGTH,
+ GuardBuilder.DICT_KEYS,
+ GuardBuilder.ODICT_KEYS,
+ GuardBuilder.TUPLE_ITERATOR_LEN,
+ )
+ ]
+ self.output.guards.update(index_guards)
+
+ self._freevars_ids = dict()
+ for name in self.code_options["co_freevars"]:
+ if name in f_locals:
+ self._freevars_ids[name] = id(f_locals[name])
+
+ def match_nested_cell(self, name, cell):
+ """Match a cell in this method to one in a function we are inlining"""
+ value = cell.cell_contents
+ # TODO(jansel): check the id of the cell rather than the contents
+ if id(value) != self._freevars_ids.get(name):
+ return None
+ return self.symbolic_locals[name]
+
+ def should_compile_partial_graph(self):
+ return all(b.can_restore() for b in self.block_stack) and not self.one_graph
+
+ def create_call_resume_at(self, inst):
+ self.instruction_pointer = None
+
+ if inst.opname == "RETURN_VALUE":
+ return [create_instruction("RETURN_VALUE")]
+
+ reads = livevars_analysis(self.instructions, inst)
+ argnames = tuple(
+ k
+ for k in self.symbolic_locals.keys()
+ if k in reads and k not in self.cell_and_freevars()
+ )
+ nargs = len(self.stack) + len(argnames)
+
+ name = unique_id(f"__resume_at_{inst.offset}")
+
+ new_code: types.CodeType = ContinueExecutionCache.lookup(
+ self.f_code,
+ self.lineno,
+ inst.offset,
+ len(self.stack),
+ argnames,
+ tuple(b.resume_fn() for b in self.block_stack),
+ )
+
+ cg = PyCodegen(self)
+
+ if new_code.co_freevars:
+ cg.make_function_with_closure(name, new_code, len(self.stack))
+ else:
+ self.output.install_global(
+ name, types.FunctionType(new_code, self.f_globals, name)
+ )
+ cg.extend_output(cg.load_function_name(name, len(self.stack)))
+
+ cg.extend_output([cg.create_load(k) for k in argnames])
+ cg.extend_output(
+ [
+ create_instruction("CALL_FUNCTION", nargs),
+ create_instruction("RETURN_VALUE"),
+ ]
+ )
+ return cg.get_instructions()
+
+ def RETURN_VALUE(self, inst):
+ if self.output.count_calls() == 0 and not self.export:
+ raise exc.SkipFrame()
+ self.instruction_pointer = None
+ self.output.compile_subgraph(self)
+ self.output.add_output_instructions([create_instruction("RETURN_VALUE")])
+
+
+class InliningInstructionTranslator(InstructionTranslatorBase):
+ """Trace and inline a called method"""
+
+ @classmethod
+ def inline_call(cls, parent, func, args, kwargs):
+ with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
+ return cls.inline_call_(parent, func, args, kwargs)
+
+ @staticmethod
+ def inline_call_(parent, func, args, kwargs):
+ assert isinstance(
+ func,
+ (UserFunctionVariable, NestedUserFunctionVariable),
+ )
+ if func.has_self():
+ unimplemented("inline with __self__")
+
+ if func.get_name() == "patched_init":
+ unimplemented("Patched init cannot be inlined.")
+
+ if skipfiles.check(
+ func.get_filename()
+ ) and not skipfiles.is_torch_inline_allowed(func.get_filename()):
+ unimplemented(
+ f"inline in skipfiles: {func.get_name()} {func.get_filename()}"
+ )
+
+ try:
+ sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
+ except TypeError as exc:
+ log.warning(
+ f"{func.get_filename()} {func.get_function()} {args} {kwargs} {exc}"
+ )
+ unimplemented("arg mismatch inlining")
+
+ for v in itertools.chain(sub_locals.values(), closure_cells.values()):
+ if not isinstance(v, VariableTracker):
+ unimplemented(f"unconverted arg {v}")
+
+ code: types.CodeType = func.get_code()
+ if code.co_name in ("__setitem__", "__setattr__"):
+ unimplemented(f"inline {code.co_name}")
+
+ log.debug(f"INLINING {code} \n {dis.Bytecode(code).dis()} \n")
+
+ if is_generator(code):
+ tracer = InliningGeneratorInstructionTranslator(
+ parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
+ )
+ else:
+ tracer = InliningInstructionTranslator(
+ parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
+ )
+
+ tracer.run()
+ assert tracer.symbolic_result is not None
+ func.export_freevars(parent, tracer)
+
+ if tracer.f_globals is parent.f_globals:
+ # Merge symbolic_globals back if parent and child are in the same namespace
+ parent.symbolic_globals.update(tracer.symbolic_globals)
+
+ log.debug(f"DONE INLINING {code}")
+
+ if is_generator(code):
+ assert tracer.symbolic_result.as_python_constant() is None
+ return ListIteratorVariable(
+ tracer.generated_items,
+ mutable_local=MutableLocal(),
+ **VariableTracker.propagate(tracer.symbolic_result),
+ )
+ else:
+ return tracer.symbolic_result
+
+ def __init__(
+ self,
+ parent: InstructionTranslatorBase,
+ code: types.CodeType,
+ symbolic_locals: Dict[str, VariableTracker],
+ symbolic_globals: Dict[str, VariableTracker],
+ closure_cells: Dict[str, VariableTracker],
+ funcvar: BaseUserFunctionVariable,
+ ):
+ f_globals = funcvar.get_globals()
+ f_builtins = f_globals["__builtins__"]
+ if not isinstance(f_builtins, dict):
+ f_builtins = f_builtins.__dict__
+ super(InliningInstructionTranslator, self).__init__(
+ output=parent.output,
+ f_locals={},
+ f_globals=f_globals,
+ f_builtins=f_builtins,
+ symbolic_locals=symbolic_locals,
+ symbolic_globals=symbolic_globals,
+ instructions=cleaned_instructions(code),
+ code_options={k: getattr(code, k) for k in dir(code)},
+ f_code=code,
+ )
+ self.parent = parent
+ self.symbolic_result = None
+ self.closure_cells = closure_cells
+ self.nn_module_stack = parent.nn_module_stack.copy()
+
+ @property
+ def fake_mode(self):
+ return self.parent.fake_mode
+
+ def STORE_DEREF(self, inst):
+ if inst.argval in self.closure_cells:
+ cell = self.closure_cells[inst.argval]
+ val = self.pop()
+ if isinstance(cell, ClosureVariable):
+ self.output.root_tx.symbolic_locals[cell.name] = val
+ else:
+ self.output.side_effects.store_cell(cell, val)
+ else:
+ if isinstance(
+ self.symbolic_locals.get(inst.argval),
+ variables.NewCellVariable,
+ ):
+ self.output.side_effects.store_cell(
+ self.symbolic_locals[inst.argval], self.pop()
+ )
+ else:
+ unimplemented("write to __closure__ while inlining")
+
+ def LOAD_DEREF(self, inst):
+ if inst.argval in self.closure_cells:
+ cell = self.closure_cells[inst.argval]
+ if isinstance(cell, ClosureVariable):
+ self.push(self.output.root_tx.symbolic_locals[cell.name])
+ else:
+ self.push(self.output.side_effects.load_cell(cell))
+ else:
+ maybe_sym_local = self.symbolic_locals.get(inst.argval, None)
+ if isinstance(maybe_sym_local, variables.NewCellVariable):
+ self.push(self.output.side_effects.load_cell(maybe_sym_local))
+ else:
+ super().LOAD_DEREF(inst)
+
+ def LOAD_CLOSURE(self, inst):
+ assert inst.argval in self.cell_and_freevars()
+ self.push(self.closure_cells[inst.argval])
+
+ def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
+ newvar = super().replace_all(oldvar, newvar)
+ # recursively check and update parent's locals and stack in case oldvar is from parent
+ translator = self
+ while hasattr(translator, "parent"):
+ translator = translator.parent
+ translator.update_locals_and_stack(oldvar, newvar)
+ return newvar
+
+ def should_compile_partial_graph(self):
+ return False # inlining functions is all-or-nothing
+
+ def create_call_resume_at(self, offset):
+ unimplemented("cant resume while inlining")
+
+ def RETURN_VALUE(self, inst):
+ self.symbolic_result = self.pop()
+ self.instruction_pointer = None
+
+
+class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
+ def __init__(self, *args, **kwargs):
+ super(InliningGeneratorInstructionTranslator, self).__init__(*args, **kwargs)
+ self.generated_items = []
+
+ def YIELD_VALUE(self, inst: Instruction):
+ self.generated_items.append(self.pop())
+ # TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE
+ self.push(ConstantVariable(None))
diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py
new file mode 100644
index 0000000000000..790de24e20e54
--- /dev/null
+++ b/torch/_dynamo/testing.py
@@ -0,0 +1,322 @@
+import contextlib
+import dis
+import functools
+import importlib
+import logging
+import os.path
+import sys
+import types
+import unittest
+from unittest.mock import patch
+
+import torch
+import torch.testing._internal.common_utils
+from torch import fx
+
+from . import config, eval_frame, optimize_assert, reset, utils
+from .bytecode_transformation import (
+ create_instruction,
+ debug_checks,
+ is_generator,
+ transform_code_object,
+)
+from .guards import CheckFunctionManager, GuardedCode
+from .utils import same
+
+unsupported = eval_frame.unsupported
+three = 3
+
+log = logging.getLogger(__name__)
+
+
+def run_tests(needs=()):
+ return # TEMPORARY: disable all tests
+
+ from torch.testing._internal.common_utils import (
+ IS_WINDOWS,
+ run_tests,
+ TEST_WITH_CROSSREF,
+ TEST_WITH_TORCHDYNAMO,
+ )
+
+ if (
+ TEST_WITH_TORCHDYNAMO
+ or IS_WINDOWS
+ or TEST_WITH_CROSSREF
+ or sys.version_info >= (3, 11)
+ ):
+ return # skip testing
+
+ if isinstance(needs, str):
+ needs = (needs,)
+ for need in needs:
+ if need == "cuda" and not torch.cuda.is_available():
+ return
+ else:
+ try:
+ importlib.import_module(need)
+ except ImportError:
+ return
+ run_tests()
+
+
+def clone_me(x):
+ if x is None:
+ return None
+ return x.detach().clone().requires_grad_(x.requires_grad)
+
+
+def collect_results(model, prediction, loss, example_inputs):
+ results = []
+ results.append(prediction)
+ results.append(loss)
+ if isinstance(loss, torch.Tensor) and loss.item() > 1:
+ log.warning(
+ f"High loss value alert - {loss:.2f}. Can result in unstable gradients."
+ )
+
+ grads = dict()
+ params = dict()
+ for name, param in model.named_parameters():
+ param_copy = param
+ grad = param.grad
+ # Treat None and zero grad as same
+ if param.grad is None:
+ grad = torch.zeros_like(param)
+ grads[name + ".grad"] = grad
+ params[name] = param_copy
+ results.append(grads)
+ results.append(params)
+ for example in example_inputs:
+ if isinstance(example, (tuple, list)):
+ for inp in example:
+ if isinstance(inp, torch.Tensor):
+ results.append(inp.grad)
+ else:
+ if isinstance(example, torch.Tensor):
+ results.append(example.grad)
+ return results
+
+
+def requires_bwd_pass(out):
+ if isinstance(out, torch.Tensor):
+ return out.requires_grad
+ elif isinstance(out, (list, tuple)):
+ return any([requires_bwd_pass(x) for x in out])
+ elif out is None:
+ return False
+ raise NotImplementedError("Don't know how to reduce", type(out))
+
+
+def reduce_to_scalar_loss(out):
+ """Reduce the output of a model to get scalar loss"""
+ if isinstance(out, torch.Tensor):
+ # Mean does not work on integer tensors
+ return out.sum() / out.numel()
+ elif isinstance(out, (list, tuple)):
+ return sum([reduce_to_scalar_loss(x) for x in out]) / len(out)
+ elif type(out).__name__ in (
+ "MaskedLMOutput",
+ "Seq2SeqLMOutput",
+ "CausalLMOutputWithCrossAttentions",
+ ):
+ return reduce_to_scalar_loss(out.logits)
+ elif type(out).__name__ == "SquashedNormal":
+ return out.mean.sum()
+ elif isinstance(out, dict):
+ return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len(
+ out.keys()
+ )
+ raise NotImplementedError("Don't know how to reduce", type(out))
+
+
+def debug_dir():
+ path = os.path.join(os.path.dirname(__file__), "../debug")
+ if not os.path.exists(path):
+ os.mkdir(path)
+ return path
+
+
+def debug_dump(name, code: types.CodeType, extra=""):
+ with open(os.path.join(debug_dir(), name), "w") as fd:
+ fd.write(
+ f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
+ )
+
+
+def debug_insert_nops(frame, cache_size):
+ """used to debug jump updates"""
+
+ def insert_nops(instructions, code_options):
+ instructions.insert(0, create_instruction("NOP"))
+ instructions.insert(0, create_instruction("NOP"))
+
+ if is_generator(frame.f_code):
+ return None
+
+ debug_checks(frame.f_code)
+ code = transform_code_object(frame.f_code, insert_nops)
+
+ return GuardedCode(code, CheckFunctionManager().check_fn)
+
+
+class CompileCounter:
+ def __init__(self):
+ self.frame_count = 0
+ self.op_count = 0
+
+ def __call__(self, gm: torch.fx.GraphModule, example_inputs):
+ self.frame_count += 1
+ for node in gm.graph.nodes:
+ if "call" in node.op:
+ self.op_count += 1
+ return gm.forward
+
+ def clear(self):
+ self.frame_count = 0
+ self.op_count = 0
+
+
+class CompileCounterWithBackend:
+ def __init__(self, backend):
+ self.frame_count = 0
+ self.op_count = 0
+ self.backend = backend
+
+ def __call__(self, gm: torch.fx.GraphModule, example_inputs):
+ from torchdynamo.eval_frame import lookup_backend
+
+ self.frame_count += 1
+ for node in gm.graph.nodes:
+ if "call" in node.op:
+ self.op_count += 1
+ return lookup_backend(self.backend)(gm, example_inputs)
+
+
+def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None):
+ if config.dynamic_shapes and expected_ops_dynamic is not None:
+ expected_ops = expected_ops_dynamic
+
+ actual = CompileCounter()
+ if expected_ops is None:
+ expected = CompileCounter()
+ try:
+ gm = torch.fx.symbolic_trace(fn)
+ expected(gm)
+ print("\nfx.symbolic_trace graph:")
+ gm.graph.print_tabular()
+ expected_ops = expected.op_count
+ except Exception:
+ pass # Silently ignore FX errors (not our issue)
+
+ args1 = [torch.randn(10, 10) for _ in range(nargs)]
+ args2 = [torch.randn(10, 10) for _ in range(nargs)]
+ correct1 = fn(*args1)
+ correct2 = fn(*args2)
+ reset()
+ opt_fn = optimize_assert(actual)(fn)
+ val1a = opt_fn(*args1)
+ val2a = opt_fn(*args2)
+ val1b = opt_fn(*args1)
+ val2b = opt_fn(*args2)
+ reset()
+ self.assertTrue(same(val1a, correct1))
+ self.assertTrue(same(val1b, correct1))
+ self.assertTrue(same(val2a, correct2))
+ self.assertTrue(same(val2b, correct2))
+ self.assertEqual(actual.frame_count, 1)
+ if expected_ops is not None:
+ self.assertEqual(actual.op_count, expected_ops)
+
+
+class TestCase(torch.testing._internal.common_utils.TestCase):
+ @classmethod
+ def tearDownClass(cls):
+ cls._exit_stack.close()
+ super().tearDownClass()
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._exit_stack = contextlib.ExitStack()
+ cls._exit_stack.enter_context(
+ patch.object(config, "raise_on_backend_error", True)
+ )
+ cls._exit_stack.enter_context(
+ patch.object(config, "raise_on_ctx_manager_usage", True)
+ )
+
+ def setUp(self):
+ super().setUp()
+ reset()
+ utils.counters.clear()
+
+ def tearDown(self):
+ for k, v in utils.counters.items():
+ print(k, v.most_common())
+ reset()
+ utils.counters.clear()
+ super().tearDown()
+
+
+def dummy_fx_compile(gm: fx.GraphModule, example_inputs):
+ return gm.forward
+
+
+def format_speedup(speedup, pvalue, is_correct=True, pvalue_threshold=0.1):
+ if not is_correct:
+ return "ERROR"
+ if pvalue > pvalue_threshold:
+ return f"{speedup:.3f}x SAME"
+ return f"{speedup:.3f}x p={pvalue:.2f}"
+
+
+def requires_static_shapes(fn):
+ @functools.wraps(fn)
+ def _fn(*args, **kwargs):
+ if config.dynamic_shapes:
+ raise unittest.SkipTest("requires static shapes")
+ return fn(*args, **kwargs)
+
+ return _fn
+
+
+def rand_strided(size, stride, dtype=torch.float32, device="cpu"):
+ needed_size = sum((shape - 1) * stride for shape, stride in zip(size, stride)) + 1
+ if dtype.is_floating_point:
+ buffer = torch.randn(needed_size, dtype=dtype, device=device)
+ else:
+ buffer = torch.ones(size=[needed_size], dtype=dtype, device=device)
+ return torch.as_strided(buffer, size, stride)
+
+
+def _make_fn_with_patches(fn, *patches):
+ @functools.wraps(fn)
+ def _fn(*args, **kwargs):
+ with contextlib.ExitStack() as stack:
+ for attr, val in patches:
+ stack.enter_context(patch.object(config, attr, val))
+
+ return fn(*args, **kwargs)
+
+ return _fn
+
+
+def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches):
+ class DummyTestClass(cls):
+ pass
+
+ DummyTestClass.__name__ = f"{cls_prefix}{cls.__name__}"
+
+ for name in dir(cls):
+ if name.startswith("test_"):
+ fn = getattr(cls, name)
+ if not callable(fn):
+ continue
+ new_name = f"{name}{fn_suffix}"
+ fn = _make_fn_with_patches(fn, *patches)
+ fn.__name__ = new_name
+ setattr(DummyTestClass, name, None)
+ setattr(DummyTestClass, new_name, fn)
+
+ return DummyTestClass
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
new file mode 100644
index 0000000000000..b66c240e0f04d
--- /dev/null
+++ b/torch/_dynamo/utils.py
@@ -0,0 +1,930 @@
+import collections
+import contextlib
+import copy
+import cProfile
+import dataclasses
+import dis
+import functools
+import gc
+import inspect
+import itertools
+import logging
+import logging.config
+import math
+import operator
+import os
+import pstats
+import re
+import sys
+import time
+import types
+import weakref
+from contextlib import contextmanager
+from functools import lru_cache
+from typing import Any, Dict
+
+import numpy as np
+
+import torch
+from torch import fx
+from torch.nn.modules.lazy import LazyModuleMixin
+
+from . import config, logging as torchdynamo_logging
+
+counters = collections.defaultdict(collections.Counter)
+troubleshooting_url = (
+ "https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md"
+)
+
+log = logging.getLogger(__name__)
+
+# profiling compilation time
+compilation_metrics = collections.OrderedDict()
+
+
+timer_counter = itertools.count()
+
+
+def tabulate(rows, headers):
+ try:
+ import tabulate
+
+ return tabulate.tabulate(rows, headers=headers)
+ except ImportError:
+ return "\n".join(
+ ", ".join(map(str, row)) for row in itertools.chain([headers], rows)
+ )
+
+
+def dynamo_profiled(func):
+ def profile_wrapper(*args, **kwargs):
+ global timer_counter
+ datafn = (
+ func.__name__ + f"{next(timer_counter)}.profile"
+ ) # Name the data file sensibly
+ prof = cProfile.Profile()
+ prof.enable()
+ retval = prof.runcall(func, *args, **kwargs)
+ prof.disable()
+ print(f"### Cprofile for {func.__name__} iter {next(timer_counter)} ###")
+ ps = pstats.Stats(prof)
+ ps.sort_stats(pstats.SortKey.TIME).print_stats(20)
+ ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20)
+ prof.dump_stats(datafn)
+ return retval
+
+ return profile_wrapper
+
+
+def dynamo_timed(func):
+ def time_wrapper(*args, **kwargs):
+ key = func.__qualname__
+ if key not in compilation_metrics:
+ compilation_metrics[key] = []
+ t0 = time.time()
+ r = func(*args, **kwargs)
+ compilation_metrics[key].append(time.time() - t0)
+ return r
+
+ return time_wrapper
+
+
+def compile_times(repr="str", aggregate=False):
+ """
+ Get metrics about torchdynamo frontend/backend compilation times.
+
+ Accumulates information from functions tagged with `@dynamo_timed`.
+
+ repr='str' returns a printable string for user interaction, and 'csv'
+ returns headers, rows which can be logged for output
+
+ aggregate causes values from multiple compilations (e.g. split graphs)
+ to be accumulated into one value. If false, expect more than one value
+ per metric.
+ """
+
+ def fmt_fn(values, item_fn=lambda x: x):
+
+ if aggregate:
+ return item_fn(sum(values))
+ return ", ".join(map(item_fn, values))
+
+ if repr == "str":
+ rows = [
+ (k, fmt_fn(compilation_metrics[k], item_fn=lambda x: f"{x:.4f}"))
+ for k in compilation_metrics
+ ]
+ out = "TorchDynamo compilation metrics:\n"
+ out += tabulate(rows, headers=("Function", "Runtimes (s)"))
+ return out
+ elif repr == "csv":
+ values = [
+ fmt_fn(v, item_fn=lambda x: f"{x:.6f}")
+ for v in compilation_metrics.values()
+ ]
+ headers = list(compilation_metrics.keys())
+ return headers, values
+
+
+tensortype_to_dtype = {
+ torch.FloatTensor: (torch.float32, torch.float),
+ torch.DoubleTensor: (torch.float64, torch.double),
+ torch.HalfTensor: (torch.float16, torch.half),
+ torch.BFloat16Tensor: (torch.bfloat16,),
+ torch.ByteTensor: (torch.uint8,),
+ torch.CharTensor: (torch.int8,),
+ torch.LongTensor: (torch.int64, torch.long),
+ torch.IntTensor: (torch.int32, torch.int),
+ torch.ShortTensor: (torch.int16, torch.short),
+ torch.BoolTensor: (torch.bool,),
+}
+
+
+class DuplicateWarningChecker(object):
+ def __init__(self, maxsize=4096):
+ self.maxsize = maxsize
+ self.reset()
+
+ def reset(self):
+ self.set = collections.OrderedDict()
+
+ def add(self, key):
+ if key in self.set:
+ self.set.move_to_end(key, last=True)
+ if not config.verbose:
+ return False
+ else:
+ self.set[key] = None
+ while len(self.set) > self.maxsize:
+ self.set.popitem(last=False)
+ return True
+
+
+graph_break_dup_warning_checker = DuplicateWarningChecker()
+
+
+def init_logging():
+ torchdynamo_logging.init_logging(
+ config.log_level, log_file_name=config.log_file_name
+ )
+ graph_break_dup_warning_checker.reset()
+
+
+# filter out all frames after entering dynamo
+def filter_stack(stack):
+ user_stack = []
+ for frame in stack:
+ if "convert_frame" in frame.filename:
+ break
+ if (
+ "eval_frame" in frame.filename
+ or f"{config.dynamo_import}.optimize(" in frame.line
+ ):
+ continue
+ user_stack.append(frame)
+
+ return user_stack
+
+
+def format_graph_tabular(graph):
+ node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in graph.nodes]
+ return tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"])
+
+
+def format_bytecode(prefix, name, filename, line_no, code):
+ return f"{prefix} {name} {filename}\
+ line {line_no} \n{dis.Bytecode(code).dis()}\n "
+
+
+def gen_record_file_name(exc, code):
+ return f"{config.replay_record_dir_name}/\
+{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec"
+
+
+def write_record_to_file(filename, exec_record):
+ try:
+ if os.path.exists(filename):
+ log.warning(
+ f"Unable to write execution record {filename}; file already exists."
+ )
+ else:
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ with open(filename, "wb") as f:
+ exec_record.dump(f)
+ except Exception:
+ log.error(f"Unable to write execution record {filename}", exc_info=1)
+
+
+def count_calls(g: fx.Graph):
+ c = 0
+ for n in g.nodes:
+ if "call" in n.op:
+ c += 1
+ return c
+
+
+def identity(x):
+ return x
+
+
+def nothing(*args, **kwargs):
+ pass
+
+
+class ExactWeakKeyDictionary:
+ """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality"""
+
+ def __init__(self):
+ self.values = dict()
+ self.refs = dict()
+
+ def __getitem__(self, key):
+ return self.values[id(key)]
+
+ def get(self, key, default=None):
+ return self.values.get(id(key), default)
+
+ def __contains__(self, key):
+ return id(key) in self.values
+
+ def __setitem__(self, key, value):
+ idx = id(key)
+ if idx not in self.refs:
+ self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx))
+ self.values[idx] = value
+
+ def _remove_id(self, idx):
+ if idx in self.values:
+ del self.values[idx]
+ if idx in self.refs:
+ del self.refs[idx]
+
+ def clear(self):
+ self.refs.clear()
+ self.values.clear()
+
+
+def istype(obj, allowed_types):
+ """isinstance() without subclasses"""
+ if isinstance(allowed_types, (tuple, list, set)):
+ return type(obj) in allowed_types
+ return type(obj) is allowed_types
+
+
+def is_numpy_int_type(value):
+ return istype(
+ value,
+ (
+ np.int8,
+ np.int16,
+ np.int32,
+ np.int64,
+ np.uint8,
+ np.uint16,
+ np.uint32,
+ np.uint64,
+ ),
+ )
+
+
+def is_numpy_float_type(value):
+ return istype(
+ value,
+ (
+ np.float16,
+ np.float32,
+ np.float64,
+ ),
+ )
+
+
+def istensor(obj):
+ """Check of obj is a tensor"""
+ tensor_list = (
+ torch.Tensor,
+ torch.nn.Parameter,
+ *config.traceable_tensor_subclasses,
+ )
+ if fake_tensors_available:
+ tensor_list = tensor_list + (torch._subclasses.FakeTensor,)
+ return istype(obj, tensor_list)
+
+
+def is_lazy_module(mod):
+ return isinstance(mod, LazyModuleMixin)
+
+
+@functools.lru_cache(4096)
+def print_once(*args):
+ print(*args)
+
+
+def make_cell(val=None):
+ """Some black magic to create a cell object that usually only exists in a closure"""
+ x = val
+
+ def f():
+ return x
+
+ assert len(f.__closure__) == 1
+ return f.__closure__[0]
+
+
+def proxy_args_kwargs(args, kwargs):
+ try:
+ proxy_args = tuple(arg.as_proxy() for arg in args)
+ proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
+ return proxy_args, proxy_kwargs
+ except NotImplementedError:
+ from .exc import unimplemented
+ from .variables.base import typestr
+
+ raise unimplemented(
+ f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}"
+ )
+
+
+@dataclasses.dataclass
+class CleanupHook:
+ """Remove a global variable when hook is called"""
+
+ scope: Dict[str, Any]
+ name: str
+
+ def __call__(self, *args):
+ CleanupManager.count -= 1
+ del self.scope[self.name]
+
+ @staticmethod
+ def create(scope, name, val):
+ assert name not in scope
+ CleanupManager.count += 1
+ scope[name] = val
+ return CleanupHook(scope, name)
+
+
+class CleanupManager(ExactWeakKeyDictionary):
+ count = 0
+
+ def _remove_id(self, idx):
+ for hook in self.values[idx]:
+ hook()
+ super()._remove_id(idx)
+
+
+CleanupManager.instance = CleanupManager()
+
+
+def clone_tensor(x):
+ """Clone the tensor and its gradient"""
+ y = x.clone().requires_grad_(x.requires_grad)
+ if x.is_leaf and x.grad is not None:
+ y.grad = x.grad.clone()
+ return y
+
+
+def clone_input(x):
+ """copy while preserving strides"""
+ with torch.no_grad():
+ needed_size = sum(
+ (shape - 1) * stride for shape, stride in zip(x.size(), x.stride())
+ )
+ if x.is_quantized:
+ result = torch.empty_quantized((needed_size + 32,), x)
+ else:
+ result = torch.empty(needed_size + 32, dtype=x.dtype, device=x.device)
+ cache_line_offset = (
+ (x.data_ptr() - result.data_ptr()) % 32
+ ) // x.element_size()
+ result.as_strided_(x.size(), x.stride(), cache_line_offset)
+ try:
+ result.copy_(x.clone())
+ if x.is_leaf:
+ result.requires_grad_(x.requires_grad)
+ if x.is_leaf and x.grad is not None:
+ result.grad = clone_input(x.grad)
+ except RuntimeError:
+ # RuntimeError: unsupported operation: more than one element of the written-to
+ # tensor refers to a single memory location. Please clone() the tensor before
+ # performing the operation.
+ y = torch.clone(x)
+ if x.is_leaf:
+ y.requires_grad_(x.requires_grad)
+ if x.is_leaf and x.grad is not None:
+ y.grad = clone_input(x.grad)
+ return y
+ return result
+
+
+def clone_inputs(example_inputs):
+ if isinstance(example_inputs, dict):
+ res = dict(example_inputs)
+ for key, value in res.items():
+ assert isinstance(value, torch.Tensor)
+ res[key] = clone_input(value)
+ return res
+
+ res = list(example_inputs)
+ for i in range(len(res)):
+ if isinstance(res[i], torch.Tensor):
+ res[i] = clone_input(res[i])
+ return res
+
+
+@contextmanager
+def preserve_rng_state():
+ rng = torch.clone(torch.random.get_rng_state())
+ if torch.cuda.is_available():
+ cuda_rng = torch.clone(torch.cuda.get_rng_state())
+ try:
+ yield
+ finally:
+ torch.random.set_rng_state(rng)
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state(cuda_rng)
+
+
+def is_jit_model(model0):
+ return isinstance(
+ model0,
+ (
+ torch.jit._trace.TopLevelTracedModule,
+ torch.jit._script.RecursiveScriptModule,
+ torch.jit.ScriptFunction,
+ torch.jit.ScriptModule,
+ ),
+ )
+
+
+def torchscript(model, example_inputs, verbose=False):
+ if is_jit_model(model):
+ # already done?
+ return model
+
+ try:
+ return torch.jit.trace(model, example_inputs)
+ except Exception:
+ try:
+ return torch.jit.script(model)
+ except Exception:
+ if verbose:
+ log.exception("jit error")
+ else:
+ log.error("Both torch.jit.trace and torch.jit.script failed")
+ return None
+
+
+def getfile(obj):
+ try:
+ return inspect.getfile(obj)
+ except TypeError:
+ return None
+
+
+def is_namedtuple(obj):
+ """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple"""
+ return is_namedtuple_cls(type(obj))
+
+
+def is_namedtuple_cls(cls):
+ """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple"""
+ try:
+ if issubclass(cls, tuple):
+ bases = getattr(cls, "__bases__", []) or [None]
+ module = getattr(cls, "__module__", None)
+ return module == "torch.return_types" or (
+ bases[0] is tuple and hasattr(cls, "_make") and hasattr(cls, "_fields")
+ )
+ except TypeError:
+ pass
+ return False
+
+
+@functools.lru_cache(1)
+def namedtuple_fields(cls):
+ """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple"""
+ if cls is slice:
+ return ["start", "stop", "step"]
+
+ assert issubclass(cls, tuple)
+ if hasattr(cls, "_fields"):
+ # normal namedtuples
+ return cls._fields
+
+ @dataclasses.dataclass
+ class Marker:
+ index: int
+
+ # frustrating ones e.g. torch.return_types.max
+ assert cls.__module__ == "torch.return_types"
+ obj = cls(map(Marker, range(cls.n_fields)))
+ fields = [None] * cls.n_fields
+ for name in dir(obj):
+ if name[0] != "_" and isinstance(getattr(obj, name), Marker):
+ fields[getattr(obj, name).index] = name
+ return fields
+
+
+def checkpoint_params(gm):
+ with torch.no_grad():
+ rng_state = torch.clone(torch.random.get_rng_state())
+ if torch.cuda.is_available():
+ cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
+ saved_state = []
+ for param in itertools.chain(gm.parameters(), gm.buffers()):
+ saved_state.append((param, param._version, torch.clone(param)))
+
+ def restore():
+ with torch.no_grad():
+ torch.random.set_rng_state(rng_state)
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state(cuda_rng_state)
+ for param, version, original_value in saved_state:
+ if param._version != version:
+ param.copy_(original_value)
+
+ return restore
+
+
+def timed(model, example_inputs, times=1):
+ if torch.cuda.is_available():
+ synchronize = torch.cuda.synchronize
+ else:
+ synchronize = nothing
+
+ synchronize()
+ gc.collect()
+ torch.manual_seed(1337)
+ t0 = time.perf_counter()
+ for _ in range(times):
+ result = model(*example_inputs)
+ synchronize()
+ t1 = time.perf_counter()
+ return result, t1 - t0
+
+
+def check_is_cuda(gm, example_inputs):
+ return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True)))
+
+
+@lru_cache(32)
+def rot_n_helper(n):
+ assert n > 1
+ vars = [f"v{i}" for i in range(n)]
+ rotated = reversed(vars[-1:] + vars[:-1])
+ fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})")
+ fn.__name__ = f"rot_{n}_helper"
+ return fn
+
+
+def is_safe_constant(v):
+ if istype(v, (tuple, frozenset)):
+ return all(map(is_safe_constant, v))
+ return istype(
+ v, (types.CodeType, int, float, bool, str, bytes, type(None), slice, type(type))
+ )
+
+
+def check_constant_args(args, kwargs):
+ return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values()))
+
+
+def check_unspec_python_args(args, kwargs):
+ from .variables.constant import ConstantVariable
+ from .variables.tensor import UnspecializedPythonVariable
+
+ unspec_count = 0
+ for x in itertools.chain(args, kwargs.values()):
+ if isinstance(x, UnspecializedPythonVariable):
+ unspec_count += 1
+ elif not isinstance(x, (UnspecializedPythonVariable, ConstantVariable)):
+ return False
+ else:
+ pass
+
+ return unspec_count > 0
+
+
+def specialize_args_kwargs(tx, args, kwargs):
+ specialized_args = []
+ specialized_kwargs = {}
+ for x in args:
+ specialized_args.append(x.as_specialized(tx))
+ for k, v in kwargs.items():
+ specialized_kwargs.update({k: v.as_specialized(tx)})
+ return specialized_args, specialized_kwargs
+
+
+dict_values = type(dict().values())
+odict_values = type(collections.OrderedDict().values())
+tuple_iterator = type(iter(tuple()))
+tuple_iterator_len = tuple_iterator.__length_hint__
+object_new = object.__new__
+
+
+def product(it):
+ return functools.reduce(operator.mul, it, 1)
+
+
+def tuple_iterator_getitem(it, index):
+ _, (obj,), start = it.__reduce__()
+ return obj[start + index]
+
+
+def dict_param_key_ids(value):
+ return set([id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)])
+
+
+def dict_const_keys(value):
+ return set(k for k in value.keys() if not isinstance(k, torch.nn.Parameter))
+
+
+def global_key_name(key):
+ return f"__dict_key_{id(key)}"
+
+
+def rename_implicit(v):
+ """
+ Usage of inline comprehensions generates a implicit ".0" variable that
+ trips up guard generation. This renames these variables in guards.
+ """
+ m = re.match(r"^[.](\d+)$", v)
+ if m:
+ assert v == ".0", f"currently only .0 supported: {v}"
+ # to support .1 etc see guards.py and _eval_frame.c
+ return f"___implicit{m.group(1)}"
+ return v
+
+
+# FakeTensors were introduced after pytorch 1.12, so gate their use
+# to allow pytorch 1.12 to work
+fake_tensors_available = True
+try:
+ from torch._subclasses import ( # noqa: F401
+ FakeTensorMode,
+ UnsupportedFakeTensorException,
+ )
+
+ def wrap_fake_exception(fn):
+ try:
+ return fn()
+ except UnsupportedFakeTensorException as e:
+ from .exc import unimplemented
+
+ msg = f"Unsupported: {e.reason} with fake tensor propagation. Run with config.fake_tensor_propagation=False"
+ log.warning(msg)
+ raise unimplemented(msg)
+
+ def wrap_to_fake_tensor(e, fake_mode):
+ if type(e) in (torch.Tensor, torch.nn.Parameter):
+ return wrap_fake_exception(lambda: fake_mode.from_tensor(e))
+ else:
+ return e
+
+ def deepcopy_to_fake_tensor(obj, fake_mode):
+ with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode):
+ return wrap_fake_exception(lambda: copy.deepcopy(obj))
+
+except ImportError:
+ fake_tensors_available = False
+
+
+def rmse(ref, res):
+ """
+ Calculate root mean squared error
+ """
+ return torch.sqrt(torch.mean(torch.square(ref - res)))
+
+
+def same(
+ ref,
+ res,
+ fp64_ref=None,
+ cos_similarity=False,
+ tol=1e-4,
+ equal_nan=False,
+ exact_dtype=True,
+):
+ """Check correctness to see if ref and res match"""
+ if fp64_ref is None:
+ fp64_ref = ref
+ if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
+ assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}"
+ return len(ref) == len(res) and all(
+ same(ai, bi, fp64_refi, cos_similarity, tol, equal_nan, exact_dtype)
+ for ai, bi, fp64_refi in zip(ref, res, fp64_ref)
+ )
+ elif isinstance(ref, dict):
+ assert isinstance(res, dict)
+ assert set(ref.keys()) == set(
+ res.keys()
+ ), f"keys mismatch {set(ref.keys())} == {set(res.keys())}"
+ for k in ref.keys():
+ if not (
+ same(
+ ref[k],
+ res[k],
+ fp64_ref[k],
+ cos_similarity=cos_similarity,
+ tol=tol,
+ equal_nan=equal_nan,
+ exact_dtype=exact_dtype,
+ )
+ ):
+ log.error(f"Accuracy failed for key name {k}")
+ return False
+ return True
+ elif isinstance(ref, torch.Tensor):
+ if ref.is_sparse:
+ assert res.is_sparse
+ ref = ref.to_dense()
+ res = res.to_dense()
+ assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}"
+ if exact_dtype:
+ assert ref.dtype == res.dtype, f"dtype mismatch {ref.dtype}, {res.dtype}"
+ if ref.dtype == torch.bool:
+ # triton stores bool as int8, so add this for more accurate checking
+ return torch.allclose(
+ ref.to(dtype=torch.uint8),
+ res.to(dtype=torch.uint8),
+ atol=tol,
+ rtol=tol,
+ equal_nan=equal_nan,
+ )
+ if cos_similarity:
+ ref = ref.flatten().to(torch.float32)
+ res = res.flatten().to(torch.float32)
+ if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True):
+ # early exit that handles zero/nan better
+ # cosine_similarity(zeros(10), zeros(10), dim=0) is 0
+ return True
+ res = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6)
+ if res < 0.99:
+ log.warning(f"Similarity score={res.cpu().detach().item()}")
+ return res >= 0.99
+ else:
+ if not exact_dtype:
+ ref = ref.to(res.dtype)
+
+ # First try usual allclose
+ if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan):
+ return True
+
+ # Check error from fp64 version
+ if fp64_ref.dtype == torch.float64:
+ ref_error = rmse(fp64_ref, ref).item()
+ res_error = rmse(fp64_ref, res).item()
+ multiplier = 2.0
+
+ if fp64_ref.numel() < 1000 or (
+ ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1
+ ):
+ # In the presence of noise, noise might dominate our error
+ # metric for smaller tensors.
+ # Similary, for 1x1 kenerls, there seems to be high noise with amp.
+ multiplier = 3.0
+
+ passes_test = res_error <= (multiplier * ref_error + 1e-4)
+ if not passes_test:
+ log.error(
+ f"RMSE (res-fp64): {res_error:.5f}, (ref-fp64): {ref_error:.5f} and shape={res.size()}"
+ )
+ # import pdb; pdb.set_trace()
+ return passes_test
+
+ return False
+ elif isinstance(ref, (str, int, type(None), bool, torch.device)):
+ return ref == res
+ elif isinstance(ref, float):
+ return math.isclose(ref, res, rel_tol=tol, abs_tol=tol)
+ elif is_numpy_int_type(ref) or is_numpy_float_type(ref):
+ return (type(ref) is type(res)) and (ref == res)
+ elif type(ref).__name__ in (
+ "MaskedLMOutput",
+ "Seq2SeqLMOutput",
+ "CausalLMOutputWithCrossAttentions",
+ "LongformerMaskedLMOutput",
+ "Instances",
+ "SquashedNormal",
+ "Boxes",
+ "Normal",
+ "TanhTransform",
+ "Foo",
+ "Variable",
+ ):
+ assert type(ref) is type(res)
+ return all(
+ same(
+ getattr(ref, key),
+ getattr(res, key),
+ getattr(fp64_ref, key),
+ cos_similarity=cos_similarity,
+ tol=tol,
+ equal_nan=equal_nan,
+ exact_dtype=exact_dtype,
+ )
+ for key in ref.__dict__.keys()
+ )
+ else:
+ raise RuntimeError(f"unsupported type: {type(ref).__name__}")
+
+
+def format_func_info(code):
+ short_filename = code.co_filename.split("/")[-1]
+ return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})"
+
+
+@contextlib.contextmanager
+def disable_cache_limit():
+ prior = config.cache_size_limit
+ config.cache_size_limit = sys.maxsize
+
+ try:
+ yield
+ finally:
+ pass
+ config.cache_size_limit = prior
+
+
+# map from transformed code back to original user code
+orig_code_map = ExactWeakKeyDictionary()
+
+# keep a record of code_obj -> list of guard failure reasons for logging
+guard_failures = collections.defaultdict(list)
+
+
+class CompileProfiler:
+ """Utility for profiling how and what dynamo would compile.
+
+ Can be used for
+ * diagnosing recompilation issues
+ * determining an appropriate compile cache limit
+ * (TODO)confirming which functions got compiled/skipped
+ """
+
+ def __init__(self):
+ self.frame_count = 0
+ self.op_count = 0
+ self.backend_ctx_ctor = lambda: disable_cache_limit()
+
+ def __call__(self, gm: torch.fx.GraphModule, example_inputs):
+ self.frame_count += 1
+ for node in gm.graph.nodes:
+ if "call" in node.op:
+ self.op_count += 1
+ return gm.forward
+
+ def get_metrics(self):
+ return {"guard_failures": guard_failures}
+
+ def report(self):
+ metrics = self.get_metrics()
+ gf = metrics["guard_failures"]
+
+ def num_recompiles(code):
+ return len(gf[code])
+
+ def recompile_reasons(code):
+ return "\n".join([str(x) for x in gf[code]])
+
+ summarized_gf = [
+ [format_func_info(code), num_recompiles(code), recompile_reasons(code)]
+ for code in gf
+ ]
+ rpt = "Torchdynamo Profiler Report\n"
+ if "graph_break" in counters:
+ rpt += "\n"
+ rpt += "The following conditions caused torchdynamo to break out of tracing and fall back to python.\n"
+ rpt += (
+ f"You may gain additional insight by passing `nopython=True` to {config.dynamo_import}.optimize, "
+ "to break on the first condition.\n"
+ )
+ graph_breaks = counters["graph_break"]
+ rpt += tabulate(
+ [[msg, graph_breaks[msg]] for msg in graph_breaks],
+ headers=["Graph Break Reason", "Count"],
+ )
+
+ if len(gf):
+ max_recompiles = max([num_recompiles(code) for code in gf])
+ rpt += "\n"
+ rpt += (
+ "These subgraphs were recompiled more than once due to guard failures."
+ )
+ rpt += (
+ "Guard failures indicate some condition assumed to be static by the tracer changed, "
+ "making it unsafe to reuse the compiled program."
+ )
+ rpt += tabulate(
+ summarized_gf,
+ headers=["Function", "Num Recompiles", "Recompile Reasons"],
+ )
+ rpt += "\n"
+ rpt += (
+ f"Set {config.dynamo_import}.config.cache_size_limit to "
+ f"{max_recompiles} to avoid being cache limited.\n"
+ )
+ else:
+ rpt += "No cache-limited recompilations detected.\n"
+
+ return rpt
diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py
new file mode 100644
index 0000000000000..8c80557e3fd01
--- /dev/null
+++ b/torch/_dynamo/variables/__init__.py
@@ -0,0 +1,88 @@
+from .base import VariableTracker
+from .builtin import BuiltinVariable
+from .constant import ConstantVariable, EnumVariable
+from .dicts import ConstDictVariable, DataClassVariable, DefaultDictVariable
+from .functions import (
+ NestedUserFunctionVariable,
+ UserFunctionVariable,
+ UserMethodVariable,
+)
+from .lists import (
+ BaseListVariable,
+ ListIteratorVariable,
+ ListVariable,
+ NamedTupleVariable,
+ RangeVariable,
+ SliceVariable,
+ TupleVariable,
+)
+from .misc import (
+ AutogradFunctionVariable,
+ BlackHoleVariable,
+ ClosureVariable,
+ ContextWrappingVariable,
+ GetAttrVariable,
+ GradModeVariable,
+ InspectSignatureVariable,
+ LambdaVariable,
+ NewCellVariable,
+ NewGlobalVariable,
+ NumpyVariable,
+ PythonModuleVariable,
+ SuperVariable,
+ UnknownVariable,
+ WithExitFunctionVariable,
+)
+from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
+from .tensor import (
+ FakeItemVariable,
+ TensorVariable,
+ UnspecializedNumpyVariable,
+ UnspecializedPythonVariable,
+)
+from .torch import TorchVariable
+from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
+
+__all__ = [
+ "AutogradFunctionVariable",
+ "BaseListVariable",
+ "BlackHoleVariable",
+ "BuiltinVariable",
+ "ClosureVariable",
+ "ConstantVariable",
+ "ConstDictVariable",
+ "ContextWrappingVariable",
+ "DataClassVariable",
+ "DefaultDictVariable",
+ "EnumVariable",
+ "FakeItemVariable",
+ "GetAttrVariable",
+ "GradModeVariable",
+ "InspectSignatureVariable",
+ "LambdaVariable",
+ "ListIteratorVariable",
+ "ListVariable",
+ "NamedTupleVariable",
+ "NestedUserFunctionVariable",
+ "NewCellVariable",
+ "NewGlobalVariable",
+ "NNModuleVariable",
+ "NumpyVariable",
+ "PythonModuleVariable",
+ "RangeVariable",
+ "SliceVariable",
+ "SuperVariable",
+ "TensorVariable",
+ "TorchVariable",
+ "TupleVariable",
+ "UnknownVariable",
+ "UnspecializedNNModuleVariable",
+ "UnspecializedNumpyVariable",
+ "UnspecializedPythonVariable",
+ "UserDefinedClassVariable",
+ "UserDefinedObjectVariable",
+ "UserFunctionVariable",
+ "UserMethodVariable",
+ "VariableTracker",
+ "WithExitFunctionVariable",
+]
diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py
new file mode 100644
index 0000000000000..62cddfff0cb29
--- /dev/null
+++ b/torch/_dynamo/variables/base.py
@@ -0,0 +1,262 @@
+import collections
+from typing import Any, Callable, Dict, List, Optional, Set
+
+from .. import variables
+from ..exc import unimplemented
+from ..source import AttrSource, Source
+from ..utils import dict_values, identity, istype, odict_values
+
+
+class MutableLocal:
+ """
+ Marker used to indicate this (list, iter, etc) was constructed in
+ local scope and can be mutated safely in analysis without leaking
+ state.
+ """
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return self is other
+
+
+class VariableTracker:
+ """
+ Base class for tracked locals and stack values
+
+ VariableTracker instances are immutable and should be copied in
+ order to change them.
+ """
+
+ # fields to leave unmodified in apply()
+ _nonvar_fields = ["value"]
+
+ @staticmethod
+ def propagate(*vars: List[List["VariableTracker"]]):
+ """Combine the guards from many VariableTracker into **kwargs for a new instance"""
+ guards = set()
+
+ def visit(var):
+ if type(var) in (list, tuple, dict_values, odict_values):
+ for i in var:
+ visit(i)
+ elif isinstance(var, variables.BaseListVariable):
+ guards.update(var.guards)
+ for i in var.items:
+ visit(i)
+ elif isinstance(var, variables.ConstDictVariable):
+ guards.update(var.guards)
+ visit(var.items.values())
+ else:
+ assert isinstance(var, VariableTracker), typestr(var)
+ guards.update(var.guards)
+
+ visit(vars)
+ return {
+ "guards": guards,
+ }
+
+ def clone(self, **kwargs):
+ """Shallow copy with some (optional) changes"""
+ args = dict(self.__dict__)
+ args.update(kwargs)
+ return self.__class__(**args)
+
+ @classmethod
+ def copy(cls, value):
+ """Deeper (but not full) copy, leaving FX and user objects alone"""
+ return cls.apply(identity, value)
+
+ @classmethod
+ def apply(
+ cls, fn: Callable[["VariableTracker"], "VariableTracker"], value, cache=None
+ ):
+ """
+ Walk this object and call fn on all the VariableTracker
+ instances to produce a new VariableTracker with the results.
+ """
+ if cache is None:
+ cache = dict()
+
+ idx = id(value)
+ if idx in cache:
+ return cache[idx][0]
+
+ if isinstance(value, VariableTracker):
+ updated_dict = dict(value.__dict__)
+ for key in updated_dict.keys():
+ if key not in value._nonvar_fields:
+ updated_dict[key] = cls.apply(fn, updated_dict[key], cache)
+ result = fn(value.clone(**updated_dict))
+ elif istype(value, list):
+ result = [cls.apply(fn, v, cache) for v in value]
+ elif istype(value, tuple):
+ result = tuple(cls.apply(fn, v, cache) for v in value)
+ elif istype(value, collections.OrderedDict):
+ result = collections.OrderedDict(
+ cls.apply(fn, v, cache) for v in value.items()
+ )
+ elif istype(value, dict):
+ result = {k: cls.apply(fn, v, cache) for k, v in list(value.items())}
+ else:
+ result = value
+
+ # save `value` to keep it alive and ensure id() isn't reused
+ cache[idx] = (result, value)
+ return result
+
+ def add_guard(self, guard):
+ return self.clone(guards=set.union(self.guards, {guard}))
+
+ def add_guards(self, guards):
+ if guards is None:
+ return self
+ assert isinstance(guards, set)
+ return self.clone(guards=set.union(self.guards, guards))
+
+ def add_options(self, options, *more):
+ if more:
+ return self.add_options(options).add_options(*more)
+ if isinstance(options, VariableTracker):
+ return self.add_guards(options.guards)
+ assert isinstance(options, dict)
+ return self.add_guards(options.get("guards", set()))
+
+ def __str__(self):
+ return f"{self.__class__.__name__}()"
+
+ def __repr__(self):
+ return str(self)
+
+ def python_type(self):
+ raise NotImplementedError(f"{self} has no type")
+
+ def as_python_constant(self):
+ """For constants"""
+ raise NotImplementedError(f"{self} is not a constant")
+
+ def is_python_constant(self):
+ try:
+ self.as_python_constant()
+ return True
+ except NotImplementedError:
+ return False
+
+ def as_specialized(self, tx):
+ """
+ For specialized variables, return itself,
+ For unspecialized variables, convert to constant variable and return.
+ """
+ return self
+
+ def can_make_guard(self):
+ try:
+ self.make_guard(None)
+ return True
+ except NotImplementedError:
+ return False
+
+ def make_guard(self, fn):
+ if self.source:
+ return self.source.make_guard(fn)
+ raise NotImplementedError()
+
+ def replace_guards(self, guards, *fns):
+ name = self.source.name()
+ new_guards = {g for g in (guards or []) if g.name != name}
+ new_guards.update(self.source.make_guard(fn) for fn in fns)
+ return new_guards
+
+ def const_getattr(self, tx, name: str) -> Any:
+ """getattr(self, name) returning a python constant"""
+ raise NotImplementedError()
+
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
+ """getattr(self, name) returning a new variable"""
+ options = VariableTracker.propagate(self)
+ value = self.const_getattr(tx, name)
+ if not variables.ConstantVariable.is_literal(value):
+ raise NotImplementedError()
+ if self.source:
+ options["source"] = AttrSource(self.source, name)
+ return variables.ConstantVariable(value, **options)
+
+ def is_proxy(self):
+ try:
+ self.as_proxy()
+ return True
+ except NotImplementedError:
+ return False
+
+ def as_proxy(self):
+ raise NotImplementedError(str(self))
+
+ def reconstruct(self, codegen):
+ raise NotImplementedError()
+
+ def unpack_var_sequence(self, tx):
+ raise NotImplementedError()
+
+ def has_unpack_var_sequence(self, tx):
+ try:
+ self.unpack_var_sequence(tx)
+ return True
+ except NotImplementedError:
+ return False
+
+ def num_parameters(self):
+ unimplemented(f"num_parameters: {self}")
+
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
+ unimplemented(f"hasattr: {self}")
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ unimplemented(f"call_function {self} {args} {kwargs}")
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ if name == "__len__" and self.has_unpack_var_sequence(tx):
+ assert not (args or kwargs)
+ return variables.ConstantVariable(
+ len(self.unpack_var_sequence(tx)), **VariableTracker.propagate(self)
+ )
+ elif (
+ name == "__getattr__"
+ and len(args) == 1
+ and args[0].is_python_constant()
+ and not kwargs
+ ):
+ return self.var_getattr(tx, args[0].as_python_constant()).add_options(
+ self, args[0]
+ )
+ raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
+
+ def __init__(
+ self,
+ guards: Optional[Set] = None,
+ source: Source = None,
+ mutable_local: MutableLocal = None,
+ ):
+ super(VariableTracker, self).__init__()
+ self.guards = guards or set()
+ self.source = source
+ self.mutable_local = mutable_local
+
+
+def typestr(*objs):
+ if len(objs) == 1:
+ (obj,) = objs
+ if isinstance(obj, VariableTracker):
+ return str(obj)
+ else:
+ return type(obj).__name__
+ else:
+ return " ".join(map(typestr, objs))
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
new file mode 100644
index 0000000000000..e22cd2bc7a0a9
--- /dev/null
+++ b/torch/_dynamo/variables/builder.py
@@ -0,0 +1,559 @@
+import collections
+import dataclasses
+import enum
+import functools
+import inspect
+import re
+import types
+from abc import ABCMeta
+from typing import Any, List
+
+import numpy as np
+from functorch.experimental.ops import PyOperator
+
+import torch
+
+from .. import config, mutation_guard, replay_record, skipfiles
+from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy
+from ..exc import unimplemented
+from ..guards import GuardBuilder, GuardSource
+from ..side_effects import SideEffects
+from ..source import (
+ AttrSource,
+ ConstantSource,
+ GetItemSource,
+ GlobalSource,
+ GlobalWeakRefSource,
+ is_constant_source,
+ RandomValueSource,
+ Source,
+ TupleIteratorGetItemSource,
+)
+from ..utils import (
+ getfile,
+ global_key_name,
+ is_namedtuple,
+ is_numpy_int_type,
+ istensor,
+ istype,
+ odict_values,
+ tuple_iterator,
+ tuple_iterator_getitem,
+ tuple_iterator_len,
+)
+from .base import MutableLocal
+from .builtin import BuiltinVariable
+from .constant import ConstantVariable, EnumVariable
+from .dicts import (
+ ConstDictVariable,
+ DataClassVariable,
+ DefaultDictVariable,
+ HFPretrainedConfigVariable,
+)
+from .functions import UserFunctionVariable
+from .lists import (
+ ListIteratorVariable,
+ ListVariable,
+ NamedTupleVariable,
+ RangeVariable,
+ SliceVariable,
+ TupleVariable,
+)
+from .misc import (
+ AutogradFunctionVariable,
+ GetAttrVariable,
+ InspectSignatureVariable,
+ LambdaVariable,
+ NumpyVariable,
+ PythonModuleVariable,
+ SkipFilesVariable,
+ TypingVariable,
+)
+from .nn_module import UnspecializedNNModuleVariable
+from .tensor import (
+ TensorVariable,
+ TensorWithTFOverrideVariable,
+ UnspecializedNumpyVariable,
+ UnspecializedPythonVariable,
+)
+from .torch import (
+ tensor_dunder_fns,
+ torch_special_class_types,
+ TorchPyOperator,
+ TorchVariable,
+)
+from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
+
+
+@dataclasses.dataclass
+class GraphArg:
+ source: Source
+ example: Any
+ is_unspecialized: bool
+
+ def __post_init__(self):
+ if isinstance(self.example, torch._subclasses.fake_tensor.FakeTensor):
+ raise AssertionError("Fake Tensor observed in TorchDynamo Fx graph inputs")
+
+ def load(self, tx):
+ return self.source.reconstruct(tx)
+
+ def get_examples(self):
+ return [self.example]
+
+ def __len__(self):
+ return 1
+
+ def erase(self):
+ self.example = None
+
+
+class VariableBuilder:
+ """Wrap a python value in a VariableTracker() instance"""
+
+ def __init__(
+ self,
+ tx,
+ source: Source,
+ ):
+ super(VariableBuilder, self).__init__()
+ self.tx = tx
+ self.source = source
+ self.name = source.name()
+
+ def __call__(self, value):
+ if value in self.tx.output.side_effects:
+ # TODO(jansel): add guard for alias relationship
+ return self.tx.output.side_effects[value]
+ return self._wrap(value).clone(**self.options())
+
+ @staticmethod
+ @functools.lru_cache(None)
+ def _common_constants():
+ return set(range(17)).union(
+ {
+ 20,
+ 30,
+ 40,
+ 32,
+ 64,
+ 96,
+ 128,
+ 144,
+ 240,
+ 256,
+ 672,
+ 1024,
+ 2048,
+ 4096,
+ 0.1,
+ 0.01,
+ 0.001,
+ 0.5,
+ 0.05,
+ 800,
+ 1.873536229133606,
+ 4.135166556742356, # Work around for vision_maskrcnn where torch.clamp can't be on different devices
+ }
+ )
+
+ @staticmethod
+ def list_type(value):
+ if is_namedtuple(value):
+ return functools.partial(NamedTupleVariable, tuple_cls=type(value))
+ return {
+ tuple: TupleVariable,
+ list: ListVariable,
+ odict_values: ListVariable,
+ torch.nn.ParameterList: ListVariable,
+ torch.nn.ModuleList: ListVariable,
+ }[type(value)]
+
+ def get_source(self):
+ return self.source
+
+ def options(self):
+ return {"source": self.get_source()}
+
+ def make_guards(self, *guards):
+ source = self.get_source()
+ if (
+ isinstance(source, ConstantSource)
+ or source.guard_source() == GuardSource.CONSTANT
+ ):
+ return None
+ return {source.make_guard(guard) for guard in guards}
+
+ def _wrap(self, value):
+ make_guards = self.make_guards
+ if istensor(value):
+ return self.wrap_tensor(value)
+ elif istype(value, (tuple, list, odict_values)) or is_namedtuple(value):
+ # One can index a tensor with a list/tuple. Therefore, we need to
+ # have a stricter match.
+ if istype(value, (tuple, list)) and all(
+ [isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value]
+ ):
+ guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
+ else:
+ guards = self.make_guards(GuardBuilder.LIST_LENGTH)
+ output = [
+ VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
+ item
+ ).add_guards(guards)
+ for i, item in enumerate(value)
+ ]
+ result = self.list_type(value)(output, guards=guards)
+ if istype(value, list):
+ return self.tx.output.side_effects.track_list(
+ self.source, value, result
+ )
+ return result
+ elif istype(value, tuple_iterator):
+ guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
+ output = [
+ VariableBuilder(
+ self.tx, TupleIteratorGetItemSource(self.get_source(), i)
+ )(tuple_iterator_getitem(value, i)).add_guards(guards)
+ for i in range(tuple_iterator_len(value))
+ ]
+ return ListIteratorVariable(
+ output, mutable_local=MutableLocal(), guards=guards
+ )
+ elif istype(value, range):
+ guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
+ return RangeVariable(value=value, guards=guards)
+ elif istype(
+ value, (dict, collections.defaultdict, collections.OrderedDict)
+ ) and all(
+ map(
+ lambda k: ConstantVariable.is_literal(k)
+ or isinstance(k, torch.nn.Parameter),
+ value.keys(),
+ )
+ ):
+ guards = self.make_guards(GuardBuilder.DICT_KEYS)
+
+ # store key variables in global location for reconstruction
+ for key in value.keys():
+ if isinstance(key, torch.nn.Parameter):
+ self.tx.store_dict_key(global_key_name(key), key)
+
+ def index_source(key):
+ if isinstance(key, torch.nn.Parameter):
+ return GlobalWeakRefSource(global_key_name(key))
+ else:
+ return key
+
+ result = dict(
+ [
+ (
+ k,
+ VariableBuilder(
+ self.tx, GetItemSource(self.get_source(), index_source(k))
+ )(value[k]).add_guards(guards),
+ )
+ for k in value.keys()
+ ]
+ )
+
+ if istype(value, collections.defaultdict):
+ result = DefaultDictVariable(
+ result, type(value), value.default_factory, guards=guards
+ )
+ else:
+ result = ConstDictVariable(result, type(value), guards=guards)
+
+ return self.tx.output.side_effects.track_dict(self.source, value, result)
+ elif isinstance(value, torch.nn.Module):
+ if mutation_guard.is_dynamic_nn_module(value):
+ # created dynamically, don't specialize on it
+ result = UnspecializedNNModuleVariable(
+ value, guards=make_guards(GuardBuilder.TYPE_MATCH)
+ )
+ if not SideEffects.cls_supports_mutation_side_effects(type(value)):
+ # don't allow STORE_ATTR mutation with custom __setattr__
+ return result
+ return self.tx.output.side_effects.track_object_existing(
+ self.source, value, result
+ )
+ elif issubclass(
+ value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
+ ):
+ return UnspecializedNNModuleVariable(
+ value, guards=make_guards(GuardBuilder.TYPE_MATCH)
+ )
+ else:
+ return self.tx.output.register_attr_or_module(
+ value,
+ self.name,
+ source=self.get_source(),
+ # Guards are added inside register_attr_or_module
+ )
+ elif ConstantVariable.is_literal(value) or istype(
+ value, (torch.Size, torch.device, torch.dtype)
+ ):
+ if type(value) in (int, float) and not config.specialize_int_float:
+ # unspecializing int/float by default, but still
+ # specialize for the following conditions
+ if (
+ value in self._common_constants()
+ or isinstance(self.source, GlobalSource)
+ or isinstance(self.source, GetItemSource)
+ or (
+ isinstance(self.source, AttrSource)
+ and isinstance(self.source.base, GlobalSource)
+ )
+ ):
+ return ConstantVariable(
+ value=value,
+ guards=make_guards(GuardBuilder.CONSTANT_MATCH),
+ )
+ else:
+ return self.wrap_unspecialized_primitive(value)
+ else:
+ return ConstantVariable(
+ value=value,
+ guards=make_guards(GuardBuilder.CONSTANT_MATCH),
+ )
+ elif isinstance(value, frozenset) and (
+ all(is_allowed(x) or ConstantVariable.is_literal(x) for x in value)
+ ):
+ # For frozenset, we can guard by object ID instead of value
+ # equality, this allows us to handle non-literal values
+ return ConstantVariable(
+ value=value,
+ guards=make_guards(GuardBuilder.ID_MATCH),
+ )
+ elif isinstance(value, enum.Enum):
+ return EnumVariable(
+ value=value,
+ guards=make_guards(GuardBuilder.ID_MATCH),
+ )
+ elif is_builtin_callable(value):
+ return BuiltinVariable(
+ value,
+ guards=make_guards(GuardBuilder.BUILTIN_MATCH),
+ )
+ elif is_allowed(value):
+ return TorchVariable(
+ value,
+ guards=make_guards(GuardBuilder.FUNCTION_MATCH),
+ )
+ elif value is List:
+ return TypingVariable(
+ value,
+ guards=make_guards(GuardBuilder.ID_MATCH),
+ )
+ elif value is inspect.signature:
+ return LambdaVariable(
+ InspectSignatureVariable.create,
+ guards=make_guards(GuardBuilder.FUNCTION_MATCH),
+ )
+ elif value is dataclasses.fields:
+ return LambdaVariable(
+ _dataclasses_fields_lambda,
+ guards=make_guards(GuardBuilder.FUNCTION_MATCH),
+ )
+ elif is_numpy(value):
+ return NumpyVariable(
+ value,
+ guards=make_guards(
+ GuardBuilder.FUNCTION_MATCH
+ if callable(value)
+ else GuardBuilder.TYPE_MATCH
+ ),
+ )
+ elif value in tensor_dunder_fns:
+ return TorchVariable(
+ value,
+ guards=make_guards(GuardBuilder.FUNCTION_MATCH),
+ )
+ elif (
+ istype(value, (type, types.FunctionType))
+ and skipfiles.check(getfile(value), allow_torch=True)
+ and not inspect.getattr_static(value, "_torchdynamo_inline", False)
+ ):
+ return SkipFilesVariable(
+ value, guards=make_guards(GuardBuilder.FUNCTION_MATCH)
+ )
+ elif istype(value, (type, ABCMeta)):
+ # TODO(whc) the following seems preferable but breaks some tests, debug
+ # elif inspect.isclass(value):
+ return UserDefinedClassVariable(
+ value, guards=make_guards(GuardBuilder.FUNCTION_MATCH)
+ )
+ elif value in tensor_dunder_fns:
+ return TorchVariable(
+ value,
+ guards=make_guards(GuardBuilder.FUNCTION_MATCH),
+ )
+ elif istype(value, types.FunctionType):
+ return UserFunctionVariable(
+ value,
+ guards=make_guards(GuardBuilder.FUNCTION_MATCH),
+ )
+ elif istype(value, (types.ModuleType, replay_record.DummyModule)):
+ return PythonModuleVariable(
+ value,
+ guards=make_guards(GuardBuilder.PYMODULE_MATCH),
+ )
+ elif type(value) is torch.autograd.function.FunctionMeta:
+ return AutogradFunctionVariable(
+ value, guards=make_guards(GuardBuilder.FUNCTION_MATCH)
+ )
+ elif (
+ isinstance(value, types.BuiltinFunctionType)
+ and type(getattr(value, "__self__", None))
+ is torch.autograd.function.FunctionMeta
+ and getattr(value, "__name__", "") == "apply"
+ ):
+ # handle aliased autograd function `apply` calls
+ return GetAttrVariable(
+ AutogradFunctionVariable(
+ value.__self__, guards=make_guards(GuardBuilder.FUNCTION_MATCH)
+ ),
+ "apply",
+ )
+ elif isinstance(value, (int, float, np.number)):
+ return self.wrap_unspecialized_primitive(value)
+ elif DataClassVariable.is_matching_object(value):
+ return DataClassVariable.wrap(self, value).add_guards(
+ make_guards(GuardBuilder.TYPE_MATCH)
+ )
+ elif HFPretrainedConfigVariable.is_matching_object(value):
+ return HFPretrainedConfigVariable(
+ value, guards=make_guards(GuardBuilder.TYPE_MATCH)
+ )
+ elif isinstance(value, slice):
+ items = [
+ VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
+ getattr(value, k)
+ )
+ for k in ("start", "stop", "step")
+ ]
+ return SliceVariable(items, guards=make_guards(GuardBuilder.TYPE_MATCH))
+ elif isinstance(value, PyOperator):
+ return TorchPyOperator(
+ value,
+ guards=self.make_guards(
+ GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH
+ ),
+ )
+ elif type(value).__name__ == "builtin_function_or_method" and isinstance(
+ value.__self__, torch_special_class_types
+ ):
+ return TorchVariable(
+ value,
+ guards=make_guards(GuardBuilder.FUNCTION_MATCH),
+ )
+ else:
+ result = UserDefinedObjectVariable(
+ value,
+ guards=self.make_guards(GuardBuilder.TYPE_MATCH),
+ )
+ if not SideEffects.cls_supports_mutation_side_effects(type(value)):
+ # don't allow STORE_ATTR mutation with custom __setattr__
+ return result
+ return self.tx.output.side_effects.track_object_existing(
+ self.source, value, result
+ )
+
+ def wrap_tensor(self, value: torch.Tensor):
+ if self.get_source().guard_source().is_nn_module():
+ return self.tx.output.register_attr_or_module(
+ value,
+ self.name,
+ source=self.get_source(),
+ # Guards are done inside register_attr_or_module
+ # guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
+ )
+ else:
+ if not is_constant_source(self.get_source()):
+ self.tx.output.graphargs.append(
+ GraphArg(self.get_source(), value, False)
+ )
+ # Disable __torch_function__ to prevent cloning of `value` to hit
+ # us
+ with torch._C.DisableTorchFunction():
+ if is_constant_source(self.get_source()):
+ return self.tx.output.register_attr_or_module(
+ value,
+ re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
+ source=None,
+ # Guards are added inside register_attr_or_module
+ )
+ tensor_variable = TensorVariable.create(
+ tx=self.tx,
+ proxy=self.tx.output.create_graph_input(
+ re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value)
+ ),
+ example_value=value,
+ guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
+ )
+ if torch.overrides.has_torch_function_unary(value):
+ subclass_torch_function__func = value.__torch_function__.__func__
+ subclass_type = type(value)
+ return TensorWithTFOverrideVariable(
+ tensor_variable,
+ self.get_source(),
+ subclass_torch_function__func,
+ subclass_type,
+ )
+ return tensor_variable
+
+ def wrap_unspecialized_primitive(self, value):
+ if self.name in self.tx.output.unspec_variable_map:
+ return self.tx.output.unspec_variable_map[self.name]
+ else:
+ wrapped_value = torch.tensor(value)
+ if not is_constant_source(self.get_source()):
+ self.tx.output.graphargs.append(
+ GraphArg(self.get_source(), wrapped_value, True)
+ )
+ if not isinstance(self.get_source(), RandomValueSource):
+ guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH, True)}
+ options = {"guards": guards}
+ else:
+ options = {}
+ options.update({"source": self.get_source()})
+ options.update({"raw_value": value})
+
+ proxy = self.tx.output.create_graph_input(
+ re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(wrapped_value)
+ )
+
+ if isinstance(value, np.number):
+ unspec_var = UnspecializedNumpyVariable.create(
+ tx=self.tx,
+ proxy=proxy,
+ example_value=wrapped_value,
+ **options,
+ )
+ else:
+ unspec_var = UnspecializedPythonVariable.create(
+ tx=self.tx,
+ proxy=proxy,
+ example_value=wrapped_value,
+ **options,
+ )
+ self.tx.output.unspec_variable_map[self.name] = unspec_var
+ return unspec_var
+
+
+def _dataclasses_fields_lambda(obj):
+ if isinstance(obj, UserDefinedObjectVariable):
+ value = obj.value
+ elif isinstance(obj, DataClassVariable):
+ value = obj.user_cls
+ else:
+ unimplemented(f"Dataclass fields handling fails for type {obj}")
+ items = []
+ for field in dataclasses.fields(value):
+ source = None
+ if obj.source:
+ source = GetItemSource(
+ AttrSource(obj.source, "__dataclass_fields__"), field.name
+ )
+ items.append(UserDefinedObjectVariable(field, source=source).add_options(obj))
+ return TupleVariable(items).add_options(obj)
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
new file mode 100644
index 0000000000000..71e094f08db91
--- /dev/null
+++ b/torch/_dynamo/variables/builtin.py
@@ -0,0 +1,787 @@
+import functools
+import inspect
+import itertools
+import logging
+import math
+import operator
+import types
+from typing import Dict, List
+
+import numpy as np
+
+import torch
+
+from .. import config, variables
+from ..allowed_functions import is_allowed
+from ..exc import unimplemented, Unsupported
+from ..guards import GuardBuilder
+from ..replay_record import DummyModule
+from ..source import AttrSource, is_constant_source, TypeSource
+from ..utils import (
+ check_constant_args,
+ check_unspec_python_args,
+ istype,
+ proxy_args_kwargs,
+ specialize_args_kwargs,
+)
+from .base import MutableLocal, VariableTracker
+from .dicts import ConstDictVariable
+from .tensor import DynamicShapeVariable, FakeItemVariable
+
+log = logging.getLogger(__name__)
+
+
+class BuiltinVariable(VariableTracker):
+ @staticmethod
+ @functools.lru_cache(None)
+ def _constant_fold_functions():
+ fns = {
+ abs,
+ all,
+ any,
+ bool,
+ callable,
+ chr,
+ dict,
+ divmod,
+ float,
+ int,
+ len,
+ list,
+ max,
+ min,
+ ord,
+ pow,
+ repr,
+ round,
+ set,
+ str,
+ str.format,
+ sum,
+ tuple,
+ type,
+ operator.pos,
+ operator.neg,
+ operator.not_,
+ operator.invert,
+ operator.pow,
+ operator.mul,
+ operator.matmul,
+ operator.floordiv,
+ operator.truediv,
+ operator.mod,
+ operator.add,
+ operator.sub,
+ operator.getitem,
+ operator.lshift,
+ operator.rshift,
+ operator.and_,
+ operator.or_,
+ operator.xor,
+ operator.ipow,
+ operator.imul,
+ operator.imatmul,
+ operator.ifloordiv,
+ operator.itruediv,
+ operator.imod,
+ operator.iadd,
+ operator.isub,
+ operator.ilshift,
+ operator.irshift,
+ operator.iand,
+ operator.ixor,
+ operator.ior,
+ operator.index,
+ }
+ fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
+ return fns
+
+ def can_constant_fold_through(self):
+ return self.fn in self._constant_fold_functions()
+
+ @staticmethod
+ @functools.lru_cache(None)
+ def _fx_graph_functions():
+ fns = {
+ operator.pos,
+ operator.neg,
+ operator.not_,
+ operator.invert,
+ operator.pow,
+ operator.mul,
+ operator.matmul,
+ operator.floordiv,
+ operator.truediv,
+ operator.mod,
+ operator.add,
+ operator.sub,
+ operator.getitem,
+ operator.lshift,
+ operator.rshift,
+ operator.and_,
+ operator.or_,
+ operator.xor,
+ operator.ipow,
+ operator.imul,
+ operator.imatmul,
+ operator.ifloordiv,
+ operator.itruediv,
+ operator.imod,
+ operator.iadd,
+ operator.isub,
+ operator.ilshift,
+ operator.irshift,
+ operator.iand,
+ operator.ixor,
+ operator.ior,
+ }
+ return fns
+
+ def can_insert_in_graph(self):
+ return self.fn in self._fx_graph_functions()
+
+ def __init__(self, fn, **kwargs):
+ super(BuiltinVariable, self).__init__(**kwargs)
+ self.fn = fn
+
+ def __str__(self):
+ if self.fn is None:
+ name = "None"
+ else:
+ name = self.fn.__name__
+
+ return f"{self.__class__.__name__}({name})"
+
+ def python_type(self):
+ return type(self.fn)
+
+ def as_python_constant(self):
+ return self.fn
+
+ def reconstruct(self, codegen):
+ name = self.fn.__name__
+ assert self.fn.__module__ == "builtins"
+ assert name not in codegen.tx.f_globals, "shadowed global"
+ return [codegen.create_load_global(name, add=True)]
+
+ def constant_args(self, *args, **kwargs):
+ return check_constant_args(args, kwargs)
+
+ def tensor_args(self, *args, **kwargs):
+ return any(
+ isinstance(i, variables.TensorVariable)
+ for i in itertools.chain(args, kwargs.values())
+ ) and not any(
+ isinstance(i, variables.GetAttrVariable)
+ for i in itertools.chain(args, kwargs.values())
+ )
+
+ def unspec_numpy_args(self, *args, **kwargs):
+ return all(
+ isinstance(
+ i,
+ (
+ variables.UnspecializedNumpyVariable,
+ variables.UnspecializedPythonVariable,
+ variables.ConstantVariable,
+ ),
+ )
+ for i in itertools.chain(args, kwargs.values())
+ ) and any(
+ isinstance(x, variables.UnspecializedNumpyVariable)
+ for x in itertools.chain(args, kwargs.values())
+ )
+
+ def unspec_python_args(self, *args, **kwargs):
+ return check_unspec_python_args(args, kwargs)
+
+ @staticmethod
+ def unwrap_unspec_args_kwargs(args, kwargs):
+ unwrapped_args = []
+ unwrapped_kwargs = {}
+ for x in args:
+ if isinstance(
+ x,
+ (
+ variables.UnspecializedNumpyVariable,
+ variables.UnspecializedPythonVariable,
+ ),
+ ):
+ unwrapped_args.append(x.raw_value)
+ else:
+ unwrapped_args.append(x.as_python_constant())
+ for k, v in kwargs:
+ if isinstance(
+ x,
+ (
+ variables.UnspecializedNumpyVariable,
+ variables.UnspecializedPythonVariable,
+ ),
+ ):
+ unwrapped_kwargs.update({k: v.raw_value})
+ else:
+ unwrapped_kwargs.update({k: v.as_python_constant()})
+ return unwrapped_args, unwrapped_kwargs
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ constant_args = check_constant_args(args, kwargs)
+ tensor_args = self.tensor_args(*args, **kwargs)
+ unspec_python_args = self.unspec_python_args(*args, **kwargs)
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ has_constant_handler = self.can_constant_fold_through() and (
+ constant_args or unspec_python_args
+ )
+ assert isinstance(args, list)
+ assert isinstance(kwargs, dict)
+
+ if (
+ self.fn is operator.getitem
+ and len(args) == 2
+ and isinstance(args[1], variables.TensorVariable)
+ and args[1].dtype == torch.bool
+ and not config.dynamic_shapes
+ ):
+ unimplemented("dynamic Tensor.__getitem__(bool[])")
+
+ # args[0] is list and args[1] is unspec
+ if self.fn is operator.getitem and not isinstance(
+ args[0], variables.TensorVariable
+ ):
+ tensor_args = False
+ args, kwargs = specialize_args_kwargs(tx, args, kwargs)
+
+ if (
+ self.can_insert_in_graph()
+ and tensor_args
+ and not (
+ self.fn is operator.getitem
+ and isinstance(args[0], ConstDictVariable)
+ and isinstance(args[1], variables.TensorVariable)
+ )
+ ):
+ try:
+ fn = self.fn
+ if self.fn is operator.iadd and isinstance(
+ args[0], variables.ConstantVariable
+ ):
+ # Work around weird bug in hf_T5
+ fn, args = operator.add, [args[1], args[0]]
+
+ proxy = tx.output.create_proxy(
+ "call_function", fn, *proxy_args_kwargs(args, kwargs), current_tx=tx
+ )
+ if any([isinstance(arg, FakeItemVariable) for arg in args]):
+ return variables.FakeItemVariable.create(
+ tx,
+ proxy,
+ **options,
+ )
+ elif self.unspec_numpy_args(*args, **kwargs):
+ _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs)
+ raw_value = self.fn(*_args, **_kwargs)
+ return variables.UnspecializedNumpyVariable.create(
+ tx,
+ proxy,
+ raw_value=raw_value,
+ **options,
+ )
+ elif self.unspec_python_args(*args, **kwargs):
+ _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs)
+ raw_value = self.fn(*_args, **_kwargs)
+
+ need_unwrap = any(
+ x.need_unwrap
+ for x in itertools.chain(args, kwargs.values())
+ if isinstance(x, variables.UnspecializedPythonVariable)
+ )
+
+ return variables.UnspecializedPythonVariable.create(
+ tx,
+ proxy,
+ raw_value=raw_value,
+ need_unwrap=need_unwrap,
+ **options,
+ )
+ else:
+ # Work around for vision_maskrcnn due to precision difference
+ # specialize the dividend when float divide by tensor
+ if self.fn is operator.truediv and isinstance(
+ args[0], variables.UnspecializedPythonVariable
+ ):
+ args[0] = args[0].convert_to_constant(tx)
+ return variables.TensorVariable.create(tx, proxy, **options)
+
+ except NotImplementedError:
+ unimplemented(f"partial tensor op: {self} {args} {kwargs}")
+
+ # Handle cases like int(torch.seed())
+ if self.fn is int and isinstance(args[0], DynamicShapeVariable):
+ return args[0]
+
+ handler = getattr(self, f"call_{self.fn.__name__}", None)
+ if handler:
+ try:
+ inspect.signature(handler).bind(tx, *args, **kwargs)
+ except TypeError as exc:
+ log.warning(f"incorrect arg count {handler} {exc}")
+ handler = None
+
+ if handler:
+ try:
+ result = handler(tx, *args, **kwargs)
+ if result is not None:
+ return result.add_options(options)
+ except Unsupported as exc:
+ if not has_constant_handler:
+ raise
+ # Actually, we will handle this just fine
+ exc.remove_from_stats()
+
+ if has_constant_handler:
+ args, kwargs = specialize_args_kwargs(tx, args, kwargs)
+ # constant fold
+ return variables.ConstantVariable(
+ self.as_python_constant()(
+ *[x.as_python_constant() for x in args],
+ **{k: v.as_python_constant() for k, v in kwargs.items()},
+ ),
+ **options,
+ )
+
+ return super().call_function(tx, args, kwargs)
+
+ def _call_min_max(self, tx, a, b):
+ if self.tensor_args(a, b):
+ if not isinstance(a, variables.TensorVariable):
+ a, b = b, a
+ assert isinstance(a, variables.TensorVariable)
+
+ # 1. result of an item call is a scalar convert to a tensor
+ # 2. dynamic shape should be resolved to tensor
+ if isinstance(a, (FakeItemVariable, DynamicShapeVariable)):
+ a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {})
+
+ # convert min/max to torch ops
+ if b.is_python_constant():
+ kwargs = {"min": b} if (self.fn is max) else {"max": b}
+ result = variables.TorchVariable(torch.clamp).call_function(
+ tx, [a], kwargs
+ )
+ else:
+ fn = {max: torch.maximum, min: torch.minimum}[self.fn]
+ result = variables.TorchVariable(fn).call_function(tx, [a, b], {})
+
+ # return unspec if both a, b are unspec or const
+ if all(
+ isinstance(
+ i,
+ (
+ variables.UnspecializedNumpyVariable,
+ variables.UnspecializedPythonVariable,
+ variables.ConstantVariable,
+ ),
+ )
+ for i in [a, b]
+ ):
+
+ if any([isinstance(val, FakeItemVariable) for val in [a, b]]):
+ return variables.FakeItemVariable.from_tensor_variable(result)
+
+ if b.is_python_constant():
+ raw_b = b.as_python_constant()
+ else:
+ raw_b = b.raw_value
+ if self.fn is max:
+ raw_res = max(a.raw_value, raw_b)
+ else:
+ raw_res = min(a.raw_value, raw_b)
+
+ if isinstance(raw_res, np.number):
+ return variables.UnspecializedNumpyVariable.from_tensor_variable(
+ result, raw_res
+ )
+ else:
+ need_unwrap = any(
+ x.need_unwrap
+ for x in [a, b]
+ if isinstance(x, variables.UnspecializedPythonVariable)
+ )
+ return variables.UnspecializedPythonVariable.from_tensor_variable(
+ result, raw_res, need_unwrap
+ )
+ # otherwise return tensor
+ else:
+ return result
+ elif isinstance(a, variables.ConstantVariable) and isinstance(
+ b, variables.ConstantVariable
+ ):
+ if self.fn is max:
+ return variables.ConstantVariable(max(a.value, b.value))
+ else:
+ return variables.ConstantVariable(min(a.value, b.value))
+ else:
+ unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}")
+
+ call_min = _call_min_max
+ call_max = _call_min_max
+
+ def call_range(self, tx, *args, **kwargs):
+ if self.unspec_python_args(*args, **kwargs) or self.constant_args(
+ *args, **kwargs
+ ):
+ args, kwargs = specialize_args_kwargs(tx, args, kwargs)
+ return variables.RangeVariable(
+ value=range(
+ *[x.value for x in args],
+ **{k: v.value for k, v in kwargs.items()},
+ ),
+ )
+
+ def call_slice(self, tx, *args):
+ return variables.SliceVariable(args)
+
+ def _call_iter_tuple_list(self, tx, obj=None):
+ cls = variables.BaseListVariable.cls_for(self.fn)
+ if obj is None:
+ return cls(
+ [],
+ mutable_local=MutableLocal(),
+ )
+ elif obj.has_unpack_var_sequence(tx):
+ guards = set()
+ if obj.source and not is_constant_source(obj.source):
+ guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
+ return cls(
+ list(obj.unpack_var_sequence(tx)),
+ mutable_local=MutableLocal(),
+ guards=guards,
+ ).add_options(self, obj)
+
+ call_iter = _call_iter_tuple_list
+ call_tuple = _call_iter_tuple_list
+ call_list = _call_iter_tuple_list
+
+ def call_dict(self, tx, arg):
+ if isinstance(arg, variables.ConstDictVariable):
+ return arg.clone(mutable_local=MutableLocal())
+
+ def call_zip(self, tx, *args):
+ options = VariableTracker.propagate(self, args)
+ if all(x.has_unpack_var_sequence(tx) for x in args):
+ items = [
+ variables.TupleVariable(list(item), **options)
+ for item in zip(*[arg.unpack_var_sequence(tx) for arg in args])
+ ]
+ return variables.TupleVariable(items, **options)
+
+ def call_enumerate(self, tx, *args):
+ options = VariableTracker.propagate(self, args)
+ if len(args) == 1:
+ start = 0
+ else:
+ assert len(args) == 2
+ assert isinstance(args[1], variables.ConstantVariable)
+ start = args[1].as_python_constant()
+ if args[0].has_unpack_var_sequence(tx):
+ items = [
+ variables.TupleVariable(
+ [variables.ConstantVariable(idx, **options), var],
+ **options,
+ )
+ for idx, var in enumerate(args[0].unpack_var_sequence(tx), start)
+ ]
+ return variables.TupleVariable(items, **options)
+
+ def call_mul(self, tx, a, b):
+ if isinstance(
+ a, (variables.ListVariable, variables.TupleVariable)
+ ) and isinstance(b, variables.ConstantVariable):
+ return a.__class__(
+ items=a.items * b.as_python_constant(), mutable_local=MutableLocal()
+ ).add_options(self, a, b)
+ elif isinstance(
+ b, (variables.ListVariable, variables.TupleVariable)
+ ) and isinstance(a, variables.ConstantVariable):
+ return b.__class__(
+ items=b.items * a.as_python_constant(), mutable_local=MutableLocal()
+ ).add_options(self, a, b)
+ else:
+ return a.call_method(tx, "__mul__", [b], {})
+
+ def call_len(self, tx, *args, **kwargs):
+ return args[0].call_method(tx, "__len__", args[1:], kwargs)
+
+ def call_add(self, tx, *args, **kwargs):
+ return args[0].call_method(tx, "__add__", args[1:], kwargs)
+
+ def call_sub(self, tx, *args, **kwargs):
+ return args[0].call_method(tx, "__sub__", args[1:], kwargs)
+
+ def call_truediv(self, tx, *args, **kwargs):
+ return args[0].call_method(tx, "__truediv__", args[1:], kwargs)
+
+ def call_floordiv(self, tx, *args, **kwargs):
+ return args[0].call_method(tx, "__floordiv__", args[1:], kwargs)
+
+ def call_iadd(self, tx, *args, **kwargs):
+ return args[0].call_method(tx, "__iadd__", args[1:], kwargs)
+
+ def call_getitem(self, tx, *args, **kwargs):
+ if self.unspec_python_args(*args, **kwargs):
+ args, kwargs = specialize_args_kwargs(tx, args, kwargs)
+ return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
+
+ def call_isinstance(self, tx, arg, isinstance_type):
+ arg_type = arg.python_type()
+ isinstance_type = isinstance_type.as_python_constant()
+
+ if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
+ return variables.ConstantVariable(arg.call_isinstance(isinstance_type))
+ # UserDefinedObject with C extensions can have torch.Tensor attributes,
+ # so break graph.
+ if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
+ arg.value, types.MemberDescriptorType
+ ):
+ unimplemented(
+ f"isinstance called on UserDefinedClass {arg} {isinstance_type}"
+ )
+ try:
+ val = issubclass(arg_type, isinstance_type)
+ except TypeError:
+ val = arg_type is isinstance_type
+ return variables.ConstantVariable(val)
+
+ def call_super(self, tx, a, b):
+ return variables.SuperVariable(a, b)
+
+ def call_next(self, tx, arg):
+ if isinstance(arg, variables.ListIteratorVariable):
+ val, next_iter = arg.next_variables()
+ tx.replace_all(arg, next_iter)
+ return val
+ elif isinstance(arg, variables.BaseListVariable):
+ return arg.items[0].add_options(self, arg)
+
+ def call_hasattr(self, tx, obj, attr):
+ if attr.is_python_constant():
+ name = attr.as_python_constant()
+ return obj.call_hasattr(tx, name).add_options(self, obj, attr)
+
+ def call_map(self, tx, fn, seq):
+ if seq.has_unpack_var_sequence(tx):
+ items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)]
+ return variables.TupleVariable(items).add_options(self, fn, seq)
+
+ def call_sum(self, tx, seq, **kwargs):
+ # Special case for sum on tuple of floats and ints
+ if (
+ isinstance(seq, (variables.ListVariable, variables.TupleVariable))
+ and all(
+ [
+ isinstance(x, variables.ConstantVariable)
+ and isinstance(x.value, (int, float))
+ for x in seq.items
+ ]
+ )
+ and not kwargs
+ ):
+ new_list = [x.value for x in seq.items]
+ return variables.ConstantVariable(sum(new_list))
+ if seq.has_unpack_var_sequence(tx):
+ start = kwargs.pop(
+ "start", variables.ConstantVariable(0)
+ ).as_python_constant()
+ assert not kwargs
+ items = seq.unpack_var_sequence(tx)[start:]
+ return BuiltinVariable(functools.reduce).call_function(
+ tx,
+ [
+ BuiltinVariable(operator.add),
+ variables.TupleVariable(items),
+ variables.ConstantVariable(0).add_options(self, seq),
+ ],
+ {},
+ )
+
+ def call_reduce(self, tx, function, iterable, initializer=None):
+ if iterable.has_unpack_var_sequence(tx):
+ items = iterable.unpack_var_sequence(tx)
+ if initializer is None:
+ value, items = items[0], items[1:]
+ else:
+ value = initializer
+ for element in items:
+ value = function.call_function(tx, [value, element], {})
+ return value
+
+ def call_getattr(
+ self, tx, obj: VariableTracker, name_var: VariableTracker, default=None
+ ):
+ from . import (
+ ConstantVariable,
+ GetAttrVariable,
+ PythonModuleVariable,
+ TorchVariable,
+ UserFunctionVariable,
+ )
+ from .builder import VariableBuilder
+
+ options = VariableTracker.propagate(self, obj, name_var)
+ guards = options["guards"]
+ name = name_var.as_python_constant()
+
+ if not name_var.is_python_constant():
+ unimplemented("non-const getattr() name")
+
+ if tx.output.side_effects.is_attribute_mutation(obj):
+ try:
+ # re-read a pending side effect?
+ return tx.output.side_effects.load_attr(obj, name).add_options(options)
+ except KeyError:
+ pass
+
+ if default is not None:
+ hasattr_var = self.call_hasattr(tx, obj, name_var)
+ guards.update(hasattr_var.guards)
+ assert hasattr_var.as_python_constant() in (True, False)
+ if not hasattr_var.as_python_constant():
+ return default.add_guards(guards)
+
+ if obj.source:
+ source = AttrSource(obj.source, name)
+ options["source"] = source
+ else:
+ source = None
+
+ if isinstance(obj, variables.NNModuleVariable):
+ return obj.var_getattr(tx, name).add_options(options)
+ elif isinstance(obj, variables.TensorVariable) and name == "grad":
+ if source:
+ # We are going to be raising this tensor as grapharg. So, ensure
+ # that we have real grad value instead of fake tensor value.
+ # Walk through the inputs of the subgraph and find if we already
+ # have the original tensor stored in the graphargs.
+ for grapharg in tx.output.graphargs:
+ if grapharg.source == source.base:
+ example_value = grapharg.example.grad
+ return VariableBuilder(tx, source)(example_value).add_options(
+ options
+ )
+ unimplemented("tensor grad")
+ else:
+ unimplemented("tensor grad")
+ elif isinstance(
+ obj,
+ (
+ variables.TensorVariable,
+ variables.NamedTupleVariable,
+ variables.ConstantVariable,
+ variables.UserDefinedClassVariable,
+ variables.UserDefinedObjectVariable,
+ ),
+ ):
+ try:
+ return (
+ obj.var_getattr(tx, name).clone(source=source).add_options(options)
+ )
+ except NotImplementedError:
+ return GetAttrVariable(obj, name, **options)
+ elif isinstance(obj, TorchVariable):
+ member = getattr(obj.value, name)
+ if is_allowed(member):
+ return TorchVariable(member, **options)
+ elif ConstantVariable.is_literal(member):
+ return ConstantVariable(member, **options)
+ else:
+ return VariableBuilder(tx, source)(member).add_guards(guards)
+ elif isinstance(obj, (PythonModuleVariable, DummyModule)):
+ member = obj.value.__dict__[name]
+
+ if config.replay_record_enabled:
+ tx.exec_recorder.record_module_access(obj.value, name, member)
+
+ return VariableBuilder(tx, source)(member).add_guards(guards)
+ elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
+ return ConstantVariable(
+ getattr(obj.fn, name), **VariableTracker.propagate(obj)
+ )
+ else:
+ try:
+ return (
+ obj.var_getattr(tx, name).clone(source=source).add_options(options)
+ )
+ except NotImplementedError:
+ return GetAttrVariable(obj, name, **options)
+
+ def call_setattr(
+ self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker
+ ):
+ if isinstance(obj, (variables.BlackHoleVariable, variables.DataClassVariable)):
+ return obj.call_method(tx, "__setattr__", [name_var, val], {})
+ elif (
+ tx.output.side_effects.is_attribute_mutation(obj)
+ and name_var.is_python_constant()
+ ):
+ tx.output.side_effects.store_attr(obj, name_var.as_python_constant(), val)
+ return val.add_options(self, obj, name_var)
+ elif isinstance(obj, variables.UserDefinedObjectVariable):
+ unimplemented(
+ f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}"
+ )
+ elif isinstance(obj, variables.NNModuleVariable):
+ obj.convert_to_unspecialized(tx)
+
+ def call_type(self, tx, obj: VariableTracker):
+ from .builder import VariableBuilder
+
+ try:
+ py_type = obj.python_type()
+ except NotImplementedError:
+ py_type = None
+
+ if istype(obj, variables.TupleVariable):
+ return BuiltinVariable(py_type).add_options(self, obj)
+
+ if py_type is not None and obj.source:
+ return VariableBuilder(tx, TypeSource(obj.source))(py_type).add_options(
+ self, obj
+ )
+
+ unimplemented(f"type({obj})")
+
+ def call_reversed(self, tx, obj: VariableTracker):
+ if obj.has_unpack_var_sequence(tx):
+ items = list(reversed(obj.unpack_var_sequence(tx)))
+ return variables.TupleVariable(
+ items, **VariableTracker.propagate(self, obj)
+ )
+
+ def call_chain(self, tx, *args):
+ if all(obj.has_unpack_var_sequence(tx) for obj in args):
+ items = []
+ for obj in args:
+ items.extend(obj.unpack_var_sequence(tx))
+ return variables.TupleVariable(
+ items, **VariableTracker.propagate(self, *args)
+ )
+
+ def call_islice(self, tx, iterable, *args):
+ if iterable.has_unpack_var_sequence(tx) and all(
+ x.is_python_constant() for x in args
+ ):
+ const_args = [x.as_python_constant() for x in args]
+ items = iterable.unpack_var_sequence(tx)
+ items = list(itertools.islice(items, *const_args))
+ return variables.TupleVariable(
+ items, **VariableTracker.propagate(self, iterable, *args)
+ )
+
+ def call_id(self, tx, *args):
+ if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
+ nn_mod_variable = args[0]
+ mod = tx.output.get_submodule(nn_mod_variable.module_key)
+ return variables.ConstantVariable(id(mod))
+ else:
+ unimplemented(f"call_id with args {args}")
diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py
new file mode 100644
index 0000000000000..d3366448e3799
--- /dev/null
+++ b/torch/_dynamo/variables/constant.py
@@ -0,0 +1,128 @@
+import operator
+from typing import Dict, List
+
+import torch
+
+from .. import variables
+from ..exc import unimplemented
+from ..utils import istype
+from .base import typestr, VariableTracker
+
+
+class ConstantVariable(VariableTracker):
+ def __init__(self, value, **kwargs):
+ super(ConstantVariable, self).__init__(**kwargs)
+ assert not isinstance(value, torch.Tensor)
+ self.value = value
+
+ def as_proxy(self):
+ return self.value
+
+ def __str__(self):
+ # return f"ConstantVariable({self.value})"
+ return f"ConstantVariable({type(self.value).__name__})"
+
+ def python_type(self):
+ return type(self.value)
+
+ def as_python_constant(self):
+ return self.value
+
+ @property
+ def items(self):
+ """
+ Need this when adding a BaseListVariable and a ConstantVariable together.
+ Happens in detectron2.
+ """
+ return self.unpack_var_sequence(tx=None)
+
+ def getitem_const(self, arg: VariableTracker):
+ return ConstantVariable(
+ self.value[arg.as_python_constant()],
+ **VariableTracker.propagate([self, arg]),
+ )
+
+ @staticmethod
+ def is_literal(obj):
+ if type(obj) in (int, float, bool, type(None), str):
+ return True
+ if type(obj) in (list, tuple, set, frozenset):
+ return all(ConstantVariable.is_literal(x) for x in obj)
+ return False
+
+ def unpack_var_sequence(self, tx):
+ try:
+ options = VariableTracker.propagate([self])
+ return [ConstantVariable(x, **options) for x in self.as_python_constant()]
+ except TypeError:
+ raise NotImplementedError()
+
+ def const_getattr(self, tx, name):
+ member = getattr(self.value, name)
+ if callable(member):
+ raise NotImplementedError()
+ return member
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ options = VariableTracker.propagate(self, args, kwargs.values())
+
+ if istype(self.value, tuple):
+ # empty tuple constant etc
+ return variables.TupleVariable(
+ items=self.unpack_var_sequence(tx), source=self.source, **options
+ ).call_method(tx, name, args, kwargs)
+
+ try:
+ const_args = [a.as_python_constant() for a in args]
+ const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
+ except NotImplementedError:
+ return super(ConstantVariable, self).call_method(tx, name, args, kwargs)
+
+ def has_arith_binop(num_ty):
+ return (
+ isinstance(self.value, num_ty)
+ and hasattr(operator, name)
+ and len(args) == 1
+ and args[0].is_python_constant()
+ )
+
+ if isinstance(self.value, str) and name in str.__dict__.keys():
+ assert not kwargs
+ method = getattr(self.value, name)
+ return ConstantVariable(method(*const_args, **const_kwargs), **options)
+ elif has_arith_binop(int) or has_arith_binop(float):
+ op = getattr(operator, name)
+ return ConstantVariable(op(self.value, const_args[0]), **options)
+ elif name == "__len__" and not (args or kwargs):
+ return ConstantVariable(len(self.value), **options)
+ elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
+ assert not kwargs
+ search = args[0].as_python_constant()
+ result = search in self.value
+ return ConstantVariable(result, **options)
+
+ unimplemented(f"const method call {typestr(self.value)}.{name}")
+
+
+class EnumVariable(VariableTracker):
+ def __init__(self, value, **kwargs):
+ super(EnumVariable, self).__init__(**kwargs)
+ self.value = value
+
+ def as_proxy(self):
+ return self.value
+
+ def __str__(self):
+ return f"EnumVariable({type(self.value)})"
+
+ def python_type(self):
+ return type(self.value)
+
+ def as_python_constant(self):
+ return self.value
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
new file mode 100644
index 0000000000000..26f040d503c2a
--- /dev/null
+++ b/torch/_dynamo/variables/dicts.py
@@ -0,0 +1,413 @@
+import collections
+import dataclasses
+import functools
+import inspect
+from typing import Dict, List
+
+import torch
+
+from .. import variables
+from ..bytecode_transformation import create_instruction
+from ..eval_frame import skip_code
+from ..exc import unimplemented
+from ..source import AttrSource, GlobalWeakRefSource
+from ..utils import global_key_name
+from .base import MutableLocal, VariableTracker
+from .constant import ConstantVariable
+from .tensor import TensorVariable
+
+
+class ConstDictVariable(VariableTracker):
+ def __init__(self, items, user_cls, **kwargs):
+ super(ConstDictVariable, self).__init__(**kwargs)
+ self.items = items
+ self.user_cls = user_cls
+
+ def as_proxy(self):
+ return {k: v.as_proxy() for k, v in self.items.items()}
+
+ def python_type(self):
+ return self.user_cls
+
+ def reconstruct(self, codegen):
+ for key, value in self.items.items():
+ if isinstance(key, torch.nn.Parameter):
+ codegen.extend_output(
+ [
+ codegen.create_load_global(global_key_name(key), add=True),
+ create_instruction("CALL_FUNCTION", 0),
+ ]
+ )
+ else:
+ codegen.append_output(codegen.create_load_const(key))
+ codegen(self.items[key])
+
+ return [create_instruction("BUILD_MAP", len(self.items))]
+
+ def getitem_const(self, arg: VariableTracker):
+ return self.items[ConstDictVariable.get_key(arg)].add_options(self, arg)
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ from . import ConstantVariable, TupleVariable
+
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ val = self.items
+
+ if name == "__getitem__":
+ return self.getitem_const(args[0])
+
+ elif name == "items":
+ assert not (args or kwargs)
+ return TupleVariable(
+ [
+ TupleVariable(
+ [
+ ConstDictVariable._key_to_var(
+ tx,
+ k,
+ **options,
+ ),
+ v,
+ ],
+ **options,
+ )
+ for k, v in val.items()
+ ],
+ **options,
+ )
+ elif name == "keys":
+ assert not (args or kwargs)
+ return TupleVariable(
+ [
+ ConstDictVariable._key_to_var(
+ tx,
+ k,
+ **options,
+ )
+ for k in val.keys()
+ ],
+ **options,
+ )
+
+ elif name == "values":
+ assert not (args or kwargs)
+ return TupleVariable(list(val.values()), **options)
+ elif name == "__len__":
+ assert not (args or kwargs)
+ return ConstantVariable(len(self.items), **options)
+ elif (
+ name == "__setitem__"
+ and args
+ and ConstDictVariable.is_valid_key(args[0])
+ and self.mutable_local
+ ):
+ assert not kwargs and len(args) == 2
+ k = ConstDictVariable.get_key(args[0])
+
+ if isinstance(k, torch.nn.Parameter):
+ tx.store_dict_key(global_key_name(k), k)
+ newval = collections.OrderedDict(val)
+ newval[k] = args[1]
+ return tx.replace_all(self, self.modifed(newval, **options))
+ elif (
+ name in ("pop", "get")
+ and args
+ and ConstDictVariable.is_valid_key(args[0])
+ and ConstDictVariable.get_key(args[0]) not in self.items
+ and len(args) == 2
+ ):
+ # missing item, return the default value
+ return args[1].add_options(options)
+ elif (
+ name == "pop"
+ and args
+ and ConstDictVariable.is_valid_key(args[0])
+ and self.mutable_local
+ ):
+ newval = collections.OrderedDict(val)
+ result = newval.pop(ConstDictVariable.get_key(args[0]))
+ tx.replace_all(self, self.modifed(newval, **options))
+ return result.add_options(options)
+ elif (
+ name == "update"
+ and args
+ and isinstance(args[0], ConstDictVariable)
+ and self.mutable_local
+ ):
+ newval = collections.OrderedDict(val)
+ newval.update(args[0].items)
+ result = self.modifed(newval, **options)
+ return tx.replace_all(self, result)
+ elif (
+ name in ("get", "__getattr__")
+ and args
+ and ConstDictVariable.is_valid_key(args[0])
+ and ConstDictVariable.get_key(args[0]) in self.items
+ ):
+ result = self.items[ConstDictVariable.get_key(args[0])]
+ return result.add_options(options)
+ elif (
+ name == "__contains__" and args and ConstDictVariable.is_valid_key(args[0])
+ ):
+ return ConstantVariable(
+ ConstDictVariable.get_key(args[0]) in self.items, **options
+ )
+ else:
+ return super().call_method(tx, name, args, kwargs)
+
+ def modifed(self, items, **options):
+ """a copy of self with different items"""
+ return self.clone(items=items, **options)
+
+ def unpack_var_sequence(self, tx):
+ options = VariableTracker.propagate([self])
+ val = self.items
+ result = [ConstDictVariable._key_to_var(tx, k, **options) for k in val.keys()]
+ return result
+
+ @classmethod
+ def get_key(cls, arg: VariableTracker):
+ if isinstance(arg, TensorVariable) and arg.parameter_value is not None:
+ return arg.parameter_value
+ else:
+ return arg.as_python_constant()
+
+ @classmethod
+ def is_valid_key(cls, key):
+ return (
+ key.is_python_constant()
+ or isinstance(key, TensorVariable)
+ and key.parameter_value is not None
+ )
+
+ @classmethod
+ def _key_to_var(cls, tx, key, **options):
+ from .builder import VariableBuilder
+
+ if isinstance(key, torch.nn.Parameter):
+ return VariableBuilder(tx, GlobalWeakRefSource(global_key_name(key)))(key)
+ else:
+ assert ConstantVariable.is_literal(key)
+ return ConstantVariable(key, **options)
+
+
+class DefaultDictVariable(ConstDictVariable):
+ def __init__(self, items, user_cls, default_factory=None, **kwargs):
+ super(DefaultDictVariable, self).__init__(items, user_cls, **kwargs)
+ assert user_cls is collections.defaultdict
+ self.default_factory = default_factory
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ from . import ListVariable, TupleVariable
+
+ options = VariableTracker.propagate(self, args, kwargs.values())
+
+ if name == "__getitem__":
+ k = ConstDictVariable.get_key(args[0])
+
+ if k in self.items:
+ return self.getitem_const(args[0])
+ else:
+ if self.default_factory is None:
+ raise KeyError(f"{k}")
+ else:
+ if isinstance(k, torch.nn.Parameter):
+ tx.store_dict_key(global_key_name(k), k)
+ new_val = collections.OrderedDict(self.items)
+ if self.default_factory is list:
+ default_var = ListVariable([], mutable_local=MutableLocal())
+ elif self.default_factory is tuple:
+ default_var = TupleVariable([], mutable_local=MutableLocal())
+ elif self.default_factory is dict:
+ default_var = ConstDictVariable(
+ {}, dict, mutable_local=MutableLocal()
+ )
+ else:
+ unimplemented(
+ f"defaultdict with default_factory = {self.default_factory}"
+ )
+ new_val[k] = default_var
+ tx.replace_all(self, self.modifed(new_val, **options))
+ return default_var
+ else:
+ return super().call_method(tx, name, args, kwargs)
+
+
+class DataClassVariable(ConstDictVariable):
+ """
+ This is a bit of a hack to deal with
+ transformers.file_utils.ModelOutput() from huggingface.
+
+ ModelOutput causes trouble because it a a mix of a dataclass and a
+ OrderedDict and it calls super() methods implemented in C.
+ """
+
+ # ModelOutput() excludes None, though generic datclasses don't
+ include_none = False
+
+ @staticmethod
+ @functools.lru_cache(None)
+ def _patch_once():
+ from transformers.file_utils import ModelOutput
+
+ for obj in ModelOutput.__dict__.values():
+ if callable(obj):
+ skip_code(obj.__code__)
+
+ @staticmethod
+ def is_matching_cls(cls):
+ try:
+ from transformers.file_utils import ModelOutput
+
+ return issubclass(cls, ModelOutput)
+ except ImportError:
+ return False
+
+ @classmethod
+ def is_matching_object(cls, obj):
+ return cls.is_matching_cls(type(obj))
+
+ @classmethod
+ def create(cls, user_cls, args, kwargs, options):
+ DataClassVariable._patch_once()
+
+ skip_code(user_cls.__init__.__code__)
+ keys = [f.name for f in dataclasses.fields(user_cls)]
+ bound = inspect.signature(user_cls).bind(*args, **kwargs)
+ bound.apply_defaults()
+ assert set(bound.arguments.keys()) == set(keys)
+ items = collections.OrderedDict()
+ for key in keys:
+ val = bound.arguments[key]
+ if isinstance(val, VariableTracker):
+ items[key] = val
+ else:
+ if cls.include_none:
+ assert variables.ConstantVariable.is_literal(val)
+ items[key] = variables.ConstantVariable(val)
+ else:
+ assert val is None, f"unexpected {val}"
+
+ if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
+ unimplemented("DataClassVariable iterator constructor")
+ # TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
+
+ return cls(items, user_cls, **options)
+
+ @classmethod
+ def wrap(cls, builder, obj):
+ user_cls = type(obj)
+ keys = [f.name for f in dataclasses.fields(user_cls)]
+
+ excluded = []
+ items = collections.OrderedDict()
+ for key in keys:
+ # __init__ function of a dataclass might not have yet defined the key
+ if hasattr(obj, key):
+ val = getattr(obj, key)
+ var = builder.__class__(
+ tx=builder.tx, source=AttrSource(builder.source, key)
+ )(val)
+ if val is not None or cls.include_none:
+ items[key] = var
+ else:
+ excluded.append(var)
+ return cls(
+ items, user_cls, **VariableTracker.propagate(excluded, items.values())
+ )
+
+ def __init__(self, items, user_cls, **options):
+ super(DataClassVariable, self).__init__(items, user_cls, **options)
+ assert self.is_matching_cls(user_cls)
+
+ def as_proxy(self):
+ raise NotImplementedError()
+
+ def reconstruct(self, codegen):
+ codegen.extend_output([codegen._create_load_const(self.user_cls)])
+ keys = tuple(self.items.keys())
+ for key in keys:
+ codegen(self.items[key])
+ return [
+ codegen.create_load_const(keys),
+ create_instruction("CALL_FUNCTION_KW", len(keys)),
+ ]
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ if name == "__getitem__":
+ assert not kwargs and len(args) == 1
+ index = args[0].as_python_constant()
+ if isinstance(index, str):
+ return self.items[index].add_options(options)
+ else:
+ return (
+ self.call_method(tx, "to_tuple", [], {})
+ .call_method(tx, "__getitem__", args, kwargs)
+ .add_options(options)
+ )
+ elif name == "to_tuple":
+ assert not (args or kwargs)
+ return variables.TupleVariable(list(self.items.values()), **options)
+ elif name == "__setattr__":
+ name = "__setitem__"
+ return super(DataClassVariable, self).call_method(tx, name, args, kwargs)
+
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
+ if name in self.items:
+ return self.call_method(
+ tx, "__getitem__", [variables.ConstantVariable(name)], {}
+ )
+ elif not self.include_none:
+ defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
+ if name in defaults:
+ assert variables.ConstantVariable.is_literal(defaults[name])
+ return variables.ConstantVariable(defaults[name]).add_options(self)
+ super(DataClassVariable, self).var_getattr(tx, name)
+
+
+class HFPretrainedConfigVariable(VariableTracker):
+ """
+ Hack for HuggingFace PretrainedConfig
+ """
+
+ @staticmethod
+ def is_matching_cls(cls):
+ try:
+ from transformers.configuration_utils import PretrainedConfig
+
+ return issubclass(cls, PretrainedConfig)
+ except ImportError:
+ return False
+
+ @classmethod
+ def is_matching_object(cls, obj):
+ return cls.is_matching_cls(type(obj))
+
+ def __init__(self, obj, **kwargs):
+ super(HFPretrainedConfigVariable, self).__init__(**kwargs)
+ self.obj = obj
+ assert self.is_matching_cls(type(obj))
+
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
+ from . import ConstantVariable
+
+ return ConstantVariable(getattr(self.obj, name))
diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py
new file mode 100644
index 0000000000000..75fba182ad06c
--- /dev/null
+++ b/torch/_dynamo/variables/functions.py
@@ -0,0 +1,400 @@
+import enum
+import functools
+import inspect
+import itertools
+import types
+from typing import Dict, List
+
+from .. import variables
+from ..bytecode_transformation import create_instruction
+from ..exc import unimplemented
+from ..source import AttrSource, GetItemSource
+from ..utils import make_cell
+from .base import typestr, VariableTracker
+
+
+def wrap_bound_arg(val, options):
+ if isinstance(val, dict):
+ return variables.ConstDictVariable(
+ {k: wrap_bound_arg(v, options) for k, v in val.items()}, dict, **options
+ )
+ elif isinstance(val, (tuple, list)):
+ cls = variables.BaseListVariable.cls_for(type(val))
+ return cls([wrap_bound_arg(x, options) for x in val], **options)
+ elif variables.ConstantVariable.is_literal(val):
+ return variables.ConstantVariable(val, **options)
+ elif isinstance(val, enum.Enum):
+ return variables.EnumVariable(val, **options)
+ else:
+ assert isinstance(val, VariableTracker), typestr(val)
+ return val
+
+
+def wrap_args_kwargs(result, options):
+ for k, v in list(result.items()):
+ if isinstance(v, (tuple, dict)):
+ # args/kwargs
+ result[k] = wrap_bound_arg(v, options)
+
+
+def init_cellvars(parent, result, code):
+ closure_cells = dict()
+ side_effects = parent.output.side_effects
+
+ for name in code.co_cellvars:
+ closure_cells[name] = side_effects.track_cell_new()
+ if name in result:
+ side_effects.store_cell(closure_cells[name], result.pop(name))
+
+ return closure_cells
+
+
+class BaseUserFunctionVariable(VariableTracker):
+ def get_filename(self):
+ return self.get_code().co_filename
+
+ def get_name(self):
+ return self.get_code().co_name
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ return tx.inline_user_function_return(
+ self, list(self.self_args()) + list(args), kwargs
+ )
+
+ def num_parameters(self):
+ return len(inspect.signature(self.get_function()).parameters)
+
+ def closure_vars(self, tx):
+ return {}
+
+
+class UserFunctionVariable(BaseUserFunctionVariable):
+ """Some unsupported user-defined global function"""
+
+ def __init__(self, fn, is_constant=False, **kwargs):
+ super(UserFunctionVariable, self).__init__(**kwargs)
+ if getattr(fn, "_dynamo_marked_constant", False):
+ # This method should be treated as a constant for the purposes of compilation
+ self.is_constant = True
+ else:
+ self.is_constant = False
+
+ assert isinstance(
+ fn, types.FunctionType
+ ), f"expected FunctionType found {typestr(fn)} {fn}"
+ # unpack @torchdynamo.optimize()(fn) wrapped function
+ fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
+ # unpack torch.jit.script_if_tracing
+ if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
+ fn = inspect.getattr_static(fn, "__original_fn", fn)
+ self.fn: types.FunctionType = fn
+
+ def self_args(self):
+ return []
+
+ def get_function(self):
+ return self.fn
+
+ def get_code(self):
+ return self.fn.__code__
+
+ def python_type(self):
+ return types.FunctionType
+
+ def has_self(self):
+ return getattr(self.fn, "__self__", None) is not None
+
+ def get_globals(self):
+ return self.fn.__globals__
+
+ def bind_args(self, parent, args, kwargs):
+ assert not self.is_constant
+ options = VariableTracker.propagate([self])
+ wrap = functools.partial(wrap_bound_arg, options=options)
+
+ fn: types.FunctionType = self.fn
+ fake_func = types.FunctionType(
+ fn.__code__,
+ fn.__globals__,
+ fn.__name__,
+ tuple(map(wrap, fn.__defaults__ or [])),
+ fn.__closure__,
+ )
+ if fn.__kwdefaults__:
+ fake_func.__kwdefaults__ = {
+ k: wrap(v) for k, v in fn.__kwdefaults__.items()
+ }
+
+ bound = inspect.signature(fake_func).bind(*args, **kwargs)
+ bound.apply_defaults()
+ result = dict(bound.arguments.items())
+
+ wrap_args_kwargs(result, options)
+ closure_cells = init_cellvars(parent, result, fn.__code__)
+ closure = self.fn.__closure__ or ()
+ assert len(closure) == len(self.fn.__code__.co_freevars)
+ for idx, name, cell in zip(
+ itertools.count(), self.fn.__code__.co_freevars, closure
+ ):
+ if name == "__class__":
+ result[name] = variables.UserDefinedClassVariable(cell.cell_contents)
+ else:
+ var = parent.output.root_tx.match_nested_cell(name, cell)
+ if var is not None:
+ # optimization for cleaner codegen
+ result[name] = var
+ elif self.source:
+ from .builder import VariableBuilder
+
+ side_effects = parent.output.side_effects
+ if cell in side_effects:
+ out = side_effects[cell]
+ else:
+ closure_cell = GetItemSource(
+ AttrSource(self.source, "__closure__"), idx
+ )
+ closure_cell_contents = AttrSource(
+ closure_cell, "cell_contents"
+ )
+
+ # cells are written to with "cell_contents",
+ # so the source should just be the closure_cell, not its contents
+ out = side_effects.track_cell_existing(closure_cell, cell)
+ side_effects.store_cell(
+ out,
+ VariableBuilder(parent, closure_cell_contents)(
+ cell.cell_contents
+ ),
+ )
+
+ result[name] = out
+
+ else:
+ unimplemented("inline with __closure__")
+
+ return result, closure_cells
+
+ def export_freevars(self, parent, child):
+ pass
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ if self.is_constant:
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ return invoke_and_store_as_constant(
+ tx, self.fn, self.get_name(), options, args, kwargs
+ )
+
+ return super(UserFunctionVariable, self).call_function(tx, args, kwargs)
+
+
+class UserMethodVariable(UserFunctionVariable):
+ """Some unsupported user-defined method"""
+
+ def __init__(self, fn, obj, **kwargs):
+ super(UserMethodVariable, self).__init__(fn=fn, **kwargs)
+ self.obj = obj
+
+ def __str__(self):
+ return f"{self.__class__.__name__}({self.fn}, {self.obj})"
+
+ def self_args(self):
+ return [self.obj]
+
+ def python_type(self):
+ return types.MethodType
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ if (
+ isinstance(self.obj, variables.NNModuleVariable)
+ and getattr(self.fn, "__module__", "").startswith("torch.nn.")
+ or self.is_constant
+ ):
+ return self.obj.call_method(
+ tx, self.fn.__name__, args, kwargs, constant=self.is_constant
+ ).add_options(self)
+ return super().call_function(tx, args, kwargs)
+
+ def num_parameters(self):
+ return super(UserMethodVariable, self).num_parameters() - 1
+
+
+class WrappedUserMethodVariable(UserMethodVariable):
+ def __init__(self, wrapped, context, **kwargs):
+ kwargs.pop("fn", None)
+ kwargs.pop("obj", None)
+ super(WrappedUserMethodVariable, self).__init__(
+ wrapped.fn, wrapped.obj, **kwargs
+ )
+ self.wrapped = wrapped
+ self.context = context
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ self.context.enter(tx)
+ result = super().call_function(tx, args, kwargs)
+ self.context.exit(tx)
+ return result
+
+
+class WrappedUserFunctionVariable(UserFunctionVariable):
+ def __init__(self, wrapped, context, **kwargs):
+ kwargs.pop("fn", None)
+ kwargs.pop("obj", None)
+ super(WrappedUserFunctionVariable, self).__init__(wrapped.fn, **kwargs)
+ self.wrapped = wrapped
+ self.context = context
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ self.context.enter(tx)
+ result = super().call_function(tx, args, kwargs)
+ self.context.exit(tx)
+ return result
+
+
+def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs):
+ def convert(x):
+ if isinstance(x, variables.TensorVariable):
+ return x.proxy.node.meta["example_value"]
+ return x.as_python_constant()
+
+ args = [convert(x) for x in args]
+ kwargs = {k: convert(v) for k, v in kwargs.items()}
+ res = fn(*args, **kwargs)
+ return tx.output.register_attr_or_module(
+ res,
+ name,
+ **options,
+ )
+
+
+class NestedUserFunctionVariable(BaseUserFunctionVariable):
+ def __init__(
+ self,
+ fn_name,
+ code,
+ f_globals,
+ defaults,
+ kwdefaults,
+ annotations,
+ closure,
+ closure_scope,
+ **kwargs,
+ ):
+ super(NestedUserFunctionVariable, self).__init__(**kwargs)
+ assert isinstance(fn_name.as_python_constant(), str)
+ assert isinstance(code.as_python_constant(), types.CodeType)
+ assert isinstance(f_globals, dict)
+ self.fn_name = fn_name
+ self.code = code
+ self.f_globals = f_globals
+ self.defaults = defaults
+ self.kwdefaults = kwdefaults
+ self.annotations = annotations
+ self.closure = closure
+ if closure is None:
+ closure_scope = None
+ self.closure_scope = closure_scope
+
+ def self_args(self):
+ return []
+
+ def get_code(self):
+ return self.code.as_python_constant()
+
+ def get_function(self):
+ if self.closure:
+ raise NotImplementedError()
+ func = types.FunctionType(
+ self.code.as_python_constant(),
+ self.f_globals,
+ self.fn_name.as_python_constant(),
+ )
+ if self.defaults:
+ func.__defaults__ = self.defaults.as_python_constant()
+ if self.kwdefaults:
+ func.__kwdefaults__ = self.kwdefaults.as_python_constant()
+ if self.annotations:
+ func.__annotations__ = self.annotations.as_python_constant()
+ return func
+
+ def has_closure(self):
+ return self.closure is not None
+
+ def has_self(self):
+ return False
+
+ def get_globals(self):
+ return self.f_globals
+
+ def bind_args(self, parent, args, kwargs):
+ code = self.get_code()
+ func = types.FunctionType(
+ code,
+ self.f_globals,
+ self.fn_name.as_python_constant(),
+ tuple(self.defaults.items) if self.defaults else None,
+ tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
+ )
+ if self.kwdefaults:
+ func.__kwdefaults__ = self.kwdefaults.items
+
+ bound = inspect.signature(func).bind(*args, **kwargs)
+ bound.apply_defaults()
+ result = dict(bound.arguments.items())
+
+ wrap_args_kwargs(result, VariableTracker.propagate(self))
+ closure_cells = init_cellvars(parent, result, code)
+
+ for idx, name in enumerate(code.co_freevars):
+ assert getattr(self.closure.items[idx], name, name) == name
+ assert name not in result
+ closure_cells[name] = self.closure.items[idx]
+
+ return result, closure_cells
+
+ def export_freevars(self, parent, child):
+ code = self.get_code()
+ for var in code.co_freevars:
+ if var in child.symbolic_locals:
+ parent.symbolic_locals[var] = child.symbolic_locals[var]
+
+ def reconstruct(self, codegen):
+ flags = 0x00
+ if self.defaults:
+ flags |= 0x01
+ codegen(self.defaults)
+ if self.kwdefaults:
+ flags |= 0x02
+ codegen(self.kwdefaults)
+ if isinstance(self.annotations, variables.ConstDictVariable) or isinstance(
+ self.annotations, variables.TupleVariable
+ ):
+ flags |= 0x04
+ try:
+ if isinstance(self.annotations, variables.ConstDictVariable):
+ annotations = {
+ k: v.as_python_constant()
+ for k, v in self.annotations.items.items()
+ }
+ else:
+ annotations = tuple(
+ [v.as_python_constant() for v in self.annotations.items]
+ )
+ codegen.extend_output([codegen._create_load_const(annotations)])
+ except NotImplementedError:
+ codegen(self.annotations)
+ if self.closure:
+ flags |= 0x08
+ codegen(self.closure)
+ codegen(self.code)
+ codegen(self.fn_name)
+ return [create_instruction("MAKE_FUNCTION", flags)]
diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py
new file mode 100644
index 0000000000000..e1c0d584073e4
--- /dev/null
+++ b/torch/_dynamo/variables/lists.py
@@ -0,0 +1,427 @@
+from typing import Dict, List, Optional
+
+import torch
+import torch.fx
+
+from .. import config, variables
+from ..bytecode_transformation import create_instruction
+from ..exc import unimplemented
+from ..source import GetItemSource
+from ..utils import namedtuple_fields
+from .base import MutableLocal, VariableTracker
+from .constant import ConstantVariable
+
+
+class BaseListVariable(VariableTracker):
+ @staticmethod
+ def cls_for(obj):
+ return {
+ iter: ListIteratorVariable,
+ list: ListVariable,
+ slice: SliceVariable,
+ torch.Size: SizeVariable,
+ tuple: TupleVariable,
+ }[obj]
+
+ def __init__(self, items: List[VariableTracker], **kwargs):
+ super(BaseListVariable, self).__init__(**kwargs)
+ assert isinstance(items, list)
+ assert all(isinstance(x, VariableTracker) for x in items)
+ self.items: List[VariableTracker] = items
+
+ def _as_proxy(self):
+ return [x.as_proxy() for x in self.items]
+
+ def as_python_constant(self):
+ return self.python_type()([x.as_python_constant() for x in self.items])
+
+ def as_proxy(self):
+ assert self.python_type() is not SizeVariable
+ return self.python_type()(self._as_proxy())
+
+ def getitem_const(self, arg: VariableTracker):
+ index = arg.as_python_constant()
+ if isinstance(index, slice):
+ if self.source is not None:
+ return self.clone(
+ items=self.items[index],
+ source=GetItemSource(self.source, index),
+ mutable_local=MutableLocal() if self.mutable_local else None,
+ ).add_options(arg, self)
+ else:
+ return self.clone(
+ items=self.items[index],
+ mutable_local=MutableLocal() if self.mutable_local else None,
+ ).add_options(arg, self)
+ else:
+ assert isinstance(index, int)
+ return self.items[index].add_options(arg, self)
+
+ def unpack_var_sequence(self, tx):
+ return [x.add_options(self) for x in self.items]
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ if name == "__getitem__":
+ assert not kwargs and len(args) == 1
+ return self.getitem_const(args[0])
+ elif name == "__add__":
+ assert not kwargs and len(args) == 1
+ return type(self)(self.items + args[0].items, **options)
+ elif (
+ name == "__contains__"
+ and len(args) == 1
+ and args[0].is_python_constant()
+ and all(x.is_python_constant() for x in self.items)
+ ):
+ assert not kwargs
+ search = args[0].as_python_constant()
+ result = any(x.as_python_constant() == search for x in self.items)
+ return variables.ConstantVariable(result, **options)
+
+ return super(BaseListVariable, self).call_method(tx, name, args, kwargs)
+
+
+class RangeVariable(BaseListVariable):
+ def __init__(self, value, items=None, guards=None, **kwargs):
+ if items is None:
+ items = [variables.ConstantVariable(x, guards=guards) for x in value]
+ super().__init__(items, guards=guards, **kwargs)
+ self.value = value
+
+ def python_type(self):
+ return range
+
+ def as_python_constant(self):
+ return self.value
+
+ def reconstruct(self, codegen):
+ assert "range" not in codegen.tx.f_globals
+ range_fn = codegen.create_load_global("range", add=True)
+ if self.value.step == 1:
+ if self.value.start == 0:
+ return [
+ range_fn,
+ codegen.create_load_const(self.value.stop),
+ create_instruction("CALL_FUNCTION", 1),
+ ]
+ return [
+ range_fn,
+ codegen.create_load_const(self.value.start),
+ codegen.create_load_const(self.value.stop),
+ create_instruction("CALL_FUNCTION", 2),
+ ]
+ return [
+ range_fn,
+ codegen.create_load_const(self.value.start),
+ codegen.create_load_const(self.value.stop),
+ codegen.create_load_const(self.value.step),
+ create_instruction("CALL_FUNCTION", 3),
+ ]
+
+
+class ListVariable(BaseListVariable):
+ def python_type(self):
+ return list
+
+ def reconstruct(self, codegen):
+ codegen.foreach(self.items)
+ return [create_instruction("BUILD_LIST", len(self.items))]
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ if name == "append" and self.mutable_local:
+ assert not kwargs
+ (arg,) = args
+ tx.replace_all(
+ self,
+ ListVariable(self.items + [arg], **options),
+ )
+ return ConstantVariable(None)
+ elif (
+ name in ("extend", "__iadd__")
+ and self.mutable_local
+ and args
+ and args[0].has_unpack_var_sequence(tx)
+ ):
+ assert not kwargs
+ (arg,) = args
+ return tx.replace_all(
+ self,
+ ListVariable(
+ list(self.items) + list(arg.unpack_var_sequence(tx)),
+ **options,
+ ),
+ )
+ elif name == "insert" and self.mutable_local:
+ assert not kwargs
+ idx, value = args
+ items = list(self.items)
+ items.insert(idx.as_python_constant(), value)
+ return tx.replace_all(
+ self,
+ ListVariable(items, **options),
+ )
+ elif name == "pop" and self.mutable_local:
+ assert not kwargs
+ items = list(self.items)
+ result = items.pop(*[a.as_python_constant() for a in args])
+ tx.replace_all(
+ self,
+ ListVariable(items, **options),
+ )
+ return result
+ elif name == "clear" and self.mutable_local:
+ assert not kwargs and not args
+ return tx.replace_all(
+ self,
+ ListVariable([], **options),
+ )
+ elif (
+ name == "__setitem__"
+ and self.mutable_local
+ and args
+ and args[0].is_python_constant()
+ ):
+ assert not kwargs
+ key, value = args
+ items = list(self.items)
+ if isinstance(key, SliceVariable):
+ items[key.as_python_constant()] = list(value.items)
+ else:
+ items[key.as_python_constant()] = value
+ result = ListVariable(items, **options)
+ return tx.replace_all(self, result)
+ else:
+ return super().call_method(tx, name, args, kwargs)
+
+
+class TupleVariable(BaseListVariable):
+ def python_type(self):
+ return tuple
+
+ def reconstruct(self, codegen):
+ codegen.foreach(self.items)
+ return [create_instruction("BUILD_TUPLE", len(self.items))]
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ if (
+ name in ("__add__", "__iadd__")
+ and len(args) == 1
+ and isinstance(args[0], TupleVariable)
+ ):
+ assert not kwargs
+ return TupleVariable(self.items + args[0].items, **options)
+ elif (
+ name in ("__add__", "__iadd__")
+ and len(args) == 1
+ and isinstance(args[0], variables.ConstantVariable)
+ ):
+ assert not kwargs
+ return TupleVariable(
+ self.items + list(args[0].unpack_var_sequence(self)), **options
+ )
+ return super().call_method(tx, name, args, kwargs)
+
+
+class SizeVariable(TupleVariable):
+ """torch.Size(...)"""
+
+ def __init__(
+ self,
+ items: List[VariableTracker],
+ proxy: Optional[torch.fx.Proxy] = None,
+ **kwargs,
+ ):
+ self.proxy = proxy
+ super().__init__(items, **kwargs)
+
+ def python_type(self):
+ return torch.Size
+
+ def as_proxy(self):
+ if self.proxy is not None:
+ return self.proxy
+
+ # torch.Size needs special handling. Normally, we pun a list-like
+ # container to directly contain Proxy/Node objects from FX, and FX
+ # knows to look inside containers (via map_aggregate). But torch.Size
+ # is weird; although it subclasses from tuple, it doesn't allow
+ # members which aren't int-like (rejecting Proxy and Node). This
+ # means we can't use the normal representation trick
+ # torch.Size([proxy0, proxy1]). I looked into seeing if I could
+ # relax torch.Size in PyTorch proper, but if torch.Size constructor
+ # sees a type that it doesn't recognize, it will try to call
+ # __index__() on it, so there is no BC way to actually change this
+ # behavior (though it occurs to me that I could have just added a
+ # YOLO no checking alternate constructor.)
+ #
+ # To work around this problem, I represent a torch.Size proxy as
+ # a straight up proxy, that would have been constructed by taking
+ # the constituent proxies as arguments. This trick can be generally
+ # used for any construct that we need a proxy for but we can't
+ # directly represent as an aggregate; I don't see very many examples
+ # of this in torchdynamo though!
+
+ # Look for a proxy. If there are none, do the legacy behavior
+ tracer = None
+ proxies = self._as_proxy()
+ for proxy in proxies:
+ if isinstance(proxy, torch.fx.Proxy):
+ tracer = proxy.tracer
+ break
+
+ if tracer is None:
+ return torch.Size(proxies)
+
+ proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {})
+ proxy.node.meta["example_value"] = torch.Size(
+ [p.node.meta["example_value"] for p in proxies]
+ )
+ return proxy
+
+ def reconstruct(self, codegen):
+ codegen.load_import_from("torch", "Size")
+ codegen.foreach(self.items)
+ build_torch_size = [
+ create_instruction("BUILD_TUPLE", len(self.items)),
+ create_instruction("CALL_FUNCTION", 1),
+ ]
+ return build_torch_size
+
+
+class ShapeVariable(TupleVariable):
+ """
+ Represents tensor.shape(...) and helps differentiate between a constant
+ TupleVariable and ShapeVariable.
+ """
+
+ pass
+
+
+class NamedTupleVariable(TupleVariable):
+ def __init__(self, items, tuple_cls, **kwargs):
+ super().__init__(items, **kwargs)
+ self.tuple_cls = tuple_cls
+
+ def python_type(self):
+ return self.tuple_cls
+
+ def reconstruct(self, codegen):
+ create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls)
+ codegen.append_output(codegen._create_load_const(create_fn))
+ codegen.foreach(self.items)
+ return [
+ create_instruction("BUILD_TUPLE", len(self.items)),
+ create_instruction("CALL_FUNCTION", 1),
+ ]
+
+ def var_getattr(self, tx, name):
+ fields = namedtuple_fields(self.tuple_cls)
+ if name not in fields:
+ unimplemented(f"NamedTupleVariable.{name}")
+ return self.items[fields.index(name)].add_options(self)
+
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
+ options = VariableTracker.propagate(self)
+ fields = namedtuple_fields(self.tuple_cls)
+ return variables.ConstantVariable(name in fields, **options)
+
+
+class SliceVariable(BaseListVariable):
+ def __init__(self, items, **kwargs):
+ start, stop, step = [variables.ConstantVariable(None)] * 3
+ if len(items) == 1:
+ (stop,) = items
+ elif len(items) == 2:
+ start, stop = items
+ elif len(items) == 3:
+ start, stop, step = items
+ else:
+ raise AssertionError()
+
+ # Avoids a .item() call in the tensor slice that would attempt to get a
+ # value out fake tensors, and which would determine the output shape of
+ # the slice. It is a workaround until
+ # https://github.com/pytorch/pytorch/pull/83567 is landed and there is
+ # more complete support for breaking on data dependent operators.
+ if not config.capture_scalar_outputs:
+ for limit in (start, stop, step):
+ if isinstance(limit, variables.TensorVariable):
+ unimplemented("Dynamic slicing not supported")
+
+ super().__init__([start, stop, step], **kwargs)
+
+ def as_proxy(self):
+ return slice(*self._as_proxy())
+
+ def python_type(self):
+ return slice
+
+ def as_python_constant(self):
+ return slice(*[x.as_python_constant() for x in self.items])
+
+ def reconstruct(self, codegen):
+ codegen.foreach(self.items)
+ return [create_instruction("BUILD_SLICE", len(self.items))]
+
+ def var_getattr(self, tx, name):
+ fields = ["start", "stop", "step"]
+ if name not in fields:
+ unimplemented(f"slice.{name}")
+ return self.items[fields.index(name)].add_options(self)
+
+
+class ListIteratorVariable(VariableTracker):
+ def __init__(self, items, index: int = 0, **kwargs):
+ super(ListIteratorVariable, self).__init__(**kwargs)
+ assert isinstance(items, list)
+ assert all(isinstance(x, VariableTracker) for x in items)
+ self.items = items
+ self.index = index
+
+ def next_variables(self):
+ assert self.mutable_local
+ if self.index >= len(self.items):
+ raise StopIteration()
+ return self.items[self.index].add_options(self), ListIteratorVariable(
+ self.items,
+ self.index + 1,
+ mutable_local=MutableLocal(),
+ **VariableTracker.propagate([self]),
+ )
+
+ def as_python_constant(self):
+ if self.index > 0:
+ raise NotImplementedError()
+ return iter([x.as_python_constant() for x in self.items])
+
+ def unpack_var_sequence(self, tx):
+ return [x.add_options(self) for x in self.items[self.index :]]
+
+ def reconstruct(self, codegen):
+ remaining_items = self.items[self.index :]
+ codegen.foreach(remaining_items)
+ return [
+ create_instruction("BUILD_TUPLE", len(remaining_items)),
+ create_instruction("GET_ITER"),
+ ]
diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py
new file mode 100644
index 0000000000000..8dd3478114396
--- /dev/null
+++ b/torch/_dynamo/variables/misc.py
@@ -0,0 +1,674 @@
+import inspect
+import sys
+import types
+from typing import Dict, List
+
+import torch._C
+
+from .. import config, variables
+from ..bytecode_transformation import create_instruction
+from ..exc import unimplemented
+from ..guards import Guard, GuardBuilder, GuardSource
+from ..source import AttrSource
+from ..utils import identity, proxy_args_kwargs
+from .base import VariableTracker
+from .functions import (
+ UserFunctionVariable,
+ UserMethodVariable,
+ WrappedUserFunctionVariable,
+ WrappedUserMethodVariable,
+)
+
+
+class SuperVariable(VariableTracker):
+ def __init__(self, typevar, objvar=None, **kwargs):
+ super(SuperVariable, self).__init__(**kwargs)
+ self.typevar = typevar
+ self.objvar = objvar
+
+ def reconstruct(self, codegen):
+ codegen(variables.BuiltinVariable(super))
+ codegen(self.typevar)
+ if self.objvar is not None:
+ codegen(self.objvar)
+ return [create_instruction("CALL_FUNCTION", 2)]
+ else:
+ return [create_instruction("CALL_FUNCTION", 1)]
+
+ def const_getattr(self, tx, name):
+ assert self.objvar, "1-arg super not implemented"
+ search_type = self.typevar.as_python_constant()
+
+ # We default to the python type of the object. However,
+ # 1. If this is a `type`, then the original object represents the user
+ # defined type.
+ # 2. If this is `torch._C._TensorMeta`, the original object is the user
+ # defined type of a custom tensor subclass.
+ # TODO(future PR): figure out how to do this in a less hacky way
+ type_to_use = self.objvar.python_type()
+ if type_to_use is type or type_to_use is torch._C._TensorMeta:
+ type_to_use = self.objvar.value
+
+ # TODO(jansel): there is a small chance this could trigger user code, prevent that
+ return getattr(super(search_type, type_to_use), name)
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ options = VariableTracker.propagate(
+ self, args, kwargs.values(), self.objvar, self.typevar
+ )
+ inner_fn = self.const_getattr(self, name)
+ if inner_fn is object.__init__:
+ return LambdaVariable(identity, **options)
+ elif isinstance(inner_fn, types.FunctionType):
+ return variables.UserFunctionVariable(inner_fn, **options).call_function(
+ tx, [self.objvar] + args, kwargs
+ )
+ elif isinstance(inner_fn, types.MethodType):
+ return variables.UserMethodVariable(
+ inner_fn.__func__, self.objvar, **options
+ ).call_function(tx, args, kwargs)
+ else:
+ unimplemented(f"non-function or method super: {inner_fn}")
+
+
+class UnknownVariable(VariableTracker):
+ """
+ It could be anything!
+ """
+
+
+class ClosureVariable(UnknownVariable):
+ def __init__(self, name, **kwargs):
+ super(ClosureVariable, self).__init__(**kwargs)
+ self.name = name
+
+ def reconstruct(self, codegen):
+ return [codegen.create_load_closure(self.name)]
+
+
+class NewCellVariable(VariableTracker):
+ def __init__(self, **kwargs):
+ super(NewCellVariable, self).__init__(**kwargs)
+
+
+class NewGlobalVariable(VariableTracker):
+ def __init__(self, **kwargs):
+ super(NewGlobalVariable, self).__init__(**kwargs)
+
+
+class ContextWrappingVariable(VariableTracker):
+ def __init__(self, target_values, initial_values=None, **kwargs):
+ super(ContextWrappingVariable, self).__init__(**kwargs)
+ self.target_values = target_values
+ self.initial_values = initial_values
+
+ def enter(self, tx):
+ self._call_func(tx, self.target_values)
+ return variables.ConstantVariable(None, **VariableTracker.propagate(self))
+
+ def exit(self, tx, *args):
+ self._call_func(tx, self.initial_values)
+ return variables.ConstantVariable(None, **VariableTracker.propagate(self))
+
+ def reconstruct(self, codegen, target_inst=None):
+ """
+ Generate following Python Bytecode, with a `torch._C._set_grad_enable` call
+ Python 3.8
+ 0 LOAD_GLOBAL 0 (torch)
+ 2 LOAD_ATTR 1 (_C)
+ 4 LOAD_METHOD 2 (_set_grad_enable)
+ 6 LOAD_CONST 1 (False)
+ 8 CALL_METHOD 1
+ 10 POP_TOP
+
+ 12 SETUP_FINALLY 10 (to 24)
+
+ 14 LOAD_GLOBAL 3 (user_inst)
+ 16 CALL_FUNCTION 0
+ 18 POP_TOP
+ 20 POP_BLOCK
+ 22 BEGIN_FINALLY
+
+ 24 LOAD_GLOBAL 0 (torch)
+ 26 LOAD_ATTR 1 (_C)
+ 28 LOAD_METHOD 2 (_set_grad_enable)
+ 30 LOAD_CONST 2 (True)
+ 32 CALL_METHOD 1
+ 34 POP_TOP
+ 36 END_FINALLY
+ 38 LOAD_CONST 0 (None)
+ 40 RETURN_VALUE
+
+ Instructions 0-10 and 24-34 call torch._C.set_grad_enable(True/False)
+
+ Python 3.9, 3.10
+ 0 LOAD_GLOBAL 0 (torch)
+ 2 LOAD_ATTR 1 (_C)
+ 4 LOAD_METHOD 2 (_set_grad_enable)
+ 6 LOAD_CONST 1 (False)
+ 8 CALL_METHOD 1
+ 10 POP_TOP
+
+ 12 SETUP_FINALLY 22 (to 36)
+
+ 14 LOAD_GLOBAL 3 (user_inst)
+ 16 CALL_FUNCTION 0
+ 18 POP_TOP
+ 20 POP_BLOCK
+
+ 22 LOAD_GLOBAL 0 (torch)
+ 24 LOAD_ATTR 1 (_C)
+ 26 LOAD_METHOD 2 (_set_grad_enable)
+ 28 LOAD_CONST 2 (True)
+ 30 CALL_METHOD 1
+ 32 POP_TOP
+
+ 34 JUMP_FORWARD 14 (to 50)
+
+ 36 LOAD_GLOBAL 0 (torch)
+ 38 LOAD_ATTR 1 (_C)
+ 40 LOAD_METHOD 2 (_set_grad_enable)
+ 42 LOAD_CONST 2 (True)
+ 44 CALL_METHOD 1
+ 46 POP_TOP
+ 48 RERAISE
+
+ 50 LOAD_CONST 0 (None)
+ 52 RETURN_VALUE
+
+ """
+ if self.target_values == self.initial_values:
+ return ([], [])
+
+ def set_context_insts(values):
+ global_torch_source = codegen.tx.import_source("torch")
+ attr_source = AttrSource(global_torch_source, self._func_name())
+ load_set_context_enabling_insts = attr_source.reconstruct(codegen)
+
+ loads = [codegen.create_load_const(val) for val in values]
+
+ return [
+ *load_set_context_enabling_insts,
+ *loads,
+ create_instruction("CALL_FUNCTION", len(values)),
+ create_instruction("POP_TOP"),
+ ]
+
+ init_block = set_context_insts(self.target_values)
+ finally_block = set_context_insts(self.initial_values)
+ setup_final_inst = create_instruction("SETUP_FINALLY", target=finally_block[0])
+ prologue = init_block + [setup_final_inst]
+
+ # Generate the epilogue - starts with 20 POP_BLOCK and ends at 34 POP_TOP
+ if sys.version_info < (3, 9):
+ # Generate the prologue that ends with setup_finally
+ epilogue = [
+ create_instruction("POP_BLOCK"),
+ codegen.create_begin_finally(),
+ *finally_block,
+ create_instruction("END_FINALLY"),
+ ]
+ else:
+ except_block = set_context_insts(self.initial_values)
+ epilogue = [
+ create_instruction("POP_BLOCK"),
+ *except_block,
+ create_instruction("JUMP_FORWARD", target=target_inst),
+ *finally_block,
+ create_instruction("RERAISE"),
+ ]
+
+ return (prologue, epilogue)
+
+ def _call_func(self, tx, initial_values):
+ raise NotImplementedError("_call_func called on base")
+
+ def _func_name(self):
+ raise NotImplementedError("_func_name called on base")
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ assert len(args) == 1
+ assert isinstance(args[0], UserMethodVariable) or isinstance(
+ args[0], UserFunctionVariable
+ )
+
+ if isinstance(args[0], UserMethodVariable):
+ return WrappedUserMethodVariable(args[0], self)
+
+ if isinstance(args[0], UserFunctionVariable):
+ return WrappedUserFunctionVariable(args[0], self)
+
+
+class GradModeVariable(ContextWrappingVariable):
+ """represents torch.{no_grad,enable_grad,set_grad_mode}()"""
+
+ _guards_singleton = {Guard("", GuardSource.GLOBAL, GuardBuilder.GRAD_MODE)}
+
+ @staticmethod
+ def create(tx, target_value, **kwargs):
+ var = GradModeVariable(
+ target_values=[target_value],
+ initial_values=[torch.is_grad_enabled()],
+ **kwargs,
+ )
+ var._call_func(tx, [target_value])
+ return var
+
+ def __init__(self, target_values, initial_values=None, **kwargs):
+ super(GradModeVariable, self).__init__(
+ target_values=target_values, initial_values=initial_values, **kwargs
+ )
+ self.guards = self.guards | self._guards_singleton
+
+ def enter(self, tx):
+ return variables.ConstantVariable(None, **VariableTracker.propagate(self))
+
+ def _call_func(self, tx, values):
+ assert len(values) == 1
+ value = values[0]
+ tx.output.graph.create_node(
+ "call_function", torch._C._set_grad_enabled, (value,), {}
+ ),
+ torch._C._set_grad_enabled(value)
+
+ def _func_name(self):
+ return "_C._set_grad_enabled"
+
+ def fn_name(self):
+ if self.target_values:
+ return "enable_grad"
+ else:
+ return "no_grad"
+
+
+class AutocastModeVariable(ContextWrappingVariable):
+ @staticmethod
+ def create(tx, target_values, kwargs):
+ values = target_values
+ # device_type : str,
+ # dtype : Optional[_dtype] = None,
+ # enabled : bool = True,
+ # cache_enabled : Optional[bool] = None):cache_enabled
+ assert "device_type" in kwargs
+ values.append(kwargs["device_type"])
+ del kwargs["device_type"]
+
+ if "dtype" in kwargs:
+ values.append(kwargs["dtype"])
+ del kwargs["dtype"]
+ else:
+ values.append(variables.ConstantVariable(None))
+
+ if "enabled" in kwargs:
+ values.append(kwargs["enabled"])
+ del kwargs["enabled"]
+ else:
+ values.append(variables.ConstantVariable(True))
+
+ if "cache_enabled" in kwargs:
+ values.append(kwargs["cache_enabled"])
+ del kwargs["cache_enabled"]
+ else:
+ values.append(variables.ConstantVariable(None))
+
+ var = AutocastModeVariable(tx, target_values, initial_values=None, **kwargs)
+ return var
+
+ def __init__(self, tx, target_values, initial_values=None, **kwargs):
+ super(AutocastModeVariable, self).__init__(
+ target_values=target_values, initial_values=initial_values, **kwargs
+ )
+ self.target_values = [val.as_python_constant() for val in target_values]
+ self.mode = None
+
+ def exit(self, tx, *args):
+ tx.output.graph.create_node(
+ "call_function", exit_functional_autocast, (self.mode,), {}
+ )
+
+ def enter(self, tx):
+ self.mode = tx.output.graph.create_node(
+ "call_function", enter_functional_autocast, (*self.target_values,), {}
+ )
+
+ def _func_name(self):
+ return "torch.amp.autocast_mode.autocast"
+
+ def fn_name(self):
+ return "torch.amp.autocast_mode.autocast"
+
+
+def enter_functional_autocast(*vals):
+ mode = torch.amp.autocast(*vals)
+ mode.__enter__()
+ return mode
+
+
+def exit_functional_autocast(mode):
+ mode.__exit__(None, None, None)
+
+
+class ProfilerContextWrapperVariable(ContextWrappingVariable):
+ def __init__(self, target_values=None, **kwargs):
+ super(ProfilerContextWrapperVariable, self).__init__(
+ target_values=target_values, **kwargs
+ )
+
+ def enter(self, tx):
+ return variables.ConstantVariable(None, **VariableTracker.propagate(self))
+
+ def exit(self, tx, *args):
+ return variables.ConstantVariable(None, **VariableTracker.propagate(self))
+
+ def fn_name(self):
+ return "autograd.profiler.profile"
+
+
+class WithExitFunctionVariable(VariableTracker):
+ def __init__(self, ctx: VariableTracker, target, **kwargs):
+ super(WithExitFunctionVariable, self).__init__(**kwargs)
+ self.ctx = ctx
+ self.target = target
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ assert not kwargs
+ return self.ctx.exit(tx, *args)
+
+ def reconstruct(self, codegen):
+ # Note here we reconstruct the context manager rather than the
+ # exit function. The handler generated by BlockStackEntry
+ # will re-enter the context in the resume function.
+ output = AttrSource(
+ codegen.tx.import_source("torch"), self.ctx.fn_name()
+ ).reconstruct(codegen)
+
+ if codegen.tx.output.partial_convert:
+ output.extend(
+ [
+ create_instruction("CALL_FUNCTION", 0),
+ create_instruction("SETUP_WITH", target=self.target),
+ create_instruction("POP_TOP"),
+ ]
+ )
+ return output
+
+
+class InspectSignatureVariable(VariableTracker):
+ """represents inspect.signature(...)"""
+
+ @staticmethod
+ def create(callable, **kwargs):
+ if kwargs:
+ unimplemented(f"inspect.signature with {kwargs}")
+ return InspectSignatureVariable(callable)
+
+ def __init__(self, inspected, **kwargs):
+ super(InspectSignatureVariable, self).__init__(**kwargs)
+ self.inspected = inspected
+
+
+class AutogradFunctionVariable(VariableTracker):
+ """represents a torch.autograd.Function subclass"""
+
+ def __init__(self, fn_cls, **kwargs):
+ super().__init__(**kwargs)
+ self.fn_cls = fn_cls
+
+ def call_apply(self, tx, args, kwargs):
+ requires_grad = False
+
+ def visit(node):
+ nonlocal requires_grad
+ if isinstance(node, variables.TensorVariable):
+ if node.requires_grad is not False:
+ requires_grad = True
+ if isinstance(node, variables.NNModuleVariable):
+ if node.is_training(tx):
+ requires_grad = True
+ return node
+
+ VariableTracker.apply(visit, (args, kwargs))
+
+ if requires_grad and torch.is_grad_enabled():
+ # TODO(jansel): handle this in training mode
+ unimplemented("autograd.Function with requires_grad")
+
+ args = [BlackHoleVariable()] + list(args)
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ return variables.UserFunctionVariable(
+ self.fn_cls.forward, **options
+ ).call_function(tx, args, kwargs)
+
+ def call_function(self, tx, args, kwargs):
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ return AutogradFunctionVariable(self.fn_cls, **options)
+
+
+class BlackHoleVariable(VariableTracker):
+ """A autograd.function context that just ignores everything (for forward extraction)"""
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ assert name in ("__setattr__", "save_for_backward"), name
+ return variables.ConstantVariable(
+ None, **VariableTracker.propagate(self, args, kwargs.values())
+ )
+
+
+class LambdaVariable(VariableTracker):
+ def __init__(self, fn, **kwargs):
+ super(LambdaVariable, self).__init__(**kwargs)
+ self.fn = fn
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ return self.fn(*args, **kwargs).add_options(self)
+
+
+class GetAttrVariable(VariableTracker):
+ def __init__(self, obj, name, **kwargs):
+ super(GetAttrVariable, self).__init__(**kwargs)
+ assert isinstance(obj, VariableTracker)
+ assert isinstance(name, str)
+ self.obj = obj
+ self.name = name
+
+ def __str__(self):
+ return f"{self.__class__.__name__}({self.obj}, {self.name})"
+
+ def as_proxy(self):
+ return getattr(self.obj.as_proxy(), self.name)
+
+ def const_getattr(self, tx, name):
+ if not isinstance(self.obj, variables.NNModuleVariable):
+ raise NotImplementedError()
+ step1 = tx.output.get_submodule(self.obj.module_key)
+ if self.name not in step1.__dict__:
+ raise NotImplementedError()
+ step2 = inspect.getattr_static(step1, self.name)
+ if name not in step2.__dict__:
+ raise NotImplementedError()
+ return inspect.getattr_static(step2, name)
+
+ def reconstruct(self, codegen):
+ codegen(self.obj)
+ return codegen.create_load_attrs(self.name)
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+
+ # This variable is True when it corresponds to user code such as
+ #
+ # super().__torch_function__(...)
+ #
+ # and the super().__torch_function__ attribute resolves
+ # to torch.Tensor.__torch_function__.
+ is_original_tensor_torch_function = (
+ self.name == "__torch_function__"
+ and isinstance(self.obj, SuperVariable)
+ # for now, only support one level of inheritance
+ and len(self.obj.objvar.value.__mro__) > 1
+ and self.obj.objvar.value.__mro__[1] == torch.Tensor
+ )
+ if is_original_tensor_torch_function:
+ # Instead of tracing inside torch.Tensor.__torch_function__,
+ # record the `call_function` or `call_method` call into the graph.
+ from . import TensorVariable, TorchVariable
+
+ original_torch_or_getattr_variable = args[0]
+ new_args = args[2].items
+ new_kwargs = args[3].items
+ options = VariableTracker.propagate(self, new_args, new_kwargs.values())
+ # Disable __torch_function__ here to prevent the clone of the
+ # example tensor from going into the override.
+ with torch._C.DisableTorchFunction():
+ if isinstance(args[0], TorchVariable):
+ return TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_function",
+ original_torch_or_getattr_variable.value,
+ *proxy_args_kwargs(new_args, new_kwargs),
+ current_tx=tx,
+ ),
+ **options,
+ )
+ elif isinstance(args[0], GetAttrVariable):
+ return TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_method",
+ original_torch_or_getattr_variable.name,
+ *proxy_args_kwargs(new_args, new_kwargs),
+ current_tx=tx,
+ ),
+ **options,
+ )
+ else:
+ unimplemented(
+ f"GetAttrVariable.call_function original __torch_function__ {args}"
+ )
+
+ if isinstance(self.obj, AutogradFunctionVariable) and self.name == "apply":
+ return self.obj.call_apply(tx, args, kwargs).add_options(self)
+ return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ if (
+ name == "__len__"
+ and isinstance(self.obj, InspectSignatureVariable)
+ and self.name == "parameters"
+ ):
+ return variables.ConstantVariable(
+ self.obj.inspected.num_parameters(),
+ **VariableTracker.propagate(self, self.obj, self.obj.inspected),
+ )
+ return super(GetAttrVariable, self).call_method(tx, name, args, kwargs)
+
+
+class PythonModuleVariable(VariableTracker):
+ def __init__(self, value: types.ModuleType, **kwargs):
+ super(PythonModuleVariable, self).__init__(**kwargs)
+ self.value = value
+
+ def python_type(self):
+ return types.ModuleType
+
+
+class SkipFilesVariable(VariableTracker):
+ def __init__(self, value, **kwargs):
+ super().__init__(**kwargs)
+ self.value = value
+
+ def python_type(self):
+ return type(self.value)
+
+ def as_python_constant(self):
+ return self.value
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
+ unimplemented(
+ f"call {config.dynamo_import}.disable() wrapped function {self.value}"
+ )
+ else:
+ try:
+ path = inspect.getfile(self.value)
+ except TypeError:
+ path = f"Builtin {self.value.__name__}"
+ unimplemented("call_function in skip_files " + path)
+
+
+class TypingVariable(VariableTracker):
+ def __init__(self, value, **kwargs):
+ super().__init__(**kwargs)
+ self.value = value
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ if name == "__getitem__" and len(args) == 1:
+ return variables.ConstantVariable(
+ self.value[args[0].as_python_constant()],
+ **VariableTracker.propagate(self, args),
+ )
+ unimplemented("typing")
+
+
+class NumpyVariable(VariableTracker):
+ """
+ Wrapper around `numpy.*` for better error messages.
+ """
+
+ def __init__(self, value, **kwargs):
+ super().__init__(**kwargs)
+ self.value = value
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ unimplemented("numpy")
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ unimplemented("numpy")
+
+ def python_type(self):
+ return type(self.value)
+
+ def as_python_constant(self):
+ return self.value
diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py
new file mode 100644
index 0000000000000..4bf6e33745202
--- /dev/null
+++ b/torch/_dynamo/variables/nn_module.py
@@ -0,0 +1,491 @@
+import functools
+import inspect
+import itertools
+import re
+import types
+from contextlib import contextmanager
+from typing import Dict, List
+
+import torch.nn
+
+from .. import skipfiles, variables
+from ..allowed_functions import is_allowed
+from ..exc import RestartAnalysis, unimplemented
+from ..guards import GuardBuilder
+from ..mutation_guard import GenerationTracker
+from ..source import AttrSource, GetItemSource, NNModuleSource, NotNNModuleSource
+from ..utils import is_lazy_module, istype, proxy_args_kwargs
+from .base import MutableLocal, typestr, VariableTracker
+from .functions import invoke_and_store_as_constant
+from .lists import SliceVariable
+from .user_defined import UserDefinedObjectVariable
+
+
+class NNModuleVariable(VariableTracker):
+ _nonvar_fields = ["module_type", "module_key"]
+
+ def __init__(self, module_type: type, module_key: str, **kwargs):
+ super(NNModuleVariable, self).__init__(**kwargs)
+ self.module_type = module_type
+ self.module_key = module_key
+ assert self.source
+
+ def python_type(self):
+ return self.module_type
+
+ def _wrap_submodule(self, tx, source, submod, *key_extra, **options):
+ return
+
+ def unpack_var_sequence(self, tx):
+ # implement list/iter/tuple/etc calls
+ base = tx.output.get_submodule(self.module_key)
+ options = VariableTracker.propagate([self])
+ assert isinstance(
+ base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential)
+ ), typestr(base)
+ assert self.source
+ result = []
+ for idx, submod in enumerate(base):
+ result.append(
+ tx.output.register_attr_or_module(
+ submod,
+ self.module_key,
+ idx,
+ source=NNModuleSource(GetItemSource(self.source, idx)),
+ **options,
+ )
+ )
+ return result
+
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
+ options = VariableTracker.propagate(self)
+ mod = tx.output.get_submodule(self.module_key)
+ result = hasattr(mod, name)
+ return variables.ConstantVariable(result, **options).add_guard(
+ NNModuleSource(AttrSource(self.source, name)).make_guard(
+ GuardBuilder.HASATTR
+ )
+ )
+
+ def is_training(self, tx):
+ mod = tx.output.get_submodule(self.module_key)
+ return getattr(mod, "training", False)
+
+ def convert_to_unspecialized(self, tx):
+ """Restart analysis treating this module as an UnspecializedNNModuleVariable"""
+ mod = tx.output.get_submodule(self.module_key)
+ GenerationTracker.tag(mod)
+
+ # Mark the class dynamic unless its module initialization
+ if tx.f_code.co_name != "__init__":
+ GenerationTracker.mark_class_dynamic(type(mod))
+ raise RestartAnalysis()
+
+ def var_getattr(self, tx, name):
+ from .builder import VariableBuilder
+
+ options = VariableTracker.propagate(self)
+ guards = options.get("guards", set())
+
+ if self.source:
+ source = AttrSource(self.source, name)
+ options["source"] = source
+ else:
+ source = None
+
+ base = tx.output.get_submodule(self.module_key)
+ base_dict = object.__getattribute__(base, "__dict__")
+ object_member = True
+ all_class_attribute_names = set()
+ for x in inspect.getmro(base.__class__):
+ all_class_attribute_names.update(x.__dict__.keys())
+
+ if not self.source:
+ unimplemented("GETATTR with no source")
+
+ if name in base_dict:
+ subobj = base_dict[name]
+ elif (
+ "_modules" in base_dict
+ and name in base_dict["_modules"]
+ and name not in all_class_attribute_names
+ ):
+ subobj = base_dict["_modules"][name]
+ elif "_parameters" in base_dict and name in base_dict["_parameters"]:
+ subobj = base_dict["_parameters"][name]
+ elif "_buffers" in base_dict and name in base_dict["_buffers"]:
+ subobj = base_dict["_buffers"][name]
+ else:
+ subobj = inspect.getattr_static(base, name)
+ object_member = False
+
+ if name == "__class__" and not object_member:
+ return variables.UserDefinedClassVariable(base.__class__, **options)
+
+ if object_member:
+ return VariableBuilder(tx, NNModuleSource(source))(subobj)
+ else:
+ if istype(subobj, property):
+ return variables.UserFunctionVariable(
+ subobj.fget, guards=guards
+ ).call_function(tx, [(self)], {})
+ elif istype(subobj, classmethod):
+ return variables.UserMethodVariable(
+ subobj.__func__,
+ variables.UserDefinedObjectVariable(type(base), guards=guards),
+ **options,
+ )
+ elif istype(subobj, staticmethod):
+ return variables.UserFunctionVariable(subobj.__get__(base), **options)
+ elif istype(subobj, types.FunctionType):
+ return variables.UserMethodVariable(subobj, self, **options)
+ else:
+ unimplemented(f"class property {typestr(base)} {typestr(subobj)}")
+
+ return variables.GetAttrVariable(self, name, **options)
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ mod = tx.output.get_submodule(self.module_key)
+
+ @contextmanager
+ def record_nn_module_stack():
+ try:
+ tx.nn_module_stack[self.module_key] = mod.__class__.__name__
+ yield
+ finally:
+ del tx.nn_module_stack[self.module_key]
+
+ with record_nn_module_stack():
+ is_lazy = is_lazy_module(mod)
+ if (
+ isinstance(mod, torch.nn.Sequential)
+ and mod.__class__.forward is torch.nn.Sequential.forward
+ ):
+ # unroll Sequential()
+ assert not kwargs
+ (arg,) = args
+ for idx, submod in enumerate(mod):
+ tx.call_function(
+ tx.output.register_attr_or_module(
+ submod,
+ self.module_key,
+ idx,
+ source=NNModuleSource(GetItemSource(self.source, idx)),
+ **options,
+ ),
+ [arg],
+ {},
+ )
+ arg = tx.pop()
+ return arg
+ elif is_allowed(mod.__class__):
+ # The module type will change after it is called
+ if is_lazy:
+ self.module_type = mod.cls_to_become
+
+ return variables.TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_module",
+ self.module_key,
+ *proxy_args_kwargs(args, kwargs),
+ current_tx=tx,
+ ),
+ nnmodule=mod,
+ **options,
+ )
+ else:
+ # for lazy modules, run the pre-hooks which will update the type
+ # TODO mlazos: we don't fully support all of the hooks that exist,
+ # so restrict using __call__ only to lazy modules for now
+ if is_lazy:
+ fn = mod.__class__.__call__
+ else:
+ fn = mod.__class__.forward
+
+ return tx.inline_user_function_return(
+ variables.UserFunctionVariable(fn, **options),
+ [self] + args,
+ kwargs,
+ )
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ constant=False,
+ ) -> "VariableTracker":
+ from . import ConstantVariable, ListIteratorVariable, TupleVariable
+
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ key = self.module_key
+ module = tx.output.get_submodule(key)
+
+ if name == "forward":
+ return self.call_function(tx, args, kwargs)
+
+ if name == "_check_input_dim" and skipfiles.is_torch_inline_allowed(
+ inspect.getfile(module.__class__._check_input_dim)
+ ):
+ return ConstantVariable(True, **options)
+
+ if name == "_get_item_by_idx":
+ assert args[1].is_python_constant()
+ assert isinstance(args[0], TupleVariable)
+ mod_var = args[0].items[args[1].value]
+ key = mod_var.module_key
+ submod = tx.output.get_submodule(key)
+ return tx.output.register_attr_or_module(
+ submod,
+ key,
+ key,
+ source=NNModuleSource(GetItemSource(self.source, key)),
+ **options,
+ )
+
+ if constant:
+ fn = getattr(module, name)
+ name = f"{module.__class__.__name__}_{name}_result"
+ return invoke_and_store_as_constant(tx, fn, name, options, args, kwargs)
+
+ if not all(
+ x.is_python_constant() for x in itertools.chain(args, kwargs.values())
+ ):
+ raise unimplemented(f"non-const NNModule method {name}")
+
+ def get_kwargs(*names):
+ fn = getattr(module, name)
+ bound_args = inspect.signature(fn).bind(
+ *([x.as_python_constant() for x in args]),
+ **{k: v.as_python_constant() for k, v in kwargs.items()},
+ )
+ bound_args.apply_defaults()
+ bound_args = bound_args.arguments
+ return {k: bound_args[k] for k in names}
+
+ def wrap_values(items, getsource=AttrSource):
+ result = []
+ for name, submod in items:
+ # layer.0.foo => layer[0].foo
+ name = re.sub(r"[.]([0-9]+)([.]|$)", r"[\1]\2", name)
+ src = NNModuleSource(getsource(self.source, name))
+ result.append(
+ tx.output.register_attr_or_module(
+ submod,
+ key,
+ name,
+ source=src,
+ **options,
+ )
+ )
+ return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
+
+ def named_embed(name, obj):
+ return TupleVariable(
+ [
+ ConstantVariable(name, **options),
+ tx.output.register_attr_or_module(
+ obj,
+ key,
+ name,
+ source=NNModuleSource(GetItemSource(self.source, name)),
+ **options,
+ ),
+ ]
+ )
+
+ if name == "children":
+ assert not (args or kwargs)
+ return wrap_values(module.named_children())
+ elif name == "named_parameters":
+ result = []
+ for name, param in module.named_parameters(
+ **get_kwargs("prefix", "recurse")
+ ):
+ result.append(named_embed(name, param))
+ return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
+ elif name == "named_modules":
+ result = []
+ for name, submod in module.named_modules(
+ **get_kwargs("memo", "prefix", "remove_duplicate")
+ ):
+ result.append(named_embed(name, submod))
+ return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
+ elif name == "parameters":
+ return wrap_values(module.named_parameters(**get_kwargs("recurse")))
+ elif name == "values":
+ assert not (args or kwargs)
+ return wrap_values(module.items(), GetItemSource)
+ elif name == "items":
+ assert not (args or kwargs)
+ result = []
+ for name, submod in module.items():
+ result.append(named_embed(name, submod))
+ return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
+ elif name == "__len__":
+ assert not (args or kwargs)
+ return ConstantVariable(len(module), **options)
+ elif (
+ name == "__contains__"
+ and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict))
+ and args
+ and args[0].is_python_constant()
+ ):
+ return ConstantVariable(
+ args[0].as_python_constant() in module._modules, **options
+ )
+ elif name == "__getitem__":
+ assert not kwargs and len(args) == 1
+ assert type(module).__getitem__ in (
+ torch.nn.ModuleDict.__getitem__,
+ torch.nn.ModuleList.__getitem__,
+ torch.nn.ParameterList.__getitem__,
+ torch.nn.Sequential.__getitem__,
+ ), typestr(module)
+ assert self.source
+
+ if isinstance(args[0], SliceVariable):
+ # Build a TupleVariable of NNModules
+ result = []
+
+ # Turn the slice into the list of integers
+ keys = list(range(len(module)))[args[0].as_python_constant()]
+ for idx, submod in enumerate(module[args[0].as_python_constant()]):
+ key = keys[idx]
+ src = NNModuleSource(GetItemSource(self.source, key))
+ result.append(
+ tx.output.register_attr_or_module(
+ submod,
+ key,
+ source=src,
+ **options,
+ )
+ )
+ return TupleVariable(result, **options)
+
+ key = args[0].as_python_constant()
+ submod = module[key]
+ return tx.output.register_attr_or_module(
+ submod,
+ key,
+ args[0].as_python_constant(),
+ source=NNModuleSource(GetItemSource(self.source, key)),
+ **options,
+ )
+ elif name == "_get_abs_string_index":
+ # Inline the function
+ fn = getattr(module, name).__func__
+ return tx.inline_user_function_return(
+ variables.UserFunctionVariable(fn, **options),
+ [self] + args,
+ kwargs,
+ )
+ else:
+ return super().call_method(tx, name, args, kwargs)
+
+
+class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
+ """
+ The above class will specialize on the id() of a module and place
+ parameters on the torch.fx.GraphModule. Giving one graph per
+ module instance. This version treats nn.Modules() like other user
+ defined objects and will pass parameters into the FX graph as inputs.
+ Giving one graph per module class.
+ """
+
+ def __init__(self, value, **kwargs):
+ super(UnspecializedNNModuleVariable, self).__init__(value=value, **kwargs)
+ if self.source and self.source.is_nn_module():
+ # force guard checks even when `not config.guard_nn_modules``
+ self.source = NotNNModuleSource(self.source)
+
+ @staticmethod
+ @functools.lru_cache(None)
+ def _nn_module_method_ids():
+ return {
+ id(x.__code__)
+ for x in torch.nn.Module.__dict__.values()
+ if hasattr(x, "__code__")
+ }
+
+ def unpack_var_sequence(self, tx):
+ from .builder import VariableBuilder
+
+ try:
+ fn = inspect.getattr_static(self.value_type, "__iter__")
+ except AttributeError:
+ raise NotImplementedError()
+
+ if fn in (
+ torch.nn.ModuleList.__iter__,
+ torch.nn.ParameterList.__iter__,
+ torch.nn.Sequential.__iter__,
+ ):
+ assert self.source
+ return [
+ VariableBuilder(tx, source=GetItemSource(self.source, idx))(
+ item
+ ).add_options(self)
+ for idx, item in enumerate(self.value)
+ ]
+
+ return super().unpack_var_sequence(tx)
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ options = VariableTracker.propagate(self, args, kwargs.values())
+
+ # TODO mlazos: only support __call__ for lazy modules
+ # until we can support a larger swath of python
+ if is_lazy_module(self.value):
+ fn = self.value_type.__call__
+ else:
+ fn = self.value_type.forward
+
+ return variables.UserFunctionVariable(fn, **options).call_function(
+ tx, [self] + list(args), kwargs
+ )
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ from .builder import VariableBuilder
+
+ options = VariableTracker.propagate(self, args, kwargs.values())
+
+ if name not in getattr(self.value, "__dict__", {}):
+ try:
+ method = inspect.getattr_static(type(self.value), name)
+ except AttributeError:
+ method = None
+
+ if method is torch.nn.Module.parameters:
+ assert not args or kwargs
+ options["guards"].add(
+ self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES)
+ )
+ items = []
+ for name, value in self.value.named_parameters():
+ items.append(
+ VariableBuilder(tx, AttrSource(self.source, name))(
+ value
+ ).add_options(options)
+ )
+ return variables.ListIteratorVariable(
+ items, mutable_local=MutableLocal(), **options
+ )
+
+ if id(method.__code__) in self._nn_module_method_ids():
+ unimplemented(f"UnspecializedNNModuleVariable missing {name}")
+
+ return super().call_method(tx, name, args, kwargs)
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
new file mode 100644
index 0000000000000..bc05980e26574
--- /dev/null
+++ b/torch/_dynamo/variables/tensor.py
@@ -0,0 +1,717 @@
+import contextlib
+import copy
+import functools
+import itertools
+import math
+import numbers
+import operator
+from typing import Dict, List
+
+import torch.fx
+import torch.random
+
+from ..utils import fake_tensors_available
+
+if fake_tensors_available:
+ from torch._subclasses import FakeTensor
+ from torch._subclasses.fake_tensor import (
+ DataDependentOutputException,
+ DynamicOutputShapeException,
+ )
+ from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor
+
+import torch.utils._python_dispatch as py_dispatch
+from torch.fx.immutable_collections import immutable_list
+from torch.utils._pytree import tree_map
+
+from .. import config, variables
+from ..exc import TorchRuntimeError, unimplemented, Unsupported
+from ..guards import GuardBuilder
+from ..source import AttrSource
+from ..utils import (
+ clone_input,
+ is_lazy_module,
+ istype,
+ preserve_rng_state,
+ product,
+ proxy_args_kwargs,
+ tensortype_to_dtype,
+)
+from .base import MutableLocal, typestr, VariableTracker
+from .constant import ConstantVariable
+from .lists import ShapeVariable, SizeVariable
+
+
+class TensorVariable(VariableTracker):
+ """A torch.Tensor input or an intermediate value in the FX graph"""
+
+ _nonvar_fields = [
+ "proxy",
+ "dtype",
+ "device",
+ "ndim",
+ "size",
+ "stride",
+ "requires_grad",
+ "is_quantized",
+ "is_contiguous",
+ ]
+
+ @staticmethod
+ def propagate_args_kwargs(node):
+ def visit(n: torch.fx.Node):
+ return n.meta["example_value"]
+
+ return torch.fx.node.map_arg((node.args, node.kwargs), visit)
+
+ @staticmethod
+ def run_proxy(proxy, args, kwargs, nnmodule):
+ op = proxy.node.op
+ if op == "call_function":
+ return proxy.node.target(*args, **kwargs)
+ elif op == "call_method":
+ return getattr(args[0], proxy.node.target)(*args[1:], **kwargs)
+ elif op == "call_module":
+ assert nnmodule is not None
+ return nnmodule(*args, **kwargs)
+ raise AssertionError(op)
+
+ @classmethod
+ def create(cls, tx, proxy, example_value=None, nnmodule=None, **options):
+ if "guards" in options and options["guards"] is not None:
+ tx.output.guards.update(options["guards"])
+
+ assert "example_value" not in proxy.node.meta
+ if not config.dynamic_propagation:
+ if isinstance(example_value, torch.Tensor):
+ options.update(cls.specialize(example_value))
+ return cls(proxy, **options)
+
+ use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation
+ if use_fake_tensors:
+ fake_wrapper = functools.partial(
+ wrap_to_fake_tensor, fake_mode=tx.fake_mode
+ )
+ # python errors if the import isnt here
+ from ..utils import wrap_fake_exception
+ else:
+
+ def wrap_fake_exception(func):
+ return func()
+
+ args = kwargs = None
+ initial_example_value = example_value
+
+ with preserve_rng_state():
+ if example_value is None:
+ op = proxy.node.op
+ args, kwargs = cls.propagate_args_kwargs(proxy.node)
+ if use_fake_tensors:
+ args = tree_map(fake_wrapper, args)
+ kwargs = tree_map(fake_wrapper, kwargs)
+ if op == "call_module" and not is_lazy_module(nnmodule):
+ nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)
+
+ def context():
+ if hasattr(py_dispatch, "enable_torch_dispatch_mode"):
+ return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode)
+ else:
+ return tx.fake_mode
+
+ else:
+ context = contextlib.nullcontext
+ if op == "call_module" and not is_lazy_module(nnmodule):
+ nnmodule = copy.deepcopy(nnmodule)
+
+ if op == "call_module" and is_lazy_module(nnmodule):
+ assert nnmodule is not None
+ # In the case of a lazy module, we want to run
+ # the pre-hooks which initialize it
+ example_value = nnmodule(*args, **kwargs)
+ try:
+ with context():
+ example_value = wrap_fake_exception(
+ lambda: cls.run_proxy(proxy, args, kwargs, nnmodule)
+ )
+ except Unsupported:
+ raise
+ except RuntimeError as e:
+ if use_fake_tensors and isinstance(e, DataDependentOutputException):
+ if (
+ config.capture_scalar_outputs
+ and proxy.node.target == "item"
+ ):
+ example_value = torch.zeros(
+ size=(), dtype=args[0].dtype
+ ).item()
+ else:
+ unimplemented(f"data dependent operator: {e.func}")
+ elif use_fake_tensors and isinstance(
+ e, DynamicOutputShapeException
+ ):
+ unimplemented(f"dynamic shape operator: {e.func}")
+ else:
+ raise TorchRuntimeError() from e
+ else:
+ if use_fake_tensors:
+ example_value = fake_wrapper(example_value)
+
+ if isinstance(example_value, torch.Tensor):
+ is_parameter = isinstance(example_value, torch.nn.Parameter)
+ parameter_value = initial_example_value if is_parameter else None
+
+ # tensor subclasses will not be converted to FakeTensors and need to be cloned
+ if not use_fake_tensors or not isinstance(example_value, FakeTensor):
+ # NB: ensure strides are preserved
+ example_value = clone_input(example_value)
+ proxy.node.meta["example_value"] = example_value
+ specialized_props = cls.specialize(example_value)
+ if use_fake_tensors and isinstance(example_value, FakeTensor):
+ specialized_props["class_type"] = (
+ torch.nn.Parameter if is_parameter else torch.Tensor
+ )
+
+ specialized_props["parameter_value"] = parameter_value
+
+ options.update(specialized_props)
+ return cls(proxy, **options)
+ elif istype(example_value, (int, bool, float)) and config.dynamic_shapes:
+ proxy.node.meta["example_value"] = example_value
+ return DynamicShapeVariable(proxy, type(example_value), **options)
+ elif istype(example_value, torch.Size) and config.dynamic_shapes:
+ proxy.node.meta["example_value"] = example_value
+ sizes = []
+ for i, v in enumerate(example_value):
+ proxy_i = proxy[i]
+ proxy_i.node.meta["example_value"] = v
+ sizes.append(DynamicShapeVariable(proxy_i, int))
+ return SizeVariable(sizes, proxy, **options)
+ elif istype(example_value, int) and proxy.node.target in (
+ torch.seed,
+ operator.mod,
+ torch.distributed.get_rank,
+ torch.distributed.get_world_size,
+ ):
+ proxy.node.meta["example_value"] = example_value
+ return DynamicShapeVariable(proxy, type(example_value), **options)
+ elif istype(example_value, torch.Size) and all(
+ [isinstance(x, int) for x in example_value]
+ ):
+ sizes = [variables.ConstantVariable(x) for x in example_value]
+ return SizeVariable(sizes, **options)
+ elif isinstance(example_value, (tuple, list)):
+ unpacked = []
+ for i, val in enumerate(example_value):
+ if val is None:
+ # nn.MultiheadAttention() can return None, see issue #175
+ unpacked.append(
+ variables.ConstantVariable(None, **options),
+ )
+ else:
+ unpacked.append(
+ cls.create(
+ tx,
+ proxy.tracer.create_proxy(
+ "call_function", operator.getitem, (proxy, i), {}
+ ),
+ example_value=val,
+ **options,
+ )
+ )
+ if istype(example_value, tuple):
+ return variables.TupleVariable(unpacked, **options)
+ elif istype(example_value, (list, immutable_list)):
+ return variables.ListVariable(
+ unpacked, mutable_local=MutableLocal(), **options
+ )
+ else:
+ assert (
+ example_value.__class__.__module__ == "torch.return_types"
+ or hasattr(example_value, "_fields")
+ ), "namedtuple?"
+ return variables.NamedTupleVariable(
+ unpacked, example_value.__class__, **options
+ )
+ elif example_value is None or proxy.node.target is torch.manual_seed:
+ return variables.ConstantVariable(None, **options)
+ elif (
+ isinstance(example_value, int)
+ and proxy.node.target is torch._utils._element_size
+ ):
+ proxy.node.meta["example_value"] = example_value
+ return variables.ConstantVariable(example_value, **options)
+ elif (
+ isinstance(example_value, numbers.Number)
+ and (
+ proxy.node.target == "item"
+ or proxy.node.target in {math.sqrt, math.pow}
+ )
+ and config.capture_scalar_outputs
+ ):
+ if use_fake_tensors:
+ # item raw value should not be accessed
+ return FakeItemVariable.create(
+ tx=tx,
+ proxy=proxy,
+ example_value=torch.tensor(example_value),
+ **options,
+ )
+ else:
+ return UnspecializedPythonVariable.create(
+ tx=tx,
+ proxy=proxy,
+ example_value=torch.tensor(example_value),
+ raw_value=None if use_fake_tensors else example_value,
+ need_unwrap=False,
+ **options,
+ )
+ elif proxy.node.target == torch._C._DisableFuncTorch:
+ from . import UserDefinedObjectVariable
+
+ return UserDefinedObjectVariable(example_value)
+ elif proxy.node.target.__name__ == "set_state" and isinstance(
+ proxy.node.target.__self__, torch._C.Generator
+ ):
+ from . import TorchVariable
+
+ return TorchVariable(proxy.node.target)
+ else:
+ raise AssertionError(
+ "torch.* op returned non-Tensor "
+ + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}"
+ )
+
+ def __init__(
+ self,
+ proxy: torch.fx.Proxy,
+ dtype=None,
+ device=None,
+ ndim=None,
+ size=None,
+ stride=None,
+ requires_grad=None,
+ is_quantized=None,
+ is_contiguous=None,
+ is_sparse=None,
+ class_type=torch.Tensor,
+ parameter_value=None,
+ **kwargs,
+ ):
+ super(TensorVariable, self).__init__(**kwargs)
+ self.proxy = proxy
+ self.dtype = dtype
+ self.device = device
+ self.ndim = ndim
+ self.size = size
+ self.stride = stride
+ self.requires_grad = requires_grad
+ self.is_quantized = is_quantized
+ self.is_contiguous = is_contiguous
+ self.is_sparse = is_sparse
+ self.class_type = class_type
+ self.parameter_value = parameter_value
+
+ def as_proxy(self):
+ return self.proxy
+
+ def python_type(self):
+ return self.class_type
+
+ def call_isinstance(self, tensor_type):
+ def check_type(ty):
+ if ty not in tensortype_to_dtype:
+ return issubclass(self.python_type(), ty)
+
+ dtypes = tensortype_to_dtype[ty]
+ return self.dtype in dtypes
+
+ if type(tensor_type) is tuple:
+ return any([check_type(ty) for ty in tensor_type])
+ else:
+ return check_type(tensor_type)
+
+ @staticmethod
+ def specialize(value: torch.Tensor):
+ props = {
+ "dtype": value.dtype,
+ "device": value.device,
+ "ndim": int(value.ndim),
+ "requires_grad": value.requires_grad,
+ "is_quantized": value.is_quantized,
+ "is_sparse": value.is_sparse,
+ "class_type": type(value),
+ }
+ if not config.dynamic_shapes:
+ props["size"] = tuple(value.size())
+ props["stride"] = tuple(value.stride())
+ props["is_contiguous"] = value.is_contiguous()
+ return props
+
+ def var_getattr(self, tx, name):
+ from . import ConstantVariable, TorchVariable
+
+ result = None
+ options = VariableTracker.propagate(self)
+ if name == "ndim" and self.ndim is not None:
+ result = ConstantVariable(self.ndim, **options)
+ elif name == "dtype" and self.dtype is not None:
+ result = TorchVariable(self.dtype, **options)
+ elif name == "device" and self.device is not None:
+ result = TorchVariable(self.device, **options)
+ elif name == "is_cuda" and self.device is not None:
+ result = ConstantVariable(self.device.type == "cuda", **options)
+ elif name == "shape" and self.size is not None:
+ sizes = [variables.ConstantVariable(x) for x in self.size]
+ result = ShapeVariable(sizes, **options)
+ elif name == "requires_grad" and self.requires_grad is not None:
+ result = ConstantVariable(self.requires_grad, **options)
+ elif name == "is_quantized" and self.is_quantized is not None:
+ result = ConstantVariable(self.is_quantized, **options)
+ elif name == "is_sparse" and self.is_sparse is not None:
+ result = ConstantVariable(self.is_sparse, **options)
+ elif name == "shape" and self.size is None:
+ result = self.call_method(tx, "size", [], {})
+ elif name == "ndim" and self.ndim is None:
+ result = self.call_method(tx, "dim", [], {})
+
+ if name == "__class__":
+ return TorchVariable(self.python_type(), **options)
+
+ # Add a guard for type matching, these guards are checked before tensor guards
+ # In some cases, a . guard can be evaluated first, and break if
+ # is later changed to another type
+ if result is not None and self.source is not None:
+ result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
+
+ if result is None:
+ raise NotImplementedError()
+
+ return result
+
+ def unpack_var_sequence(self, tx):
+ options = VariableTracker.propagate(self)
+ if self.size:
+ return [
+ variables.BuiltinVariable(operator.getitem, **options).call_function(
+ tx, [self, variables.ConstantVariable(i)], {}
+ )
+ for i in range(self.size[0])
+ ]
+
+ return super(TensorVariable, self).unpack_var_sequence(tx)
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ from . import ConstantVariable, TupleVariable
+
+ kwargs = dict(kwargs)
+
+ options = VariableTracker.propagate(self, args, kwargs.values())
+
+ if name == "stride" and self.stride is not None:
+ constant_result = ConstantVariable(self.stride, **options)
+ elif name == "size" and self.size is not None:
+ sizes = [variables.ConstantVariable(x) for x in self.size]
+ constant_result = SizeVariable(sizes, **options)
+ elif name == "numel" and self.size is not None:
+ constant_result = ConstantVariable(product(self.size), **options)
+ elif name in ("ndimension", "dim") and self.ndim is not None:
+ constant_result = ConstantVariable(self.ndim, **options)
+ elif name == "is_floating_point" and self.dtype is not None:
+ constant_result = ConstantVariable(self.dtype.is_floating_point, **options)
+ elif name == "is_contiguous" and self.is_contiguous is not None:
+ if (
+ "memory_format" in kwargs
+ and kwargs["memory_format"].as_python_constant()
+ == torch.contiguous_format
+ ):
+ kwargs.pop("memory_format")
+ constant_result = ConstantVariable(self.is_contiguous, **options)
+ else:
+ constant_result = None
+
+ if constant_result:
+ assert not kwargs, f"Tensor.{name}() unhandled kwargs"
+ if len(args) == 1:
+ return constant_result.getitem_const(args[0])
+ elif args:
+ return TupleVariable(
+ [constant_result.getitem_const(a) for a in args], **options
+ )
+ return constant_result
+ elif (
+ name == "repeat"
+ and not all(
+ x.is_python_constant() for x in itertools.chain(args, kwargs.values())
+ )
+ and not config.dynamic_shapes
+ ):
+ unimplemented("dynamic Tensor.repeat")
+ elif name in ("tolist", "numpy", "backward"):
+ unimplemented(f"Tensor.{name}")
+ elif name == "nonzero" and not config.dynamic_shapes:
+ unimplemented(f"Tensor.{name}")
+ elif name == "item":
+ if config.capture_scalar_outputs:
+ return self.__class__.create(
+ tx,
+ tx.output.create_proxy(
+ "call_method", "item", (self.as_proxy(),), {}, current_tx=tx
+ ),
+ **options,
+ )
+ else:
+ unimplemented(f"Tensor.{name}")
+ elif name == "__len__":
+ if self.size:
+ assert not config.dynamic_shapes
+ return ConstantVariable(self.size[0], **options)
+ else:
+ return self.__class__.create(
+ tx,
+ tx.output.create_proxy(
+ "call_function", len, (self.as_proxy(),), {}, current_tx=tx
+ ),
+ **options,
+ )
+ elif name == "__setitem__":
+ tx.output.guards.update(options["guards"])
+ tx.output.create_proxy(
+ "call_function",
+ operator.setitem,
+ *proxy_args_kwargs([self] + args, kwargs),
+ current_tx=tx,
+ )
+ return ConstantVariable(None, **options)
+ else:
+ # Convert x.new(torch.Size) into x.new_empty(torch.Size),
+ # as Tensor.new acts differently with a Size input versus a tuple input.
+ if (
+ name == "new"
+ and len(args) == 1
+ and isinstance(args[0], (SizeVariable, ShapeVariable))
+ and not config.dynamic_shapes
+ ):
+ name = "new_empty"
+
+ return self.__class__.create(
+ tx,
+ tx.output.create_proxy(
+ "call_method",
+ name,
+ *proxy_args_kwargs([self] + args, kwargs),
+ current_tx=tx,
+ ),
+ **options,
+ )
+
+
+class DynamicShapeVariable(TensorVariable):
+ """
+ Represents a symbolic size, e.g., as returned by tensor.size(0)
+ """
+
+ def __init__(self, proxy, dyn_shape_cls, **kwargs):
+ super(DynamicShapeVariable, self).__init__(proxy, **kwargs)
+ self.dyn_shape_cls = dyn_shape_cls
+
+ def python_type(self):
+ return self.dyn_shape_cls
+
+ def unpack_var_sequence(self, tx):
+ super(DynamicShapeVariable, self).unpack_var_sequence(tx)
+
+
+class TensorWithTFOverrideVariable(VariableTracker):
+ """
+ Represents a tensor subclass instance with a __torch_function__ override.
+ """
+
+ def __init__(
+ self,
+ tensor_variable,
+ orig_tensor_variable_source,
+ subclass_torch_function__func,
+ subclass_type,
+ **kwargs,
+ ):
+ super(TensorWithTFOverrideVariable, self).__init__(**kwargs)
+ self.tensor_variable = tensor_variable
+ self.orig_tensor_variable_source = orig_tensor_variable_source
+ self.subclass_torch_function__func = subclass_torch_function__func
+ self.subclass_type = subclass_type
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ # This code block implements inlining the __torch_function__ override
+ # of `call_method`.
+ from . import GetAttrVariable
+
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ # insert unwrapped version of self as the first argument
+ args = list(args)
+ args.insert(0, self.tensor_variable)
+ func_var = GetAttrVariable(self.tensor_variable, name)
+
+ unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
+ tx,
+ func_var,
+ self.orig_tensor_variable_source,
+ self.subclass_torch_function__func,
+ self.subclass_type,
+ options,
+ args,
+ kwargs,
+ )
+
+ # TODO(future PR): implement rewrapping conditional on method presence
+ # in `torch.overrides.get_default_nowrap_function()`. It's unclear how
+ # to do this easily in the current codebase since the resolution of
+ # `GetAttrVariable` depends on the type of the underlying object.
+
+ return TensorWithTFOverrideVariable(
+ unwrapped,
+ self.orig_tensor_variable_source,
+ self.subclass_torch_function__func,
+ self.subclass_type,
+ )
+
+ @staticmethod
+ def inline_torch_function_unwrapped(
+ tx,
+ original_func_var,
+ tensor_with_tf_override_source,
+ tf_func,
+ subclass_type,
+ options,
+ args,
+ kwargs,
+ ):
+ """
+ This function inlines the `__torch_function__` override for `original_func_var`.
+ For example, if the user code is
+
+ x1 = torch.sigmoid(x0)
+
+ And `x0` has an override, then:
+ * `original_func_var` will be a `VariableTracker` object wrapping `torch.sigmoid`
+ * `tensor_with_tf_override_source` will be the `Source` object from
+ the original tensor override instance in the beginning of the program
+ * `tf_func` will be the custom `__torch_function__` function
+ * `subclass_type` will be `type(x0)`
+
+ The caller is expected to properly massage args and kwargs before
+ passing them into this function.
+
+ The caller is responsible for wrapping the return value, if needed.
+ """
+ from . import UserDefinedClassVariable
+ from .builder import TupleVariable, VariableBuilder
+
+ source = AttrSource(
+ AttrSource(tensor_with_tf_override_source, "__torch_function__"),
+ "__func__",
+ )
+ tf_func_var = VariableBuilder(tx, source)(tf_func)
+ type_var = UserDefinedClassVariable(subclass_type, **options)
+
+ # signature:
+ # def __torch_function__(cls, func, types, args=(), kwargs=None):
+ tf_args = (
+ type_var, # cls
+ original_func_var, # func
+ (type_var,), # types
+ TupleVariable(args), # args
+ kwargs, # kwargs
+ )
+
+ # Disable __torch_function__ here to prevent the clone of the
+ # example tensor from going into the override.
+ with torch._C.DisableTorchFunction():
+ return tx.inline_user_function_return(tf_func_var, tf_args, {})
+
+
+class UnspecializedNumpyVariable(TensorVariable):
+ """
+ This is a 1-element tensor represents unspecialized numpy float/int.
+ """
+
+ def __init__(self, proxy: torch.fx.Proxy, **kwargs):
+ raw_value = kwargs.pop("raw_value", None)
+ super(UnspecializedNumpyVariable, self).__init__(proxy, **kwargs)
+ self.raw_value = raw_value
+
+ @classmethod
+ def from_tensor_variable(cls, tensor_variable, raw_value):
+ # Convert a `TensorVariable` instance into an `UnspecializedNumpyVariable` instance.
+ return UnspecializedNumpyVariable(
+ **dict(tensor_variable.__dict__), raw_value=raw_value
+ )
+
+ def as_specialized(self, tx):
+ for graph_arg in tx.output.graphargs:
+ if graph_arg.source is self.source:
+ graph_arg.erase()
+
+ for g in self.guards:
+ if g.is_volatile:
+ g.create_fn = GuardBuilder.CONSTANT_MATCH
+
+ return ConstantVariable(value=self.raw_value, guards=self.guards)
+
+
+class UnspecializedPythonVariable(TensorVariable):
+ """
+ This is a 1-element tensor represents unspecialized python float/int.
+ """
+
+ def __init__(self, proxy: torch.fx.Proxy, **kwargs):
+ raw_value = kwargs.pop("raw_value", None)
+ need_unwrap = kwargs.pop("need_unwrap", True)
+ super(UnspecializedPythonVariable, self).__init__(proxy, **kwargs)
+ self.raw_value = raw_value
+ self.need_unwrap = need_unwrap
+
+ @classmethod
+ def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
+ # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
+ return UnspecializedPythonVariable(
+ **dict(tensor_variable.__dict__),
+ raw_value=raw_value,
+ need_unwrap=need_unwrap,
+ )
+
+ def as_specialized(self, tx):
+ for graph_arg in tx.output.graphargs:
+ if graph_arg.source is self.source:
+ graph_arg.erase()
+
+ for g in self.guards:
+ if g.is_volatile:
+ g.create_fn = GuardBuilder.CONSTANT_MATCH
+
+ return ConstantVariable(value=self.raw_value, guards=self.guards)
+
+
+class FakeItemVariable(TensorVariable):
+ """An unspecialized python variable which prevents access to the underlying raw value.
+ This is needed if item is called on a FakeTensor."""
+
+ def __init__(self, proxy: torch.fx.Proxy, **kwargs):
+ need_unwrap = kwargs.pop("need_unwrap", False)
+ super(FakeItemVariable, self).__init__(proxy, **kwargs)
+ self.need_unwrap = need_unwrap
+
+ @classmethod
+ def from_tensor_variable(cls, tensor_variable):
+ return FakeItemVariable(**dict(tensor_variable.__dict__))
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
new file mode 100644
index 0000000000000..56ec442ed10ce
--- /dev/null
+++ b/torch/_dynamo/variables/torch.py
@@ -0,0 +1,651 @@
+import logging
+import re
+import types
+from typing import Dict, List
+
+import numpy
+
+import torch._C
+import torch.nn
+import torch.onnx.operators
+
+from .. import config, variables
+from ..allowed_functions import torch_get_name
+from ..exc import unimplemented
+from ..source import GetItemSource, NNModuleSource
+from ..utils import (
+ check_constant_args,
+ check_unspec_python_args,
+ istype,
+ product,
+ proxy_args_kwargs,
+ specialize_args_kwargs,
+ tensortype_to_dtype,
+)
+from .base import VariableTracker
+from .lists import ListVariable, TupleVariable
+from .misc import AutocastModeVariable, ProfilerContextWrapperVariable
+from .tensor import TensorWithTFOverrideVariable
+
+log = logging.getLogger(__name__)
+
+# TODO(voz): Maybe rename these later
+tensor_dunder_fns = [
+ torch.Tensor.__rmatmul__,
+ torch.Tensor.__rmod__,
+ torch.Tensor.__rpow__,
+ torch.Tensor.__rsub__,
+ torch._C._TensorBase.__radd__,
+ torch._C._TensorBase.__rmul__,
+ torch._C._TensorBase.__ror__,
+ torch._C._TensorBase.__rxor__,
+ torch._C._TensorBase.__rand__,
+]
+
+torch_special_class_types = (torch._C.Generator,)
+
+REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [
+ torch.onnx.operators.shape_as_tensor,
+ torch._shape_as_tensor,
+]
+
+
+# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
+def remap_as_fn___radd__(*args):
+ return torch._C._TensorBase.__radd__(*args)
+
+
+def remap_as_fn___rmul__(*args):
+ return torch._C._TensorBase.__rmul__(*args)
+
+
+def remap_as_fn___ror__(*args):
+ return torch._C._TensorBase.__ror__(*args)
+
+
+def remap_as_fn___rxor__(*args):
+ return torch._C._TensorBase.__rxor__(*args)
+
+
+def remap_as_fn___rand__(*args):
+ return torch._C._TensorBase.__rand__(*args)
+
+
+tensor_dunder_fns_remap = {
+ torch._C._TensorBase.__radd__: remap_as_fn___radd__,
+ torch._C._TensorBase.__rmul__: remap_as_fn___rmul__,
+ torch._C._TensorBase.__ror__: remap_as_fn___ror__,
+ torch._C._TensorBase.__rxor__: remap_as_fn___rxor__,
+ torch._C._TensorBase.__rand__: remap_as_fn___rand__,
+}
+
+
+try:
+ # Wed need to monkeypatch transformers here, sadly.
+ # TODO(voz): Upstream to transformers lib
+ import transformers
+
+ def _dynamo_overriden_transformers_eq(self, other):
+ if not hasattr(other, "__dict__"):
+ return False
+ return self.__dict__ == other.__dict__
+
+ transformers.configuration_utils.PretrainedConfig.__eq__ = (
+ _dynamo_overriden_transformers_eq
+ )
+except ImportError:
+ pass
+
+
+class TorchVariable(VariableTracker):
+ """Points to a module or method in torch.*"""
+
+ def __init__(self, value, **kwargs):
+ super(TorchVariable, self).__init__(**kwargs)
+
+ if value in tensor_dunder_fns_remap:
+ value = tensor_dunder_fns_remap[value]
+ self.value = value
+
+ # the remainder of this is just optional debug checks
+ try:
+ self_should_be_none = getattr(self.value, "__self__", None)
+ except RuntimeError as e:
+ assert "No such operator" in str(e), str(e)
+ self_should_be_none = None
+
+ # assert "_ntuple..parse" not in str(value)
+
+ if self_should_be_none is None:
+ pass
+ elif isinstance(self_should_be_none, types.ModuleType):
+ # weird ones like torch.nn.functional.avg_pool2d have __self__
+ name = self_should_be_none.__name__
+ assert re.match(r"^(torch|math)([.]|$)", name), f"__self__ set to {name}"
+ elif isinstance(
+ self_should_be_none, type(torch._C._get_tracing_state.__self__)
+ ):
+ # some _C functions have __self__ as a null capsule
+ pass
+ elif isinstance(self_should_be_none, torch_special_class_types):
+ pass
+ else:
+ raise AssertionError(f"{value} found with __self__ set")
+
+ def __repr__(self):
+ return f"TorchVariable({self.value})"
+
+ def unique_var_name(self):
+ name = torch_get_name(self.value, f"allowed_fn_{id(self.value)}")
+ return "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name)
+
+ def reconstruct(self, codegen):
+ return codegen.setup_globally_cached(self.unique_var_name(), self.value)
+
+ def as_proxy(self):
+ return self.value
+
+ def python_type(self):
+ if isinstance(self.value, (torch.Tensor, torch.nn.Module)):
+ return type(self.value)
+ return super().python_type()
+
+ def as_python_constant(self):
+ return self.value
+
+ def can_constant_fold_through(self):
+ if self.value in (
+ torch._assert,
+ torch.device,
+ torch.finfo,
+ torch.iinfo,
+ torch.is_floating_point,
+ torch.is_tensor,
+ torch.overrides.is_tensor_like,
+ ):
+ return True
+ return getattr(self.value, "__module__", None) == "math"
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ from . import ConstantVariable, GradModeVariable, TensorVariable
+
+ constant_args = check_constant_args(args, kwargs)
+ unspec_python_args = check_unspec_python_args(args, kwargs)
+ options = VariableTracker.propagate(self, args, kwargs.values())
+
+ if self.value in config.constant_functions:
+ assert not args and not kwargs
+ return ConstantVariable(config.constant_functions[self.value], **options)
+ elif self.can_constant_fold_through() and (constant_args or unspec_python_args):
+ args, kwargs = specialize_args_kwargs(tx, args, kwargs)
+ # constant fold
+ return ConstantVariable(
+ self.as_python_constant()(
+ *[x.as_python_constant() for x in args],
+ **{k: v.as_python_constant() for k, v in kwargs.items()},
+ ),
+ **options,
+ )
+ elif istype(self.value, type) and issubclass(self.value, torch.nn.Module):
+ if self.value is torch.nn.Softmax:
+ return self._call_softmax(tx, args, kwargs, options)
+ if self.value is torch.nn.CrossEntropyLoss:
+ return self._call_cross_entropy_loss(tx, args, kwargs, options)
+ else:
+ unimplemented(f"construct nn.Module: {self.value.__name__}")
+ elif (
+ self.value
+ in (
+ torch.is_tensor,
+ torch.is_floating_point,
+ torch.is_complex,
+ torch.overrides.is_tensor_like,
+ torch.is_complex,
+ )
+ and isinstance(args[0], TensorVariable)
+ and args[0].dtype is not None
+ ):
+ if self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
+ return ConstantVariable(True, **options)
+ elif self.value is torch.is_floating_point:
+ return ConstantVariable(args[0].dtype.is_floating_point, **options)
+ elif self.value is torch.is_complex:
+ return ConstantVariable(args[0].dtype.is_complex, **options)
+ else:
+ raise AssertionError()
+ elif (
+ self.value is torch.numel
+ and isinstance(args[0], TensorVariable)
+ and args[0].size is not None
+ ):
+ return ConstantVariable(product(args[0].size), **options)
+ elif self.value in REWRITE_OPS_TO_TENSOR_SIZE_METHOD:
+ assert len(args) == 1
+ assert isinstance(args[0], TensorVariable)
+ return args[0].call_method(tx, "size", [], {})
+ elif self.value in (
+ torch.nn.modules.utils._single,
+ torch.nn.modules.utils._pair,
+ torch.nn.modules.utils._triple,
+ torch.nn.modules.utils._quadruple,
+ torch.nn.modules.utils._ntuple,
+ ):
+ return self._call_ntuple(tx, args, kwargs, options)
+ elif self.value is torch.no_grad:
+ return GradModeVariable.create(tx, False, **options)
+ elif self.value is torch.enable_grad:
+ return GradModeVariable.create(tx, True, **options)
+ elif self.value is torch.set_grad_enabled and len(args) == 1:
+ return GradModeVariable.create(tx, args[0].as_python_constant(), **options)
+ elif self.value is torch.is_grad_enabled:
+ assert not (args or kwargs)
+ return ConstantVariable(torch.is_grad_enabled(), **options).add_guards(
+ GradModeVariable._guards_singleton
+ )
+ elif not config.dynamic_shapes and self.is_dynamic_shapes(args, kwargs):
+ unimplemented(f"dynamic shapes: {self.value.__name__}")
+ elif len(args) > 0 and isinstance(args[0], TensorWithTFOverrideVariable):
+ # This code block implements inlining the __torch_function__
+ # override of a tensor.
+
+ tensor_with_tf_override = args[0]
+
+ # TODO(future PR): make this implement the full __torch_function__ API
+ # instead of assuming the relevant override is in the first argument.
+ args[0] = args[0].tensor_variable
+
+ unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
+ tx,
+ self,
+ tensor_with_tf_override.orig_tensor_variable_source,
+ tensor_with_tf_override.subclass_torch_function__func,
+ tensor_with_tf_override.subclass_type,
+ options,
+ args,
+ kwargs,
+ )
+
+ # The wrapping here follows the logic in
+ # `torch.Tensor.__torch_function__`.
+ if self.value in torch.overrides.get_default_nowrap_functions():
+ return unwrapped
+ return TensorWithTFOverrideVariable(
+ unwrapped,
+ tensor_with_tf_override.orig_tensor_variable_source,
+ tensor_with_tf_override.subclass_torch_function__func,
+ tensor_with_tf_override.subclass_type,
+ )
+ elif self.value is torch.amp.autocast_mode.autocast:
+ return AutocastModeVariable.create(tx, target_values=args, kwargs=kwargs)
+ elif self.value in (
+ torch.profiler.profile,
+ torch.profiler.record_function,
+ torch.autograd.profiler.profile,
+ torch.autograd.profiler.record_function,
+ ):
+ log.warning("Profiler will be ignored")
+ return ProfilerContextWrapperVariable(**options)
+ elif self.value is torch.jit.annotate:
+ assert len(args) == 2
+ return args[1]
+ if (
+ self.value.__name__ == "get_state"
+ and hasattr(self.value, "__self__")
+ and isinstance(self.value.__self__, torch._C.Generator)
+ ):
+
+ def get_state_from_generator():
+ return self.value()
+
+ return TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_function",
+ get_state_from_generator,
+ *proxy_args_kwargs(args, kwargs),
+ current_tx=tx,
+ ),
+ example_value=self.value(),
+ **options,
+ )
+ if (
+ self.value.__name__ == "set_state"
+ and hasattr(self.value, "__self__")
+ and isinstance(self.value.__self__, torch._C.Generator)
+ ):
+ assert len(args) == 1
+ assert isinstance(args[0], TensorVariable)
+
+ if config.fake_tensor_propagation:
+ # In fake tensor case, this state doesn't matter, but
+ # it needs to be valid to not segfault. Pull a real tensor out.
+ # The value won't matter since we are running with fake tensors anyway, so rng doesn't matter.
+ # However, it is imperative to record the call_function in the graph with the true args
+ # (Not the fake example_value) - for the sake of graph correctness.
+ example_value = self.value.__self__.get_state()
+ else:
+ example_value = args[0].proxy.node.meta["example_value"]
+
+ self.value.__module__ = self.__module__
+ return TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_function",
+ self.value,
+ *proxy_args_kwargs(args, kwargs),
+ current_tx=tx,
+ ),
+ example_value=example_value,
+ **options,
+ )
+ else:
+ # Handle sth like torch.LongTensor(list(np.int64, np.int64, ...)),
+ # as FX symbolic trace doesn't support numpy int/float as base types.
+ if (
+ self.value in tensortype_to_dtype
+ and len(args) == 1
+ and isinstance(args[0], ListVariable)
+ and args[0].is_python_constant()
+ ):
+ for x in args[0].items:
+ if isinstance(x.value, numpy.generic):
+ x.value = x.value.item()
+
+ tensor_variable = TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_function",
+ self.value,
+ *proxy_args_kwargs(args, kwargs),
+ current_tx=tx,
+ ),
+ **options,
+ )
+
+ if "out" in kwargs:
+ # out variants of torch operators like torch.sort and
+ # torch.sigmoid mutate the tensors in the out field. Track such
+ # tensors and rewrite the symbolic locals.
+ if isinstance(tensor_variable, TupleVariable):
+ assert isinstance(kwargs["out"], TupleVariable)
+ output_tensor_names = [
+ tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
+ ]
+ for idx, name in enumerate(output_tensor_names):
+ assert name in tx.symbolic_locals
+ tx.symbolic_locals[name] = tensor_variable.items[idx]
+ elif isinstance(tensor_variable, TensorVariable):
+ assert isinstance(kwargs["out"], TensorVariable)
+ name = tx.find_symbolic_locals_name(kwargs["out"])
+ assert name in tx.symbolic_locals
+ tx.symbolic_locals[name] = tensor_variable
+ else:
+ unimplemented(f"out variant of {type(kwargs['out'])}")
+
+ return tensor_variable
+
+ def is_dynamic_shapes(self, args, kwargs):
+ """Check for dynamic shapes when shape specialization is enabled"""
+ # TODO(jansel): need to get a complete list
+ if self.value in (
+ torch.nonzero,
+ torch.unique,
+ torch.unique_consecutive,
+ ) or self.value.__name__ in ("nms",):
+ return True
+
+ if self.value is torch.where and len(args) + len(kwargs) == 1:
+ return True
+
+ if self.value in (
+ torch.arange,
+ torch.repeat_interleave,
+ ):
+ none = variables.ConstantVariable(None)
+
+ def has_non_const(it):
+ return not all(x.is_python_constant() for x in it)
+
+ def arange(start=none, end=none, step=none, **kwargs):
+ return has_non_const([start, end, step])
+
+ def repeat_interleave(input, repeats, dim=none, **kwargs):
+ return has_non_const([repeats])
+
+ return locals()[self.value.__name__](*args, **kwargs)
+
+ return False
+
+ def _call_softmax(self, tx, args, kwargs, options):
+ """rewrite the pattern nn.Softmax(dim=-1)(x) to F.softmax(x, -1)"""
+ dim = args[0] if args else kwargs.get("dim", variables.ConstantVariable(None))
+
+ def fake_softmax(input):
+ return variables.TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_function",
+ torch.nn.functional.softmax,
+ *proxy_args_kwargs([input, dim], {}),
+ current_tx=tx,
+ ),
+ **VariableTracker.propagate([self, dim, input]),
+ )
+
+ return variables.LambdaVariable(fake_softmax, **options)
+
+ def _call_cross_entropy_loss(self, tx, args, kwargs, options):
+ """
+ functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
+ label_smoothing=0.0
+
+ non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
+ label_smoothing=0.0
+
+ non functional loss call: input, target, optional_output
+ """
+ from . import ConstantVariable
+
+ def normalize_args(
+ weight=ConstantVariable(None),
+ size_average=ConstantVariable(None),
+ ignore_index=ConstantVariable(-100),
+ reduce=ConstantVariable(None),
+ reduction=ConstantVariable("mean"),
+ label_smoothing=ConstantVariable(0.0),
+ ):
+ return (
+ weight,
+ size_average,
+ ignore_index,
+ reduce,
+ reduction,
+ label_smoothing,
+ )
+
+ (
+ weight,
+ size_average,
+ ignore_index,
+ reduce_arg,
+ reduction,
+ label_smoothing,
+ ) = normalize_args(*args, **kwargs)
+
+ def fake_cross_entropy_loss(input, target):
+ return variables.TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_function",
+ torch.nn.functional.cross_entropy,
+ *proxy_args_kwargs(
+ [
+ input,
+ target,
+ weight,
+ size_average,
+ ignore_index,
+ reduce_arg,
+ reduction,
+ label_smoothing,
+ ],
+ {},
+ ),
+ current_tx=tx,
+ ),
+ **VariableTracker.propagate(
+ [
+ self,
+ weight,
+ size_average,
+ ignore_index,
+ reduce_arg,
+ reduction,
+ label_smoothing,
+ input,
+ target,
+ ]
+ ),
+ )
+
+ return variables.LambdaVariable(fake_cross_entropy_loss, **options)
+
+ def _call_ntuple(self, tx, args, kwargs, options):
+ """inline behavior of torch.nn.modules.utils._ntuple"""
+ if self.value is torch.nn.modules.utils._ntuple:
+ count = args[0].as_python_constant()
+ else:
+ count = self.value.__closure__[0].cell_contents
+ assert isinstance(count, int)
+
+ def handle_ntuple(value):
+ if value.has_unpack_var_sequence(tx):
+ return variables.TupleVariable(
+ list(value.unpack_var_sequence(tx)),
+ **VariableTracker.propagate(self, value, args, kwargs.values()),
+ )
+ elif value.is_python_constant():
+ # constant prop through it
+ return variables.ConstantVariable(
+ torch.nn.modules.utils._ntuple(count)(value.as_python_constant()),
+ **VariableTracker.propagate(self, value, args, kwargs.values()),
+ )
+ else:
+ unimplemented(f"torch.nn.modules.utils._ntuple({value})")
+
+ if self.value is torch.nn.modules.utils._ntuple:
+ return variables.LambdaVariable(handle_ntuple, **options)
+ else:
+ return handle_ntuple(args[0])
+
+
+class TorchPyOperator(VariableTracker):
+ def __init__(self, value, **kwargs):
+ super(TorchPyOperator, self).__init__(**kwargs)
+ self.value = value
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ from . import ListVariable, TensorVariable, UserFunctionVariable
+
+ assert kwargs is None or len(kwargs) == 0, "kwargs are not supported, yet"
+
+ def unwrap_real(arg):
+ if isinstance(arg, TensorVariable):
+ return arg.as_proxy().node.meta["example_value"]
+ if isinstance(arg, UserFunctionVariable):
+ return arg.fn
+ if arg.has_unpack_var_sequence(tx):
+ return [
+ unwrap_real(arg_inner) for arg_inner in arg.unpack_var_sequence(tx)
+ ]
+ return arg
+
+ # Get values
+ u_args = [unwrap_real(arg) for arg in args]
+
+ def unwrap_proxy(arg):
+ try:
+ return arg.as_proxy()
+ except NotImplementedError:
+ return arg
+
+ def register_as_subgraph(fn, name, args):
+ from .. import export
+
+ gm, guards = export(fn, *args)
+
+ next_name = None
+ i = 0
+ while not next_name:
+ candidate = f"name_{i}"
+ if candidate in tx.output.nn_modules:
+ i += 1
+ else:
+ next_name = candidate
+
+ gm.__name__ = next_name
+ src = NNModuleSource(GetItemSource(self.source, next_name))
+ gm.torchdynamo_force_dynamic = False
+ tx.output.register_attr_or_module(gm, next_name, source=src)
+ return next_name, gm, guards
+
+ # Get args as proxies
+ p_args = [unwrap_proxy(arg) for arg in args]
+ if self.value.__name__ == "cond":
+ # TODO(voz): Support fake tensor dispatch for recursive
+ # ops - see torch/dispatch/_dispatcher.py
+ from .. import config
+
+ if config.fake_tensor_propagation:
+ unimplemented("Fake tensor mode not yet supported for cond")
+
+ assert len(p_args) == 4
+ assert type(args[0]) is TensorVariable # predicate
+ assert type(p_args[1]) is UserFunctionVariable # true_fn
+ assert type(p_args[2]) is UserFunctionVariable # false_fn
+ assert type(args[3]) is ListVariable # args
+
+ node_args = [unwrap_real(x) for x in args[3].unpack_var_sequence(tx)]
+ proxy_args = [unwrap_proxy(x) for x in args[3].unpack_var_sequence(tx)]
+ true_name, true_graph, true_guards = register_as_subgraph(
+ p_args[1].get_function(), "true", node_args
+ )
+ false_name, false_graph, false_guards = register_as_subgraph(
+ p_args[2].get_function(), "false", node_args
+ )
+
+ if config.enforce_cond_guards_match:
+ assert (
+ true_guards == false_guards
+ ), "Guards for true and false path must be equal."
+
+ def make_attr(name):
+ node = tx.output.create_proxy(
+ "get_attr",
+ name,
+ tuple(proxy_args),
+ {},
+ )
+ return node
+
+ true_node = make_attr(true_name)
+ false_node = make_attr(false_name)
+ p_args[1] = true_node
+ p_args[2] = false_node
+
+ # Store the invocation as a call
+ return variables.TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_function",
+ self.value,
+ args=tuple(p_args),
+ kwargs={},
+ current_tx=tx,
+ ),
+ example_value=self.value(*u_args),
+ )
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
new file mode 100644
index 0000000000000..2d33c8328268a
--- /dev/null
+++ b/torch/_dynamo/variables/user_defined.py
@@ -0,0 +1,382 @@
+import collections
+import dataclasses
+import functools
+import importlib
+import inspect
+import random
+import types
+from typing import Dict, List
+
+import torch.nn
+
+from .. import variables
+from ..exc import unimplemented
+from ..guards import Guard, GuardBuilder
+from ..source import AttrSource, ODictGetItemSource, RandomValueSource
+from ..utils import is_namedtuple_cls, namedtuple_fields
+from .base import MutableLocal, VariableTracker
+from .misc import ProfilerContextWrapperVariable
+
+
+class UserDefinedVariable(VariableTracker):
+ pass
+
+
+class UserDefinedClassVariable(UserDefinedVariable):
+ def __init__(self, value, **kwargs):
+ super().__init__(**kwargs)
+ self.value = value
+
+ def as_python_constant(self):
+ return self.value
+
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
+ options = VariableTracker.propagate(self)
+ try:
+ obj = inspect.getattr_static(self.value, name)
+ except AttributeError:
+ obj = None
+
+ if isinstance(obj, staticmethod):
+ return variables.UserFunctionVariable(obj.__get__(self.value), **options)
+ elif isinstance(obj, classmethod):
+ return variables.UserMethodVariable(obj.__func__, self, **options)
+
+ return super(UserDefinedClassVariable, self).var_getattr(tx, name)
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ if (
+ name == "__subclasses__"
+ and len(args) == 0
+ and not kwargs
+ and "__subclasses__" not in self.value.__dict__
+ ):
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ options["mutable_local"] = MutableLocal()
+ subs_as_vars: List[VariableTracker] = list()
+ for sub in self.value.__subclasses__():
+ source = AttrSource(tx.import_source(sub.__module__), sub.__name__)
+ subs_as_vars.append(
+ variables.UserDefinedClassVariable(sub, source=source)
+ )
+
+ return variables.ListVariable(subs_as_vars, **options)
+
+ return super().call_method(tx, args, kwargs)
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ from ..side_effects import SideEffects
+
+ options = VariableTracker.propagate(self, args, kwargs.values())
+
+ if self.value is torch.autograd.profiler.profile:
+ return ProfilerContextWrapperVariable()
+ elif is_namedtuple_cls(self.value):
+ fields = namedtuple_fields(self.value)
+ items = list(args)
+ items.extend([None] * (len(fields) - len(items)))
+ for name, value in kwargs.items():
+ assert name in fields
+ items[fields.index(name)] = value
+ assert all(x is not None for x in items)
+ return variables.NamedTupleVariable(
+ items, self.value, **VariableTracker.propagate(self, items)
+ )
+ elif (
+ inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
+ and SideEffects.cls_supports_mutation_side_effects(self.value)
+ and self.source
+ ):
+ var = tx.output.side_effects.track_object_new(
+ self.source, self.value, UserDefinedObjectVariable, options
+ )
+ return var.add_options(var.call_method(tx, "__init__", args, kwargs))
+ elif variables.DataClassVariable.is_matching_cls(self.value):
+ options["mutable_local"] = MutableLocal()
+ return variables.DataClassVariable.create(self.value, args, kwargs, options)
+
+ return super().call_function(tx, args, kwargs)
+
+ def const_getattr(self, tx, name):
+ if name == "__name__":
+ return self.value.__name__
+ return super().const_getattr(tx, name)
+
+
+class UserDefinedObjectVariable(UserDefinedVariable):
+ """
+ Mostly objects of defined type. Catch-all for something where we only know the type.
+ """
+
+ def __init__(self, value, value_type=None, **kwargs):
+ super(UserDefinedObjectVariable, self).__init__(**kwargs)
+ self.value = value
+ self.value_type = value_type or type(value)
+ assert type(value) is self.value_type
+
+ def __str__(self):
+ inner = self.value_type.__name__
+ if inner in [
+ "builtin_function_or_method",
+ "getset_descriptor",
+ "method_descriptor",
+ "method",
+ ]:
+ inner = str(getattr(self.value, "__name__", None))
+ return f"{self.__class__.__name__}({inner})"
+
+ def python_type(self):
+ return self.value_type
+
+ @staticmethod
+ @functools.lru_cache(None)
+ def _supported_random_functions():
+ fns = {
+ random.random,
+ random.randint,
+ random.randrange,
+ random.uniform,
+ }
+ return fns
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: "List[VariableTracker]",
+ kwargs: "Dict[str, VariableTracker]",
+ ) -> "VariableTracker":
+ from . import ConstantVariable, TupleVariable, UserMethodVariable
+
+ options = VariableTracker.propagate(self, args, kwargs.values())
+
+ if name not in getattr(self.value, "__dict__", {}):
+ try:
+ method = inspect.getattr_static(type(self.value), name)
+ except AttributeError:
+ method = None
+
+ if method is object.__init__:
+ return ConstantVariable(None, **options)
+
+ if method is collections.OrderedDict.keys and self.source:
+ # subclass of OrderedDict
+ assert not (args or kwargs)
+ keys = list(self.value.keys())
+ assert all(map(ConstantVariable.is_literal, keys))
+ return TupleVariable(
+ [ConstantVariable(k, **options) for k in keys], **options
+ ).add_guard(
+ Guard(
+ self.source.name(),
+ self.source.guard_source(),
+ GuardBuilder.ODICT_KEYS,
+ )
+ )
+
+ if (
+ method is collections.OrderedDict.items
+ and isinstance(self.value, collections.OrderedDict)
+ and self.source
+ ):
+ assert not (args or kwargs)
+ items = []
+ keys = self.call_method(tx, "keys", [], {})
+ options = VariableTracker.propagate(self, args, kwargs.values(), keys)
+ for key in keys.unpack_var_sequence(tx):
+ items.append(
+ TupleVariable(
+ [key, self.odict_getitem(tx, key)],
+ **options,
+ )
+ )
+ return TupleVariable(items, **options)
+
+ if method is collections.OrderedDict.__getitem__ and len(args) == 1:
+ assert not kwargs
+ return self.odict_getitem(tx, args[0])
+
+ # check for methods implemented in C++
+ if isinstance(method, types.FunctionType):
+ # TODO(jansel): add a guard to check for monkey patching?
+ return UserMethodVariable(method, self, **options).call_function(
+ tx, args, kwargs
+ )
+
+ return super().call_method(tx, name, args, kwargs)
+
+ def is_supported_random(self):
+ try:
+ return self.value in self._supported_random_functions()
+ except TypeError:
+ # TypeError: unhashable type
+ return False
+
+ def call_function(
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
+ ) -> "VariableTracker":
+ from .builder import VariableBuilder
+
+ if (
+ self.is_supported_random()
+ and all(k.is_python_constant() for k in args)
+ and all(v.is_python_constant() for v in kwargs.values())
+ ):
+ args = [x.as_python_constant() for x in args]
+ kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
+ random_call_index = len(tx.random_calls)
+ if random_call_index == 0:
+ tx.output.initial_random_state = random.getstate()
+ example_value = self.value(*args, **kwargs)
+ source = RandomValueSource(random_call_index)
+ tx.random_calls.append((self.value, args, kwargs))
+ return VariableBuilder(tx, source).wrap_unspecialized_primitive(
+ example_value
+ )
+
+ return super().call_function(tx, args, kwargs)
+
+ def _check_for_getattribute(self):
+ try:
+ if isinstance(
+ inspect.getattr_static(type(self.value), "__getattribute__"),
+ types.FunctionType,
+ ):
+ unimplemented("UserDefinedObjectVariable with custom __getattribute__")
+ except AttributeError:
+ pass
+
+ def _check_for_getattr(self):
+ try:
+ getattr_fn = inspect.getattr_static(type(self.value), "__getattr__")
+ except AttributeError:
+ getattr_fn = None
+ if getattr_fn is torch.nn.Module.__getattr__:
+ # ignore this case of getattr
+ getattr_fn = None
+ return getattr_fn
+
+ def _getattr_static(self, name):
+ if isinstance(self.value, (dataclasses.Field, torch.nn.Module)):
+ # getattr_static doesn't work on these
+ subobj = getattr(self.value, name)
+ else:
+ subobj = inspect.getattr_static(self.value, name)
+ return subobj
+
+ def var_getattr(self, tx, name):
+ from . import ConstantVariable
+ from .builder import VariableBuilder
+
+ options = VariableTracker.propagate(self)
+ value = self.value
+ source = AttrSource(self.source, name) if self.source else None
+ self._check_for_getattribute()
+ getattr_fn = self._check_for_getattr()
+
+ try:
+ subobj = self._getattr_static(name)
+ except AttributeError:
+ if isinstance(getattr_fn, types.FunctionType):
+ return variables.UserMethodVariable(
+ getattr_fn, self, **options
+ ).call_function(tx, [ConstantVariable(name)], {})
+ elif getattr_fn is not None:
+ unimplemented("UserDefined with non-function __getattr__")
+
+ if isinstance(subobj, property):
+ return variables.UserMethodVariable(
+ subobj.fget, self, **options
+ ).call_function(tx, [], {})
+
+ if (
+ name in getattr(value, "__dict__", {})
+ or ConstantVariable.is_literal(subobj)
+ or isinstance(
+ subobj,
+ (
+ torch.Tensor,
+ torch.nn.Module,
+ ),
+ )
+ ):
+ if source:
+ return VariableBuilder(tx, source)(subobj).add_options(options)
+ elif ConstantVariable.is_literal(subobj):
+ return ConstantVariable(subobj, **options)
+
+ if (
+ name not in getattr(value, "__dict__", {})
+ and type(value).__module__.startswith("torch.")
+ and "torch.optim" not in type(value).__module__
+ and not callable(value)
+ ):
+ if not source:
+ assert getattr(
+ importlib.import_module(type(value).__module__),
+ type(value).__name__,
+ ) is type(value)
+ source = AttrSource(
+ AttrSource(
+ tx.import_source(type(value).__module__), type(value).__name__
+ ),
+ name,
+ )
+
+ return VariableBuilder(tx, source)(subobj).add_options(options)
+
+ if isinstance(
+ subobj,
+ (
+ torch.distributions.constraints._Interval,
+ torch.distributions.constraints._Real,
+ torch.distributions.constraints.Constraint,
+ ),
+ ):
+ return UserDefinedObjectVariable(subobj, source=source, **options)
+
+ if isinstance(subobj, staticmethod):
+ return variables.UserFunctionVariable(subobj.__get__(self.value), **options)
+ elif isinstance(subobj, classmethod):
+ return variables.UserMethodVariable(subobj.__func__, self, **options)
+
+ if name == "__class__":
+ return UserDefinedClassVariable(type(self.value), source=source, **options)
+
+ return variables.GetAttrVariable(self, name, source=source, **options)
+
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
+ if not self.source:
+ unimplemented("hasattr no source")
+ options = VariableTracker.propagate(self)
+ options["guards"].add(
+ AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
+ )
+ if self._check_for_getattribute() or self._check_for_getattr():
+ unimplemented("hasattr with custom __getattr__")
+
+ try:
+ self._getattr_static(name)
+ return variables.ConstantVariable(True, **options)
+ except AttributeError:
+ return variables.ConstantVariable(False, **options)
+
+ def odict_getitem(self, tx, key):
+ from .builder import VariableBuilder
+
+ return VariableBuilder(
+ tx,
+ ODictGetItemSource(self.source, key.as_python_constant()),
+ )(
+ collections.OrderedDict.__getitem__(self.value, key.as_python_constant())
+ ).add_options(
+ key, self
+ )
diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
new file mode 100644
index 0000000000000..8b36db16a8a68
--- /dev/null
+++ b/torch/_inductor/codecache.py
@@ -0,0 +1,261 @@
+import base64
+import functools
+import getpass
+import hashlib
+import logging
+import os
+import re
+import shutil
+import subprocess
+import sysconfig
+import tempfile
+import types
+from concurrent.futures import Future, ThreadPoolExecutor
+from ctypes import cdll
+from typing import Any, Dict
+
+import torch
+from torch.utils import cpp_extension
+
+from . import config, exc
+
+LOCK_TIMEOUT = 600
+
+log = logging.getLogger(__name__)
+logging.getLogger("filelock").setLevel(logging.DEBUG if config.debug else logging.INFO)
+
+
+def cache_dir():
+ return f"/tmp/torchinductor_{getpass.getuser()}"
+
+
+def get_lock_dir():
+ lock_dir = os.path.join(cache_dir(), "locks")
+ if not os.path.exists(lock_dir):
+ os.makedirs(lock_dir, exist_ok=True)
+ return lock_dir
+
+
+def code_hash(code):
+ return (
+ "c"
+ + base64.b32encode(hashlib.sha256(code.encode("utf-8")).digest())[:51]
+ .decode("utf-8")
+ .lower()
+ )
+
+
+def write(source_code, ext, extra=""):
+ basename = code_hash(source_code + extra)
+ subdir = os.path.join(cache_dir(), basename[1:3])
+ if not os.path.exists(subdir):
+ os.makedirs(subdir, exist_ok=True)
+ path = os.path.join(subdir, f"{basename}.{ext}")
+ if not os.path.exists(path):
+ # use a temp file for thread safety
+ fd, tmp_path = tempfile.mkstemp(dir=subdir)
+ with os.fdopen(fd, "w") as f:
+ f.write(source_code)
+ os.rename(tmp_path, path)
+ return basename, path
+
+
+def cpp_compiler():
+ if isinstance(config.cpp.cxx, (list, tuple)):
+ search = tuple(config.cpp.cxx)
+ else:
+ search = (config.cpp.cxx,)
+ return cpp_compiler_search(search)
+
+
+@functools.lru_cache(1)
+def cpp_compiler_search(search):
+ for cxx in search:
+ try:
+ if cxx is None:
+ from filelock import FileLock
+
+ lock_dir = get_lock_dir()
+ lock = FileLock(
+ os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT
+ )
+ with lock:
+ cxx = install_gcc_via_conda()
+ subprocess.check_output([cxx, "--version"])
+ return cxx
+ except (subprocess.SubprocessError, FileNotFoundError, ImportError):
+ continue
+ raise exc.InvalidCxxCompiler()
+
+
+def install_gcc_via_conda():
+ """On older systems, this is a quick way to get a modern compiler"""
+ prefix = os.path.join(cache_dir(), "gcc")
+ cxx_path = os.path.join(prefix, "bin", "g++")
+ if not os.path.exists(cxx_path):
+ log.info("Downloading GCC via conda")
+ conda = os.environ.get("CONDA_EXE", "conda")
+ if conda is None:
+ conda = shutil.which("conda")
+ if conda is not None:
+ subprocess.check_call(
+ [
+ conda,
+ "create",
+ f"--prefix={prefix}",
+ "--channel=conda-forge",
+ "--quiet",
+ "-y",
+ "python=3.8",
+ "gxx",
+ ],
+ stdout=subprocess.PIPE,
+ )
+ return cxx_path
+
+
+def is_gcc():
+ return re.search(r"(gcc|g\+\+)", cpp_compiler())
+
+
+def cpp_compile_command(input, output, include_pytorch=False):
+ if include_pytorch:
+ ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
+ lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")]
+ libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"]
+ else:
+ # Note - this is effectively a header only inclusion. Usage of some header files may result in
+ # symbol not found, if those header files require a library.
+ # For those cases, include the lpath and libs command as we do for pytorch above.
+ # This approach allows us to only pay for what we use.
+ ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
+ lpaths = []
+ libs = ["gomp"]
+ ipaths = " ".join(["-I" + p for p in ipaths])
+ lpaths = " ".join(["-L" + p for p in lpaths])
+ libs = " ".join(["-l" + p for p in libs])
+ return re.sub(
+ r"[ \n]+",
+ " ",
+ f"""
+ {cpp_compiler()} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable
+ {ipaths} {lpaths} {libs}
+ -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp
+ -o{output} {input}
+ """,
+ ).strip()
+
+
+class CppCodeCache:
+ cache = dict()
+ clear = staticmethod(cache.clear)
+
+ @classmethod
+ def load(cls, source_code):
+ key, input_path = write(source_code, "cpp", extra=cpp_compile_command("i", "o"))
+ if key not in cls.cache:
+ from filelock import FileLock
+
+ lock_dir = get_lock_dir()
+ lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
+ with lock:
+ output_path = input_path[:-3] + "so"
+ if not os.path.exists(output_path):
+ cmd = cpp_compile_command(
+ input=input_path, output=output_path
+ ).split(" ")
+ try:
+ subprocess.check_output(cmd, stderr=subprocess.STDOUT)
+ except subprocess.CalledProcessError as e:
+ raise exc.CppCompileError(cmd, e.output)
+
+ cls.cache[key] = cdll.LoadLibrary(output_path)
+ cls.cache[key].key = key
+
+ return cls.cache[key]
+
+
+class PyCodeCache:
+ cache = dict()
+ clear = staticmethod(cache.clear)
+
+ @classmethod
+ def load(cls, source_code):
+ key, path = write(source_code, "py")
+ if key not in cls.cache:
+ with open(path) as f:
+ code = compile(f.read(), path, "exec")
+ mod = types.ModuleType(f"{__name__}.{key}")
+ mod.__file__ = path
+ mod.key = key
+ exec(code, mod.__dict__, mod.__dict__)
+ # another thread might set this first
+ cls.cache.setdefault(key, mod)
+ return cls.cache[key]
+
+
+@functools.lru_cache(None)
+def patch_triton_dir():
+ os.environ["TRITON_CACHE_DIR"] = os.environ.get(
+ "TRITON_CACHE_DIR", os.path.join(cache_dir(), "triton")
+ )
+
+
+class TritonCodeCache:
+ @staticmethod
+ def get_name(mod):
+ (name,) = [n for n in dir(mod) if n.startswith("kernel")]
+ return name
+
+ @classmethod
+ def load(cls, source_code):
+ patch_triton_dir()
+ mod = PyCodeCache.load(source_code)
+ return getattr(mod, cls.get_name(mod))
+
+
+class AsyncCompile:
+ def __init__(self):
+ self._context_keepalive = None
+
+ @staticmethod
+ @functools.lru_cache(1)
+ def pool():
+ assert config.compile_threads > 1
+ return ThreadPoolExecutor(config.compile_threads)
+
+ @classmethod
+ def submit(cls, task):
+ if config.compile_threads <= 1:
+ return task()
+ return cls.pool().submit(task)
+
+ @classmethod
+ def map(cls, fn, seq):
+ if config.compile_threads <= 1 or len(seq) <= 1:
+ return list(map(fn, seq))
+ return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]]
+
+ def triton(self, source_code):
+ if self._context_keepalive is None:
+ # Workaround `CUDA: Error- context is destroyed`
+ self._context_keepalive = torch.tensor([1], device="cuda")
+ kernel = TritonCodeCache.load(source_code)
+
+ def task():
+ kernel.precompile()
+ return kernel
+
+ return self.submit(task)
+
+ def cpp(self, source_code):
+ def task():
+ return CppCodeCache.load(source_code).kernel
+
+ return self.submit(task)
+
+ def wait(self, scope: Dict[str, Any]):
+ if config.compile_threads > 1:
+ for key, result in list(scope.items()):
+ if isinstance(result, Future):
+ scope[key] = result.result()
diff --git a/torch/_inductor/codegen/__init__.py b/torch/_inductor/codegen/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/torch/_inductor/codegen/autotuner.py b/torch/_inductor/codegen/autotuner.py
new file mode 100644
index 0000000000000..6425f28cc18f6
--- /dev/null
+++ b/torch/_inductor/codegen/autotuner.py
@@ -0,0 +1,274 @@
+import builtins
+
+import torch
+
+from .. import config, triton_ops
+from ..triton_ops.autotune import mm_autotune, mm_heuristics
+from ..utils import dynamo_testing
+from ..virtualized import V
+
+aten = torch.ops.aten
+rand_strided = dynamo_testing.rand_strided
+
+
+def str2func(str):
+ module, *name = str.split(".")
+ if module == "aten":
+ runnable = aten
+ elif module == "triton_ops":
+ runnable = triton_ops
+ elif module == "torch":
+ runnable = torch
+ else:
+ raise Exception(f"{str} could not be called")
+
+ for n in name:
+ runnable = getattr(runnable, n)
+ return runnable
+
+
+class Autotuner:
+ def __init__(self):
+ self.cache = dict()
+
+ def _bench(self, kernel, *args, **kwargs):
+ def kernel_call():
+ kernel(*args, **kwargs)
+
+ from triton.testing import do_bench
+
+ return do_bench(kernel_call)
+
+
+autotune = Autotuner()
+
+
+def tuned_conv(
+ x_shape,
+ w_shape,
+ x_stride,
+ w_stride,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ device,
+ dtype,
+ adjust_triton=0.95,
+):
+ """
+ Return the best kernel name given inputs and layer parameters;
+ Considering potential pointwise fusion of conv, we could adjust triton timing
+ by multiplying adjust_triton (default=0.95)
+ """
+
+ sizevars = V.graph.sizevars
+ x_shape = [sizevars.size_hint(s) for s in x_shape]
+ w_shape = [sizevars.size_hint(s) for s in w_shape]
+ x_stride = [sizevars.size_hint(s) for s in x_stride]
+ w_stride = [sizevars.size_hint(s) for s in w_stride]
+ x = rand_strided(x_shape, x_stride, device=device, dtype=dtype)
+ w = rand_strided(w_shape, w_stride, device=device, dtype=dtype)
+ # the identifiable args for the layers
+ id_args = [
+ *x_shape,
+ *w_shape,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ # *x_stride,
+ # *w_stride,
+ ]
+ use_cuda = x.is_cuda
+
+ # gen_key
+ key = tuple(id_args)
+ key = ("conv",) + key
+
+ # candidate kernels
+ kernels = ["aten.convolution"]
+ if use_cuda:
+ kernels += ["triton_ops.conv"]
+
+ # filter kernels that args/kwargs does not meet requirements
+ remove_kernels = []
+ if groups > 1 or transposed:
+ remove_kernels += ["triton_ops.conv"]
+ kernels = [k for k in kernels if k not in remove_kernels]
+
+ # if only one choice, return that kernel
+ if len(kernels) == 1:
+ kernel = kernels[0]
+ # return kernel(
+ # x, w, stride, padding, dilation, transposed, output_padding, groups
+ # )
+ return kernel
+ timings = {}
+ if key not in autotune.cache:
+ for kernel in kernels:
+ runnable_kernel = str2func(kernel)
+ if "triton_ops" in kernel:
+ # because we use nhwc layout by default for triton conv
+ x = x.to(memory_format=torch.channels_last)
+ run_args = (
+ x,
+ w,
+ None,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ )
+ timing, _, _ = autotune._bench(runnable_kernel, *run_args)
+ if "triton_ops" in kernel:
+ timing = timing * adjust_triton
+ timings[kernel] = timing
+ autotune.cache[key] = builtins.min(timings, key=timings.get)
+ if config.debug:
+ print("for key = ", key)
+ print("timing", timings)
+ print("best_kernel", autotune.cache[key])
+ best_kernel = autotune.cache[key]
+ # if best_kernel == "triton_ops.conv":
+ # print(key, best_kernel)
+ return best_kernel
+
+
+def tuned_mm(
+ a_shape,
+ b_shape,
+ a_stride,
+ b_stride,
+ device,
+ dtype,
+ adjust_triton=0.95,
+):
+ """
+ Return the best kernel name given mm input size;
+ Considering potential pointwise fusion of mm, we could adjust triton timing
+ by multiplying adjust_triton (default=0.95)
+ """
+
+ sizevars = V.graph.sizevars
+ a_shape = [sizevars.size_hint(s) for s in a_shape]
+ b_shape = [sizevars.size_hint(s) for s in b_shape]
+ a_stride = [sizevars.size_hint(s) for s in a_stride]
+ b_stride = [sizevars.size_hint(s) for s in b_stride]
+ a = rand_strided(a_shape, a_stride, device=device, dtype=dtype)
+ b = rand_strided(b_shape, b_stride, device=device, dtype=dtype)
+ c = torch.empty((a_shape[0], b_shape[1]), device=device, dtype=dtype)
+ id_args = [
+ *a_shape,
+ *b_shape,
+ ]
+ use_cuda = a.is_cuda
+
+ # gen_key
+ key = tuple(id_args)
+ key = ("mm",) + key
+
+ # candidate kernels
+ kernels = ["aten.mm.out"]
+ if use_cuda:
+ kernels += ["triton_ops.matmul_out"]
+ # if only one choice, return that kernel
+ if len(kernels) == 1:
+ kernel = kernels[0]
+ return kernel
+ timings = {}
+ if key not in autotune.cache:
+ # bench_start = time.time()
+ for kernel in kernels:
+ runnable_kernel = str2func(kernel)
+ if "triton_ops" in kernel:
+ run_args = (a, b, c)
+ run_kwargs = {}
+ inner_kernel = str2func(
+ kernel.replace("matmul_out", "_matmul_out") + ".kernel"
+ )
+ inner_kernel.kernel_decorators = []
+ # fix SPLIT_K = 1 for fusable kernels
+ mm_heuristics()(mm_autotune(get_io_bound_configs=False)(inner_kernel))
+ else:
+ run_args = (a, b)
+ run_kwargs = {"out": c}
+ timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
+ if "triton_ops" in kernel:
+ timing = timing * adjust_triton
+ timings[kernel] = timing
+ # bench_end = time.time()
+ # bench_time = bench_end - bench_start
+ autotune.cache[key] = builtins.min(timings, key=timings.get)
+ if config.debug:
+ print("for key = ", key)
+ print("timing", timings)
+ print("best_kernel", autotune.cache[key])
+ best_kernel = autotune.cache[key]
+ return best_kernel
+
+
+def tuned_conv_layout(
+ kernel,
+ x_shape,
+ w_shape,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ device,
+ dtype,
+):
+ sizevars = V.graph.sizevars
+ x_shape = [sizevars.size_hint(s) for s in x_shape]
+ w_shape = [sizevars.size_hint(s) for s in w_shape]
+ x = torch.randn(x_shape, device=device, dtype=dtype)
+ w = torch.randn(w_shape, device=device, dtype=dtype)
+ id_args = [
+ *x_shape,
+ *w_shape,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ ]
+
+ # gen_key
+ key = tuple(id_args)
+ key = ("conv_layout",) + key
+ runnable_kernel = str2func(kernel)
+
+ timings = {}
+ if key not in autotune.cache:
+ for memory_format in ["torch.contiguous_format", "torch.channels_last"]:
+ x = x.to(memory_format=str2func(memory_format))
+ run_args = (
+ x,
+ w,
+ None,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ )
+ timing, _, _ = autotune._bench(runnable_kernel, *run_args)
+ timings[memory_format] = timing
+ autotune.cache[key] = builtins.min(timings, key=timings.get)
+ if config.debug:
+ print("for key = ", key)
+ print("timing", timings)
+ print("best_layout", autotune.cache[key])
+ best_layout = autotune.cache[key]
+ return best_layout
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
new file mode 100644
index 0000000000000..e6fc91ea52c63
--- /dev/null
+++ b/torch/_inductor/codegen/common.py
@@ -0,0 +1,586 @@
+import collections
+import contextlib
+import itertools
+import logging
+import math
+import re
+import textwrap
+import typing
+from collections import namedtuple
+from io import StringIO
+from itertools import chain
+
+import sympy
+from sympy.printing.printer import Printer
+
+from .. import metrics
+from ..utils import free_symbol_startswith, sympy_dot, sympy_subs, unique
+from ..virtualized import ops, V
+
+log = logging.getLogger(__name__)
+
+TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"])
+SizeArg = namedtuple("SizeArg", ["name", "expr"])
+
+
+def index_prevent_reordering(index: typing.List[sympy.Expr], index_vars, sizes):
+ from ..ir import FlexibleLayout
+
+ # added contiguous index prevents reordering
+ return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
+
+
+class ExprPrinter(Printer):
+ @staticmethod
+ def paren(string):
+ if (
+ re.match(r"^[a-z0-9_.]+$", string, re.I)
+ or re.match(r"^\([^)]*\)$", string, re.I)
+ or string == ""
+ ):
+ return string
+ return f"({string})"
+
+ def _print_Pow(self, expr):
+ # Pow() confuses triton
+ base, exp = expr.args
+ base = self._print(base)
+ assert exp.is_integer
+ exp = int(exp)
+ return "*".join([self.paren(base)] * exp)
+
+ def _print_Mul(self, expr):
+ return "*".join(map(self.paren, map(self._print, expr.args)))
+
+ def _print_Add(self, expr):
+ return " + ".join(map(self.paren, map(self._print, expr.args)))
+
+ def _print_Mod(self, expr):
+ return " % ".join(map(self.paren, map(self._print, expr.args)))
+
+ def _print_CleanDiv(self, expr):
+ return self._print_IndexingDiv(expr)
+
+
+class OpOverrides:
+ def __init__(self, parent):
+ super().__init__()
+ self._parent = parent
+
+ def __getattr__(self, item):
+ return getattr(self._parent, item)
+
+ @staticmethod
+ def identity(value):
+ # used to trigger cse
+ return value
+
+ @staticmethod
+ def constant(value, dtype):
+ return repr(value)
+
+ @staticmethod
+ def sigmoid(x):
+ x = ops.exp(f"-{x}")
+ return f"1 / (1 + {x})"
+
+ @staticmethod
+ def silu(x):
+ return f"{x} * {ops.sigmoid(x)}"
+
+ @staticmethod
+ def reciprocal(x):
+ return ops.div("1", x)
+
+ @staticmethod
+ def square(x):
+ return ops.mul(x, x)
+
+ @staticmethod
+ def sign(x):
+ return ops.where(f"{x} == 0", "0", ops.where(f"{x} < 0", "-1", "1"))
+
+ @staticmethod
+ def bitwise_not(x):
+ return f"~{ExprPrinter.paren(x)}"
+
+ @staticmethod
+ def logical_not(a):
+ return f"{ExprPrinter.paren(a)} == 0"
+
+ @staticmethod
+ def bitwise_and(x, y):
+ return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
+
+ @staticmethod
+ def bitwise_or(x, y):
+ return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
+
+ @staticmethod
+ def bitwise_xor(x, y):
+ return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
+
+ @staticmethod
+ def remainder(a, b):
+ r = ops.mod(a, b)
+ return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
+
+
+class IndentedBuffer:
+ tabwidth = 4
+
+ def __init__(self, initial_indent=0):
+ self._lines = []
+ self._indent = initial_indent
+
+ def getvalue(
+ self,
+ ):
+ buf = StringIO()
+ for line in self._lines:
+ if isinstance(line, DeferredLine):
+ line = line()
+ if line is None:
+ continue
+ assert isinstance(line, str)
+ buf.write(line)
+ buf.write("\n")
+ return buf.getvalue()
+
+ def clear(self):
+ self._lines.clear()
+
+ def __bool__(self):
+ return bool(self._lines)
+
+ def prefix(self):
+ return " " * (self._indent * self.tabwidth)
+
+ def writeline(self, line):
+ if isinstance(line, DeferredLine):
+ self._lines.append(line.with_prefix(self.prefix()))
+ elif line.strip():
+ self._lines.append(f"{self.prefix()}{line}")
+ else:
+ self._lines.append("")
+
+ def writelines(self, lines):
+ for line in lines:
+ self.writeline(line)
+
+ def indent(self, offset=1):
+ @contextlib.contextmanager
+ def ctx():
+ self._indent += offset
+ yield
+ self._indent -= offset
+
+ return ctx()
+
+ def splice(self, other_code, strip=False):
+ if isinstance(other_code, IndentedBuffer):
+ dedent = float("inf")
+ for line in other_code._lines:
+ if line:
+ dedent = min(dedent, len(line) - len(line.lstrip()))
+ if math.isinf(dedent):
+ dedent = 0
+ for line in other_code._lines:
+ IndentedBuffer.writeline(self, line[dedent:])
+ else:
+ other_code = textwrap.dedent(other_code)
+ if strip:
+ other_code = other_code.lstrip()
+ if not other_code:
+ return
+ other_code = other_code.rstrip()
+ for line in other_code.split("\n"):
+ self.writeline(line)
+
+
+class DeferredLine:
+ """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
+
+ def __init__(self, name, line):
+ if not line.strip():
+ line = ""
+ self.name = name
+ self.line = line
+
+ def __call__(self):
+ if self.name not in V.graph.removed_buffers:
+ return self.line
+ return None
+
+ def with_prefix(self, prefix):
+ return DeferredLine(self.name, f"{prefix}{self.line}")
+
+ def lstrip(self):
+ return DeferredLine(self.name, self.line.lstrip())
+
+ def __getitem__(self, index):
+ return DeferredLine(self.name, self.line[index])
+
+ def __bool__(self):
+ return bool(self.line)
+
+ def __len__(self):
+ return len(self.line)
+
+
+class DeferredIndentedBuffer(IndentedBuffer):
+ def __init__(self, initial_indent=0):
+ super(DeferredIndentedBuffer, self).__init__(initial_indent)
+
+ def writeline(self, name, line):
+ if name is None:
+ return super().writeline(line)
+ assert "buf" in name
+ return super().writeline(DeferredLine(name, line))
+
+ def writelines(self, name, lines):
+ for line in lines:
+ self.writeline(name, line)
+
+
+class BracesBuffer(IndentedBuffer):
+ def indent(self, offset=1):
+ @contextlib.contextmanager
+ def ctx():
+ for _ in range(offset):
+ self.writeline("{")
+ self._indent += 1
+ for _ in range(-offset):
+ self._indent -= 1
+ self.writeline("}")
+ yield
+ for _ in range(-offset):
+ self.writeline("{")
+ self._indent += 1
+ for _ in range(offset):
+ self._indent -= 1
+ self.writeline("}")
+
+ return ctx()
+
+
+class InplacedBuffer(typing.NamedTuple):
+ inner_name: str
+ other_names: typing.List[str]
+
+
+class KernelArgs:
+ @staticmethod
+ def _lookup(prefix, odict, name):
+ assert isinstance(name, (str, sympy.Symbol))
+ name = str(name)
+ if name not in odict:
+ odict[name] = f"{prefix}{len(odict)}"
+ return odict[name]
+
+ def __init__(self, sizevars=None):
+ self.input_buffers = collections.OrderedDict()
+ self.output_buffers = collections.OrderedDict()
+ self.inplace_buffers = collections.OrderedDict()
+ self.sizevars = sizevars or collections.OrderedDict()
+
+ def input(self, name):
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
+ assert name not in V.graph.removed_buffers, name
+ if name in self.output_buffers:
+ return self.output_buffers[name]
+ if name.startswith("seed"):
+ return self._lookup("seed", self.input_buffers, name)
+ return self._lookup("in_ptr", self.input_buffers, name)
+
+ def output(self, name):
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
+ assert name not in V.graph.removed_buffers, name
+ return self._lookup("out_ptr", self.output_buffers, name)
+
+ def make_inplace(self, input_name, output_name):
+ buf = InplacedBuffer(
+ f"in_out_ptr{len(self.inplace_buffers)}", [input_name, output_name]
+ )
+ self.inplace_buffers[input_name] = buf
+ self.inplace_buffers[output_name] = buf
+
+ def size(self, name):
+ if str(name) == "seed":
+ self.sizevars["seed"] = "seed"
+ return "seed"
+ return self._lookup("ks", self.sizevars, name)
+
+ def call_names(self):
+ return chain(
+ self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
+ )
+
+ def cpp_argdefs(self):
+ from .cpp import DTYPE_TO_CPP, INDEX_TYPE
+
+ # TODO(jansel): replace this with data from scheduler
+ buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers}
+ buffer_types.update(
+ {name: val.get_dtype() for name, val in V.graph.graph_inputs.items()}
+ )
+ buffer_types.update(
+ {name: val.dtype for name, val in V.graph.constants.items()}
+ )
+
+ call_args = []
+ arg_defs = []
+ for inplaced in unique(self.inplace_buffers.values()):
+ outer = inplaced.other_names[0]
+ inner = inplaced.inner_name
+ dtype = buffer_types[outer]
+ arg_defs.append(f"{DTYPE_TO_CPP[dtype]}* __restrict__ {inner}")
+ name = inplaced.other_names[-1]
+ call_args.append(f"c_void_p({name}.data_ptr())")
+ for outer, inner in self.input_buffers.items():
+ if outer in self.inplace_buffers:
+ continue
+ dtype = buffer_types[outer]
+ arg_defs.append(f"const {DTYPE_TO_CPP[dtype]}* __restrict__ {inner}")
+ call_args.append(f"c_void_p({outer}.data_ptr())")
+ for outer, inner in self.output_buffers.items():
+ if outer in self.inplace_buffers or inner == "REMOVED":
+ continue
+ dtype = buffer_types[outer]
+ arg_defs.append(f"{DTYPE_TO_CPP[dtype]}* __restrict__ {inner}")
+ call_args.append(f"c_void_p({outer}.data_ptr())")
+ for outer, inner in self.sizevars.items():
+ arg_defs.append(f"const {INDEX_TYPE} {inner}")
+ call_args.append(f"c_long({outer})")
+ return arg_defs, call_args
+
+ def python_argdefs(self):
+ arg_defs = []
+ call_args = []
+ precompile_args = []
+ for inplaced in unique(self.inplace_buffers.values()):
+ arg_defs.append(inplaced.inner_name)
+ call_args.append(inplaced.other_names[-1])
+ precompile_args.append(
+ TensorArg(
+ inplaced.inner_name,
+ inplaced.other_names[-1],
+ V.graph.get_dtype(inplaced.other_names[-1]),
+ )
+ )
+ for outer, inner in chain(
+ self.input_buffers.items(), self.output_buffers.items()
+ ):
+ if outer in self.inplace_buffers or inner == "REMOVED":
+ continue
+ arg_defs.append(inner)
+ call_args.append(outer)
+ precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer)))
+ for outer, inner in self.sizevars.items():
+ arg_defs.append(inner)
+ call_args.append(outer)
+ precompile_args.append(SizeArg(inner, sympy.expand(outer)))
+ return arg_defs, call_args, precompile_args
+
+ def aliases(self):
+ for inplaced in unique(self.inplace_buffers.values()):
+ for other in inplaced.other_names:
+ if other in self.input_buffers:
+ yield self.input_buffers[other], inplaced.inner_name
+ if other in self.output_buffers:
+ yield self.output_buffers[other], inplaced.inner_name
+
+
+class CSE:
+ """Common subexpression elimination"""
+
+ def __init__(
+ self,
+ prefix="",
+ suffix="",
+ name_prefix="tmp",
+ iter_buffers=None,
+ store_cache=None,
+ reduction_cache=None,
+ ):
+ self.prefix = prefix
+ self.suffix = suffix
+ self.cache = {}
+ self.name_prefix = name_prefix
+ self.store_cache = store_cache or {}
+ self.reduction_cache = reduction_cache or {}
+ self.iter_buffer_ids = iter_buffers or itertools.count()
+ self.invalidated_stores = set()
+
+ def invalidate(self, keep_vars: typing.Set[str]):
+ for name, tmp in list(self.store_cache.items()):
+ if tmp not in keep_vars:
+ del self.store_cache[name]
+ self.invalidated_stores.add(name)
+ self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
+
+ def clone(self):
+ return CSE(
+ self.prefix,
+ self.suffix,
+ self.name_prefix,
+ self.iter_buffer_ids,
+ self.store_cache,
+ )
+
+ def generate(self, buffer: IndentedBuffer, expr: str, write=True):
+ assert isinstance(expr, str), expr
+ if expr.startswith(self.name_prefix) and re.match(r"^[a-z0-9]+$", expr):
+ return expr
+ if expr not in self.cache:
+ var = self.newvar()
+ self.cache[expr] = var
+ if write:
+ V.kernel.current_node.codegen_originating_info(buffer, only_once=True)
+ buffer.writeline(f"{self.prefix}{var} = {expr}{self.suffix}")
+ return self.cache[expr]
+
+ def newvar(self):
+ return f"{self.name_prefix}{next(self.iter_buffer_ids)}"
+
+
+class CodeGen:
+ def __init__(self):
+ super().__init__()
+ self.exit_stack = contextlib.ExitStack()
+
+ def __enter__(self):
+ self.exit_stack.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
+
+
+class Kernel(CodeGen):
+ newvar_prefix = ""
+ suffix = ""
+ overrides = None
+ load_format = None
+ store_format = None
+
+ def __init__(self, args=None):
+ super().__init__()
+ metrics.generated_kernel_count += 1
+ self.args = args or KernelArgs()
+ self.loads = IndentedBuffer()
+ self.compute = IndentedBuffer()
+ self.stores = DeferredIndentedBuffer()
+ self.cse = CSE(self.newvar_prefix, self.suffix)
+ self.must_keep_buffers = set()
+ self.current_node = None
+ self.store_buffer_names = set()
+
+ @contextlib.contextmanager
+ def set_current_node(self, node):
+ prior = self.current_node
+ self.current_node = node
+ yield
+ self.current_node = prior
+
+ @contextlib.contextmanager
+ def swap_buffers(self, lb, cb=None, sb=None):
+ if cb is None:
+ cb = lb
+ loads = self.loads
+ compute = self.compute
+ stores = self.stores
+ cse = self.cse
+ self.loads = lb
+ self.compute = cb
+ self.stores = sb
+ self.cse = cse.clone()
+ yield
+ self.loads = loads
+ self.compute = compute
+ self.stores = stores
+ self.cse = cse
+
+ def load(self, name: str, index: sympy.Expr):
+ raise NotImplementedError()
+
+ def indirect_load(self, name: str, index: sympy.Expr):
+ """A load the depends on an index we have read"""
+ prior = self.loads
+ try:
+ # put the load in the compute section as it might have deps
+ self.loads = self.compute
+ return self.load(name, index)
+ finally:
+ self.loads = prior
+
+ def store(self, name, index, value, mode=None):
+ raise NotImplementedError()
+
+ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
+ raise NotImplementedError()
+
+ def __enter__(self):
+ class CSEProxy:
+ @staticmethod
+ def __getattr__(name):
+ def inner(*args, **kwargs):
+ return self.cse.generate(
+ self.compute, getattr(parent_handler, name)(*args, **kwargs)
+ )
+
+ return inner
+
+ @staticmethod
+ def indirect_indexing(index_var):
+ return sympy.Symbol(str(index_var))
+
+ @staticmethod
+ def load(name: str, index: sympy.Expr):
+ if name in self.cse.invalidated_stores:
+ # A load from an invalidated store requires us to
+ # keep the actual buffer around
+ V.kernel.must_keep_buffers.add(name)
+ if free_symbol_startswith(index, "tmp"):
+ return self.indirect_load(name, index)
+ store_cache = self.cse.store_cache
+ if name in store_cache:
+ return store_cache[name]
+ return self.load(name, index)
+
+ @staticmethod
+ def store(name, index, value, mode=None):
+ self.store_buffer_names.add(name)
+ if mode is None:
+ self.cse.store_cache[name] = value
+ for other_name in self.current_node.get_mutations():
+ self.cse.store_cache[other_name] = value
+ if name not in V.graph.removed_buffers:
+ return self.store(name, index, value, mode=mode)
+
+ @staticmethod
+ def reduction(name, dtype, src_dtype, reduction_type, index, value):
+ self.store_buffer_names.add(name)
+ return self.reduction(
+ name, dtype, src_dtype, reduction_type, index, value
+ )
+
+ super().__enter__()
+ parent_handler = self.overrides(V.get_ops_handler())
+ self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
+ self.exit_stack.enter_context(V.set_kernel_handler(self))
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ V.graph.scheduler.remove_kernel_local_buffers()
+ super().__exit__(exc_type, exc_val, exc_tb)
+
+ def rename_indexing(self, index) -> sympy.Expr:
+ if isinstance(index, (list, tuple)):
+ return [self.rename_indexing(x) for x in index]
+ index = V.graph.sizevars.simplify(index)
+ sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
+ replacements = {
+ x: self.args.size(x) for x in sorted_symbols if x.name.startswith("s")
+ }
+ return sympy_subs(index, replacements)
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
new file mode 100644
index 0000000000000..788fe53ab378c
--- /dev/null
+++ b/torch/_inductor/codegen/cpp.py
@@ -0,0 +1,716 @@
+import contextlib
+import dataclasses
+import functools
+from pathlib import Path
+from typing import Dict, List
+
+import sympy
+
+import torch
+from torch._prims_common import is_float_dtype
+
+from .. import codecache, config
+from ..utils import sympy_product
+from ..virtualized import ops, V
+from .common import (
+ BracesBuffer,
+ DeferredIndentedBuffer,
+ ExprPrinter,
+ IndentedBuffer,
+ Kernel,
+ KernelArgs,
+ OpOverrides,
+)
+
+DTYPE_TO_CPP = {
+ torch.float32: "float",
+ torch.float64: "double",
+ torch.float16: "half",
+ torch.int64: "long",
+ torch.int32: "int",
+ torch.int16: "short",
+ torch.int8: "signed char",
+ torch.uint8: "unsigned char",
+ torch.bool: "bool",
+ torch.bfloat16: "bfloat16",
+}
+INDEX_TYPE = "long"
+
+RTYPE_TO_CPP = {
+ "sum": "+",
+ "min": "min",
+ "max": "max",
+ "argmin": "argmin",
+ "argmax": "argmax",
+ "any": "||",
+}
+
+
+def reduction_init(reduction_type, dtype):
+ if reduction_type in ("sum", "any"):
+ return 0
+ if reduction_type in {"max", "argmax"}:
+ return (
+ f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
+ if is_float_dtype(dtype)
+ else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::min()"
+ )
+ if reduction_type in {"min", "argmin"}:
+ return (
+ f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
+ if is_float_dtype(dtype)
+ else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::max()"
+ )
+ raise AssertionError(reduction_type)
+
+
+def reduction_combine(reduction_type, var, next_value):
+ if reduction_type == "sum":
+ return f"{var} += {next_value}"
+ if reduction_type == "any":
+ return f"{var} = {var} || {next_value}"
+ return f"{var} = std::{reduction_type}({var}, {next_value})"
+
+
+index_value_name_counter = 1
+
+
+def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar):
+ global index_value_name_counter
+ struct_name = f"IndexValue_{index_value_name_counter}"
+ index_value_name_counter += 1
+
+ # A small annoyance, due to it being a little cumbersome to just throw {} into strings
+ prefix = [
+ f"struct {struct_name} {{size_t index; {DTYPE_TO_CPP[src_dtype]} value;}};",
+ f"{struct_name} {tmpvar}{{0, {reduction_init(reduction_type, src_dtype)}}};",
+ ]
+ if reduction_type == "argmax":
+ prefix.extend(
+ [
+ f"#pragma omp declare reduction(argmax : struct {struct_name} :\\",
+ " omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,\\",
+ " omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)\\",
+ f"\tinitializer(omp_priv = {{0, {reduction_init(reduction_type, src_dtype)}}})",
+ ]
+ )
+ elif reduction_type == "argmin":
+ prefix.extend(
+ [
+ f"#pragma omp declare reduction(argmin : struct {struct_name} :\\",
+ " omp_out.value = omp_in.value > omp_out.value ? omp_out.value : omp_in.value,\\",
+ " omp_out.index = omp_in.value > omp_out.value ? omp_out.index : omp_in.index)\\",
+ f"\tinitializer(omp_priv = {{0, {reduction_init(reduction_type, src_dtype)}}})",
+ ]
+ )
+ return prefix
+
+
+def float16_reduction_prefix(rtype):
+ # TODO: This user-defined reduction uses float16 accumulation for sum. To reduce numerical
+ # errors, float32 accumulation should be used instead.
+ assert rtype in (
+ "sum",
+ "any",
+ ), f"float16 user-defined reduction only supports 'sum' and 'any' but got {rtype}"
+ prefix = [
+ f"#pragma omp declare reduction({RTYPE_TO_CPP[rtype]}:{DTYPE_TO_CPP[torch.float16]}:"
+ + f"omp_out = omp_out {RTYPE_TO_CPP[rtype]} omp_in)"
+ ]
+ return prefix
+
+
+@functools.lru_cache()
+def cpp_prefix():
+ path = Path(__file__).parent / "cpp_prefix.h"
+ with path.open() as f:
+ _, filename = codecache.write(
+ f.read(),
+ "h",
+ )
+ return f'#include "{filename}"'
+
+
+class CppPrinter(ExprPrinter):
+ def _print_ModularIndexing(self, expr):
+ x, div, mod = expr.args
+ x = self.paren(self.doprint(x))
+ div = self.paren(self.doprint(div))
+ mod = self.paren(self.doprint(mod))
+ if div != "1":
+ x = f"({x} / {div})"
+ return f"{x} % {mod}"
+
+ def _print_IndexingDiv(self, expr):
+ x, div = expr.args
+ x = self.paren(self.doprint(x))
+ div = self.paren(self.doprint(div))
+ return f"({x} / {div})"
+
+
+cexpr = CppPrinter().doprint
+
+
+class CppOverrides(OpOverrides):
+ """Map element-wise ops to C++"""
+
+ @staticmethod
+ def to_dtype(x, dtype):
+ assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP"
+ return f"static_cast<{DTYPE_TO_CPP[dtype]}>({x})"
+
+ @staticmethod
+ def abs(x):
+ return f"std::abs({x})"
+
+ @staticmethod
+ def sin(x):
+ return f"std::sin({x})"
+
+ @staticmethod
+ def cos(x):
+ return f"std::cos({x})"
+
+ @staticmethod
+ def exp(x):
+ # return f"Sleef_expf_u10({x})"
+ return f"std::exp({x})"
+
+ @staticmethod
+ def sqrt(x):
+ return f"std::sqrt({x})"
+
+ @staticmethod
+ def rsqrt(x):
+ return f"1 / std::sqrt({x})"
+
+ @staticmethod
+ def signbit(x):
+ return f"std::signbit({x})"
+
+ @staticmethod
+ def pow(a, b):
+ return f"std::pow({a}, {b})"
+
+ @staticmethod
+ def log(x):
+ return f"std::log({x})"
+
+ @staticmethod
+ def round(x):
+ return f"std::nearbyint({x})"
+
+ @staticmethod
+ def floor(x):
+ return f"std::floor({x})"
+
+ @staticmethod
+ def floordiv(a, b):
+ # a and b are integer type
+ quot = f"{a} / {b}"
+ rem = f"{a} % {b}"
+ return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})"
+
+ @staticmethod
+ def ceil(x):
+ return f"std::ceil({x})"
+
+ @staticmethod
+ def trunc(x):
+ return f"std::trunc({x})"
+
+ @staticmethod
+ def truncdiv(a, b):
+ # a and b are integer type
+ return f"{a} / {b}"
+
+ @staticmethod
+ def fmod(a, b):
+ return f"std::fmod({a}, {b})"
+
+ @staticmethod
+ def isinf(x):
+ return f"std::isinf({x})"
+
+ @staticmethod
+ def isnan(x):
+ return f"std::isnan({x})"
+
+ @staticmethod
+ def lgamma(x):
+ return f"std::lgamma({x})"
+
+ @staticmethod
+ def relu(x):
+ return f"{x} * ({x}>0)"
+
+ @staticmethod
+ def minimum(a, b):
+ return f"std::min({a}, {b})"
+
+ @staticmethod
+ def maximum(a, b):
+ return f"std::max({a}, {b})"
+
+ @staticmethod
+ def where(a, b, c):
+ return f"{a} ? {b} : {c}"
+
+ @staticmethod
+ def mod(a, b):
+ return f"mod({a}, {b})"
+
+ @staticmethod
+ def constant(val, dtype):
+ if val == float("inf"):
+ return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
+ elif val == float("-inf"):
+ return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
+ elif val is True or val is False:
+ return ops.to_dtype(str(val).lower(), dtype)
+ return ops.to_dtype(repr(val), dtype)
+
+ @staticmethod
+ def index_expr(expr, dtype):
+ return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype)
+
+ @staticmethod
+ def masked(mask, body, other):
+ code = BracesBuffer()
+ var = V.kernel.cse.newvar()
+ if other == float("-inf"):
+ code.writeline(f"float {var} = -std::numeric_limits::infinity();")
+ elif other == float("inf"):
+ code.writeline(f"float {var} = std::numeric_limits::infinity();")
+ else:
+ code.writeline(f"auto {var} = {other!r};")
+ code.writeline(f"if({mask})")
+ with V.kernel.swap_buffers(code), code.indent():
+ result = body()
+ code.writeline(f"{var} = {result};")
+ V.kernel.compute.splice(code)
+ return var
+
+ @staticmethod
+ def logical_and(a, b):
+ return f"{a} && {b}"
+
+ @staticmethod
+ def logical_or(a, b):
+ return f"{a} || {b}"
+
+ @staticmethod
+ def rand(seed: sympy.Expr, offset: sympy.Expr, dtype):
+ return f"static_cast<{DTYPE_TO_CPP[dtype]}>(normalized_rand_cpu({seed}, {offset}));"
+
+ @staticmethod
+ def randn(seed: sympy.Expr, offset: sympy.Expr, dtype):
+ return f"static_cast<{DTYPE_TO_CPP[dtype]}>(randn_cpu({seed}, {offset}));"
+
+
+class CppKernel(Kernel):
+ overrides = CppOverrides
+ sexpr = cexpr
+ newvar_prefix = "auto "
+ suffix = ";"
+
+ def __init__(self, args, num_threads):
+ super(CppKernel, self).__init__(args)
+ self.call_ranges = None
+ self.ranges = None
+ self.itervars = None
+ self.reduction_depth = None
+ self.reduction_prefix = IndentedBuffer()
+ self.reduction_suffix = DeferredIndentedBuffer()
+ self.reduction_vars = {}
+ self.num_threads = num_threads # num_threads the kernel specialized for
+
+ def load(self, name: str, index: sympy.Expr):
+ var = self.args.input(name)
+ index = self.rename_indexing(index)
+ line = f"{var}[{cexpr(index)}]"
+ if V.graph.get_dtype(name) in (torch.float16, torch.bfloat16):
+ line = f"static_cast({line})"
+ return self.cse.generate(self.loads, line)
+
+ def store(self, name, index, value, mode=None):
+ assert "buf" in name
+ var = self.args.output(name)
+ index = self.rename_indexing(index)
+ if mode is None:
+ line = f"{var}[{cexpr(index)}] = {value};"
+ elif mode == "atomic_add":
+ if not config.cpp.dynamic_threads and self.num_threads == 1:
+ line = f"{var}[{cexpr(index)}] += {value};"
+ else:
+ line = f"atomic_add(&{var}[{cexpr(index)}], {value});"
+ else:
+ raise NotImplementedError(f"store mode={mode}")
+ self.stores.writeline(name, line)
+
+ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
+ argmax_or_argmin = reduction_type in {"argmax", "argmin"}
+ tmpvar = self.cse.generate(
+ self.loads, f"reduction {name} {cexpr(index)}", write=False
+ )
+ index = self.rename_indexing(index)
+ self.reduction_vars[tmpvar] = reduction_type
+ if argmax_or_argmin:
+ self.reduction_prefix.writelines(
+ argmax_argmin_prefix(reduction_type, src_dtype, tmpvar)
+ )
+ compare_op = "<" if reduction_type == "argmax" else ">"
+ self.stores.writelines(
+ None,
+ [
+ f"if ({tmpvar}.value {compare_op} {value}) {{",
+ f" {tmpvar}.index = {self.itervars[-1]}; {tmpvar}.value = {value};",
+ "}",
+ ],
+ )
+ else:
+ if dtype == torch.float16:
+ self.reduction_prefix.writelines(
+ float16_reduction_prefix(reduction_type)
+ )
+ self.reduction_prefix.writeline(
+ f"{DTYPE_TO_CPP[dtype]} {tmpvar} = {reduction_init(reduction_type, dtype)};"
+ )
+ self.stores.writeline(
+ None, f"{reduction_combine(reduction_type, tmpvar, value)};"
+ )
+
+ if name not in V.graph.removed_buffers:
+ var = self.args.output(name)
+ member_name = ".index" if argmax_or_argmin else ""
+ self.reduction_suffix.writeline(
+ name, f"{var}[{cexpr(index)}] = {tmpvar}{member_name};"
+ )
+ self.cse.store_cache[name] = tmpvar
+
+ def set_ranges(self, lengths, reduction_lengths):
+ if self.call_ranges:
+ assert self.call_ranges == tuple(lengths) + tuple(
+ reduction_lengths
+ ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
+ assert self.reduction_depth == len(lengths)
+ else:
+ self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
+ self.ranges = [self.rename_indexing(x) for x in self.call_ranges]
+ self.itervars = [sympy.Symbol(f"i{n}") for n in range(len(self.ranges))]
+ self.reduction_depth = len(lengths)
+ return (
+ self.itervars[: self.reduction_depth],
+ self.itervars[self.reduction_depth :],
+ )
+
+ def size_hint(self):
+ return V.graph.sizevars.size_hint(sympy_product(self.call_ranges))
+
+ def codegen_loops(self, code, worksharing):
+ threads = config.cpp.threads
+ if threads < 1:
+ threads = torch.get_num_threads()
+
+ loops = [LoopLevel(var, size) for var, size in zip(self.itervars, self.ranges)]
+ loops, reductions = LoopNest(loops[: self.reduction_depth]), LoopNest(
+ loops[self.reduction_depth :]
+ )
+ reductions.mark_reduction(self.reduction_vars)
+
+ if config.cpp.simdlen:
+ # TODO(jansel): detect stride-1 dimension and vectorize that
+ if reductions:
+ reductions.loops[-1].simd = True
+ else:
+ loops.loops[-1].simd = True
+
+ par_depth = 0
+ reduction_par_depth = 0
+ if loops:
+ par_depth = self.decide_parallel_depth(
+ self.call_ranges[: self.reduction_depth], threads
+ )
+ else:
+ reduction_par_depth = self.decide_parallel_depth(
+ self.call_ranges[self.reduction_depth :], threads
+ )
+
+ with contextlib.ExitStack() as stack:
+ if par_depth:
+ worksharing.parallel(threads)
+ loops.mark_parallel(par_depth)
+ elif reduction_par_depth:
+ # need to close the worksharing scope to define reduction vars outside it
+ worksharing.close()
+ reductions.mark_parallel(reduction_par_depth)
+ elif threads > 1:
+ if worksharing.single():
+ stack.enter_context(code.indent())
+
+ loops.codegen(code, stack)
+
+ with contextlib.ExitStack() as stack_outer:
+ if self.reduction_prefix:
+ stack_outer.enter_context(code.indent())
+ code.splice(self.reduction_prefix)
+
+ if reduction_par_depth:
+ worksharing.parallel(threads)
+
+ with contextlib.ExitStack() as stack:
+ reductions.codegen(code, stack)
+ code.splice(self.loads)
+ code.splice(self.compute)
+ code.splice(self.stores)
+
+ if reduction_par_depth:
+ worksharing.close()
+
+ code.splice(self.reduction_suffix)
+
+ def decide_parallel_depth(self, ranges, threads):
+ seq = self.size_hint()
+ par = 1
+ depth = 0
+ for expr in ranges:
+ hint = V.graph.sizevars.size_hint(expr)
+ if par >= 2 * threads or par == threads:
+ break
+ if seq // threads < config.cpp.min_chunk_size:
+ # not enough work
+ break
+ depth += 1
+ par *= hint
+ seq /= hint
+ # if we assume thread number is dynamic, make sure we
+ # have at least one parallel scope and let OMP runtime
+ # to manage the serial vs. parallel.
+ if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0:
+ depth = 1
+ return depth
+
+ @contextlib.contextmanager
+ def write_to_suffix(self):
+ prior = (self.loads, self.compute, self.stores, self.cse)
+ self.loads = IndentedBuffer()
+ self.compute = IndentedBuffer()
+ self.stores = DeferredIndentedBuffer()
+ self.cse = self.cse.clone()
+ yield
+ self.reduction_suffix.splice(self.loads)
+ self.reduction_suffix.splice(self.compute)
+ self.reduction_suffix.splice(self.stores)
+ (self.loads, self.compute, self.stores, self.cse) = prior
+
+
+class CppScheduling:
+ def __init__(self, scheduler):
+ self.scheduler = scheduler
+ self.kernel_group = KernelGroup()
+
+ def group_fn(self, sizes):
+ return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)
+
+ @staticmethod
+ def can_fuse_horizontal(node1, node2):
+ _, (vars1, reduce1) = node1.group
+ _, (vars2, reduce2) = node2.group
+ if vars1 == vars2 and reduce1 == reduce2:
+ return True
+ if reduce1 == () and vars1 == vars2 + reduce2:
+ return True
+ # TODO(jansel): allow fusion pointwise (vars1, ()) suffix?
+ return False
+
+ @classmethod
+ def can_fuse_vertical(cls, node1, node2):
+ return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction()
+
+ def codegen_nodes(self, nodes):
+ """
+ Turn an set of pre-fused nodes into a C++ kernel.
+ """
+ kernel_group = self.kernel_group
+ scheduler = self.scheduler
+ _, (group, reduction_group) = max(
+ nodes, key=lambda x: int(x.is_reduction())
+ ).group
+ in_suffix = False
+
+ with kernel_group.new_kernel() as kernel:
+ vars, reduction_vars = kernel.set_ranges(group, reduction_group)
+
+ for node in nodes:
+ if node.group[1] in [
+ (group, reduction_group),
+ (group + reduction_group, ()),
+ ]:
+ assert not in_suffix
+ node.run(vars, reduction_vars)
+ else:
+ in_suffix = True
+ assert node.group[1] == (
+ group,
+ (),
+ ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
+ # we can fuse in some extra pointwise into the suffix
+ with kernel.write_to_suffix():
+ node.run(vars, ())
+
+ kernel_group.finalize_kernel(kernel, scheduler)
+
+ def flush(self):
+ self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
+ self.kernel_group = KernelGroup()
+
+
+class KernelGroup:
+ def __init__(self):
+ super().__init__()
+ self.args = KernelArgs()
+ self.loops_code = BracesBuffer()
+ self.ws = WorkSharing(self.loops_code)
+ self.stack = contextlib.ExitStack()
+ self.stack.enter_context(self.ws)
+ self.count = 0
+
+ def new_kernel(self):
+ return CppKernel(self.args, self.ws.num_threads)
+
+ def finalize_kernel(self, new_kernel, scheduler):
+ self.count += 1
+ code = self.loops_code
+ ws = self.ws
+ new_kernel.codegen_loops(code, ws)
+
+ def codegen_define_and_call(self, wrapper):
+ self.stack.close()
+ if self.count == 0:
+ return
+
+ arg_defs, call_args = self.args.cpp_argdefs()
+ arg_defs = ",\n".ljust(25).join(arg_defs)
+ code = BracesBuffer()
+ code.writelines([cpp_prefix(), "" f'extern "C" void kernel({arg_defs})'])
+ with code.indent():
+ for old, new in self.args.aliases():
+ code.writeline(f"auto {old} = {new};")
+ code.splice(self.loops_code)
+
+ codecache_def = IndentedBuffer()
+ codecache_def.writeline("async_compile.cpp('''")
+ codecache_def.splice(code)
+ codecache_def.writeline("''')")
+
+ kernel_name = wrapper.next_kernel_name()
+ codecache_str = codecache_def.getvalue()
+ # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
+ # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
+ codecache_str = codecache_str.replace("#pragma CMT", "//")
+ wrapper.define_kernel(kernel_name, codecache_str)
+
+ # generate the code to call this
+ wrapper.writeline(
+ "{}({})".format(kernel_name, ", ".join(call_args)),
+ )
+
+
+class WorkSharing:
+ def __init__(self, code):
+ self.code = code
+ self.in_parallel = False
+ self.num_threads = None
+ self.stack = contextlib.ExitStack()
+
+ def parallel(self, threads):
+ if self.in_parallel and threads != self.num_threads:
+ # wrong number of threads
+ self.close()
+ if not self.in_parallel:
+ self.num_threads = threads
+ self.in_parallel = True
+ if config.cpp.dynamic_threads:
+ self.code.writeline("#pragma omp parallel")
+ else:
+ self.code.writeline(f"#pragma omp parallel num_threads({threads})")
+ self.stack.enter_context(self.code.indent())
+
+ def single(self):
+ if self.in_parallel:
+ self.code.writeline("#pragma omp single")
+ return self.in_parallel
+
+ def close(self):
+ self.stack.close()
+ self.in_parallel = False
+
+ def __enter__(self):
+ self.stack.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.stack.__exit__(exc_type, exc_val, exc_tb)
+
+
+@dataclasses.dataclass
+class LoopLevel:
+ var: sympy.Expr
+ size: sympy.Expr
+ parallel: int = 0
+ simd: bool = False
+ collapsed: bool = False
+ reduction_vars: Dict[str, str] = None
+
+ def lines(self):
+ if self.reduction_vars:
+ reduction = " " + " ".join(
+ f"reduction({RTYPE_TO_CPP[rtype]}:{var})"
+ for var, rtype in self.reduction_vars.items()
+ )
+ else:
+ reduction = ""
+ simd = f"simd simdlen({config.cpp.simdlen})"
+ if self.parallel:
+ # TODO(jansel): look into chunk size and other schedules
+ line1 = f"#pragma omp for{reduction} "
+ if self.parallel > 1:
+ line1 += f" collapse({self.parallel})"
+ if self.simd:
+ line1 = line1.replace(" for ", f" for {simd}")
+ elif self.simd:
+ line1 = f"#pragma omp {simd}{reduction}"
+ elif not self.reduction_vars and codecache.is_gcc():
+ line1 = "#pragma GCC ivdep"
+ else:
+ line1 = ""
+ line2 = f"for({INDEX_TYPE} {self.var}=0; {self.var}<{cexpr(self.size)}; ++{self.var})"
+ if self.collapsed or not line1:
+ return [line2]
+ return [line1, line2]
+
+
+@dataclasses.dataclass
+class LoopNest:
+ loops: List[LoopLevel]
+
+ def __bool__(self):
+ return bool(self.loops)
+
+ def mark_reduction(self, reduction_vars):
+ for loop in self.loops:
+ loop.reduction_vars = reduction_vars
+
+ def mark_parallel(self, par_depth):
+ loops = self.loops
+ loops[0].parallel = par_depth
+ for i in range(1, par_depth):
+ loops[i].collapsed = True
+ loops[0].simd = loops[par_depth - 1].simd
+
+ def codegen(self, code, stack):
+ for loop in self.loops:
+ code.writelines(loop.lines())
+ stack.enter_context(code.indent())
+ else:
+ stack.enter_context(code.indent())
diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h
new file mode 100644
index 0000000000000..d9c0a99f5f42c
--- /dev/null
+++ b/torch/_inductor/codegen/cpp_prefix.h
@@ -0,0 +1,55 @@
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "ATen/core/PhiloxRNGEngine.h"
+#include
+#include
+
+typedef at::Half half;
+typedef at::BFloat16 bfloat16;
+
+template inline T mod(T a, T b) { return a % b; }
+template <> inline float mod(float a, float b) { return std::fmod(a, b); }
+template <> inline double mod(double a, double b) { return std::fmod(a, b); }
+
+constexpr float uint32_to_uniform_float(uint32_t value) {
+ // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
+ constexpr float scale = 4.6566127342e-10;
+ return static_cast(value & 0x7FFFFFFF) * scale;
+}
+
+float normalized_rand_cpu(uint32_t seed, uint32_t offset) {
+ return uint32_to_uniform_float(at::Philox4_32(seed, 0, offset)());
+}
+
+float randn_cpu(uint32_t seed, uint32_t offset) {
+ at::Philox4_32 engine(seed, 0, offset);
+ return engine.randn(10);
+}
+
+template struct AsIntegerType { typedef T type; };
+template <> struct AsIntegerType { typedef uint32_t type; };
+template <> struct AsIntegerType { typedef uint64_t type; };
+
+template void atomic_add(volatile T *addr, T offset) {
+ typedef typename AsIntegerType::type alt_type;
+
+ static_assert(sizeof(std::atomic) == sizeof(T),
+ "std::atomic issue");
+
+ alt_type expected;
+
+ alt_type desired;
+
+ std::atomic *atomic_addr = (std::atomic *)addr;
+ do {
+ T val = *addr;
+ reinterpret_cast(&expected)[0] = val;
+ reinterpret_cast(&desired)[0] = val + offset;
+ } while (!atomic_addr->compare_exchange_weak(expected, desired,
+ std::memory_order_relaxed));
+}
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
new file mode 100644
index 0000000000000..062ca366a2894
--- /dev/null
+++ b/torch/_inductor/codegen/triton.py
@@ -0,0 +1,1399 @@
+import collections
+import contextlib
+import dataclasses
+import functools
+import itertools
+import logging
+import math
+import operator
+from typing import Dict, List
+
+import sympy
+
+import torch
+
+from .. import config, ir, scheduler
+from ..ir import ReductionHint
+from ..utils import (
+ dynamo_logging,
+ free_symbol_startswith,
+ instance_descriptor,
+ sympy_product,
+ sympy_subs,
+)
+from ..virtualized import ops, V
+from .common import (
+ DeferredLine,
+ ExprPrinter,
+ IndentedBuffer,
+ index_prevent_reordering,
+ Kernel,
+ OpOverrides,
+ SizeArg,
+ TensorArg,
+)
+
+log = logging.getLogger(__name__)
+
+
+def signature_of(arg):
+ from triton.runtime.jit import JITFunction
+
+ if isinstance(arg, TensorArg):
+ return JITFunction._type_of(arg.dtype)
+ if isinstance(arg, SizeArg):
+ return JITFunction._key_of(V.graph.sizevars.size_hint(arg.expr))
+ raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
+
+
+def config_of(args):
+ from ..compile_fx import ALIGNMENT
+
+ def is_aligned(x):
+ if isinstance(x, TensorArg):
+ return x.buffer not in V.graph.unaligned_buffers
+ assert isinstance(x, SizeArg)
+ return V.graph.sizevars.maybe_guard_multiple_of(x.expr, ALIGNMENT)
+
+ divisible_by_16 = [i for i, arg in enumerate(args) if is_aligned(arg)]
+ return instance_descriptor(tuple(divisible_by_16), ())
+
+
+class TritonPrinter(ExprPrinter):
+ def _print_ModularIndexing(self, expr):
+ x, div, mod = expr.args
+ x = self.paren(self.doprint(x))
+ div = self.paren(self.doprint(div))
+ mod = self.paren(self.doprint(mod))
+ if div != "1":
+ x = f"({x} // {div})"
+ return f"{x} % {mod}"
+
+ def _print_IndexingDiv(self, expr):
+ x, div = expr.args
+ x = self.paren(self.doprint(x))
+ div = self.paren(self.doprint(div))
+ return f"({x} // {div})"
+
+
+texpr = TritonPrinter().doprint
+
+
+def triton_compute_type(dtype):
+ triton_type_name = str(dtype).split(".")[-1]
+ if triton_type_name == "bool":
+ triton_type_name = "int1"
+ if triton_type_name in ("float16", "bfloat16"):
+ # float16 math is done in float32 inside the kernel
+ triton_type_name = "float32"
+ return f"tl.{triton_type_name}"
+
+
+def triton_constant(value):
+ if value == float("inf"):
+ return 'float("inf")'
+ elif value == float("-inf"):
+ return 'float("-inf")'
+ elif math.isnan(value):
+ return 'float("nan")'
+ return repr(value)
+
+
+class TritonOverrides(OpOverrides):
+ """Map element-wise ops to Triton"""
+
+ @staticmethod
+ def to_dtype(x, dtype: torch.dtype):
+ if dtype == torch.bool:
+ return f"({x} != 0)"
+ return f"{x}.to({triton_compute_type(dtype)})"
+
+ @staticmethod
+ def constant(value, dtype):
+ return triton_constant(value)
+
+ @staticmethod
+ def abs(x):
+ return f"tl.libdevice.abs({x}) if ({x}).dtype is tl.float64 else tl.abs({x})"
+
+ @staticmethod
+ def exp(x):
+ return f"tl.libdevice.exp({x}) if ({x}).dtype is tl.float64 else tl.exp({x})"
+
+ @staticmethod
+ def sqrt(x):
+ return f"tl.libdevice.sqrt({x}) if ({x}).dtype is tl.float64 else tl.sqrt({x})"
+
+ @staticmethod
+ def relu(x):
+ return ops.maximum("0", x)
+
+ @staticmethod
+ def minimum(a, b):
+ return f"tl.minimum({a}, {b})"
+
+ @staticmethod
+ def maximum(a, b):
+ return f"tl.maximum({a}, {b})"
+
+ @staticmethod
+ def where(a, b, c):
+ if not config.triton.simple_where:
+ # wonkyness to work around https://github.com/openai/triton/issues/532
+ # identity calls to force new triton variables (and get access to .shape/.dtype/.numel
+ a = ops.identity(a)
+ b = ops.identity(b)
+ c = ops.identity(c)
+ a = ops.identity(
+ f"{a} | tl.zeros({b}.shape, {a}.dtype) if {b}.numel > 1 else {a}"
+ )
+ a = ops.identity(
+ f"{a} | tl.zeros({c}.shape, {a}.dtype) if {c}.numel > 1 else {a}"
+ )
+ return f"tl.where({a}, {b}, {c})"
+
+ @staticmethod
+ def cos(x):
+ return f"tl.libdevice.cos({x}) if ({x}).dtype is tl.float64 else tl.cos({x})"
+
+ @staticmethod
+ def sin(x):
+ return f"tl.libdevice.sin({x}) if ({x}).dtype is tl.float64 else tl.sin({x})"
+
+ @staticmethod
+ def index_expr(expr, dtype):
+ return V.kernel.indexing(expr)[0]
+
+ @staticmethod
+ def masked(mask, body, other):
+ with V.kernel.mask_loads(mask) as new_mask:
+ result = body()
+ return ops.where(
+ new_mask, result, TritonOverrides.constant(other, torch.float32)
+ )
+
+ @staticmethod
+ def lgamma(x):
+ return f"tl.libdevice.lgamma({x})"
+
+ @staticmethod
+ def logical_and(a, b):
+ return f"{a} & {b}"
+
+ @staticmethod
+ def logical_or(a, b):
+ return f"{a} | {b}"
+
+ @staticmethod
+ def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand op
+ return f"tl.rand({seed}, {offset})"
+
+ @staticmethod
+ def randn(seed, offset, _): # _ here to keep the contract identical to CPU randn op
+ return f"tl.randn({seed}, {offset})"
+
+ @staticmethod
+ def rsqrt(x):
+ return f"tl.libdevice.rsqrt({x})"
+
+ @staticmethod
+ def signbit(x):
+ # XX: This is wrong for the value -0.0 in floating point
+ return f"tl.libdevice.signbitf({x}) if ({x}).dtype is tl.float32 else {x} < 0"
+
+ @staticmethod
+ def fmod(a, b):
+ return f"tl.libdevice.fmod({a}, ({b}).to(tl.float32))"
+
+ @staticmethod
+ def pow(a, b):
+ return f"tl.libdevice.pow({a}, {b})"
+
+ @staticmethod
+ def log(x):
+ return f"tl.libdevice.log({x}) if ({x}).dtype is tl.float64 else tl.log({x})"
+
+ @staticmethod
+ def isinf(x):
+ return f"tl.libdevice.isinfd({x}) if ({x}).dtype is tl.float64 else tl.libdevice.isinff({x})"
+
+ @staticmethod
+ def isnan(x):
+ return f"tl.libdevice.isnand({x}) if ({x}).dtype is tl.float64 else tl.libdevice.isnanf({x})"
+
+ @staticmethod
+ def round(x):
+ return f"tl.libdevice.nearbyint({x})"
+
+ @staticmethod
+ def floor(x):
+ return f"tl.libdevice.floor({x})"
+
+ @staticmethod
+ def floordiv(a, b):
+ # See the comment in lowering.div_mode. a and b are integer type.
+ # Similar to div_floor_kernel_cuda in pytorch core.
+ # Notice that // in triton behaves as truncdiv instead of floordiv
+ quot = f"{a} // {b}"
+ rem = f"{a} % {b}"
+ return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})"
+
+ @staticmethod
+ def trunc(x):
+ return f"tl.libdevice.trunc({x})"
+
+ @staticmethod
+ def truncdiv(a, b):
+ # See the comment in lowering.div_mode. a and b are integer type.
+ # Notice that // in triton behaves as truncdiv instead of floordiv
+ return f"{a} // {b}"
+
+ @staticmethod
+ def ceil(x):
+ return f"tl.libdevice.ceil({x})"
+
+
+@dataclasses.dataclass
+class IterationRanges:
+ """
+ Each range tree represents multiple sets of iteration indexing
+ in a single tiled dimension in the output kernel.
+
+ If you have two loops ranges one (4, 3, 2) and another (4, 6),
+ then the range tree will be:
+ 4 (i0)
+ 3 (i1) 6 (i3)
+ 2 (i2)
+ Where i0 is shared between both loops, but then the split into
+ different indexing vars. All loop ranges must iterate over
+ the same number of elements.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ var_list: List[sympy.Symbol],
+ var_ranges: Dict[sympy.Symbol, sympy.Expr],
+ numel: sympy.Expr,
+ prefix: str,
+ divisor=sympy.Integer(1),
+ length=sympy.Integer(1),
+ ):
+ super(IterationRanges, self).__init__()
+ self.name = name
+ self.var_list = var_list
+ self.var_ranges = var_ranges
+ self.numel = numel
+ self.prefix = prefix
+ self.divisor = divisor
+ self.length = length
+
+ def is_loop(self):
+ return self.prefix == "r"
+
+
+class IterationRangesRoot(IterationRanges):
+ def __init__(
+ self,
+ name: str,
+ numel: sympy.Expr,
+ prefix: str,
+ index: int,
+ kernel: "Kernel",
+ pid_cache=None,
+ ):
+ if pid_cache is None:
+ pid_cache = {}
+ super(IterationRangesRoot, self).__init__(
+ name=name,
+ var_list=[],
+ var_ranges={},
+ numel=numel,
+ prefix=prefix,
+ )
+ self.index = index
+ self.kernel = kernel
+ # Store all the nodes in one flat list
+ self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {}
+ # This is for re-ordering program ID in triton mm template
+ # pid_cache["tl.program_id(0)"] = pid_m
+ self.pid_cache: Dict[str, str] = pid_cache
+
+ def cache_clear(self):
+ for node in self.nodes.values():
+ node.cache_clear()
+
+ def lookup(self, divisor, length):
+ """
+ Lookup a given RangeTreeEntry, creating it if needed
+ """
+ if V.graph.sizevars.maybe_guard_equals(divisor * length, self.numel):
+ expr = ir.IndexingDiv(sympy.Symbol(f"{self.prefix}index"), divisor)
+ else:
+ expr = ir.ModularIndexing(
+ sympy.Symbol(f"{self.prefix}index"), divisor, length
+ )
+
+ if expr not in self.nodes:
+ node = IterationRangesEntry(
+ f"{self.prefix}{next(V.kernel.iter_vars_count)}",
+ divisor,
+ length,
+ expr,
+ self,
+ )
+ V.kernel.range_tree_nodes[node.symbol()] = node
+ self.var_list.append(node.symbol())
+ self.var_ranges[node.symbol()] = length
+ self.nodes[expr] = node
+ return self.nodes[expr]
+
+ def construct(self, lengths: List[sympy.Expr]):
+ divisor = sympy.Integer(1)
+ itervars = []
+ for length in reversed(lengths):
+ itervars.append(self.lookup(divisor, length).symbol())
+ divisor = divisor * length
+ return list(reversed(itervars))
+
+ def vars_and_sizes(self, index: sympy.Expr):
+ """Figure out vars from this tree used in index"""
+ nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
+ nodes = [n for n in nodes if n and n.prefix == self.prefix]
+ nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor))
+ divisor = sympy.Integer(1)
+ index_vars = []
+ sizes = []
+
+ def add(node):
+ nonlocal divisor
+ index_vars.append(node.symbol())
+ sizes.append(node.length)
+ divisor = divisor * node.length
+
+ for node in nodes:
+ if not V.graph.sizevars.maybe_guard_equals(node.divisor, divisor):
+ # fill in unused index var
+ add(self.lookup(divisor, ir.IndexingDiv(node.divisor, divisor)))
+ divisor = node.divisor
+ add(node)
+ if not V.graph.sizevars.maybe_guard_equals(self.numel, divisor):
+ # fill in unused index var
+ add(self.lookup(divisor, ir.IndexingDiv(self.numel, divisor)))
+
+ return list(reversed(index_vars)), list(reversed(sizes))
+
+ def ranges_code(self):
+ size = self.kernel.reshape_size_str(self.index, self.prefix)
+ return f"tl.reshape(tl.arange(0, {self.prefix.upper()}BLOCK), {size})"
+
+ def pid_cache_lookup(self, key):
+ if key in self.pid_cache:
+ return self.pid_cache[key]
+ return key
+
+ def codegen_header(self, code):
+ x = self.prefix
+ if self.is_loop():
+ code.writeline(f"{self.name} = {x}offset + {x}base")
+ else:
+ pid = self.pid_cache_lookup(f"tl.program_id({self.index})")
+ code.writelines(
+ [
+ f"{x}offset = {pid} * {x.upper()}BLOCK",
+ f"{self.name} = {x}offset + {self.ranges_code()}",
+ ]
+ )
+ code.writeline(f"{x}mask = {self.name} < {x}numel")
+
+
+class IterationRangesEntry(IterationRanges):
+ def __init__(
+ self,
+ name: str,
+ divisor: sympy.Expr,
+ length: sympy.Expr,
+ expr: sympy.Expr,
+ parent: IterationRanges,
+ ):
+ super(IterationRangesEntry, self).__init__(
+ name=name,
+ numel=parent.numel / length,
+ var_list=parent.var_list,
+ var_ranges=parent.var_ranges,
+ prefix=parent.prefix,
+ divisor=divisor,
+ length=length,
+ )
+ self.parent = parent
+ self.codegen = functools.lru_cache(None)(self._codegen)
+ self.expr = expr
+
+ def cache_clear(self):
+ self.codegen.cache_clear()
+
+ def writeline(self, line):
+ if self.is_loop():
+ V.kernel.indexing_code.writeline(line)
+ else:
+ # lift non-reduction stores outside loop
+ V.kernel.body.writeline(line)
+
+ def _codegen(self):
+ self.writeline(f"{self.name} = " + texpr(V.kernel.rename_indexing(self.expr)))
+ return self.name
+
+ def symbol(self):
+ return sympy.Symbol(self.name)
+
+ def __hash__(self):
+ return hash(self.name)
+
+ def __eq__(self, other):
+ return self.name == other.name
+
+
+class TritonKernel(Kernel):
+ overrides = TritonOverrides
+ sexpr = texpr
+
+ def __init__(self, *groups, pid_cache=None, reduction_hint=ReductionHint.DEFAULT):
+ if pid_cache is None:
+ pid_cache = {}
+ super(TritonKernel, self).__init__()
+ self.numels = [V.graph.sizevars.simplify(s) for s in groups]
+ self.range_trees = []
+ self.range_tree_nodes = {}
+ self.iter_vars_count = itertools.count()
+ self.inside_reduction = self.numels[-1] != 1
+ self._load_mask = None
+ self.body = IndentedBuffer()
+ self.indexing_code = IndentedBuffer()
+ self.suffix = IndentedBuffer()
+ self.outside_loop_vars = set()
+ self.initialize_range_tree(pid_cache)
+ self.reduction_hint = reduction_hint
+
+ # define this in a closure to make cache local to object
+ @functools.lru_cache(None)
+ def simplify_indexing(index: sympy.Expr):
+ index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges())
+ for tree in self.range_trees:
+ index = self.combine_contiguous_dims(index, tree)
+ return index
+
+ self.simplify_indexing = simplify_indexing
+
+ def initialize_range_tree(self, pid_cache):
+ names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
+ for i in range(len(self.numels)):
+ self.range_trees.append(
+ IterationRangesRoot(
+ names[i], self.numels[i], names[i][0], i, self, pid_cache
+ )
+ )
+ for tree in self.range_trees:
+ # reduction indexing goes inside a loop
+ if tree.prefix != "r":
+ tree.codegen_header(self.body)
+ if self.inside_reduction and self.range_trees[-1].is_loop():
+ # workaround for this issue:
+ # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
+ self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}")
+
+ def disable_reduction(self):
+ @contextlib.contextmanager
+ def ctx():
+ if self.numels[-1] == 1:
+ assert not self.inside_reduction
+ yield
+ return
+ # calling codegen_body() will flush all the pending buffers
+ # and write out a reduction loop
+ self.codegen_body()
+ self.inside_reduction = False
+ yield
+ # flush out any code before opening the next loop
+ self.codegen_body()
+ self.inside_reduction = True
+
+ return ctx()
+
+ def set_ranges(self, *lengths):
+ assert len(lengths) == len(self.range_trees)
+ return [
+ ranges.construct(length)
+ for length, ranges in zip(lengths, self.range_trees)
+ ]
+
+ @staticmethod
+ def _split_iteration_ranges(
+ groups: List[sympy.Expr], lengths: List[List[sympy.Expr]]
+ ):
+ sv = V.graph.sizevars
+ new_ranges = [[] for _ in groups]
+ remaining = [sv.simplify(g) for g in groups]
+ var_count = itertools.count()
+
+ def add_range(i, expr):
+ expr = sv.simplify(expr)
+ if not sv.maybe_guard_multiple_of(remaining[i], expr):
+ raise CantSplit()
+ # guard on the last item out
+ sv.maybe_guard_equals(remaining[i], expr)
+ remaining[i] = ir.IndexingDiv(remaining[i], expr)
+ new_ranges[i].append(expr)
+ return next(var_count)
+
+ def make_combined(size, idx1, idx2):
+ def getter(flat_vars):
+ return size * flat_vars[idx1] + flat_vars[idx2]
+
+ return getter
+
+ return_getters_groups = []
+ current_group = 0
+ for length_group in lengths:
+ return_getters = []
+ for size in length_group:
+ if sv.maybe_guard_equals(size, 1):
+ return_getters.append(lambda _: sympy.Integer(0))
+ continue
+
+ while (
+ current_group < len(remaining)
+ and sv.size_hint(remaining[current_group]) == 1
+ ):
+ # scroll to next group with remaining elements
+ current_group += 1
+
+ if sv.size_hint(size) > sv.size_hint(remaining[current_group]):
+ # need to break size in two
+ if not sv.maybe_guard_multiple_of(size, remaining[current_group]):
+ raise CantSplit()
+ size1 = remaining[current_group]
+ size2 = ir.IndexingDiv(size, remaining[current_group])
+ return_getters.append(
+ make_combined(
+ size2,
+ add_range(current_group, size1),
+ add_range(current_group + 1, size2),
+ )
+ )
+ else:
+ return_getters.append(
+ operator.itemgetter(add_range(current_group, size))
+ )
+ return_getters_groups.append(return_getters)
+
+ assert all(
+ V.graph.sizevars.size_hint(s) == 1 for s in remaining
+ ), f"failed to set ranges {remaining} {lengths}"
+
+ return new_ranges, return_getters_groups
+
+ @classmethod
+ def is_compatible(cls, groups: List[sympy.Expr], lengths: List[List[sympy.Expr]]):
+ try:
+ cls._split_iteration_ranges(groups, lengths)
+ return True
+ except CantSplit:
+ return False
+
+ def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
+ """
+ We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1).
+
+ To do this we need to split up the iteration space of i0 into something like:
+ for i1 in s0:
+ for i2 in s1:
+ i0 = i1*s1 + i2
+ ....
+
+ This function matches and resplits lengths to the groups of
+ this kernel to enable tiled + non-tiled fusions.
+ """
+ groups = [rt.numel for rt in self.range_trees]
+ if not self.inside_reduction:
+ groups[-1] = sympy.Integer(1)
+
+ if len(lengths) == len(self.range_trees) and all(
+ V.graph.sizevars.simplify(sympy_product(x) - g) == 0
+ for x, g in zip(lengths, groups)
+ ):
+ return self.set_ranges(*lengths)
+
+ new_ranges, return_getters_groups = self._split_iteration_ranges(
+ groups, lengths
+ )
+ itervars = list(itertools.chain(*self.set_ranges(*new_ranges)))
+ return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
+
+ def is_indirect_indexing(self, index: sympy.Expr):
+ # tmpX means indirect indexing
+ return free_symbol_startswith(index, "tmp")
+
+ def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
+ """
+ More aggressive simplification to merge contiguous dims
+ """
+ if isinstance(index, (sympy.Integer, sympy.Symbol)):
+ return index
+ index_vars, sizes = tree.vars_and_sizes(index)
+ if len(sizes) <= 1:
+ return index
+ new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
+ index_vars, sizes, index_prevent_reordering([index], index_vars, sizes)
+ )
+ if new_sizes == sizes:
+ return index
+ new_index_vars = tree.construct(new_sizes)
+ new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars))))
+ return new_index
+
+ def indexing(
+ self,
+ index: sympy.Expr,
+ copy_shape=None,
+ dense_indexing=False,
+ ):
+ """
+ Compute the index and mask to pass to tl.load() or tl.store()
+ """
+ index = self.simplify_indexing(index)
+ index_vars = index.free_symbols
+ index_str = texpr(self.rename_indexing(self.codegen_indexing(index)))
+ indirect_indexing = self.is_indirect_indexing(index)
+
+ need_dense = (
+ config.triton.dense_indexing
+ or dense_indexing
+ or indirect_indexing
+ or self._load_mask is not None
+ ) and index != 0
+
+ have_dense = True
+ have_loop_vars = False
+ mask = []
+ dense_mask = []
+
+ for tree in self.range_trees:
+ if tree.prefix == "r" and not self.inside_reduction:
+ continue
+ if index_vars.intersection(tree.var_list):
+ have_loop_vars = True
+ have_dense = False
+ mask.append(f"{tree.prefix}mask")
+ dense_mask.append(f"{tree.prefix}mask")
+
+ if (need_dense and not have_dense) or index == 0:
+ index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
+ if index == 0:
+ return index_str, "None"
+ else:
+ mask = dense_mask
+
+ elif not have_loop_vars and copy_shape:
+ mask = dense_mask
+ index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
+ elif indirect_indexing:
+ mask = dense_mask
+
+ if self._load_mask:
+ mask.append(self._load_mask)
+ elif not mask:
+ mask = ["None"]
+
+ if mask == ["xmask"] and index == 0 and self.range_trees[0].numel == 1:
+ # This causes a triton error:
+ # https://github.com/openai/triton/issues/633
+ mask = ["None"]
+
+ return index_str, " & ".join(mask)
+
+ def var_ranges(self):
+ return dict(
+ itertools.chain.from_iterable(
+ tree.var_ranges.items() for tree in self.range_trees
+ )
+ )
+
+ def codegen_indexing(self, expr: sympy.Expr):
+ expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
+ for sym in sorted(expr.free_symbols, key=str):
+ if sym in self.range_tree_nodes:
+ self.range_tree_nodes[sym].codegen()
+ return expr
+
+ @contextlib.contextmanager
+ def mask_loads(self, mask):
+ """Context manager to add an additional mask to tl.load/store"""
+ prior = self._load_mask
+ if prior:
+ mask = self.cse.generate(self.compute, f"{mask} & {prior}")
+
+ self._load_mask = mask
+ with self.swap_buffers(self.compute, self.compute):
+ # TODO(jansel): do we need a reshape here?
+ yield mask
+ self._load_mask = prior
+
+ def load(self, name: str, index: sympy.Expr):
+ var = self.args.input(name)
+ indirect_indexing = self.is_indirect_indexing(index)
+ index, mask = self.indexing(index)
+
+ if "rmask" in mask:
+ # This eviction policy heuristic is untested.
+ # ptillet suggested we should try only doing this for
+ # the first N-1 loops and not for the final loop.
+ ep = ", eviction_policy='evict_last'"
+ else:
+ ep = ""
+ # "other" below is a workaround for https://github.com/openai/triton/issues/737
+ # for bool, even though it's likely subject to the same bug, setting `other` leads
+ # to LLVM errors so we are skipping it for now
+ if "tmp" in mask and V.graph.get_dtype(name) != torch.bool:
+ other = ", other=0"
+ else:
+ other = ""
+ line = f"tl.load({var} + ({index}), {mask}{ep}{other})"
+ if V.graph.get_dtype(name) in (torch.float16, torch.bfloat16):
+ line += ".to(tl.float32)"
+
+ if (
+ self.inside_reduction
+ and "rmask" not in mask
+ and "tmp" not in mask
+ and not indirect_indexing
+ ):
+ # can lift a common load outside of reduction loop
+ # One exception is when this is an indirect_load.
+ tmp = self.cse.generate(self.body, line)
+ else:
+ tmp = self.cse.generate(self.loads, line)
+
+ if not self.inside_reduction or "rmask" not in mask:
+ self.outside_loop_vars.add(tmp)
+ return tmp
+
+ def store(self, name, index, value, mode=None):
+ var = self.args.output(name)
+ index, mask = self.indexing(index, value, dense_indexing=True)
+ if mode is None:
+ line = f"tl.store({var} + ({index}), {value}, {mask})"
+ elif mode == "atomic_add":
+ line = f"tl.atomic_add({var} + ({index}), {value}, {mask})"
+ else:
+ raise NotImplementedError(f"store mode={mode}")
+ self.stores.writeline(name, line)
+ if not self.inside_reduction:
+ self.outside_loop_vars.add(value)
+
+ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
+ assert self.inside_reduction
+ default = triton_constant(ir.Reduction.default_value(reduction_type, src_dtype))
+ masks = [f"{tree.prefix}mask" for tree in self.range_trees]
+ if self._load_mask:
+ masks.append(self._load_mask)
+ sizes = [f"{tree.prefix.upper()}BLOCK" for tree in self.range_trees]
+ sizes[-1] = "1"
+ reduction_range_prefix = self.range_trees[-1].prefix
+ reduction_sizes = ["1" for _ in self.range_trees]
+ reduction_sizes[-1] = f"{reduction_range_prefix.upper()}BLOCK"
+
+ if reduction_type == "any":
+ reduction_type = "max"
+
+ dim = len(self.range_trees) - 1
+ result_var = self.cse.newvar()
+ if (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
+ self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
+ accumulator = f"_{result_var}"
+ self.body.writeline(
+ f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}) + {default}"
+ )
+ accumulator_index = None
+ if reduction_type in {"argmax", "argmin"}:
+ accumulator_index = f"_{result_var}_index"
+ self.body.writeline(
+ f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)"
+ )
+
+ updated = value
+ if reduction_type in {"min", "argmin"}:
+ masks.append(f"({accumulator} > {value})")
+ elif reduction_type in {"max", "argmax"}:
+ masks.append(f"({accumulator} < {value})")
+ elif reduction_type == "sum":
+ updated = f"{accumulator} + {value}"
+ else:
+ raise NotImplementedError(f"reduction_type {reduction_type}")
+
+ cond = " & ".join(masks)
+
+ if accumulator_index:
+ # argmax or argmin
+ self.compute.writeline(
+ f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})",
+ )
+ self.compute.writeline(
+ f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
+ )
+
+ if accumulator_index:
+ # argmax, argmin
+ self.suffix.writelines(
+ [
+ f"{accumulator_index}_reduce = tl.reshape(",
+ f"\ttl.{reduction_type}({accumulator}, {dim}), [{', '.join(sizes)}]).to(tl.int32)",
+ f"{accumulator_index}_mask = (tl.reshape(tl.arange(0, {reduction_range_prefix.upper()}BLOCK),",
+ f"\t[{', '.join(reduction_sizes)}]) == {accumulator_index}_reduce)",
+ f"{result_var} = tl.reshape(tl.sum(",
+ f"\ttl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim}), [{', '.join(sizes)}])",
+ ]
+ )
+ else:
+ self.suffix.writeline(
+ f"{result_var} = tl.reshape(tl.{reduction_type}({accumulator}, {dim}), [{', '.join(sizes)}])"
+ )
+ else:
+ var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
+ self.suffix.writeline(f"{result_var} = {var_name}")
+ self.inside_reduction = False
+ index, mask = self.indexing(index, result_var)
+ assert "rmask" not in index
+ self.inside_reduction = True
+ self.outside_loop_vars.add(result_var)
+ self.cse.store_cache[name] = result_var
+ if name not in V.graph.removed_buffers:
+ var = self.args.output(name)
+ self.suffix.writeline(
+ DeferredLine(name, f"tl.store({var} + {index}, {result_var}, {mask})")
+ )
+
+ def codegen_body(self):
+ """
+ Concat output code from index_code, loads, compute, stores,
+ suffix into self.body.
+
+ For pointwise kernels, this is called just once at the end.
+
+ For reduction kernels, this generates a loop over the reduction
+ axis.
+ """
+ if not (
+ self.indexing_code
+ or self.loads
+ or self.stores
+ or self.compute
+ or self.suffix
+ ):
+ return
+
+ if self.inside_reduction:
+ self.body.writeline("for roffset in range(0, rnumel, RBLOCK):")
+ with self.body.indent():
+ # last range tree is always reduction
+ self.range_trees[-1].codegen_header(self.body)
+ self.body.splice(self.indexing_code)
+ self.body.splice(self.loads)
+ self.body.splice(self.compute)
+ self.body.splice(self.stores)
+
+ # invalidate any caches that came from inside the reduction loop
+ self.cse.invalidate(self.outside_loop_vars)
+ self.range_trees[-1].cache_clear()
+ else:
+ self.body.splice(self.indexing_code)
+ self.body.splice(self.loads)
+ self.body.splice(self.compute)
+ self.body.splice(self.stores)
+ self.body.splice(self.suffix)
+ self.indexing_code.clear()
+ self.loads.clear()
+ self.compute.clear()
+ self.stores.clear()
+ self.suffix.clear()
+
+ def codegen_kernel(self, name=None):
+ from triton import next_power_of_2
+
+ code = IndentedBuffer()
+ size_hints = [
+ next_power_of_2(V.graph.sizevars.size_hint(numel)) for numel in self.numels
+ ]
+ if not self.inside_reduction:
+ size_hints.pop()
+ heuristics = "pointwise"
+ else:
+ heuristics = "reduction"
+
+ if name is None:
+ code.splice(
+ f"""
+ import triton
+ import triton.language as tl
+ from {config.inductor_import}.ir import ReductionHint
+ from {config.inductor_import}.triton_ops.autotune import {heuristics}
+ from {config.inductor_import}.utils import instance_descriptor
+ """
+ )
+
+ argdefs, _, signature = self.args.python_argdefs()
+ triton_meta = {
+ "signature": dict(enumerate(map(signature_of, signature))),
+ "device": V.graph.scheduler.current_device.index,
+ "configs": [config_of(signature)],
+ "constants": {},
+ }
+
+ for tree in self.range_trees:
+ if tree.prefix != "r" or self.inside_reduction:
+ triton_meta["signature"][len(argdefs)] = signature_of(
+ SizeArg(f"{tree.prefix}numel", tree.numel)
+ )
+ argdefs.append(f"{tree.prefix}numel")
+ # constexpr version causes issues, see
+ # https://github.com/pytorch/torchdynamo/pull/1362
+ # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
+ # tree.numel
+ # )
+ # argdefs.append(f"{tree.prefix}numel: tl.constexpr")
+
+ for tree in self.range_trees:
+ if tree.prefix != "r" or self.inside_reduction:
+ argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
+
+ if self.inside_reduction:
+ reduction_hint = self.reduction_hint
+ heuristics_line = f"""
+ @{heuristics}(size_hints={size_hints!r},
+ reduction_hint={reduction_hint},
+ filename=__file__,
+ meta={triton_meta!r})
+ @triton.jit
+ """
+ else:
+ heuristics_line = f"""
+ @{heuristics}(size_hints={size_hints!r}, filename=__file__, meta={triton_meta!r})
+ @triton.jit
+ """
+ code.splice(heuristics_line)
+ code.writeline(f"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):")
+ self.codegen_body()
+ with code.indent():
+ self.codegen_static_numels(code)
+ for old, new in self.args.aliases():
+ code.writeline(f"{old} = {new}")
+ code.splice(self.body)
+
+ if name is not None:
+ return code.getvalue()
+
+ wrapper = IndentedBuffer()
+ wrapper.writeline("async_compile.triton('''")
+ wrapper.splice(code.getvalue(), strip=True)
+ wrapper.writeline("''')")
+ return wrapper.getvalue()
+
+ def codegen_static_numels(self, code):
+ """
+ We get a small speedup from hard coding numels if they are static.
+ """
+ for tree in self.range_trees:
+ if tree.prefix != "r" or self.inside_reduction:
+ if isinstance(V.graph.sizevars.simplify(tree.numel), sympy.Integer):
+ code.writeline(
+ f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)}"
+ )
+ elif not config.dynamic_shapes:
+ code.writeline(
+ f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)} # dynamic_shapes=False"
+ )
+
+ def reshape_size_str(self, i=None, x=None):
+ sizes = ["1"] * (len(self.range_trees) - int(self.numels[-1] == 1))
+ if i is not None:
+ sizes[i] = f"{x.upper()}BLOCK"
+ return f"[{', '.join(sizes)}]"
+
+ def dense_size_str(self):
+ sizes = []
+ for tree in self.range_trees:
+ if tree.prefix != "r" or self.inside_reduction:
+ sizes.append(f"{tree.prefix.upper()}BLOCK")
+ elif tree.prefix == "r" and tree.numel != 1:
+ sizes.append("1")
+ return f"[{', '.join(sizes)}]"
+
+ def call_kernel(self, code, name: str):
+ _, call_args, _ = self.args.python_argdefs()
+ grid = []
+ # TODO(jansel): if there are constants, we shouldn't bother passing them as args
+ for tree in self.range_trees:
+ if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
+ expr = texpr(tree.numel)
+ else:
+ expr = f"{name}_{tree.prefix}numel"
+ code.writeline(f"{expr} = {texpr(tree.numel)}")
+ if tree.prefix != "r" or self.inside_reduction:
+ call_args.append(expr)
+ if tree.prefix != "r":
+ grid.append(expr)
+ call_args = ", ".join(call_args)
+ stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index)
+ code.writeline(
+ f"{name}.run({call_args}, grid=grid({', '.join(grid)}), stream={stream_name})"
+ )
+
+
+class TritonScheduling:
+ def __init__(self, scheduler):
+ self.scheduler = scheduler
+
+ def group_fn(self, sizes):
+ return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
+
+ def can_fuse(self, node1, node2):
+ """
+ Hook called by Scheduler to determine if the Triton backend
+ can fuse node1 and node2. These nodes might already be
+ FusedSchedulerNodes.
+ """
+ _, (numel1, rnumel1) = node1.group
+ _, (numel2, rnumel2) = node2.group
+
+ if node1.is_reduction() and node2.is_reduction():
+ return numel1 == numel2 and rnumel1 == rnumel2
+
+ if not node1.is_reduction() and not node2.is_reduction():
+ if not (numel1 == numel2 and rnumel1 == rnumel2):
+ return False
+
+ # check for a bad combined tiling
+ tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
+ tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1)
+ tiling3 = self.select_tiling(
+ node1.get_nodes() + node2.get_nodes(), numel1, rnumel1
+ )
+ if config.triton.tiling_prevents_pointwise_fusion:
+ if len(tiling1) > 2:
+ if len(tiling2) > 2:
+ return tiling1 == tiling2 == tiling3
+ else:
+ return tiling1 == tiling3
+ elif len(tiling2) > 2:
+ return tiling2 == tiling3
+
+ return True
+
+ if not node1.is_reduction() and node2.is_reduction():
+ assert rnumel1 == 1 and rnumel2 != 1
+ if numel1 == numel2 * rnumel2:
+ if not all(
+ TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
+ for n in node1.get_nodes()
+ ):
+ return False
+ if config.triton.tiling_prevents_reduction_fusion:
+ return self.select_tiling(node1.get_nodes(), numel1) in (
+ (numel1, 1),
+ (numel2, rnumel2, 1),
+ )
+ return True
+
+ return numel1 == numel2
+
+ assert node1.is_reduction() and not node2.is_reduction()
+ # swap args to hit the case above
+ return self.can_fuse_horizontal(node2, node1)
+
+ can_fuse_vertical = can_fuse
+ can_fuse_horizontal = can_fuse
+
+ def codegen_nodes(self, nodes):
+ """
+ Given a set of pre-fused nodes, generate a Triton kernel.
+ """
+ _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
+ node_schedule = []
+ current_loop_writes = set()
+ done = set()
+
+ def fits_in_main_body(n):
+ _, (node_numel, node_rnumel) = n.group
+ return (node_numel == numel and node_rnumel == rnumel) or (
+ node_numel == numel * rnumel and node_rnumel == 1
+ )
+
+ def fits_outside_reduction(n):
+ _, (node_numel, node_rnumel) = n.group
+ return node_numel == numel and node_rnumel == 1 and rnumel != 1
+
+ @contextlib.contextmanager
+ def end_current_reduction_loop():
+ if current_loop_writes:
+ # flush out any other runnable nodes to reduce number of loops
+ for other_node in nodes[index + 1 :]:
+ if (
+ node not in done
+ and fits_in_main_body(other_node)
+ and not (
+ current_loop_writes & other_node.recursive_predecessors
+ )
+ ):
+ done.add(node)
+ current_loop_writes.add(node.get_name())
+ node_schedule.append(node)
+
+ if node_schedule and node_schedule[-1] is EnableReduction:
+ node_schedule.pop()
+ else:
+ node_schedule.append(DisableReduction)
+ yield
+ node_schedule.append(EnableReduction)
+ current_loop_writes.clear()
+
+ for index, node in enumerate(nodes):
+ if node in done:
+ continue
+ done.add(node)
+
+ if fits_in_main_body(node):
+ if current_loop_writes & node.recursive_predecessors and rnumel != 1:
+ with end_current_reduction_loop():
+ pass # need to start a new reduction loop
+ current_loop_writes.add(node.get_name())
+ node_schedule.append(node)
+ elif fits_outside_reduction(node):
+ with end_current_reduction_loop():
+ node_schedule.append(node)
+ else:
+ raise NotImplementedError(
+ f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}"
+ )
+
+ for node in node_schedule:
+ if node not in (EnableReduction, DisableReduction):
+ node.mark_run()
+
+ log.log(dynamo_logging.CODE, "schedule: %s", node_schedule)
+ return self.codegen_node_schedule(node_schedule, numel, rnumel)
+
+ @staticmethod
+ def reduction_hint(node):
+ assert node.is_reduction()
+ if all(
+ dep.is_contiguous()
+ for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes)
+ ):
+ return ReductionHint.INNER
+ else:
+ return node.node.data.reduction_hint
+
+ def codegen_node_schedule(self, node_schedule, numel, reduction_numel):
+ tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel)
+ reductions = list(
+ filter(
+ lambda n: n not in (EnableReduction, DisableReduction)
+ and n.is_reduction(),
+ node_schedule,
+ )
+ )
+ if len(reductions) > 0:
+ hints = [self.reduction_hint(n) for n in reductions]
+ if hints.count(hints[0]) == len(hints):
+ reduction_hint_val = hints[0]
+ else:
+ reduction_hint_val = ReductionHint.DEFAULT
+ else:
+ reduction_hint_val = ReductionHint.DEFAULT
+ with TritonKernel(*tiled_groups, reduction_hint=reduction_hint_val) as kernel:
+ stack = contextlib.ExitStack()
+ for node in node_schedule:
+ if node is DisableReduction:
+ stack.enter_context(kernel.disable_reduction())
+ elif node is EnableReduction:
+ stack.close()
+ else:
+ node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
+
+ wrapper = V.graph.wrapper_code
+ src_code = kernel.codegen_kernel()
+ if src_code in wrapper.kernels:
+ kernel_name = wrapper.kernels[src_code]
+ else:
+ kernel_name = wrapper.next_kernel_name()
+ wrapper.kernels[src_code] = kernel_name
+ subs_name = kernel_name if config.triton.ordered_kernel_names else "kernel"
+ src_code = src_code.replace("KERNEL_NAME", subs_name)
+ # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
+ # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
+ src_code = src_code.replace("#pragma CMT", "#")
+ wrapper.define_kernel(kernel_name, src_code)
+ kernel.call_kernel(wrapper, kernel_name)
+ self.scheduler.free_buffers()
+
+ @staticmethod
+ @functools.lru_cache(32)
+ def candidate_tilings(node):
+ ranges, reduction_ranges = node.get_ranges()
+ if len(ranges) <= 1:
+ return ()
+
+ rw = node.pointwise_read_writes()
+ assert len(rw.range_vars) == len(ranges)
+
+ deps = [
+ dep
+ for dep in itertools.chain(rw.reads, rw.writes)
+ if dep.name not in V.graph.removed_buffers
+ ]
+ write_names = {dep.name for dep in rw.writes}
+
+ tilings = []
+
+ for dep in deps:
+ strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars)
+ assert len(strides) == len(ranges)
+ try:
+ split = strides.index(1) + 1
+ if split == len(ranges):
+ continue
+ if all(s == 0 for s in strides[split:]):
+ # if this is a broadcasted tensor and all dimensions after split are broadcast,
+ # this is not a real split
+ continue
+
+ except ValueError:
+ continue
+ tiled_groups = (
+ V.graph.sizevars.simplify(sympy_product(ranges[:split])),
+ V.graph.sizevars.simplify(sympy_product(ranges[split:])),
+ )
+ # score by number of elements
+ score = V.graph.sizevars.size_hint(
+ sympy_product(
+ size for size, stride in zip(ranges, strides) if stride != 0
+ )
+ )
+ if dep.name in write_names:
+ # ngimel said contiguous writes is more important than reads
+ score *= 2
+ if CandidateTiling.is_good_size(tiled_groups[0]):
+ score *= 2
+ if CandidateTiling.is_good_size(tiled_groups[1]):
+ score *= 2
+
+ if (
+ V.graph.sizevars.size_hint(
+ score - sympy_product(itertools.chain(ranges, reduction_ranges))
+ )
+ >= 0
+ ):
+ tilings.append(CandidateTiling(tiled_groups, score, dep.name))
+ return tilings
+
+ @classmethod
+ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
+ """
+ Heuristics to decide how to tile kernels.
+ Currently, we tile based on stride-1 dimensions.
+
+ Returns:
+ `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel`
+
+ """
+ if reduction_numel != 1 or config.triton.max_tiles <= 1:
+ # TODO(jansel): should we tile reductions?
+ return (numel, reduction_numel)
+
+ seen_names = set()
+ candidate_tiles = collections.Counter()
+ for node in EnableReduction.filter(node_schedule):
+ for tiling in cls.candidate_tilings(node):
+ if tiling.name in seen_names:
+ continue
+ seen_names.add(tiling.name)
+ candidate_tiles[tiling.tiling] += tiling.score
+
+ ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()]
+
+ if config.triton.max_tiles >= 3:
+ # Add one 3D tiling choice
+ for i in range(1, len(ranked_tilings)):
+ a0, a1 = ranked_tilings[0]
+ b0, b1 = ranked_tilings[i]
+ if V.graph.sizevars.size_hint(a1 - b1) == 0:
+ continue
+ if V.graph.sizevars.size_hint(a1 - b1) < 0:
+ # swap so a0 is bigger
+ a0, a1 = ranked_tilings[i]
+ b0, b1 = ranked_tilings[0]
+ assert V.graph.sizevars.size_hint(a1 - b1) > 0
+ if V.graph.sizevars.maybe_guard_multiple_of(a1, b1):
+ tiling = (a0, ir.IndexingDiv(a1, b1), b1)
+ ranked_tilings = [tiling] + ranked_tilings
+ break # only 1 choice for now
+
+ for tiled_groups in ranked_tilings:
+ new_groups = (*tiled_groups, reduction_numel)
+ if all(
+ TritonKernel.is_compatible(new_groups, node.get_ranges())
+ for node in node_schedule
+ if isinstance(node, scheduler.SchedulerNode)
+ ):
+ return new_groups
+
+ return (numel, reduction_numel)
+
+ def flush(self):
+ pass
+
+
+@dataclasses.dataclass
+class CandidateTiling:
+ tiling: List[sympy.Expr]
+ score: int # higher is better
+ name: str = None
+
+ @staticmethod
+ def is_good_size(s):
+ """Somewhat arbitrary heuristic used to boost scores for some sizes"""
+ s = V.graph.sizevars.size_hint(s)
+ return s >= 32 and (s % 32 == 0)
+
+
+class DisableReduction:
+ """
+ Marker to invoke `kernel.disable_reduction()`. This closes a
+ reduction loop and allows for pointwise ops to occur on the output
+ of a reduction.
+ """
+
+
+class EnableReduction:
+ """
+ Marker to end a DisableReduction block.
+ """
+
+ @staticmethod
+ def filter(node_schedule):
+ """
+ Get the nodes from node_schedule skipping those in a
+ DisableReduction block.
+ """
+ disabled = False
+ for node in node_schedule:
+ if node in (EnableReduction, DisableReduction):
+ # Don't tile stuff outside the main reduction loop
+ disabled = node is DisableReduction
+ elif disabled:
+ pass
+ else:
+ yield node
+
+
+class CantSplit(Exception):
+ pass
diff --git a/torch/_inductor/codegen/triton_conv_delta_x.j2 b/torch/_inductor/codegen/triton_conv_delta_x.j2
new file mode 100644
index 0000000000000..a7bf8ac433eac
--- /dev/null
+++ b/torch/_inductor/codegen/triton_conv_delta_x.j2
@@ -0,0 +1,181 @@
+
+@conv_heuristics()
+@triton.jit
+def {{kernel_name}}(
+ {% for i in template_inout_argdefs %}
+ {{i}},
+ {% endfor %}
+ # stride of tensor
+ stride_xn,
+ stride_xc,
+ stride_xh,
+ stride_xw,
+ stride_wn,
+ stride_wc,
+ stride_wh,
+ stride_ww,
+ stride_yn,
+ stride_yc,
+ stride_yh,
+ stride_yw,
+ stride_biasn,
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ OUT_H,
+ OUT_W,
+ # parameters of conv
+ stride_h,
+ stride_w,
+ padding_h,
+ padding_w,
+ dilation_h,
+ dilation_w,
+ output_padding_h,
+ output_padding_w,
+ groups: tl.constexpr,
+ # pointer inc for x
+ delta_x_ptr,
+ # fusable kernels args
+ {% for i in extra_argdefs %}
+ {{i}},
+ {% endfor %}
+ # Metaparameters
+ ACC_TYPE: tl.constexpr,
+ CONV1X1_NHWC: tl.constexpr,
+ # blocks in different dimension
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ # reduction tiling parameter for matmul
+ BLOCK_K: tl.constexpr,
+):
+ """
+ each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of y it should compute.
+ pid_nhw = tl.program_id(0)
+ pid_k = tl.program_id(1)
+
+ # offset for output y
+ off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
+ off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_y_n = off_y_nhw // (OUT_H * OUT_W)
+ off_y_hw = off_y_nhw % (OUT_H * OUT_W)
+ off_y_h = off_y_hw // OUT_W
+ off_y_w = off_y_hw % OUT_W
+
+ # offset for the initial ptr for x
+ off_x_n = off_y_n
+ off_x_h = off_y_h * stride_h - padding_h
+ off_x_w = off_y_w * stride_w - padding_w
+ off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
+ off_x_crs = tl.arange(0, BLOCK_K)
+
+ CRS = IN_C * KERNEL_H * KERNEL_W
+ # load inc ptr of x, upade x_ptrs
+ if not CONV1X1_NHWC:
+ delta_x_ptrs = delta_x_ptr + off_x_crs
+ off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS, other=0)
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
+ else:
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
+
+ mask_x = (
+ (off_x_n < BATCH)
+ & (off_x_h >= 0)
+ & (off_x_h < IN_H)
+ & (off_x_w >= 0)
+ & (off_x_w < IN_W)
+ )[:, None] & (off_x_crs < CRS)[None, :]
+
+ # offset for the inital ptr for w
+ off_w_crs = tl.arange(0, BLOCK_K)
+ off_w_k = off_y_k
+ w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
+ mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
+
+ # ------ load x ------
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+ # ------ load w ------
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+
+ # -----------------------------------------------------------
+ # allocate accumulator
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for crs in range(0, CRS, BLOCK_K):
+
+ # ------ matrix multiplication ------
+ acc += tl.dot(matrix_x, matrix_w)
+ # ------ update ptrs ------
+ w_ptrs += BLOCK_K
+ # load inc ptr of x, upade x_ptrs
+ if not CONV1X1_NHWC:
+ delta_x_ptrs += BLOCK_K
+ off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
+ off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS, other=0)
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
+ else:
+ off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
+ x_ptrs += BLOCK_K
+
+ mask_x = (
+ (off_x_n < BATCH)
+ & (off_x_h >= 0)
+ & (off_x_h < IN_H)
+ & (off_x_w >= 0)
+ & (off_x_w < IN_W)
+ )[:, None] & (off_x_crs < CRS)[None, :]
+ mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
+ # ------ prefetch ------
+ # ------ load x ------
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+ # ------ load w ------
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+
+ acc = acc.to({{out_def}}.dtype.element_ty)
+
+{% if keep_store %}
+ # rematerialize -- this saves some registers
+ # offset for output y
+ off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
+ off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_y_n = off_y_nhw // (OUT_H * OUT_W)
+ off_y_hw = off_y_nhw % (OUT_H * OUT_W)
+ # consider output padding
+ off_y_h = off_y_hw // OUT_W + output_padding_h
+ off_y_w = off_y_hw % OUT_W + output_padding_w
+
+ # y ptrs in the block of [BLOCK_M, BLOCK_N]
+ y_ptrs = (
+ {{out_def}}
+ + off_y_n[:, None] * stride_yn
+ + off_y_h[:, None] * stride_yh
+ + off_y_w[:, None] * stride_yw
+ + off_y_k[None, :] * stride_yc
+ )
+
+ # out-of-bounds check
+ mask_y = (
+ (off_y_n < BATCH)[:, None]
+ & (off_y_h < OUT_H + output_padding_h)[:, None]
+ & (off_y_w < OUT_W + output_padding_w)[:, None]
+ & (off_y_k < KERNEL_N)[None, :]
+ )
+ tl.store(y_ptrs, acc, mask=mask_y)
+{% endif %}
+
+{% if pointwise_code %}
+{{ pointwise_code | indent(4, true) }}
+ {#
+ z = tl.load(z_ptrs, mask=mask_z)
+ acc += z
+ #}
+{% endif %}
+
+ return
diff --git a/torch/_inductor/codegen/triton_conv_delta_x_hwc.j2 b/torch/_inductor/codegen/triton_conv_delta_x_hwc.j2
new file mode 100644
index 0000000000000..34f2c3881272a
--- /dev/null
+++ b/torch/_inductor/codegen/triton_conv_delta_x_hwc.j2
@@ -0,0 +1,200 @@
+
+@conv_heuristics()
+@triton.jit
+def {{kernel_name}}(
+ {% for i in template_inout_argdefs %}
+ {{i}},
+ {% endfor %}
+ # stride of tensor
+ stride_xn,
+ stride_xc,
+ stride_xh,
+ stride_xw,
+ stride_wn,
+ stride_wc,
+ stride_wh,
+ stride_ww,
+ stride_yn,
+ stride_yc,
+ stride_yh,
+ stride_yw,
+ stride_biasn,
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ OUT_H,
+ OUT_W,
+ # parameters of conv
+ stride_h,
+ stride_w,
+ padding_h,
+ padding_w,
+ dilation_h,
+ dilation_w,
+ output_padding_h,
+ output_padding_w,
+ groups,
+ # pointer inc for x
+ delta_xh_ptr,
+ delta_xw_ptr,
+ delta_xc_ptr,
+ # fusable kernels args
+ {% for i in extra_argdefs %}
+ {{i}},
+ {% endfor %}
+ # Metaparameters
+ ACC_TYPE: tl.constexpr,
+ CONV1X1_NHWC: tl.constexpr,
+ # blocks in different dimension
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ # reduction tiling parameter for matmul
+ BLOCK_K: tl.constexpr,
+):
+ """
+ each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of y it should compute.
+ pid_nhw = tl.program_id(0)
+ pid_k = tl.program_id(1)
+
+ # offset for output y
+ off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
+ off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_y_n = off_y_nhw // (OUT_H * OUT_W)
+ off_y_hw = off_y_nhw % (OUT_H * OUT_W)
+ off_y_h = off_y_hw // OUT_W + output_padding_h
+ off_y_w = off_y_hw % OUT_W + output_padding_w
+
+ # offset for the initial ptr for x
+ off_x_n = off_y_n
+ off_x_h = off_y_h * stride_h - padding_h
+ off_x_w = off_y_w * stride_w - padding_w
+ off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
+ off_x_crs = tl.arange(0, BLOCK_K)
+
+ CRS = IN_C * KERNEL_H * KERNEL_W
+ # load inc ptr of x, upade x_ptrs
+ if not CONV1X1_NHWC:
+ delta_xh_ptrs = delta_xh_ptr + off_x_crs
+ delta_xw_ptrs = delta_xw_ptr + off_x_crs
+ delta_xc_ptrs = delta_xc_ptr + off_x_crs
+ delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
+ delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
+ delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
+ off_x_crs_unpacked = (
+ delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
+ )
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
+ else:
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
+ delta_xh = 0
+ delta_xw = 0
+
+ mask_x = (
+ (off_x_n < BATCH)[:, None]
+ & (off_x_crs < CRS)[None, :]
+ & (off_x_h[:, None] + delta_xh[None, :] >= 0)
+ & (off_x_h[:, None] + delta_xh[None, :] < IN_H)
+ & (off_x_w[:, None] + delta_xw[None, :] >= 0)
+ & (off_x_w[:, None] + delta_xw[None, :] < IN_W)
+ )
+
+ # offset for the inital ptr for w
+ off_w_crs = tl.arange(0, BLOCK_K)
+ off_w_k = off_y_k
+ w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
+ mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
+
+ # ------ load x ------
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+ # ------ load w ------
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+
+ # -----------------------------------------------------------
+ # allocate accumulator
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for crs in range(0, CRS, BLOCK_K):
+
+ # ------ matrix multiplication ------
+ acc += tl.dot(matrix_x, matrix_w)
+ # ------ update ptrs ------
+ w_ptrs += BLOCK_K
+ # load inc ptr of x, upade x_ptrs
+ off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
+ if not CONV1X1_NHWC:
+ delta_xh_ptrs += BLOCK_K
+ delta_xw_ptrs += BLOCK_K
+ delta_xc_ptrs += BLOCK_K
+ delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
+ delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
+ delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
+ off_x_crs_unpacked = (
+ delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
+ )
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
+ else:
+ x_ptrs += BLOCK_K
+
+ mask_x = (
+ (off_x_n < BATCH)[:, None]
+ & (off_x_crs < CRS)[None, :]
+ & (off_x_h[:, None] + delta_xh[None, :] >= 0)
+ & (off_x_h[:, None] + delta_xh[None, :] < IN_H)
+ & (off_x_w[:, None] + delta_xw[None, :] >= 0)
+ & (off_x_w[:, None] + delta_xw[None, :] < IN_W)
+ )
+ mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
+ # ------ prefetch ------
+ # ------ load x ------
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+ # ------ load w ------
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+
+ acc = acc.to({{out_def}}.dtype.element_ty)
+
+{% if keep_store %}
+ # rematerialize -- this saves some registers
+ # offset for output y
+ off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
+ off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_y_n = off_y_nhw // (OUT_H * OUT_W)
+ off_y_hw = off_y_nhw % (OUT_H * OUT_W)
+ # consider output padding
+ off_y_h = off_y_hw // OUT_W + output_padding_h
+ off_y_w = off_y_hw % OUT_W + output_padding_w
+
+ # y ptrs in the block of [BLOCK_M, BLOCK_N]
+ y_ptrs = (
+ {{out_def}}
+ + off_y_n[:, None] * stride_yn
+ + off_y_h[:, None] * stride_yh
+ + off_y_w[:, None] * stride_yw
+ + off_y_k[None, :] * stride_yc
+ )
+
+ # out-of-bounds check
+ mask_y = (
+ (off_y_n < BATCH)[:, None]
+ & (off_y_h < OUT_H + output_padding_h)[:, None]
+ & (off_y_w < OUT_W + output_padding_w)[:, None]
+ & (off_y_k < KERNEL_N)[None, :]
+ )
+ tl.store(y_ptrs, acc, mask=mask_y)
+{% endif %}
+
+{% if pointwise_code %}
+{{ pointwise_code | indent(4, true) }}
+ {#
+ z = tl.load(z_ptrs, mask=mask_z)
+ acc += z
+ #}
+{% endif %}
+
+ return
diff --git a/torch/_inductor/codegen/triton_mm.j2 b/torch/_inductor/codegen/triton_mm.j2
new file mode 100644
index 0000000000000..3073b3f490714
--- /dev/null
+++ b/torch/_inductor/codegen/triton_mm.j2
@@ -0,0 +1,80 @@
+import torch
+import triton
+import triton.language as tl
+{# from triton.ops.matmul import get_configs_io_bound #}
+
+@mm_autotune()
+@mm_heuristics()
+@triton.jit
+def {{kernel_name}}(
+ {% for i in template_inout_argdefs %}
+ {{i}},
+ {% endfor %}
+ M,
+ N,
+ K,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ # fusable kernels args
+ {% for i in extra_argdefs %}
+ {{i}},
+ {% endfor %}
+ allow_tf32: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+ SPLIT_K: tl.constexpr,
+ EVEN_K: tl.constexpr,
+ ACC_TYPE: tl.constexpr,
+):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_z = tl.program_id(1)
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ # do matrix multiplication
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
+ # pointers
+ A_ptrs = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B_ptrs = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(K, 0, -BLOCK_K * SPLIT_K):
+ if EVEN_K:
+ a = tl.load(A_ptrs)
+ b = tl.load(B_ptrs)
+ else:
+ a = tl.load(A_ptrs, mask=rk[None, :] < k, other=0.0)
+ b = tl.load(B_ptrs, mask=rk[:, None] < k, other=0.0)
+ acc += tl.dot(a, b, allow_tf32=allow_tf32)
+ A_ptrs += BLOCK_K * SPLIT_K * stride_ak
+ B_ptrs += BLOCK_K * SPLIT_K * stride_bk
+ acc = acc.to({{out_def}}.dtype.element_ty)
+
+{% if keep_store %}
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C_ptrs = {{out_def}} + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+ # handles write-back with reduction-splitting
+ tl.store(C_ptrs, acc, mask=mask)
+{% endif %}
+
+{% if pointwise_code %}
+{{ pointwise_code | indent(4, true) }}
+{% endif %}
diff --git a/torch/_inductor/codegen/triton_template.py b/torch/_inductor/codegen/triton_template.py
new file mode 100644
index 0000000000000..308b1c1f45d91
--- /dev/null
+++ b/torch/_inductor/codegen/triton_template.py
@@ -0,0 +1,349 @@
+import logging
+import os
+
+import sympy
+
+from .. import config, ir
+from ..virtualized import V
+from .common import IndentedBuffer
+from .triton import TritonKernel
+
+log = logging.getLogger((__name__))
+template_dict = {ir.Convolution: "triton_conv", ir.MatrixMultiply: "triton_mm"}
+
+
+class TritonTemplateKernel(TritonKernel):
+ def __init__(self, node: ir.ExternKernel, *groups):
+ from jinja2 import Environment, FileSystemLoader, StrictUndefined
+
+ self.node = node
+ self.template_name = template_dict[type(node)]
+ env = Environment(
+ loader=FileSystemLoader(os.path.dirname(__file__)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ undefined=StrictUndefined,
+ )
+ pid_cache = {}
+ if isinstance(node, ir.Convolution):
+ pid_cache = {
+ "tl.program_id(0)": "pid_nhw",
+ "tl.program_id(1)": "pid_k",
+ }
+ self.map_args()
+ KERNEL_H = self.args_dict["KERNEL_H"]
+ KERNEL_W = self.args_dict["KERNEL_W"]
+ padding_h = self.args_dict["padding_h"]
+ padding_w = self.args_dict["padding_w"]
+ if ((KERNEL_H == "1" and KERNEL_W == "1")) or (
+ (padding_h == "0") and (padding_w == "0")
+ ):
+ self.template_name += "_delta_x"
+ else:
+ self.template_name += "_delta_x_hwc"
+ elif isinstance(node, ir.MatrixMultiply):
+ pid_cache = {
+ "tl.program_id(0)": "pid_m",
+ "tl.program_id(1)": "pid_n",
+ }
+
+ self.template = env.get_template(self.template_name + ".j2")
+ super(TritonTemplateKernel, self).__init__(*groups, pid_cache=pid_cache)
+
+ def rename_vars(self):
+ for k, v in self.inout_dict.items():
+ self.args.output_buffers[v] = k
+ if isinstance(self.node, ir.Convolution):
+ self.cse.store_cache[self.inout_dict["y"]] = "acc"
+ elif isinstance(self.node, ir.MatrixMultiply):
+ self.cse.store_cache[self.inout_dict["C"]] = "acc"
+
+ def assign_block_numel(self):
+ code = IndentedBuffer()
+ if isinstance(self.node, ir.Convolution):
+ code.writeline("XBLOCK: tl.constexpr = BLOCK_M")
+ code.writeline("YBLOCK: tl.constexpr = BLOCK_N")
+ code.writeline(
+ "xnumel = BATCH * (OUT_H + 2 * output_padding_h) * (OUT_W + 2 * output_padding_w)"
+ )
+ code.writeline("ynumel = KERNEL_N")
+ elif isinstance(self.node, ir.MatrixMultiply):
+ code.writeline("XBLOCK: tl.constexpr = BLOCK_M")
+ code.writeline("YBLOCK: tl.constexpr = BLOCK_N")
+ code.writeline("xnumel = M")
+ code.writeline("ynumel = N")
+
+ return code
+
+ def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=True):
+ # use dense_indexing for TritonTemplateKernel to avoid map::at error
+ return super().indexing(index, copy_shape, dense_indexing)
+
+ def codegen_body(
+ self, name, fuse, could_remove_kernel_buf, kernel_buf_replace_name
+ ):
+ """
+ put render_variables into the template
+ to generate the final code
+ """
+ # get extra_argdefs from fusable triton kernels
+ self.extra_argdefs = []
+ self.extra_call_args = []
+ argdefs, call_args, _ = self.args.python_argdefs()
+ # add extra args if it is different from
+ # current TritonTemplateKernel args
+ for (argdef, call_arg) in zip(argdefs, call_args):
+ if (
+ argdef not in self.inout_dict.keys()
+ and argdef not in self.args_dict.keys()
+ ):
+ self.extra_argdefs.append(argdef)
+ self.extra_call_args.append(call_arg)
+
+ if could_remove_kernel_buf:
+ if isinstance(self.node, ir.Convolution):
+ self.inout_dict.pop("y")
+ elif isinstance(self.node, ir.MatrixMultiply):
+ self.inout_dict.pop("C")
+ self.template_inout_argdefs = list(self.inout_dict.keys())
+
+ if kernel_buf_replace_name is not None:
+ idx = self.extra_call_args.index(kernel_buf_replace_name)
+ kernel_buf_replace_def = self.extra_argdefs[idx]
+
+ super().codegen_body()
+ self.pointwise_code = IndentedBuffer()
+ self.pointwise_code.splice(self.assign_block_numel())
+ self.pointwise_code.splice(self.body)
+ render_dict = {}
+ render_dict["kernel_name"] = name
+ render_dict["template_inout_argdefs"] = self.template_inout_argdefs
+ render_dict["extra_argdefs"] = self.extra_argdefs
+ render_dict["pointwise_code"] = self.pointwise_code.getvalue() if fuse else None
+ render_dict["keep_store"] = not could_remove_kernel_buf
+ render_dict["out_def"] = (
+ self.out_def() if not could_remove_kernel_buf else kernel_buf_replace_def
+ )
+ self.body = self.template.render(render_dict) + "\n"
+
+ def out_def(self):
+ if isinstance(self.node, ir.Convolution):
+ return "y"
+ elif isinstance(self.node, ir.MatrixMultiply):
+ return "C"
+
+ def codegen_kernel(
+ self,
+ name=None,
+ fuse=False,
+ could_remove_kernel_buf=False,
+ kernel_buf_replace_name=None,
+ ):
+
+ code = IndentedBuffer()
+
+ self.codegen_body(name, fuse, could_remove_kernel_buf, kernel_buf_replace_name)
+ code.splice(self.body)
+
+ if name is not None:
+ return code.getvalue()
+
+ wrapper = IndentedBuffer()
+ wrapper.writeline("TritonCodeCache.load('''")
+ wrapper.splice(code.getvalue(), strip=True)
+ wrapper.writeline("''').kernel")
+
+ return wrapper.getvalue()
+
+ def map_args(self):
+ """
+ map the constant args or
+ kernel[grid](..., IN_C, IN_H, IN_W, strides,...)
+ """
+ (
+ self.inout_dict,
+ self.args_dict,
+ self.const_dict,
+ self.other_dict,
+ ) = self.node.map_args()
+
+ def precompute(self, wrapper, kernel_name):
+ """
+ some triton kernels needs host precompute tensor
+ for example, triton_conv needs precompte delta_x_ptr
+ """
+ if isinstance(self.node, ir.Convolution):
+ if self.const_dict["CONV1X1_NHWC"] == "False":
+ IN_C = self.args_dict["IN_C"]
+ KERNEL_H = self.args_dict["KERNEL_H"]
+ KERNEL_W = self.args_dict["KERNEL_W"]
+ dilation_h = self.args_dict["dilation_h"]
+ dilation_w = self.args_dict["dilation_w"]
+ stride_wc = self.args_dict["stride_wc"]
+ stride_wh = self.args_dict["stride_wh"]
+ stride_ww = self.args_dict["stride_ww"]
+ stride_xc = self.args_dict["stride_xc"]
+ stride_xh = self.args_dict["stride_xh"]
+ stride_xw = self.args_dict["stride_xw"]
+ device = self.other_dict["device"]
+ if self.template_name == "triton_conv_delta_x":
+ assert "delta_x_ptr" not in self.args_dict.keys()
+ self.args_dict["delta_x_ptr"] = f"delta_x_{kernel_name}"
+ wrapper.writeline(
+ f"from {config.inductor_import}.triton_ops import _conv"
+ )
+ wrapper.writeline(
+ f"delta_x_{kernel_name} = _conv._delta_x_ptr("
+ f"{IN_C}, {KERNEL_H}, {KERNEL_W}, "
+ f"{dilation_h}, {dilation_w}, "
+ f"{stride_wc}, {stride_wh}, {stride_ww}, "
+ f"{stride_xc}, {stride_xh}, {stride_xw}, {device})"
+ )
+ # triton_conv_delta_x_hwc
+ else:
+ assert "delta_xh_ptr" not in self.args_dict.keys()
+ assert "delta_xw_ptr" not in self.args_dict.keys()
+ assert "delta_xc_ptr" not in self.args_dict.keys()
+ self.args_dict["delta_xh_ptr"] = f"delta_xh_{kernel_name}"
+ self.args_dict["delta_xw_ptr"] = f"delta_xw_{kernel_name}"
+ self.args_dict["delta_xc_ptr"] = f"delta_xc_{kernel_name}"
+ wrapper.writeline(
+ f"from {config.inductor_import}.triton_ops import _conv"
+ )
+ wrapper.writeline(
+ f"delta_xh_{kernel_name}, delta_xw_{kernel_name}, delta_xc_{kernel_name}"
+ f" = _conv._delta_x_ptr_hwc("
+ f"{IN_C}, {KERNEL_H}, {KERNEL_W}, "
+ f"{dilation_h}, {dilation_w}, "
+ f"{stride_wc}, {stride_wh}, {stride_ww}, "
+ f"{stride_xc}, {stride_xh}, {stride_xw}, {device})"
+ )
+
+ # else, delta_x_ptr is None
+ else:
+ assert "delta_x_ptr" not in self.args_dict.keys()
+ self.args_dict["delta_x_ptr"] = "None"
+ return
+
+ def gen_grid(self, name):
+ code = IndentedBuffer()
+ if isinstance(self.node, ir.Convolution):
+ BATCH = self.args_dict["BATCH"]
+ OUT_H = self.args_dict["OUT_H"]
+ OUT_W = self.args_dict["OUT_W"]
+ KERNEL_N = self.args_dict["KERNEL_N"]
+ with code.indent():
+ code.splice(
+ f"""
+ def grid_{name}(META):
+ return (
+ triton.cdiv({BATCH} * {OUT_H} * {OUT_W}, META["BLOCK_M"]),
+ triton.cdiv({KERNEL_N}, META["BLOCK_N"]),
+ )
+ """
+ )
+ if isinstance(self.node, ir.MatrixMultiply):
+ M = self.args_dict["M"]
+ N = self.args_dict["N"]
+ with code.indent():
+ code.splice(
+ f"""
+ def grid_{name}(META):
+ return (
+ triton.cdiv({M}, META["BLOCK_M"]) * triton.cdiv({N}, META["BLOCK_N"]),
+ META["SPLIT_K"],
+ )
+ """
+ )
+ return code.getvalue()
+
+ def call_kernel(self, wrapper, name: str):
+ # gen code to call kernel
+ # e.g.
+ # def grid(META):
+ # return (...)
+ # kernel1[grid](arg0, arg1, ...)
+ extra_args = ", ".join(self.extra_call_args)
+ self_args = ", ".join({**self.inout_dict, **self.args_dict}.values())
+ self_const_kwargs = ", ".join(f"{k}={v}" for k, v in self.const_dict.items())
+ args = self_args + (
+ ", " + extra_args if extra_args and len(extra_args) > 0 else ""
+ )
+ args_kwargs = args + ", " + self_const_kwargs
+ wrapper.writeline(self.gen_grid(name))
+ wrapper.writeline(f"{name}[grid_{name}]({args_kwargs})")
+
+
+def should_use_template(node: ir.ExternKernel):
+ template_kernels = [ir.Convolution, ir.MatrixMultiply]
+ if type(node) in template_kernels and ir.is_triton(node.get_device()):
+ if isinstance(node, ir.Convolution):
+ return node.kernel != "aten.convolution"
+ elif isinstance(node, ir.MatrixMultiply):
+ return node.kernel != "aten.mm.out"
+ return False
+
+
+def template_can_fuse(snode1, snode2):
+ assert snode1.is_template()
+ if snode1.group != snode2.group:
+ return False
+ tiling = snode1.get_nodes()[0].node.get_template_tiling()
+ for node in snode2.get_nodes():
+ if not TritonKernel.is_compatible(tiling, node.get_ranges()):
+ return False
+ return True
+
+
+def template_codegen(scheduler, scheduler_node, epilogue):
+ """
+ codegen function for triton templates
+ scheduler: Scheduler
+ scheduler_node: ExternKernelSchedulerNode
+ """
+ log.debug("template_codegen: %s -- %s", scheduler_node, epilogue)
+
+ wrapper = V.graph.wrapper_code
+ _, groups = scheduler_node.group
+
+ with TritonTemplateKernel(
+ scheduler_node.node, *scheduler_node.node.get_template_tiling()
+ ) as kernel:
+ # map const args/ shape/ strides to kernel args
+ kernel.map_args()
+ # set self.args name to match the TritonTemplateKernel's args names
+ kernel.rename_vars()
+ # scheduler.pop_group will keep iterating all reachable fusable SchedulerNodes
+ assert type(kernel.node) in template_dict.keys()
+
+ kernel.store_buffer_names.add(scheduler_node.get_name())
+
+ for node in epilogue:
+ node.mark_run()
+ node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
+
+ could_remove_kernel_buf = (
+ kernel.args.output_buffers[scheduler_node.get_name()] == "REMOVED"
+ )
+ kernel_buf_replace_name = None
+ if could_remove_kernel_buf:
+ for node in epilogue:
+ if kernel.args.output_buffers[node.get_name()] != "REMOVED":
+ kernel_buf_replace_name = node.get_name()
+ break
+ assert kernel_buf_replace_name is not None
+
+ kernel_name = wrapper.next_kernel_name()
+ # code gen kernel
+ wrapper.header.splice(
+ kernel.codegen_kernel(
+ kernel_name,
+ bool(epilogue),
+ could_remove_kernel_buf,
+ kernel_buf_replace_name,
+ )
+ )
+ # gen precompute tensor (like delta_x_ptr) if needed
+ kernel.precompute(wrapper, kernel_name)
+ # code gen call to kernel
+ kernel.call_kernel(wrapper, kernel_name)
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
new file mode 100644
index 0000000000000..996ed9c64bb10
--- /dev/null
+++ b/torch/_inductor/codegen/wrapper.py
@@ -0,0 +1,398 @@
+import collections
+import dataclasses
+import functools
+import hashlib
+from itertools import count
+from typing import Any, Dict, List
+
+from .. import codecache, config, ir
+from ..utils import dynamo_utils, has_triton, sympy_dot, sympy_product
+from ..virtualized import V
+from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel
+from .triton import texpr
+
+pexpr = texpr
+
+
+def buffer_reuse_key(node: ir.Buffer):
+ size = node.get_size()
+ stride = node.get_stride()
+ last_element = sympy_dot([s - 1 for s in size], stride)
+ return (
+ node.get_device(),
+ node.get_dtype(),
+ V.graph.sizevars.simplify(sympy_product(size)),
+ # Detect gaps in tensor storage caused by strides
+ V.graph.sizevars.size_hint(last_element),
+ )
+
+
+def make_buffer_reuse(old, new):
+ assert old.get_dtype() == new.get_dtype()
+ if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
+ return f"{new.get_name()} = {old.get_name()}; del {old.get_name()}"
+
+ return (
+ f"{new.get_name()} = as_strided({old.get_name()}, "
+ f"{V.graph.sizevars.codegen_shape_tuple(new.get_size())}, "
+ f"{V.graph.sizevars.codegen_shape_tuple(new.get_stride())}); del {old.get_name()}"
+ )
+
+
+def make_buffer_allocation(buffer):
+ device = buffer.get_device()
+ dtype = buffer.get_dtype()
+ shape = tuple(buffer.get_size())
+ stride = tuple(buffer.get_stride())
+ return (
+ f"{buffer.get_name()} = empty_strided("
+ f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
+ f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
+ f"device='{device.type}', dtype={dtype})"
+ )
+
+
+class MemoryPlanningState:
+ def __init__(self):
+ super().__init__()
+ self.reuse_pool: Dict[
+ Any, List["FreeIfNotReusedLine"]
+ ] = collections.defaultdict(list)
+
+ def __contains__(self, key):
+ return bool(self.reuse_pool.get(key, None))
+
+ def pop(self, key) -> "FreeIfNotReusedLine":
+ item = self.reuse_pool[key].pop()
+ assert not item.is_reused
+ return item
+
+ def push(self, key, item: "FreeIfNotReusedLine"):
+ assert not item.is_reused
+ self.reuse_pool[key].append(item)
+
+
+class MemoryPlanningLine:
+ def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
+ """First pass to find reuse"""
+ return self
+
+ def codegen(self, code: IndentedBuffer):
+ """Second pass to output code"""
+ pass
+
+
+@dataclasses.dataclass
+class AllocateLine(MemoryPlanningLine):
+ node: ir.Buffer
+
+ def plan(self, state: MemoryPlanningState):
+ if self.node.get_name() in V.graph.removed_buffers:
+ return NullLine()
+
+ # try to reuse a recently freed buffer
+ key = buffer_reuse_key(self.node)
+ if key in state:
+ free_line = state.pop(key)
+ free_line.is_reused = True
+ return ReuseLine(free_line.node, self.node)
+
+ return self
+
+ def codegen(self, code: IndentedBuffer):
+ assert self.node.get_name() not in V.graph.removed_buffers
+ code.writeline(make_buffer_allocation(self.node))
+
+
+@dataclasses.dataclass
+class FreeIfNotReusedLine(MemoryPlanningLine):
+ node: ir.Buffer
+ is_reused: bool = False
+
+ def plan(self, state: MemoryPlanningState):
+ assert not self.is_reused
+ if self.node.get_name() in V.graph.removed_buffers:
+ return NullLine()
+ state.push(buffer_reuse_key(self.node), self)
+ return self
+
+ def codegen(self, code: IndentedBuffer):
+ assert self.node.get_name() not in V.graph.removed_buffers
+ if not self.is_reused:
+ code.writeline(f"del {self.node.get_name()}")
+
+
+@dataclasses.dataclass
+class ReuseLine(MemoryPlanningLine):
+ node: ir.Buffer
+ reused_as: ir.Buffer
+
+ def plan(self, state: MemoryPlanningState):
+ if self.reused_as.get_name() in V.graph.removed_buffers:
+ # we hit this case only for inplace buffers
+ return FreeLine(self.node).plan(state)
+ assert self.node.get_name() not in V.graph.removed_buffers
+ return self
+
+ def codegen(self, code: IndentedBuffer):
+ assert self.node.get_name() not in V.graph.removed_buffers
+ assert self.reused_as.get_name() not in V.graph.removed_buffers
+ code.writeline(make_buffer_reuse(self.node, self.reused_as) + " # reuse")
+
+
+@dataclasses.dataclass
+class FreeLine(MemoryPlanningLine):
+ node: ir.Buffer
+
+ def plan(self, state: MemoryPlanningState):
+ if self.node.get_name() in V.graph.removed_buffers:
+ return NullLine()
+ return self
+
+ def codegen(self, code: IndentedBuffer):
+ assert self.node.get_name() not in V.graph.removed_buffers
+ code.writeline(f"del {self.node.get_name()}")
+
+
+class NullLine(MemoryPlanningLine):
+ pass
+
+
+class WrapperCodeGen(CodeGen):
+ """
+ The outer wrapper that calls the kernels.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self._names_iter = count()
+ self.header = IndentedBuffer()
+ self.prefix = IndentedBuffer()
+ self.kernels = {}
+ self.lines = []
+ self.header.splice(
+ f"""
+ from ctypes import c_void_p, c_long
+ import torch
+ import random
+ from torch import empty_strided, as_strided, device
+ from {codecache.__name__} import AsyncCompile
+
+ aten = torch.ops.aten
+ async_compile = AsyncCompile()
+
+ """
+ )
+
+ if has_triton():
+ self.header.splice(
+ f"""
+ import triton
+ import triton.language as tl
+ from {config.inductor_import}.triton_ops.autotune import grid
+ from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
+ """
+ )
+
+ if config.triton.convolution != "aten":
+ self.header.splice(
+ f"""
+ from {config.inductor_import}.triton_ops.conv_perf_model import early_config_prune
+ from {config.inductor_import}.triton_ops.conv_perf_model import estimate_conv_time
+ from {config.inductor_import}.triton_ops.autotune import conv_heuristics
+ """
+ )
+
+ if config.triton.mm != "aten":
+ self.header.splice(
+ f"""
+ from {config.inductor_import}.triton_ops.autotune import mm_heuristics
+ from {config.inductor_import}.triton_ops.autotune import mm_autotune
+ """
+ )
+
+ if config.triton.use_bmm:
+ self.header.writeline(
+ f"from {config.inductor_import}.triton_ops.batched_matmul import bmm_out as triton_bmm_out"
+ )
+
+ self.prefix.splice(
+ f"""
+
+ async_compile.wait(globals())
+ del async_compile
+
+ def call({', '.join(V.graph.graph_inputs.keys())}):
+ """
+ )
+ with self.prefix.indent():
+ for name in V.graph.randomness_seeds:
+ self.prefix.writeline(
+ f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
+ )
+ V.graph.sizevars.codegen(self.prefix, V.graph.graph_inputs)
+
+ for name, value in V.graph.constants.items():
+ # include a hash so our code cache gives different constants different files
+ hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
+ self.header.writeline(f"{name} = None # {hashed}")
+
+ self.allocated = set()
+ self.freed = set()
+ self.write_get_cuda_stream = functools.lru_cache(None)(
+ self.write_get_cuda_stream
+ )
+
+ def write_get_cuda_stream(self, index):
+ name = f"stream{index}"
+ self.writeline(f"{name} = get_cuda_stream({index})")
+ return name
+
+ def next_kernel_name(self):
+ return f"kernel{next(self._names_iter)}"
+
+ def codegen_allocation(self, buffer):
+ name = buffer.get_name()
+ if name in V.graph.removed_buffers or name in self.allocated:
+ return
+ self.allocated.add(name)
+
+ layout = buffer.get_layout()
+ if isinstance(layout, ir.MutationLayout):
+ return
+ if isinstance(layout, ir.AliasedLayout):
+ assert isinstance(layout.view, ir.ReinterpretView)
+ if not layout.maybe_guard_aligned():
+ V.graph.unaligned_buffers.add(name)
+ self.codegen_allocation(layout.view.data)
+ allocation = DeferredLine(
+ name, f"{name} = {layout.view.codegen_reference()} # alias"
+ )
+ self.writeline(allocation)
+ return
+
+ self.writeline(AllocateLine(buffer))
+
+ def codegen_free(self, buffer):
+ name = buffer.get_name()
+ if not self.can_reuse(buffer):
+ return
+ self.freed.add(name)
+
+ layout = buffer.get_layout()
+ if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
+ self.writeline(f"del {name}")
+ return
+
+ self.writeline(FreeIfNotReusedLine(buffer))
+
+ def can_reuse(self, buffer):
+ name = buffer.get_name()
+ if (
+ name in V.graph.removed_buffers
+ or name in V.graph.graph_inputs
+ or name in V.graph.constants
+ or name in self.freed
+ ):
+ return False
+ return True
+
+ def codegen_inplace_reuse(self, input_buffer, output_buffer):
+ assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
+ self.codegen_allocation(input_buffer)
+ self.freed.add(input_buffer.get_name())
+ self.allocated.add(output_buffer.get_name())
+ self.writeline(ReuseLine(input_buffer, output_buffer))
+
+ @dynamo_utils.dynamo_timed
+ def generate(self):
+ result = IndentedBuffer()
+ result.splice(self.header)
+ result.splice(self.prefix)
+
+ out_names = V.graph.get_output_names()
+ with result.indent():
+ while (
+ self.lines
+ and isinstance(self.lines[-1], MemoryPlanningLine)
+ and self.lines[-1].node.name not in out_names
+ ):
+ # these lines will be pointless
+ self.lines.pop()
+
+ # codegen allocations in two passes
+ planning_state = MemoryPlanningState()
+ for i in range(len(self.lines)):
+ if isinstance(self.lines[i], MemoryPlanningLine):
+ self.lines[i] = self.lines[i].plan(planning_state)
+
+ for line in self.lines:
+ if isinstance(line, MemoryPlanningLine):
+ line.codegen(result)
+ else:
+ result.writeline(line)
+
+ output_refs = [x.codegen_reference() for x in V.graph.graph_outputs]
+ if output_refs:
+ result.writeline("return (" + ", ".join(output_refs) + ", )")
+ else:
+ result.writeline("return ()")
+
+ self.add_benchmark_harness(result)
+
+ return result.getvalue()
+
+ def add_benchmark_harness(self, output):
+ """
+ Append a benchmark harness to generated code for debugging
+ """
+ if not config.benchmark_harness:
+ return
+
+ def add_fake_input(name, shape, stride, device, dtype):
+ output.writeline(
+ f"{name} = rand_strided("
+ f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
+ f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
+ f"device='{device.type}', dtype={dtype})"
+ )
+
+ output.writelines(["", "", 'if __name__ == "__main__":'])
+ with output.indent():
+ output.splice(
+ f"""
+ from {config.dynamo_import}.testing import rand_strided
+ from {config.inductor_import}.utils import print_performance
+ """,
+ strip=True,
+ )
+
+ for name, value in V.graph.constants.items():
+ add_fake_input(
+ name, value.size(), value.stride(), value.device, value.dtype
+ )
+
+ for name, value in V.graph.graph_inputs.items():
+ shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
+ stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
+ add_fake_input(
+ name, shape, stride, value.get_device(), value.get_dtype()
+ )
+
+ output.writeline(
+ f"print_performance(lambda: call({', '.join(V.graph.graph_inputs.keys())}))"
+ )
+
+ def define_kernel(self, name: str, kernel: str):
+ self.header.splice(f"\n\n{name} = {kernel}")
+
+ def call_kernel(self, name: str, kernel: Kernel):
+ tmp = IndentedBuffer()
+ kernel.call_kernel(self, tmp, name)
+ for line in tmp.getvalue().split("\n"):
+ line = line.strip()
+ if line:
+ self.writeline(line)
+
+ def writeline(self, line):
+ self.lines.append(line)
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
new file mode 100644
index 0000000000000..0f7fcbbf96acc
--- /dev/null
+++ b/torch/_inductor/compile_fx.py
@@ -0,0 +1,368 @@
+import dataclasses
+import functools
+import itertools
+import logging
+from typing import List
+
+import functorch
+from functorch.compile import make_boxed_compiler, min_cut_rematerialization_partition
+
+import torch.fx
+from torch._subclasses.fake_tensor import FakeTensor
+from torch.utils._mode_utils import no_dispatch
+
+from . import config, overrides
+from .debug import DebugContext
+from .decomposition import select_decomp_table
+from .graph import GraphLowering
+from .utils import (
+ dynamo_logging,
+ dynamo_optimizations,
+ dynamo_utils,
+ has_incompatible_cudagraph_ops,
+)
+from .virtualized import V
+
+log = logging.getLogger(__name__)
+ALIGNMENT = 16
+
+aot_autograd = dynamo_optimizations.backends.aot_autograd
+normalize_ir = dynamo_optimizations.normalize.normalize_ir
+is_aot_autograd_safe_to_run = dynamo_optimizations.training.is_aot_autograd_safe_to_run
+count_calls = dynamo_utils.count_calls
+
+
+@dataclasses.dataclass
+class BoxedBool:
+ value: bool
+
+ def __bool__(self):
+ return self.value
+
+ @staticmethod
+ def disable(obj):
+ if isinstance(obj, BoxedBool):
+ obj.value = False
+ return obj
+ return False
+
+
+# copy_ fails when trying to write to tensors with memory overlap,
+# for expanded dimensions (a dimension which used to have size 1 -> ?)
+# we can select one element from that dimension and write to it
+# to achieve writing to all values of that dimension of the input tensor
+def get_expanded_dims(t):
+ return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
+
+
+def index_expanded_dims(t, expanded_dims):
+ for expanded_dim in expanded_dims:
+ t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
+ return t
+
+
+def complex_memory_overlap(t):
+ indexed_tensor = index_expanded_dims(t, get_expanded_dims(t))
+ return torch._debug_has_internal_overlap(indexed_tensor) != 0
+
+
+def is_unspec_input(t):
+ return t.device.type == "cpu" and t.dim() == 0
+
+
+@functools.lru_cache(None)
+def _step_logger():
+ return dynamo_logging.get_step_logger(log)
+
+
+@DebugContext.wrap
+@no_dispatch()
+def compile_fx_inner(
+ gm: torch.fx.GraphModule,
+ example_inputs: List[torch.Tensor],
+ cudagraphs=None,
+ num_fixed=0,
+ is_backward=False,
+ graph_id=None,
+):
+ if dynamo_utils.count_calls(gm.graph) == 0:
+ return gm
+
+ _step_logger()(
+ logging.INFO,
+ "torchinductor compiling "
+ f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
+ f"graph {graph_id}",
+ )
+
+ V.debug.fx_graph(gm, example_inputs)
+
+ if cudagraphs is None:
+ cudagraphs = config.triton.cudagraphs
+
+ graph = GraphLowering(gm, num_dynamic_inputs=len(example_inputs))
+ with V.set_graph_handler(graph):
+ graph.run(*example_inputs)
+ compiled_fn = graph.compile_to_fn()
+
+ complex_memory_overlap_inputs = any(
+ complex_memory_overlap(t) for t in example_inputs
+ )
+
+ if (
+ cudagraphs
+ and set(graph.device_types) == {"cuda"}
+ and not graph.mutated_inputs
+ and not has_incompatible_cudagraph_ops(gm)
+ and not complex_memory_overlap_inputs
+ ):
+ compiled_fn = cudagraphify(
+ compiled_fn, example_inputs, static_input_idxs=range(num_fixed)
+ )
+ elif cudagraphs:
+ BoxedBool.disable(cudagraphs)
+
+ if len(set(graph.device_types)) > 1:
+ log.warning("skipping cudagraphs due to multiple devices")
+ elif set(graph.device_types) == {"cuda"}:
+ if graph.mutated_inputs:
+ log.warning("skipping cudagraphs due to input mutation")
+ elif complex_memory_overlap_inputs:
+ log.warning("skipping cudagraphs due to complex input striding")
+
+ result = align_inputs(compiled_fn, example_inputs, range(num_fixed))
+ _step_logger()(
+ logging.INFO,
+ "torchinductor done compiling "
+ f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
+ f"graph {graph_id}",
+ )
+ return result
+
+
+def clone_preserve_strides(x):
+ needed_size = (
+ sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
+ )
+ buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
+ return torch.as_strided(buffer, x.size(), x.stride())
+
+
+def align_inputs(model, inputs, static_input_idxs=()):
+ check_inputs = [
+ i
+ for i in range(len(inputs))
+ if (i not in static_input_idxs or (inputs[i].data_ptr() % ALIGNMENT) != 0)
+ and inputs[i].device.type == "cuda"
+ ]
+
+ if len(check_inputs) == 0:
+ return model
+
+ def run(*new_inputs):
+ for i in check_inputs:
+ if new_inputs[i].data_ptr() % ALIGNMENT:
+ if isinstance(new_inputs, tuple):
+ new_inputs = list(new_inputs)
+ new_inputs[i] = clone_preserve_strides(new_inputs[i])
+ new_inputs = [x.to("cuda") if is_unspec_input(x) else x for x in new_inputs]
+ return model(*new_inputs)
+
+ return run
+
+
+@dynamo_utils.dynamo_timed
+def cudagraphify(model, inputs, static_input_idxs=()):
+ # if using fake tensors, defer cudagraphs until we get real inputs at runtime
+ if not any(isinstance(inp, FakeTensor) for inp in inputs):
+ return cudagraphify_impl(model, inputs, static_input_idxs)
+
+ compiled_fn = None
+
+ def run(*new_inputs):
+ nonlocal compiled_fn
+ if compiled_fn is None:
+ with dynamo_utils.preserve_rng_state():
+ compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
+
+ return compiled_fn(*new_inputs)
+
+ return run
+
+
+def remove_unaligned_input_idxs(inputs, static_input_idxs):
+ """
+ We require all inputs to be aligned, so introduce a copy for any
+ that aren't.
+ """
+ aligned_static_input_idxs = {
+ idx for idx in static_input_idxs if (inputs[idx].data_ptr() % ALIGNMENT) == 0
+ }
+ if len(aligned_static_input_idxs) != len(static_input_idxs):
+ return aligned_static_input_idxs
+ return static_input_idxs
+
+
+def cudagraphify_impl(model, inputs, static_input_idxs=()):
+ """
+ Assumes inputs[static_input_idxs[i]] are always the same memory address
+ """
+ static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
+
+ def static_input(x):
+ """
+ Copy and input while preserving strides
+ """
+ # TODO(jansel): figure out why this version doesn't work:
+ # return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
+ needed_size = (
+ sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
+ )
+ buffer = torch.zeros(needed_size, dtype=x.dtype, device=x.device)
+ return torch.as_strided(buffer, x.size(), x.stride())
+
+ assert isinstance(inputs, (list, tuple))
+ # dynamo wraps unspec variable as 0 dim tensor on CPU, need to move to GPU explicitly
+ inputs = [x.to("cuda") if is_unspec_input(x) else x for x in inputs]
+
+ static_inputs = [
+ static_input(x) if idx not in static_input_idxs else x
+ for idx, x in enumerate(inputs)
+ ]
+
+ inps_expanded_dims = [
+ get_expanded_dims(x) if idx not in static_input_idxs else []
+ for idx, x in enumerate(inputs)
+ ]
+
+ # warmup
+ torch.cuda.synchronize()
+ stream = torch.cuda.Stream()
+ stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(stream):
+ model(*static_inputs)
+ stream.synchronize()
+ torch.cuda.current_stream().wait_stream(stream)
+ torch.cuda.synchronize()
+
+ # record
+ graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(graph, stream=stream):
+ static_outputs = model(*static_inputs)
+ if not isinstance(static_outputs, (list, tuple)):
+ static_outputs = (static_outputs,)
+
+ if config.size_asserts:
+
+ def run(*new_inputs):
+ assert len(static_inputs) == len(new_inputs)
+ for idx, (dst, src, expanded_dims) in enumerate(
+ zip(static_inputs, new_inputs, inps_expanded_dims)
+ ):
+ if idx in static_input_idxs:
+ assert dst.data_ptr() == src.data_ptr()
+ else:
+ # TODO - could make one single op of multiple slices
+ # and avoid dispatch.
+ # Could also pre-index the `dst` tensors
+ dst = index_expanded_dims(dst, expanded_dims)
+ src = index_expanded_dims(src, expanded_dims)
+ dst.copy_(src)
+ graph.replay()
+ return static_outputs
+
+ else:
+ copy_indices = [
+ idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
+ ]
+
+ def run(*new_inputs):
+ for idx in copy_indices:
+ src = index_expanded_dims(static_inputs[idx], inps_expanded_dims[idx])
+ dst = index_expanded_dims(new_inputs[idx], inps_expanded_dims[idx])
+ dst.copy_(src)
+ graph.replay()
+ return static_outputs
+
+ return run
+
+
+def count_tangents(fx_g: torch.fx.GraphModule):
+ """
+ Infers which inputs are static for a backwards graph
+ """
+
+ def is_not_gradout(x):
+ return "tangents" not in x.name
+
+ arg_count = 0
+ static_arg_idxs = []
+ for n in fx_g.graph.nodes:
+ if n.op == "placeholder":
+ if is_not_gradout(n):
+ static_arg_idxs.append(arg_count)
+ arg_count += 1
+
+ assert static_arg_idxs == list(range(len(static_arg_idxs)))
+ return len(static_arg_idxs)
+
+
+_graph_counter = itertools.count(0)
+
+
+def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]):
+ """Main entrypoint to a compile given FX graph"""
+
+ if not is_aot_autograd_safe_to_run(model_, example_inputs_):
+ log.warning("Aot Autograd is not safe to run, so falling back to eager")
+ return model_
+
+ functorch.compile.config.use_functionalize = True
+ functorch.compile.config.use_fake_tensor = True
+
+ with overrides.patch_functions():
+ model_ = normalize_ir(model_, example_inputs_)
+ model_ = overrides.replace_fx(model_)
+ num_example_inputs = len(example_inputs_)
+ cudagraphs = BoxedBool(config.triton.cudagraphs)
+
+ graph_id = next(_graph_counter)
+
+ @dynamo_utils.dynamo_timed
+ def fw_compiler(model: torch.fx.GraphModule, example_inputs):
+ fixed = len(example_inputs) - num_example_inputs
+ return compile_fx_inner(
+ model,
+ example_inputs,
+ num_fixed=fixed,
+ cudagraphs=cudagraphs,
+ graph_id=graph_id,
+ )
+
+ @dynamo_utils.dynamo_timed
+ def bw_compiler(model: torch.fx.GraphModule, example_inputs):
+ fixed = count_tangents(model)
+ return compile_fx_inner(
+ model,
+ example_inputs,
+ num_fixed=fixed,
+ cudagraphs=cudagraphs,
+ is_backward=True,
+ graph_id=graph_id,
+ )
+
+ with overrides.patch_functions():
+
+ # TODO: can add logging before/after the call to create_aot_dispatcher_function
+ # in functorch/_src/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
+ # once torchdynamo is merged into pytorch
+ return aot_autograd(
+ model_,
+ example_inputs_,
+ fw_compiler=make_boxed_compiler(fw_compiler),
+ bw_compiler=make_boxed_compiler(bw_compiler),
+ decompositions=select_decomp_table(),
+ partition_fn=functools.partial(
+ min_cut_rematerialization_partition, compiler="inductor"
+ ),
+ )
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
new file mode 100644
index 0000000000000..2850143c22e8f
--- /dev/null
+++ b/torch/_inductor/config.py
@@ -0,0 +1,153 @@
+import os
+
+# add some debug printouts
+debug = False
+
+# dead code elimination
+dce = False
+
+# assume input tensors are dynamic
+dynamic_shapes = True
+
+# assume weight tensors are fixed size
+static_weight_shapes = True
+
+# put correctness assertions in generated code
+size_asserts = True
+
+# enable loop reordering based on input orders
+pick_loop_orders = True
+
+# generate inplace computations
+inplace_buffers = False
+
+# codegen benchmark harness
+benchmark_harness = True
+
+# control store vs recompute heuristic
+realize_reads_threshold = 4
+realize_bytes_threshold = 2000
+
+# fallback to eager for random/dropout, this is slow but useful for debugging
+fallback_random = False
+
+# automatically create fallbacks when encountering an unhandled op
+implicit_fallbacks = True
+
+# Enables a fusion pass that groups nodes together before the scheduler
+prefuse_nodes = True
+
+# do bench to decide best layout, currently only for aten.conv
+tune_layout = False
+
+# fuse even in cases without common reads
+aggressive_fusion = False
+
+# how many nodes to allow into a single fusion
+max_fusion_size = 64
+
+# replace small reductions with pointwise, disable with `= 1`
+unroll_reductions_threshold = 8
+
+comment_origin = False
+
+compile_threads = 1
+
+# How to import torchinductor, either torchinductor or torch.inductor
+inductor_import = __name__.replace(".config", "")
+
+# How to import torchdynamo, either torchdynamo or torch.dynamo
+dynamo_import = inductor_import.replace("inductor", "dynamo")
+
+
+# config specific to codegen/cpp.pp
+class cpp:
+ # set to torch.get_num_threads()
+ threads = -1
+
+ # Assume number of threads is dynamic, don't specialize thread number.
+ # Kernels don't recompile on thread number changes with this flag on.
+ # For single-threaded workload, turning it on would incur a slight
+ # performance degradation.
+ dynamic_threads = False
+
+ simdlen = None
+ min_chunk_size = 4096
+ cxx = (
+ None, # download gcc12 from conda-forge if conda is installed
+ "g++-12",
+ "g++-11",
+ "g++-10",
+ "clang++",
+ "g++",
+ )
+
+
+# config specific to codegen/triton.py
+class triton:
+
+ # Use cudagraphs on output code
+ cudagraphs = True
+
+ # choose conv backend, "aten" or "triton" or "autotune"
+ convolution = "aten"
+
+ # choose mm backend, "aten" or "triton" or "autotune"
+ mm = "aten"
+
+ # Always load full blocks (rather than broadcasting inside the block)
+ # Set default as True because otherwise will encouter `map::at` error
+ # in triton if loading from 1-dim tensor using 2-dim pointer offset
+ # https://triton-lang.slack.com/archives/C01L1FLTX70/p1656023403343639
+ # could be set as False if triton fixes the bug later
+ dense_indexing = False
+
+ # limit tiling dimensions
+ max_tiles = 2
+
+ # use triton.autotune?
+ autotune = True
+
+ use_bmm = False
+
+ # should we stop a fusion to allow better tiling?
+ tiling_prevents_pointwise_fusion = True
+ tiling_prevents_reduction_fusion = True
+ # should we give different names to kernels
+ ordered_kernel_names = False
+ # should we use natural codegen for where, needs newer triton version
+ simple_where = True
+
+
+# create a directory containing lots of debug information
+class trace:
+ # master switch for all debugging flags below
+ enabled = os.environ.get("TORCHINDUCTOR_TRACE", "0") == "1"
+
+ # Save python logger call >=logging.DEBUG
+ debug_log = True
+
+ # Save python logger call >=logging.INFO
+ info_log = False
+
+ # Save input FX graph (post decomps)
+ fx_graph = True
+
+ # Save TorchInductor IR before fusion pass
+ ir_pre_fusion = True
+
+ # Save TorchInductor IR after fusion pass
+ ir_post_fusion = True
+
+ # Copy generated code to trace dir
+ output_code = True
+
+ # SVG figure showing post-fusion graph
+ graph_diagram = False
+
+ # Store cProfile (see snakeviz to view)
+ compile_profile = False
+
+ # Upload the .tar.gz file
+ # Needs to be overriden based on specific environment needs
+ upload_tar = None
diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py
new file mode 100644
index 0000000000000..d2bc9bcd73344
--- /dev/null
+++ b/torch/_inductor/debug.py
@@ -0,0 +1,325 @@
+import collections
+import contextlib
+import cProfile
+import functools
+import itertools
+import logging
+import os.path
+import pstats
+import shutil
+import subprocess
+from typing import Any, List
+
+from functorch.compile import draw_graph, get_graph_being_compiled
+
+import torch
+from torch import fx as fx
+from torch.fx.graph_module import GraphModule
+from torch.fx.passes.shape_prop import TensorMetadata
+from torch.fx.passes.tools_common import legalize_graph
+
+from . import config, ir
+from .codecache import cache_dir
+from .scheduler import (
+ BaseSchedulerNode,
+ ExternKernelSchedulerNode,
+ FusedSchedulerNode,
+ NopKernelSchedulerNode,
+ OutputNode,
+ SchedulerNode,
+ TemplateSchedulerNode,
+)
+from .utils import dynamo_config, dynamo_debug_utils, dynamo_utils
+from .virtualized import V
+
+log = logging.getLogger(__name__)
+
+
+@functools.lru_cache(None)
+def has_dot():
+ try:
+ subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
+ return True
+ except subprocess.SubprocessError:
+ return False
+
+
+def draw_buffers(nodes, print_graph=False, fname=None):
+ """
+ Draw a graph in fname.svg.
+ nodes is a list of SchedulerNode objects.
+ """
+ if not has_dot():
+ log.warning("draw_buffers() requires `graphviz` package")
+ return
+
+ if fname is None:
+ fname = get_graph_being_compiled()
+
+ graph = create_fx_from_snodes(nodes)
+
+ for node in graph.nodes:
+ if "fusion_meta" not in node.meta:
+ continue
+ group = node.meta["fusion_meta"].group
+ if isinstance(group, tuple):
+ group = group[1]
+
+ # gather meta data
+ dtype = None
+ if isinstance(node, ir.ComputedBuffer):
+ dtype = node.data.dtype
+
+ metadata = TensorMetadata(group, dtype, None, None, None, None, None)
+ node.meta["tensor_meta"] = metadata
+
+ if print_graph:
+ print(graph)
+
+ gm = GraphModule({}, graph)
+ legalize_graph(gm)
+ gm.graph.lint()
+ draw_graph(gm, fname, clear_meta=False)
+
+
+def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
+ """
+ Creates a FX Graph from a list of SchedulerNode objects.
+ """
+
+ def get_fake_func(name):
+ def func1(*args):
+ return 0
+
+ func1.__name__ = name
+ return func1
+
+ FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
+
+ func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
+ buf_to_fx_node = {}
+ graph = torch.fx.Graph()
+ first_node = None
+
+ outputs = []
+ group: Any = None
+ # create call_function node for each Buffer and Kernel
+ for snode in snodes:
+ if isinstance(snode, ExternKernelSchedulerNode):
+ node_type = "extern"
+ group = node_type
+ elif isinstance(snode, TemplateSchedulerNode):
+ node_type = "template"
+ group = node_type
+ elif isinstance(snode, NopKernelSchedulerNode):
+ node_type = "nop"
+ group = node_type
+ elif isinstance(snode, SchedulerNode):
+ node_type = "compute"
+ group = snode.group
+ elif isinstance(snode, FusedSchedulerNode):
+ node_type = "fused"
+ group = snode.group
+ else:
+ raise RuntimeError("Unknown node type")
+ node_func = func_dict[node_type]
+ fx_node = graph.call_function(node_func, args=(), kwargs=None)
+
+ def in_output(snode):
+ if isinstance(snode, FusedSchedulerNode):
+ return any([in_output(x) for x in snode.snodes])
+ return any([isinstance(user.node, OutputNode) for user in snode.users])
+
+ if in_output(snode):
+ outputs.append(fx_node)
+ name = snode.get_name()
+ fx_node.name = name
+
+ fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
+
+ if isinstance(snode, FusedSchedulerNode):
+ for x in snode.snodes:
+ buf_to_fx_node[x.get_name()] = fx_node
+ buf_to_fx_node[name] = fx_node
+
+ if first_node is None:
+ first_node = fx_node
+
+ # create edges between nodes
+ for snode in snodes:
+ name = snode.get_name()
+ deps = snode.read_writes.reads
+
+ fx_node = buf_to_fx_node[name]
+ new_args = []
+ for dep in deps:
+ if dep.name in buf_to_fx_node:
+ dep_node = buf_to_fx_node[dep.name]
+ else:
+ with graph.inserting_before(first_node):
+ dep_node = graph.placeholder(dep.name)
+ buf_to_fx_node[dep.name] = dep_node
+ new_args.append(dep_node)
+
+ fx_node.args = tuple(new_args)
+
+ graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
+ return graph
+
+
+class DebugContext:
+ _counter = itertools.count()
+
+ @staticmethod
+ def wrap(fn):
+ @functools.wraps(fn)
+ def inner(*args, **kwargs):
+ with DebugContext():
+ return fn(*args, **kwargs)
+
+ return dynamo_debug_utils.wrap_compiler_debug(inner, compiler_name="inductor")
+
+ @staticmethod
+ def create_debug_dir():
+ for n in DebugContext._counter:
+ dirname = os.path.join(cache_dir(), f"debug.{os.getpid()}.{n}")
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ return dirname
+
+ def __init__(self):
+ self._prof = None
+ self._path = None
+ self._stack = contextlib.ExitStack()
+
+ def rename(self, new_path: str):
+ if not self._path:
+ return
+ assert new_path.endswith(".debug"), new_path
+ if os.path.exists(new_path):
+ shutil.rmtree(new_path)
+ try:
+ os.rename(self._path, new_path)
+ self._path = new_path
+ except OSError:
+ # other OS might have troubling renaming dir with open files
+ pass
+
+ def fopen(self, filename):
+ assert self._path
+ return open(os.path.join(self._path, filename), "w")
+
+ def filename(self, suffix):
+ return os.path.join(self._path, suffix)
+
+ def upload_tar(self):
+ if config.trace.upload_tar is not None:
+ import tarfile
+
+ assert self._path
+ tar_file = os.path.join(
+ self._path, f"{os.path.basename(self._path)}.tar.gz"
+ )
+ with tarfile.open(tar_file, "w:gz") as tar:
+ tar.add(self._path, arcname=os.path.basename(self._path))
+ config.trace.upload_tar(tar_file)
+
+ def __enter__(self):
+ log = logging.getLogger(config.inductor_import)
+ if not log.handlers:
+ dynamo_utils.init_logging()
+
+ if config.debug:
+ dynamo_config.log_level = logging.DEBUG
+
+ self._stack.enter_context(V.set_debug_handler(self))
+
+ if not config.trace.enabled:
+ return
+
+ self._path = self.create_debug_dir()
+
+ if config.trace.debug_log:
+ self._setup_log_capture("debug.log", logging.DEBUG)
+ if config.trace.info_log:
+ self._setup_log_capture("info.log", logging.INFO)
+ if config.trace.compile_profile:
+ self._prof = cProfile.Profile()
+ self._prof.enable()
+
+ def _setup_log_capture(self, filename, level):
+ log = logging.getLogger(config.inductor_import)
+ fd = self._stack.enter_context(self.fopen(filename))
+ ch = logging.StreamHandler(fd)
+ ch.setLevel(level)
+ ch.setFormatter(
+ logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
+ )
+ log.addHandler(ch)
+ log.setLevel(min(log.level, level))
+ self._stack.callback(log.removeHandler, ch)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self._prof:
+ self._prof.disable()
+ self._save_profile_data()
+
+ if self._path:
+ self.upload_tar()
+ log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
+ self._stack.close()
+
+ def _save_profile_data(self):
+ self._prof.dump_stats(self.filename("compile.prof"))
+ with self.fopen("compile.stats") as fd:
+ stats = pstats.Stats(self._prof, stream=fd)
+ stats.strip_dirs()
+ stats.sort_stats("cumtime")
+ stats.print_stats(100)
+ stats.sort_stats("tottime")
+ stats.print_stats(100)
+
+ def __getattr__(self, name):
+ if config.trace.enabled and getattr(config.trace, name):
+ try:
+ return getattr(DebugFormatter(self), name)
+ except Exception:
+ log.warning("Ignoring exception in debug code", exc_info=True)
+ else:
+
+ def ignored(*args, **kwargs):
+ pass
+
+ return ignored
+
+
+SchedulerNodeList = List[Any]
+
+
+class DebugFormatter:
+ def __init__(self, handler):
+ self.fopen = handler.fopen
+ self.filename = handler.filename
+ self.handler = handler
+
+ def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
+ with self.fopen("fx_graph.py") as fd:
+ dynamo_debug_utils.save_graph_repro(fd, gm, inputs, "inductor")
+
+ def ir_pre_fusion(self, nodes: SchedulerNodeList):
+ self._write_ir("ir_pre_fusion.txt", nodes)
+
+ def ir_post_fusion(self, nodes: SchedulerNodeList):
+ self._write_ir("ir_post_fusion.txt", nodes)
+
+ def _write_ir(self, filename: str, nodes: SchedulerNodeList):
+ with self.fopen(filename) as fd:
+ for node in nodes:
+ fd.write(node.debug_str())
+ fd.write("\n\n\n")
+
+ def graph_diagram(self, nodes: SchedulerNodeList):
+ draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
+
+ def output_code(self, filename):
+ shutil.copy(filename, self.filename("output_code.py"))
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
new file mode 100644
index 0000000000000..ede2aca75bef7
--- /dev/null
+++ b/torch/_inductor/decomposition.py
@@ -0,0 +1,327 @@
+import functools
+import logging
+import math
+import numbers
+
+from functorch._src.aot_autograd import aot_autograd_decompositions
+
+import torch
+import torch._decomp as decomp
+from torch import Tensor
+from torch._decomp import get_decompositions
+from torch._prims_common import is_boolean_dtype, is_integer_dtype
+
+from . import config
+
+log = logging.getLogger(__name__)
+aten = torch.ops.aten
+log = logging.getLogger(__name__)
+
+decompositions = get_decompositions(
+ [
+ aten._adaptive_avg_pool2d_backward,
+ aten.addcmul,
+ aten.avg_pool2d_backward,
+ aten.binary_cross_entropy_with_logits,
+ aten.clamp_max,
+ aten.clamp_min,
+ aten.col2im,
+ aten.cudnn_batch_norm,
+ aten.cudnn_batch_norm_backward,
+ aten.detach,
+ aten.dot,
+ aten.elu,
+ aten.elu_backward,
+ aten._embedding_bag,
+ aten.embedding_dense_backward,
+ aten.expand_as,
+ aten.eye,
+ aten.flip,
+ aten._fused_moving_avg_obs_fq_helper,
+ aten.gelu,
+ aten.gelu_backward,
+ aten.glu_backward,
+ aten.grid_sampler_2d,
+ aten.hardsigmoid,
+ aten.hardsigmoid_backward,
+ aten.hardswish,
+ aten.hardswish_backward,
+ aten.hardtanh,
+ aten.hardtanh_backward,
+ aten.im2col,
+ aten.index_add,
+ aten.index_add_,
+ aten.index_select,
+ aten.l1_loss,
+ aten.leaky_relu,
+ aten.leaky_relu_backward,
+ aten.linalg_vector_norm,
+ aten.logit,
+ aten.logit_backward,
+ aten._log_softmax,
+ aten._log_softmax_backward_data,
+ aten.logsumexp.default,
+ aten.max_pool2d_with_indices_backward,
+ aten.mse_loss,
+ aten.mse_loss_backward,
+ aten.mv,
+ aten.narrow,
+ aten.native_batch_norm,
+ aten.native_batch_norm_backward,
+ aten.native_dropout_backward,
+ aten.native_group_norm,
+ aten.native_group_norm_backward,
+ aten.native_layer_norm,
+ aten.native_layer_norm_backward,
+ aten.new_empty,
+ aten.new_full,
+ aten.new_ones,
+ aten.nll_loss_backward,
+ aten.nll_loss_forward,
+ aten.norm,
+ aten.reflection_pad2d_backward,
+ aten._reshape_alias,
+ aten.select_backward,
+ aten.select_scatter,
+ aten.sigmoid_backward,
+ aten.silu_backward,
+ aten.slice_backward,
+ aten.sgn,
+ aten.std_mean.correction,
+ aten._softmax,
+ aten._softmax_backward_data,
+ aten.stack,
+ aten.t,
+ aten.tanh_backward,
+ aten.threshold_backward,
+ aten.transpose.int,
+ aten.tril.default,
+ aten.upsample_bilinear2d.vec,
+ aten.upsample_nearest2d_backward,
+ ]
+)
+decompositions.update(aot_autograd_decompositions)
+
+
+def register_decomposition(ops):
+ for op in [ops] if callable(ops) else ops:
+ if op in decompositions:
+ log.warning(f"duplicate decomp: {ops}")
+ return decomp.register_decomposition(ops, decompositions, disable_meta=True)
+
+
+@register_decomposition([aten.clamp])
+def clamp(x, min=None, max=None):
+ if min is not None:
+ x = torch.maximum(x, torch.tensor(min, dtype=x.dtype, device=x.device))
+ if max is not None:
+ x = torch.minimum(x, torch.tensor(max, dtype=x.dtype, device=x.device))
+ return x
+
+
+@register_decomposition([aten.tanh])
+def tanh(x):
+ return 2.0 / (1.0 + torch.exp(-2.0 * x)) - 1.0
+
+
+# TorchInductor-only decomposition. It should not be taken to core.
+# See https://github.com/pytorch/torchdynamo/pull/1120
+@register_decomposition([aten.floor_divide.default])
+def floordiv(a, b):
+ return aten.div.Tensor_mode(a, b, rounding_mode="floor")
+
+
+@register_decomposition([aten.addmm])
+def addmm(input, mat1, mat2, *, beta=1, alpha=1):
+ if config.triton.mm != "aten":
+ out = torch.mm(mat1, mat2)
+ if not isinstance(alpha, numbers.Number) or alpha != 1:
+ out = out * alpha
+ if not isinstance(beta, numbers.Number) or beta != 1:
+ input = input * beta
+ return input + out
+ else:
+ return NotImplemented # go directly to lowering
+
+
+@register_decomposition([aten.rsqrt])
+def rsqrt(x):
+ return torch.reciprocal(torch.sqrt(x))
+
+
+@register_decomposition([aten.log2])
+def log2(x):
+ return torch.log(x) * (1.0 / math.log(2.0))
+
+
+@register_decomposition([aten.round.decimals])
+def round_dec(x, decimals=0):
+ ten_pow_decimals = 10.0**decimals
+ return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
+
+
+@register_decomposition([aten.special_erf, aten.erf])
+def special_erf(x):
+ # TODO(jansel): this might be crazy slow. Triton doesn't have the
+ # cuda ::erf() builtin. I've made a feature request for this,
+ # so it may be coming soon.
+
+ # from https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/
+ a1 = 0.254829592
+ a2 = -0.284496736
+ a3 = 1.421413741
+ a4 = -1.453152027
+ a5 = 1.061405429
+ p = 0.3275911
+
+ sign = torch.sign(x)
+ x = torch.abs(x)
+
+ # A & S 7.1.26
+ t = 1.0 / (1.0 + p * x)
+ y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * torch.exp(-x * x)
+
+ return sign * y
+
+
+@register_decomposition([aten.rsub.Tensor, aten.rsub.Scalar])
+def rsub(a, b):
+ if isinstance(b, numbers.Number):
+ b = torch.tensor(b, dtype=a.dtype, device=a.device)
+ return b - a
+
+
+@register_decomposition([aten.masked_fill])
+def masked_fill(value, mask, other):
+ if isinstance(other, numbers.Number):
+ other = torch.tensor(other, dtype=value.dtype, device=value.device)
+ if other.device != value.device and other.numel() == 1:
+ other = other.to(value.device)
+ value, mask, other = torch.broadcast_tensors(value, mask, other)
+ return torch.where(mask, other, value)
+
+
+@register_decomposition([aten.nan_to_num])
+def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
+ if is_boolean_dtype(x.dtype) or is_integer_dtype(x.dtype):
+ return x
+
+ if nan is None:
+ nan = 0.0
+ if posinf is None:
+ posinf = torch.finfo(x.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(x.dtype).min
+ nan, posinf, neginf = (
+ torch.tensor(v, dtype=x.dtype, device=x.device) for v in (nan, posinf, neginf)
+ )
+ x = torch.where(x != x, nan, x)
+ x = torch.where(x == float("inf"), posinf, x)
+ x = torch.where(x == float("-inf"), neginf, x)
+ return x
+
+
+@register_decomposition([aten.all.default])
+def all(input):
+ return torch.logical_not(torch.any(torch.logical_not(input)))
+
+
+@register_decomposition([aten.all.dim])
+def all_dim(input, dim, keeepdim=False):
+ return torch.logical_not(torch.any(torch.logical_not(input), dim, keeepdim))
+
+
+@register_decomposition(aten.hardswish_)
+def hardswish_(x):
+ return x.copy_(aten.hardswish(x))
+
+
+@register_decomposition(aten.hardtanh_)
+def hardtanh_(x, min_val=-1, max_val=1):
+ return x.copy_(aten.hardtanh(x, min_val, max_val))
+
+
+@register_decomposition(aten.leaky_relu_)
+def leaky_relu_(x, negative_slope=0.01):
+ return x.copy_(aten.leaky_relu(x, negative_slope))
+
+
+@register_decomposition(aten.silu_)
+def silu_(x):
+ return x.copy_(aten.silu(x))
+
+
+@register_decomposition(aten.masked_fill_)
+def masked_fill_(x, mask, value):
+ return x.copy_(aten.masked_fill(x, mask, value))
+
+
+@register_decomposition([aten.log1p])
+def log1p(x):
+ return torch.log(x + 1)
+
+
+@register_decomposition([aten.baddbmm])
+def baddbmm(self, batch1, batch2, beta=1, alpha=1):
+ result = torch.bmm(batch1, batch2)
+ if not isinstance(alpha, numbers.Number) or alpha != 1:
+ result = result * alpha
+ if not isinstance(beta, numbers.Number) or beta != 1:
+ self = self * beta
+ return self + result
+
+
+@register_decomposition([aten.conj_physical])
+def conj_physical(self):
+ assert not self.is_complex(), "TODO: implement this"
+ return self
+
+
+@register_decomposition([aten.lift, aten.detach_])
+def lift(self):
+ return self
+
+
+@register_decomposition([aten.fill.Scalar])
+def fill_scalar(self, value):
+ return torch.full_like(self, value)
+
+
+@register_decomposition([aten.fill.Tensor])
+def fill_tensor(self, value: Tensor):
+ assert value.dim() == 0, "aten.fill.Tensor only supports 0-dimension value tensor"
+ return torch.full_like(self, value.item())
+
+
+@register_decomposition([aten.bernoulli.default])
+def bernoulli(self, *, generator=None):
+ assert generator is None
+ return torch.rand_like(self, dtype=torch.float32) < self
+
+
+"""
+Some decomps result in differences from eager related to randomness.
+We put these decomps in a separate table `extra_random_decomps` to allow
+turning them on and off via `config.fallback_random`.
+"""
+extra_random_decomps = get_decompositions([aten.native_dropout])
+register_extra_random_decomp = functools.partial(
+ decomp.register_decomposition, registry=extra_random_decomps, disable_meta=True
+)
+
+
+@register_extra_random_decomp([aten.bernoulli_])
+def bernoulli_(self, p=0.5):
+ return self.copy_(torch.rand_like(self) < p)
+
+
+@functools.lru_cache(None)
+def fast_random_decomps():
+ return {**decompositions, **extra_random_decomps}
+
+
+def select_decomp_table():
+ """decomps can change based on config"""
+ if config.fallback_random:
+ return decompositions
+ return fast_random_decomps()
diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py
new file mode 100644
index 0000000000000..253bef1236b53
--- /dev/null
+++ b/torch/_inductor/dependencies.py
@@ -0,0 +1,251 @@
+import collections
+import dataclasses
+import itertools
+import logging
+import typing
+from typing import Callable, cast, Dict, List, Optional, Set, Tuple, Union
+
+import sympy
+
+from .codegen.common import index_prevent_reordering
+from .utils import sympy_product, sympy_str, sympy_subs, VarRanges
+from .virtualized import V
+
+log = logging.getLogger(__name__)
+
+Dep = Union["MemoryDep", "StarDep"]
+
+
+class MemoryDep(typing.NamedTuple):
+ name: str
+ index: sympy.Expr # type: ignore[assignment]
+ size: Tuple[sympy.Expr, ...]
+
+ def broadcast_extend_sizes(self, extra_sizes: List[sympy.Expr]) -> "MemoryDep":
+ size = (*self.size, *[x for x in extra_sizes if x != 1])
+ return MemoryDep(self.name, self.index, size)
+
+ def maybe_swap_sizes(self) -> "MemoryDep":
+ # swap only in simple cases where index is trivial and
+ # there are just 2 sizes
+ if (
+ len(self.size) == 2
+ and len(self.index.args) == 0
+ and cast(sympy.Symbol, self.index).name == canonicalization_prefix() + "0"
+ ):
+ c = canonicalization_prefix()
+ size = (self.size[1], self.size[0])
+ s0 = sympy.Symbol(c + "0")
+ s1 = sympy.Symbol(c + "1")
+ index = sympy_subs(self.index, {s0: s1})
+ return MemoryDep(self.name, index, size)
+ else:
+ return self
+
+ def strip_last_size(self) -> "MemoryDep":
+ nsizes = len(self.size)
+ if not (nsizes >= 1 and len(self.index.args) <= nsizes - 1):
+ return self
+ # make sure last dim index is not used
+ prefix = canonicalization_prefix()
+ len_prefix = len(prefix)
+ prefixes = [
+ fs.name[:len_prefix]
+ for fs in cast(Set[sympy.Symbol], self.index.free_symbols)
+ ]
+ assert (
+ len(prefixes) == 0 or prefix in prefixes
+ ), "index expression should contain canonicalized symbols"
+ last_index = f"{prefix}{len(self.size)-1}"
+ if last_index not in self.index.free_symbols:
+ size = self.size[:-1]
+ return MemoryDep(self.name, self.index, size)
+ else:
+ return self
+
+ def rename(self, renames: Dict[str, str]) -> "MemoryDep":
+ if self.name in renames:
+ return MemoryDep(renames[self.name], self.index, self.size)
+ return self
+
+ def numel_hint(self):
+ vars = set(self.index.free_symbols)
+ return V.graph.sizevars.size_hint(
+ sympy_product([s for s in self.size if s in vars])
+ )
+
+ def is_contiguous(self) -> bool:
+ return isinstance(self.index, (sympy.Symbol, sympy.Integer))
+
+
+class StarDep(typing.NamedTuple):
+ # depends on the entire buffer
+ name: str
+
+ def rename(self, renames: Dict[str, str]) -> "StarDep":
+ if self.name in renames:
+ return StarDep(renames[self.name])
+ return self
+
+ def numel_hint(self):
+ return 1
+
+ def is_contiguous(self) -> bool:
+ return False
+
+
+class IndexExprDep(typing.NamedTuple):
+ index: sympy.Expr # type: ignore[assignment]
+ size: Tuple[sympy.Expr, ...]
+
+
+@dataclasses.dataclass
+class ReadWrites:
+ reads: Set[Dep]
+ writes: Set[Dep]
+ index_exprs: Set[IndexExprDep]
+ range_vars: Optional[List[sympy.Expr]] = None
+ var_ranges: Optional[VarRanges] = None
+
+ def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
+ return ReadWrites(
+ {dep.rename(renames) for dep in self.reads},
+ {dep.rename(renames) for dep in self.writes},
+ self.index_exprs,
+ self.range_vars,
+ self.var_ranges,
+ )
+
+ def with_read(self, name: str) -> "ReadWrites":
+ assert isinstance(name, str)
+ return ReadWrites(
+ set.union(self.reads, {StarDep(name)}),
+ self.writes,
+ self.index_exprs,
+ self.range_vars,
+ self.var_ranges,
+ )
+
+ def merge(self, other):
+ reads = set.union(self.reads, other.reads)
+ writes = set.union(self.writes, other.writes)
+ index_exprs = set.union(self.index_exprs, other.index_exprs)
+ return ReadWrites(
+ reads - writes,
+ writes,
+ index_exprs,
+ )
+
+
+class RecordLoadStore(V.MockHandler): # type: ignore[name-defined]
+ def __init__(self, var_ranges: VarRanges, normalize: bool):
+ super(RecordLoadStore, self).__init__()
+ self._reads: Set[MemoryDep] = set()
+ self._writes: Set[MemoryDep] = set()
+ self._index_exprs: Set[IndexExprDep] = set()
+ self._var_ranges: VarRanges = var_ranges
+ self._normalize: bool = normalize
+
+ def canonicalize(
+ self, index: sympy.Expr
+ ) -> Tuple[sympy.Expr, Tuple[sympy.Expr, ...]]:
+ sizes = list(self._var_ranges.values())
+ sizes = [V.graph.sizevars.simplify(x) for x in sizes]
+ if not self._normalize:
+ return index, tuple([x for x in sizes if x != 1])
+
+ # Try to further simplify the indexes even if simplify_loops didn't
+ # convert it to the simpliest form because of the interference from
+ # different indexing formulas.
+ index_vars = list(self._var_ranges.keys())
+ new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
+ index_vars,
+ sizes,
+ index_prevent_reordering([index], index_vars, sizes),
+ )
+
+ # assign new variables each dimension to deal with numbering mismatches
+ # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
+ _, add_var = var_builder(canonicalization_prefix())
+ replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
+
+ index = sympy_subs(sympy.expand(index), replacement)
+ return index, tuple(new_sizes)
+
+ def load(self, name: str, index: sympy.Expr) -> str:
+ canonicalized_index, canonicalized_size = self.canonicalize(index)
+ self._reads.add(MemoryDep(name, canonicalized_index, canonicalized_size))
+ return f"load({name}, {sympy_str(index)})"
+
+ def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
+ canonicalized_index, canonicalized_size = self.canonicalize(index)
+ self._writes.add(MemoryDep(name, canonicalized_index, canonicalized_size))
+ return f"store({name}, {sympy_str(index)}, {value}, {mode})"
+
+ def reduction(
+ self, name: str, dtype, src_dtype, reduction_type, index, value
+ ) -> str:
+ return self.store(name, index, f"reduce_{reduction_type})({value})")
+
+ def index_expr(self, index: sympy.Expr, dtype) -> str:
+ canonicalized_index, canonicalized_size = self.canonicalize(index)
+ self._index_exprs.add(IndexExprDep(canonicalized_index, canonicalized_size))
+ return f"index_expr({sympy_str(index)}, {dtype})"
+
+
+def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
+ cnt = itertools.count()
+ var_ranges: VarRanges = collections.OrderedDict()
+
+ def add_var(length: sympy.Expr) -> sympy.Symbol:
+ v = sympy.Symbol(f"{prefix}{next(cnt)}")
+ var_ranges[v] = length
+ return v
+
+ return var_ranges, add_var
+
+
+def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
+ var_ranges, add_var = var_builder(prefix)
+ args: List[List[sympy.Symbol]] = []
+ for size in argsizes:
+ args.append(list(map(add_var, size)))
+ return args, var_ranges
+
+
+def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
+ from .ir import SqueezeView
+
+ var_ranges, add_var = var_builder(prefix)
+ args: List[List[sympy.Expr]] = []
+ new_sizes: List[List[sympy.Expr]] = []
+ for size in argsizes:
+ new_size, reindex = SqueezeView.squeezer(size)
+ new_sizes.append(new_size)
+ args.append(reindex(list(map(add_var, new_size))))
+ return new_sizes, args, var_ranges
+
+
+def extract_read_writes(
+ fn: Callable,
+ *argsizes: Tuple[sympy.Expr, ...],
+ normalize: bool = False,
+ prefix: str = "d",
+):
+ _, args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
+ rw = RecordLoadStore(var_ranges, normalize=normalize)
+ with V.set_ops_handler(rw): # type: ignore[call-arg]
+ fn(*args)
+
+ if normalize:
+ range_vars = [] # Number of vars could differ due to normalization
+ else:
+ range_vars = [*itertools.chain(*args)]
+
+ return ReadWrites(
+ set(rw._reads), set(rw._writes), rw._index_exprs, range_vars, var_ranges
+ )
+
+
+def canonicalization_prefix():
+ return "c"
diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py
new file mode 100644
index 0000000000000..8b70874d9542d
--- /dev/null
+++ b/torch/_inductor/exc.py
@@ -0,0 +1,85 @@
+import os
+import textwrap
+from functools import lru_cache
+
+from . import config
+
+if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
+
+ @lru_cache(None)
+ def _record_missing_op(target):
+ with open("/tmp/missing_ops.txt", "a") as fd:
+ fd.write(str(target) + "\n")
+
+else:
+
+ def _record_missing_op(target):
+ pass
+
+
+class OperatorIssue(RuntimeError):
+ @staticmethod
+ def operator_str(target, args, kwargs):
+ lines = [f"target: {target}"] + [
+ f"args[{i}]: {arg}" for i, arg in enumerate(args)
+ ]
+ if kwargs:
+ lines.append(f"kwargs: {kwargs}")
+ return textwrap.indent("\n".join(lines), " ")
+
+
+class MissingOperatorWithoutDecomp(OperatorIssue):
+ def __init__(self, target, args, kwargs):
+ _record_missing_op(target)
+ super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
+
+
+class MissingOperatorWithDecomp(OperatorIssue):
+ def __init__(self, target, args, kwargs):
+ _record_missing_op(target)
+ super().__init__(
+ f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
+ + textwrap.dedent(
+ f"""
+
+ There is a decomposition available for {target} in
+ torch._decomp.get_decompositions(). Please add this operator to the
+ `decompositions` list in {config.inductor_import}.decompositions
+ """
+ )
+ )
+
+
+class LoweringException(OperatorIssue):
+ def __init__(self, exc, target, args, kwargs):
+ super().__init__(
+ f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
+ )
+
+
+class InvalidCxxCompiler(RuntimeError):
+ def __init__(self):
+ from . import config
+
+ super().__init__(
+ f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}"
+ )
+
+
+class CppCompileError(RuntimeError):
+ def __init__(self, cmd, output):
+ super().__init__(
+ textwrap.dedent(
+ """
+ C++ compile error
+
+ Command:
+ {cmd}
+
+ Output:
+ {output}
+ """
+ )
+ .strip()
+ .format(cmd=" ".join(cmd), output=output.decode("utf-8"))
+ )
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
new file mode 100644
index 0000000000000..61cf20743fe21
--- /dev/null
+++ b/torch/_inductor/graph.py
@@ -0,0 +1,354 @@
+import logging
+import operator
+import os
+import time
+
+import sympy
+from sympy import Integer
+
+import torch
+import torch.fx
+from torch._decomp import get_decompositions
+from torch.utils._mode_utils import no_dispatch
+
+from . import config, ir
+from .codegen.wrapper import WrapperCodeGen
+from .exc import (
+ LoweringException,
+ MissingOperatorWithDecomp,
+ MissingOperatorWithoutDecomp,
+)
+from .ir import Constant, FixedLayout, InputBuffer, TensorBox
+from .lowering import lowerings, make_fallback, needs_realized_inputs
+from .sizevars import SizeVarAllocator
+from .utils import dynamo_logging, dynamo_utils
+from .virtualized import V
+
+log = logging.getLogger(__name__)
+
+
+class GraphLowering(torch.fx.Interpreter):
+ def symbolic_sizes_strides(self, ex: torch.Tensor):
+ """
+ Support dynamic shapes and dynamic strides by assigning variables
+ to each dimension. We duck-shape tensors, so if two tensors
+ have the same size they get assigned the same symbolic variable.
+ """
+ size = [self.sizevars[i] for i in ex.size()]
+ stride = [None] * len(size)
+ for i, val in enumerate(ex.stride()):
+ if val in (0, 1):
+ stride[i] = Integer(val)
+ while any(x is None for x in stride):
+ candidates = {
+ ex.size(i) * ex.stride()[i]: size[i] * stride[i]
+ for i in range(len(size))
+ if stride[i] is not None and ex.stride()[i] >= 0
+ }
+ # iterate over unbound strides in sorted order
+ val_list = sorted(
+ [(ex.stride()[i], i) for i in range(len(stride)) if stride[i] is None]
+ )
+ for _, i in val_list:
+ if stride[i] is None and ex.stride()[i] in candidates:
+ stride[i] = candidates[ex.stride()[i]]
+ candidates[ex.size(i) * ex.stride()[i]] = size[i] * stride[i]
+ if any(x is None for x in stride):
+ # bind the smallest unbound stride to a new variable
+ val, i = sorted(
+ [
+ (ex.stride()[i], i)
+ for i in range(len(stride))
+ if stride[i] is None
+ ]
+ )[0]
+ stride[i] = self.sizevars[val]
+ return size, stride
+
+ def static_sizes_strides(self, ex: torch.Tensor):
+ """
+ Primarily used to weights
+ """
+ size = [sympy.Integer(i) for i in ex.size()]
+ stride = [sympy.Integer(i) for i in ex.stride()]
+ return size, stride
+
+ def __init__(self, gm: torch.fx.GraphModule, num_dynamic_inputs=None):
+ super().__init__(gm)
+ self.sizevars = SizeVarAllocator("s")
+ self.graph_inputs = {}
+ self.graph_inputs_original = {}
+ self.graph_outputs = None
+ self.device_types = set()
+ self.buffers = []
+ self.constants = {}
+ self.removed_buffers = set()
+ self.wrapper_code = None
+ self.num_dynamic_inputs = num_dynamic_inputs
+ self.num_static_inputs = None
+ self.mutated_inputs = set()
+ self.unaligned_buffers = set()
+ self.randomness_offset = sympy.Integer(0)
+ self.randomness_seeds = []
+ self.name_to_buffer = {}
+ self.creation_time = time.time()
+
+ def get_dtype(self, buffer_name):
+ if buffer_name in self.constants:
+ return self.constants[buffer_name].dtype
+ if buffer_name in self.name_to_buffer:
+ return self.name_to_buffer[buffer_name].get_dtype()
+ if buffer_name in self.graph_inputs:
+ return self.graph_inputs[buffer_name].get_dtype()
+ raise KeyError(f"could not find {buffer_name}")
+
+ def random_seed_buffer(self, device: torch.device):
+ """
+ Return a device-unique 1-element tensor storing our RNG seed.
+ This will get initialized at the start of each graph in
+ `wrapper.py`.
+
+ Note this is only used by cuda backends. The CPU backend handles
+ RNG seeds as a sizevar.
+ """
+ name = f"seed_{device.type}_{device.index}"
+ if name not in self.constants:
+ self.constants[name] = torch.zeros((), device=device, dtype=torch.int64)
+ self.randomness_seeds.append(name)
+
+ return ir.RandSeedBuffer(
+ name=name,
+ layout=ir.FixedLayout(
+ device=device,
+ dtype=torch.int64,
+ size=[],
+ stride=[],
+ ),
+ )
+
+ def increment_randomness_offset(self, numel):
+ """
+ A global counter of how many random numbers we have handed out so far.
+ """
+ offset = self.randomness_offset
+ self.randomness_offset = offset + numel
+ return offset
+
+ @dynamo_utils.dynamo_timed
+ def run(self, *args):
+ if self.num_dynamic_inputs is None:
+ self.num_dynamic_inputs = len(args)
+ self.num_static_inputs = len(args) - self.num_dynamic_inputs
+ return super().run(*args)
+
+ def register_buffer(self, buffer: ir.ComputedBuffer):
+ name = f"buf{len(self.buffers)}"
+ self.buffers.append(buffer)
+ self.name_to_buffer[name] = buffer
+ return name
+
+ def realize_users_of(self, name: str):
+ """
+ When a buffer is mutated we need to make sure all the reads to
+ the old version are realized before the mutation happens.
+ """
+ assert isinstance(name, str)
+
+ def visit(value):
+ if isinstance(value, (list, tuple)):
+ return [visit(x) for x in value]
+ if isinstance(value, ir.IRNode):
+ if value.is_user_of(name):
+ value.realize()
+ return value
+
+ for key, value in self.env.items():
+ try:
+ visit(value)
+ except Exception:
+ log.warning("error in realize_users_of", exc_info=True)
+
+ def add_tensor_constant(self, data):
+ def allocate():
+ for name, value in self.constants.items():
+ if (
+ data.size() == value.size()
+ and data.stride() == value.stride()
+ and data.dtype == value.dtype
+ and data.device == value.device
+ and torch.eq(data, value).all()
+ ):
+ return name
+ name = f"constant{len(self.constants)}"
+ self.constants[name] = data
+ return name
+
+ return TensorBox.create(
+ ir.ConstantBuffer(
+ allocate(),
+ FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
+ )
+ )
+
+ def constant_name(self, name: str, device_override: torch.device):
+ """
+ We AOT copy constants to the devices they are needed on.
+ If device_override doesn't match the constant's device, then
+ copy it and return a different name.
+ """
+ if self.constants[name].device == device_override or device_override is None:
+ return name
+ alt_name = f"{name}_{device_override.type}{device_override.index or 0}"
+ if alt_name not in self.constants:
+ self.constants[alt_name] = self.constants[name].to(device_override)
+ return alt_name
+
+ def placeholder(self, target, args, kwargs):
+ example: torch.Tensor = super().placeholder(target, args, kwargs)
+ if config.static_weight_shapes and (
+ len(self.graph_inputs) < self.num_static_inputs or not config.dynamic_shapes
+ ):
+ # the first N inputs are weights
+ sizes, strides = self.static_sizes_strides(example)
+ else:
+ sizes, strides = self.symbolic_sizes_strides(example)
+ # TODO(jansel): handle input aliasing
+ tensor = TensorBox.create(
+ InputBuffer(
+ target,
+ FixedLayout(example.device, example.dtype, sizes, strides),
+ )
+ )
+ self.graph_inputs[target] = tensor
+ self.graph_inputs_original[target] = tensor.data.data
+ if example.dim() != 0:
+ self.device_types.add(example.device.type)
+ return tensor
+
+ def call_function(self, target, args, kwargs):
+ if target is operator.getitem and isinstance(args[0], (list, tuple)):
+ return super().call_function(target, args, kwargs)
+
+ if target not in lowerings:
+ if config.implicit_fallbacks:
+ error = (
+ MissingOperatorWithDecomp
+ if get_decompositions([target])
+ else MissingOperatorWithoutDecomp
+ )
+ log.warning(
+ "Creating implicit fallback for:\n%s",
+ error.operator_str(target, args, kwargs),
+ )
+ make_fallback(target)
+ elif get_decompositions([target]):
+ # There isn't a good way to dynamically patch this in
+ # since AOT Autograd already ran. The error message tells
+ # the user how to fix it.
+ raise MissingOperatorWithDecomp(target, args, kwargs)
+ else:
+ raise MissingOperatorWithoutDecomp(target, args, kwargs)
+
+ try:
+ return lowerings[target](*args, **kwargs)
+ except Exception as e:
+ raise LoweringException(e, target, args, kwargs) from e
+
+ def get_attr(self, target, args, kwargs):
+ # this is a constant
+ value = getattr(self.module, target)
+ with no_dispatch():
+ if value.shape == ():
+ return Constant(value.item(), value.dtype, value.device)
+ if len(value.shape) == 1 and value.shape[0] <= 8:
+ # tensor lowering has constant inlining logic
+ from .lowering import tensor
+
+ return tensor(value.tolist(), dtype=value.dtype, device=value.device)
+
+ return self.add_tensor_constant(value)
+
+ def call_module(self, target, args, kwargs):
+ raise AssertionError()
+
+ def call_method(self, target, args, kwargs):
+ raise AssertionError()
+
+ def output(self, target, args, kwargs):
+ result = super().output(target, args, kwargs)
+ assert isinstance(result, (tuple, list)), type(result)
+ assert all(
+ isinstance(x, (TensorBox, ir.Constant, type(None), ir.ConstantBuffer))
+ for x in result
+ ), result
+ self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result]
+ for name, value in self.graph_inputs.items():
+ value.realize()
+ assert isinstance(value, TensorBox)
+ value = value.data
+ assert isinstance(value, ir.StorageBox)
+ value_storage_box = value
+ value = value.data
+ if not isinstance(value, InputBuffer) or value.get_name() != name:
+ # one of our inputs was mutated, need to turn that into a copy
+ ir.MutationLayout.realize_into(value, self.graph_inputs_original[name])
+ # replace output with mutated input
+ try:
+ ind = self.graph_outputs.index(value_storage_box)
+ self.graph_outputs[ind] = self.graph_inputs_original[name]
+ except ValueError:
+ pass
+
+ self.finalize()
+
+ def finalize(self):
+ for buf in self.buffers:
+ buf.decide_layout()
+
+ def run_node(self, n: torch.fx.Node):
+ with ir.IRNode.current_origins({n}):
+ result = super().run_node(n)
+ num_users = len(set(n.users))
+ if num_users > 1 and isinstance(result, TensorBox):
+ for user in n.users:
+ if user.target in needs_realized_inputs or user.op == "output":
+ result.realize_hint()
+
+ # TODO(jansel): introduce a store vs inline choice
+ result.mark_reuse(len(n.users))
+ return result
+
+ def codegen(self):
+ from .scheduler import Scheduler
+
+ self.wrapper_code = WrapperCodeGen()
+ self.scheduler = Scheduler(self.buffers)
+ self.scheduler.codegen()
+ return self.wrapper_code.generate()
+
+ @dynamo_utils.dynamo_timed
+ def compile_to_module(self):
+ from .codecache import PyCodeCache
+
+ code = self.codegen()
+ if config.debug:
+ print(code)
+
+ mod = PyCodeCache.load(code)
+ for name, value in self.constants.items():
+ setattr(mod, name, value)
+
+ log.log(dynamo_logging.CODE, "Output code: %s", mod.__file__)
+ V.debug.output_code(mod.__file__)
+ V.debug.rename(os.path.splitext(mod.__file__)[0] + ".debug")
+ return mod
+
+ def compile_to_fn(self):
+ return self.compile_to_module().call
+
+ def get_output_names(self):
+ return [
+ node.get_name()
+ for node in self.graph_outputs
+ if not isinstance(node, ir.NoneAsConstantBuffer)
+ ]
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
new file mode 100644
index 0000000000000..e44a9fb01850a
--- /dev/null
+++ b/torch/_inductor/ir.py
@@ -0,0 +1,3566 @@
+import contextlib
+import dataclasses
+import functools
+import itertools
+import logging
+import re
+import textwrap
+from collections import OrderedDict
+from enum import Enum
+from functools import partial
+from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
+from unittest.mock import patch
+
+import numpy
+import sympy
+from sympy import Expr, Integer
+
+import torch.fx
+import torch.utils._pytree as pytree
+from torch._prims_common import is_boolean_dtype, is_float_dtype
+
+from . import config, dependencies
+from .codegen.common import index_prevent_reordering
+from .dependencies import extract_read_writes, var_builder
+from .utils import cache_on_self, sympy_dot, sympy_product, sympy_subs
+from .virtualized import ops, V
+
+log = logging.getLogger(__name__)
+indent = functools.partial(textwrap.indent, prefix=" ")
+
+
+def inverse_reorder(order):
+ inv_order = dict(zip(order, range(len(order))))
+
+ def reindex(index):
+ assert len(index) == len(inv_order)
+ return [index[inv_order[i]] for i in range(len(index))]
+
+ return reindex
+
+
+def same_reorder(order):
+ def reindex(index):
+ assert len(index) == len(order)
+ return [index[order[i]] for i in range(len(index))]
+
+ return reindex
+
+
+def fuse_reindexing(reindex1, reindex2):
+ def reindex(index):
+ return reindex1(reindex2(index))
+
+ return reindex
+
+
+def stride_order2fill_order(order):
+ """
+ Convert stride order to fill order
+ For channel last format,
+ stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
+ """
+ lookup = {pos: idx for idx, pos in enumerate(order)}
+ fill_order = [lookup[i] for i in range(len(order))]
+ return fill_order
+
+
+def reads_from_conv(buf, var_ranges):
+ """
+ return:
+ if reads_from_conv: boolean
+ the new memory_addr: Sympy Expression
+ """
+ if buf is None:
+ return False, None
+ if isinstance(buf, Convolution):
+ indexer = buf.layout.as_fixed().make_indexer()
+ index_vars = sorted(var_ranges, key=lambda var: var.name)
+ index = indexer(index_vars)
+ return True, index
+ # for case like
+ # buf0 = conv(x, w)
+ # return torch.cat([buf0, buf1]), torch.cat([buf0, buf2])
+ # Because of ConcatKernel, it will create two bufs buf3 and 4
+ # buf3 has the AliasedLayout which reads from buf0(Convolution)
+ # but buf4 is a copy of buf3 which reads from buf3
+ # we want to know that buf4 also follows buf0 conv's layout
+ if isinstance(buf.layout, AliasedLayout):
+ reads = buf.get_read_writes().reads
+ reads_bufs = [
+ V.graph.name_to_buffer[r.name]
+ if r.name in V.graph.name_to_buffer.keys()
+ else None
+ for r in reads
+ ]
+ for reads_buf in reads_bufs:
+ read_from_conv, addr = reads_from_conv(reads_buf, var_ranges)
+ if read_from_conv:
+ return True, addr
+ return False, None
+
+
+def layout_priority_idx(reads_bufs, memory_addrs, var_ranges):
+ """
+ if reads from conv that needs to use specific layout
+ return:
+ priority_idx regarding memory_addrs idx
+ memory_addrs - update memory_addrs with the true addr if needed
+ """
+
+ priority_idx = []
+ for i, reads_buf in enumerate(reads_bufs):
+ read_from_conv, mem_addr = reads_from_conv(reads_buf, var_ranges)
+ if read_from_conv:
+ priority_idx.append(i)
+ memory_addrs[i] = mem_addr
+ return priority_idx, memory_addrs
+
+
+class ModularIndexing(sympy.Function):
+ """
+ ModularIndexing(a, b, c) => (a // b) % c
+ """
+
+ nargs = (3,)
+
+ @classmethod
+ def eval(cls, base, divisor, modulus):
+ if base == 0 or modulus == 1:
+ return sympy.Integer(0)
+
+ if (
+ isinstance(base, sympy.Integer)
+ and isinstance(divisor, sympy.Integer)
+ and isinstance(modulus, sympy.Integer)
+ ):
+ return (base // divisor) % modulus
+
+ if divisor != 1:
+ gcd = sympy.gcd(base, divisor)
+ if gcd != 1:
+ return ModularIndexing(base / gcd, divisor / gcd, modulus)
+
+ if isinstance(base, sympy.Add):
+ new_terms = []
+ for term in base.args:
+ if sympy.gcd(term, modulus * divisor) != modulus * divisor:
+ new_terms.append(term)
+ if len(new_terms) != len(base.args):
+ return ModularIndexing(sum(new_terms), divisor, modulus)
+
+ if isinstance(base, IndexingDiv):
+ return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
+
+
+class IndexingDiv(sympy.Function):
+ """
+ a // b used in indexing where we need to be careful about simplification.
+ We don't use sympy.FloorDiv to bypass some simplification rules.
+ """
+
+ nargs = (2,)
+
+ @classmethod
+ def eval(cls, base, divisor):
+ if base == 0:
+ return sympy.Integer(0)
+ if divisor == 1:
+ return base
+ if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
+ return base // divisor
+ if isinstance(base, IndexingDiv):
+ return IndexingDiv(base.args[0], base.args[1] * divisor)
+
+ if isinstance(base, sympy.Add):
+ for a in base.args:
+ gcd = sympy.gcd(a, divisor)
+ if gcd == divisor:
+ return IndexingDiv(base - a, divisor) + a / gcd
+ gcd = sympy.gcd(base, divisor)
+ if gcd != 1:
+ return IndexingDiv(
+ sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
+ )
+
+
+class CleanDiv(IndexingDiv):
+ """
+ Div where we can assume no rounding.
+ This is to enable future optimizations.
+ """
+
+ pass
+
+
+class CeilDiv(sympy.Function):
+ """
+ Div used in indexing that rounds up.
+ """
+
+ def __new__(cls, base, divisor):
+ if sympy.gcd(base, divisor) == divisor:
+ return CleanDiv(base, divisor)
+ else:
+ return IndexingDiv(base + (divisor - 1), divisor)
+
+
+def get_device_type(x):
+ if getattr(x, "get_device", None):
+ return get_device_type(x.get_device())
+ if isinstance(x, torch.device):
+ return x.type
+ return None
+
+
+def is_triton(x):
+ return get_device_type(x) == "cuda"
+
+
+def is_cpu(x):
+ return get_device_type(x) == "cpu"
+
+
+@dataclasses.dataclass
+class IRNode(object):
+ _current_origins: ClassVar[Set[Any]] = set()
+
+ @staticmethod
+ @contextlib.contextmanager
+ def current_origins(origins: Set[torch.fx.Node]):
+ old = IRNode._current_origins
+ IRNode._current_origins = old | origins
+ yield
+ IRNode._current_origins = old
+
+ def __post_init__(self):
+ self.origins = set(self._current_origins)
+
+ def common_repr(self):
+ return (
+ [f"origins={self.origins}"] if hasattr(self, "origins") else ["no origins?"]
+ )
+
+ def str_helper(self, lines):
+ lines = lines + self.common_repr()
+ lines = indent(",\n".join(map(str, lines)))
+ return f"{type(self).__name__}(\n{lines}\n)"
+
+ def is_user_of(self, name):
+ return any(name == dep.name for dep in self.get_reads())
+
+ def get_numel(self):
+ return sympy_product(self.get_size())
+
+
+@dataclasses.dataclass
+class Loops(IRNode):
+ device: torch.device
+ dtype: torch.dtype
+ inner_fn: Callable
+ ranges: List[Expr]
+
+ def __str__(self, names=("ranges",)):
+ return self.str_helper(
+ [
+ f"'{self.device.type}'",
+ str(self.dtype),
+ self.inner_fn_str(),
+ ]
+ + [f"{name}={getattr(self, name)}" for name in names]
+ )
+
+ __repr__ = __str__
+
+ def get_dtype(self):
+ return self.dtype
+
+ def get_device(self):
+ return self.device
+
+ def get_size(self):
+ return self.ranges
+
+ def is_extern(self):
+ return False
+
+ @classmethod
+ def create(cls, *args, **kwargs):
+ return TensorBox.create(cls(*args, **kwargs))
+
+ @staticmethod
+ def _index(ranges, prefix="i"):
+ return [
+ sympy.Integer(0) if s == 1 else sympy.Symbol(f"{prefix}{n}")
+ for n, s in enumerate(ranges)
+ ]
+
+ @cache_on_self
+ def inner_fn_str(self):
+ try:
+ with V.set_ops_handler(V.MockHandler()), patch.object(
+ FlexibleLayout, "allow_indexing", True
+ ):
+ return self.inner_fn(self._index(self.ranges))
+ except Exception as e:
+ return f"inner_fn(): {e}"
+
+ def is_zero_elements(self):
+ return any(r == 0 for r in self.ranges)
+
+ @cache_on_self
+ def get_reads(self):
+ with patch.object(FlexibleLayout, "allow_indexing", True):
+ if self.get_reduction_type():
+ return extract_read_writes(
+ self.make_loader(),
+ self.get_size(),
+ self.get_reduction_size(),
+ ).reads
+ else:
+ return extract_read_writes(
+ self.make_loader(),
+ self.get_size(),
+ ).reads
+
+
+class Pointwise(Loops):
+ def make_loader(self):
+ return self.inner_fn
+
+ def get_reduction_size(self):
+ return []
+
+ def get_reduction_type(self):
+ return None
+
+ def store_output(self, output_name, indexer, vars):
+ return ops.store(output_name, indexer(vars), self.inner_fn(vars))
+
+ def constant_to_device(self, device):
+ """Move this to a given device. Requires that all reads are to constants."""
+ loader = self.make_loader()
+ loader = patch.object(ConstantBuffer, "override_device", device)(loader)
+ return Pointwise(device, self.dtype, loader, self.ranges)
+
+
+@dataclasses.dataclass
+class Scatter(Pointwise):
+ output_indexer: Callable[[List[Expr]], Expr]
+ scatter_mode: Optional[str] = None
+
+ def constant_to_device(self, device):
+ """Move this to a given device. Requires that all reads are to constants."""
+ loader = self.make_loader()
+ loader = patch.object(ConstantBuffer, "override_device", device)(loader)
+ return Scatter(
+ device,
+ self.dtype,
+ loader,
+ self.ranges,
+ self.output_indexer,
+ self.scatter_mode,
+ )
+
+ def store_output(self, output_name, indexer, vars):
+ return ops.store(
+ output_name,
+ indexer(self.output_indexer(vars)),
+ self.inner_fn(vars),
+ mode=self.scatter_mode,
+ )
+
+
+class ReductionHint(Enum):
+ INNER = 0
+ OUTER = 1
+ OUTER_TINY = 2
+ DEFAULT = 3
+
+
+@dataclasses.dataclass
+class Reduction(Loops):
+ reduction_ranges: List[Expr]
+ reduction_type: str
+ # self.dtype represents the dst dtype
+ src_dtype: torch.dtype
+ reduction_hint: ReductionHint
+
+ def __str__(self):
+ return Loops.__str__(
+ self, names=("ranges", "reduction_ranges", "reduction_type")
+ )
+
+ __repr__ = __str__
+
+ def get_reduction_size(self):
+ return self.reduction_ranges
+
+ def get_reduction_type(self):
+ return self.reduction_type
+
+ def store_reduction(self, output_name, indexer, vars, reduction_vars):
+ return ops.reduction(
+ output_name,
+ self.dtype,
+ self.src_dtype,
+ self.reduction_type,
+ indexer(vars),
+ self.inner_fn(vars, reduction_vars),
+ )
+
+ def index_length(self):
+ return len(self.ranges) + len(self.reduction_ranges)
+
+ @cache_on_self
+ def inner_fn_str(self):
+ try:
+ with V.set_ops_handler(V.MockHandler()), patch.object(
+ FlexibleLayout, "allow_indexing", True
+ ):
+ return self.inner_fn(
+ self._index(self.ranges), self._index(self.reduction_ranges, "r")
+ )
+ except Exception as e:
+ return f"inner_fn(): {e}"
+
+ def constant_to_device(self, device):
+ """Move this to a given device. Requires that all reads are to constants."""
+ loader = self.make_loader()
+ loader = patch.object(ConstantBuffer, "override_device", device)(loader)
+ return Reduction(
+ device,
+ self.dtype,
+ loader,
+ self.ranges,
+ self.reduction_ranges,
+ self.reduction_type,
+ self.src_dtype,
+ ReductionHint.DEFAULT,
+ )
+
+ @staticmethod
+ def num_splits(
+ device,
+ dst_dtype,
+ src_dtype,
+ inner_fn,
+ ranges,
+ reduction_ranges,
+ reduction_type,
+ reduction_numel,
+ ):
+ num_sm = torch.cuda.get_device_properties(device).multi_processor_count
+ min_elements_per_thread = 32
+ max_elements_per_thread = 512
+ threads_per_sm = 2048
+ min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
+ max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
+
+ def inner_reduction_splits(reduction_numel_hint, numel_hint):
+ # do heuristics that's close to eager mode for split inner reduction
+ # we leak reduction autotune configs here, and will need to refactor to avoid this later
+ num_warps = 8
+ num_threads = 32 * num_warps
+ if numel_hint >= 2 * num_sm: # don't split if there are enough outputs
+ return 1
+ if reduction_numel_hint <= 8192:
+ return 1
+ if reduction_numel_hint * numel_hint <= min_elements_per_device:
+ split_size = min_elements_per_thread
+ elif reduction_numel_hint * numel_hint < max_elements_per_device:
+ target_blocks = num_sm * threads_per_sm // (2 * num_threads)
+ blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
+ tmp_split_size = (
+ reduction_numel_hint + num_threads * blocks_per_output - 1
+ ) // (num_threads * blocks_per_output)
+ divisors = sympy.divisors(reduction_numel_hint)
+ closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
+ if abs(closest - tmp_split_size) < 30:
+ # prefer even splits, but never smalle than min_elements_per_thread
+ split_size = max(closest, min_elements_per_thread)
+ else:
+ split_size = tmp_split_size
+ else:
+ divisors = sympy.divisors(reduction_numel_hint)
+ closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
+ if abs(closest - max_elements_per_thread) < 50:
+ # prefer even splits
+ split_size = closest
+ else:
+ split_size = max_elements_per_thread
+ return (reduction_numel_hint + split_size * num_threads - 1) // (
+ split_size * num_threads
+ )
+
+ def outer_reduction_splits(reduction_numel_hint, numel_hint):
+ # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
+ # extend to even smaller number of outputs
+ num_warps = 8
+ num_threads = num_warps * 32
+ rvals_per_thread = 4 # comes from heuristics, refactor to not leak here
+ xvals_per_block = 128
+ xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
+ if reduction_numel_hint * numel_hint < min_elements_per_device:
+ split_size = min_elements_per_thread
+ elif reduction_numel_hint * numel_hint < max_elements_per_device:
+ target_blocks = num_sm * threads_per_sm // (num_threads)
+ target_blocks = (target_blocks + xblocks - 1) // xblocks
+ tmp_split_size = (
+ reduction_numel_hint + rvals_per_thread * target_blocks - 1
+ ) // (rvals_per_thread * target_blocks)
+ divisors = sympy.divisors(reduction_numel_hint)
+ closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
+ if abs(tmp_split_size - closest) < 20:
+ split_size = max(closest, min_elements_per_thread)
+ else:
+ split_size = tmp_split_size
+ else:
+ divisors = sympy.divisors(reduction_numel_hint)
+ closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
+ if abs(closest - max_elements_per_thread) < 50:
+ # prefer even splits
+ split_size = closest
+ else:
+ split_size = max_elements_per_thread
+
+ return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
+ rvals_per_thread * split_size
+ )
+
+ reduction_numel_hint = V.graph.sizevars.size_hint(reduction_numel)
+ numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
+ # easy cases
+ if numel_hint == 1:
+ return ReductionHint.INNER, inner_reduction_splits(
+ reduction_numel_hint, numel_hint
+ )
+ if (
+ reduction_numel_hint <= min_elements_per_thread
+ or numel_hint >= num_sm * 2 * 32
+ ):
+ return ReductionHint.DEFAULT, 1
+
+ r = Reduction(
+ device,
+ dst_dtype,
+ inner_fn,
+ ranges,
+ reduction_ranges,
+ reduction_type,
+ src_dtype,
+ ReductionHint.DEFAULT,
+ )
+ read_writes = ComputedBuffer(
+ name=None,
+ layout=FlexibleLayout(
+ device=r.get_device(),
+ dtype=r.get_dtype(),
+ size=r.get_size(),
+ ),
+ data=r,
+ ).get_read_writes()
+ # try finding the full size producer
+ # TODO this will fail for something like ((1, N) * (N, 1)).sum()
+ # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
+ # TODO maybe go over all full size producers and pick the most common one?
+ range_vars = [
+ r
+ for r in read_writes.range_vars
+ if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number)
+ ]
+ index = None
+ for md in read_writes.reads:
+ if all([r in md.index.free_symbols for r in range_vars]):
+ index = md.index
+ break
+ if not index:
+ # TODO determine splits when all inputs are broadcasted
+ return ReductionHint.DEFAULT, 1
+ reduction_vars = [
+ rv for rv in range_vars if read_writes.var_ranges[rv] in reduction_ranges
+ ]
+ strides = V.graph.sizevars.stride_hints(index, reduction_vars)
+ outer = all([s > 1 for s in strides])
+ if not outer:
+ return ReductionHint.INNER, inner_reduction_splits(
+ reduction_numel_hint, numel_hint
+ )
+ else: # outer reduction
+ return ReductionHint.OUTER, outer_reduction_splits(
+ reduction_numel_hint, numel_hint
+ )
+
+ @staticmethod
+ def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type):
+ """Convert inner_fn from a reduction to an pointwise"""
+ reduction_ranges = [
+ V.graph.sizevars.guard_static_shape(x) for x in reduction_ranges
+ ]
+
+ if reduction_type == "sum":
+
+ def combine_fn(a, b):
+ return ops.add(a, b)
+
+ elif reduction_type == "min":
+
+ def combine_fn(a, b):
+ return ops.minimum(a, b)
+
+ elif reduction_type == "max":
+
+ def combine_fn(a, b):
+ return ops.maximum(a, b)
+
+ elif reduction_type == "any":
+
+ def combine_fn(a, b):
+ return ops.logical_or(a, b)
+
+ elif reduction_type == "argmin":
+
+ def combine_fn(a, b):
+ return ops.minimum(a[0], b[0]), ops.where(
+ ops.lt(b[0], a[0]), b[1], a[1]
+ )
+
+ elif reduction_type == "argmax":
+
+ def combine_fn(a, b):
+ return ops.maximum(a[0], b[0]), ops.where(
+ ops.gt(b[0], a[0]), b[1], a[1]
+ )
+
+ else:
+ raise NotImplementedError(f"unknown reduction_type={reduction_type}")
+
+ def fn(index):
+ return functools.reduce(
+ combine_fn,
+ (
+ value_fn(index, rindex)
+ for rindex in itertools.product(
+ *[range(x) for x in reduction_ranges]
+ )
+ ),
+ )
+
+ if reduction_type in ("argmin", "argmax"):
+ flatten_index = FixedLayout(
+ None,
+ None,
+ reduction_ranges,
+ FlexibleLayout.contiguous_strides(reduction_ranges),
+ ).make_indexer()
+
+ def value_fn(index, rindex):
+ rindex = [sympy.expand(i) for i in rindex]
+ return (
+ inner_fn(index, rindex),
+ ops.index_expr(flatten_index(rindex), torch.int64),
+ )
+
+ return lambda index: fn(index)[1]
+ else:
+ value_fn = inner_fn
+ return fn
+
+ @classmethod
+ def create(
+ cls,
+ device: torch.device,
+ dst_dtype: torch.dtype,
+ src_dtype: torch.dtype,
+ inner_fn: Callable,
+ ranges: List[Expr],
+ reduction_ranges: List[Expr],
+ reduction_type: str,
+ reduction_hint: ReductionHint = ReductionHint.DEFAULT,
+ ):
+ reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
+ if reduction_numel == 1:
+ # this reduction is actually a pointwise op
+ if reduction_type in ("argmin", "argmax"):
+
+ def fn(index):
+ assert len(index) <= 1
+ return 0
+
+ else:
+
+ def fn(index):
+ reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
+ return inner_fn(index, reduction_index)
+
+ return Pointwise.create(device, dst_dtype, fn, ranges)
+
+ if (
+ isinstance(reduction_numel, sympy.Integer)
+ and V.graph.sizevars.size_hint(reduction_numel)
+ < config.unroll_reductions_threshold
+ and sympy_product(ranges) != 1
+ ):
+ return Pointwise.create(
+ device,
+ dst_dtype,
+ cls._unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type),
+ ranges,
+ )
+
+ if is_triton(device) and reduction_type not in {"argmax", "argmin"}:
+ # triton doesn't support reduce to single element well, so break it up
+ hint, split = cls.num_splits(
+ device,
+ dst_dtype,
+ src_dtype,
+ inner_fn,
+ ranges,
+ reduction_ranges,
+ reduction_type,
+ reduction_numel,
+ )
+ # intermediate reduction in split can contain complex indexing,
+ # and num_splits will fail to correctly set the hint
+ # reuse the passed hint if available
+ if reduction_hint == ReductionHint.DEFAULT:
+ reduction_hint = hint
+ if split > 1:
+ # triton doesn't support reduce to single element well, so break it up
+ return cls.create_multilayer(
+ device,
+ dst_dtype,
+ src_dtype,
+ inner_fn,
+ ranges,
+ reduction_ranges,
+ reduction_type,
+ split,
+ reduction_hint,
+ )
+
+ return TensorBox.create(
+ Reduction(
+ device,
+ dst_dtype,
+ inner_fn,
+ ranges,
+ reduction_ranges,
+ reduction_type,
+ src_dtype,
+ reduction_hint,
+ )
+ )
+
+ @staticmethod
+ def default_value(reduction_type, dtype):
+ if reduction_type in {"max", "argmax"}:
+ if is_float_dtype(dtype):
+ return float("-inf")
+ elif is_boolean_dtype(dtype):
+ return 0
+ else:
+ return torch.iinfo(dtype).min
+ if reduction_type in {"min", "argmin"}:
+ if is_float_dtype(dtype):
+ return float("inf")
+ elif is_boolean_dtype(dtype):
+ return 1
+ else:
+ return torch.iinfo(dtype).max
+
+ return {
+ "sum": 0,
+ "any": 0,
+ }[reduction_type]
+
+ @classmethod
+ def create_multilayer(
+ cls,
+ device: torch.device,
+ dst_dtype: torch.dtype,
+ src_dtype: torch.dtype,
+ inner_fn: Callable,
+ ranges: List[Expr],
+ reduction_ranges: List[Expr],
+ reduction_type: str,
+ split: int,
+ reduction_hint: ReductionHint,
+ ):
+ """
+ Break a large reduction up into multiple smaller reductions
+ recursively
+ """
+ reduction_numel = sympy_product(reduction_ranges)
+
+ # TODO(jansel): convert this to dynamic shapes
+ # TODO(jansel): realize the reduction so we can do dynamic indexing
+ reduction_ranges = [
+ sympy.Integer(V.graph.sizevars.guard_static_shape(s))
+ for s in reduction_ranges
+ ]
+ reduction_numel = sympy.Integer(
+ V.graph.sizevars.guard_static_shape(reduction_numel)
+ )
+
+ if V.graph.sizevars.size_hint(reduction_numel) % split == 0:
+ need_mask = False
+ else:
+ need_mask = True
+
+ split = sympy.Integer(split)
+ block_size = IndexingDiv(reduction_numel + (split - 1), split)
+
+ reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
+
+ def wrapper_fn(index, reduction_index):
+ (reduction_index,) = reduction_index
+ *new_index, reduction_block = index
+ indices = block_size * reduction_block + reduction_index
+
+ def body():
+ return inner_fn(new_index, reindex([indices]))
+
+ if need_mask:
+ mask = ops.lt(
+ ops.index_expr(indices, torch.int32),
+ ops.index_expr(reduction_numel, torch.int32),
+ )
+ return ops.masked(
+ mask, body, cls.default_value(reduction_type, dst_dtype)
+ )
+ else:
+ return body()
+
+ # triton will automatically compute reductions in fp32 if reducing over fp16/bf16
+ # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
+ # in fp32 and not reduce precision by breaking up the kernel into multiple layers
+ intermediate_dtype = (
+ dst_dtype
+ if dst_dtype not in (torch.float16, torch.bfloat16)
+ else torch.float
+ )
+ intermediate = Reduction.create(
+ device,
+ intermediate_dtype,
+ src_dtype,
+ wrapper_fn,
+ [*ranges, split],
+ [block_size],
+ reduction_type,
+ reduction_hint,
+ )
+ intermediate.realize()
+ intermediate_loader = intermediate.make_loader()
+
+ def intermediate_fn(index, reduction_index):
+ return intermediate_loader([*index, *reduction_index])
+
+ numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
+ if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
+ reduction_hint = ReductionHint.OUTER_TINY
+ return TensorBox.create(
+ Reduction(
+ device,
+ dst_dtype,
+ intermediate_fn,
+ ranges,
+ [split],
+ reduction_type,
+ src_dtype,
+ reduction_hint,
+ )
+ )
+
+
+def is_storage_and_layout(x):
+ try:
+ as_storage_and_layout(x, freeze=False)
+ return True
+ except NotImplementedError:
+ return False
+
+
+def is_contiguous_storage_and_layout(x):
+ try:
+ buffer, layout = as_storage_and_layout(x, freeze=False)
+ return layout.is_contiguous()
+ except NotImplementedError:
+ return False
+
+
+def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None):
+ """Try to simplify x into a StorageBox and a Layout"""
+ if isinstance(x, TensorBox):
+ return as_storage_and_layout(
+ x.data,
+ freeze=freeze,
+ want_contiguous=want_contiguous,
+ stride_order=stride_order,
+ )
+ if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
+ if freeze:
+ if want_contiguous:
+ x.data.freeze_layout()
+ elif stride_order is not None:
+ x.data.freeze_layout_with_stride_order(stride_order)
+ else:
+ x.data.decide_layout()
+ return x, x.data.layout
+ if isinstance(x, ReinterpretView):
+ buffer, _ = as_storage_and_layout(
+ x.data,
+ freeze=freeze,
+ want_contiguous=want_contiguous,
+ stride_order=stride_order,
+ )
+ return buffer, x.layout
+ raise NotImplementedError
+
+
+as_contiguous_storage_and_layout = functools.partial(
+ as_storage_and_layout, want_contiguous=True
+)
+
+
+def is_stride_order_storage_and_layout(x, stride_order):
+ try:
+ buffer, layout = as_storage_and_layout(x, freeze=False)
+ return layout.is_stride_ordered(stride_order)
+ except NotImplementedError:
+ return False
+
+
+@dataclasses.dataclass
+class BaseView(IRNode):
+ data: IRNode
+
+ def get_dtype(self):
+ return self.data.get_dtype()
+
+ def get_device(self):
+ return self.data.get_device()
+
+ def get_name(self):
+ return self.data.get_name()
+
+ def mark_reuse(self, users):
+ return self.data.mark_reuse(users)
+
+ def realize(self):
+ return self.data.realize()
+
+ def realize_hint(self):
+ return self.data.realize_hint()
+
+ def get_storage_numel(self):
+ return self.data.get_storage_numel()
+
+ def is_extern(self):
+ return self.data.is_extern()
+
+ @cache_on_self
+ def get_reads(self):
+ with patch.object(FlexibleLayout, "allow_indexing", True):
+ return extract_read_writes(
+ self.make_loader(),
+ self.get_size(),
+ ).reads
+
+ def unwrap_view(self):
+ x = self
+ while isinstance(x, BaseView):
+ x = x.data
+ return x
+
+ def constant_to_device(self, device):
+ """Move this to a given device. Requires that all reads are to constants."""
+ loader = self.make_loader()
+ loader = patch.object(ConstantBuffer, "override_device", device)(loader)
+ return Pointwise(device, self.get_dtype(), loader, self.get_size())
+
+
+@dataclasses.dataclass
+class ExpandView(BaseView):
+ size: List[Expr]
+
+ @staticmethod
+ def _normalize_size(x, new_size):
+ """Replace `-1` with correct sizes"""
+ new_size = list(map(sympy.expand, new_size))
+ old_size = x.get_size()
+ old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
+ assert len(new_size) == len(old_size)
+ for i in range(len(new_size)):
+ if new_size[i] == -1:
+ assert old_size[i] is not None
+ new_size[i] = old_size[i]
+ return new_size
+
+ @classmethod
+ def create(cls, x, new_size):
+ new_size = cls._normalize_size(x, new_size)
+
+ if is_storage_and_layout(x):
+ storage, old_layout = as_storage_and_layout(x)
+ skip = len(new_size) - len(old_layout.size)
+ assert skip >= 0
+ new_stride = [sympy.Integer(0)] * skip
+ for stride, size in zip(old_layout.stride, old_layout.size):
+ new_stride.append(stride if size != 1 else sympy.Integer(0))
+ new_layout = FixedLayout(
+ old_layout.device,
+ old_layout.dtype,
+ list(new_size),
+ new_stride,
+ old_layout.offset,
+ )
+ return ReinterpretView(storage, new_layout)
+
+ return ExpandView(x, new_size)
+
+ def get_size(self):
+ return self.size
+
+ def make_loader(self):
+ target = self.get_size()
+ actual = self.data.get_size()
+ skip = len(target) - len(actual)
+ inner = self.data.make_loader()
+
+ def load(index):
+ index = list(index[skip:])
+ assert len(index) == len(actual)
+ for i in range(len(actual)):
+ if actual[i] == 1:
+ # zero out broadcast dimension
+ index[i] = sympy.Integer(0)
+ return inner(index)
+
+ return load
+
+
+@dataclasses.dataclass
+class PermuteView(BaseView):
+ dims: List[Expr]
+
+ @classmethod
+ def create(cls, x, dims):
+ assert set(cls._map_neg_dims(dims)) == set(range(len(dims)))
+
+ if is_storage_and_layout(x):
+ storage, old_layout = as_storage_and_layout(x)
+ new_layout = FixedLayout(
+ old_layout.device,
+ old_layout.dtype,
+ [old_layout.size[i] for i in dims],
+ [old_layout.stride[i] for i in dims],
+ old_layout.offset,
+ )
+ return ReinterpretView(storage, new_layout)
+
+ return PermuteView(x, dims)
+
+ @classmethod
+ def _map_neg_dims(cls, dims):
+ return [dim if dim >= 0 else len(dims) + dim for dim in dims]
+
+ def get_size(self):
+ assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
+ size = self.data.get_size()
+ return [size[i] for i in self.dims]
+
+ def make_loader(self):
+ inner = self.data.make_loader()
+ inv = {j: i for i, j in enumerate(self.dims)}
+ inv = [inv[i] for i in range(len(self.dims))]
+ assert set(inv) == set(range(len(self.dims)))
+
+ def load(index):
+ index = [index[i] for i in inv]
+ return inner(index)
+
+ return load
+
+
+class SqueezeView(BaseView):
+ @classmethod
+ def create(cls, x, *, dim=None):
+
+ if is_storage_and_layout(x):
+ storage, old_layout = as_storage_and_layout(x)
+ new_size = []
+ new_stride = []
+ if dim is not None:
+ assert isinstance(dim, int), "expected integer dim argument"
+ assert 0 <= dim and dim < len(old_layout.size)
+
+ for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
+ if dim is None:
+ if size != 1:
+ new_size.append(size)
+ new_stride.append(stride)
+ else:
+ if i != dim:
+ new_size.append(size)
+ new_stride.append(stride)
+ else:
+ assert size == 1, "expected squeezed size to be 1"
+
+ new_layout = FixedLayout(
+ old_layout.device,
+ old_layout.dtype,
+ new_size,
+ new_stride,
+ old_layout.offset,
+ )
+ return ReinterpretView(storage, new_layout)
+
+ if dim is None:
+ # redirect to a generic view
+ return View.create(x, [s for s in x.get_size() if s != 1])
+ else:
+ assert x.get_size()[dim] == 1
+ return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
+
+ @staticmethod
+ def squeezer(size: Tuple[sympy.Expr, ...]):
+ new_size = [s for s in size if s != 1]
+ not_one = [i for i, s in enumerate(size) if s != 1]
+ length = len(size)
+
+ def reindex(index: List[sympy.Expr]) -> List[sympy.Expr]:
+ assert len(index) == len(not_one), f"{index} {not_one}"
+ new_index = [sympy.Integer(0)] * length
+ for idx, s in zip(not_one, index):
+ new_index[idx] = s
+ return tuple(new_index)
+
+ return new_size, reindex
+
+ def __init__(self, data):
+ raise AssertionError("use SqueezeView.create()")
+
+
+@dataclasses.dataclass
+class View(BaseView):
+ size: List[Expr]
+ reindex: Callable
+
+ def make_indexer(self):
+ base_indexer = self.data.make_indexer()
+
+ def indexer(idx):
+ return base_indexer(self.reindex(idx))
+
+ return indexer
+
+ @staticmethod
+ def handle_negative_index(idx, size):
+ idx = sympy.expand(idx)
+ size = sympy.expand(size)
+ sizevars = V.graph.sizevars
+ if sizevars.size_hint(idx) < 0:
+ sizevars.guard_lt(idx, 0)
+ idx = idx + size
+ return idx
+
+ def reindex_str(self):
+ index_old = [sympy.Symbol(f"i{n}") for n in range(len(self.size))]
+ index_new = list(self.reindex(index_old))
+ return f"lambda {', '.join(map(str, index_old))}: {index_new}"
+
+ def __str__(self):
+ return self.str_helper(
+ [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
+ )
+
+ __repr__ = __str__
+
+ @classmethod
+ def create(cls, x, new_size):
+ assert isinstance(new_size, (tuple, list))
+ old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
+
+ if V.graph.sizevars.maybe_guard_list_equals(old_size, new_size):
+ return x
+
+ # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
+ if is_contiguous_storage_and_layout(x) and not isinstance(
+ x.data, ExternKernelAlloc
+ ):
+ storage, old_layout = as_contiguous_storage_and_layout(x)
+ new_layout = FixedLayout(
+ old_layout.device,
+ old_layout.dtype,
+ new_size,
+ FlexibleLayout.contiguous_strides(new_size),
+ old_layout.offset,
+ )
+ return ReinterpretView(storage, new_layout)
+
+ reindex = cls.dynamic_reshape_indexer(old_size, new_size)
+ return cls(x, tuple(new_size), reindex)
+
+ @staticmethod
+ def resolve_negative_size(old_size, new_size):
+ new_size = [V.graph.sizevars.simplify(x) for x in new_size]
+ old_size = [V.graph.sizevars.simplify(x) for x in old_size]
+
+ new_size = list(new_size)
+ for i in range(len(new_size)):
+ if new_size[i] == -1:
+ new_size[i] = sympy.Integer(1)
+ new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
+ break
+
+ V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
+ return old_size, new_size
+
+ @classmethod
+ def dynamic_reshape_indexer(cls, old_size, new_size):
+ try:
+ reindex = cls._dynamic_reshape_indexer(old_size, new_size)
+ except (AssertionError, IndexError):
+ # optimistic algorithm failed, lets do a fallback
+ flat = [sympy_product(old_size)]
+ reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
+ reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
+ reindex = fuse_reindexing(reindex1, reindex2)
+ return reindex
+
+ @staticmethod
+ def _dynamic_reshape_indexer(old_size, new_size):
+ """
+ Perform a reshape entirely by modifying indexing math
+ """
+ size_hint = V.graph.sizevars.size_hint
+ vars = [sympy.Symbol(f"view{i}") for i in range(len(new_size))]
+
+ stack_new = list(zip(vars, new_size))
+ stack_old = list(old_size)
+
+ view_expr = []
+ while stack_new and stack_old:
+ size_old = stack_old.pop()
+ var, size_new = stack_new.pop()
+ if size_old == 1:
+ view_expr.append(sympy.Integer(0))
+ stack_new.append((var, size_new)) # re-add
+ elif size_new == 1:
+ stack_old.append(size_old) # re-add
+ elif size_hint(size_new) == size_hint(size_old):
+ view_expr.append(var)
+ V.graph.sizevars.guard_equals(size_new, size_old)
+ elif size_hint(size_new) < size_hint(size_old):
+ while size_hint(size_new) < size_hint(size_old):
+ var2, size_new2 = stack_new.pop()
+ var = var2 * size_new + var
+ size_new = size_new * size_new2
+ view_expr.append(var)
+ V.graph.sizevars.guard_equals(size_new, size_old)
+ elif size_hint(size_new) > size_hint(size_old):
+ divisor = sympy.Integer(1)
+ modulus = size_old
+ view_expr.append(ModularIndexing(var, divisor, modulus))
+ divisor = divisor * modulus
+ while size_hint(size_new) > size_hint(size_old):
+ modulus = stack_old.pop()
+ view_expr.append(ModularIndexing(var, divisor, modulus))
+ divisor = divisor * modulus
+ size_old = size_old * modulus
+ V.graph.sizevars.guard_equals(size_new, size_old)
+ else:
+ raise AssertionError()
+
+ while stack_old:
+ size_old = stack_old.pop()
+ assert size_old == 1
+ view_expr.append(sympy.Integer(0))
+
+ while stack_new:
+ var, size_new = stack_new.pop()
+ assert size_new == 1
+
+ view_expr = list(reversed(view_expr))
+ assert len(view_expr) == len(old_size)
+
+ def reindex(index):
+ assert len(index) == len(vars), (len(index), len(vars))
+ replacements = dict(zip(vars, index))
+ return tuple(sympy_subs(x, replacements) for x in view_expr)
+
+ return reindex
+
+ def get_size(self):
+ return self.size
+
+ def make_loader(self):
+ def load(index):
+ return inner(self.reindex(index))
+
+ inner = self.data.make_loader()
+ return load
+
+
+@dataclasses.dataclass
+class ReinterpretView(BaseView):
+ """Pretend our storage has a different layout"""
+
+ layout: "Layout"
+
+ def __str__(self):
+ return self.str_helper(
+ [
+ self.data,
+ self.layout,
+ ]
+ )
+
+ __repr__ = __str__
+
+ def get_name(self):
+ return self.data.get_name()
+
+ def get_device(self):
+ return self.layout.device
+
+ def get_dtype(self):
+ return self.layout.dtype
+
+ def get_size(self):
+ return self.layout.size
+
+ def get_stride(self):
+ return self.layout.stride
+
+ def make_loader(self):
+ def loader(index):
+ indexer = self.layout.make_indexer()
+ return ops.load(self.get_name(), indexer(index))
+
+ return loader
+
+ def make_indexer(self):
+ return self.layout.make_indexer()
+
+ def get_layout(self):
+ return self.layout
+
+ def freeze_layout(self):
+ pass
+
+ def codegen_reference(self):
+ size = V.graph.sizevars.codegen_shape_tuple(self.layout.size)
+ stride = V.graph.sizevars.codegen_shape_tuple(self.layout.stride)
+ offset = V.graph.sizevars.codegen_sizevar(self.layout.offset)
+ if offset != "0":
+ return f"as_strided({self.get_name()}, {size}, {stride}, {offset})"
+ return f"as_strided({self.get_name()}, {size}, {stride})"
+
+
+class SliceView(View):
+ @classmethod
+ def create(cls, x, dim, start, end, step=1):
+ step = sympy.expand(step)
+ assert step > 0
+ try:
+ if start == 0 and end >= 2**63 and step == 1:
+ return x
+ except TypeError:
+ pass
+
+ sizevars = V.graph.sizevars
+ new_size = list(x.get_size())
+
+ start = cls.handle_negative_index(start, new_size[dim])
+ end = cls.handle_negative_index(end, new_size[dim])
+
+ end = sizevars.guard_min(end, new_size[dim])
+ start = sizevars.guard_min(sizevars.guard_min(start, new_size[dim]), end)
+ if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1:
+ sizevars.guard_equals(end, new_size[dim])
+ return x
+
+ new_size[dim] = IndexingDiv(end - start + (step - 1), step)
+
+ if is_storage_and_layout(x):
+ # Fast path
+ storage, old_layout = as_storage_and_layout(x)
+ new_stride = list(old_layout.stride)
+ new_stride[dim] = new_stride[dim] * step
+ new_layout = FixedLayout(
+ old_layout.device,
+ old_layout.dtype,
+ new_size,
+ new_stride,
+ old_layout.offset + old_layout.stride[dim] * start,
+ )
+ return ReinterpretView(storage, new_layout)
+
+ def reindex(index):
+ assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
+ index = list(index)
+ index[dim] = index[dim] * step + start
+ return index
+
+ # redirect to a generic view
+ return SliceView(x, size=new_size, reindex=reindex)
+
+
+class BaseConstant(IRNode):
+ def get_size(self):
+ return ()
+
+ def get_dtype(self):
+ return self.dtype
+
+ def get_device(self):
+ return self.device
+
+ def mark_reuse(self, users):
+ pass
+
+ def get_reads(self):
+ return ()
+
+ def is_extern(self):
+ return False
+
+
+@dataclasses.dataclass
+class Constant(BaseConstant):
+ value: Any
+ dtype: torch.dtype
+ device: torch.device
+
+ def make_loader(self):
+ def loader(index):
+ return ops.constant(self.value, self.dtype)
+
+ return loader
+
+
+@dataclasses.dataclass
+class IndexingConstant(BaseConstant):
+ index: Any
+ dtype: torch.dtype
+ device: torch.device
+
+ def make_loader(self):
+ def loader(index):
+ return ops.index_expr(self.index, self.dtype)
+
+ return loader
+
+
+@dataclasses.dataclass
+class Layout(IRNode):
+ device: torch.device
+ dtype: torch.dtype
+ size: List[Expr]
+ stride: List[Expr]
+ offset: Expr = Integer(0)
+
+ def __str__(self):
+ offset = ""
+ if self.offset != 0:
+ offset = f", offset={self.offset}"
+ return (
+ f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
+ f"size={self.size}, stride={self.stride}{offset})"
+ )
+
+ __repr__ = __str__
+
+ def is_contiguous(self):
+ for left, right, size in zip(
+ self.stride, FlexibleLayout.contiguous_strides(self.size), self.size
+ ):
+ if size != 1 and left != right:
+ return False
+ return True
+
+ def is_transposed(self):
+ for left, right, size in zip(
+ self.stride,
+ reversed(FlexibleLayout.contiguous_strides(self.size)),
+ self.size,
+ ):
+ if size != 1 and left != right:
+ return False
+ return True
+
+ def is_stride_ordered(self, order):
+ assert len(self.stride) == len(order)
+ # reorder the stride given order
+ stride_ordered = [None] * len(order)
+ for i in range(len(order)):
+ stride_ordered[order[i]] = V.graph.sizevars.size_hint(self.stride[i])
+ # check if it is in ascending order
+ for i in range(len(order) - 1):
+ if stride_ordered[i] > stride_ordered[i + 1]:
+ return False
+ return True
+
+ def is_channels_last_stride_ordered(self):
+ # create channels_last order(NCHW, NCDHW, the C is the first order).
+ order = [0] + list(reversed(range(1, len(self.stride) - 1)))
+ order = [len(order)] + order
+ return self.is_stride_ordered(order)
+
+ def as_fixed(self):
+ return FixedLayout(
+ self.device,
+ self.dtype,
+ self.size,
+ self.stride,
+ self.offset,
+ )
+
+ def make_indexer(self):
+ assert (
+ FlexibleLayout.allow_indexing
+ ), f"convert {type(self).__name__} to FixedLayout first"
+ return self.as_fixed().make_indexer()
+
+ def __eq__(self, other) -> bool:
+ return (
+ self.device == other.device
+ and self.dtype == other.dtype
+ and self.size == other.size
+ and self.stride == other.stride
+ and self.offset == other.offset
+ )
+
+
+class FixedLayout(Layout):
+ """A Tensor layout we cannot change"""
+
+ def make_indexer(self):
+ """A closure containing math to read a given element"""
+
+ def indexer(index):
+ assert len(index) == len(self.stride) == len(self.size)
+ result = self.offset
+ for idx, stride, sz in zip(index, self.stride, self.size):
+ if sz != 1:
+ result = result + idx * stride
+ return result
+
+ return indexer
+
+
+class FlexibleLayout(Layout):
+ """A Tensor layout we are allowed to change"""
+
+ allow_indexing = False
+
+ @staticmethod
+ def contiguous_strides(sizes):
+ if len(sizes) == 0:
+ return []
+ reversed_strides = [sympy.Integer(1)]
+ for size in reversed(sizes[1:]):
+ reversed_strides.append(size * reversed_strides[-1])
+ return list(reversed(reversed_strides))
+
+ @staticmethod
+ def fill_ordered(sizes, order):
+ """
+ Create a stride based on the order the dimensions should be filled in.
+
+ In this format, channels last would be:
+ [1, 3, 2, 0]
+ """
+ assert set(range(len(sizes))) == set(order)
+ next_stride = sympy.Integer(1)
+ strides = [None] * len(order)
+
+ for i in order:
+ strides[i] = next_stride
+ next_stride = next_stride * sizes[i]
+ return strides
+
+ @staticmethod
+ def stride_ordered(sizes, order):
+ """
+ Create a stride based on the sorted order of a permuted range.
+
+ In this format, channels last would be:
+ [3, 0, 2, 1]
+ """
+ assert set(range(len(sizes))) == set(order)
+ fill_order = stride_order2fill_order(order)
+ return FlexibleLayout.fill_ordered(sizes, fill_order)
+
+ @staticmethod
+ def same_ordered(sizes, stride):
+ """
+ Create a stride that has the same stride order as given stride
+
+ For example, if given stride is [1000, 1, 100, 10],
+ the fill order should be [1, 3, 2, 0]
+ """
+ assert len(sizes) == len(stride)
+ stride = [V.graph.sizevars.size_hint(x) for x in stride]
+ fill_order = sorted(range(len(stride)), key=stride.__getitem__)
+ return FlexibleLayout.fill_ordered(sizes, fill_order)
+
+ def as_stride_order(self, order):
+ return FixedLayout(
+ self.device,
+ self.dtype,
+ self.size,
+ self.stride_ordered(self.size, order),
+ self.offset,
+ )
+
+ def as_fill_order(self, order):
+ return FixedLayout(
+ self.device,
+ self.dtype,
+ self.size,
+ self.fill_ordered(self.size, order),
+ self.offset,
+ )
+
+ def as_same_order(self, stride):
+ return FixedLayout(
+ self.device,
+ self.dtype,
+ self.size,
+ self.same_ordered(self.size, stride),
+ self.offset,
+ )
+
+ def __init__(self, device, dtype, size, stride_order=None):
+ super(FlexibleLayout, self).__init__(
+ device, dtype, size, FlexibleLayout.contiguous_strides(size)
+ )
+ self.preferred_stride_order = stride_order
+
+
+class AliasedLayout(Layout):
+ """Shares the same storage as another tensor"""
+
+ def __init__(self, view: "ReinterpretView"):
+ layout = view.get_layout()
+ super().__init__(
+ layout.device,
+ layout.dtype,
+ layout.size,
+ layout.stride,
+ )
+ self.view = view
+
+ def make_indexer(self):
+ return self.as_fixed().make_indexer()
+
+ def maybe_guard_aligned(self):
+ offset = self.view.get_layout().offset
+ if offset == 0:
+ return True
+ from .compile_fx import ALIGNMENT
+
+ return V.graph.sizevars.maybe_guard_multiple_of(offset, ALIGNMENT)
+
+
+class MutationLayout(Layout):
+ def __init__(self, target: IRNode):
+ super().__init__(
+ target.get_device(),
+ target.get_dtype(),
+ target.get_size(),
+ None, # type: ignore[arg-type]
+ )
+ self.target = target
+
+ @classmethod
+ def realize_into(cls, src, dst):
+ dst.realize()
+ V.graph.realize_users_of(dst.get_name())
+
+ if isinstance(src, TensorBox):
+ src = src.data
+
+ if not isinstance(src, StorageBox) or src.is_user_of(dst.get_name()):
+ need_copy = True
+ else:
+ src.realize()
+ need_copy = not isinstance(src.data.layout, FlexibleLayout)
+
+ if need_copy:
+ src = Pointwise.create(
+ device=src.get_device(),
+ dtype=src.get_dtype(),
+ inner_fn=src.make_loader(),
+ ranges=[
+ V.graph.sizevars.guard_equals(a, b)
+ for a, b in zip(src.get_size(), dst.get_size())
+ ],
+ ).data
+ src.realize()
+
+ assert isinstance(src.data.layout, FlexibleLayout)
+ src.data.layout = MutationLayout(dst)
+ return src.data
+
+ def as_fixed(self):
+ return self
+
+ def make_indexer(self):
+ return self.target.make_indexer()
+
+
+@dataclasses.dataclass
+class Buffer(IRNode):
+ name: str
+ layout: Layout
+
+ def make_indexer(self):
+ return self.layout.make_indexer()
+
+ def get_name(self):
+ assert self.name
+ return self.name
+
+ def get_device(self):
+ return self.layout.device
+
+ def get_dtype(self):
+ return getattr(self.layout, "dtype", None)
+
+ def get_size(self):
+ return self.layout.size
+
+ def get_stride(self):
+ return self.layout.stride
+
+ def get_layout(self):
+ return self.layout
+
+ def get_storage_numel(self):
+ return self.get_numel()
+
+ def is_extern(self):
+ return False
+
+ def freeze_layout(self):
+ if not isinstance(self.layout, MultiOutputLayout):
+ self.layout = self.layout.as_fixed()
+
+ def freeze_layout_with_stride_order(self, order):
+ assert isinstance(self.layout, FlexibleLayout)
+ self.layout = self.layout.as_stride_order(order)
+
+ def freeze_layout_with_fill_order(self, order):
+ assert isinstance(self.layout, FlexibleLayout)
+ self.layout = self.layout.as_fill_order(order)
+
+ def freeze_layout_with_same_order(self, stride):
+ assert isinstance(self.layout, FlexibleLayout)
+ self.layout = self.layout.as_same_order(stride)
+
+ def make_loader(self):
+ def loader(index):
+ indexer = self.layout.make_indexer()
+ return ops.load(self.name, indexer(index))
+
+ return loader
+
+ def is_no_op(self):
+ return False
+
+ def codegen_reference(self):
+ return self.get_name()
+
+ def decide_layout(self):
+ pass
+
+ def get_alias_names(self):
+ if isinstance(self.layout, AliasedLayout):
+ return [self.layout.view.get_name()]
+ return ()
+
+ def get_mutation_names(self):
+ if isinstance(self.layout, MutationLayout):
+ return [self.layout.target.get_name()]
+ return ()
+
+ @cache_on_self
+ def get_read_writes(self):
+ with patch.object(FlexibleLayout, "allow_indexing", True):
+ return extract_read_writes(
+ self.make_loader(),
+ self.get_size(),
+ )
+
+ def get_reads(self):
+ return self.get_read_writes().reads
+
+ def realize(self):
+ pass
+
+
+class InputBuffer(Buffer):
+ pass
+
+
+class ConstantBuffer(InputBuffer):
+ override_device = None
+
+ def make_loader(self):
+ def loader(index):
+ indexer = self.layout.make_indexer()
+ return ops.load(
+ V.graph.constant_name(self.name, self.override_device), indexer(index)
+ )
+
+ return loader
+
+ def constant_to_device(self, device):
+ return ConstantBuffer(V.graph.constant_name(self.name, device), self.layout)
+
+
+class RandSeedBuffer(ConstantBuffer):
+ def codegen_reference(self):
+ # Clone makes sure if we pass this from forwards to backwards
+ # the value does not get clobbered by the time backwards is run.
+ return self.get_name() + ".clone()"
+
+
+class NoneAsConstantBuffer(IRNode):
+ def codegen_reference(self):
+ return "None"
+
+
+@dataclasses.dataclass
+class ComputedBuffer(Buffer):
+ data: Loops
+
+ @cache_on_self
+ def get_read_writes(self):
+ with patch.object(FlexibleLayout, "allow_indexing", True):
+ if self.data.get_reduction_type():
+ return extract_read_writes(
+ self.get_store_function(),
+ self.data.get_size(),
+ self.data.get_reduction_size(),
+ )
+ else:
+ return extract_read_writes(
+ self.get_store_function(),
+ self.data.get_size(),
+ )
+
+ def get_store_function(self):
+ indexer = self.layout.as_fixed().make_indexer()
+ if self.data.get_reduction_type():
+ return partial(self.data.store_reduction, self.name, indexer)
+ else:
+ return partial(self.data.store_output, self.name, indexer)
+
+ def decide_layout(self):
+ """
+ If our layout is still flexible, try to set it based on stride orders of reads.
+
+ TODO(jansel): A better algorithm here would look at downstream consumers of this
+ value and try to do global graph-level layout optimization.
+ This is also something just begging to be autotuned.
+ """
+ if isinstance(self.layout, FlexibleLayout):
+ _, (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
+ self.data.get_size(), self.data.get_reduction_size()
+ )
+ reads = self.get_read_writes().reads
+ reads_bufs = [
+ V.graph.name_to_buffer[r.name]
+ if r.name in V.graph.name_to_buffer.keys()
+ else None
+ for r in reads
+ ]
+ priority_idx = []
+ for i, reads_buf in enumerate(reads_bufs):
+ if (
+ isinstance(reads_buf, Convolution)
+ and reads_buf.kernel != "aten.convolution"
+ ):
+ # prioritize Conv layout order
+ priority_idx.append(i)
+ # only consider reads to buffer of same size
+ reads = [
+ sympy_subs(
+ r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0}
+ )
+ for r in reads
+ ]
+
+ if reads:
+ stride_lengths = numpy.array(
+ [V.graph.sizevars.stride_hints(expr, index_vars) for expr in reads],
+ dtype=numpy.int64,
+ )
+ from .scheduler import pick_loop_order
+
+ self.freeze_layout_with_fill_order(
+ pick_loop_order(stride_lengths, self.get_size(), priority_idx)
+ )
+
+ if isinstance(self.layout, FlexibleLayout):
+ self.freeze_layout()
+
+ def simplify_and_reorder(self):
+ """
+ This is a main place where we do loop transformations in a
+ backend-agnostic way.
+
+ Here we:
+ 1) Remove any 1 dimensions
+ 2) Fuse contiguous dimensions together
+ 3) Reorder dimensions based on stride orders
+ """
+ _, args, var_ranges = dependencies.index_vars_squeeze(
+ self.data.get_size(), self.data.get_reduction_size(), prefix="q"
+ )
+ with patch.object(ConstantBuffer, "override_device", self.get_device()):
+ body = LoopBody(
+ self.get_store_function(),
+ (args if self.get_reduction_type() else args[:1]),
+ var_ranges,
+ )
+ index_formulas = [*body.indexing_exprs.values()]
+ reads_bufs = [
+ V.graph.name_to_buffer[reads_name]
+ if reads_name in V.graph.name_to_buffer.keys()
+ else None
+ for reads_name in body.reads_name2expr.keys()
+ ]
+ priority_idx = []
+ if config.triton.convolution == "aten":
+ memory_addrs = [
+ *body.reads_name2expr.values(),
+ *body.writes_name2expr.values(),
+ ]
+ else:
+ # prioritize reads layout/loop_ordering over writes
+ if len(body.reads_name2expr.values()) > 0:
+ memory_addrs = [*body.reads_name2expr.values()]
+ else:
+ memory_addrs = [*body.writes_name2expr.values()]
+ for i, reads_buf in enumerate(reads_bufs):
+ if isinstance(reads_buf, Convolution):
+ priority_idx.append(i)
+
+ index_vars = []
+ reduce_vars = []
+ index_size = []
+ reduce_size = []
+ for v, s in var_ranges.items():
+ if v in args[0]:
+ assert not reduce_vars
+ index_vars.append(v)
+ index_size.append(s)
+ else:
+ assert v in args[1]
+ reduce_vars.append(v)
+ reduce_size.append(s)
+
+ # the reordering_reindex in reads' simplify_reorder_and_tile
+ reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs)
+ for i, reads_buf in enumerate(reads_bufs):
+ if isinstance(reads_buf, ComputedBuffer) and hasattr(
+ reads_buf, "iter_reordering_reindex"
+ ):
+ reordering_reindex[i] = reads_buf.iter_reordering_reindex
+
+ def simplify_and_reorder(x_vars, sizes, reordering_reindex=None):
+ sizes, reindex0, reindex1 = self._apply_loop_reordering(
+ x_vars, sizes, memory_addrs, reordering_reindex, priority_idx
+ )
+ # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
+ x_vars = reindex0(x_vars)
+ sizes, reindex2, prune = V.graph.sizevars._simplify_loops(
+ x_vars,
+ sizes,
+ index_prevent_reordering(index_formulas, x_vars, sizes),
+ )
+ x_vars = prune(x_vars)
+ # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas)
+ # x_vars = prune(x_vars)
+ # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs)
+ reindex = fuse_reindexing(reindex1, reindex2)
+ return sizes, reindex, reindex1
+
+ iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder(
+ index_vars, index_size, reordering_reindex
+ )
+ reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
+ reduce_vars, reduce_size
+ )
+
+ # remember the reordering order
+ self.iter_reordering_reindex = iter_reordering_reindex
+ # retrace the loop body with simplification and reordering applied
+ (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
+ iter_ranges, reduce_ranges, prefix="z"
+ )
+ body = LoopBody(
+ body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges
+ )
+ return (iter_ranges, reduce_ranges), body
+
+ @staticmethod
+ def _apply_loop_reordering(
+ index_vars, sizes, memory_addrs, reordering_reindex=None, priority_idx=None
+ ):
+ """
+ Shuffle the order of loops around to hopefully improve performance.
+ """
+ from .scheduler import pick_loop_order
+
+ if priority_idx is None:
+ priority_idx = []
+
+ try:
+ strides = numpy.array(
+ [
+ V.graph.sizevars.stride_hints(expr, index_vars)
+ for expr in memory_addrs
+ ],
+ dtype=numpy.int64,
+ )
+ assert strides.shape == (len(memory_addrs), len(index_vars))
+ # consider both layout(strides) and reordering(reordering_reindex)
+ if reordering_reindex is not None:
+ for i in range(len(memory_addrs)):
+ try:
+ strides[i] = reordering_reindex[i](strides[i])
+ # if len(order) != len(strides), do not reorder
+ except AssertionError:
+ pass
+ order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
+ except Exception:
+ if config.debug:
+ log.warning(
+ f"Did not simplify complex index:\n{dict(zip(index_vars, sizes))}\n{memory_addrs}"
+ )
+ order = list(range(len(sizes)))
+ sizes = [sizes[i] for i in order]
+ return sizes, same_reorder(order), inverse_reorder(order)
+
+ def get_reduction_size(self):
+ return self.data.get_reduction_size()
+
+ def get_reduction_type(self):
+ return self.data.get_reduction_type()
+
+ def is_no_op(self):
+ return self.data.is_zero_elements()
+
+ def should_allocate(self):
+ return True
+
+ def constant_to_device(self, device):
+ """Move this to a given device. Requires that all reads are to constants."""
+ return self.data.constant_to_device(device)
+
+
+@dataclasses.dataclass
+class InputsKernel(Buffer):
+ inputs: List[Buffer]
+
+ def get_read_writes(self):
+ return dependencies.ReadWrites(
+ {dependencies.StarDep(x.get_name()) for x in self.inputs},
+ {dependencies.StarDep(self.get_name())},
+ set(),
+ [],
+ None,
+ )
+
+ @staticmethod
+ def unwrap_storage(inputs):
+ inputs_new = []
+ for x in inputs:
+ if isinstance(x, TensorBox):
+ x = x.data
+ if isinstance(x, StorageBox):
+ x = x.data
+ if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
+ x = ExternKernel.realize_input(x)
+ assert isinstance(x, (Buffer, ReinterpretView)), x
+ inputs_new.append(x)
+ return inputs_new
+
+ def is_extern(self):
+ return True
+
+
+class NopKernel(InputsKernel):
+ def is_no_op(self):
+ return True
+
+
+class ConcatKernel(NopKernel):
+ """
+ There isn't actually a real kernel for concat, we just change the
+ storage for the upstream data.
+ """
+
+ @classmethod
+ def create(cls, inputs, dim):
+ device = inputs[0].get_device()
+ dtype = inputs[0].get_dtype()
+ new_size = list(inputs[0].get_size())
+ offsets_start = [0]
+ offsets_end = [new_size[dim]]
+ assert 0 <= dim < len(new_size)
+ for i in range(1, len(inputs)):
+ input_size = inputs[i].get_size()
+ offsets_start.append(new_size[dim])
+ assert len(input_size) == len(new_size)
+ assert inputs[i].get_dtype() == dtype
+ assert inputs[i].get_device() == device
+ for j in range(len(new_size)):
+ if j == dim:
+ new_size[j] = new_size[j] + input_size[j]
+ else:
+ new_size[j] = V.graph.sizevars.guard_equals(
+ new_size[j], input_size[j]
+ )
+ offsets_end.append(new_size[dim])
+
+ kernel = ConcatKernel(
+ name=None,
+ layout=FixedLayout(
+ device=device,
+ dtype=dtype,
+ size=new_size,
+ stride=FlexibleLayout.contiguous_strides(new_size),
+ ),
+ inputs=[],
+ )
+ kernel = StorageBox(kernel)
+ for i in range(len(inputs)):
+ kernel.data.inputs.append(
+ cls.realize_into(
+ inputs[i],
+ SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]),
+ )
+ )
+ kernel.data.name = V.graph.register_buffer(kernel.data)
+ kernel.data.inputs = cls.unwrap_storage(kernel.data.inputs)
+
+ return kernel
+
+ @classmethod
+ def realize_into(cls, src, dst):
+ # Attempt to turn this into a ReinterpretView rather than assert.
+ # This has concessions around layout, as as_storage_and_layout
+ # can cause us to go from flexible to fixed layout.
+ if not isinstance(dst, ReinterpretView):
+ if is_storage_and_layout(dst):
+ storage, layout = as_storage_and_layout(dst)
+ dst = ReinterpretView(storage, layout)
+ assert isinstance(dst, ReinterpretView), dst
+ if isinstance(src, TensorBox):
+ # unwrap a TensorBox
+ return cls.realize_into(src.data, dst)
+ if isinstance(src, StorageBox):
+ src.realize()
+ # ExternKernelAlloc has specific requirements for output layout, should create a copy
+ if isinstance(src.data.layout, FlexibleLayout) and not isinstance(
+ src.data, ExternKernelAlloc
+ ):
+ src.data.layout = AliasedLayout(dst)
+ return src.data
+ # introduce a copy
+ pw = Pointwise.create(
+ device=src.get_device(),
+ dtype=src.get_dtype(),
+ inner_fn=src.make_loader(),
+ ranges=[
+ V.graph.sizevars.guard_equals(a, b)
+ for a, b in zip(src.get_size(), dst.get_size())
+ ],
+ )
+ return cls.realize_into(pw, dst)
+
+ def should_allocate(self):
+ return True
+
+
+@dataclasses.dataclass
+class ExternKernel(InputsKernel):
+ constant_args: Tuple[Any, ...] = ()
+ kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ output_view: Optional[ReinterpretView] = None
+
+ def decide_layout(self):
+ if isinstance(self.layout, FlexibleLayout):
+ self.apply_constraint()
+ self.freeze_layout()
+
+ def codegen(self, wrapper):
+ raise NotImplementedError
+
+ @staticmethod
+ def copy_input(x):
+ pw = Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=x.make_loader(),
+ ranges=x.get_size(),
+ )
+ pw.realize()
+ return pw
+
+ @classmethod
+ def convert_to_reinterpret_view(cls, x):
+ """
+ In order to pass this to an extern kernel we need a
+ ReinterpretView not a View. This allows us to avoid some
+ uneeded copies.
+ """
+ assert isinstance(x, BaseView)
+ if isinstance(x, ReinterpretView):
+ return x
+
+ x.unwrap_view().freeze_layout()
+ rw = extract_read_writes(x.make_loader(), x.get_size(), normalize=False)
+ assert len(rw.reads) == 1
+
+ index = V.graph.sizevars.simplify_with_ranges(
+ list(rw.reads)[0].index, rw.var_ranges
+ )
+ strides = V.graph.sizevars.stride_vars(index, rw.range_vars)
+ offset = V.graph.sizevars.offset_var(index, rw.range_vars)
+ expected = sympy_dot(rw.range_vars, strides) + offset
+
+ if index != expected:
+ log.debug(
+ "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
+ strides,
+ offset,
+ index,
+ )
+ raise NotImplementedError()
+
+ return ReinterpretView(
+ data=x.data,
+ layout=FixedLayout(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ size=x.get_size(),
+ stride=strides,
+ offset=offset,
+ ),
+ )
+
+ @classmethod
+ def realize_input(cls, x):
+ if x is None:
+ return NoneAsConstantBuffer()
+ if isinstance(x, Constant):
+ return V.graph.add_tensor_constant(
+ torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
+ )
+ if isinstance(x, ConstantBuffer):
+ return x
+ if isinstance(x, TensorBox):
+ return cls.realize_input(x.data)
+ if isinstance(x, ReinterpretView):
+ return x
+ if isinstance(x, BaseView):
+ x.realize()
+ if is_storage_and_layout(x.unwrap_view()) and not isinstance(
+ x.unwrap_view().data, ExternKernelAlloc
+ ):
+ try:
+ return cls.convert_to_reinterpret_view(x)
+ except NotImplementedError:
+ pass
+ if isinstance(x, StorageBox):
+ # TODO(jansel): impose layout preference on realized buffer
+ x.realize()
+ return x
+ return cls.copy_input(x)
+
+ @classmethod
+ def require_stride1(cls, x):
+ if len(x.get_stride()) == 0:
+ return x
+ for stride in x.get_stride():
+ if stride == 1:
+ return x
+ return cls.copy_input(x)
+
+ @classmethod
+ def require_contiguous(cls, x):
+ if is_contiguous_storage_and_layout(x):
+ as_contiguous_storage_and_layout(x, freeze=True)
+ return x
+ x = cls.copy_input(x)
+ assert is_contiguous_storage_and_layout(x)
+ as_contiguous_storage_and_layout(x, freeze=True)
+ return x
+
+ @classmethod
+ def require_stride_order(cls, x, order):
+ # require x to have the layout as strided_ordered as order
+ if isinstance(
+ x.get_layout(), FlexibleLayout
+ ) and is_stride_order_storage_and_layout(x, order):
+ # fix flexiblelayout to be FixedLayout with stride_order
+ as_storage_and_layout(
+ x, freeze=True, want_contiguous=False, stride_order=order
+ )
+ return x
+ elif isinstance(x.get_layout(), FixedLayout) and x.layout.is_stride_ordered(
+ order
+ ):
+ return x
+ x = cls.copy_input(x)
+ as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order)
+ assert is_stride_order_storage_and_layout(x, order)
+ return x
+
+ def apply_constraint(self):
+ pass
+
+ def codegen_args(self):
+ args = [x.codegen_reference() for x in self.inputs]
+ args.extend(map(repr, self.constant_args))
+ return args
+
+ def codegen_kwargs(self):
+ kwargs = []
+ if self.kwargs:
+ kwargs = [f"{k}={repr(v)}" for k, v in self.kwargs.items()]
+ return kwargs
+
+ def codegen_size_asserts(self, wrapper):
+ if config.size_asserts:
+ size = V.graph.sizevars.codegen_shape_tuple(self.get_size())
+ stride = V.graph.sizevars.codegen_shape_tuple(self.get_stride())
+ wrapper.writeline(f"assert {self.get_name()}.size() == {size}")
+ wrapper.writeline(f"assert {self.get_name()}.stride() == {stride}")
+
+ def get_group_stride(self):
+ """
+ get output sizes and strides, for template_codegen
+ """
+ _size = self.get_size()
+ _stride = self.get_stride()
+ # iter_ranges = _size of output tensor, reduce_range = [] because no reduction
+ return [_size, []], _stride
+
+ def canonicalize(self):
+ """
+ Manually get cononicalization of the output index
+ """
+ # manually generate index formula for conv
+ sizevars = V.graph.sizevars
+ sizes = self.get_size()
+ strides = self.get_stride()
+ strides = [sizevars.size_hint(x) for x in strides]
+ index_vars = [sympy.Symbol(f"d{i}") for i in range(len(sizes))]
+ # reorder index vars according to stride
+ index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
+ lookup = {pos: idx for idx, pos in enumerate(index_order)}
+ order = [lookup[i] for i in range(len(lookup))]
+ index_vars = [index_vars[i] for i in order]
+ indexer = self.make_indexer()
+ index = indexer(index_vars)
+
+ new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
+ index_vars, sizes, [index]
+ )
+
+ # assign new variables each dimension to deal with numbering mismatches
+ # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
+ _, add_var = var_builder("c")
+ replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
+
+ index = sympy_subs(sympy.expand(index), replacement)
+ return index, tuple(new_sizes)
+
+ def __str__(self):
+ lines = [
+ f"{field.name}={getattr(self, field.name)}"
+ for field in dataclasses.fields(self)
+ ]
+ return self.str_helper(lines)
+
+
+@dataclasses.dataclass
+class ExternKernelOut(ExternKernel):
+ output_view: Optional[ReinterpretView] = None
+
+ def codegen(self, wrapper):
+ args = self.codegen_args()
+
+ kwargs = self.codegen_kwargs()
+ if kwargs:
+ args.extend(kwargs)
+
+ if self.output_view:
+ args.append(f"out={self.output_view.codegen_reference()}")
+ else:
+ args.append(f"out={self.codegen_reference()}")
+ wrapper.writeline(f"{self.kernel}({', '.join(args)})")
+
+ def __init__(self, layout, inputs, constant_args=(), kwargs=None, output_view=None):
+ super().__init__(
+ None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
+ )
+ self.output_view = output_view
+ self.name = V.graph.register_buffer(self)
+
+ def should_allocate(self):
+ return True
+
+
+class ExternKernelAlloc(ExternKernel):
+ def codegen(self, wrapper):
+ wrapper.writeline(
+ f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
+ )
+ if isinstance(self.layout, Layout):
+ self.codegen_size_asserts(wrapper)
+
+ def __init__(self, layout, inputs, constant_args=()):
+ super().__init__(None, layout, self.unwrap_storage(inputs), constant_args)
+ self.name = V.graph.register_buffer(self)
+
+ def should_allocate(self):
+ return False
+
+ def apply_constraint(self):
+ raise NotImplementedError
+
+
+class InplaceBernoulliFallback(ExternKernel):
+ """
+ This needs to be a custom class to handle mutation properly
+ """
+
+ kernel = "aten.bernoulli_"
+
+ def codegen(self, wrapper):
+ (x,) = [t.codegen_reference() for t in self.inputs]
+ wrapper.writeline(
+ f"{self.kernel}({x}, {', '.join(map(repr, self.constant_args))})"
+ )
+
+ def should_allocate(self):
+ return False
+
+ def get_mutation_names(self):
+ assert isinstance(self.layout, MutationLayout)
+ return (self.layout.target.get_name(),)
+
+ def __init__(self, x, *constant_args):
+ super().__init__(
+ None,
+ MutationLayout(x),
+ self.unwrap_storage([x]),
+ constant_args,
+ )
+ self.name = V.graph.register_buffer(self)
+
+
+class IndexPutFallback(ExternKernel):
+ """
+ This needs to be a custom class to handle mutation and indices properly
+ """
+
+ kernel = "aten.index_put_"
+
+ def codegen(self, wrapper):
+ (x, values, *valid_indices) = [t.codegen_reference() for t in self.inputs]
+ indices = []
+ iter_valid_indices = iter(valid_indices)
+ for i, _ in enumerate(self.indices):
+ if self.indices[i] is not None:
+ indices.append(next(iter_valid_indices))
+ else:
+ indices.append("None")
+ wrapper.writeline(
+ f"{self.kernel}({x}, [{','.join(indices)}], {values}, {repr(self.constant_args[0])})"
+ )
+
+ def should_allocate(self):
+ return False
+
+ def __init__(self, x, indices, values, accumulate):
+ self.indices = indices
+ valid_indices = [i for i in indices if i is not None]
+ tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
+ super().__init__(
+ None,
+ MutationLayout(x),
+ self.unwrap_storage(tensors),
+ [accumulate],
+ )
+ self.name = V.graph.register_buffer(self)
+
+
+class MatrixMultiply(ExternKernelOut):
+ kernel = "aten.mm.out"
+
+ def __init__(
+ self, layout, inputs, constant_args=(), output_view=None, kernel="aten.mm.out"
+ ):
+ super().__init__(layout, inputs, constant_args, output_view)
+ self.kernel = kernel
+
+ @classmethod
+ def create(cls, a, b):
+ *m, k1 = a.get_size()
+ k2, n = b.get_size()
+ V.graph.sizevars.guard_equals(k1, k2)
+ a = cls.realize_input(a)
+ b = cls.realize_input(b)
+ if len(m) != 1 and not a.get_layout().is_contiguous():
+ a = cls.copy_input(a)
+ else:
+ a = cls.require_stride1(a)
+ b = cls.require_stride1(b)
+
+ # choose runtime kernel
+ config_mm = config.triton.mm
+ # default kernel is aten
+ kernel = "aten.mm.out"
+ if config_mm == "aten":
+ kernel = "aten.mm.out"
+ elif config_mm == "triton" and a.get_device().type == "cuda":
+ kernel = "triton_ops.matmul_out"
+ elif config_mm == "autotune":
+ from .codegen.autotuner import tuned_mm
+
+ kernel = tuned_mm(
+ a.get_size(),
+ b.get_size(),
+ a.get_stride(),
+ b.get_stride(),
+ a.get_device(),
+ a.get_dtype(),
+ )
+
+ return MatrixMultiply(
+ layout=FlexibleLayout(
+ device=a.get_device(),
+ dtype=a.get_dtype(),
+ size=list(m) + [n],
+ ),
+ inputs=[a, b],
+ kernel=kernel,
+ )
+
+ def get_template_tiling(self):
+ tile1, tile2 = self.get_size()
+ return (
+ tile1,
+ tile2,
+ sympy.Integer(1),
+ )
+
+ def map_args(self):
+ # a, b
+ in_args = [x.codegen_reference() for x in self.inputs]
+ # const_args = self.constant_args
+ inout_dict = OrderedDict(
+ [
+ ("A", f"{in_args[0]}"),
+ ("B", f"{in_args[1]}"),
+ ("C", f"{self.get_name()}"),
+ ]
+ )
+ # batch==1 bmm->mm
+ if len(self.get_stride()) == 3:
+ assert self.get_size()[0] == 1
+ stride_cm = self.get_stride()[1]
+ stride_cn = self.get_stride()[2]
+ else:
+ stride_cm = self.get_stride()[0]
+ stride_cn = self.get_stride()[1]
+ args_dict = OrderedDict(
+ [
+ ("M", f"{self.inputs[0].get_size()[0]}"),
+ ("N", f"{self.inputs[1].get_size()[1]}"),
+ ("K", f"{self.inputs[0].get_size()[1]}"),
+ ("stride_am", f"{self.inputs[0].get_stride()[0]}"),
+ ("stride_ak", f"{self.inputs[0].get_stride()[1]}"),
+ ("stride_bk", f"{self.inputs[1].get_stride()[0]}"),
+ ("stride_bn", f"{self.inputs[1].get_stride()[1]}"),
+ ("stride_cm", f"{stride_cm}"),
+ ("stride_cn", f"{stride_cn}"),
+ ]
+ )
+ # accumulator types
+ ACC_TYPE = (
+ "tl.float32"
+ if self.inputs[0].get_dtype()
+ in [torch.float16, torch.bfloat16, torch.float32]
+ else "tl.int32"
+ )
+ # dict for tl.constexpr
+ const_dict = OrderedDict(
+ [
+ ("GROUP_M", "8"),
+ ("ACC_TYPE", ACC_TYPE),
+ ("allow_tf32", f"{torch.backends.cuda.matmul.allow_tf32}"),
+ ]
+ )
+
+ other_dict = OrderedDict()
+
+ return inout_dict, args_dict, const_dict, other_dict
+
+
+class MatrixMultiplyAdd(ExternKernelOut):
+ def __init__(self, layout, inputs, constant_args=(), kwargs=None, output_view=None):
+ super().__init__(layout, inputs, constant_args, kwargs or {}, output_view)
+ self.kernel = "aten.addmm.out"
+
+ @classmethod
+ def create(cls, inp, a, b, beta, alpha):
+ m, k1 = a.get_size()
+ k2, n = b.get_size()
+ V.graph.sizevars.guard_equals(k1, k2)
+ inp = cls.realize_input(inp)
+ a = cls.realize_input(a)
+ b = cls.realize_input(b)
+ a = cls.require_stride1(a)
+ b = cls.require_stride1(b)
+ return MatrixMultiplyAdd(
+ layout=FlexibleLayout(
+ device=a.get_device(),
+ dtype=a.get_dtype(),
+ size=[m] + [n],
+ ),
+ inputs=[inp, a, b],
+ kwargs={"beta": beta, "alpha": alpha},
+ )
+
+
+class BatchMatrixMultiply(ExternKernelOut):
+ kernel = "aten.bmm.out"
+
+ def __init__(self, layout, inputs, constant_args=(), output_view=None):
+ super().__init__(layout, inputs, constant_args, output_view)
+ if (
+ config.triton.use_bmm
+ and len(inputs) > 0
+ and inputs[0].get_device().type == "cuda"
+ ):
+ self.kernel = "triton_bmm_out"
+
+ @classmethod
+ def create(cls, a, b):
+ b1, m, k1 = a.get_size()
+ b2, k2, n = b.get_size()
+ b3 = V.graph.sizevars.guard_equals(b1, b2)
+ V.graph.sizevars.guard_equals(k1, k2)
+ a = cls.require_stride1(cls.realize_input(a))
+ b = cls.require_stride1(cls.realize_input(b))
+
+ output_layout = FlexibleLayout(
+ device=a.get_device(),
+ dtype=a.get_dtype(),
+ size=[b3, m, n],
+ ).as_fixed()
+
+ if b3 == 1:
+ # convert to normal mm
+ data = MatrixMultiply(
+ layout=output_layout.as_fixed(),
+ inputs=[SqueezeView.create(a, dim=0), SqueezeView.create(b, dim=0)],
+ )
+ data.output_view = ReinterpretView(
+ data,
+ FlexibleLayout(
+ device=a.get_device(),
+ dtype=a.get_dtype(),
+ size=[m, n],
+ ).as_fixed(),
+ )
+ else:
+ data = BatchMatrixMultiply(
+ layout=output_layout,
+ inputs=[a, b],
+ )
+ return data
+
+
+class DeviceCopy(ExternKernelOut):
+ @classmethod
+ def create(cls, x, device):
+ if not x.is_extern() and all(
+ (r.name in V.graph.constants and hasattr(r, "index")) for r in x.get_reads()
+ ):
+ return x.constant_to_device(device)
+
+ V.graph.device_types.add(device.type)
+ V.graph.device_types.add(x.get_device().type)
+
+ log.warning("DeviceCopy")
+ return DeviceCopy(
+ FlexibleLayout(
+ device=device,
+ dtype=x.get_dtype(),
+ size=x.get_size(),
+ ),
+ [cls.realize_input(x)],
+ )
+
+ def codegen(self, wrapper):
+ args = self.codegen_args()
+ assert len(args) == 1
+ if self.output_view:
+ wrapper.writeline(
+ f"{self.output_view.codegen_reference()}.copy_({args[0]})"
+ )
+ else:
+ wrapper.writeline(f"{self.codegen_reference()}.copy_({args[0]})")
+
+
+class DynamicScalar(IRNode):
+ """
+ The result of a call to aten._local_scalar_dense.
+
+ This is not yet implemented. The one model (so far) that calls this
+ (fastNLP_Bert) does not actually use the result. So we expect this
+ node to get dead code eliminated.
+ """
+
+ def get_reads(self):
+ return ()
+
+
+class AdaptiveAvgPool2d(ExternKernelAlloc):
+ kernel = "aten._adaptive_avg_pool2d"
+
+ @classmethod
+ def create(cls, x, target_size):
+ # x = cls.require_stride1(cls.realize_input(x))
+ x = cls.realize_input(x)
+ output_size = [
+ *x.get_size()[: -len(target_size)],
+ *map(sympy.Integer, target_size),
+ ]
+ # contigouse stride order
+ stride_order = list(reversed(range(len(output_size))))
+ return cls(
+ FlexibleLayout(
+ x.get_device(),
+ x.get_dtype(),
+ output_size,
+ # TODO(jansel): fix channels last case
+ # FlexibleLayout.contiguous_strides(output_size),
+ stride_order,
+ ),
+ (x,),
+ (tuple(target_size),),
+ )
+
+ def apply_constraint(self):
+ x = self.inputs[0]
+ if isinstance(x.get_layout(), FixedLayout):
+ # fix self's layout to be the same order as x
+ self.freeze_layout_with_same_order(x.get_layout().stride)
+ else:
+ x = self.require_stride_order(x, self.layout.preferred_stride_order)
+ self.inputs[0] = x
+ self.freeze_layout_with_stride_order(self.layout.preferred_stride_order)
+
+
+@dataclasses.dataclass
+class FallbackKernel(ExternKernelAlloc):
+ def __init__(
+ self,
+ layout,
+ kernel,
+ tensor_args,
+ nontensor_args,
+ unflatten_args,
+ kwargs=None,
+ ):
+ super(FallbackKernel, self).__init__(
+ layout,
+ tuple(tensor_args),
+ tuple(nontensor_args),
+ )
+ if getattr(torch.ops.aten, kernel.__name__, None) is kernel:
+ self.kernel = f"aten.{kernel.__name__}"
+ else:
+ self.kernel = (
+ f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
+ )
+ self.unflatten_args = unflatten_args
+ self.kwargs = {} if kwargs is None else kwargs
+ if self.kernel not in ("aten.convolution_backward",):
+ log.warning(f"Using FallbackKernel: {self.kernel}")
+
+ def codegen_args(self):
+ @dataclasses.dataclass
+ class Shim:
+ ref: Any
+
+ def __repr__(self):
+ return self.ref
+
+ tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
+ constant_args = [Shim(repr(x)) for x in self.constant_args]
+
+ def gen_kwarg(k, v):
+ return f"{k}={repr(v)}"
+
+ kwargs = list(gen_kwarg(k, v) for k, v in self.kwargs.items())
+
+ return list(map(repr, self.unflatten_args(tensor_args, constant_args))) + kwargs
+
+ @classmethod
+ def create(cls, kernel, *args, **kwargs):
+ args_flat, args_spec = pytree.tree_flatten(args)
+
+ is_arg_tensor = []
+ tensor_args = []
+ non_tensor_args = []
+ for arg in args_flat:
+ is_arg_tensor.append(isinstance(arg, IRNode))
+ if is_arg_tensor[-1]:
+ tensor_args.append(arg)
+ else:
+ non_tensor_args.append(arg)
+
+ def unflatten_args(new_tensor_args, new_non_tensor_args):
+ new_args = []
+ it_tensors = iter(new_tensor_args)
+ it_non_tensors = iter(new_non_tensor_args)
+ for is_tensor in is_arg_tensor:
+ if is_tensor:
+ new_args.append(next(it_tensors))
+ else:
+ new_args.append(next(it_non_tensors))
+ return pytree.tree_unflatten(new_args, args_spec)
+
+ tensor_args = [
+ cls.require_contiguous(cls.realize_input(x)) for x in tensor_args
+ ]
+
+ # We don't have generic shape formulas, so just burn in the
+ # shapes and run an example input.
+ # TODO(jansel): replace this with dynamic shape formulas
+ example_args = [
+ torch.zeros(
+ [V.graph.sizevars.guard_static_shape(s) for s in x.get_size()],
+ dtype=x.get_dtype(),
+ device=x.get_device(),
+ )
+ for x in tensor_args
+ ]
+ example_output = kernel(
+ *unflatten_args(example_args, non_tensor_args), **kwargs
+ )
+
+ if isinstance(example_output, (list, tuple)):
+ packed = FallbackKernel(
+ MultiOutputLayout(tensor_args[0].get_device()),
+ kernel,
+ tensor_args,
+ non_tensor_args,
+ unflatten_args,
+ )
+ return [
+ (
+ MultiOutput(
+ FixedLayout(
+ example_output[i].device,
+ example_output[i].dtype,
+ [sympy.Integer(s) for s in example_output[i].size()],
+ [sympy.Integer(s) for s in example_output[i].stride()],
+ ),
+ packed,
+ i,
+ )
+ if example_output[i] is not None
+ else None
+ )
+ for i in range(len(example_output))
+ ]
+ else:
+ return FallbackKernel(
+ FixedLayout(
+ example_output.device,
+ example_output.dtype,
+ [sympy.Integer(s) for s in example_output.size()],
+ [sympy.Integer(s) for s in example_output.stride()],
+ ),
+ kernel,
+ tensor_args,
+ non_tensor_args,
+ unflatten_args,
+ kwargs,
+ )
+
+ def apply_constraint(self):
+ return super().apply_constraint()
+
+
+@dataclasses.dataclass
+class MultiOutputLayout(IRNode):
+ device: torch.device
+
+
+class MultiOutput(ExternKernel):
+ def codegen(self, wrapper):
+ wrapper.writeline(
+ f"{self.get_name()} = {self.inputs[0].get_name()}[{self.index}]"
+ )
+ self.codegen_size_asserts(wrapper)
+
+ def __init__(self, layout, input, index):
+ super().__init__(None, layout, [input], ())
+ self.name = V.graph.register_buffer(self)
+ self.index = index
+
+ def should_allocate(self):
+ return False
+
+
+class Convolution(ExternKernelAlloc):
+ kernel = "aten.convolution"
+
+ def __init__(
+ self,
+ layout,
+ inputs,
+ constant_args=(),
+ preferred_stride_order=None,
+ kernel="aten.convolution",
+ ):
+ super().__init__(layout, inputs, constant_args)
+ self.kernel = kernel
+ self.preferred_stride_order = preferred_stride_order
+
+ def codegen(self, wrapper):
+ if self.kernel == "triton_ops.conv":
+ wrapper.header.writeline(
+ f"import {config.inductor_import}.triton_ops.conv as {self.kernel}"
+ )
+ wrapper.writeline(
+ f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
+ )
+ if isinstance(self.layout, Layout):
+ self.codegen_size_asserts(wrapper)
+
+ @classmethod
+ def create(
+ cls,
+ x: "TensorBox",
+ weight: "TensorBox",
+ bias: "TensorBox",
+ stride_: List[int],
+ padding_: List[int],
+ dilation_: List[int],
+ transposed: bool,
+ output_padding_: List[int],
+ groups: int,
+ ):
+ x = cls.require_stride1(cls.realize_input(x))
+ weight = cls.require_stride1(cls.realize_input(weight))
+ stride = tuple(stride_)
+ padding = tuple(padding_)
+ dilation = tuple(dilation_)
+ assert isinstance(transposed, bool)
+ output_padding = tuple(output_padding_)
+ assert isinstance(groups, int)
+
+ weight_shape = [
+ sympy.Integer(V.graph.sizevars.guard_static_shape(s))
+ for s in weight.get_size()
+ ]
+
+ out_channels, in_channels1, *kernel_size = weight_shape
+ in_channels1 = in_channels1 * groups
+ if transposed:
+ out_channels, in_channels1 = in_channels1, out_channels
+
+ if bias is not None:
+ bias = cls.require_stride1(cls.realize_input(bias))
+ (bias_shape,) = [
+ sympy.Integer(V.graph.sizevars.guard_static_shape(s))
+ for s in bias.get_size()
+ ]
+ assert bias_shape == out_channels, f"{bias_shape} == {out_channels}"
+
+ if len(x.get_size()) == 1 + len(kernel_size):
+ in_channels2, *input_size = x.get_size()
+ in_channels_stride, *_ = x.get_stride()
+ output_size = []
+ else:
+ assert len(x.get_size()) == 2 + len(kernel_size)
+ batch, in_channels2, *input_size = x.get_size()
+ _, in_channels_stride, *_ = x.get_stride()
+ output_size = [batch]
+
+ V.graph.sizevars.guard_equals(in_channels1, in_channels2)
+
+ output_size.append(out_channels)
+
+ assert (
+ len(stride)
+ == len(padding)
+ == len(dilation)
+ == len(output_padding)
+ == len(kernel_size)
+ == len(input_size)
+ )
+ for i in range(len(stride)):
+ if transposed:
+ output_size.append(
+ (input_size[i] - 1) * stride[i]
+ - 2 * padding[i]
+ + dilation[i] * (kernel_size[i] - 1)
+ + output_padding[i]
+ + 1
+ )
+ else:
+ output_size.append(
+ IndexingDiv(
+ input_size[i]
+ + 2 * padding[i]
+ - dilation[i] * (kernel_size[i] - 1)
+ - 1
+ + stride[i],
+ stride[i],
+ )
+ + 2 * output_padding[i]
+ )
+ output_size[-1] = sympy.Integer(
+ V.graph.sizevars.guard_static_shape(output_size[-1])
+ )
+
+ # choose runtime kernel
+ config_conv = config.triton.convolution
+ if (
+ config_conv == "aten"
+ or len(kernel_size) != 2 # triton conv only supports conv2d
+ or not is_triton(x.get_device())
+ or transposed
+ or groups != 1
+ # or x.get_dtype() == torch.float16
+ # or x.get_dtype() == torch.bfloat16
+ ):
+ kernel = "aten.convolution"
+ elif config_conv == "triton":
+ kernel = "triton_ops.conv"
+ else:
+ assert config_conv == "autotune"
+ from .codegen.autotuner import tuned_conv
+
+ kernel = tuned_conv(
+ x.get_size(),
+ weight.get_size(),
+ x.get_stride(),
+ weight.get_stride(),
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ x.get_device(),
+ x.get_dtype(),
+ )
+
+ # for conv2d or conv3d, prefer channels last format
+ if kernel == "triton_ops.conv":
+ output_layout_str = "torch.channels_last"
+ elif config.tune_layout:
+ from .codegen.autotuner import tuned_conv_layout
+
+ output_layout_str = tuned_conv_layout(
+ kernel,
+ x.get_size(),
+ weight.get_size(),
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ x.get_device(),
+ x.get_dtype(),
+ )
+ else:
+ output_layout_str = "torch.contiguous_format"
+ # If x or weight have one channels_last(2d or 3d) format, it will call channels_last path,
+ # which align with aten.convolutuion path(cpu only support 2d case now).
+ # TODO: after cpu 3d convolution support channels_last path, the size check can be removed.
+ # TODO: the gpu channels_last path depend on cudnn version, see
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvUtils.h.
+ if (
+ x.get_device().type == "cpu"
+ and len(x.get_size()) == 4
+ and (
+ x.get_layout().is_channels_last_stride_ordered()
+ or weight.get_layout().is_channels_last_stride_ordered()
+ )
+ ):
+ output_layout_str = "torch.channels_last"
+
+ if output_layout_str == "torch.channels_last":
+ stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1)))
+ if len(stride_order) < len(output_size):
+ # add batch dim if it exists
+ stride_order = [len(stride_order)] + stride_order
+ else:
+ stride_order = list(reversed(range(len(output_size))))
+
+ output_layout = FlexibleLayout(
+ x.get_device(),
+ x.get_dtype(),
+ output_size,
+ stride_order,
+ )
+
+ if bias is not None:
+ return Convolution(
+ output_layout,
+ (x, weight, bias),
+ (stride, padding, dilation, transposed, output_padding, groups),
+ stride_order,
+ kernel,
+ )
+ else:
+ return Convolution(
+ output_layout,
+ (x, weight),
+ (bias, stride, padding, dilation, transposed, output_padding, groups),
+ stride_order,
+ kernel,
+ )
+
+ def apply_constraint(self):
+ x = self.inputs[0]
+ # FixedLayout of input
+ x = self.require_stride_order(x, self.layout.preferred_stride_order)
+ self.inputs[0] = x
+ self.freeze_layout_with_stride_order(self.layout.preferred_stride_order)
+
+ def map_args(self):
+ # x, w, bias
+ in_args = [x.codegen_reference() for x in self.inputs]
+ # stride, padding, dilation, transposed, output_padding, groups
+ const_args = self.constant_args
+ if len(in_args) < 3:
+ # otherwise, bias=None is the first constant_args
+ const_args = const_args[1:]
+
+ inout_dict = OrderedDict(
+ [
+ ("x", f"{in_args[0]}"),
+ ("w", f"{in_args[1]}"),
+ ("y", f"{self.get_name()}"),
+ ]
+ )
+ args_dict = OrderedDict(
+ [
+ ("stride_xn", f"{self.inputs[0].get_stride()[0]}"),
+ ("stride_xc", f"{self.inputs[0].get_stride()[1]}"),
+ ("stride_xh", f"{self.inputs[0].get_stride()[2]}"),
+ ("stride_xw", f"{self.inputs[0].get_stride()[3]}"),
+ ("stride_wn", f"{self.inputs[1].get_stride()[0]}"),
+ ("stride_wc", f"{self.inputs[1].get_stride()[1]}"),
+ ("stride_wh", f"{self.inputs[1].get_stride()[2]}"),
+ ("stride_ww", f"{self.inputs[1].get_stride()[3]}"),
+ ("stride_yn", f"{self.get_stride()[0]}"),
+ ("stride_yc", f"{self.get_stride()[1]}"),
+ ("stride_yh", f"{self.get_stride()[2]}"),
+ ("stride_yw", f"{self.get_stride()[3]}"),
+ (
+ "stride_biasn",
+ f"{self.inputs[0].get_stride()[0]}"
+ if len(in_args) >= 3
+ else "None",
+ ),
+ # ("delta_x_ptr", "None"),
+ ("BATCH", f"{self.inputs[0].get_size()[0]}"),
+ ("IN_C", f"{self.inputs[0].get_size()[1]}"),
+ ("IN_H", f"{self.inputs[0].get_size()[2]}"),
+ ("IN_W", f"{self.inputs[0].get_size()[3]}"),
+ ("KERNEL_N", f"{self.inputs[1].get_size()[0]}"),
+ ("KERNEL_H", f"{self.inputs[1].get_size()[2]}"),
+ ("KERNEL_W", f"{self.inputs[1].get_size()[3]}"),
+ ("OUT_H", f"{self.get_size()[2]}"),
+ ("OUT_W", f"{self.get_size()[3]}"),
+ ("stride_h", f"{const_args[0][0]}"),
+ ("stride_w", f"{const_args[0][1]}"),
+ ("padding_h", f"{const_args[1][0]}"),
+ ("padding_w", f"{const_args[1][1]}"),
+ ("dilation_h", f"{const_args[2][0]}"),
+ ("dilation_w", f"{const_args[2][1]}"),
+ # ("transposed", f"{const_args[3]}"),
+ ("output_padding_h", f"{const_args[4][0]}"),
+ ("output_padding_w", f"{const_args[4][1]}"),
+ ("groups", f"{const_args[5]}"),
+ ]
+ )
+
+ # accumulator type
+ ACC_TYPE = (
+ "tl.float32"
+ if self.inputs[0].get_dtype()
+ in [torch.float16, torch.bfloat16, torch.float32]
+ else "tl.int32"
+ )
+ CONV1X1_NHWC = (
+ "True"
+ if self.inputs[0].get_stride()[1] == 1
+ and self.inputs[1].get_size()[2] == 1
+ and self.inputs[1].get_size()[3] == 1
+ else "False"
+ )
+ # dict for tl.constexpr
+ const_dict = OrderedDict(
+ [
+ ("ACC_TYPE", ACC_TYPE),
+ ("CONV1X1_NHWC", CONV1X1_NHWC),
+ ]
+ )
+
+ # dict for non-kernel args (e.g. delta_x_ptr)
+ other_dict = OrderedDict(
+ [
+ ("device", f'"{self.inputs[0].get_device()}"'),
+ ]
+ )
+
+ return inout_dict, args_dict, const_dict, other_dict
+
+ def get_template_tiling(self):
+ n, c, h, w = self.get_size()
+ return (
+ n * h * w,
+ c,
+ sympy.Integer(1),
+ )
+
+
+@dataclasses.dataclass
+class MutableBox(IRNode):
+ """
+ TensorBox / StorageBox allow in-place mutation of Tensors
+ """
+
+ data: IRNode
+
+ def __getattr__(self, name):
+ fn = getattr(self.data, name)
+ if callable(fn):
+ return fn
+ raise AttributeError(f"{type(self.data).__name__}.{name} not callable")
+
+ def __str__(self):
+ if isinstance(self.data, MutableBox):
+ line0 = f"{type(self).__name__}({type(self.data).__name__}("
+ endl = "))"
+ inner = self.data.data
+ else:
+ line0 = f"{type(self).__name__}("
+ inner = self.data
+ endl = ")"
+
+ lines = [
+ line0,
+ indent(str(inner)),
+ endl,
+ ]
+ return "\n".join(lines)
+
+ __repr__ = __str__
+
+
+class TensorBox(MutableBox):
+ @staticmethod
+ def create(data):
+ return TensorBox(StorageBox(data))
+
+
+class StorageBox(MutableBox):
+ def is_input_buffer(self):
+ if isinstance(self.data, (InputBuffer, ReinterpretView)):
+ return self.data.get_name() in V.graph.graph_inputs
+ return False
+
+ def realize(self):
+ if isinstance(
+ self.data, (ComputedBuffer, InputsKernel, InputBuffer, ReinterpretView)
+ ):
+ return self.data.get_name()
+ assert isinstance(self.data, (Pointwise, Reduction)), type(self.data)
+ self.data = ComputedBuffer(
+ name=None,
+ layout=FlexibleLayout(
+ device=self.data.get_device(),
+ dtype=self.data.get_dtype(),
+ size=self.data.get_size(),
+ ),
+ data=self.data,
+ )
+ self.data.name = V.graph.register_buffer(self.data)
+ return self.data.name
+
+ def realize_hint(self):
+ """
+ Called on buffers we expect to be forced to realize later.
+ """
+ if isinstance(self.data, (Pointwise, Reduction)) and self.num_reads() > 1:
+ self.realize()
+
+ def mark_reuse(self, users):
+ """
+ A heuristic to decide if we should realize a tensor
+ that is used multiple times.
+ """
+
+ def should_realize_on_cpu(loops: Union[Pointwise, Reduction]):
+ """
+ The heuristic for realizing reused result of heavy ops on cpu
+ """
+ heavy_ops = ["exp"] # a list of heavy ops
+ fn_str = loops.inner_fn_str()
+ return any([fn_str.startswith(op + "(") for op in heavy_ops])
+
+ if (
+ users > 1
+ and isinstance(self.data, (Pointwise, Reduction))
+ and (
+ self.num_reads() > config.realize_reads_threshold
+ or len(self.inner_fn_str()) > config.realize_bytes_threshold
+ or (is_cpu(self.data) and should_realize_on_cpu(self.data))
+ )
+ ):
+ self.realize()
+
+ @cache_on_self
+ def num_reads(self):
+ data = self.data
+ if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)):
+ return 1
+ if isinstance(data, ComputedBuffer):
+ read_writes = data.get_read_writes()
+ else:
+ assert isinstance(data, (Pointwise, Reduction)), type(data)
+ read_writes = ComputedBuffer(
+ name=None,
+ layout=FlexibleLayout(
+ device=data.get_device(),
+ dtype=data.get_dtype(),
+ size=data.get_size(),
+ ),
+ data=data,
+ ).get_read_writes()
+ return len(read_writes.reads)
+
+
+class LoopBody:
+ """
+ Captures the body of a Loops subclass into an FX graph. Persists any
+ indexing simplifications and makes it easier to analyze loop bodies.
+ """
+
+ def __init__(self, fn, args, var_ranges):
+ super().__init__()
+ self.var_ranges = var_ranges
+ self.indexing_exprs = {}
+ self.indexing_exprs_name = {}
+ self.reads = []
+ self.writes = []
+ self.reads_name2expr = {}
+ self.writes_name2expr = {}
+ self.other = []
+ self.submodules = {"get_index": self.get_index}
+ self.subblocks = {}
+ self.indirect_vars = []
+ self.root_block = LoopBodyBlock(self, fn, args)
+ self.indexing = None
+
+ def debug_str(self):
+ lines = [f"var_ranges = {dict(self.var_ranges)}"]
+ lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
+ lines.extend(
+ [
+ block.debug_str(name)
+ for name, block in itertools.chain(
+ [("body", self.root_block)], self.subblocks.items()
+ )
+ ]
+ )
+ return "\n".join(lines)
+
+ def add_index_expr(self, expr: sympy.Expr, category, buf_name):
+ getattr(self, category).append(expr)
+ if buf_name is not None:
+ getattr(self, f"{category}_name2expr")[buf_name] = expr
+ if expr not in self.indexing_exprs_name:
+ name = f"index{len(self.indexing_exprs)}"
+ self.indexing_exprs_name[expr] = name
+ self.indexing_exprs[name] = expr
+ return self.indexing_exprs_name[expr]
+
+ def add_submodule(self, block, prefix):
+ """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
+ if prefix[-1].isnumeric() and prefix not in self.submodules:
+ name = prefix
+ else:
+ name = f"{prefix}{len(self.submodules)}"
+ self.submodules[name] = block
+ return name
+
+ def add_indirect(self):
+ name = f"indirect{len(self.indirect_vars)}"
+ var = sympy.Symbol(name)
+ self.indirect_vars.append([var])
+ return var
+
+ def replace_indirect(self, old, new):
+ """Swap in a variable used in indirect indexing"""
+ if str(old) == str(new):
+ return
+ self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
+
+ def get_index(self, name):
+ return self.indexing[name]
+
+ def __call__(self, *indices):
+ index = list(itertools.chain(*indices))
+ assert len(index) == len(self.var_ranges), (index, self.var_ranges)
+ assert all(v not in self.var_ranges for v in index)
+ replacements = dict(zip(self.var_ranges.keys(), index))
+ self.indexing = {
+ name: sympy_subs(expr, replacements)
+ for name, expr in self.indexing_exprs.items()
+ }
+ result = self.root_block()
+ self.indexing = None
+ return result
+
+
+class LoopBodyBlock:
+ """
+ Captures the body of a Loops subclass into an FX graph.
+ In normal cases there will be a 1:1 mapping between LoopBody and
+ LoopBodyBlock, hower in the case of ops.masked() the masked out
+ operations will manifest as an extra LoopBodyBlock.
+ """
+
+ def __init__(self, body: LoopBody, fn: Callable, args: List[Any]):
+ self.body = body
+
+ def add_index(expr, category, buf_name=None):
+ return tracer.create_proxy(
+ "call_module",
+ "get_index",
+ (self.body.add_index_expr(expr, category, buf_name),),
+ {},
+ )
+
+ class CaptureIndexing(V.WrapperHandler):
+ def load(self, name: str, index: sympy.Expr):
+ index = add_index(index, "reads", name)
+ return self._inner.load(name, index)
+
+ def store(self, name, index, value, mode=None):
+ index = add_index(index, "writes", name)
+ return self._inner.store(name, index, value, mode)
+
+ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
+ index = add_index(index, "writes", name)
+ return self._inner.reduction(
+ name, dtype, src_dtype, reduction_type, index, value
+ )
+
+ def index_expr(self, index, dtype):
+ if isinstance(index, (int, sympy.Integer)):
+ return ops.constant(int(index), dtype)
+ index = add_index(index, "other")
+ return self._inner.index_expr(index, dtype)
+
+ @staticmethod
+ def masked(mask_proxy, masked_body: Callable, other_proxy):
+ """
+ Recursively capture the masked out body in another LoopBodyBlock
+ """
+
+ def shim(mask, other):
+ return V.ops.masked(mask, subblock, other)
+
+ name = self.body.add_submodule(shim, "masked_subblock")
+ subblock = LoopBodyBlock(self.body, masked_body, [])
+ self.body.subblocks[name] = subblock
+ return tracer.create_proxy(
+ "call_module", name, (mask_proxy, other_proxy), {}
+ )
+
+ @staticmethod
+ def indirect_indexing(index_proxy):
+ """
+ Flow data from tensors into indexing formulas.
+ Introduce a call_module to update the indexing.
+ """
+
+ def set_indirect(new_var):
+ self.body.replace_indirect(var, V.ops.indirect_indexing(new_var))
+
+ var = self.body.add_indirect()
+ tracer.create_proxy(
+ "call_module",
+ self.body.add_submodule(set_indirect, f"set_{var}"),
+ (index_proxy,),
+ {},
+ )
+ return var
+
+ tracer = torch.fx.Tracer()
+ tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
+ proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
+ from .sizevars import SimplifyIndexing
+
+ with V.set_ops_handler(
+ SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges)
+ ):
+ tracer.create_proxy("output", "output", (fn(*args),), {})
+ self.graph = tracer.graph
+
+ def __call__(self):
+ graph = self.graph
+ submodules = self.body.submodules
+
+ class InterpreterShim(torch.fx.Interpreter):
+ def __init__(self):
+ """
+ We don't call super() here to avoid constructing a
+ GraphModule which is very expensive (it does codegen).
+ """
+ self.module = self
+ self.graph = graph
+ self.submodules = submodules
+ self.garbage_collect_values = False
+ self.env = {}
+ self.fetch_attr = submodules.__getitem__
+
+ return InterpreterShim().run(V.get_ops_handler())
+
+ def debug_str(self, name="block"):
+ code = torch.fx.GraphModule(self.body.submodules, self.graph).code
+ return re.sub(
+ # strip `; del var0` suffixes to make output prettier
+ r";[^\n]*",
+ "",
+ code.strip().replace("def forward(", f"def {name}("),
+ )
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
new file mode 100644
index 0000000000000..90657b7db1d83
--- /dev/null
+++ b/torch/_inductor/lowering.py
@@ -0,0 +1,3301 @@
+import functools
+import itertools
+import logging
+import operator
+from collections.abc import Iterable
+from typing import List, Optional, Tuple
+
+import sympy
+
+import torch
+import torch.fx
+from torch._prims_common import (
+ elementwise_dtypes,
+ ELEMENTWISE_TYPE_PROMOTION_KIND,
+ is_boolean_dtype,
+ is_integer_dtype,
+ Number,
+)
+
+from . import config, ir, overrides
+from .decomposition import decompositions, get_decompositions
+from .ir import (
+ ExpandView,
+ PermuteView,
+ Pointwise,
+ Reduction,
+ SqueezeView,
+ TensorBox,
+ View,
+)
+from .utils import ceildiv, has_torchvision_roi_align, sympy_product
+from .virtualized import ops, V
+
+log = logging.getLogger(__name__)
+lowerings = {}
+fallbacks = set()
+aten = torch.ops.aten
+prims = torch.ops.prims
+needs_realized_inputs = set()
+
+
+def add_needs_realized_inputs(fn):
+ if isinstance(fn, (list, tuple, set)):
+ return [add_needs_realized_inputs(x) for x in fn]
+ needs_realized_inputs.add(fn)
+ if isinstance(fn, torch._ops.OpOverloadPacket):
+ for overload in fn.overloads():
+ needs_realized_inputs.add(getattr(fn, overload))
+
+
+add_needs_realized_inputs(
+ [
+ aten.as_strided,
+ aten.avg_pool2d,
+ aten.avg_pool2d_backward,
+ aten.bmm,
+ aten.convolution,
+ aten.convolution_backward,
+ aten.max_pool2d_with_indices,
+ aten.max_pool2d_with_indices_backward,
+ aten.mm,
+ aten.upsample_bilinear2d,
+ aten.upsample_nearest2d,
+ aten.upsample_bicubic2d,
+ ]
+)
+
+# TODO(jansel): ezyang says we won't need this in the future, try removing it
+# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28
+DTYPE_ID_LOOKUP = {
+ 0: torch.uint8,
+ 1: torch.int8,
+ 2: torch.int16,
+ 3: torch.int32,
+ 4: torch.int64,
+ 5: torch.float16,
+ 6: torch.float32,
+ 7: torch.float64,
+ 8: torch.complex32,
+ 9: torch.complex64,
+ 10: torch.complex32,
+ 11: torch.bool,
+ 15: torch.bfloat16,
+ # TODO(jansel): add quantized types?
+ # _(c10::qint8, QInt8) /* 12 */
+ # _(c10::quint8, QUInt8) /* 13 */
+ # _(c10::qint32, QInt32) /* 14 */
+ # _(c10::quint4x2, QUInt4x2) /* 16 */
+ # _(c10::quint2x4, QUInt2x4) /* 17 */
+}
+
+
+def decode_dtype(dtype: int):
+ if not isinstance(dtype, int):
+ return dtype
+ assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP"
+ dtype = DTYPE_ID_LOOKUP[dtype]
+ return dtype
+
+
+def is_integer_type(x):
+ if isinstance(x, TensorBox):
+ return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
+ else:
+ return isinstance(x, int)
+
+
+def is_boolean_type(x):
+ if isinstance(x, TensorBox):
+ return is_boolean_dtype(x.get_dtype())
+ else:
+ return isinstance(x, bool)
+
+
+def decode_device(device):
+ if device is None:
+ return torch.tensor(0.0).device # default device
+ if isinstance(device, str):
+ device = torch.device(device)
+ if device.type == "cuda" and device.index is None:
+ return torch.device("cuda", index=torch.cuda.current_device())
+ return device
+
+
+def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
+ def construct_input(inp):
+ if isinstance(inp, Number):
+ return inp
+ else:
+ assert hasattr(inp, "get_dtype")
+ dim = len(inp.get_size())
+ # construct a tmp tensor to feed into torch.result_type
+ return torch.zeros([1] * dim, dtype=inp.get_dtype())
+
+ inps = [construct_input(arg) for arg in args]
+ _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind)
+ return dtype
+
+
+def _register_lowering(
+ aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool
+):
+ """
+ Add a lowering to lowerings dict
+
+ Arguments:
+ aten_fn: torch.ops.aten.* fn we are lowering
+ decomp_fn: alternate implementation on our IR
+ broadcast: True to apply broadcasting to tensor inputs
+ type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
+ convert_input_to_bool: some logical ops require inputs are converted to bool
+ """
+
+ @functools.wraps(decomp_fn)
+ def wrapped(*args, **kwargs):
+ args = list(args)
+ # Only look at args that are Tensors
+ indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
+ # kwargs tensors not supported yet
+ assert not any(isinstance(x, TensorBox) for x in kwargs.values())
+
+ if (type_promotion_kind or convert_input_to_bool) and indices:
+ if convert_input_to_bool:
+ dtype = torch.bool
+ else:
+ # FIXME that's a crude approximation for promoting args
+ promoting_args = [
+ a for a in args if isinstance(a, Number) or hasattr(a, "get_dtype")
+ ]
+ dtype = get_promoted_dtype(
+ *promoting_args, type_promotion_kind=type_promotion_kind
+ )
+ for i in indices:
+ args[i] = to_dtype(args[i], dtype)
+ for i in range(len(args)):
+ if isinstance(args[i], ir.Constant):
+ args[i] = ir.Constant(
+ args[i].value, dtype, args[indices[0]].get_device()
+ )
+
+ if broadcast and indices:
+ for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
+ args[i] = x
+ for i in range(len(args)):
+ if isinstance(args[i], ir.Constant):
+ args[i] = ExpandView.create(
+ args[i], list(args[indices[0]].get_size())
+ )
+
+ return decomp_fn(*args, **kwargs)
+
+ if not isinstance(aten_fn, (list, tuple)):
+ aten_fn = [aten_fn]
+ else:
+ aten_fn = list(aten_fn)
+
+ for fn in list(aten_fn):
+ if isinstance(fn, torch._ops.OpOverloadPacket):
+ for overload in fn.overloads():
+ other_fn = getattr(fn, overload)
+ if other_fn not in lowerings:
+ aten_fn.append(other_fn)
+
+ lowerings.update({fn: wrapped for fn in aten_fn})
+ return wrapped
+
+
+def register_lowering(
+ aten_fn,
+ broadcast=False,
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+ convert_input_to_bool=False,
+):
+ """
+ Shim to support decorator syntax.
+ """
+ return functools.partial(
+ _register_lowering,
+ aten_fn,
+ broadcast=broadcast,
+ type_promotion_kind=type_promotion_kind,
+ convert_input_to_bool=convert_input_to_bool,
+ )
+
+
+def broadcast_symbolic_shapes(a, b):
+ """
+ Broadcasting logic based on symbolic shapes.
+
+ We give the shapes 0 and 1 concrete values, while all other shapes
+ are symbolic sympy formulas.
+ """
+ output = []
+ for a, b in itertools.zip_longest(
+ reversed(a), reversed(b), fillvalue=sympy.Integer(1)
+ ):
+ if b == 1:
+ output.append(a)
+ elif a == 1:
+ output.append(b)
+ else:
+ V.graph.sizevars.guard_equals(a, b)
+ if len(sympy.expand(b).free_symbols) < len(sympy.expand(a).free_symbols):
+ output.append(b) # prefer shorter formula
+ else:
+ output.append(a)
+ return tuple(reversed(output))
+
+
+def promote_constants(inputs, override_return_dtype=None):
+ if not any(isinstance(x, (int, float)) for x in inputs):
+ return inputs
+ if all(isinstance(x, (int, float)) for x in inputs):
+ dtype = override_return_dtype or get_promoted_dtype(
+ *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ )
+ return [ir.Constant(x, dtype, decode_device(None)) for x in inputs]
+ ex = next(x for x in inputs if isinstance(x, TensorBox))
+ return [
+ (
+ ExpandView.create(
+ ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size())
+ )
+ if isinstance(x, (int, float))
+ else x
+ )
+ for x in inputs
+ ]
+
+
+def make_pointwise(
+ fn,
+ override_return_dtype=None,
+ override_device=None,
+ override_fn_when_input_bool=None,
+ allow_alpha=False,
+):
+ def inner(*inputs: List[TensorBox], alpha=None):
+ inputs = promote_constants(inputs, override_return_dtype)
+ if allow_alpha:
+ if alpha is not None and alpha != 1:
+ inputs = list(inputs)
+ inputs[-1] = mul(inputs[-1], alpha)
+ else:
+ assert alpha is None
+ loaders = [x.make_loader() for x in inputs]
+ ranges = inputs[0].get_size()
+ dtype = override_return_dtype or inputs[0].get_dtype()
+
+ for other in inputs[1:]:
+ assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
+ other.get_size()
+ ), f"ndim mismatch {fn} {ranges} {other.get_size()}"
+
+ def inner_fn(index):
+ assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
+ if dtype == torch.bool and override_fn_when_input_bool is not None:
+ return override_fn_when_input_bool(*[load(index) for load in loaders])
+ else:
+ return fn(*[load(index) for load in loaders])
+
+ return Pointwise.create(
+ device=override_device or inputs[0].get_device(),
+ dtype=dtype,
+ inner_fn=inner_fn,
+ ranges=ranges,
+ )
+
+ return inner
+
+
+@register_lowering(prims.convert_element_type, type_promotion_kind=None)
+def to_dtype(x: TensorBox, dtype: torch.dtype):
+ if x.get_dtype() == dtype:
+ return x
+
+ def _to_dtype(x):
+ return ops.to_dtype(x, dtype)
+
+ return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
+
+
+def to_device(x: TensorBox, device: torch.device):
+ device = decode_device(device)
+ if x.get_device() == device:
+ return x
+ return TensorBox.create(ir.DeviceCopy.create(x, device))
+
+
+@register_lowering(aten._to_copy)
+def _to_copy(
+ x,
+ *,
+ dtype=None,
+ layout=None,
+ device=None,
+ pin_memory=None,
+ non_blocking=False,
+ memory_format=None,
+):
+ assert not layout or layout == torch.strided, "TODO"
+ assert not pin_memory, "TODO"
+ assert not memory_format, "TODO"
+ if device:
+ device = decode_device(device)
+ if device is not None and device != x.get_device():
+ if dtype is not None and device.type == "cpu":
+ # CPU can do fewer type conversions
+ x = to_dtype(x, decode_dtype(dtype))
+ x = to_device(x, device)
+ if dtype is not None:
+ x = to_dtype(x, decode_dtype(dtype))
+ return x
+
+
+@register_lowering(aten.to)
+def to(
+ x,
+ device_or_dtype=None,
+ non_blocking=False,
+ copy=False,
+ memory_format=None,
+ device=None,
+ dtype=None,
+ layout=None,
+):
+ assert not memory_format, "TODO"
+ assert layout in (None, torch.strided)
+ if isinstance(device_or_dtype, torch.dtype):
+ return to_dtype(x, device_or_dtype)
+ elif isinstance(device_or_dtype, torch.device):
+ return to_device(x, device_or_dtype)
+ else:
+ assert device_or_dtype is None, device_or_dtype
+
+ if device is not None:
+ x = to_device(x, device)
+ if dtype is not None:
+ x = to_dtype(x, dtype)
+ return x
+
+
+def ops_wrapper(name):
+ assert isinstance(name, str)
+
+ def fn(*args, **kwargs):
+ return getattr(ops, name)(*args, **kwargs)
+
+ return fn
+
+
+def register_pointwise(
+ aten_fn,
+ name=None,
+ broadcast=True,
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+ convert_input_to_bool=False,
+ override_return_dtype=None,
+ override_fn_when_input_bool=None,
+ allow_alpha=False,
+):
+ """A pointwise function that maps ops.{name} to inputs"""
+ name = name or aten_fn.__name__
+ fn = ops_wrapper(name)
+ if override_fn_when_input_bool is not None:
+ override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool)
+
+ fn = make_pointwise(
+ fn,
+ override_return_dtype=override_return_dtype,
+ override_fn_when_input_bool=override_fn_when_input_bool,
+ allow_alpha=allow_alpha,
+ )
+ fn = register_lowering(
+ aten_fn,
+ broadcast=broadcast,
+ type_promotion_kind=type_promotion_kind,
+ convert_input_to_bool=convert_input_to_bool,
+ )(fn)
+
+ if hasattr(prims, name):
+ register_lowering(
+ getattr(prims, name),
+ type_promotion_kind=None,
+ convert_input_to_bool=convert_input_to_bool,
+ )(fn)
+ return fn
+
+
+@register_lowering(aten.where, broadcast=True, type_promotion_kind=None)
+def where(cond, a, b):
+ def fn(*args):
+ return ops.where(*args)
+
+ if isinstance(a, (float, int)):
+ a = constant_like(a)(b)
+ if isinstance(b, (float, int)):
+ b = constant_like(b)(a)
+
+ dtype = torch.promote_types(a.get_dtype(), b.get_dtype())
+ return make_pointwise(fn, override_return_dtype=dtype)(
+ cond, to_dtype(a, dtype), to_dtype(b, dtype)
+ )
+
+
+@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None)
+def broadcast_tensors(*inputs):
+ if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
+ return broadcast_tensors(*inputs[0])
+ target = functools.reduce(
+ broadcast_symbolic_shapes, [x.get_size() for x in inputs], ()
+ )
+ outputs = []
+ for x in inputs:
+ sizes = x.get_size()
+ if len(sizes) != len(target) or any(
+ ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target)
+ ):
+ x = expand(x, target)
+ outputs.append(x)
+ return outputs
+
+
+@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
+def nop(x):
+ return x # AOT autograd handles this for us
+
+
+if hasattr(aten, "lift_fresh"):
+ register_lowering(aten.lift_fresh)(nop)
+
+
+@register_lowering(aten.squeeze, type_promotion_kind=None)
+def squeeze(x, dim=None):
+ assert isinstance(x, TensorBox)
+ if dim is None:
+ return TensorBox(SqueezeView.create(x.data))
+
+ dim = _validate_dim(x, dim, 0)
+ new_shape = list(x.get_size())
+ removed = new_shape.pop(dim)
+ if V.graph.sizevars.maybe_guard_equals(removed, 1):
+ return view(x, new_shape)
+
+ # squeeze does nothing if the size isn't 1
+ return x
+
+
+@register_lowering([aten.squeeze_])
+def squeeze_(x, dim=None):
+ val = squeeze(x, dim)
+ assert isinstance(x, TensorBox)
+ assert isinstance(val, TensorBox)
+ x.data = val.data
+ return x
+
+
+@register_lowering(aten.isinf)
+def isinf(x):
+ if is_integer_type(x):
+ return full_like(x, False, dtype=torch.bool)
+ fn = ops_wrapper("isinf")
+ return make_pointwise(fn, override_return_dtype=torch.bool)(x)
+
+
+@register_lowering(aten.isnan)
+def isnan(x):
+ if is_integer_type(x):
+ return full_like(x, False, dtype=torch.bool)
+ fn = ops_wrapper("isnan")
+ return make_pointwise(fn, override_return_dtype=torch.bool)(x)
+
+
+@register_lowering(aten.ceil)
+def ceil(x):
+ if is_integer_type(x):
+ return x
+ fn = ops_wrapper("ceil")
+ return make_pointwise(fn)(x)
+
+
+@register_lowering(aten.floor)
+def floor(x):
+ if is_integer_type(x):
+ return x
+ fn = ops_wrapper("floor")
+ return make_pointwise(fn)(x)
+
+
+@register_lowering(aten.round)
+def round(x):
+ if is_integer_type(x):
+ return x
+ fn = ops_wrapper("round")
+ return make_pointwise(fn)(x)
+
+
+@register_lowering(aten.trunc)
+def trunc(x):
+ if is_integer_type(x):
+ return x
+ fn = ops_wrapper("trunc")
+ return make_pointwise(fn)(x)
+
+
+@register_lowering(aten.expand, type_promotion_kind=None)
+def expand(x, sizes):
+ if isinstance(x, ir.BaseConstant):
+ return ExpandView.create(x, tuple(sizes))
+ assert isinstance(x, TensorBox)
+ assert isinstance(sizes, (list, tuple))
+ if tuple(x.get_size()) == tuple(sizes):
+ return x
+
+ x_size_product = sympy_product(x.get_size())
+ try:
+ if x_size_product > 0:
+ x.mark_reuse(
+ V.graph.sizevars.size_hint(sympy_product(sizes) / x_size_product)
+ )
+ except TypeError:
+ # Certain sympy products cannot be compared, fails with
+ # cannot determine truth value of Relational
+ pass
+ return TensorBox(ExpandView.create(x.data, tuple(sizes)))
+
+
+@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None)
+def broadcast_in_dim(a, shape, broadcast_dimensions):
+ s = list(shape)
+ for broadcast_dimension in broadcast_dimensions:
+ s[broadcast_dimension] = -1
+
+ v = a
+ for idx, x in enumerate(s):
+ if x != -1:
+ v = unsqueeze(v, idx)
+
+ return expand(v, shape)
+
+
+@register_lowering(aten.expand_as, type_promotion_kind=None)
+def expand_as(x, y):
+ return expand(x, y.get_size())
+
+
+@register_lowering(aten.repeat)
+def repeat(x, repeats):
+ old_size = list(x.get_size())
+ if len(repeats) > len(old_size):
+ old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size
+ x = view(x, list(old_size))
+ assert len(repeats) == len(x.get_size())
+
+ new_size = list(x.get_size())
+
+ for i in range(len(repeats)):
+ assert repeats[i] >= 1
+ if repeats[i] > 1:
+ new_size[i] = new_size[i] * repeats[i]
+
+ if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)):
+ return expand(x, new_size)
+
+ def inner_fn(index):
+ assert len(index) == len(repeats)
+ index = list(index)
+ for i in range(len(repeats)):
+ if repeats[i] > 1:
+ if old_size[i] == 1:
+ index[i] = sympy.Integer(0)
+ else:
+ index[i] = ir.ModularIndexing(index[i], 1, old_size[i])
+ return x_loader(index)
+
+ old_size_product = sympy_product(old_size)
+ try:
+ if old_size_product > 0:
+ x.mark_reuse(
+ V.graph.sizevars.size_hint(sympy_product(new_size) / old_size_product)
+ )
+ except TypeError:
+ # Certain sympy products cannot be compared, fails with
+ # cannot determine truth value of Relational
+ pass
+
+ x_loader = x.make_loader()
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=inner_fn,
+ ranges=list(new_size),
+ )
+
+
+@register_lowering(aten._unsafe_view, type_promotion_kind=None)
+@register_lowering(aten.view, type_promotion_kind=None)
+@register_lowering(aten.reshape, type_promotion_kind=None)
+def view(x, sizes):
+ assert isinstance(x, TensorBox)
+ assert isinstance(sizes, (list, tuple))
+ return TensorBox(View.create(x.data, sizes))
+
+
+@register_lowering(aten.permute, type_promotion_kind=None)
+def permute(x, dims):
+ assert isinstance(x, TensorBox)
+ assert isinstance(dims, (list, tuple))
+ return TensorBox(PermuteView.create(x.data, tuple(dims)))
+
+
+@register_lowering(aten.slice, type_promotion_kind=None)
+def slice_(x, dim=0, start=0, end=2**63, step=1):
+ assert isinstance(x, TensorBox)
+ dim = _validate_dim(x, dim, 0)
+ return TensorBox(ir.SliceView.create(x.data, dim, start, end, step))
+
+
+@register_lowering(aten.roll, type_promotion_kind=None)
+def roll(a, shifts, dims=tuple()):
+ """
+ This is based on torch._refs.roll(), but uses ir.ModularIndexing().
+
+ We can't use the ref here because it is based on multiple calls to
+ torch.cat() that this will result in terrible code.
+ """
+ # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
+ if not isinstance(shifts, Iterable):
+ shifts = (shifts,)
+ if not isinstance(dims, Iterable):
+ dims = (dims,)
+ dims = [_validate_dim(a, d) for d in dims]
+
+ if sympy_product(a.get_size()) == 0:
+ return clone(a)
+
+ len_shifts = len(shifts)
+ len_dims = len(dims)
+ if len_shifts != 1 or len_dims != 1:
+ if len_shifts == 0:
+ raise RuntimeError("`shifts` required")
+ # Takes care of the case when dims is not specified (default)
+ # By default, the tensor is flattened before shifting, after which the original shape is restored
+ if len_dims == 0 and len_shifts == 1:
+ flat = view(a, [sympy_product(a.get_size())])
+ rolled = roll(flat, shifts, 0)
+ return view(rolled, list(a.get_size()))
+ if len_shifts != len_dims:
+ raise RuntimeError(
+ f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
+ )
+ tail_shifts = shifts[1:]
+ tail_dims = dims[1:]
+ first_dim_rolled = roll(a, shifts[0], dims[0])
+ return roll(first_dim_rolled, tail_shifts, tail_dims)
+
+ (dim,) = dims
+ size = V.graph.sizevars.guard_static_shape(a.get_size()[dim])
+ start = (size - shifts[0]) % size
+ a_loader = a.make_loader()
+
+ def fn(index):
+ index = list(index)
+ index[dim] = ir.ModularIndexing(
+ index[dim] + start, sympy.Integer(1), sympy.expand(size)
+ )
+ return a_loader(index)
+
+ return Pointwise.create(
+ device=a.get_device(),
+ dtype=a.get_dtype(),
+ inner_fn=fn,
+ ranges=a.get_size(),
+ )
+
+
+@register_lowering(aten.as_strided, type_promotion_kind=None)
+def as_strided(x, size, stride, storage_offset=None):
+ if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
+ # as_strided ignores views
+ x = x.data.unwrap_view()
+ x.realize()
+ if not ir.is_contiguous_storage_and_layout(x):
+ raise NotImplementedError(f"unrealized as_strided({x}, ...)")
+ storage, old_layout = ir.as_contiguous_storage_and_layout(x)
+ new_layout = ir.FixedLayout(
+ old_layout.device,
+ old_layout.dtype,
+ [sympy.expand(s) for s in size],
+ [sympy.expand(s) for s in stride],
+ sympy.expand(storage_offset or 0),
+ )
+ return TensorBox(ir.ReinterpretView(storage, new_layout))
+
+
+@register_lowering(aten.as_strided_)
+def as_strided_(x, size, stride, storage_offset=None):
+ assert isinstance(x, TensorBox)
+ x.data = as_strided(x, size, stride, storage_offset).data
+ return x
+
+
+@register_lowering(aten.cat)
+def cat(inputs, dim=0):
+ if len(inputs) == 1:
+ return inputs[0]
+ dim = _validate_dim(inputs[0], dim, 0)
+ return TensorBox(ir.ConcatKernel.create(inputs, dim))
+
+
+@register_lowering(aten.select, type_promotion_kind=None)
+def select(x, dim, idx):
+ idx = View.handle_negative_index(idx, x.get_size()[dim])
+ return squeeze(slice_(x, dim, idx, idx + 1), dim)
+
+
+@register_lowering(aten.split, type_promotion_kind=None)
+def split(x, sizes, dim=0):
+ dim = _validate_dim(x, dim, 0)
+ x_size = V.graph.sizevars.guard_static_shape(x.get_size()[dim])
+ if isinstance(sizes, int):
+ sizes = [sizes] * ((x_size + sizes - 1) // sizes)
+ result = []
+ start = 0
+ for size in sizes:
+ end = start + size
+ result.append(slice_(x, dim, start, end))
+ start = end
+ return result
+
+
+@register_lowering(aten.split_with_sizes, type_promotion_kind=None)
+def split_with_sizes(x, sizes, dim=0):
+ return split(x, sizes, dim)
+
+
+@register_lowering(aten.unbind, type_promotion_kind=None)
+def unbind(x, dim=0):
+ dim = _validate_dim(x, dim, 0)
+ x_size = V.graph.sizevars.guard_static_shape(x.get_size()[dim])
+ result = []
+ for i in range(x_size):
+ result.append(select(x, dim, i))
+ return result
+
+
+@register_lowering(aten.unsqueeze, type_promotion_kind=None)
+def unsqueeze(x, dim):
+ dim = _validate_dim(x, dim, 1)
+ new_shape = list(x.get_size())
+ new_shape.insert(dim, sympy.Integer(1))
+ return view(x, new_shape)
+
+
+@register_lowering(aten.unsqueeze_, type_promotion_kind=None)
+def unsqueeze_(x, dim):
+ val = unsqueeze(x, dim)
+ assert isinstance(x, TensorBox)
+ assert isinstance(val, TensorBox)
+ x.data = val.data
+ return x
+
+
+def _validate_dim(x, dim, offset=0):
+ assert isinstance(dim, int)
+ ndim = len(x.get_size())
+ if dim < 0:
+ dim += ndim + offset
+ assert 0 <= dim < ndim + offset
+ return dim
+
+
+@register_lowering(aten.glu)
+def glu(x, dim=-1):
+ dim = _validate_dim(x, dim, 0)
+ new_len = V.graph.sizevars.guard_static_shape(x.get_size()[dim]) // 2
+ a = slice_(x, dim, 0, new_len)
+ b = slice_(x, dim, new_len, new_len * 2)
+ return mul(a, sigmoid(b))
+
+
+@register_lowering(aten.mm)
+def mm(a: TensorBox, b: TensorBox):
+ return TensorBox.create(ir.MatrixMultiply.create(a, b))
+
+
+@register_lowering(aten.addmm)
+def addmm(inp: TensorBox, a: TensorBox, b: TensorBox, beta=1, alpha=1):
+ return TensorBox.create(ir.MatrixMultiplyAdd.create(inp, a, b, beta, alpha))
+
+
+@register_lowering(aten.bmm)
+def bmm(a: TensorBox, b: TensorBox):
+ return TensorBox.create(ir.BatchMatrixMultiply.create(a, b))
+
+
+def fallback_handler(kernel):
+ fallbacks.add(kernel)
+
+ def handler(*args, **kwargs):
+ result = ir.FallbackKernel.create(kernel, *args, **kwargs)
+ if isinstance(result, (list, tuple)):
+ return list(map(TensorBox.create, result))
+ else:
+ return TensorBox.create(result)
+
+ return handler
+
+
+def make_fallback(kernel):
+ assert (
+ kernel not in decompositions
+ ), f"both a fallback and a decomp for same kernel: {kernel}"
+ if get_decompositions([kernel]) and kernel is not aten.cumsum:
+ log.warning(
+ f"make_fallback({kernel}): a decomposition exists, we should switch to it"
+ )
+
+ add_needs_realized_inputs(kernel)
+ return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel))
+
+
+@register_lowering(aten.native_dropout, type_promotion_kind=None)
+def native_dropout(x, p, train):
+ assert (
+ config.fallback_random
+ ), "this should be handled in decomps unless config.fallback_random"
+ if train:
+ return list(
+ map(
+ TensorBox.create,
+ ir.FallbackKernel.create(aten.native_dropout, x, p, train),
+ )
+ )
+ return x, ones_like(x, dtype=torch.bool)
+
+
+@register_lowering(aten.bernoulli_, type_promotion_kind=None)
+def bernoulli_(x, *args):
+ assert (
+ config.fallback_random
+ ), "this should be handled in decomps unless config.fallback_random"
+ x.realize()
+ V.graph.realize_users_of(x.get_name())
+ ir.InplaceBernoulliFallback(x, *args)
+ return x
+
+
+# This shouldn't be called in general
+@register_lowering(aten._foobar)
+def _foobar(_):
+ raise AssertionError()
+
+
+@functools.lru_cache(1)
+def _warn_triton_random(salt):
+ log.warning("using triton random, expect difference from eager")
+
+
+def warn_triton_random():
+ # only warn once per graph
+ _warn_triton_random(V.graph.creation_time)
+
+
+def make_rand(fn_name):
+ def rand_or_randn(
+ *size,
+ dtype=None,
+ layout=0,
+ device=None,
+ pin_memory=False,
+ memory_format=None,
+ ):
+ warn_triton_random()
+ assert not pin_memory
+ assert layout in (0, torch.strided)
+ assert memory_format in (None, torch.contiguous_format)
+ device = decode_device(device)
+ dtype = dtype or torch.get_default_dtype()
+ if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
+ size = tuple(size[0])
+ size = [sympy.expand(s) for s in size]
+ offset = V.graph.increment_randomness_offset(sympy_product(size))
+
+ random_pos = ir.FixedLayout(
+ device,
+ dtype,
+ size,
+ ir.FlexibleLayout.contiguous_strides(size),
+ offset=offset,
+ ).make_indexer()
+
+ seed_buffer = V.graph.random_seed_buffer(device).make_loader()
+
+ def inner_fn(index):
+ seed = seed_buffer([])
+ # change seed so that we don't collide with philox_rand_like()
+ # TODO(jansel): migrate everything to philox_rand_like()
+ seed = ops.bitwise_xor(seed, ops.constant(0xFFFF, torch.int32))
+ return getattr(ops, fn_name)(
+ seed,
+ ops.index_expr(random_pos(index), torch.int32),
+ dtype,
+ )
+
+ return Pointwise.create(
+ device=device,
+ dtype=dtype,
+ inner_fn=inner_fn,
+ ranges=list(size),
+ )
+
+ return rand_or_randn
+
+
+fallback_rand = fallback_handler(aten.rand)
+fallback_randn = fallback_handler(aten.randn)
+fast_rand = make_rand("rand")
+fast_randn = make_rand("randn")
+
+
+@register_lowering([aten.rand, torch.rand])
+def rand(*args, **kwargs):
+ if config.fallback_random:
+ return fallback_rand(*args, **kwargs)
+ else:
+ return fast_rand(*args, **kwargs)
+
+
+@register_lowering([aten.randn, torch.randn])
+def randn(*args, **kwargs):
+ if config.fallback_random:
+ return fallback_randn(*args, **kwargs)
+ else:
+ return fast_randn(*args, **kwargs)
+
+
+@register_lowering(overrides.philox_seed_like._overloadpacket)
+def philox_seed_like(x):
+ warn_triton_random()
+ return V.graph.random_seed_buffer(x.get_device())
+
+
+@register_lowering(overrides.philox_rand_like._overloadpacket, type_promotion_kind=None)
+def philox_rand_like(x, seed, offset):
+ device = x.get_device()
+ dtype = x.get_dtype()
+ size = x.get_size()
+ random_pos = ir.FixedLayout(
+ device,
+ dtype,
+ size,
+ ir.FlexibleLayout.contiguous_strides(size),
+ offset=sympy.expand(offset),
+ ).make_indexer()
+ seed_loader = seed.make_loader()
+
+ def inner_fn(index):
+ return ops.rand(
+ seed_loader([]),
+ ops.index_expr(random_pos(index), torch.int32),
+ dtype,
+ )
+
+ return Pointwise.create(
+ device=device,
+ dtype=dtype,
+ inner_fn=inner_fn,
+ ranges=list(size),
+ )
+
+
+if has_torchvision_roi_align():
+ make_fallback(torch.ops.torchvision.roi_align)
+
+# TODO(jansel): we should implement decomps or lowerings for these
+# https://github.com/pytorch/torchdynamo/issues/327
+make_fallback(aten._adaptive_avg_pool2d_backward)
+make_fallback(aten.as_strided_scatter)
+make_fallback(aten.convolution_backward)
+make_fallback(aten._cudnn_rnn)
+make_fallback(aten._cudnn_rnn_backward)
+make_fallback(aten.cumsum)
+make_fallback(aten._embedding_bag)
+make_fallback(aten._embedding_bag_forward_only)
+make_fallback(aten._fused_moving_avg_obs_fq_helper)
+make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
+make_fallback(aten.grid_sampler_2d_backward)
+make_fallback(aten.randperm)
+make_fallback(aten.sort)
+make_fallback(aten.sort.stable)
+make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
+make_fallback(aten._thnn_fused_lstm_cell)
+make_fallback(aten.topk)
+make_fallback(aten.unfold)
+make_fallback(aten.unfold_backward)
+make_fallback(aten.upsample_bicubic2d_backward)
+make_fallback(aten.upsample_bilinear2d_backward)
+
+
+@register_lowering(aten.convolution)
+def convolution(
+ x: TensorBox,
+ weight: TensorBox,
+ bias: TensorBox,
+ stride: List[int],
+ padding: List[int],
+ dilation: List[int],
+ transposed: bool,
+ output_padding: List[int],
+ groups: int,
+):
+ result = TensorBox.create(
+ ir.Convolution.create(
+ x,
+ weight,
+ None, # bias handled below
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ )
+ )
+ if bias is not None:
+ kernel_dims = len(weight.get_size()) - 2
+ out_chan = result.get_size()[-1 - kernel_dims]
+ bias = view(bias, [out_chan] + kernel_dims * [1])
+ result = add(result, bias)
+ return result
+
+
+@register_lowering(aten._convolution)
+def _convolution(
+ x,
+ weight,
+ bias,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ benchmark,
+ deterministic,
+ cudnn_enabled,
+ allow_tf32,
+):
+ return convolution(
+ x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
+ )
+
+
+@register_lowering(aten.clone)
+def clone(x, *, memory_format=0):
+ # TODO(jansel): memory format
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=x.make_loader(),
+ ranges=list(x.get_size()),
+ )
+
+
+if hasattr(aten, "lift_fresh_copy"):
+ register_lowering(aten.lift_fresh_copy)(clone)
+
+
+fallback_arange = fallback_handler(aten.arange)
+
+
+@register_lowering([torch.arange, aten.arange])
+def arange(
+ start,
+ end=None,
+ step=1,
+ *,
+ dtype=None,
+ device=None,
+ layout=torch.strided,
+ pin_memory=False,
+):
+ assert layout == torch.strided
+ assert not pin_memory
+ if end is None:
+ end = start
+ start = 0
+
+ if isinstance(start, float) and int(start) == start:
+ start = int(start)
+ if isinstance(end, float) and int(end) == end:
+ end = int(end)
+ if isinstance(step, float) and int(step) == step:
+ step = int(step)
+
+ # Triton kernel doesn't support float arange yet, fallback to aten.arange
+ if not (isinstance(start, int) and isinstance(end, int) and isinstance(step, int)):
+ return fallback_arange(
+ start,
+ end,
+ step,
+ dtype=dtype,
+ device=device,
+ layout=layout,
+ pin_memory=pin_memory,
+ )
+
+ dtype = dtype or torch.int64
+ length = ceildiv((end - start), step)
+ start = sympy.Integer(start)
+ step = sympy.Integer(step)
+
+ return Pointwise.create(
+ device=decode_device(device),
+ dtype=dtype,
+ inner_fn=lambda index: ops.index_expr(step * index[0] + start, dtype),
+ ranges=[sympy.Integer(length)],
+ )
+
+
+@register_lowering([torch.linspace, aten.linspace])
+def linspace(start, end, steps, *, dtype=None, device=None, pin_memory=False):
+ assert not pin_memory
+ dtype = dtype or torch.get_default_dtype()
+
+ step_size = (end - start) / (steps - 1)
+
+ def inner_fn(index):
+ return ops.add(
+ ops.mul(ops.constant(step_size, dtype), ops.index_expr(index[0], dtype)),
+ ops.constant(start, dtype),
+ )
+
+ return Pointwise.create(
+ device=decode_device(device),
+ dtype=dtype,
+ inner_fn=inner_fn,
+ ranges=[sympy.Integer(steps)],
+ )
+
+
+@register_lowering(aten.triu)
+def triu(x, diagonal=0):
+ x_loader = x.make_loader()
+ dtype = x.get_dtype()
+
+ def inner_fn(index):
+ *_, i, j = index
+ return ops.where(
+ ops.ge(
+ ops.index_expr(j - i - diagonal, torch.int32),
+ ops.constant(0, torch.int32),
+ ),
+ x_loader(index),
+ ops.constant(0, dtype),
+ )
+
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=dtype,
+ inner_fn=inner_fn,
+ ranges=list(x.get_size()),
+ )
+
+
+@register_lowering(aten.select_scatter, type_promotion_kind=None)
+def select_scatter(x, src, dim: int, index: int):
+ assert x.get_dtype() == src.get_dtype()
+ x_loader = x.make_loader()
+ dim = _validate_dim(x, dim, 0)
+ if index < 0:
+ index = index + x.get_size()[dim]
+ V.graph.sizevars.guard_leq(0, index)
+ V.graph.sizevars.guard_lt(index, x.get_size()[dim])
+ src = expand(unsqueeze(src, dim), x.get_size())
+ src_loader = src.make_loader()
+
+ def inner_fn(idx):
+ return ops.where(
+ ops.eq(
+ ops.index_expr(idx[dim], torch.int32),
+ ops.index_expr(index, torch.int32),
+ ),
+ src_loader(idx),
+ x_loader(idx),
+ )
+
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=inner_fn,
+ ranges=list(x.get_size()),
+ )
+
+
+@register_lowering(aten.slice_scatter, type_promotion_kind=None)
+def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
+ assert x.get_dtype() == src.get_dtype()
+ x_loader = x.make_loader()
+ dim = _validate_dim(x, dim, 0)
+ dim_size = x.get_size()[dim]
+ if start is not None and start < 0:
+ start = start + dim_size
+ if end is not None and end < 0:
+ end = end + dim_size
+ if start is None:
+ start = 0
+ if end is None or V.graph.sizevars.maybe_guard_leq(x.get_size()[dim], end):
+ end = dim_size
+
+ src_size = list(x.get_size())
+ src_size[dim] = ir.IndexingDiv(sympy.expand(end - start), sympy.expand(step))
+ src = expand(src, src_size)
+ src_loader = src.make_loader()
+
+ def inner_fn(idx):
+ if start == 0 and end == dim_size and step == 1:
+ # selecting every element is the same as just src.clone()
+ return src_loader(idx)
+
+ idx_dim = ops.index_expr(idx[dim], torch.int32)
+ src_idx = list(idx)
+ src_idx[dim] = ir.IndexingDiv(idx[dim] - start, step)
+
+ mask = []
+ if start != 0:
+ mask.append(
+ ops.ge(
+ idx_dim,
+ ops.index_expr(sympy.expand(start), torch.int32),
+ )
+ )
+ if end != dim_size:
+ mask.append(
+ ops.lt(
+ idx_dim,
+ ops.index_expr(sympy.expand(end), torch.int32),
+ )
+ )
+ if step != 1:
+ mask.append(
+ ops.eq(
+ ops.index_expr(
+ ir.ModularIndexing(idx[dim] - start, 1, step), torch.int32
+ ),
+ ops.constant(0, torch.int32),
+ )
+ )
+ assert mask
+ mask = functools.reduce(ops.and_, mask)
+ src_val = ops.masked(
+ mask,
+ lambda: src_loader(src_idx),
+ 0 if is_integer_type(x) else 0.0,
+ )
+ return ops.where(
+ mask,
+ src_val,
+ x_loader(idx),
+ )
+
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=inner_fn,
+ ranges=list(x.get_size()),
+ )
+
+
+def _unwrap(x):
+ if isinstance(x, (list, tuple)) and len(x) > 0:
+ return _unwrap(x[0])
+ return x
+
+
+@register_lowering([torch.tensor, aten.scalar_tensor])
+def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
+ assert layout in (None, torch.strided)
+ assert pin_memory is False
+ if isinstance(_unwrap(data), int):
+ dtype = dtype or torch.int64
+ else:
+ dtype = dtype or torch.get_default_dtype()
+
+ if isinstance(data, (float, int)):
+ ranges = []
+
+ def inner_fn(index):
+ return ops.constant(data, dtype)
+
+ elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
+ # inline small tensors
+ ranges = [sympy.Integer(len(data))]
+
+ def inner_fn(index):
+ def binary_search(start, end):
+ assert start < end
+ if end - start == 1:
+ return ops.constant(data[start], dtype)
+ mid = (end - start) // 2 + start
+ return ops.where(
+ ops.lt(
+ ops.index_expr(index[0], torch.int64),
+ ops.constant(mid, torch.int64),
+ ),
+ binary_search(start, mid),
+ binary_search(mid, end),
+ )
+
+ if len(data) == 0:
+ return ops.constant(0, dtype)
+ return binary_search(0, len(data))
+
+ else:
+ return V.graph.add_tensor_constant(
+ torch.tensor(data, dtype=dtype, device=device)
+ )
+
+ return Pointwise.create(
+ device=decode_device(device),
+ dtype=dtype,
+ inner_fn=inner_fn,
+ ranges=ranges,
+ )
+
+
+@register_lowering(torch.as_tensor)
+def as_tensor(data, dtype=None, device=None):
+ if isinstance(data, TensorBox):
+ if dtype is not None:
+ data = to(data, dtype)
+ if device is not None:
+ data = to(data, device)
+ return data
+ return tensor(data, dtype=dtype, device=device)
+
+
+@register_lowering(torch.LongTensor)
+def long_tensor(data):
+ return tensor(data, dtype=torch.int64)
+
+
+@register_lowering(aten._local_scalar_dense)
+def _local_scalar_dense(data):
+ return ir.DynamicScalar()
+
+
+def _full(fill_value, device, dtype, size):
+ value = fill_value
+ if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):
+ value = value.value
+ if isinstance(value, (int, float)):
+
+ def inner_fn(index):
+ return ops.constant(value, dtype)
+
+ else:
+ assert len(value.get_size()) == 0
+ value_loader = value.make_loader()
+
+ def inner_fn(index):
+ return value_loader([])
+
+ return Pointwise.create(
+ device=device,
+ dtype=dtype,
+ inner_fn=inner_fn,
+ ranges=list(size),
+ )
+
+
+@register_lowering(aten.full_like, type_promotion_kind=None)
+def full_like(x, fill_value, **kwargs):
+ return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
+
+
+def tensor_constructor(fill_value):
+ # torch.zeros, torch.ones, etc
+ def inner(
+ *size,
+ names=None,
+ dtype=None,
+ device=None,
+ layout=0,
+ pin_memory=False,
+ memory_format=None,
+ ):
+ assert names is None
+ assert not pin_memory
+ assert layout in (0, torch.strided)
+ assert memory_format in (None, torch.contiguous_format)
+ device = decode_device(device)
+ dtype = dtype or torch.get_default_dtype()
+ if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
+ size = tuple(size[0])
+ size = [sympy.expand(s) for s in size]
+ return _full(fill_value, device, dtype, size)
+
+ return inner
+
+
+empty = register_lowering([torch.empty, aten.empty])(tensor_constructor(0))
+zeros = register_lowering([torch.zeros, aten.zeros])(tensor_constructor(0))
+ones = register_lowering([torch.ones, aten.ones])(tensor_constructor(1))
+
+
+def create_tensor_like(creation_fn):
+ """
+ Shim to convert X_like(...) into X(...). For example zeros_like() into zeros().
+ """
+
+ def _constant_like(
+ x, *, dtype=None, device=None, layout=0, pin_memory=False, memory_format=None
+ ):
+ assert not pin_memory
+ assert layout in (0, torch.strided)
+ if dtype is None:
+ dtype = x.get_dtype()
+ else:
+ dtype = decode_dtype(dtype)
+ device = device or x.get_device()
+ size = list(x.get_size())
+ return creation_fn(
+ size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory
+ )
+
+ return _constant_like
+
+
+def constant_like(fill_value):
+ return create_tensor_like(tensor_constructor(fill_value))
+
+
+empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty))
+zeros_like = register_lowering(aten.zeros_like)(create_tensor_like(zeros))
+ones_like = register_lowering(aten.ones_like)(create_tensor_like(ones))
+if not config.fallback_random:
+ rand_like = register_lowering(aten.rand_like)(create_tensor_like(rand))
+
+register_lowering(aten.zero)(zeros_like)
+
+
+def new_constant(fill_value):
+ def _new_constant(
+ x, size, *, dtype=None, layout=None, device=None, pin_memory=None
+ ):
+ assert isinstance(size, (list, type))
+ assert not pin_memory
+ assert not layout or layout == torch.strided
+ dtype = decode_dtype(dtype) or x.get_dtype()
+ device = device or x.get_device()
+ size = [sympy.Integer(s) for s in size]
+ return _full(fill_value, device, dtype, size)
+
+ return _new_constant
+
+
+register_lowering(aten.new_empty)(new_constant(0))
+register_lowering(aten.new_zeros)(new_constant(0))
+register_lowering(aten.new_ones)(new_constant(1))
+
+
+@register_lowering(aten.empty_strided)
+def empty_strided(
+ size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
+):
+ assert isinstance(size, (list, type))
+ assert isinstance(stride, (list, type))
+ assert not pin_memory
+ assert not layout or layout == torch.strided
+ dtype = decode_dtype(dtype) or torch.get_default_dtype()
+ device = device or torch.tensor(0.0).device
+ pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
+ if tuple(ir.FlexibleLayout.contiguous_strides(size)) == tuple(stride):
+ # fast path, no need to realize it
+ return pointwise
+ pointwise.realize()
+ buffer = pointwise.data.data
+ assert isinstance(buffer, ir.ComputedBuffer)
+ buffer.layout = ir.FixedLayout(
+ device=device,
+ dtype=dtype,
+ size=[sympy.expand(s) for s in size],
+ stride=[sympy.expand(s) for s in stride],
+ )
+ return pointwise
+
+
+@register_lowering(aten.new_empty_strided)
+def new_empty_strided(
+ x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
+):
+ if dtype is None:
+ dtype = x.get_dtype()
+ if device is None:
+ device = x.get_device()
+ return empty_strided(
+ size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
+ )
+
+
+@register_lowering([torch.full, aten.full])
+def full(size, fill_value, **kwargs):
+ return tensor_constructor(fill_value)(size, **kwargs)
+
+
+@register_lowering(aten.gather, type_promotion_kind=None)
+def gather(x, dim, index):
+ assert isinstance(x, TensorBox)
+ assert index.get_dtype() == torch.int64
+ offset = len(x.get_size()) == 0
+ dim = _validate_dim(x, dim, offset)
+
+ x_loader = x.make_loader()
+ index_loader = index.make_loader()
+
+ def fn(idx):
+ idx = list(idx)
+ if len(idx) != 0:
+ idx[dim] = ops.indirect_indexing(index_loader(idx))
+ return x_loader(idx)
+
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=fn,
+ ranges=index.get_size(),
+ )
+
+
+@register_lowering(aten.embedding, type_promotion_kind=None)
+def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
+ assert not sparse
+ assert isinstance(weight, TensorBox)
+ assert isinstance(indices, TensorBox)
+ assert "int" in str(indices.get_dtype())
+
+ weight_loader = weight.make_loader()
+ indices_loader = indices.make_loader()
+ indices_ndim = len(indices.get_size())
+ new_size = [*indices.get_size(), *weight.get_size()[1:]]
+
+ def fn(idx):
+ assert len(idx) == len(new_size), f"{idx} != {new_size}"
+ var_index = indices_loader(idx[:indices_ndim])
+ weight_idx = [ops.indirect_indexing(var_index)] + [*idx[indices_ndim:]]
+ return weight_loader(weight_idx)
+
+ return Pointwise.create(
+ device=weight.get_device(),
+ dtype=weight.get_dtype(),
+ inner_fn=fn,
+ ranges=new_size,
+ )
+
+
+def check_and_broadcast_indices(indices):
+ assert all(
+ i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
+ for i in indices
+ if i is not None
+ ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
+ assert all(
+ [i.get_dtype() in (torch.int32, torch.int64) for i in indices if i is not None]
+ ), "bool indices are not supported yet"
+ valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)]
+ assert len(valid_idxs) > 0, "requires at least 1 non-None index"
+ new_indices = [None] * len(indices)
+ for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])):
+ new_indices[i] = x
+ output_dim = len(x.get_size())
+ start_offset = 0
+ # only support None at start or end for now
+ tmp = list(new_indices)
+ while tmp and tmp[-1] is None:
+ tmp.pop()
+ while tmp and tmp[0] is None:
+ tmp.pop(0)
+ start_offset += 1
+ assert all((i is not None) for i in tmp)
+ end_offset = output_dim + start_offset
+
+ return new_indices, start_offset, end_offset
+
+
+@register_lowering(aten.index, type_promotion_kind=None)
+def index(x, indices):
+ assert isinstance(indices, (list, tuple))
+ x_loader = x.make_loader()
+ indices, start_offset, end_offset = check_and_broadcast_indices(indices)
+ indices_sizes = [i.get_size() for i in indices if i is not None]
+ indices_loaders = [i.make_loader() for i in indices if i is not None]
+ # no guards on output size, all the guards are set in broadcast_tensors
+ output_size = list(indices_sizes[0])
+
+ x_size = x.get_size()
+ output_size = [
+ *x_size[:start_offset],
+ *output_size,
+ *x_size[start_offset + len(indices_loaders) :],
+ ]
+
+ def fn(idx):
+ assert len(idx) == len(output_size)
+ new_index = [
+ ops.indirect_indexing(loader(idx[start_offset:end_offset]))
+ for loader in indices_loaders
+ ]
+ new_index = [*idx[:start_offset], *new_index, *idx[end_offset:]]
+ return x_loader(new_index)
+
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=fn,
+ ranges=output_size,
+ )
+
+
+# This is moved from decomposition to lowering because this decomp introduced
+# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead
+# code elimination and common subexpression elimination optimizations, which
+# assume graphs to be side-effect free. More details at
+# https://github.com/pytorch/torchdynamo/issues/1235.
+# Moving such reinplacing type of decomps to lowering ensures that AotAutograd
+# gets good graphs.
+@register_lowering([aten.index_put])
+def index_put(x, indices, values, accumulate=False):
+ return index_put_(clone(x), indices, values, accumulate)
+
+
+def index_put_as_masked_fill(self, indices, value, accumulate):
+ if value.get_device() != self.get_device():
+ value = to_device(value, self.get_device())
+ if accumulate:
+ value = add(self, value)
+ return mutate_to(self, where(indices[0], value, self))
+
+
+def index_put_fallback(self, indices, values, accumulate):
+ ir.IndexPutFallback(self, indices, values, accumulate)
+ return self
+
+
+@register_lowering(aten.index_put_, type_promotion_kind=None)
+def index_put_(self, indices, values, accumulate=False):
+ # Dispatch to masked fill for single boolean index with single value
+ if (
+ values.get_numel() == 1
+ and len(indices) == 1
+ and indices[0].get_dtype() in {torch.bool, torch.uint8}
+ ):
+ return index_put_as_masked_fill(self, indices, values, accumulate)
+
+ # Fallback if there is a boolean index
+ for index in indices:
+ if index is not None and index.get_dtype() in {torch.bool, torch.uint8}:
+ return index_put_fallback(self, indices, values, accumulate)
+
+ x_size = self.get_size()
+ x_ndim = len(x_size)
+
+ # fallback to aten.index_put_, as tl.atomic_add does NOT support int64 or bool
+ if self.get_dtype() in {torch.int64, torch.bool}:
+ # self is an scalar Tensor
+ if x_ndim == 0:
+ self = view(self, [1])
+ self = index_put_fallback(self, indices, values, accumulate)
+ if x_ndim == 0:
+ self = view(self, [])
+ return self
+
+ values = to_dtype(values, self.get_dtype())
+ indices, start_offset, end_offset = check_and_broadcast_indices(indices)
+ indices_sizes = [i.get_size() for i in indices if i is not None]
+ indices_loaders = [i.make_loader() for i in indices if i is not None]
+
+ assert isinstance(self, TensorBox)
+ self.realize()
+ V.graph.realize_users_of(self.get_name())
+
+ # self is an scalar Tensor
+ if x_ndim == 0:
+ self = view(self, [1])
+
+ output_size = list(indices_sizes[0])
+ expected_vals_size = [
+ *x_size[:start_offset],
+ *output_size,
+ *x_size[start_offset + len(indices_sizes) :],
+ ]
+
+ values = expand(values, expected_vals_size)
+ # all guards are set above during broadcast_tensors and expand
+
+ def output_indexer(index):
+ assert len(index) == len(expected_vals_size)
+ new_index = [
+ ops.indirect_indexing(loader(index[start_offset:end_offset]))
+ for loader in indices_loaders
+ ]
+ new_index = [*index[:start_offset], *new_index, *index[end_offset:]]
+ return new_index
+
+ scatter = ir.Scatter(
+ device=self.get_device(),
+ dtype=self.get_dtype(),
+ inner_fn=values.make_loader(),
+ ranges=expected_vals_size, # iter_ranges,
+ output_indexer=output_indexer,
+ scatter_mode="atomic_add" if accumulate else None,
+ )
+ buffer = ir.ComputedBuffer(
+ None,
+ ir.MutationLayout(self),
+ scatter,
+ )
+ buffer.name = V.graph.register_buffer(buffer)
+
+ if x_ndim == 0:
+ self = view(self, [])
+ return self
+
+
+@register_lowering(aten.scatter, type_promotion_kind=None)
+def scatter(x, dim: int, index, src, **kwargs):
+ return scatter_(clone(x), dim, index, src, **kwargs)
+
+
+@register_lowering(aten.scatter_, type_promotion_kind=None)
+def scatter_(self, dim: int, index, src, *, reduce: str = None):
+ if reduce == "add":
+ reduce = "sum"
+ elif reduce == "multiply":
+ reduce = "prod"
+ else:
+ assert reduce is None
+ return scatter_reduce_(self, dim, index, src, reduce)
+
+
+@register_lowering(aten.scatter_add, type_promotion_kind=None)
+def scatter_add(x, dim: int, index, src):
+ return scatter_add_(clone(x), dim, index, src)
+
+
+@register_lowering(aten.scatter_add_, type_promotion_kind=None)
+def scatter_add_(x, dim: int, index, src):
+ return scatter_reduce_(clone(x), dim, index, src, "sum")
+
+
+@register_lowering(aten.scatter_reduce, type_promotion_kind=None)
+def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
+ return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs)
+
+
+fallback_scatter_reduce_ = fallback_handler(aten.scatter_reduce_)
+
+
+@register_lowering(aten.scatter_reduce_, type_promotion_kind=None)
+def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
+ assert reduce in {None, "sum", "prod", "mean", "amax", "amin"}
+
+ # TODO: Need to support more reduction type
+ # For reduction of "sum", tl.atomic_add doesn't support bool or int64
+ if reduce not in {None, "sum"} or (
+ reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64}
+ ):
+ self.realize()
+ return fallback_scatter_reduce_(
+ self, dim, index, src, reduce, include_self=include_self
+ )
+
+ assert isinstance(self, TensorBox)
+ assert "int" in str(index.get_dtype())
+
+ ndim = len(self.get_size())
+ if ndim == 0:
+ self = view(self, [1])
+
+ if isinstance(src, TensorBox) and len(src.get_size()) == 0:
+ src = view(src, [1])
+
+ if isinstance(index, TensorBox) and len(index.get_size()) == 0:
+ index = view(index, [1])
+
+ assert -len(self.get_size()) <= dim < len(self.get_size())
+
+ self.realize()
+ V.graph.realize_users_of(self.get_name())
+ index_loader = index.make_loader()
+ src_loader = src.make_loader() if isinstance(src, TensorBox) else None
+
+ def output_indexer(idx):
+ indirect_idx = list(idx)
+ indirect_idx[dim] = ops.indirect_indexing(index_loader(idx))
+ return indirect_idx
+
+ def fn(idx):
+ if src_loader:
+ return src_loader(idx)
+ else:
+ # src is a scalar
+ return ops.constant(src, self.get_dtype())
+
+ def backend_reduce_str(reduce):
+ if reduce == "sum":
+ return "atomic_add"
+ else:
+ # TODO: Need to support more reduction type
+ assert reduce is None
+ return None
+
+ if not include_self:
+ # zero out the corresponding elements first
+ zero_out = ir.Scatter(
+ device=self.get_device(),
+ dtype=self.get_dtype(),
+ inner_fn=lambda index: ops.constant(0, self.get_dtype()),
+ ranges=index.get_size(),
+ output_indexer=output_indexer,
+ scatter_mode=None,
+ )
+ buffer = ir.ComputedBuffer(
+ None,
+ ir.MutationLayout(self),
+ zero_out,
+ )
+ buffer.name = V.graph.register_buffer(buffer)
+
+ # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
+ # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
+ # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
+ scatter = ir.Scatter(
+ device=self.get_device(),
+ dtype=self.get_dtype(),
+ inner_fn=fn,
+ ranges=index.get_size(),
+ output_indexer=output_indexer,
+ scatter_mode=backend_reduce_str(reduce),
+ )
+ buffer = ir.ComputedBuffer(
+ None,
+ ir.MutationLayout(self),
+ scatter,
+ )
+ buffer.name = V.graph.register_buffer(buffer)
+
+ if ndim == 0:
+ self = view(self, [])
+ return self
+
+
+@register_lowering(aten.upsample_nearest2d)
+def upsample_nearest2d(x, output_size=None, scale_factors=None):
+ x.realize_hint() # elements are reused
+ x_loader = x.make_loader()
+
+ *batch, ih, iw = x.get_size()
+ ih = V.graph.sizevars.guard_static_shape(ih)
+ iw = V.graph.sizevars.guard_static_shape(iw)
+
+ if scale_factors:
+ assert not output_size
+ sh, sw = scale_factors
+ oh = int(ih * sh)
+ ow = int(iw * sw)
+ else:
+ oh, ow = output_size
+
+ scale_h = ih / oh
+ scale_w = iw / ow
+
+ def scale(x, scale):
+ x = ops.index_expr(x, torch.float32)
+ x = ops.mul(x, ops.constant(scale, torch.float32))
+ x = ops.to_dtype(x, torch.int32)
+ return ops.indirect_indexing(x)
+
+ def fn(idx):
+ *b, x, y = idx
+ return x_loader([*b, scale(x, scale_h), scale(y, scale_w)])
+
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=fn,
+ ranges=[*batch, sympy.Integer(oh), sympy.Integer(ow)],
+ )
+
+
+@register_lowering(aten.upsample_bicubic2d.default)
+def upsample_bicubic2d_default(
+ x,
+ output_size,
+ align_corners: bool,
+ scales_h: Optional[float] = None,
+ scales_w: Optional[float] = None,
+):
+ x.realize_hint()
+ x_loader = x.make_loader()
+
+ N, C, iH, iW = x.get_size()
+ oH, oW = output_size
+
+ iH = V.graph.sizevars.guard_static_shape(iH)
+ iW = V.graph.sizevars.guard_static_shape(iW)
+
+ def get_int_dtype(maxval):
+ if maxval > torch.iinfo(torch.int32).max:
+ return torch.int64
+ return torch.int32
+
+ def compute_scale(in_size, out_size, align_corners, scale=None):
+ if align_corners:
+ return (in_size - 1) / (out_size - 1) if out_size > 1 else 0
+ else:
+ return 1 / scale if scale is not None and scale > 0 else in_size / out_size
+
+ def compute_source_index(scale, dst_index, align_corners):
+ dst_index_ie = ops.index_expr(dst_index, torch.float32)
+ if align_corners:
+ return ops.mul(scale, dst_index_ie)
+ else:
+ return ops.sub(
+ ops.mul(scale, ops.add(dst_index_ie, 0.5)), 0.5
+ ) # scale * (dst_index + 0.5) - 0.5
+
+ def cubic_convolution1(x, A):
+ # ((A + 2) * x - (A+3)) * x * x + 1
+ return ops.add(ops.mul(ops.mul(ops.sub(ops.mul(A + 2, x), A + 3), x), x), 1.0)
+
+ def cubic_convolution2(x, A):
+ # ((A * x - 5 * A) * x + 8 * A) * x - 4*A
+ return ops.sub(
+ ops.mul(ops.add(ops.mul(ops.sub(ops.mul(A, x), 5 * A), x), 8 * A), x), 4 * A
+ )
+
+ def get_cubic_upsample_coefficients(t):
+ A = -0.75
+ c0 = cubic_convolution2(ops.add(t, 1.0), A)
+ c1 = cubic_convolution1(t, A)
+
+ x2 = ops.sub(1.0, t)
+ c2 = cubic_convolution1(x2, A)
+ c3 = cubic_convolution2(ops.add(x2, 1.0), A)
+ return (
+ c0,
+ c1,
+ c2,
+ c3,
+ )
+
+ def cubic_interp1d(xs, t):
+ cs = get_cubic_upsample_coefficients(t)
+ # dot product between xs and cs
+ return ops.add(
+ ops.mul(xs[0], cs[0]),
+ ops.add(
+ ops.mul(xs[1], cs[1]),
+ ops.add(ops.mul(xs[2], cs[2]), ops.mul(xs[3], cs[3])),
+ ),
+ )
+
+ height_scale = compute_scale(iH, oH, align_corners, scales_h)
+ width_scale = compute_scale(iW, oW, align_corners, scales_h)
+
+ def clamp(v, min, max):
+ return ops.maximum(min, ops.minimum(max, v))
+
+ def fn(idx):
+ n, c, oy, ox = idx
+
+ real_x = compute_source_index(width_scale, ox, align_corners)
+ in_x = ops.floor(real_x)
+ t_x = ops.sub(real_x, in_x)
+
+ real_y = compute_source_index(height_scale, oy, align_corners)
+ in_y = ops.floor(real_y)
+ t_y = ops.sub(real_y, in_y)
+
+ def load_bounded(fy, fx):
+ iy = ops.indirect_indexing(clamp(fy, 0, iH - 1))
+ ix = ops.indirect_indexing(clamp(fx, 0, iW - 1))
+ return x_loader([n, c, iy, ix])
+
+ iy = ops.to_dtype(in_y, get_int_dtype(iH + 1))
+ ix = ops.to_dtype(in_x, get_int_dtype(iW + 1))
+ iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2)))
+ ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2)))
+
+ def get_x_interp(y):
+ coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs))
+ return cubic_interp1d(coeffs_x, t_x)
+
+ coeffs_y = tuple(get_x_interp(y) for y in iys_ofs)
+ return cubic_interp1d(coeffs_y, t_y)
+
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=fn,
+ ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)],
+ )
+
+
+@register_lowering(aten.upsample_bicubic2d.vec)
+def upsample_bicubic2d_vec(
+ a,
+ output_size,
+ align_corners: bool,
+ scale_factors: Optional[Tuple[float, float]] = None,
+):
+ _, _, iH, iW = a.get_size()
+ iH = V.graph.sizevars.guard_static_shape(iH)
+ iW = V.graph.sizevars.guard_static_shape(iW)
+
+ if bool(output_size) + bool(scale_factors) != 1:
+ raise RuntimeError("Must specify exactly one of output_size and scale_factor.")
+ if output_size is None:
+ assert scale_factors is not None
+ output_size = (int(iH * scale_factors[0]), int(iW * scale_factors[1]))
+ scale_h, scale_w = scale_factors if scale_factors else (None, None)
+ return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w)
+
+
+@register_lowering(aten.reflection_pad2d)
+def reflection_pad2d(x, padding):
+ assert len(padding) == 4
+ left, right, top, bot = padding
+
+ x_loader = x.make_loader()
+ *batch, h, w = x.get_size()
+ h = V.graph.sizevars.guard_static_shape(h)
+ w = V.graph.sizevars.guard_static_shape(w)
+
+ def reflect(x, size, offset):
+ size = ops.constant(size - 1, torch.int32)
+ x = ops.index_expr(x, torch.int32)
+ x = ops.sub(x, ops.constant(offset, torch.int32))
+ x = ops.sub(size, ops.abs(ops.sub(size, ops.abs(x))))
+ return ops.indirect_indexing(x)
+
+ def fn(idx):
+ *b, x, y = idx
+ x = reflect(x, h, top)
+ y = reflect(y, w, left)
+ return x_loader([*b, x, y])
+
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=fn,
+ ranges=[*batch, sympy.Integer(h + top + bot), sympy.Integer(w + left + right)],
+ )
+
+
+@register_lowering(aten.reflection_pad2d_backward)
+def reflection_pad2d_backward(grad_output, x, padding):
+ assert len(padding) == 4
+ left, right, top, bot = padding
+
+ *_, h, w = x.get_size()
+ h = V.graph.sizevars.guard_static_shape(h) - 1
+ w = V.graph.sizevars.guard_static_shape(w) - 1
+ grad_loader = grad_output.make_loader()
+
+ def fn(idx):
+ *b, x, y = idx
+
+ def load_from_output(x, y):
+ x = ops.indirect_indexing(ops.index_expr(x, torch.int32))
+ y = ops.indirect_indexing(ops.index_expr(y, torch.int32))
+ return grad_loader([*b, x, y])
+
+ def index_range_condition(index_range):
+ i, lb, ub = index_range
+ i = ops.index_expr(i, torch.int32)
+ return ops.and_(ops.ge(i, lb), ops.le(i, ub))
+
+ def accumulate(out_x, out_y, index_range1, index_range2=None):
+ nonlocal grad
+
+ # If the upper bound is less than the lower bound, we can get rid of one accumulation.
+ # This happens when the padding size is zero.
+ if index_range1[2] < index_range1[1]:
+ return
+ cond = index_range_condition(index_range1)
+ if index_range2 is not None:
+ if index_range2[2] < index_range2[1]:
+ return
+ cond = ops.and_(cond, index_range_condition(index_range2))
+ g = ops.masked(cond, lambda: load_from_output(out_x, out_y), 0.0)
+ grad = ops.add(grad, g)
+
+ # Areas after reflection:
+ #
+ # top-left | top | top-right
+ # -----------------------------------------
+ # left | center | right
+ # -----------------------------------------
+ # bottom-left | bottom | bottom-right
+ #
+ # The center area is the orignial matrix. Other areas are reflections.
+
+ center_x, center_y = x + top, y + left
+ top_reflect_x, left_reflect_y = top - x, left - y
+ bot_reflect_x, right_reflect_y = 2 * h + top - x, 2 * w + left - y
+
+ # Accumulate gradients from different areas
+ grad = load_from_output(center_x, center_y)
+ accumulate(center_x, left_reflect_y, (y, 1, left))
+ accumulate(center_x, right_reflect_y, (y, w - right, w - 1))
+ accumulate(top_reflect_x, center_y, (x, 1, top))
+ accumulate(bot_reflect_x, center_y, (x, h - bot, h - 1))
+ accumulate(top_reflect_x, left_reflect_y, (x, 1, top), (y, 1, left))
+ accumulate(top_reflect_x, right_reflect_y, (x, 1, top), (y, w - right, w - 1))
+ accumulate(bot_reflect_x, left_reflect_y, (x, h - bot, h - 1), (y, 1, left))
+ accumulate(
+ bot_reflect_x, right_reflect_y, (x, h - bot, h - 1), (y, w - right, w - 1)
+ )
+
+ return grad
+
+ return Pointwise.create(
+ device=grad_output.get_device(),
+ dtype=grad_output.get_dtype(),
+ inner_fn=fn,
+ ranges=list(x.get_size()),
+ )
+
+
+@register_lowering(prims.rev.default)
+def rev(x, dims):
+ # note - dims pre-canoncalized
+ x_loader = x.make_loader()
+ sizes = x.get_size()
+
+ def loader(idx):
+ idx = list(idx)
+ assert len(idx) == len(sizes)
+ for dim in dims:
+ idx[dim] = (sizes[dim] - 1) - idx[dim]
+
+ return x_loader(idx)
+
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=loader,
+ ranges=sizes,
+ )
+
+
+@register_lowering(aten.constant_pad_nd, type_promotion_kind=None)
+def constant_pad_nd(x, padding, fill_value=0):
+ assert (len(padding) % 2) == 0
+ if all(p == 0 for p in padding):
+ return x
+
+ sizes = x.get_size()
+
+ bounds = list(reversed(list(zip(padding[::2], padding[1::2]))))
+ n = len(sizes) - len(bounds)
+
+ output_size = list(sizes[:n])
+ mask_sizes = []
+ for (low, high), size in zip(bounds, sizes[n:]):
+ size = V.graph.sizevars.guard_static_shape(size)
+ mask_sizes.append(size)
+ output_size.append(sympy.expand(size + low + high))
+ assert len(output_size) == len(sizes)
+
+ def mask(index):
+ mask = []
+ for idx, (low, high), length in zip(index[n:], bounds, mask_sizes):
+ if low != 0:
+ mask.append(range_mask_low(idx))
+ if high != 0:
+ mask.append(range_mask_high(idx, length))
+ mask = functools.reduce(ops.and_, mask)
+ return ops.masked(mask, lambda: x_loader(index), fill_value)
+
+ def offset_fn(index):
+ new_index = list(index[:n])
+ for idx, (low, high) in zip(index[n:], bounds):
+ new_index.append(idx - low)
+ assert len(new_index) == len(index)
+ return mask(new_index)
+
+ x_loader = x.make_loader()
+ return Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=offset_fn,
+ ranges=output_size,
+ )
+
+
+def range_mask_low(i: sympy.Expr):
+ return ops.ge(
+ ops.index_expr(i, torch.int64),
+ ops.index_expr(sympy.Integer(0), torch.int64),
+ )
+
+
+def range_mask_high(i: sympy.Expr, length: sympy.Expr):
+ return ops.lt(
+ ops.index_expr(i, torch.int64),
+ ops.index_expr(length, torch.int64),
+ )
+
+
+def range_mask(i: sympy.Expr, length: sympy.Expr):
+ return ops.and_(
+ range_mask_low(i),
+ range_mask_high(i, length),
+ )
+
+
+def constant_boundary_condition_2d(x, fill_value, padding):
+ *_, h, w = x.get_size()
+ x_loader = x.make_loader()
+
+ def load(index):
+ *prefix, ih, iw = index
+
+ mask = ops.and_(
+ range_mask(ih, h),
+ range_mask(iw, w),
+ )
+ return ops.masked(mask, lambda: x_loader([*prefix, ih, iw]), fill_value)
+
+ return load
+
+
+def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
+
+ x_out = ir.IndexingDiv(
+ x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i]
+ )
+
+ if ceil_mode:
+ x_alt = ir.IndexingDiv(
+ x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i]
+ )
+
+ if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
+ # ceil mode is actually a no-op, lets guard on that
+ V.graph.sizevars.guard_equals(x_out, x_alt)
+ ceil_mode = False
+ else:
+ x_out = x_alt
+ return x_out, ceil_mode
+
+
+@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
+def max_pool2d_with_indices(
+ x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
+):
+ if padding == 0:
+ padding = [0, 0]
+ if not stride:
+ stride = kernel_size
+
+ assert dilation == 1 or all(d == 1 for d in dilation)
+ assert isinstance(x, TensorBox)
+ assert len(kernel_size) == 2
+ assert len(stride) == 2
+ assert len(padding) == 2
+ assert len(x.get_size()) in (3, 4)
+
+ x.realize_hint()
+ *batch, h, w = x.get_size()
+
+ h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
+ w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
+
+ if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
+ x_loader = constant_boundary_condition_2d(x, float("-inf"), padding)
+ else:
+ x_loader = x.make_loader()
+
+ new_size = list(batch) + [h_out, w_out]
+
+ def fn(idx, return_index):
+ *prefix, bh, bw = idx
+ maxval = None
+ maxindex = None
+ for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
+ ih = bh * stride[0] + ih - padding[0]
+ iw = bw * stride[1] + iw - padding[1]
+ val = x_loader([*prefix, ih, iw])
+ index = ops.index_expr(ih * w + iw, torch.int64)
+ if maxval is None:
+ maxindex = index
+ maxval = val
+ else:
+ maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
+ maxval = ops.maximum(val, maxval)
+ if return_index:
+ return maxindex
+ else:
+ return maxval
+
+ r1 = Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=functools.partial(fn, return_index=False),
+ ranges=new_size,
+ )
+ r2 = Pointwise.create(
+ device=x.get_device(),
+ dtype=torch.int64,
+ inner_fn=functools.partial(fn, return_index=True),
+ ranges=new_size,
+ )
+ # TODO(jansel): should we force these to be realized?
+ return r1, r2
+
+
+@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None)
+def max_pool2d_with_indices_backward(
+ grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
+):
+ if padding == 0:
+ padding = [0, 0]
+ if not stride:
+ stride = kernel_size
+
+ assert dilation == 1 or all(d == 1 for d in dilation)
+ assert isinstance(x, TensorBox)
+ assert len(kernel_size) == 2
+ assert len(stride) == 2
+ assert len(padding) == 2
+ assert len(x.get_size()) in (3, 4)
+
+ # we will read this many times, so make sure it is computed
+ grad_output.realize_hint()
+ indices.realize_hint()
+
+ *batch, height, width = x.get_size()
+ *_, pooled_height, pooled_width = grad_output.get_size()
+
+ indices_loader = indices.make_loader()
+ grad_loader = grad_output.make_loader()
+ new_size = list(x.get_size())
+
+ h_window_size = max(
+ [
+ max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
+ for h in range(kernel_size[0] * 2)
+ ]
+ )
+ w_window_size = max(
+ [
+ max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
+ for w in range(kernel_size[1] * 2)
+ ]
+ )
+
+ def fn(idx):
+ *prefix, h, w = idx
+ index_test = ops.index_expr(h * width + w, torch.int32)
+ h = h + padding[0]
+ w = w + padding[1]
+ phstart = ops.index_expr(
+ ir.IndexingDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
+ )
+ pwstart = ops.index_expr(
+ ir.IndexingDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
+ )
+ phend = ops.index_expr(ir.IndexingDiv(h, stride[0]) + 1, torch.int32)
+ pwend = ops.index_expr(ir.IndexingDiv(w, stride[1]) + 1, torch.int32)
+
+ phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
+ pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
+ phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
+ pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
+
+ gradient = None
+ for ph_ in range(h_window_size):
+ for pw_ in range(w_window_size):
+ ph = ops.add(phstart, ops.constant(ph_, torch.int32))
+ pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
+ grad_index = [
+ *prefix,
+ ops.indirect_indexing(
+ ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32)))
+ ),
+ ops.indirect_indexing(
+ ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32)))
+ ),
+ ]
+
+ index_actual = indices_loader(grad_index)
+ grad_part = grad_loader(grad_index)
+ check = ops.eq(index_actual, index_test)
+
+ if gradient is None:
+ # don't need mask for 0, 0
+ gradient = ops.where(
+ check, grad_part, ops.constant(0.0, torch.float32)
+ )
+ else:
+ mask = ops.and_(
+ ops.and_(
+ ops.lt(ph, phend),
+ ops.lt(pw, pwend),
+ ),
+ check,
+ )
+ gradient = ops.where(mask, ops.add(gradient, grad_part), gradient)
+ assert gradient is not None
+ return gradient
+
+ return Pointwise.create(
+ device=grad_output.get_device(),
+ dtype=grad_output.get_dtype(),
+ inner_fn=fn,
+ ranges=new_size,
+ )
+
+
+def pad_adaptive_loader(x):
+ *_, h, w = x.get_size()
+ x_loader = x.make_loader()
+
+ def load(prefix, increments, start_indices, end_indices):
+ ih, iw = increments
+ h_start_index, w_start_index = start_indices
+ h_end_index, w_end_index = end_indices
+
+ mask = ops.and_(
+ ops.lt(
+ ops.index_expr(h_start_index + ih, torch.int64),
+ ops.index_expr(h_end_index, torch.int64),
+ ),
+ ops.lt(
+ ops.index_expr(w_start_index + iw, torch.int64),
+ ops.index_expr(w_end_index, torch.int64),
+ ),
+ )
+
+ return ops.masked(
+ mask,
+ lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]),
+ 0.0,
+ )
+
+ return load
+
+
+def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
+ h_start_index_fn, w_start_index_fn = start_index_fns
+ h_end_index_fn, w_end_index_fn = end_index_fns
+
+ def fn_sum(idx, loader):
+ *prefix, bh, bw = idx
+
+ h_start_index = h_start_index_fn(bh)
+ h_end_index = h_end_index_fn(bh)
+
+ w_start_index = w_start_index_fn(bw)
+ w_end_index = w_end_index_fn(bw)
+
+ total = None
+ for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
+ val = loader(
+ prefix,
+ [ih, iw],
+ [h_start_index, w_start_index],
+ [h_end_index, w_end_index],
+ )
+ if total is None:
+ total = val
+ else:
+ total = ops.add(val, total)
+ return total
+
+ return fn_sum
+
+
+@register_lowering(aten._adaptive_avg_pool2d)
+def _adaptive_avg_pool2d(x, output_size):
+ assert isinstance(x, TensorBox)
+ assert len(output_size) == 2
+ x.realize_hint()
+
+ *batch, h_in, w_in = x.get_size()
+
+ h_in = V.graph.sizevars.guard_static_shape(h_in)
+ w_in = V.graph.sizevars.guard_static_shape(w_in)
+
+ h_out, w_out = output_size
+
+ # no-op if the same input and output
+ if h_in == h_out and w_in == w_out:
+ return clone(x)
+
+ if h_in % h_out == 0 and w_in % w_out == 0:
+ kernel_size = [h_in // h_out, w_in // w_out]
+ return avg_pool2d(x, kernel_size)
+
+ h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
+ w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
+
+ new_size = list(batch) + [h_out, w_out]
+ dtype = x.get_dtype()
+
+ def start_index(index, out_dim, inp_dim):
+ return ir.IndexingDiv((index * inp_dim), out_dim)
+
+ def end_index(index, out_dim, inp_dim):
+ return ir.IndexingDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
+
+ h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
+ h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
+
+ w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
+ w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
+
+ fn_sum = _adaptive_pooling_idx_sum(
+ [h_kernel_max, w_kernel_max],
+ [h_start_index, w_start_index],
+ [h_end_index, w_end_index],
+ )
+
+ ones_loader = pad_adaptive_loader(ones_like(x))
+
+ def fn(idx):
+ return ops.div(fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader))
+
+ rv = Pointwise.create(
+ device=x.get_device(),
+ dtype=dtype,
+ inner_fn=fn,
+ ranges=new_size,
+ )
+ # TODO: should we force these to be realized?
+ return rv
+
+
+@register_lowering(aten.upsample_nearest2d_backward.vec)
+def upsample_nearest2d_backward(
+ x, output_size=None, input_size=None, scale_factors=None
+):
+ x.realize_hint()
+
+ *batch, inp_h, inp_w = x.get_size()
+ inp_h = V.graph.sizevars.guard_static_shape(inp_h)
+ inp_w = V.graph.sizevars.guard_static_shape(inp_w)
+
+ *batch, out_h, out_w = input_size
+
+ if inp_h % out_h == 0 and inp_w % out_w == 0:
+ return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1)
+
+ h_kernel_max = ceildiv(inp_h, out_h)
+ w_kernel_max = ceildiv(inp_w, out_w)
+
+ def start_index(index, out_dim, inp_dim):
+ return ir.CeilDiv(index * inp_dim, out_dim)
+
+ def end_index(index, out_dim, inp_dim):
+ return start_index((index + 1), out_dim, inp_dim)
+
+ h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h)
+ h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h)
+
+ w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w)
+ w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w)
+
+ fn_sum = _adaptive_pooling_idx_sum(
+ [h_kernel_max, w_kernel_max],
+ [h_start_index, w_start_index],
+ [h_end_index, w_end_index],
+ )
+
+ def fn(idx):
+ return fn_sum(idx, pad_adaptive_loader(x))
+
+ rv = Pointwise.create(
+ device=x.get_device(),
+ dtype=x.get_dtype(),
+ inner_fn=fn,
+ ranges=list(input_size),
+ )
+
+ return rv
+
+
+@register_lowering(aten.avg_pool2d, type_promotion_kind=None)
+def avg_pool2d(
+ x,
+ kernel_size,
+ stride=(),
+ padding=0,
+ ceil_mode=False,
+ count_include_pad=True,
+ divisor_override=None,
+):
+ if not stride:
+ stride = kernel_size
+ if not padding:
+ padding = [0, 0]
+
+ assert isinstance(x, TensorBox)
+ assert len(kernel_size) == 2
+ assert len(stride) == 2
+ assert len(padding) == 2
+ assert len(x.get_size()) in (3, 4)
+
+ x.realize_hint()
+ *batch, h, w = x.get_size()
+
+ h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
+ w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
+
+ if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
+ x_loader = constant_boundary_condition_2d(x, 0.0, padding)
+ had_padding = True
+ else:
+ x_loader = x.make_loader()
+ had_padding = False
+
+ new_size = list(batch) + [h_out, w_out]
+ dtype = x.get_dtype()
+
+ def fn_sum(idx, loader):
+ *prefix, bh, bw = idx
+ total = None
+ for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
+ ih = bh * stride[0] + ih - padding[0]
+ iw = bw * stride[1] + iw - padding[1]
+ val = loader([*prefix, ih, iw])
+ if total is None:
+ total = val
+ else:
+ total = ops.add(val, total)
+ return total
+
+ if count_include_pad or not had_padding or divisor_override:
+ if divisor_override:
+ scale = 1 / divisor_override
+ else:
+ scale = 1.0 / (kernel_size[0] * kernel_size[1])
+
+ def fn(idx):
+ return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype))
+
+ else:
+ ones_loader = constant_boundary_condition_2d(ones_like(x), 0.0, padding)
+
+ def fn(idx):
+ # TODO(jansel): optimize to do `int(x p
+ return bool_mask.to(x.dtype) * x * scale
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ p = ctx.p
+ scale = float(1.0 / (1.0 - p))
+ (seed,) = ctx.saved_tensors
+ bool_mask = philox_rand_like(grad_output, seed, ctx.offset) > p
+ return bool_mask.to(grad_output.dtype) * grad_output * scale, None
+
+
+@torch.fx.wrap
+def lowmem_dropout(input, p, training=True, inplace=False):
+ if isinstance(input, torch.fx.Proxy):
+ # double check we don't FX trace this
+ return input.tracer.create_proxy(
+ "call_function",
+ lowmem_dropout,
+ (input, p, training),
+ {},
+ )
+ if not training or p == 0:
+ return input
+ result = LowmemDropout.apply(input, p)
+ if inplace:
+ input.copy_(result)
+ return result
+
+
+@torch.fx.wrap
+def rand_like(x, **kwargs):
+ if isinstance(x, torch.fx.Proxy):
+ # double check we don't FX trace this
+ return x.tracer.create_proxy("call_function", rand_like, (x), kwargs)
+ assert kwargs.get("device", x.device) == x.device
+ seed, offset = PhiloxRandomState.get_seed_offset(x)
+ return philox_rand_like(x, seed, offset).to(kwargs.get("dtype", torch.float32))
+
+
+replacements = {torch.nn.functional.dropout: lowmem_dropout, torch.rand_like: rand_like}
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
new file mode 100644
index 0000000000000..88181fb0ce7f2
--- /dev/null
+++ b/torch/_inductor/scheduler.py
@@ -0,0 +1,1083 @@
+import collections
+import dataclasses
+import functools
+import itertools
+import logging
+import os
+import pprint
+import textwrap
+from typing import Dict, List, Optional, Set, Union
+
+import numpy as np
+import sympy
+
+import torch
+
+from . import config, dependencies, ir
+from .dependencies import MemoryDep, StarDep
+from .sizevars import SimplifyIndexing
+from .utils import cache_on_self, cmp, dynamo_utils
+from .virtualized import V
+
+log = logging.getLogger(__name__)
+
+
+def pformat(obj):
+ if isinstance(obj, set):
+ # pformat has trouble with sets of sympy exprs
+ obj = sorted(obj, key=str)
+ result = pprint.pformat(obj, indent=4)
+ if "\n" in result:
+ return f"\n{textwrap.indent(result, ' '*4)}"
+ return result
+
+
+class OutputNode:
+ def __init__(self, dep):
+ self.unmet_dependencies = {dep}
+ self.inverse_users = []
+
+ def is_reduction(self):
+ return False
+
+ def get_alias_names(self):
+ return ()
+
+ def get_name(self):
+ return "OUTPUT"
+
+ __repr__ = get_name
+
+
+class BaseSchedulerNode:
+ def __init__(self, scheduler: "Scheduler", node: ir.Buffer):
+ self.scheduler: "Scheduler" = scheduler
+ self.node: ir.Buffer = node
+ self.users: Optional[List[NodeUser]] = None
+ self.inverse_users: List[BaseSchedulerNode] = []
+ self.set_read_writes(node.get_read_writes())
+ self.recursive_predecessors: Optional[Set[str]] = None
+ self.min_order: Optional[int] = None
+ self.max_order: Optional[int] = None
+ self.last_usage: Set[str] = None # buffers that won't be used after this kernel
+ self.written = False
+
+ def __repr__(self):
+ return f"{type(self).__name__}(name={self.get_name()!r})"
+
+ def debug_str(self):
+ """Longer form printout for trace logs"""
+ name = self.get_name()
+ lines = [
+ f"{name}: {type(self).__name__}({type(self.node).__name__})",
+ f"{name}.writes = {pformat(self.read_writes.writes)}",
+ f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}",
+ f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}",
+ ]
+ try:
+ lines += [
+ self.debug_str_extra(),
+ ]
+ except Exception:
+ log.warning("Ignoring error in debug_str()", exc_info=True)
+ return "\n".join(lines).rstrip()
+
+ def debug_str_extra(self):
+ return ""
+
+ def log_details(self):
+ log.info(
+ "%s: unmet_dependencies = %s, writes = %s",
+ self,
+ self.unmet_dependencies,
+ self.read_writes.writes,
+ )
+
+ def update_mutated_names(self, renames: Dict[str, str]):
+ self.set_read_writes(self.read_writes.rename(renames))
+
+ def add_mutation_dep(self, name):
+ self.set_read_writes(self.read_writes.with_read(name))
+
+ def set_users(self, users: List["NodeUser"]):
+ # deduplicate
+ result: Dict[int, NodeUser] = {}
+ for use in users:
+ if id(use.node) in result:
+ result[id(use.node)] = NodeUser(
+ use.node, result[id(use.node)].can_inplace and use.can_inplace
+ )
+ else:
+ result[id(use.node)] = use
+ self.users = list(result.values())
+
+ def get_aliases(self):
+ return self.node.get_alias_names()
+
+ def get_mutations(self):
+ return self.node.get_mutation_names()
+
+ def set_read_writes(self, rw: dependencies.ReadWrites):
+ self.read_writes: dependencies.ReadWrites = rw
+ self.unmet_dependencies = self.read_writes.reads
+ self.prune_deps()
+
+ def used_buffer_names(self) -> Set[str]:
+ return {
+ dep.name
+ for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
+ }
+
+ def prune_deps(self):
+ self.unmet_dependencies = {
+ dep
+ for dep in self.unmet_dependencies
+ if dep.name not in self.scheduler.available_buffer_names
+ }
+
+ def get_name(self) -> str:
+ return self.node.get_name()
+
+ def get_first_name(self) -> str:
+ return self.get_name()
+
+ def get_names(self) -> Set[str]:
+ return set([self.get_name()])
+
+ def get_nodes(self) -> List["BaseSchedulerNode"]:
+ return [self]
+
+ def get_device(self):
+ return self.node.get_device()
+
+ def is_reduction(self):
+ return False
+
+ def is_template(self):
+ return False
+
+ def is_extern(self):
+ return False
+
+ def can_inplace(self, read_dep: dependencies.MemoryDep):
+ return False
+
+ def allocate(self):
+ from .codegen.triton_template import should_use_template
+
+ if self.node.should_allocate() or should_use_template(self.node):
+ # if self.node should allocate or
+ # if self.node is generated by TritonKernelTemplates
+ # because Triton kernel could not allocate tensor itself
+ V.graph.wrapper_code.codegen_allocation(self.node)
+
+ def can_free(self):
+ for use in self.users:
+ if isinstance(use.node, OutputNode):
+ return False
+ return True
+
+ def codegen_originating_info(self, buffer, only_once=True):
+ if not config.comment_origin:
+ return
+
+ if only_once and self.written:
+ return
+ origins = self.node.origins
+ out_lines = []
+
+ for o in origins:
+ if o.op == "output":
+ # These are boring and samey
+ continue
+
+ out_lines.append("")
+ # TODO(voz): Should the pragma be constant somewhere?
+ out_lines.append("#pragma CMT ORIGIN:")
+ out_lines.append(f"#pragma CMT {o.op} {o.target}")
+ if "stack_trace" in o.meta:
+ stack_trace = f"{o.meta['stack_trace']}"
+ stack_trace_last_line = stack_trace.split("|")[-1]
+ out_lines.append(
+ "#pragma CMT "
+ + stack_trace_last_line.replace("{", "{{")
+ .replace("}", "}}")
+ .replace("\n", "\\")
+ )
+ out_lines.append("#pragma CMT END ORIGIN")
+ out_lines.append("")
+
+ if len(out_lines) == 0:
+ return
+
+ # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
+ # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
+ buffer.writelines(out_lines)
+ self.written = True
+
+
+class ExternKernelSchedulerNode(BaseSchedulerNode):
+ def debug_str_extra(self):
+ return f"{self.get_name()}.node.kernel = {getattr(self.node, 'kernel', None)}"
+
+ def is_extern(self):
+ return True
+
+
+class TemplateSchedulerNode(BaseSchedulerNode):
+ def __init__(self, scheduler: "Scheduler", node: ir.ExternKernel, group_fn):
+ super().__init__(scheduler, node)
+ (self._sizes, self._stride) = node.get_group_stride()
+ self.group = (node.get_device(), group_fn(self._sizes))
+ self.set_read_writes(node.get_read_writes())
+ self.update_dep_type()
+
+ def is_template(self):
+ return True
+
+ def update_dep_type(self):
+ assert len(self.read_writes.writes) == 1
+ write = self.read_writes.writes.pop()
+ if isinstance(write, StarDep):
+ name = write.name
+ canonicalized_index, canonicalized_size = self.node.canonicalize()
+ new_dep = MemoryDep(name, canonicalized_index, canonicalized_size)
+ self.read_writes.writes.add(new_dep)
+ else:
+ self.read_writes.writes.add(write)
+
+ def get_ranges(self):
+ return self._sizes
+
+
+class NopKernelSchedulerNode(BaseSchedulerNode):
+ pass
+
+
+class SchedulerNode(BaseSchedulerNode):
+ def __init__(self, scheduler: "Scheduler", node: ir.ComputedBuffer, group_fn):
+ super().__init__(scheduler, node)
+ (
+ self._sizes,
+ self._body,
+ ) = node.simplify_and_reorder()
+
+ self.group = (node.get_device(), group_fn(self._sizes))
+
+ self.set_read_writes(
+ dependencies.extract_read_writes(self._body, *self._sizes, normalize=True)
+ )
+ if self.is_reduction():
+ # reduction has last (reduced) dim in its sizes, and some
+ # downstream dependencies get confused by it
+ self.read_writes.writes = self.read_writes.writes | {
+ w.strip_last_size() for w in self.read_writes.writes
+ }
+ # reduction not on the last dim swaps the sizes, and downstream
+ # dependencies expect unswapped
+ # TODO swapping sizes doesn't work, leads to
+ # File "/scratch/ngimel/work/repos/torchdynamo/torchinductor/sizevars.py", line 130, in guard_equals
+ # if len(right.free_symbols) < len(left.free_symbols):
+ # AttributeError: 'int' object has no attribute 'free_symbols'
+ # even though memory dep looks correct
+ # self.read_writes.writes = self.read_writes.writes | {
+ # w.maybe_swap_sizes() for w in self.read_writes.writes
+ # }
+
+ def debug_str_extra(self):
+ name = self.get_name()
+ lines = [
+ f"{name}.group.device = {self.group[0]}",
+ f"{name}.group.iteration = {self.group[1]}",
+ f"{name}.sizes = {self._sizes}",
+ ]
+ if self.get_aliases():
+ lines.append(f"{name}.aliases = {pformat(self.get_aliases())}")
+ if self.get_mutations():
+ lines.append(f"{name}.mutations = {pformat(self.get_mutations())}")
+ if isinstance(self._body, ir.LoopBody):
+ lines.append(f"class {name}_loop_body:")
+ lines.append(textwrap.indent(self._body.debug_str(), " "))
+ return "\n".join(lines)
+
+ def get_ranges(self):
+ return self._sizes
+
+ def is_reduction(self):
+ return bool(self.node.data.get_reduction_type())
+
+ def allocate(self):
+ if (
+ not self.node.should_allocate()
+ or self.node.get_alias_names()
+ or self.node.get_mutation_names()
+ ):
+ return super().allocate()
+
+ if config.inplace_buffers:
+ raise AssertionError("https://github.com/pytorch/torchdynamo/issues/823")
+ """
+ for read in self.read_writes.reads:
+ input_node: BaseSchedulerNode = self.scheduler.name_to_node.get(
+ read.name
+ )
+ if input_node and V.graph.wrapper_code.can_reuse(input_node):
+ remaining_uses = [
+ x
+ for x in input_node.users
+ if x.node.get_name()
+ not in self.scheduler.available_buffer_names
+ ]
+ if (
+ len(remaining_uses) == 1
+ and remaining_uses[0].can_inplace
+ and remaining_uses[0].node is self
+ ):
+ V.graph.wrapper_code.codegen_inplace_reuse(
+ input_node.node, self.node
+ )
+ V.kernel.args.make_inplace(
+ input_node.get_name(), self.get_name()
+ )
+ return
+ """
+ super().allocate()
+
+ def run(self, *index_vars):
+ self.mark_run()
+ self.codegen(index_vars)
+
+ def mark_run(self):
+ self.allocate()
+
+ def codegen(self, index_vars):
+ sizes = self._sizes
+ assert sum(map(len, sizes)) == sum(map(len, index_vars))
+ var_ranges = dict(
+ zip(
+ itertools.chain.from_iterable(index_vars),
+ itertools.chain.from_iterable(sizes),
+ )
+ )
+ try:
+ with V.set_ops_handler(
+ SimplifyIndexing(V.get_ops_handler(), var_ranges)
+ ), V.kernel.set_current_node(self):
+ self._body(*index_vars)
+ except Exception:
+ log.fatal("Error in codegen for %s", self.node)
+ raise
+
+ def pointwise_read_writes(self):
+ """
+ Get the memory dependencies in the non-reduction axis.
+ """
+ sizes, reduction_sizes = self._sizes
+
+ def fn(index):
+ return self._body(index, [sympy.Integer(0) for _ in reduction_sizes])
+
+ return dependencies.extract_read_writes(fn, sizes)
+
+ def can_inplace(self, read_dep: dependencies.MemoryDep):
+ if self.get_aliases():
+ return False
+ if len(self.read_writes.writes) == 1 and hasattr(read_dep, "index"):
+ write_dep = next(iter(self.read_writes.writes))
+ return read_dep.index == write_dep.index and read_dep.size == write_dep.size
+ return False
+
+
+class FusedSchedulerNode(BaseSchedulerNode):
+ """
+ This is a "fake" scheduler node that represents a group of scheduler nodes
+ that are meant to be fused together. The way it does this is by maintaining
+ its unmet dependencies as the union of its constituent nodes.
+ """
+
+ @classmethod
+ def fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
+ assert node1.scheduler is node2.scheduler
+ return cls(node1.scheduler, node1.get_nodes() + node2.get_nodes())
+
+ def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]):
+ # NB: No need to call super().__init__() because we don't need to re-use any of its logic.
+ self.snodes = snodes
+ self.scheduler = scheduler
+ self.node = None # type: ignore[assignment]
+ self.users = None
+ self.inverse_users = []
+ self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
+ self.recursive_predecessors = functools.reduce(
+ set.union, [x.recursive_predecessors for x in snodes]
+ )
+ self.set_read_writes(
+ functools.reduce(
+ dependencies.ReadWrites.merge, [x.read_writes for x in snodes]
+ )
+ )
+ names = set(self.get_names())
+ self.unmet_dependencies = {
+ dep
+ for dep in functools.reduce(
+ set.union, [x.unmet_dependencies for x in snodes]
+ )
+ if dep.name not in names
+ } - self.read_writes.writes
+ self.min_order = min([x.min_order for x in self.snodes])
+ self.max_order = max([x.max_order for x in self.snodes])
+
+ @cache_on_self
+ def get_name(self) -> str:
+ return "_".join([x.get_name() for x in self.snodes])
+
+ def get_first_name(self) -> str:
+ return self.snodes[0].get_name()
+
+ @cache_on_self
+ def get_names(self) -> Set[str]:
+ return functools.reduce(set.union, [x.get_names() for x in self.snodes])
+
+ def debug_str_extra(self):
+ return (
+ f"{self.get_name()}.snodes = {pformat([x.get_name() for x in self.snodes])}"
+ )
+
+ @cache_on_self
+ def used_buffer_names(self) -> Set[str]:
+ return functools.reduce(set.union, [x.used_buffer_names() for x in self.snodes])
+
+ def get_nodes(self) -> List[BaseSchedulerNode]:
+ return self.snodes
+
+ def __repr__(self):
+ return f"{type(self).__name__}(nodes={self.get_name()})"
+
+ @cache_on_self
+ def is_reduction(self):
+ return any(x.is_reduction() for x in self.snodes)
+
+ @cache_on_self
+ def is_template(self):
+ return any(x.is_template() for x in self.snodes)
+
+ def get_device(self):
+ return self.group[0]
+
+ # None of these need to be implemented, as a FusedSchedulerNode is just an
+ # abstraction for scheduling purposes
+ def update_mutated_names(self, renames: Dict[str, str]):
+ raise NotImplementedError
+
+ def add_mutation_dep(self, name):
+ raise NotImplementedError
+
+ def set_users(self, users: List["NodeUser"]):
+ raise NotImplementedError
+
+ def get_aliases(self):
+ raise NotImplementedError
+
+ def get_mutations(self):
+ raise NotImplementedError
+
+ def can_inplace(self, read_dep: dependencies.MemoryDep):
+ raise NotImplementedError
+
+ def allocate(self):
+ raise NotImplementedError
+
+ def can_free(self):
+ raise NotImplementedError
+
+
+def pick_loop_order(stride_lengths, sizes, priority_idx=()):
+ """
+ A heuristic to decide loop iteration orders. This has not been well
+ tuned and may be something we should autotune.
+ """
+
+ @functools.cmp_to_key
+ def index_cmp(a, b):
+ if sizes[a] == 1 or sizes[b] == 1:
+ # 1-sizes don't matter, just move them to the end
+ return cmp(sizes[a] == 1, sizes[b] == 1)
+
+ a_first = np.logical_or(
+ stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]
+ ).all()
+ b_first = np.logical_or(
+ stride_lengths[:, a] == 0, stride_lengths[:, a] > stride_lengths[:, b]
+ ).all()
+
+ if a_first and not b_first:
+ return -1
+ if b_first and not a_first:
+ return 1
+
+ # otherwise contiguous
+ return cmp(b, a)
+
+ order = list(reversed(range(stride_lengths.shape[1])))
+ if len(priority_idx) > 0:
+ # if we have priority node, only use that node's order
+ stride_lengths = stride_lengths[priority_idx]
+ if config.pick_loop_orders:
+ order.sort(key=index_cmp)
+ return order
+
+
+@dataclasses.dataclass
+class NodeUser:
+ node: BaseSchedulerNode
+ can_inplace: bool = False
+
+ def get_name(self):
+ return self.node.get_name()
+
+
+class Scheduler:
+ @dynamo_utils.dynamo_timed
+ def __init__(self, nodes):
+ from .codegen.triton_template import should_use_template
+
+ super(Scheduler, self).__init__()
+ self.backends = {}
+
+ self.nodes = []
+ self.available_buffer_names = {
+ *V.graph.graph_inputs.keys(),
+ *V.graph.constants.keys(),
+ }
+ for node in nodes:
+ assert (
+ node.origins is not None
+ ), "All nodes passed to scheduling must have an origin"
+ if node.is_no_op():
+ self.nodes.append(NopKernelSchedulerNode(self, node))
+ elif isinstance(node, ir.ComputedBuffer):
+ group_fn = self.get_backend(node.get_device()).group_fn
+ self.nodes.append(SchedulerNode(self, node, group_fn))
+ elif isinstance(node, ir.ExternKernel) and should_use_template(node):
+ group_fn = self.get_backend(node.get_device()).group_fn
+ self.nodes.append(TemplateSchedulerNode(self, node, group_fn))
+ elif isinstance(node, ir.ExternKernel):
+ self.nodes.append(ExternKernelSchedulerNode(self, node))
+ else:
+ raise NotImplementedError(node)
+ # some new constants could have been created above
+ self.available_buffer_names.update(V.graph.constants.keys())
+ for node in self.nodes:
+ node.prune_deps()
+
+ self.name_to_node = {node.get_name(): node for node in self.nodes}
+ self.name_to_fused_node = None # set in fuse_nods()
+
+ # we handle mutation by renaming modified versions of the same
+ # buffer in the dependency graph to prevent cycles.
+ # mutation_renames: tracks the current name for a given buffer
+ # (changed once per mutation)
+ self.mutation_real_name = {}
+ # mutation_real_name: maps back to the original name for codegen
+ self.mutation_renames = {}
+
+ self.compute_dependencies()
+ self.topological_sort_schedule()
+ self.compute_predecessors()
+ self.dead_node_elimination()
+
+ V.debug.ir_pre_fusion(self.nodes)
+ self.num_orig_nodes = len(self.nodes)
+ self.name_to_fused_node = {n.get_name(): n for n in self.nodes}
+ self.fuse_nodes()
+ self.compute_last_usage()
+ V.debug.ir_post_fusion(self.nodes)
+ V.debug.graph_diagram(self.nodes)
+ self.debug_draw_graph()
+
+ # used during codegen:
+ self.current_device = None
+ self.buffer_names_to_free = set()
+ self.buffer_names_no_longer_needed = set()
+
+ def debug_draw_graph(self):
+ """Generate an image of the graph for debugging"""
+ if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
+ from .debug import draw_buffers
+
+ draw_buffers(self.nodes, print_graph=True)
+
+ def debug_print_nodes(self, label):
+ if log.isEnabledFor(logging.INFO):
+ log.info("%s:", label)
+ for node in self.nodes:
+ node.log_details()
+
+ def compute_dependencies(self):
+ """
+ Create dependency edges between nodes, handling aliasing and
+ mutation properly.
+ """
+ name_to_users = collections.defaultdict(list)
+
+ # handle aliasing by using python aliasing in name_to_users
+ # if foo aliases bar then we will make name_to_users["foo"] point
+ # to the same python list as name_to_users["bar"]
+ for node1 in self.nodes:
+ node1_name = node1.get_name()
+ for node2_name in node1.get_aliases():
+ if node1_name in name_to_users and node2_name in name_to_users:
+ # merge the two
+ list1 = name_to_users[node1_name]
+ list2 = name_to_users[node2_name]
+ combined = list1 + list2
+ for key in name_to_users.keys():
+ if name_to_users[key] is list1 or name_to_users[key] is list2:
+ name_to_users[key] = combined
+ elif node1_name in name_to_users:
+ name_to_users[node2_name] = name_to_users[node1_name]
+ else:
+ name_to_users[node1_name] = name_to_users[node2_name]
+
+ def rename(n):
+ if n in self.mutation_renames:
+ return rename(self.mutation_renames[n])
+ return n
+
+ def dep_closure(node_name):
+ reachable_names = {node_name}
+ node = self.name_to_node[node_name]
+ write_dep = list(node.read_writes.writes)[0]
+ for read_dep in node.read_writes.reads:
+ if (
+ read_dep.name in self.name_to_node
+ and read_dep.index == write_dep.index
+ and read_dep.size == write_dep.size
+ ):
+ reachable_names.update(dep_closure(read_dep.name))
+ return reachable_names
+
+ def add_user(used_by_name, user_node, can_inplace=False):
+ name_to_users[rename(used_by_name)].append(NodeUser(user_node, can_inplace))
+
+ for node in self.nodes:
+ # a node will mutate either 0 or 1 buffers
+ for alt_name in node.get_mutations():
+ alt_name = rename(alt_name)
+ # this node must run after the prior writer
+ add_user(alt_name, node)
+ node.add_mutation_dep(alt_name)
+ for other_node in name_to_users[alt_name]:
+ # this node must run after all prior readers
+ other_name = rename(other_node.get_name())
+ known_dep_node_names = dep_closure(node.get_name())
+ if other_name not in known_dep_node_names:
+ # If this node alreay directly or indirectly depends on other_node,
+ # we don't need to insert an extra StarDep.
+ node.add_mutation_dep(other_name)
+ add_user(other_name, node)
+
+ # add normal non-mutation dependencies
+ for read in node.read_writes.reads:
+ add_user(read.name, node, node.can_inplace(read))
+
+ node.update_mutated_names(self.mutation_renames)
+
+ # update our renaming scheme for the next iteration
+ for alt_name in node.get_mutations():
+ self.mutation_renames[rename(alt_name)] = node.get_name()
+ self.mutation_renames[alt_name] = node.get_name()
+ self.mutation_real_name[node.get_name()] = self.mutation_real_name.get(
+ alt_name, alt_name
+ )
+
+ # make sure outputs aren't dead-code-eliminated
+ for node_name in V.graph.get_output_names():
+ add_user(node_name, OutputNode(StarDep(node_name)))
+
+ # make sure input mutation isn't dead-code-eliminated
+ for name in self.mutation_renames:
+ if name in V.graph.graph_inputs:
+ add_user(name, OutputNode(StarDep(name)))
+ V.graph.mutated_inputs.add(name)
+
+ # copy users information onto the nodes
+ for node in self.nodes:
+ node.set_users(name_to_users[node.get_name()])
+
+ # populate inverse_users
+ for node in self.nodes:
+ for user in node.users:
+ user.node.inverse_users.append(node)
+
+ def dead_node_elimination(self):
+ """
+ Remove any nodes without users
+ """
+ updated_nodes = []
+ for node in self.nodes:
+ if node.users:
+ updated_nodes.append(node)
+ else:
+ # dead code
+ log.debug("removed dead node: %s", node.get_name())
+ V.graph.removed_buffers.add(node.get_name())
+ self.nodes = updated_nodes
+
+ def topological_sort_schedule(self):
+ """
+ Ensure self.nodes is in topologically sorted order
+ """
+ seen = set()
+ name_to_node = dict()
+ result = []
+
+ def visit(n):
+ if n not in seen:
+ seen.add(n)
+ for dep in sorted(n.unmet_dependencies, key=lambda d: d.name):
+ visit(name_to_node[dep.name])
+ result.append(n)
+
+ for node in self.nodes:
+ for name in node.get_names():
+ name_to_node[name] = node
+ for node in self.nodes:
+ visit(node)
+ self.nodes = result
+
+ def compute_predecessors(self):
+ """
+ Populate each node.recursive_predecessors
+ """
+ # note self.nodes is topologically sorted
+ name_to_predecessors = {}
+ for node in self.nodes:
+ recursive_predecessors = set()
+ for dep in node.unmet_dependencies:
+ recursive_predecessors.add(dep.name)
+ recursive_predecessors |= name_to_predecessors[dep.name]
+ name_to_predecessors[node.get_name()] = recursive_predecessors
+ node.recursive_predecessors = recursive_predecessors
+
+ for order, node in enumerate(self.nodes):
+ node.min_order = order
+ node.max_order = order
+
+ def fuse_nodes(self):
+ """
+ Mutates self.nodes to combine nodes into FusedSchedulerNodes.
+ """
+ for _ in range(10):
+ old_len = len(self.nodes)
+ self.fuse_nodes_once()
+ if len(self.nodes) == old_len:
+ break
+
+ def fuse_nodes_once(self):
+ """
+ Mutates self.nodes to combine nodes into FusedSchedulerNodes.
+
+ This relies on two key functions to control the logic:
+ - self.can_fuses(): checks if a fusion is legal
+ - self.score_fusion(): assigns priority to a given fusion
+ """
+ fused_nodes = set(self.nodes)
+ for node1, node2 in self.get_possible_fusions():
+ node1 = self.name_to_fused_node[node1.get_first_name()]
+ node2 = self.name_to_fused_node[node2.get_first_name()]
+ if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle(
+ node1, node2
+ ):
+ node3 = FusedSchedulerNode.fuse(node1, node2)
+ fused_nodes.remove(node1)
+ fused_nodes.remove(node2)
+ fused_nodes.add(node3)
+ self.name_to_fused_node.update(
+ {n.get_name(): node3 for n in node3.get_nodes()}
+ )
+ self.nodes = sorted(fused_nodes, key=lambda x: x.min_order)
+ self.topological_sort_schedule()
+
+ def get_possible_fusions(self):
+ """
+ Helper to find all legal fusion opportunities, sorted by self.score_fusion()
+ """
+ possible_fusions = []
+ seen = set()
+
+ def check_all_pairs(nodes):
+ for node1_index, node1 in enumerate(nodes):
+ for node2 in nodes[node1_index + 1 :]:
+ key = (node1, node2)
+ if key in seen:
+ continue
+ seen.add(key)
+
+ if self.can_fuse(node1, node2):
+ possible_fusions.append(key)
+ elif node2.is_template() and self.can_fuse(node2, node1):
+ # epilogue fusions are order dependent
+ possible_fusions.append((node2, node1))
+
+ buffer_names_grouping = collections.defaultdict(list)
+ for node in self.nodes:
+ for buf in node.used_buffer_names():
+ buffer_names_grouping[buf].append(node)
+ for node_grouping in buffer_names_grouping.values():
+ check_all_pairs(node_grouping)
+
+ if config.aggressive_fusion:
+ group_grouping = collections.defaultdict(list)
+ for node in self.nodes:
+ group = getattr(node, "group", None)
+ if group:
+ group_grouping[group].append(node)
+ for node_grouping in group_grouping.values():
+ check_all_pairs(node_grouping)
+
+ return sorted(possible_fusions, key=self.score_fusion_key, reverse=True)
+
+ def will_fusion_create_cycle(self, node1, node2):
+ """Finds whether there's a path from src to dst caused indirectly by fusion"""
+
+ def check(node):
+ if isinstance(node, FusedSchedulerNode) and node not in visited:
+ visited.add(node)
+ return bool(combined_names & node.recursive_predecessors) or any(
+ check(self.name_to_fused_node[n])
+ for n in node.recursive_predecessors - combined_predecessors
+ )
+ return False
+
+ visited = set()
+ combined_names = node1.get_names() | node2.get_names()
+ combined_predecessors = (
+ node1.recursive_predecessors | node2.recursive_predecessors
+ ) - combined_names
+ return any(check(self.name_to_fused_node[n]) for n in combined_predecessors)
+
+ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
+ """
+ Determine if it is possible to combine node1 and node2 into a
+ single fused node.
+ """
+ if node1 is node2:
+ return False
+ if (
+ isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
+ and not node1.is_template()
+ ):
+ return False
+ if (
+ isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
+ and not node2.is_template()
+ ):
+ return False
+ if node2.get_names() & node1.recursive_predecessors:
+ return False # node2 must go before node1
+ if node2.is_template():
+ return False # only epilogues
+
+ device = node1.get_device()
+ if device != node2.get_device():
+ return False # wrong device
+
+ no_shared_data = self.score_fusion_memory(node1, node2) == 0
+ if no_shared_data and (
+ not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction()
+ ):
+ return False # heuristic not needed for correctness
+
+ if len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size:
+ return False # heuristic not needed for correctness
+
+ if node1.get_names() & node2.recursive_predecessors:
+ # node2 depends on node1 outputs
+ if not self.can_fuse_vertical(node1, node2):
+ return False
+ if node1.is_template():
+ from .codegen.triton_template import template_can_fuse
+
+ return template_can_fuse(node1, node2)
+ return self.get_backend(device).can_fuse_vertical(node1, node2)
+ else: # nodes don't depend on each other, but may have common reads
+ if node1.is_template():
+ return False
+ return self.get_backend(device).can_fuse_horizontal(node1, node2)
+
+ def can_fuse_vertical(self, node1, node2):
+ """
+ Check if it is legal to fuse a consumer (node2) into a producer (node1).
+
+ We can fuse them if all the reads of node2 either match
+ corresponding writes in node1, or are written by nodes that can
+ be scheduled before the fusion of node1 and node2.
+ """
+ node1_names = node1.get_names()
+ remaining_deps = {
+ dep.name for dep in node2.unmet_dependencies - node1.read_writes.writes
+ }
+ if remaining_deps & node1_names:
+ # MemoryDeps didn't match and read different locations of the same buffer.
+ # Examples here include:
+ # - MemoryDep("foo", x) != MemoryDep("foo", x + 1)
+ # - MemoryDep("foo", x) != StarDep("foo")
+ return False
+ for name in remaining_deps:
+ if node1_names & self.name_to_fused_node[name].recursive_predecessors:
+ return False
+ return True
+
+ def score_fusion(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
+ """
+ Assign a score (higher comes first) to the fusion of node1
+ and node2. When different fusions conflict with each other,
+ this is the way we decide what order to run them in.
+
+ Our current score is based on:
+ - Estimate of the saved memory operations
+ - Fusions closer together in original order
+ """
+ memory_score = self.score_fusion_memory(node1, node2)
+ proximity_score = -max(
+ abs(node1.min_order - node2.max_order),
+ abs(node2.min_order - node1.max_order),
+ )
+ return (
+ node1.is_reduction() == node2.is_reduction() and memory_score > 0,
+ memory_score,
+ proximity_score,
+ )
+
+ def score_fusion_memory(self, node1, node2):
+ """
+ The first term in our fusion score that estimates number of saved memory operations.
+ """
+ common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
+ node2.read_writes.reads | node2.read_writes.writes
+ )
+ return sum(dep.numel_hint() for dep in common_memory_deps)
+
+ def score_fusion_key(self, nodes):
+ """
+ Shim for list.sort(key=...)
+ """
+ node1, node2 = nodes
+ return self.score_fusion(node1, node2)
+
+ def compute_last_usage(self):
+ """
+ Populate node.last_usage
+ """
+
+ future_used_buffers = set()
+ for node_name in V.graph.get_output_names():
+ future_used_buffers.add(node_name)
+
+ for node in reversed(self.nodes):
+ used_buffers = node.used_buffer_names()
+ used_buffers = {self.mutation_real_name.get(k, k) for k in used_buffers}
+ node.last_usage = used_buffers - future_used_buffers
+ future_used_buffers.update(used_buffers)
+
+ def free_buffers(self):
+ """Free any buffers that are no longer needed"""
+ for name in sorted(self.buffer_names_to_free - V.graph.removed_buffers):
+ if name in self.name_to_node:
+ node = self.name_to_node[name]
+ if node.can_free():
+ V.graph.wrapper_code.codegen_free(node.node)
+ self.buffer_names_to_free.clear()
+
+ def remove_kernel_local_buffers(self):
+ """
+ Any buffers that are both created and have a last use in the
+ same kernel can be removed.
+ """
+ for name in V.kernel.store_buffer_names & self.buffer_names_no_longer_needed:
+ if (
+ name not in V.kernel.must_keep_buffers
+ and name not in V.kernel.args.input_buffers
+ and name not in self.mutation_renames
+ and name not in self.mutation_real_name
+ ):
+ self.remove_buffer(name)
+
+ def remove_buffer(self, name):
+ # Assign a special value instead of deleting the entry
+ # because we still rely on output_buffers's length to
+ # generate unique arg name.
+ log.debug("remove_buffer(%r)", name)
+ V.kernel.args.output_buffers[name] = "REMOVED"
+ V.graph.removed_buffers.add(name)
+
+ def flush(self):
+ for backend in self.backends.values():
+ backend.flush()
+ self.free_buffers()
+
+ def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode):
+ assert isinstance(scheduler_node, ExternKernelSchedulerNode)
+ scheduler_node.allocate()
+ node = scheduler_node.node
+ node.codegen(V.graph.wrapper_code)
+ self.free_buffers()
+
+ def codegen_template_call(
+ self, scheduler_node: Union[FusedSchedulerNode, TemplateSchedulerNode]
+ ):
+ from .codegen.triton_template import template_codegen
+
+ node, *epilogue = scheduler_node.get_nodes()
+ node.allocate()
+ template_codegen(self, node, epilogue)
+ self.free_buffers()
+
+ def create_backend(self, device: torch.device):
+ assert (
+ device.type != "cuda" or device.index is not None
+ ), f"{device} should have been normalized in lowering"
+ V.graph.device_types.add(device.type)
+ if device.type == "cpu":
+ from .codegen.cpp import CppScheduling
+
+ return CppScheduling(self)
+ else:
+ from .codegen.triton import TritonScheduling
+
+ return TritonScheduling(self)
+
+ def get_backend(self, device: torch.device):
+ if device not in self.backends:
+ self.backends[device] = self.create_backend(device)
+ return self.backends[device]
+
+ @dynamo_utils.dynamo_timed
+ def codegen(self):
+ for node in self.nodes:
+ self.buffer_names_no_longer_needed.update(node.last_usage)
+
+ if not isinstance(node, NopKernelSchedulerNode):
+ device = node.get_device()
+ if (
+ device != self.current_device
+ or node.is_extern()
+ or node.is_template()
+ ):
+ self.flush()
+ self.current_device = device
+
+ self.buffer_names_to_free.update(node.last_usage)
+
+ if node.is_template():
+ self.codegen_template_call(node)
+ elif node.is_extern():
+ self.codegen_extern_call(node)
+ elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
+ self.get_backend(device).codegen_nodes(node.get_nodes())
+ else:
+ assert isinstance(node, NopKernelSchedulerNode)
+ node.allocate()
+
+ self.flush()
diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py
new file mode 100644
index 0000000000000..8c7c74e17c964
--- /dev/null
+++ b/torch/_inductor/sizevars.py
@@ -0,0 +1,591 @@
+import collections
+import dataclasses
+import functools
+import itertools
+import logging
+from typing import Callable, Dict, List, Tuple
+
+import sympy
+from sympy import Expr, Integer, Symbol
+
+from . import ir
+from .codegen.common import IndentedBuffer
+from .utils import sympy_subs, VarRanges
+from .virtualized import V
+
+log = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class ZeroGuard:
+ """
+ An expression we should check equals zero.
+ Guards are currently not checked. Plan to add this later.
+ """
+
+ expr: Expr
+
+
+@dataclasses.dataclass
+class PositiveGuard:
+ """
+ An expression we should check for > 0
+ Guards are currently not checked. Plan to add this later.
+ """
+
+ expr: Expr
+
+
+class SizeVarAllocator(object):
+ def __init__(self, prefix="s", zero_one_const=True):
+ super().__init__()
+ self.prefix = prefix
+ self.val_to_var: Dict[int, Expr] = {0: Integer(0), 1: Integer(1)}
+ self.var_to_val: Dict[Expr, int] = collections.OrderedDict()
+ self.guards = []
+ self.replacements: Dict[sympy.Symbol, Expr] = {}
+ self.need_seed = False
+ self.stride_vars = self.make_stride_vars_cache()
+ if not zero_one_const:
+ self.val_to_var.clear()
+ self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
+ self._simplify_loops = self.make_simplify_loops_cache()
+
+ def seed(self):
+ """
+ Seed is a special variable used to hold the rng seed for a graph.
+
+ Note this is only used by the CPU backend, we put seeds in a
+ 1-element tensor for the CUDA backend.
+ """
+ self.need_seed = True
+ return sympy.Symbol("seed")
+
+ def simplify(self, expr: Expr):
+ return sympy.expand(expr).xreplace(self.replacements)
+
+ def make_simplify_with_ranges_cache(self):
+ """
+ self._simplify_with_ranges() can be expensive, cache its results
+ """
+ cache = dict()
+ replacement_count = len(self.replacements)
+
+ def simplify_with_ranges(expr: Expr, var_ranges: VarRanges):
+ nonlocal replacement_count
+ if replacement_count != len(self.replacements):
+ # new replacements invalidates cached results
+ cache.clear()
+ replacement_count = len(self.replacements)
+ key = (expr, *var_ranges.items())
+ result = cache.get(key, None)
+ if result is None:
+ result = self._simplify_with_ranges(expr, var_ranges)
+ cache[key] = result
+ return result
+
+ return simplify_with_ranges
+
+ def make_simplify_loops_cache(self):
+ """
+ self._simplify_with_ranges() can be expensive, cache its results
+ """
+ cache = dict()
+ replacement_count = len(self.replacements)
+
+ def simplify_loops(index_vars, sizes, index_formulas):
+ nonlocal replacement_count
+ if replacement_count != len(self.replacements):
+ # new replacements invalidates cached results
+ cache.clear()
+ replacement_count = len(self.replacements)
+ key = (*index_vars, *sizes, *index_formulas)
+ result = cache.get(key, None)
+ if result is None:
+ result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
+ cache[key] = result
+ return result
+
+ return simplify_loops
+
+ def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges):
+ """
+ Simplify indexing expression with knowledge of the ranges of
+ iteration variables.
+ """
+ from .ir import IndexingDiv, ModularIndexing
+
+ expr = join_dimensions(self.simplify(expr))
+ original_expr = expr
+
+ def remove_zero_terms(base, divisor):
+ """Symbols smaller than the divisor are zero"""
+ for v in base.free_symbols:
+ if v in var_ranges:
+ # var smaller than divisor can be removed
+ # if the rest is guaranteed to be multiple of divisor
+ rest = sympy.Wild("_rest", exclude=[v])
+ m = base.match(v + rest)
+ if m and v not in m[rest].free_symbols:
+ gcd = sympy.gcd(m[rest], divisor)
+ if gcd == divisor:
+ if self.maybe_guard_leq(var_ranges[v], divisor):
+ base = m[rest]
+ return base
+
+ def visit_indexing_div(base, divisor):
+ return IndexingDiv(remove_zero_terms(base, divisor), divisor)
+
+ def visit_modular_indexing(base, divisor, modulus):
+ base = remove_zero_terms(base, divisor)
+ if isinstance(base, ModularIndexing):
+ # for modular indexing, biggest values from the ranges don't necessarily result in
+ # the biggest result, the biggest result is modulus - 1
+ base_s = base.args[2] - 1
+ elif not base.has(ModularIndexing):
+ # actual iteration range is to size-1
+ iter_ranges = {k: v - 1 for k, v in var_ranges.items()}
+ base_s = sympy_subs(base, iter_ranges)
+ else:
+ base_s = base
+ if self.maybe_guard_lt(base_s, modulus * divisor):
+ return IndexingDiv(base, divisor)
+ return ModularIndexing(base, divisor, modulus)
+
+ if expr.has(ModularIndexing):
+ expr = expr.replace(
+ ModularIndexing(
+ sympy.Wild("base"),
+ sympy.Wild("divisor"),
+ sympy.Wild("modulus"),
+ ),
+ visit_modular_indexing,
+ )
+
+ if expr.has(IndexingDiv):
+ expr = expr.replace(
+ IndexingDiv(
+ sympy.Wild("base"),
+ sympy.Wild("divisor"),
+ ),
+ visit_indexing_div,
+ )
+
+ if expr != original_expr:
+ return self._simplify_with_ranges(expr, var_ranges)
+ return expr
+
+ def _simplify_loops_impl(self, index_vars, sizes, index_formulas):
+ """
+ Try to remove as many axis from loop iterations as possible, by:
+ 1) removing size==1 dimensions
+ 2) fuse contiguous dimensions into a single loop
+ If channel_last = True, we will prevent the last dim fused with other dims
+ """
+ sizes = list(map(self.simplify, sizes))
+
+ strides = [self.stride_vars(x, index_vars) for x in index_formulas]
+ assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
+
+ for i in range(len(sizes)):
+ if sizes[i] == 1:
+ # remove dim
+ sizes[i] = None
+
+ def can_merge_dims(a, b):
+ for k in range(len(strides)):
+ if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
+ strides[k][b]
+ ):
+ # approximate test passed, try sound version
+ va = index_vars[a]
+ vb = index_vars[b]
+ v = sympy.Symbol("_merge_tester")
+ expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0})
+ expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v})
+ if self.simplify(expr1) == self.simplify(expr2):
+ continue
+ return False
+ return True
+
+ changed = True
+ while changed:
+ changed = False
+ for i, j in itertools.product(
+ reversed(range(len(sizes))), reversed(range(len(sizes)))
+ ):
+ if i == j or sizes[i] is None or sizes[j] is None:
+ continue
+ if can_merge_dims(i, j):
+ changed = True
+ sizes[i] = sizes[i] * sizes[j]
+ sizes[j] = None
+
+ def reindex(index):
+ it = list(reversed(index))
+ new_index = []
+ for size in sizes:
+ if size is None:
+ new_index.append(sympy.Integer(0))
+ else:
+ new_index.append(it.pop())
+ assert not it
+ return new_index
+
+ def prune(index):
+ assert len(index) == len(sizes)
+ return [i for i, s in zip(index, sizes) if s is not None]
+
+ return [x for x in sizes if x is not None], reindex, prune
+
+ def guard_equals(self, left: Expr, right: Expr) -> Expr:
+ left = sympy.expand(left)
+ right = sympy.expand(right)
+ if left == right:
+ return left
+ expr = self.simplify(left - right)
+ assert self.size_hint(expr) == 0, (expr, self.size_hint(expr))
+ free = list(expr.free_symbols)
+ if len(free) == 0:
+ assert expr == 0
+ return left
+ elif len(free) in (1, 2, 3):
+ # remove the largest of the guarded variables
+ free.sort(key=self.size_hint)
+ try:
+ solutions = sympy.solve(expr, free[-1])
+ if (
+ len(solutions) == 1
+ and solutions[0]
+ and "/" not in str(solutions[0])
+ ):
+ self.replacements[free[-1]] = solutions[0]
+ except NotImplementedError:
+ pass
+
+ self.guards.append(ZeroGuard(expr))
+
+ if len(right.free_symbols) < len(left.free_symbols):
+ return right
+ else:
+ return left
+
+ def maybe_guard_equals(self, left: Expr, right: Expr) -> bool:
+ """if left==right, guard on that fact and return true"""
+ if left == right:
+ return True
+ if self.size_hint(left - right) == 0:
+ self.guard_equals(left, right)
+ return True
+ return False
+
+ def maybe_guard_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
+ """if left==right, guard on that fact and return true"""
+ if len(left) != len(right):
+ return False
+ if all(self.size_hint(a - b) == 0 for a, b in zip(left, right)):
+ for a, b in zip(left, right):
+ self.guard_equals(a, b)
+ return True
+ return False
+
+ def maybe_guard_leq(self, left: Expr, right: Expr) -> bool:
+ try:
+ if self.size_hint(left) > self.size_hint(right):
+ return False
+ except TypeError:
+ return False
+ self.guard_leq(left, right)
+ return True
+
+ def maybe_guard_lt(self, left: Expr, right: Expr) -> bool:
+ try:
+ if self.size_hint(left) >= self.size_hint(right):
+ return False
+ except TypeError:
+ return False
+ self.guard_lt(left, right)
+ return True
+
+ def guard_leq(self, left: Expr, right: Expr) -> None:
+ return self.guard_lt(left, right + 1)
+
+ def guard_lt(self, left: Expr, right: Expr) -> None:
+ expr = self.simplify(right - left)
+ assert self.size_hint(expr) > 0
+ if len(expr.free_symbols) == 0:
+ return
+ if "-" in str(expr):
+ # all vars are positive, so needs a minus sign to get negative values
+ self.guards.append(PositiveGuard(expr))
+
+ def guard_min(self, left: Expr, right: Expr) -> Expr:
+ """return the smaller of left and right, and guard on that choice"""
+ lv = self.size_hint(left)
+ rv = self.size_hint(right)
+ if lv == rv:
+ return self.guard_equals(left, right)
+ elif lv < rv:
+ self.guard_lt(left, right)
+ return left
+ else:
+ self.guard_lt(right, left)
+ return right
+
+ def guard_max(self, left: Expr, right: Expr) -> Expr:
+ """return the larger of left and right, and guard on that choice"""
+ return -self.guard_min(-left, -right)
+
+ def maybe_guard_multiple_of(self, numerator: Expr, denominator: Expr) -> bool:
+ """if denominator divides numerator, return True and guard on that fact"""
+ if sympy.gcd(numerator, denominator) == denominator:
+ # can prove it symbolically
+ return True
+ if self.size_hint(numerator) % self.size_hint(denominator) == 0:
+ multiple = self.size_hint(numerator) // self.size_hint(denominator)
+ self.guard_equals(multiple * denominator, numerator)
+ return True
+ return False
+
+ def guard_static_shape(self, left: Expr) -> int:
+ right = self.size_hint(left)
+ self.guard_equals(left, sympy.Integer(right))
+ return int(right)
+
+ def __getitem__(self, val: int) -> Expr:
+ if val < 0:
+ # all variables are positive
+ return -self[-val]
+ if val in self.val_to_var:
+ return self.val_to_var[val]
+ var = Symbol(f"{self.prefix}{len(self.var_to_val)}")
+ self.val_to_var[val] = var
+ self.var_to_val[var] = val
+ return var
+
+ def size_hint(self, expr: Expr) -> int:
+ return int(sympy_subs(sympy.expand(expr), self.var_to_val))
+
+ def _lru_cache(self, fn, maxsize=None):
+ """
+ Wrapper around functools.lru_cache that clears when replacements
+ has been invalidated.
+ """
+ fn_cache = functools.lru_cache(maxsize)(fn)
+ prior_len = len(self.replacements)
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ nonlocal prior_len
+ if prior_len != len(self.replacements):
+ prior_len = len(self.replacements)
+ fn_cache.cache_clear()
+ return fn_cache(*args, **kwargs)
+
+ return wrapper
+
+ def make_stride_vars_cache(self):
+ cache = self._lru_cache(self._stride_vars)
+
+ def stride_vars(index: Expr, vars: List[sympy.Symbol]) -> List[Expr]:
+ return cache(index, tuple(vars))
+
+ return stride_vars
+
+ def _stride_vars(self, index: Expr, vars: List[sympy.Symbol]) -> List[Expr]:
+ """Convert an indexing expression back into strides"""
+ strides = []
+ index = self.simplify(index)
+ # remove any offset
+ index = index - sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0})
+ for i in range(len(vars)):
+ # drop all the other dims
+ index_dim = sympy_subs(
+ index,
+ {
+ vars[j]: sympy.Integer(0)
+ for j in range(len(vars))
+ if i != j and vars[j] != 0
+ },
+ )
+ v = vars[i]
+ if v == 0:
+ strides.append(sympy.Integer(0))
+ else:
+ # TODO(jansel): should we use sympy.diff here?
+ strides.append(
+ sympy_subs(index_dim, {v: sympy.Integer(1)})
+ - sympy_subs(index_dim, {v: sympy.Integer(0)})
+ )
+ return strides
+
+ def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr:
+ """Extract offset part of an indexing expression"""
+ index = self.simplify(index)
+ return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0})
+
+ def stride_hints(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
+ for v in index.free_symbols:
+ if v.name.startswith("indirect"):
+ index = sympy_subs(index, {v: 0})
+ result = []
+ for s in self.stride_vars(index, vars):
+ try:
+ result.append(self.size_hint(s))
+ except TypeError:
+ result.append(0)
+ return result
+
+ def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
+ strides = tuple(
+ map(lambda x: abs(x), self.stride_hints(index, vars))
+ ) # lambda to placate mypy
+ order = list(range(len(strides)))
+ order.sort(key=lambda x: (strides[x] == 0, strides[x]))
+ return order
+
+ def codegen(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]):
+ """Assign all symbolic shapes to locals"""
+ if self.need_seed:
+ code.writeline(
+ "seed = torch.randint(2**31, size=(), dtype=torch.int32).item()"
+ )
+
+ @functools.lru_cache(None)
+ def sizeof(name):
+ code.writeline(f"{name}_size = {name}.size()")
+ return f"{name}_size"
+
+ @functools.lru_cache(None)
+ def strideof(name):
+ code.writeline(f"{name}_stride = {name}.stride()")
+ return f"{name}_stride"
+
+ # TODO: This should be the below, but causes test/test_torchinductor.py::GpuTests::test_triton_conv_cuda to fail
+ # needed_vars = set(self.var_to_val.keys()) - set(self.replacements.keys())
+
+ needed_vars = set(self.var_to_val.keys())
+ needed = set(map(str, needed_vars))
+
+ for name, value in graph_inputs.items():
+ shapes = value.get_size()
+ for dim, shape in enumerate(shapes):
+ shape = str(shape)
+ if shape in needed:
+ needed.remove(shape)
+ code.writeline(f"{shape} = {sizeof(name)}[{dim}]")
+
+ for name, value in graph_inputs.items():
+ shapes = value.get_stride()
+ for dim, shape in enumerate(shapes):
+ shape = str(shape)
+ if shape in needed:
+ needed.remove(shape)
+ code.writeline(f"{shape} = {strideof(name)}[{dim}]")
+
+ assert not needed
+
+ def codegen_sizevar(self, x: Expr) -> str:
+ from .codegen.wrapper import pexpr
+
+ return pexpr(self.simplify(x))
+
+ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
+ parts = list(map(self.codegen_sizevar, shape))
+ if len(parts) == 0:
+ return "()"
+ if len(parts) == 1:
+ return f"({parts[0]}, )"
+ return f"({', '.join(parts)})"
+
+
+def join_dimensions(expr: Expr) -> Expr:
+ from .ir import ModularIndexing
+
+ if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing):
+ return expr # fast exit path
+ return _join_dimensions_cached(expr)
+
+
+@functools.lru_cache(256)
+def _join_dimensions_cached(expr: Expr) -> Expr:
+ """
+ ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
+ becomes
+ ModularIndexing(i0, 1, 128)
+ ModularIndexing(i0, 1, 32) + 32 * IndexingDiv(i0, 32)
+ becomes i0
+
+
+ This type of pattern can come from view operations
+ """
+ from .ir import IndexingDiv, ModularIndexing
+
+ assert isinstance(expr, sympy.Add)
+
+ scale = sympy.Wild("scale", exclude=[0])
+ base = sympy.Wild("base")
+ divisor = sympy.Wild("divisor")
+ mod1 = sympy.Wild("modulus")
+ mod2 = sympy.Wild("modulus2")
+ for term1 in expr.args:
+ m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
+ if m1:
+ for term2 in expr.args:
+ m2 = term2.match(
+ m1[scale]
+ * m1[mod1]
+ * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2)
+ )
+ if m2 and term1 != term2:
+ expr = join_dimensions(
+ expr
+ - term1
+ - term2
+ + m1[scale]
+ * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2])
+ )
+ return expr
+ for term1 in expr.args:
+ m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
+ if m1:
+ for term2 in expr.args:
+ m2 = term2.match(
+ m1[scale] * m1[mod1] * IndexingDiv(m1[base], m1[divisor] * m1[mod1])
+ )
+ if m2 is not None: # in case of success we get an empty dict here
+ expr = join_dimensions(
+ expr
+ - term1
+ - term2
+ + m1[scale] * IndexingDiv(m1[base], m1[divisor])
+ )
+ return expr
+ return expr
+
+
+class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
+ """
+ A wrapper around .virtualize.ops that uses var range information to
+ simplify ir.ModularIndexing/ir.IndexingDiv.
+ """
+
+ def __init__(self, inner, var_ranges: VarRanges):
+ super().__init__(inner)
+ self._simplify: Callable[
+ [Expr], Expr
+ ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
+
+ def load(self, name: str, index: sympy.Expr):
+ return self._inner.load(name, self._simplify(index))
+
+ def store(self, name, index, value, mode=None):
+ return self._inner.store(name, self._simplify(index), value, mode=mode)
+
+ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
+ return self._inner.reduction(
+ name, dtype, src_dtype, reduction_type, self._simplify(index), value
+ )
+
+ def index_expr(self, index, dtype):
+ return self._inner.index_expr(self._simplify(index), dtype)
diff --git a/torch/_inductor/triton_ops/__init__.py b/torch/_inductor/triton_ops/__init__.py
new file mode 100644
index 0000000000000..b3f6ecc3ff429
--- /dev/null
+++ b/torch/_inductor/triton_ops/__init__.py
@@ -0,0 +1,8 @@
+from ..utils import has_triton
+
+if has_triton():
+ from .conv import _conv, conv
+ from .conv1x1 import _conv1x1, conv1x1
+ from .matmul import _matmul_out, matmul_out
+
+ __all__ = ["_conv", "conv", "_conv1x1", "conv1x1", "_matmul_out", "matmul_out"]
diff --git a/torch/_inductor/triton_ops/autotune.py b/torch/_inductor/triton_ops/autotune.py
new file mode 100644
index 0000000000000..f6d05cf2f8cfd
--- /dev/null
+++ b/torch/_inductor/triton_ops/autotune.py
@@ -0,0 +1,673 @@
+import builtins
+import copy
+import hashlib
+import json
+import logging
+import multiprocessing
+import os.path
+import threading
+from typing import List
+
+import torch
+
+from .. import config
+from ..codecache import AsyncCompile
+from ..ir import ReductionHint
+from ..triton_ops.mm_perf_model import estimate_matmul_time
+from ..utils import conditional_product, has_triton
+from .conv_perf_model import (
+ early_config_prune as conv_early_config_prune,
+ estimate_conv_time,
+)
+
+log = logging.getLogger(__name__)
+
+if has_triton():
+ import triton
+ from triton import cdiv, Config, next_power_of_2
+ from triton.runtime.jit import get_cuda_stream, KernelInterface
+else:
+ cdiv = None
+ Config = object
+ get_cuda_stream = None
+ KernelInterface = object
+ next_power_of_2 = None
+ triton = None
+
+
+class CachingAutotuner(KernelInterface):
+ """
+ Simplified version of Triton autotuner that has no invalidation
+ key and caches the best config to disk to improve cold start times.
+ Unlike the main triton Autotuner, this version can precompile all
+ configs, and does not rely on the Triton JIT.
+ """
+
+ def __init__(self, fn, meta, configs, save_cache_hook):
+ super().__init__()
+ self.fn = fn
+ self.meta = meta
+ self.save_cache_hook = save_cache_hook
+ self.configs = configs
+ self.launchers = []
+ self.lock = threading.Lock()
+
+ def precompile(self):
+ with self.lock:
+ if self.launchers:
+ return
+ self.launchers = AsyncCompile.map(self._precompile_config, self.configs)
+ self.configs = None
+
+ def _precompile_config(self, cfg: Config):
+ """Ahead of time compile a given autotuner config."""
+ torch.cuda.set_device(torch.cuda.current_device())
+ compile_meta = copy.deepcopy(self.meta)
+ for k, v in cfg.kwargs.items():
+ compile_meta["constants"][self.fn.arg_names.index(k)] = v
+ compile_meta["num_warps"] = cfg.num_warps
+ compile_meta["num_stages"] = cfg.num_stages
+
+ if config.compile_threads > 1:
+ major, minor = torch.cuda.get_device_capability(compile_meta["device"])
+ compile_meta["cc"] = major * 10 + minor
+ try:
+ p = multiprocessing.Process(
+ target=triton.compile,
+ args=(self.fn,),
+ kwargs={**compile_meta, "warm_cache_only": True},
+ )
+ p.start()
+ p.join()
+ except Exception:
+ log.exception("Error in async Triton compile")
+ # continue on to hopefully get a better error message below
+
+ binary = triton.compile(
+ self.fn,
+ **compile_meta,
+ )
+
+ call_args = [
+ arg
+ for i, arg in enumerate(self.fn.arg_names)
+ if i not in self.fn.constexprs
+ ]
+ def_args = list(self.fn.arg_names)
+ while def_args and def_args[-1] in cfg.kwargs:
+ def_args.pop()
+
+ scope = {
+ "grid_meta": cfg.kwargs,
+ "bin": binary,
+ "torch": torch,
+ "set_device": torch.cuda.set_device,
+ "current_device": torch.cuda.current_device,
+ }
+ exec(
+ f"""
+ def launcher({', '.join(def_args)}, grid, stream):
+ # set_device(current_device()) # TODO(jansel): is this needed?
+ grid_0, grid_1, grid_2 = grid(grid_meta)
+ bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
+ stream, bin.cu_function, None, None, None,
+ {', '.join(call_args)})
+ """.lstrip(),
+ scope,
+ )
+ launcher = scope["launcher"]
+ launcher.config = cfg
+ return launcher
+
+ def bench(self, launcher, *args, grid):
+ """Measure the performance of a given launcher"""
+ stream = get_cuda_stream(torch.cuda.current_device())
+
+ def kernel_call():
+ if launcher.config.pre_hook is not None:
+ launcher.config.pre_hook(
+ {**zip(self.arg_names, args), **launcher.config.kwargs}
+ )
+ launcher(
+ *args,
+ grid=grid,
+ stream=stream,
+ )
+
+ from triton.testing import do_bench
+
+ return do_bench(kernel_call)
+
+ def autotune_to_one_config(self, *args, **kwargs):
+ """Do the actual autotuning"""
+ timings = {
+ launcher: self.bench(launcher, *args, **kwargs)
+ for launcher in self.launchers
+ }
+ self.launchers = [builtins.min(timings, key=timings.get)]
+ if self.save_cache_hook:
+ self.save_cache_hook(self.launchers[0].config)
+
+ def run(self, *args, grid, stream):
+ if len(self.launchers) != 1:
+ if len(self.launchers) == 0:
+ self.precompile()
+ if len(self.launchers) > 1:
+ self.autotune_to_one_config(*args, grid=grid)
+
+ (launcher,) = self.launchers
+ if launcher.config.pre_hook is not None:
+ launcher.config.pre_hook(
+ {**zip(self.arg_names, args), **launcher.config.kwargs}
+ )
+ return launcher(
+ *args,
+ grid=grid,
+ stream=stream,
+ )
+
+
+def hash_configs(configs: List[Config]):
+ """
+ Hash used to check for changes in configurations
+ """
+ hasher = hashlib.sha256()
+ for cfg in configs:
+ hasher.update(
+ f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode(
+ "utf-8"
+ )
+ )
+ return hasher.hexdigest()
+
+
+def load_cached_autotuning(
+ cache_filename: str, configs_hash: str, configs: List[Config]
+):
+ """
+ Read a cached autotuning result from disk
+ """
+ if not os.path.exists(cache_filename):
+ return None
+
+ best_config = json.loads(open(cache_filename).read())
+ if best_config.get("configs_hash") != configs_hash:
+ return None
+
+ matching_configs = [
+ cfg
+ for cfg in configs
+ if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
+ ]
+ if len(matching_configs) != 1:
+ return None
+
+ return matching_configs[0]
+
+
+def cached_autotune(
+ configs: List[Config],
+ meta,
+ filename=None,
+):
+ """
+ A copy of triton.autotune that calls our subclass. Our subclass
+ has additional debugging, error handling, and on-disk caching.
+ """
+ configs = unique_configs(configs)
+ assert len(configs) == 1 or filename
+
+ # on disk caching logic
+ if filename is not None and len(configs) > 1:
+ cache_filename = os.path.splitext(filename)[0] + ".best_config"
+ configs_hash = hash_configs(configs)
+ best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
+ if best_config:
+ configs = [best_config]
+
+ def save_cache_hook(cfg):
+ with open(cache_filename, "w") as fd:
+ fd.write(json.dumps({**cfg.kwargs, "configs_hash": configs_hash}))
+
+ else:
+ save_cache_hook = None
+
+ def decorator(fn):
+ return CachingAutotuner(
+ fn, meta=meta, configs=configs, save_cache_hook=save_cache_hook
+ )
+
+ return decorator
+
+
+def unique_configs(configs: List[Config]):
+ """Remove duplicate configurations"""
+ seen = set()
+ pruned_configs = []
+ for cfg in configs:
+ key = tuple(cfg.kwargs.items())
+ if key not in seen:
+ seen.add(key)
+ pruned_configs.append(cfg)
+ return pruned_configs
+
+
+def triton_config(size_hints, x, y=None, z=None, num_stages=1) -> Config:
+ """
+ Construct a pointwise triton config with some adjustment heuristics
+ based on size_hints. Size_hints is a tuple of numels in each tile
+ dimension and will be rounded up to the nearest power of 2.
+ """
+ # Ideally we want to read this from some device config
+ maxGridSize = [2147483647, 65535, 65535]
+
+ target = conditional_product(x, y, z)
+ if conditional_product(*size_hints) < target:
+ target //= 8
+
+ # shrink sizes to size hints
+ x = min(x, size_hints[0])
+ if y:
+ y = min(y, size_hints[1])
+ if z:
+ z = min(z, size_hints[2])
+
+ # if we are below original block size, scale up where we can;
+ # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
+ while x < size_hints[0] and (
+ x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target
+ ):
+ x *= 2
+ while (
+ y
+ and y < size_hints[1]
+ and (
+ y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target
+ )
+ ):
+ y *= 2
+ while (
+ z
+ and z < size_hints[2]
+ and (
+ z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target
+ )
+ ):
+ z *= 2
+
+ cfg = {"XBLOCK": x}
+ if y:
+ cfg["YBLOCK"] = y
+ if z:
+ cfg["ZBLOCK"] = z
+ num_warps = next_power_of_2(min(max(conditional_product(x, y, z) // 256, 1), 8))
+ return Config(cfg, num_warps=num_warps, num_stages=num_stages)
+
+
+def triton_config_reduction(size_hints, x, r, num_stages=2) -> Config:
+ """
+ Construct a reduction triton config with some adjustment heuristics
+ based on size_hints. Size_hints is a tuple of numels in each tile
+ dimension and will be rounded up to the nearest power of 2.
+ """
+
+ target = conditional_product(x, r)
+ if conditional_product(*size_hints) < target:
+ target //= 8
+
+ # shrink sizes to size hints
+ x = min(x, size_hints[0])
+ r = min(r, size_hints[1])
+
+ # if we are below original block size, scale up where we can
+ while x < size_hints[0] and conditional_product(x, r) < target:
+ x *= 2
+ while r < size_hints[1] and conditional_product(x, r) < target:
+ r *= 2
+
+ cfg = {"XBLOCK": x, "RBLOCK": r}
+ num_warps = next_power_of_2(min(max(conditional_product(x, r) // 128, 1), 8))
+ return Config(cfg, num_warps=num_warps, num_stages=num_stages)
+
+
+def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=2):
+ """
+ Construct a tile reduction triton config with some adjustment
+ heuristics based on size_hints. Size_hints is a tuple of numels in
+ each tile dimension and will be rounded up to the nearest power of 2.
+ """
+
+ target = conditional_product(x, y, r)
+ if conditional_product(*size_hints) < target:
+ target //= 8
+
+ # shrink sizes to size hints
+ x = min(x, size_hints[0])
+ y = min(y, size_hints[1])
+ r = min(r, size_hints[2])
+
+ # if we are below original block size, scale up where we can
+ while x < size_hints[0] and conditional_product(x, y, r) < target:
+ x *= 2
+ while r < size_hints[2] and conditional_product(x, y, r) < target:
+ r *= 2
+ while y < size_hints[1] and conditional_product(x, y, r) < target:
+ y *= 2
+
+ cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
+ num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8))
+ return Config(cfg, num_warps=num_warps, num_stages=num_stages)
+
+
+def pointwise(size_hints, meta, filename=None):
+ """
+ Construct @triton.heuristics() based on size_hints.
+ """
+ if len(size_hints) == 1:
+ return cached_autotune([triton_config(size_hints, 1024)], meta=meta)
+ if len(size_hints) == 2:
+ if not config.triton.autotune:
+ return cached_autotune([triton_config(size_hints, 64, 64)], meta=meta)
+ return cached_autotune(
+ [
+ triton_config(size_hints, 32, 32),
+ triton_config(size_hints, 8, 256),
+ triton_config(size_hints, 256, 8),
+ triton_config(size_hints, 1, 1024),
+ triton_config(size_hints, 1024, 1),
+ ],
+ meta=meta,
+ filename=filename,
+ )
+ if len(size_hints) == 3:
+ if not config.triton.autotune:
+ return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta)
+ return cached_autotune(
+ [
+ triton_config(size_hints, 16, 16, 16),
+ triton_config(size_hints, 64, 8, 8),
+ triton_config(size_hints, 8, 64, 8),
+ triton_config(size_hints, 8, 8, 64),
+ triton_config(size_hints, 1024, 1, 1),
+ triton_config(size_hints, 1, 1024, 1),
+ triton_config(size_hints, 1, 1, 1024),
+ ],
+ meta=meta,
+ filename=filename,
+ )
+ raise NotImplementedError(f"size_hints: {size_hints}")
+
+
+def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
+ """args to @triton.heuristics()"""
+ assert meta is not None
+ rnumel = size_hints[-1]
+ if len(size_hints) == 2:
+ contiguous_config = triton_config_reduction(
+ size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048), num_stages=1
+ )
+ outer_config = triton_config_reduction(size_hints, 128, 8)
+ tiny_config = triton_config_reduction(
+ size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
+ )
+ if reduction_hint == ReductionHint.INNER:
+ return cached_autotune([contiguous_config], meta=meta)
+ elif reduction_hint == ReductionHint.OUTER:
+ return cached_autotune([outer_config], meta=meta)
+ elif reduction_hint == ReductionHint.OUTER_TINY:
+ return cached_autotune([tiny_config], meta=meta)
+ if not config.triton.autotune:
+ return cached_autotune(
+ [triton_config_reduction(size_hints, 32, 128)], meta=meta
+ )
+ return cached_autotune(
+ [
+ triton_config_reduction(size_hints, 64, 64),
+ triton_config_reduction(
+ size_hints, 128, 8
+ ), # this one is the best for outer reduction
+ triton_config_reduction(
+ size_hints, 8, 512
+ ), # this and the next one seem very similar but both are needed for perf
+ contiguous_config,
+ ],
+ meta=meta,
+ filename=filename,
+ )
+ raise NotImplementedError(f"size_hints: {size_hints}")
+
+
+def conv_heuristics():
+ configs = [
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=8
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=2
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=2
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=2
+ ),
+ # triton.Config(
+ # {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=4, num_warps=2
+ # ),
+ ]
+ key = [
+ "BATCH",
+ "IN_C",
+ "IN_H",
+ "IN_W",
+ "KERNEL_N",
+ "KERNEL_H",
+ "KERNEL_W",
+ "OUT_H",
+ "OUT_W",
+ # parameters of conv
+ "stride_h",
+ "stride_w",
+ "padding_h",
+ "padding_w",
+ "dilation_h",
+ "dilation_w",
+ "output_padding_h",
+ "output_padding_w",
+ "groups",
+ ]
+ prune_configs_by = {
+ "early_config_prune": conv_early_config_prune,
+ "perf_model": estimate_conv_time,
+ "top_k": 10,
+ }
+ return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
+
+
+def mm_heuristics():
+ from triton import heuristics
+
+ mm_heuristic = heuristics(
+ {
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
+ }
+ )
+ return mm_heuristic
+
+
+def mm_autotune(get_io_bound_configs=False):
+ from triton.ops.matmul import get_configs_io_bound
+ from triton.ops.matmul_perf_model import early_config_prune as mm_early_config_prune
+
+ configs = [
+ # basic configs for compute-bound matmuls
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=5,
+ num_warps=2,
+ ),
+ # good for int8
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=5,
+ num_warps=2,
+ ),
+ ]
+ if get_io_bound_configs:
+ configs += get_configs_io_bound()
+ key = ["M", "N", "K"]
+ prune_configs_by = {
+ "early_config_prune": mm_early_config_prune,
+ "perf_model": estimate_matmul_time,
+ "top_k": 10,
+ }
+ return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
+
+
+def grid(xnumel, ynumel=None, znumel=None):
+ """Helper function to compute triton grids"""
+
+ if ynumel and znumel:
+
+ def grid_fn(meta):
+ return (
+ cdiv(xnumel, meta["XBLOCK"]),
+ cdiv(ynumel, meta["YBLOCK"]),
+ cdiv(znumel, meta["ZBLOCK"]),
+ )
+
+ elif ynumel:
+
+ def grid_fn(meta):
+ return (
+ cdiv(xnumel, meta["XBLOCK"]),
+ cdiv(ynumel, meta["YBLOCK"]),
+ 1,
+ )
+
+ else:
+
+ def grid_fn(meta):
+ return (
+ cdiv(xnumel, meta["XBLOCK"]),
+ 1,
+ 1,
+ )
+
+ return grid_fn
diff --git a/torch/_inductor/triton_ops/batched_matmul.py b/torch/_inductor/triton_ops/batched_matmul.py
new file mode 100644
index 0000000000000..7e7a65596b021
--- /dev/null
+++ b/torch/_inductor/triton_ops/batched_matmul.py
@@ -0,0 +1,274 @@
+import torch
+
+from ..utils import has_triton
+
+if has_triton():
+ import triton
+ import triton.language as tl
+
+ def init_to_zero(name):
+ return lambda nargs: nargs[name].zero_()
+
+ @triton.heuristics(
+ {
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
+ }
+ )
+ @triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
+ num_stages=5,
+ num_warps=2,
+ ),
+ # additional configs
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=3,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=2,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=2,
+ num_warps=4,
+ ),
+ # additional configs for K = 64
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=1,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=1,
+ num_warps=8,
+ ),
+ triton.Config(
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=1,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=1,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=1,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=5,
+ num_warps=2,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=1,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=1,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=1,
+ num_warps=4,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
+ num_stages=1,
+ num_warps=2,
+ ),
+ ],
+ # + get_configs_io_bound(),
+ key=["M", "N", "K"],
+ #
+ # key=["M", "N", "K"],
+ # prune_configs_by={
+ # "early_config_prune": early_config_prune,
+ # "perf_model": estimate_matmul_time,
+ # "top_k": 18,
+ # },
+ )
+ @triton.jit
+ def _kernel(
+ A,
+ B,
+ C,
+ M,
+ N,
+ K,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+ SPLIT_K: tl.constexpr,
+ EVEN_K: tl.constexpr,
+ ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_z = tl.program_id(1)
+ bid = tl.program_id(2)
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ # do matrix multiplication
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ A += bid * M * K
+ B += bid * K * N
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(K, 0, -BLOCK_K * SPLIT_K):
+ if EVEN_K:
+ a = tl.load(A)
+ b = tl.load(B)
+ else:
+ a = tl.load(A, mask=rk[None, :] < k, other=0.0)
+ b = tl.load(B, mask=rk[:, None] < k, other=0.0)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * SPLIT_K * stride_ak
+ B += BLOCK_K * SPLIT_K * stride_bk
+ acc = acc.to(C.dtype.element_ty)
+
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ C += bid * M * N
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+ # handles write-back with reduction-splitting
+ if SPLIT_K == 1:
+ tl.store(C, acc, mask=mask)
+ else:
+ tl.atomic_add(C, acc, mask=mask)
+
+ def bmm_out(a, b, out):
+ # handle non-contiguous inputs if necessary
+ if a.stride(0) > 1 and a.stride(1) > 1:
+ a = a.contiguous()
+ if b.stride(0) > 1 and b.stride(1) > 1:
+ b = b.contiguous()
+ # checks constraints
+ assert a.shape[2] == b.shape[1], "incompatible dimensions"
+ B, M, K = a.shape
+ _, _, N = b.shape
+ # allocates output
+ c = out
+ # accumulator types
+ ACC_TYPE = (
+ tl.float32
+ if a.dtype in [torch.float16, torch.bfloat16, torch.float32]
+ else tl.int32
+ )
+
+ # launch kernel
+ def grid(META):
+ return (
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
+ META["SPLIT_K"],
+ B,
+ )
+
+ # grid = lambda META: (
+ # triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
+ # META["SPLIT_K"],
+ # B,
+ # )
+
+ # autotuner = _kernel[grid].kernel
+ _kernel[grid](a, b, c, M, N, K, K, 1, N, 1, N, 1, GROUP_M=8, ACC_TYPE=ACC_TYPE)
+ # print(autotuner.best_config)
+ # print(autotuner.configs_timings)
diff --git a/torch/_inductor/triton_ops/conv.py b/torch/_inductor/triton_ops/conv.py
new file mode 100644
index 0000000000000..62d7123174a5b
--- /dev/null
+++ b/torch/_inductor/triton_ops/conv.py
@@ -0,0 +1,744 @@
+import torch
+
+from ..utils import has_triton
+
+if has_triton():
+ import triton
+ import triton.language as tl
+
+ from .autotune import conv_heuristics
+ from .utils import _unpack
+
+ @conv_heuristics()
+ @triton.jit
+ def _kernel_delta_x_hwc(
+ x,
+ w,
+ y,
+ # stride of tensor
+ stride_xn,
+ stride_xc,
+ stride_xh,
+ stride_xw,
+ stride_wn,
+ stride_wc,
+ stride_wh,
+ stride_ww,
+ stride_yn,
+ stride_yc,
+ stride_yh,
+ stride_yw,
+ stride_biasn,
+ # pointer inc for x
+ delta_xh_ptr,
+ delta_xw_ptr,
+ delta_xc_ptr,
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ OUT_H,
+ OUT_W,
+ # parameters of conv
+ stride_h,
+ stride_w,
+ padding_h,
+ padding_w,
+ dilation_h,
+ dilation_w,
+ output_padding_h,
+ output_padding_w,
+ groups,
+ # Metaparameters
+ ACC_TYPE: tl.constexpr,
+ CONV1X1_NHWC: tl.constexpr,
+ # blocks in different dimension
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ # reduction tiling parameter for matmul
+ BLOCK_K: tl.constexpr,
+ # Super-blocking for better L2 peformance
+ GROUP_H: tl.constexpr,
+ ):
+ """
+ each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of y it should compute.
+ pid_nhw = tl.program_id(0)
+ pid_k = tl.program_id(1)
+
+ # offset for output y
+ off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
+ off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_y_n = off_y_nhw // (OUT_H * OUT_W)
+ off_y_hw = off_y_nhw % (OUT_H * OUT_W)
+ off_y_h = off_y_hw // OUT_W + output_padding_h
+ off_y_w = off_y_hw % OUT_W + output_padding_w
+
+ # offset for the initial ptr for x
+ off_x_n = off_y_n
+ off_x_h = off_y_h * stride_h - padding_h
+ off_x_w = off_y_w * stride_w - padding_w
+ off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
+ off_x_crs = tl.arange(0, BLOCK_K)
+
+ CRS = IN_C * KERNEL_H * KERNEL_W
+ # load inc ptr of x, upade x_ptrs
+ if not CONV1X1_NHWC:
+ delta_xh_ptrs = delta_xh_ptr + off_x_crs
+ delta_xw_ptrs = delta_xw_ptr + off_x_crs
+ delta_xc_ptrs = delta_xc_ptr + off_x_crs
+ delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
+ delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
+ delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
+ off_x_crs_unpacked = (
+ delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
+ )
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
+ else:
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
+ delta_xh = 0
+ delta_xw = 0
+
+ mask_x = (
+ (off_x_n < BATCH)[:, None]
+ & (off_x_crs < CRS)[None, :]
+ & (off_x_h[:, None] + delta_xh[None, :] >= 0)
+ & (off_x_h[:, None] + delta_xh[None, :] < IN_H)
+ & (off_x_w[:, None] + delta_xw[None, :] >= 0)
+ & (off_x_w[:, None] + delta_xw[None, :] < IN_W)
+ )
+
+ # offset for the inital ptr for w
+ off_w_crs = tl.arange(0, BLOCK_K)
+ off_w_k = off_y_k
+ w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
+ mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
+
+ # ------ load x ------
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+ # ------ load w ------
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+
+ # -----------------------------------------------------------
+ # allocate accumulator
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for crs in range(0, CRS, BLOCK_K):
+
+ # ------ matrix multiplication ------
+ acc += tl.dot(matrix_x, matrix_w)
+ # ------ update ptrs ------
+ w_ptrs += BLOCK_K
+ # load inc ptr of x, upade x_ptrs
+ off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
+ if not CONV1X1_NHWC:
+ delta_xh_ptrs += BLOCK_K
+ delta_xw_ptrs += BLOCK_K
+ delta_xc_ptrs += BLOCK_K
+ delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
+ delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
+ delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
+ off_x_crs_unpacked = (
+ delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
+ )
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
+ else:
+ x_ptrs += BLOCK_K
+
+ mask_x = (
+ (off_x_n < BATCH)[:, None]
+ & (off_x_crs < CRS)[None, :]
+ & (off_x_h[:, None] + delta_xh[None, :] >= 0)
+ & (off_x_h[:, None] + delta_xh[None, :] < IN_H)
+ & (off_x_w[:, None] + delta_xw[None, :] >= 0)
+ & (off_x_w[:, None] + delta_xw[None, :] < IN_W)
+ )
+ mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
+ # ------ prefetch ------
+ # ------ load x ------
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+ # ------ load w ------
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+
+ acc = acc.to(y.dtype.element_ty)
+
+ # rematerialize -- this saves some registers
+ # offset for output y
+ off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
+ off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_y_n = off_y_nhw // (OUT_H * OUT_W)
+ off_y_hw = off_y_nhw % (OUT_H * OUT_W)
+ # consider output padding
+ off_y_h = off_y_hw // OUT_W + output_padding_h
+ off_y_w = off_y_hw % OUT_W + output_padding_w
+
+ # y ptrs in the block of [BLOCK_M, BLOCK_N]
+ y_ptrs = (
+ y
+ + off_y_n[:, None] * stride_yn
+ + off_y_h[:, None] * stride_yh
+ + off_y_w[:, None] * stride_yw
+ + off_y_k[None, :] * stride_yc
+ )
+
+ # out-of-bounds check
+ mask_y = (
+ (off_y_n < BATCH)[:, None]
+ & (off_y_h < OUT_H + output_padding_h)[:, None]
+ & (off_y_w < OUT_W + output_padding_w)[:, None]
+ & (off_y_k < KERNEL_N)[None, :]
+ )
+
+ tl.store(y_ptrs, acc, mask=mask_y)
+
+ return
+
+ @conv_heuristics()
+ @triton.jit
+ def _kernel_delta_x(
+ x,
+ w,
+ y,
+ # stride of tensor
+ stride_xn,
+ stride_xc,
+ stride_xh,
+ stride_xw,
+ stride_wn,
+ stride_wc,
+ stride_wh,
+ stride_ww,
+ stride_yn,
+ stride_yc,
+ stride_yh,
+ stride_yw,
+ stride_biasn,
+ # pointer inc for x
+ delta_x_ptr,
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ OUT_H,
+ OUT_W,
+ # parameters of conv
+ stride_h,
+ stride_w,
+ padding_h,
+ padding_w,
+ dilation_h,
+ dilation_w,
+ output_padding_h,
+ output_padding_w,
+ groups,
+ # Metaparameters
+ ACC_TYPE: tl.constexpr,
+ CONV1X1_NHWC: tl.constexpr,
+ # blocks in different dimension
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ # reduction tiling parameter for matmul
+ BLOCK_K: tl.constexpr,
+ # Super-blocking for better L2 peformance
+ GROUP_H: tl.constexpr,
+ ):
+ """
+ each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
+ """
+ # -----------------------------------------------------------
+ # Map program ids `pid` to the block of y it should compute.
+ pid_nhw = tl.program_id(0)
+ pid_k = tl.program_id(1)
+
+ # offset for output y
+ off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
+ off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_y_n = off_y_nhw // (OUT_H * OUT_W)
+ off_y_hw = off_y_nhw % (OUT_H * OUT_W)
+ off_y_h = off_y_hw // OUT_W + output_padding_h
+ off_y_w = off_y_hw % OUT_W + output_padding_w
+
+ # offset for the initial ptr for x
+ off_x_n = off_y_n
+ off_x_h = off_y_h * stride_h - padding_h
+ off_x_w = off_y_w * stride_w - padding_w
+ off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
+ off_x_crs = tl.arange(0, BLOCK_K)
+
+ CRS = IN_C * KERNEL_H * KERNEL_W
+ # load inc ptr of x, upade x_ptrs
+ if not CONV1X1_NHWC:
+ delta_x_ptrs = delta_x_ptr + off_x_crs
+ off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS)
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
+ else:
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
+
+ mask_x = (
+ (off_x_n < BATCH)
+ & (off_x_h >= 0)
+ & (off_x_h < IN_H)
+ & (off_x_w >= 0)
+ & (off_x_w < IN_W)
+ )[:, None] & (off_x_crs < CRS)[None, :]
+
+ # offset for the inital ptr for w
+ off_w_crs = tl.arange(0, BLOCK_K)
+ off_w_k = off_y_k
+ w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
+ mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
+
+ # ------ load x ------
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+ # ------ load w ------
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+
+ # -----------------------------------------------------------
+ # allocate accumulator
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for crs in range(0, CRS, BLOCK_K):
+
+ # ------ matrix multiplication ------
+ acc += tl.dot(matrix_x, matrix_w)
+ # ------ update ptrs ------
+ w_ptrs += BLOCK_K
+ # load inc ptr of x, upade x_ptrs
+ if not CONV1X1_NHWC:
+ delta_x_ptrs += BLOCK_K
+ off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
+ off_x_crs_unpacked = tl.load(
+ delta_x_ptrs, mask=off_x_crs < CRS, other=0
+ )
+ x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
+ else:
+ off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
+ x_ptrs += BLOCK_K
+
+ mask_x = (
+ (off_x_n < BATCH)
+ & (off_x_h >= 0)
+ & (off_x_h < IN_H)
+ & (off_x_w >= 0)
+ & (off_x_w < IN_W)
+ )[:, None] & (off_x_crs < CRS)[None, :]
+ mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
+ # ------ prefetch ------
+ # ------ load x ------
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
+ # ------ load w ------
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
+
+ acc = acc.to(y.dtype.element_ty)
+
+ # rematerialize -- this saves some registers
+ # offset for output y
+ off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
+ off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_y_n = off_y_nhw // (OUT_H * OUT_W)
+ off_y_hw = off_y_nhw % (OUT_H * OUT_W)
+ # consider output padding
+ off_y_h = off_y_hw // OUT_W + output_padding_h
+ off_y_w = off_y_hw % OUT_W + output_padding_w
+
+ # y ptrs in the block of [BLOCK_M, BLOCK_N]
+ y_ptrs = (
+ y
+ + off_y_n[:, None] * stride_yn
+ + off_y_h[:, None] * stride_yh
+ + off_y_w[:, None] * stride_yw
+ + off_y_k[None, :] * stride_yc
+ )
+
+ # out-of-bounds check
+ mask_y = (
+ (off_y_n < BATCH)[:, None]
+ & (off_y_h < OUT_H + output_padding_h)[:, None]
+ & (off_y_w < OUT_W + output_padding_w)[:, None]
+ & (off_y_k < KERNEL_N)[None, :]
+ )
+
+ tl.store(y_ptrs, acc, mask=mask_y)
+
+ return
+
+ class _conv:
+ kernel = _kernel_delta_x_hwc
+
+ # for the contigous order of w ptr, what"s the corresponding
+ # ptr changes for x in a sliding window
+ @staticmethod
+ def _delta_x_ptr_hwc(
+ IN_C,
+ KERNEL_H,
+ KERNEL_W,
+ dilation_h,
+ dilation_w,
+ stride_wc,
+ stride_wh,
+ stride_ww,
+ stride_xc,
+ stride_xh,
+ stride_xw,
+ device,
+ ):
+ # get the order of axes in w, innermost dimension outward
+ stride_w_3d = [stride_wc, stride_wh, stride_ww]
+ order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__)
+ window_size = IN_C * KERNEL_H * KERNEL_W
+
+ r_window = torch.arange(0, window_size, 1, device=device)
+ window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W])
+ window_unpack_c = window_unpack[order[0]]
+ window_unpack_h = window_unpack[order[1]]
+ window_unpack_w = window_unpack[order[2]]
+ r_dilation_h = dilation_h * window_unpack_h
+ r_dilation_w = dilation_w * window_unpack_w
+ r_inc = window_unpack_c
+ # delta_x = (
+ # r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc
+ # )
+ # return delta_x
+ return (
+ r_dilation_h,
+ r_dilation_w,
+ r_inc,
+ )
+
+ @staticmethod
+ def _delta_x_ptr(
+ IN_C,
+ KERNEL_H,
+ KERNEL_W,
+ dilation_h,
+ dilation_w,
+ stride_wc,
+ stride_wh,
+ stride_ww,
+ stride_xc,
+ stride_xh,
+ stride_xw,
+ device,
+ ):
+ # get the order of axes in w, innermost dimension outward
+ stride_w_3d = [stride_wc, stride_wh, stride_ww]
+ order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__)
+ window_size = IN_C * KERNEL_H * KERNEL_W
+
+ r_window = torch.arange(0, window_size, 1, device=device)
+ window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W])
+ window_unpack_c = window_unpack[order[0]]
+ window_unpack_h = window_unpack[order[1]]
+ window_unpack_w = window_unpack[order[2]]
+ r_dilation_h = dilation_h * window_unpack_h
+ r_dilation_w = dilation_w * window_unpack_w
+ r_inc = window_unpack_c
+ delta_x = (
+ r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc
+ )
+ return delta_x
+
+ @staticmethod
+ def _call(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ ):
+ # Q: should we check x, w, bias dtypes?
+ device = x.device
+ # input shapes
+ shape_x = x.shape
+ shape_w = w.shape
+ shape_bias = bias.shape if bias is not None else None
+
+ # indicies for the layeout
+ xn, xc, xh, xw = 0, 1, 2, 3
+ yn, yc, yh, yw = 0, 1, 2, 3
+ wn, wc, wh, ww = 0, 1, 2, 3
+
+ # out_channel, in_channel, kernel_height, kernel_width
+ kernel_size = [shape_w[wh], shape_w[ww]]
+ input_size = [shape_x[xh], shape_x[xw]]
+ assert (
+ not shape_bias or shape_bias[0] == shape_w[wn]
+ ), f"bias shape did not match{shape_bias} != {shape_w[wn]}"
+ in_channel = shape_w[wc] * groups
+
+ assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups"
+ assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups"
+ assert (
+ shape_x[xc] == in_channel
+ ), f"in_channel did not match {shape_x[xc]} != {in_channel}"
+
+ assert (
+ len(stride)
+ == len(padding)
+ == len(dilation)
+ == len(output_padding)
+ == len(kernel_size)
+ == len(input_size)
+ )
+
+ # output shape
+ shape_y = [0] * 4
+ shape_y[yn] = shape_x[xn]
+ shape_y[yc] = shape_w[wn]
+ shape_y[yh] = (
+ input_size[0]
+ + 2 * padding[0]
+ - dilation[0] * (kernel_size[0] - 1)
+ - 1
+ + stride[0]
+ ) // stride[0] + 2 * output_padding[0]
+ shape_y[yw] = (
+ input_size[1]
+ + 2 * padding[1]
+ - dilation[1] * (kernel_size[1] - 1)
+ - 1
+ + stride[1]
+ ) // stride[1] + 2 * output_padding[1]
+
+ BATCH = shape_x[xn]
+ IN_C = shape_x[xc]
+ IN_H = shape_x[xh]
+ IN_W = shape_x[xw]
+ KERNEL_N = shape_w[wn]
+ KERNEL_H = shape_w[wh]
+ KERNEL_W = shape_w[ww]
+ OUT_H = shape_y[yh]
+ OUT_W = shape_y[yw]
+
+ # allocate output
+ y = torch.empty(shape_y, device=device, dtype=x.dtype)
+
+ # get strides for tensors
+ stride_x = x.stride()
+ stride_w = w.stride()
+ stride_bias = bias.stride() if shape_bias else None
+ stride_biasn = stride_bias[0] if stride_bias else None
+
+ # output layout should be the same as x
+ if stride_x[xc] < stride_x[xh] and stride_x[xc] < stride_x[xw]:
+ y = y.to(memory_format=torch.channels_last)
+ stride_y = y.stride()
+
+ # allocate tmp
+ # WINDOW_SIZE = KERNEL_H * KERNEL_W * IN_C
+ # tmp_x = torch.empty((BATCH * OUT_H * OUT_W, WINDOW_SIZE), device=device, dtype=x.dtype)
+ # tmp_w = torch.empty((WINDOW_SIZE, KERNEL_N), device=device, dtype=w.dtype)
+ # accumulator types
+ ACC_TYPE = (
+ tl.float32
+ if x.dtype in [torch.float16, torch.bfloat16, torch.float32]
+ else tl.int32
+ )
+ # if stride_x[xc] == 1 and stride_x > 1 and stride_y > 1:
+ CONV1X1_NHWC = False
+ if stride_x[xc] == 1 and KERNEL_H == 1 and KERNEL_W == 1:
+ CONV1X1_NHWC = True
+ # do we need delta x ptr for h, w, c dimension each or not
+ DELTA_X_PTR_HWC = (
+ False
+ if (
+ (padding[0] == 0 and padding[1] == 0)
+ or (KERNEL_H == 1 and KERNEL_W == 1)
+ )
+ else True
+ )
+ if not CONV1X1_NHWC:
+ if DELTA_X_PTR_HWC:
+ delta_xh, delta_xw, delta_xc = _conv._delta_x_ptr_hwc(
+ IN_C,
+ KERNEL_H,
+ KERNEL_W,
+ dilation[0],
+ dilation[1],
+ stride_w[wc],
+ stride_w[wh],
+ stride_w[ww],
+ stride_x[xc],
+ stride_x[xh],
+ stride_x[xw],
+ device,
+ )
+ else:
+ delta_x = _conv._delta_x_ptr(
+ IN_C,
+ KERNEL_H,
+ KERNEL_W,
+ dilation[0],
+ dilation[1],
+ stride_w[wc],
+ stride_w[wh],
+ stride_w[ww],
+ stride_x[xc],
+ stride_x[xh],
+ stride_x[xw],
+ device,
+ )
+ else:
+ delta_x = None
+ delta_xh, delta_xw, delta_xc = None, None, None
+
+ # launch kernel, 2-dim, batch*h*w, kernel
+ def grid(META):
+ return (
+ triton.cdiv(BATCH * OUT_H * OUT_W, META["BLOCK_M"]),
+ triton.cdiv(KERNEL_N, META["BLOCK_N"]),
+ )
+
+ # conv1x1 or padding==0
+ if CONV1X1_NHWC or not DELTA_X_PTR_HWC:
+ _kernel_delta_x[grid](
+ x,
+ w,
+ y,
+ # stride nchw for x,w,y tensor
+ stride_x[xn],
+ stride_x[xc],
+ stride_x[xh],
+ stride_x[xw],
+ stride_w[wn],
+ stride_w[wc],
+ stride_w[wh],
+ stride_w[ww],
+ stride_y[yn],
+ stride_y[yc],
+ stride_y[yh],
+ stride_y[yw],
+ stride_biasn,
+ # pointer inc for x
+ delta_x,
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ OUT_H,
+ OUT_W,
+ # conv parameters
+ stride[0],
+ stride[1],
+ padding[0],
+ padding[1],
+ dilation[0],
+ dilation[1],
+ output_padding[0],
+ output_padding[1],
+ groups,
+ # Metaparameters
+ ACC_TYPE=ACC_TYPE,
+ CONV1X1_NHWC=CONV1X1_NHWC,
+ # BLOCK_M=128,
+ # BLOCK_N=32,
+ # BLOCK_K=32,
+ GROUP_H=1,
+ )
+ # need to know ptr update for each dimension to check if
+ # the sliding window is out of bounds
+ else:
+ # kernel = _kernel_delta_x_hwc
+ _kernel_delta_x_hwc[grid](
+ x,
+ w,
+ y,
+ # stride nchw for x,w,y tensor
+ stride_x[xn],
+ stride_x[xc],
+ stride_x[xh],
+ stride_x[xw],
+ stride_w[wn],
+ stride_w[wc],
+ stride_w[wh],
+ stride_w[ww],
+ stride_y[yn],
+ stride_y[yc],
+ stride_y[yh],
+ stride_y[yw],
+ stride_biasn,
+ # pointer inc for x
+ delta_xh,
+ delta_xw,
+ delta_xc,
+ # Tensor dimensions
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ OUT_H,
+ OUT_W,
+ # conv parameters
+ stride[0],
+ stride[1],
+ padding[0],
+ padding[1],
+ dilation[0],
+ dilation[1],
+ output_padding[0],
+ output_padding[1],
+ groups,
+ # Metaparameters
+ ACC_TYPE=ACC_TYPE,
+ CONV1X1_NHWC=CONV1X1_NHWC,
+ # BLOCK_M=128,
+ # BLOCK_N=32,
+ # BLOCK_K=32,
+ GROUP_H=1,
+ )
+
+ if bias is not None:
+ if len(bias.shape) == 1:
+ bias = bias.reshape([1, bias.shape[0], 1, 1])
+ y += bias
+ return y
+
+ @staticmethod
+ def forward(
+ x,
+ w,
+ bias,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ transposed=False,
+ output_padding=(0, 0),
+ groups=1,
+ ):
+ if groups != 1:
+ print(f"Do not support groups = {groups}")
+ return
+ if transposed:
+ print("Do not support transposed")
+ return _conv._call(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ )
+
+ conv = _conv.forward
diff --git a/torch/_inductor/triton_ops/conv1x1.py b/torch/_inductor/triton_ops/conv1x1.py
new file mode 100644
index 0000000000000..c7b79f004a5a9
--- /dev/null
+++ b/torch/_inductor/triton_ops/conv1x1.py
@@ -0,0 +1,195 @@
+import torch
+
+from ..utils import has_triton
+
+if has_triton():
+
+ import triton
+
+ class _conv1x1:
+ @staticmethod
+ def _call(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ ):
+ # Q: should we check x, w, bias dtypes?
+ device = x.device
+ # input shapes
+ shape_x = x.shape
+ shape_w = w.shape
+ shape_bias = bias.shape if bias is not None else None
+
+ # indicies for the layeout
+ xn, xc, xh, xw = 0, 1, 2, 3
+ yn, yc, yh, yw = 0, 1, 2, 3
+ wn, wc, wh, ww = 0, 1, 2, 3
+
+ # out_channel, in_channel, kernel_height, kernel_width
+ kernel_size = [shape_w[wh], shape_w[ww]]
+ input_size = [shape_x[xh], shape_x[xw]]
+ assert (
+ not shape_bias or shape_bias[0] == shape_w[wn]
+ ), f"bias shape did not match{shape_bias} != {shape_w[wn]}"
+ in_channel = shape_w[wc] * groups
+
+ assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups"
+ assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups"
+ assert (
+ shape_x[xc] == in_channel
+ ), f"in_channel did not match {shape_x[xc]} != {in_channel}"
+
+ assert (
+ len(stride)
+ == len(padding)
+ == len(dilation)
+ == len(output_padding)
+ == len(kernel_size)
+ == len(input_size)
+ )
+
+ # output shape
+ shape_y = [0] * 4
+ shape_y[yn] = shape_x[xn]
+ shape_y[yc] = shape_w[wn]
+ shape_y[yh] = (
+ input_size[0]
+ + 2 * padding[0]
+ - dilation[0] * (kernel_size[0] - 1)
+ - 1
+ + stride[0]
+ ) // stride[0] + 2 * output_padding[0]
+ shape_y[yw] = (
+ input_size[1]
+ + 2 * padding[1]
+ - dilation[1] * (kernel_size[1] - 1)
+ - 1
+ + stride[1]
+ ) // stride[1] + 2 * output_padding[1]
+
+ BATCH = shape_x[xn]
+ IN_C = shape_x[xc]
+ # IN_H = shape_x[xh]
+ # IN_W = shape_x[xw]
+ KERNEL_N = shape_w[wn]
+ KERNEL_H = shape_w[wh]
+ KERNEL_W = shape_w[ww]
+ OUT_H = shape_y[yh]
+ OUT_W = shape_y[yw]
+
+ assert KERNEL_H == 1 and KERNEL_W == 1, "only support 1x1 conv"
+ channels_last = x.stride()[1] == 1
+
+ if padding == (0, 0):
+ # nchw -> nhwc
+ x = x.permute(0, 2, 3, 1)
+ # select every stride's element (for stride > 1)
+ x = x[:, :: stride[0], :: stride[1], :]
+ # 2d matrix
+ mat_x = x.reshape(-1, IN_C)
+ # 2d matrix
+ mat_w = w.view(KERNEL_N, IN_C)
+ mat_w = mat_w.permute(1, 0)
+ # 2d matrix y, (BATCH * OUT_H * OUT_W, KERNEL_N)
+ mat_y = triton.ops.matmul(mat_x, mat_w)
+ # mat_y = torch.empty((BATCH * OUT_H * OUT_W, KERNEL_N), device=device, dtype=x.dtype,)
+ y = mat_y.view(BATCH, OUT_H, OUT_W, KERNEL_N)
+ if bias is not None:
+ y += bias
+ # convert back to the original layout of y
+ # nhwc -> nchw
+ y = y.permute(0, 3, 1, 2)
+ if not channels_last:
+ y = y.to(memory_format=torch.contiguous_format)
+ return y
+
+ else:
+ y = torch.empty(
+ (shape_y[yn], shape_y[yh], shape_y[yw], shape_y[yc]),
+ device=device,
+ dtype=x.dtype,
+ )
+ if channels_last:
+ y = y.to(memory_format=torch.channels_last)
+ # y = bias.repeat((shape_y[yn], shape_y[yh], shape_y[yw], 1)).to(device).type(x.dtype)
+ # convert x to channel-last layout;
+ # don't care w layout since kernel size is 1
+ x = x.permute(0, 2, 3, 1)
+ # select every stride"s element (for stride > 1)
+ x = x[:, :: stride[0], :: stride[1], :]
+ # 2d matrix
+ mat_x = x.view(-1, IN_C)
+ # 2d matrix
+ mat_w = w.view(KERNEL_N, IN_C)
+ mat_w = mat_w.permute(1, 0)
+ # 2d matrix y, (BATCH * (OUT_H-2*padding) * (OUT_W-2*padding), KERNEL_N)
+ mat_y = triton.ops.matmul(mat_x, mat_w)
+ mat_y = mat_y.view(
+ BATCH, OUT_H - 2 * padding[0], OUT_W - 2 * padding[1], KERNEL_N
+ )
+ # consider padding > 0
+ if bias is not None:
+ y[
+ :,
+ padding[0] : OUT_H - padding[0],
+ padding[1] : OUT_W - padding[1],
+ :,
+ ] = (
+ mat_y + bias
+ )
+ y[:, : padding[0], :, :] = bias
+ y[:, :, : padding[1], :] = bias
+ y[:, OUT_H - padding[0] :, :, :] = bias
+ y[:, :, OUT_W - padding[1] :, :] = bias
+ else:
+ y[
+ :,
+ padding[0] : OUT_H - padding[0],
+ padding[1] : OUT_W - padding[1],
+ :,
+ ] = mat_y
+ y[:, : padding[0], :, :] = 0
+ y[:, :, : padding[1], :] = 0
+ y[:, OUT_H - padding[0] :, :, :] = 0
+ y[:, :, OUT_W - padding[1] :, :] = 0
+ # convert back to the original layout of y
+ # nhwc -> nchw
+ y = y.permute(0, 3, 1, 2)
+ return y
+
+ @staticmethod
+ def forward(
+ x,
+ w,
+ bias,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ transposed=False,
+ output_padding=(0, 0),
+ groups=1,
+ ):
+ if groups != 1:
+ print(f"Do not support groups = {groups}")
+ return
+ if transposed:
+ print("Do not support transposed")
+ return _conv1x1._call(
+ x,
+ w,
+ bias,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ )
+
+ conv1x1 = _conv1x1.forward
diff --git a/torch/_inductor/triton_ops/conv_perf_model.py b/torch/_inductor/triton_ops/conv_perf_model.py
new file mode 100644
index 0000000000000..0369e35ec6cac
--- /dev/null
+++ b/torch/_inductor/triton_ops/conv_perf_model.py
@@ -0,0 +1,165 @@
+import heapq
+
+import torch
+
+
+def estimate_conv_time(
+ # backend, device,
+ num_warps,
+ num_stages,
+ x,
+ BATCH,
+ IN_C,
+ IN_H,
+ IN_W,
+ KERNEL_N,
+ KERNEL_H,
+ KERNEL_W,
+ OUT_H,
+ OUT_W,
+ BLOCK_M,
+ BLOCK_K,
+ BLOCK_N,
+ debug=False,
+ **kwargs,
+):
+ """return estimated running time in ms
+ = max(compute, loading) + store"""
+ import triton
+ import triton._C.libtriton.triton as _triton
+ from triton.ops.matmul_perf_model import (
+ get_dram_gbps as get_dram_gbps,
+ get_tflops as get_tflops,
+ )
+
+ backend = _triton.runtime.backend.CUDA
+ device = torch.cuda.current_device()
+ dtype = x.dtype
+ dtsize = x.element_size()
+
+ M = BATCH * OUT_H * OUT_W
+ N = KERNEL_N
+ K = KERNEL_H * KERNEL_W * IN_C
+ num_cta_m = triton.cdiv(M, BLOCK_M)
+ num_cta_n = triton.cdiv(N, BLOCK_N)
+ num_cta_k = 1
+ num_ctas = num_cta_m * num_cta_n * num_cta_k
+
+ # If the input is smaller than the block size
+ M, N = max(M, BLOCK_M), max(N, BLOCK_N)
+
+ # time to compute
+ total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
+ tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
+ compute_ms = total_ops / tput
+
+ # time to load data
+ num_sm = _triton.runtime.num_sm(backend, device)
+ active_cta_ratio = min(1, num_ctas / num_sm)
+ active_cta_ratio_bw1 = min(
+ 1, num_ctas / 32
+ ) # 32 active ctas are enough to saturate
+ active_cta_ratio_bw2 = max(
+ min(1, (num_ctas - 32) / (108 - 32)), 0
+ ) # 32-108, remaining 5%
+ dram_bw = get_dram_gbps(backend, device) * (
+ active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05
+ ) # in GB/s
+ l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
+ # assume 80% of (following) loads are in L2 cache
+ load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
+ load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
+ load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
+ load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
+ # total
+ total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
+ total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
+ # loading time in ms
+ load_ms = total_dram / dram_bw + total_l2 / l2_bw
+
+ # estimate storing time
+ store_bw = dram_bw * 0.6 # :o
+ store_c_dram = M * N * dtsize / (1024 * 1024) # MB
+ store_ms = store_c_dram / store_bw
+
+ total_time_ms = max(compute_ms, load_ms) + store_ms
+ if debug:
+ print(
+ f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, "
+ f"loading time: {load_ms}ms, store time: {store_ms}ms, "
+ f"Activate CTAs: {active_cta_ratio*100}%"
+ )
+ return total_time_ms
+
+
+def early_config_prune(configs, named_args):
+ import triton._C.libtriton.triton as _triton
+
+ backend = _triton.runtime.backend.CUDA
+ device = torch.cuda.current_device()
+ cc = _triton.runtime.cc(backend, device)
+ # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
+ dtsize = named_args["x"].element_size()
+ # dtype = named_args["x"].dtype
+
+ # 1. make sure we have enough smem
+ pruned_configs = []
+ for config in configs:
+ kw = config.kwargs
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
+ kw["BLOCK_M"],
+ kw["BLOCK_N"],
+ kw["BLOCK_K"],
+ config.num_stages,
+ )
+ max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
+ if required_shared_memory <= max_shared_memory:
+ pruned_configs.append(config)
+ configs = pruned_configs
+
+ # group configs by (BLOCK_M,_N,_K, num_warps)
+ configs_map = {}
+ for config in configs:
+ kw = config.kwargs
+ BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = (
+ kw["BLOCK_M"],
+ kw["BLOCK_N"],
+ kw["BLOCK_K"],
+ config.num_warps,
+ config.num_stages,
+ )
+
+ key = (BLOCK_M, BLOCK_N, BLOCK_K, num_warps)
+ if key in configs_map:
+ configs_map[key].append((config, num_stages))
+ else:
+ configs_map[key] = [(config, num_stages)]
+
+ pruned_configs = []
+ for k, v in configs_map.items():
+ BLOCK_M, BLOCK_N, BLOCK_K, num_warps = k
+ if cc >= 80:
+ # compute cycles (only works for ampere GPUs)
+ mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
+ mma_cycles = mmas / min(4, num_warps) * 8
+
+ ldgsts_latency = 300 # Does this matter?
+ optimal_num_stages = ldgsts_latency / mma_cycles
+
+ # nearest stages, prefer large #stages
+ nearest = heapq.nsmallest(
+ 2,
+ v,
+ key=lambda x: 10 + abs(x[1] - optimal_num_stages)
+ if (x[1] - optimal_num_stages) < 0
+ else x[1] - optimal_num_stages,
+ )
+
+ for n in nearest:
+ pruned_configs.append(n[0])
+ else: # Volta & Turing only supports num_stages <= 2
+ random_config = v[0][0]
+ random_config.num_stages = 2
+ pruned_configs.append(random_config)
+ return pruned_configs
diff --git a/torch/_inductor/triton_ops/matmul.py b/torch/_inductor/triton_ops/matmul.py
new file mode 100644
index 0000000000000..c120b8c0b2773
--- /dev/null
+++ b/torch/_inductor/triton_ops/matmul.py
@@ -0,0 +1,136 @@
+import torch
+
+from ..utils import has_triton
+
+if has_triton():
+
+ import triton
+ import triton.language as tl
+
+ from .autotune import mm_autotune, mm_heuristics
+
+ @mm_heuristics()
+ @mm_autotune(get_io_bound_configs=True)
+ @triton.jit
+ def _kernel(
+ A,
+ B,
+ C,
+ M,
+ N,
+ K,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ allow_tf32: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+ SPLIT_K: tl.constexpr,
+ EVEN_K: tl.constexpr,
+ ACC_TYPE: tl.constexpr,
+ ):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_z = tl.program_id(1)
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ # do matrix multiplication
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(K, 0, -BLOCK_K * SPLIT_K):
+ if EVEN_K:
+ a = tl.load(A)
+ b = tl.load(B)
+ else:
+ a = tl.load(A, mask=rk[None, :] < k, other=0.0)
+ b = tl.load(B, mask=rk[:, None] < k, other=0.0)
+ acc += tl.dot(a, b, allow_tf32=allow_tf32)
+ A += BLOCK_K * SPLIT_K * stride_ak
+ B += BLOCK_K * SPLIT_K * stride_bk
+ acc = acc.to(C.dtype.element_ty)
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+ # handles write-back with reduction-splitting
+ if SPLIT_K == 1:
+ tl.store(C, acc, mask=mask)
+ else:
+ tl.atomic_add(C, acc, mask=mask)
+
+ class _matmul_out:
+ kernel = _kernel
+
+ @staticmethod
+ def _call(a, b, out, allow_tf32=True):
+ # handle non-contiguous inputs if necessary
+ if a.stride(0) > 1 and a.stride(1) > 1:
+ a = a.contiguous()
+ if b.stride(0) > 1 and b.stride(1) > 1:
+ b = b.contiguous()
+ # checks constraints
+ assert a.shape[1] == b.shape[0], "incompatible dimensions"
+ M, K = a.shape
+ _, N = b.shape
+ # allocates output
+ c = out
+ # accumulator types
+ ACC_TYPE = (
+ tl.float32
+ if a.dtype in [torch.float16, torch.bfloat16, torch.float32]
+ else tl.int32
+ )
+
+ # launch kernel (grid defined as using def instead of lambda to pass `make lint`)
+ def grid(META):
+ return (
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
+ META["SPLIT_K"],
+ )
+
+ # grid = lambda META: (
+ # triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
+ # META["SPLIT_K"],
+ # )
+ _kernel[grid](
+ a,
+ b,
+ c,
+ M,
+ N,
+ K,
+ a.stride(0),
+ a.stride(1),
+ b.stride(0),
+ b.stride(1),
+ c.stride(0),
+ c.stride(1),
+ allow_tf32=allow_tf32,
+ GROUP_M=8,
+ ACC_TYPE=ACC_TYPE,
+ )
+
+ @staticmethod
+ def forward(a, b, out, allow_tf32=True):
+ return _matmul_out._call(a, b, out, allow_tf32)
+
+ matmul_out = _matmul_out.forward
diff --git a/torch/_inductor/triton_ops/mm_perf_model.py b/torch/_inductor/triton_ops/mm_perf_model.py
new file mode 100644
index 0000000000000..fd3a6904213ea
--- /dev/null
+++ b/torch/_inductor/triton_ops/mm_perf_model.py
@@ -0,0 +1,90 @@
+import torch
+
+
+def estimate_matmul_time(
+ # backend, device,
+ num_warps,
+ num_stages,
+ A,
+ B,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ SPLIT_K,
+ debug=False,
+ **kwargs,
+):
+ """return estimated running time in ms
+ = max(compute, loading) + store"""
+ import triton
+ import triton._C.libtriton.triton as _triton
+ from triton.ops.matmul_perf_model import (
+ get_dram_gbps as get_dram_gbps,
+ get_tflops as get_tflops,
+ )
+
+ backend = _triton.runtime.backend.CUDA
+ device = torch.cuda.current_device()
+ dtype = A.dtype
+ dtsize = A.element_size()
+
+ num_cta_m = triton.cdiv(M, BLOCK_M)
+ num_cta_n = triton.cdiv(N, BLOCK_N)
+ num_cta_k = SPLIT_K
+ num_ctas = num_cta_m * num_cta_n * num_cta_k
+
+ # If the input is smaller than the block size
+ M, N = max(M, BLOCK_M), max(N, BLOCK_N)
+
+ # time to compute
+ total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
+ tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
+ compute_ms = total_ops / tput
+
+ # time to load data
+ num_sm = _triton.runtime.num_sm(backend, device)
+ active_cta_ratio = min(1, num_ctas / num_sm)
+ active_cta_ratio_bw1 = min(
+ 1, num_ctas / 32
+ ) # 32 active ctas are enough to saturate
+ active_cta_ratio_bw2 = max(
+ min(1, (num_ctas - 32) / (108 - 32)), 0
+ ) # 32-108, remaining 5%
+ dram_bw = get_dram_gbps(backend, device) * (
+ active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05
+ ) # in GB/s
+ l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
+ # assume 80% of (following) loads are in L2 cache
+ load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
+ load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
+ load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
+ load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
+ # total
+ total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
+ total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
+ # loading time in ms
+ load_ms = total_dram / dram_bw + total_l2 / l2_bw
+
+ # estimate storing time
+ store_bw = dram_bw * 0.6 # :o
+ store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
+ if SPLIT_K == 1:
+ store_ms = store_c_dram / store_bw
+ else:
+ reduce_bw = store_bw
+ store_ms = store_c_dram / reduce_bw
+ # c.zero_()
+ zero_ms = M * N * 2 / (1024 * 1024) / store_bw
+ store_ms += zero_ms
+
+ total_time_ms = max(compute_ms, load_ms) + store_ms
+ if debug:
+ print(
+ f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, "
+ f"loading time: {load_ms}ms, store time: {store_ms}ms, "
+ f"Activate CTAs: {active_cta_ratio*100}%"
+ )
+ return total_time_ms
diff --git a/torch/_inductor/triton_ops/utils.py b/torch/_inductor/triton_ops/utils.py
new file mode 100644
index 0000000000000..2bc98ae29c4fe
--- /dev/null
+++ b/torch/_inductor/triton_ops/utils.py
@@ -0,0 +1,31 @@
+import torch
+
+
+def _extract_strides(shape):
+ rank = len(shape)
+ ret = [1] * rank
+ for i in range(rank - 1, 0, -1):
+ ret[i - 1] = ret[i] * shape[i]
+ return ret
+
+
+def _roundup(x, div):
+ return (x + div - 1) // div * div
+
+
+# unpack the given idx given the order of axis of the desired 3-dim tensor
+# You could view it as the reverse of flatten the idx of 3 axis in a tensor to 1-dim idx.
+# order is the order of axes in tensor, innermost dimension outward
+# shape is the 3D tensor's shape
+def _unpack(idx, order, shape):
+ if torch.is_tensor(idx):
+ _12 = torch.div(idx, shape[order[0]], rounding_mode="trunc")
+ _0 = idx % shape[order[0]]
+ _2 = torch.div(_12, shape[order[1]], rounding_mode="trunc")
+ _1 = _12 % shape[order[1]]
+ else:
+ _12 = idx // shape[order[0]]
+ _0 = idx % shape[order[0]]
+ _2 = _12 // shape[order[1]]
+ _1 = _12 % shape[order[1]]
+ return _0, _1, _2
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
new file mode 100644
index 0000000000000..922a5a765c4ec
--- /dev/null
+++ b/torch/_inductor/utils.py
@@ -0,0 +1,259 @@
+import collections
+import contextlib
+import functools
+import operator
+import os
+import tempfile
+import time
+from importlib import import_module
+from typing import Any, Dict, List
+from unittest import mock
+
+import numpy as np
+import sympy
+
+import torch
+from torch.fx.immutable_collections import immutable_dict, immutable_list
+
+from . import config
+
+VarRanges = Dict[sympy.Expr, sympy.Expr]
+
+# We import torchdynamo modules indirectly to allow a future rename to torch.dynamo
+dynamo_config = import_module(f"{config.dynamo_import}.config")
+dynamo_debug_utils = import_module(f"{config.dynamo_import}.debug_utils")
+dynamo_logging = import_module(f"{config.dynamo_import}.logging")
+dynamo_optimizations = import_module(f"{config.dynamo_import}.optimizations")
+dynamo_testing = import_module(f"{config.dynamo_import}.testing")
+dynamo_utils = import_module(f"{config.dynamo_import}.utils")
+
+
+@functools.lru_cache(None)
+def has_triton():
+ if not torch.cuda.is_available():
+ return False
+ try:
+ import triton
+
+ return triton is not None
+ except ImportError:
+ return False
+
+
+@functools.lru_cache(None)
+def has_torchvision_roi_align():
+ try:
+ from torchvision.ops import roi_align # noqa: F401
+
+ return roi_align is not None and hasattr(
+ getattr(torch.ops, "torchvision", None), "roi_align"
+ )
+ except ImportError:
+ return False
+
+
+def conditional_product(*args):
+ return functools.reduce(operator.mul, [x for x in args if x])
+
+
+def sympy_product(it):
+ return functools.reduce(operator.mul, it, sympy.Integer(1))
+
+
+def sympy_dot(seq1, seq2):
+ assert len(seq1) == len(seq2)
+ return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
+
+
+def unique(it):
+ return {id(x): x for x in it}.values()
+
+
+def ceildiv(numer: int, denom: int):
+ assert isinstance(numer, int) and isinstance(denom, int)
+ return -(numer // -denom)
+
+
+def gen_gm_and_inputs(target, args, kwargs):
+ g = torch.fx.Graph()
+ g_args = []
+ a_args = []
+ for n, arg in enumerate(args):
+ if isinstance(arg, torch.Tensor):
+ g_args.append(g.placeholder(f"arg{n}"))
+ a_args.append(arg)
+ else:
+ g_args.append(arg)
+ assert all(not isinstance(x, torch.Tensor) for x in kwargs.values())
+ node = g.call_function(target, tuple(g_args), kwargs)
+ if (
+ len(target._schema.returns) == 1
+ and str(target._schema.returns[0].type) == "Tensor"
+ ):
+ node = (node,)
+ g.output(node)
+
+ gm = torch.fx.GraphModule({}, g)
+ return gm, a_args
+
+
+def synchronize():
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+
+def timed(model, example_inputs, times=1):
+ synchronize()
+ torch.manual_seed(1337)
+ t0 = time.perf_counter()
+ for _ in range(times):
+ result = model(*example_inputs)
+ synchronize()
+ t1 = time.perf_counter()
+ # GC the result after timing
+ assert result is not None
+ return t1 - t0
+
+
+def print_performance(fn, args=(), times=10, repeat=10, baseline=1.0):
+ timings = [timed(fn, args, times) for _ in range(repeat)]
+ took = np.median(timings)
+ print(f"{took/baseline:.6f}")
+ return took
+
+
+immutable_dict.__hash__ = lambda self: hash(tuple(self.items()))
+immutable_list.__hash__ = lambda self: hash(tuple(self))
+
+
+def freeze_inputs(f):
+ """
+ Useful for wrapping lists in tuples for caching purposes
+ """
+
+ def freeze_value(x):
+ if isinstance(x, (immutable_dict, immutable_list)):
+ return x
+ if isinstance(x, list):
+ return immutable_list(x)
+ if isinstance(x, dict):
+ return immutable_dict(x)
+ return x
+
+ @functools.wraps(f)
+ def wrapped(*args):
+ args = [freeze_value(x) for x in args]
+ return f(*args)
+
+ wrapped.cache_info = f.cache_info
+ return wrapped
+
+
+def precompute_method(obj: Any, method: str):
+ """Replace obj.method() with a new method that returns a precomputed constant."""
+ result = getattr(obj, method)()
+ setattr(obj, method, lambda: result)
+
+
+def precompute_methods(obj: Any, methods: List[str]):
+ """Replace methods with new methods that returns a precomputed constants."""
+ for method in methods:
+ precompute_method(obj, method)
+
+
+def cmp(a, b):
+ return int(a > b) - int(a < b)
+
+
+def cache_on_self(fn):
+ key = f"__{fn.__name__}_cache"
+
+ @functools.wraps(fn)
+ def wrapper(self):
+ if not hasattr(self, key):
+ setattr(self, key, fn(self))
+ return getattr(self, key)
+
+ return wrapper
+
+
+def sympy_str(expr: sympy.Expr):
+ """
+ Normal sympy str is very slow, this is a lot faster. The result are
+ somewhat worse, as it doesn't do as much simplification. So don't
+ use this for final codegen.
+ """
+ if isinstance(expr, sympy.Symbol):
+ return expr.name
+ if isinstance(expr, sympy.Add):
+ return " + ".join(map(sympy_str, expr.args))
+ if isinstance(expr, sympy.Mul):
+ return " * ".join(map(sympy_str, expr.args))
+
+ from .ir import CleanDiv, IndexingDiv, ModularIndexing
+
+ if isinstance(expr, (ModularIndexing, CleanDiv, IndexingDiv)):
+ return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
+ return str(expr)
+
+
+def sympy_subs(expr: sympy.Expr, replacements: Dict[Any, Any]):
+ """
+ xreplace is faster than subs, but is way more picky
+ """
+
+ def promote_strings(key):
+ if isinstance(key, str):
+ return sympy.Symbol(key)
+ return key
+
+ return expr.xreplace(
+ {promote_strings(k): promote_strings(v) for k, v in replacements.items()}
+ )
+
+
+def free_symbol_startswith(index: sympy.Expr, prefix: str):
+ return any(v.name.startswith(prefix) for v in index.free_symbols)
+
+
+def has_incompatible_cudagraph_ops(gm):
+ forbidden_list = set(
+ [
+ "aten._fused_moving_avg_obs_fq_helper.default",
+ "aten._fused_moving_avg_obs_fq_helper_functional.default",
+ "fbgemm.dense_to_jagged.default",
+ "fbgemm.jagged_to_padded_dense.default",
+ ]
+ )
+ for node in gm.graph.nodes:
+ if str(node.target) in forbidden_list:
+ return True
+ return False
+
+
+instance_descriptor = collections.namedtuple(
+ "instance_descriptor", ["divisible_by_16", "equal_to_1"]
+)
+
+
+@contextlib.contextmanager
+def fresh_triton_cache(cache_entries=None):
+ """
+ Contextmanager that provides a clean tmp cachedir for triton.
+
+ Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
+ generated with this cache instance.
+ """
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": tmpdirname}):
+ yield
+ if isinstance(cache_entries, dict):
+ assert len(cache_entries) == 0, "expected empty cache_entries dict"
+ files = os.listdir(tmpdirname)
+ cache_entries.update(
+ {
+ f: os.path.getsize(os.path.join(tmpdirname, f))
+ for f in files
+ if ".lock" not in f
+ }
+ )
diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py
new file mode 100644
index 0000000000000..35109aba271e2
--- /dev/null
+++ b/torch/_inductor/virtualized.py
@@ -0,0 +1,136 @@
+from contextlib import contextmanager
+from itertools import chain
+from threading import local
+
+import sympy
+
+from torch.fx.graph import inplace_methods, magic_methods
+
+from .utils import sympy_str
+
+threadlocal = local()
+
+
+class Virtualized:
+ """
+ A global variable that redirects via thread local variable
+
+ This allows us to swap in different op implementations in codegen.
+ """
+
+ def __init__(self, vname, default):
+ self._key = f"__torchinductor_{vname}"
+ self._default = default
+
+ def _set_handler(self, value):
+ prior = self._get_handler()
+ setattr(threadlocal, self._key, value)
+
+ @contextmanager
+ def ctx():
+ try:
+ yield
+ finally:
+ self._set_handler(prior)
+
+ return ctx()
+
+ def _get_handler(self):
+ try:
+ return getattr(threadlocal, self._key)
+ except AttributeError:
+ return self._default()
+
+ def __getattr__(self, name):
+ return getattr(self._get_handler(), name)
+
+
+class NullHandler:
+ pass
+
+
+def _arg_str(a):
+ if isinstance(a, sympy.Expr):
+ return sympy_str(a)
+ return str(a)
+
+
+class MockHandler:
+ def __getattr__(self, name):
+ def inner(*args, **kwargs):
+ fargs = [_arg_str(a) for a in args]
+ fargs.extend(f"{k}={v}" for k, v in kwargs.items())
+ return f"{name}({', '.join(fargs)})"
+
+ return inner
+
+ @staticmethod
+ def masked(mask, body, other):
+ return f"masked({mask}, {body()}, {other})"
+
+ @staticmethod
+ def indirect_indexing(index_var):
+ return sympy.Symbol(str(index_var))
+
+ @classmethod
+ def _init_cls(cls):
+ def make_handler(format_string):
+ @staticmethod
+ def inner(*args):
+ return format_string.format(*args)
+
+ return inner
+
+ for name, format_string in chain(
+ magic_methods.items(), inplace_methods.items()
+ ):
+ setattr(cls, name, make_handler(format_string))
+
+
+class WrapperHandler:
+ def __init__(self, inner):
+ self._inner = inner
+
+ def __getattr__(self, item):
+ return getattr(self._inner, item)
+
+
+MockHandler._init_cls()
+
+ops = Virtualized("ops", MockHandler)
+_graph = Virtualized("graph", NullHandler)
+_kernel = Virtualized("kernel", NullHandler)
+_debug = Virtualized("debug", NullHandler)
+
+
+class _V:
+ MockHandler = MockHandler
+ WrapperHandler = WrapperHandler
+
+ set_ops_handler = ops._set_handler
+ get_ops_handler = ops._get_handler
+ set_graph_handler = _graph._set_handler
+ set_kernel_handler = _kernel._set_handler
+ set_debug_handler = _debug._set_handler
+
+ @property
+ def ops(self) -> MockHandler:
+ """The operator handler specific to the current codegen task"""
+ return ops._get_handler()
+
+ @property
+ def graph(self):
+ """The graph currently being generated"""
+ return _graph._get_handler()
+
+ @property
+ def kernel(self):
+ """The kernel currently being generated"""
+ return _kernel._get_handler()
+
+ @property
+ def debug(self):
+ return _debug._get_handler()
+
+
+V = _V()
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index 020f3742c2066..e81457e4a2487 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -1,5 +1,6 @@
#define PY_SSIZE_T_CLEAN
#include
+#include
// Only Python 3.7 through 3.10 supported
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 11
@@ -15,13 +16,6 @@
#undef Py_BUILD_CORE
#endif
-// C doesn't have bool types
-#ifndef bool
-#define bool char
-#endif
-#define false 0
-#define true 1
-
#ifdef _WIN32
#define unlikely(x) (x)
#else