Skip to content

Commit

Permalink
[BE] Enable more flake8-comprehensions checks (pytorch#94601)
Browse files Browse the repository at this point in the history
I applied some flake8 fixes and enabled checking for them in the linter. I also enabled some checks for my previous comprehensions PR.

This is a follow up to pytorch#94323 where I enable the flake8 checkers for the fixes I made and fix a few more of them.

Pull Request resolved: pytorch#94601
Approved by: https://github.com/ezyang
  • Loading branch information
Skylion007 authored and pytorchmergebot committed Feb 10, 2023
1 parent 0b31ebf commit 3d82d8d
Show file tree
Hide file tree
Showing 30 changed files with 71 additions and 82 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ignore =
# these ignores are from flake8-bugbear; please fix!
B007,B008,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
C400,C401,C402,C405,C407
per-file-ignores =
__init__.py: F401
torch/utils/cpp_extension.py: B950
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/dynamo/microbenchmarks/operator_inp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
return out

def log_to_file(self, output_filename, *, skip_non_compute_operators=True):
sorted_operators = sorted(list(self.func_db.keys()))
sorted_operators = sorted(self.func_db.keys())
with open(output_filename, "w") as f:
for operator in sorted_operators:
if skip_non_compute_operators and non_compute_operator(eval(operator)):
Expand Down
2 changes: 1 addition & 1 deletion scripts/model_zoo/update-models-from-caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def tensortype_to_ndarray(tensor_type):


def generate_test_input_data(onnx_model, scale):
real_inputs_names = list(set([input.name for input in onnx_model.graph.input]) - set([init.name for init in onnx_model.graph.initializer]))
real_inputs_names = list({input.name for input in onnx_model.graph.input} - {init.name for init in onnx_model.graph.initializer})
real_inputs = []
for name in real_inputs_names:
for input in onnx_model.graph.input:
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2297,7 +2297,7 @@ def test_ddp_packed_sequence(self):
store=store,
)
seqs = ["sequence_sequence", "seq", "sequence"]
vocab = ["<pad>"] + sorted(set([ch for seq in seqs for ch in seq]))
vocab = ["<pad>"] + sorted({ch for seq in seqs for ch in seq})
vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]
# Set the seed to make the embedding and LSTM deterministic (even
# across ranks since DDP broadcasts parameters from rank 0)
Expand Down
12 changes: 6 additions & 6 deletions test/functorch/discover_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def remove_torch(name):

def get_list_of_all_tests():
all_tests = list(tested_overridable_outplace_ops.keys())
return set([remove_torch(test) for test in all_tests])
return {remove_torch(test) for test in all_tests}


mytest = {
Expand Down Expand Up @@ -459,11 +459,11 @@ def get_jvp_coverage(subset=None):
supports_forwardad_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items()
if op_to_opinfo[fn][0].supports_forward_ad}

ops = set([remove_torch(test) for test in list(ops_dct.keys())])
supports_autograd = set([remove_torch(test)
for test in list(supports_autograd_ops_dct.keys())])
supports_forward_ad = set([remove_torch(test)
for test in list(supports_forwardad_ops_dct.keys())])
ops = {remove_torch(test) for test in list(ops_dct.keys())}
supports_autograd = {remove_torch(test)
for test in list(supports_autograd_ops_dct.keys())}
supports_forward_ad = {remove_torch(test)
for test in list(supports_forwardad_ops_dct.keys())}
assert supports_forward_ad.issubset(supports_autograd)
assert supports_autograd.issubset(ops)

Expand Down
4 changes: 2 additions & 2 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ def f(x):
return torch.tanh(x).sum()

fx_f = make_fx(grad(f))(torch.randn(5))
ops = set([i.target for i in fx_f.graph.nodes])
ops = {i.target for i in fx_f.graph.nodes}

self.assertEqual(torch.ops.aten.tanh_backward in ops, True)

fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5))
ops = set([i.target for i in fx_f.graph.nodes])
ops = {i.target for i in fx_f.graph.nodes}
self.assertEqual(torch.ops.aten.tanh_backward in ops, False)

