Skip to content

Commit

Permalink
Delete dynamo_import and inductor_import (pytorch#93851)
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#93851
Approved by: https://github.com/albanD, https://github.com/jansel
  • Loading branch information
ezyang authored and pytorchmergebot committed Feb 2, 2023
1 parent 74592a4 commit ca9ebf9
Show file tree
Hide file tree
Showing 15 changed files with 49 additions and 65 deletions.
6 changes: 0 additions & 6 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,6 @@
# If True, raise when aot autograd is unsafe to use
raise_on_unsafe_aot_autograd = False

# How to import torchdynamo, either torchdynamo or torch._dynamo
dynamo_import = __name__.replace(".config", "")

# How to import torchinductor, either torchinductor or torch.inductor
inductor_import = dynamo_import.replace("dynamo", "inductor")

# If true, error with a better message if we symbolically trace over a
# dynamo-optimized function. If false, silently suppress dynamo.
error_on_nested_fx_trace = True
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def format_guard_failures(code):

assert code in guard_failures, "TODO(whc) any other recompile reasons?"
log.warning(
f"{config.dynamo_import} hit config.cache_size_limit ({config.cache_size_limit})\n"
f"torch._dynamo hit config.cache_size_limit ({config.cache_size_limit})\n"
+ f" function: {format_func_info(code)}\n"
+ f" reasons: {format_guard_failures(code)}\n"
+ f"to diagnose recompilation issues, see {troubleshooting_url}."
Expand Down
40 changes: 20 additions & 20 deletions torch/_dynamo/debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
log = logging.getLogger(__name__)


inductor_config = import_module(f"{config.inductor_import}.config")
inductor_config = import_module("torch._inductor.config")
use_buck = inductor_config.is_fbcode()


Expand Down Expand Up @@ -224,10 +224,10 @@ def generate_config_string():

return textwrap.dedent(
f"""\
import {config.dynamo_import}.config
import {config.inductor_import}.config
{config.dynamo_import}.config.load_config({repr(torch._dynamo.config.save_config())})
{config.inductor_import}.config.load_config({repr(torch._inductor.config.save_config())})
import torch._dynamo.config
import torch._inductor.config
torch._dynamo.config.load_config({repr(torch._dynamo.config.save_config())})
torch._inductor.config.load_config({repr(torch._inductor.config.save_config())})
"""
)

Expand All @@ -241,7 +241,7 @@ def generate_compiler_repro_string(gm, args):
import torch
from torch import tensor, device
import torch.fx as fx
from {config.dynamo_import}.testing import rand_strided
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
Expand Down Expand Up @@ -273,9 +273,9 @@ def generate_compiler_repro_string(gm, args):
return model_str


INDUCTOR_IMPORT = f"""
from {config.inductor_import}.compile_fx import compile_fx_inner
from {config.dynamo_import}.debug_utils import same_two_models
INDUCTOR_IMPORT = """
from torch._inductor.compile_fx import compile_fx_inner
from torch._dynamo.debug_utils import same_two_models
"""

COMPILER_REPRO_OPTIONS = {
Expand Down Expand Up @@ -316,7 +316,7 @@ def save_graph_repro(fd, gm, args, compiler_name):
break

if "inductor" in compiler_name:
fd.write(f"import {config.inductor_import}.overrides\n")
fd.write("import torch._inductor.overrides\n")
fd.write(generate_compiler_repro_string(gm, args))
fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0])
if "_accuracy" in compiler_name:
Expand Down Expand Up @@ -757,10 +757,10 @@ class AccuracyError(Exception):
import torch
from torch import tensor, device
import torch.fx as fx
import {config.dynamo_import}
from {config.dynamo_import}.testing import rand_strided
from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd
from {config.dynamo_import}.debug_utils import same_two_models
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models
{generate_config_string()}
Expand All @@ -773,7 +773,7 @@ class AccuracyError(Exception):
{model_str}
mod = Repro()
opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod)
opt_mod = torch._dynamo.optimize("{compiler_name}")(mod)
{run_code}
"""
Expand Down Expand Up @@ -954,10 +954,10 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name):
from torch import tensor, device
import torch.fx as fx
import functools
import {config.dynamo_import}
from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd
from {config.dynamo_import}.optimizations.backends import BACKENDS
from {config.dynamo_import}.testing import rand_strided
import torch._dynamo
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.optimizations.backends import BACKENDS
from torch._dynamo.testing import rand_strided
{generate_config_string()}
Expand All @@ -978,7 +978,7 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name):
compiler_fn,
compiler_name="{compiler_name}",
)
opt_mod = {config.dynamo_import}.optimize(dynamo_minifier_backend)(mod)
opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod)
with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}):
opt_mod(*args)
Expand Down
11 changes: 3 additions & 8 deletions torch/_dynamo/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,10 @@ def augment_exc_message(exc, msg="\n"):

if config.replay_record_enabled and hasattr(exc, "record_filename"):
msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
{config.dynamo_import}.replay('{exc.record_filename}').\n"
torch._dynamo.replay('{exc.record_filename}').\n"

if not config.verbose:
msg += (
f"\nSet {config.dynamo_import}.config.verbose=True for more information\n"
)
msg += "\nSet torch._dynamo.config.verbose=True for more information\n"

if hasattr(exc, "inner_exception") and hasattr(
exc.inner_exception, "minifier_path"
Expand Down Expand Up @@ -143,10 +141,7 @@ def filter_stack(stack):
for frame in stack:
if "convert_frame" in frame.filename:
break
if (
"eval_frame" in frame.filename
or f"{config.dynamo_import}.optimize(" in frame.line
):
if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line:
continue
user_stack.append(frame)

Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/optimizations/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _wrapped_bw_compiler(*args, **kwargs):
bw_compiler=nop,
# NB: lambda here is to delay import of inductor
decompositions=lambda: import_module(
f"{config.inductor_import}.compile_fx"
"torch._inductor.compile_fx"
).select_decomp_table(),
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def is_torch_inline_allowed(filename):

@functools.lru_cache(None)
def dynamo_dir():
return _module_dir(importlib.import_module(config.dynamo_import))
import torch._dynamo

return _module_dir(torch._dynamo)


def is_torch(filename):
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ def recompile_reasons(code):
rpt += "\n"
rpt += "The following conditions caused torchdynamo to break out of tracing and fall back to python.\n"
rpt += (
f"You may gain additional insight by passing `nopython=True` to {config.dynamo_import}.optimize, "
"You may gain additional insight by passing `nopython=True` to torch._dynamo.optimize, "
"to break on the first condition.\n"
)
graph_breaks = counters["graph_break"]
Expand All @@ -1086,7 +1086,7 @@ def recompile_reasons(code):
)
rpt += "\n"
rpt += (
f"Set {config.dynamo_import}.config.cache_size_limit to "
f"Set torch._dynamo.config.cache_size_limit to "
f"{max_recompiles} to avoid being cache limited.\n"
)
else:
Expand Down
6 changes: 2 additions & 4 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch._C
from torch._guards import Guard, GuardSource

from .. import config, variables
from .. import variables
from ..bytecode_transformation import create_instruction
from ..exc import unimplemented
from ..guards import GuardBuilder
Expand Down Expand Up @@ -716,9 +716,7 @@ def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
unimplemented(
f"call {config.dynamo_import}.disable() wrapped function {self.value}"
)
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
else:
try:
path = inspect.getfile(self.value)
Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,10 +1066,10 @@ def codegen_kernel(self, name=None):
f"""
import triton
import triton.language as tl
from {config.inductor_import}.ir import ReductionHint
from {config.inductor_import}.ir import TileHint
from {config.inductor_import}.triton_ops.autotune import {heuristics}
from {config.inductor_import}.utils import instance_descriptor
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import {heuristics}
from torch._inductor.utils import instance_descriptor
"""
)

Expand Down
12 changes: 6 additions & 6 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,20 +286,20 @@ def __init__(self):

if has_triton():
self.header.splice(
f"""
"""
import triton
import triton.language as tl
from {config.inductor_import}.triton_ops.autotune import grid
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
"""
)

if config.triton.convolution != "aten":
self.header.splice(
f"""
from {config.inductor_import}.triton_ops.conv_perf_model import early_config_prune
from {config.inductor_import}.triton_ops.conv_perf_model import estimate_conv_time
from {config.inductor_import}.triton_ops.autotune import conv_heuristics
"""
from torch._inductor.triton_ops.conv_perf_model import early_config_prune
from torch._inductor.triton_ops.conv_perf_model import estimate_conv_time
from torch._inductor.triton_ops.autotune import conv_heuristics
"""
)

Expand Down
3 changes: 0 additions & 3 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ def is_fbcode():
# for larger kernels limit this
kernel_name_max_ops = 10

# How to import torchinductor, either torchinductor or torch.inductor
inductor_import = __name__.replace(".config", "")

# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "0") == "1"

Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def upload_tar(self):
config.trace.upload_tar(tar_file)

def __enter__(self):
log = logging.getLogger(config.inductor_import)
log = logging.getLogger("torch._inductor")
if not log.handlers:
init_logging()

Expand Down Expand Up @@ -318,7 +318,7 @@ def reset_log_level(level):
self._prof.enable()

def _setup_log_capture(self, filename, level):
log = logging.getLogger(config.inductor_import)
log = logging.getLogger("torch._inductor")
fd = self._stack.enter_context(self.fopen(filename))
ch = logging.StreamHandler(fd)
ch.setLevel(level)
Expand Down
4 changes: 1 addition & 3 deletions torch/_inductor/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import textwrap
from functools import lru_cache

from . import config

if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":

@lru_cache(None)
Expand Down Expand Up @@ -45,7 +43,7 @@ def __init__(self, target, args, kwargs):
There is a decomposition available for {target} in
torch._decomp.get_decompositions(). Please add this operator to the
`decompositions` list in {config.inductor_import}.decompositions
`decompositions` list in torch._inductor.decompositions
"""
)
)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3038,7 +3038,7 @@ def __init__(

def codegen(self, wrapper):
if self.kernel.startswith("triton_ops."):
wrapper.header.writeline(f"from {config.inductor_import} import triton_ops")
wrapper.header.writeline("from torch._inductor import triton_ops")
wrapper.writeline(
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
)
Expand Down
6 changes: 3 additions & 3 deletions torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import counters, identity

from . import config, ir
from . import ir
from .codecache import code_hash, DiskCache, PyCodeCache

from .codegen.common import IndentedBuffer
Expand Down Expand Up @@ -134,8 +134,8 @@ def def_kernel(self, *argnames):
[
"import triton.language as tl",
"import triton",
f"from {config.inductor_import}.triton_ops.autotune import template",
f"from {config.inductor_import}.utils import instance_descriptor",
"from torch._inductor.triton_ops.autotune import template",
"from torch._inductor.utils import instance_descriptor",
"",
self.jit_line(),
f"def {self.kernel_name}({', '.join(arg_defs)}):",
Expand Down

0 comments on commit ca9ebf9

Please sign in to comment.