def test_nnc_jit(self, device):
Expand Down
4 changes: 2 additions & 2 deletions test/functorch/test_minifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def failing_f(x, y):
failing_f = make_fx(failing_f)(*inps)

def has_mul(fx_g, inps):
return (torch.ops.aten.mul.Tensor in set([i.target for i in fx_g.graph.nodes]))
return (torch.ops.aten.mul.Tensor in (i.target for i in fx_g.graph.nodes))

min_f, inps = minifier(failing_f, inps, has_mul)
self.assertEqual(len(min_f.graph.nodes), 4)
Expand Down Expand Up @@ -74,7 +74,7 @@ def f(a, b):
inps = [torch.randn(3), torch.randn(3)]

def has_add(fx_g, inps):
return (torch.ops.aten.add.Tensor in set([i.target for i in fx_g.graph.nodes]))
return (torch.ops.aten.add.Tensor in (i.target for i in fx_g.graph.nodes))

failing_f = make_fx(f)(*inps)
min_f, inps = minifier(failing_f, inps, has_add)
Expand Down
2 changes: 1 addition & 1 deletion test/functorch/xfail_suggester.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_suggested_xfails(base, tests):
tests = [test[len(base):] for test in tests if
belongs_to_base(test, base)]

base_tests = set([remove_device_dtype(test) for test in tests])
base_tests = {remove_device_dtype(test) for test in tests}
tests = set(tests)
for base in base_tests:
cpu_variant = base + '_cpu_float32'
Expand Down
2 changes: 1 addition & 1 deletion test/jit/test_list_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def foo2():
self.checkScript(foo2, ())

def foo3():
return list(list("abc"))
return list(list("abc")) # noqa: C414

self.checkScript(foo3, ())
FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph)
Expand Down
2 changes: 1 addition & 1 deletion test/mobile/model_test/gen_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def calcOpsCoverage(ops):
"_coverage": round(coverage, 2),
"uncovered_ops": uncovered_ops_dict,
"covered_ops": covered_ops_dict,
"all_generated_ops": sorted(list(all_generated_ops)),
"all_generated_ops": sorted(all_generated_ops),
},
f,
)
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
if hasattr(test_suite, "check_dtype"):
options.check_dtype = test_suite.check_dtype

names = set([f.name for f in dataclasses.fields(options)])
names = {f.name for f in dataclasses.fields(options)}
keywords_to_pop = []
for k, v in kwargs.items():
if k in names:
Expand Down
2 changes: 1 addition & 1 deletion test/package/test_digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_all_paths(self):

result = g.all_paths("1", "3")
# to get rid of indeterminism
actual = set([i.strip("\n") for i in result.split(";")[2:-1]])
actual = {i.strip("\n") for i in result.split(";")[2:-1]}
expected = {
'"2" -> "3"',
'"1" -> "7"',
Expand Down
8 changes: 4 additions & 4 deletions test/quantization/eager/test_quantize_eager_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,10 @@ def checkQuantized(model):
# test one line API - out of place version
base = AnnotatedSingleLayerLinearModel(qengine)
base.qconfig = qconfig
keys_before = set(list(base.state_dict().keys()))
keys_before = set(base.state_dict().keys())
model = quantize(base, test_only_eval_fn, [self.calib_data])
checkQuantized(model)
keys_after = set(list(base.state_dict().keys()))
keys_after = set(base.state_dict().keys())
self.assertEqual(keys_before, keys_after) # simple check that nothing changed

# in-place version
Expand Down Expand Up @@ -1107,10 +1107,10 @@ def checkQuantized(model):

# test one line API - out of place version
base = SingleLayerLinearDynamicModel()
keys_before = set(list(base.state_dict().keys()))
keys_before = set(base.state_dict().keys())
model = quantize_dynamic(base, qconfig_dict)
checkQuantized(model)
keys_after = set(list(base.state_dict().keys()))
keys_after = set(base.state_dict().keys())
self.assertEqual(keys_before, keys_after) # simple check that nothing changed

# in-place version
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/fx/test_model_report_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def test_constructor(self):
model_report = ModelReport(model_prep, test_detector_set)

# make sure internal valid reports matches
detector_name_set = set([detector.get_detector_name() for detector in test_detector_set])
detector_name_set = {detector.get_detector_name() for detector in test_detector_set}
self.assertEqual(model_report.get_desired_reports_names(), detector_name_set)

# now attempt with no valid reports, should raise error
Expand Down Expand Up @@ -1329,7 +1329,7 @@ def test_input_weight_equalization_determine_points(self):
mods_to_check = set([nn.Linear, nn.Conv2d])

# get the set of all nodes in the graph their fqns
node_fqns = set([node.target for node in prepared_for_callibrate_model.graph.nodes])
node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes}

# there should be 4 node fqns that have the observer inserted
correct_number_of_obs_inserted = 4
Expand Down
2 changes: 1 addition & 1 deletion test/test_namedtuple_return_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def check_torch_return_type(f, names):
ret3 = meth(*op.input)
check_namedtuple(ret3, op.names)

all_covered_operators = set([x for y in operators for x in y.operators])
all_covered_operators = {x for y in operators for x in y.operators}

self.assertEqual(all_operators_with_namedtuple_return, all_covered_operators, textwrap.dedent('''
The set of covered operators does not match the `all_operators_with_namedtuple_return` of
Expand Down
2 changes: 1 addition & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def forward(mod_self, x): # noqa: B902


gm = make_fx(Emformer())(torch.randn(16, 1, 256))
ops = set([n.target for n in gm.graph.nodes if n.op == 'call_function'])
ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
self.assertEqual(len(ops), 2)


Expand Down
2 changes: 1 addition & 1 deletion test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _test_coalesce(t):
else:
value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val

new_indices = sorted(list(value_map.keys()))
new_indices = sorted(value_map.keys())
_new_values = [value_map[idx] for idx in new_indices]
if t._values().ndimension() < 2:
new_values = t._values().new(_new_values)
Expand Down
12 changes: 5 additions & 7 deletions torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,11 @@ def _module_dir(m: types.ModuleType):
}

# Include optimizer code for tracing
FILENAME_ALLOWLIST |= set(
[
inspect.getfile(obj)
for obj in torch.optim.__dict__.values()
if inspect.isclass(obj)
]
)
FILENAME_ALLOWLIST |= {
inspect.getfile(obj)
for obj in torch.optim.__dict__.values()
if inspect.isclass(obj)
}
FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}

if HAS_PRIMS_REFS:
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def enum_repr(value):


def dict_param_key_ids(value):
return set([id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)])
return {id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)}


def dict_const_keys(value):
Expand All @@ -771,7 +771,7 @@ def dict_const_keys_repr(const_keys):
if any(isinstance(k, enum.Enum) for k in const_keys):
# To workaround repr(Enum) returning invalid global reference before python 3.11
# by calling enum_repr and removing quotes to render enum in guard code.
const_keys_str = f"{set([enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys])}".replace(
const_keys_str = f"{set(enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys)}".replace(
"'", ""
)
else:
Expand Down
17 changes: 6 additions & 11 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,17 +304,12 @@ def index_source(key):
else:
return key

result = dict(
[
(
k,
VariableBuilder(
self.tx, GetItemSource(self.get_source(), index_source(k))
)(value[k]).add_guards(guards),
)
for k in value.keys()
]
)
result = {
k: VariableBuilder(
self.tx, GetItemSource(self.get_source(), index_source(k))
)(value[k]).add_guards(guards)
for k in value.keys()
}

if istype(value, collections.defaultdict):
result = DefaultDictVariable(
Expand Down
6 changes: 3 additions & 3 deletions torch/_functorch/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def is_tensor_node(x):
for node in joint_module.graph.nodes
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
)
ops_ignored = joint_module_ops - set([str(i) for i in recomputable_ops])
ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops}
print("Ops banned from rematerialization: ", ops_ignored)
print()

Expand Down Expand Up @@ -522,8 +522,8 @@ def get_node_weight(node) -> int:
joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)
if AOT_PARTITIONER_DEBUG:
print("Theoretical Activations Stored: ", sum([_size_of(i) for i in saved_values]) / 1e9)
fw_module_nodes = set([node.name for node in fw_module.graph.nodes if node.op == 'call_function'])
bw_module_nodes = set([node.name for node in bw_module.graph.nodes if node.op == 'call_function'])
fw_module_nodes = {node.name for node in fw_module.graph.nodes if node.op == 'call_function'}
bw_module_nodes = {node.name for node in bw_module.graph.nodes if node.op == 'call_function'}
remat_nodes = fw_module_nodes & bw_module_nodes

counts = defaultdict(int)
Expand Down
4 changes: 1 addition & 3 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,7 @@ def get_read_write_buffers_sizes(node):
writes = set(dep.name for dep in node.read_writes.writes)

def is_materialized(buf):
buf_uses = set(
[user.node for user in scheduler.name_to_node[buf].users]
)
buf_uses = {user.node for user in scheduler.name_to_node[buf].users}
return len(buf_uses - set(node.snodes)) > 0

if isinstance(node, FusedSchedulerNode):
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def fresh_inductor_cache(cache_entries=None):

def argsort(seq):
# preserve original order for equal strides
return list(reversed(sorted(range(len(seq)), key=seq.__getitem__, reverse=True)))
getter = seq.__getitem__
a_r = range(len(seq))
return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413


@functools.lru_cache(8)
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/fx/_model_report/model_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBas

# keep the reports private so they can't be modified
self._desired_report_detectors = desired_report_detectors
self._desired_detector_names = set([detector.get_detector_name() for detector in desired_report_detectors])
self._desired_detector_names = {detector.get_detector_name() for detector in desired_report_detectors}

# keep a mapping of desired reports to observers of interest
# this is to get the readings, and to remove them, can create a large set
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,7 @@ def _all_gather_optim_state(
gathered_state: Dict[str, Any] = {}

all_tensor_states = sorted(
set([n for state in object_list for n in state.tensors.keys()])
{n for state in object_list for n in state.tensors.keys()}
)
empty_ranks: Set[int] = set()
for name in all_tensor_states:
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/_symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def __init__(
for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
if not name.startswith("_") and callable(value)
}
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
self._autowrap_function_ids.update({id(f) for f in autowrap_functions})

# Python modules to apply autowrap to at the start, in addition to
# modules we see while tracing
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3611,8 +3611,8 @@ def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs):
torch = kwargs.get('torch', globals()['torch'])
dtype = kwargs.get('dtype', torch.double)
device = kwargs.get('device', 'cpu')
data = dict([((i, i), float(i + 1) / matrix_size)
for i in range(matrix_size)])
data = {(i, i): float(i + 1) / matrix_size
for i in range(matrix_size)}


def multiply(data, N, i, j, cs, sn, left=True):
Expand Down
36 changes: 16 additions & 20 deletions torchgen/gen_backend_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,29 +377,25 @@ def gen_dispatchkey_nativefunc_headers(
# Convert to a set first to remove duplicate kernel names.
# Backends are allowed to repeat kernel names; only generate the declaration once!
# Sort for deterministic output.
backend_declarations = list(
sorted(
set(
concatMap(
lambda f: dest.compute_native_function_declaration(
f, backend_indices[backend_dispatch_key]
),
grouped_native_functions,
)
backend_declarations = sorted(
set(
concatMap(
lambda f: dest.compute_native_function_declaration(
f, backend_indices[backend_dispatch_key]
),
grouped_native_functions,
)
)
)
autograd_declarations = list(
sorted(
set(
concatMap(
lambda f: []
if autograd_dispatch_key is None
else dest.compute_native_function_declaration(
f, backend_indices[autograd_dispatch_key]
),
grouped_native_functions,
)
autograd_declarations = sorted(
set(
concatMap(
lambda f: []
if autograd_dispatch_key is None
else dest.compute_native_function_declaration(
f, backend_indices[autograd_dispatch_key]
),
grouped_native_functions,
)
)
)
Expand Down
Loading

0 comments on commit 3d82d8d

Please sign in to comment.