Skip to content

Commit

Permalink
Apply pyfmt to fbcode/deeplearning
Browse files Browse the repository at this point in the history
Summary:
Formats a subset of opted-in Python files in fbsource.
Black formatting was applied first, which is guaranteed
safe as the AST will not have changed during formatting.
Pyfmt was then run, which also includes import sorting.
The changes from isort were manually reviewed, and
some potentially dangerous changes were reverted,
and the `# isort:skip_file` directive was added to those
files. A final run of pyfmt shows no more changes to
be applied.

Reviewed By: zertosh

Differential Revision: D24106710

fbshipit-source-id: 3e32224a813a256374f3856a6efda0363af0c56e
  • Loading branch information
amyreese authored and facebook-github-bot committed Oct 5, 2020
1 parent 929dae0 commit 18dc81a
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 23 deletions.
2 changes: 1 addition & 1 deletion torchbiggraph/bucket_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def create_buckets_ordered_by_affinity(
same_as_lhs = buckets_per_partition[b.lhs]
same_as_rhs = buckets_per_partition[b.rhs]
while len(same_as_lhs) > 0 or len(same_as_rhs) > 0:
chosen, = generator.choices(
(chosen,) = generator.choices(
[same_as_lhs, same_as_rhs], weights=[len(same_as_lhs), len(same_as_rhs)]
)
next_b = chosen.pop()
Expand Down
10 changes: 4 additions & 6 deletions torchbiggraph/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, private: str, public: str, fields: List[str]) -> None:


def model_state_dict_public_to_private(
public_state_dict: Optional[Dict[str, torch.Tensor]],
public_state_dict: Optional[Dict[str, torch.Tensor]]
) -> Optional[ModuleStateDict]:
if public_state_dict is None:
return None
Expand Down Expand Up @@ -198,9 +198,7 @@ def get_checkpoint_metadata(self) -> Dict[str, Any]:
return {"config/json": self.json_config_dict}


def serialize_optim_state(
optim_state: Optional[OptimizerStateDict],
) -> Optional[bytes]:
def serialize_optim_state(optim_state: Optional[OptimizerStateDict]) -> Optional[bytes]:
if optim_state is None:
return None
with io.BytesIO() as bf:
Expand Down Expand Up @@ -358,13 +356,13 @@ def append_stats(self, stats: List[Dict[str, Union[int, SerializedStats]]]) -> N
self.storage.append_stats([json.dumps(s) for s in stats])

def read_stats(
self
self,
) -> Generator[Dict[str, Union[int, SerializedStats]], None, None]:
for line in self.storage.load_stats():
yield json.loads(line)

def maybe_read_stats(
self
self,
) -> Generator[Dict[str, Union[int, SerializedStats]], None, None]:
try:
yield from self.read_stats()
Expand Down
2 changes: 1 addition & 1 deletion torchbiggraph/checkpoint_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def save_model_state_dict(hf: h5py.File, state_dict: Dict[str, ModelParameter])
dataset.attrs[STATE_DICT_KEY_ATTR] = param.private_name


def load_model_state_dict(hf: h5py.File,) -> Optional[ModuleStateDict]:
def load_model_state_dict(hf: h5py.File) -> Optional[ModuleStateDict]:
if MODEL_STATE_DICT_GROUP not in hf:
return None
g = hf[MODEL_STATE_DICT_GROUP]
Expand Down
6 changes: 3 additions & 3 deletions torchbiggraph/converters/export_to_tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def make_tsv_for_relation_types(
print("Writing relation type parameters...")
relation_types = relation_type_storage.load_names()
if model.num_dynamic_rels > 0:
rel_t_config, = model.relations
(rel_t_config,) = model.relations
op_name = rel_t_config.operator
lhs_operator, = model.lhs_operators
rhs_operator, = model.rhs_operators
(lhs_operator,) = model.lhs_operators
(rhs_operator,) = model.rhs_operators
for side, operator in [("lhs", lhs_operator), ("rhs", rhs_operator)]:
for param_name, all_params in operator.named_parameters():
for rel_t_name, param in zip(relation_types, all_params):
Expand Down
8 changes: 7 additions & 1 deletion torchbiggraph/converters/import_from_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def main():
loader = ConfigFileLoader()
config_dict = loader.load_raw_config(opt.config, opt.param)

entity_configs, relation_configs, entity_path, edge_paths, dynamic_relations = parse_config_partial( # noqa
(
entity_configs,
relation_configs,
entity_path,
edge_paths,
dynamic_relations,
) = parse_config_partial( # noqa
config_dict
)

Expand Down
8 changes: 7 additions & 1 deletion torchbiggraph/converters/import_from_tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def main():
loader = ConfigFileLoader()
config_dict = loader.load_raw_config(opt.config, opt.param)

entity_configs, relation_configs, entity_path, edge_paths, dynamic_relations = parse_config_partial( # noqa
(
entity_configs,
relation_configs,
entity_path,
edge_paths,
dynamic_relations,
) = parse_config_partial( # noqa
config_dict
)

Expand Down
2 changes: 1 addition & 1 deletion torchbiggraph/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def add_group(group_size: int) -> List[Rank]:
trainers = add_group(num_machines)
parameter_servers = add_group(num_machines)
parameter_clients = add_group(num_machines)
lock_server, = add_group(1)
(lock_server,) = add_group(1)
if num_partition_servers < 0:
# Use machines as partition servers
partition_servers = add_group(num_machines)
Expand Down
2 changes: 1 addition & 1 deletion torchbiggraph/edgelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def cat(cls, edge_lists: Sequence["EdgeList"]) -> "EdgeList":
if all(el.has_scalar_relation_type() for el in edge_lists):
rel_types = {el.get_relation_type_as_scalar() for el in edge_lists}
if len(rel_types) == 1:
rel_type, = rel_types
(rel_type,) = rel_types
return cls(cat_lhs, cat_rhs, torch.tensor(rel_type, dtype=torch.long))
cat_rel = torch.cat([el.rel.expand((len(el),)) for el in edge_lists])
return EdgeList(cat_lhs, cat_rhs, cat_rel)
Expand Down
3 changes: 1 addition & 2 deletions torchbiggraph/filtered_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def __init__(self, config: ConfigSchema, filter_paths: List[str]) -> None:
)
if not config.relations[0].all_negs:
raise RuntimeError("Filtered Eval can only be done with all negatives.")

entity, = config.entities.values()
(entity,) = config.entities.values()
if entity.featurized:
raise RuntimeError("Entity cannot be featurized for filtered eval.")
if entity.num_partitions > 1:
Expand Down
2 changes: 1 addition & 1 deletion torchbiggraph/graph_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def flush_buffer(self, _last: bool = False) -> None:
self.buffer_offset = 0

def append(self, tensor: torch.Tensor) -> None:
tensor_size, = tensor.shape
(tensor_size,) = tensor.shape
tensor_offset = 0
while True:
tensor_left = tensor_size - tensor_offset
Expand Down
10 changes: 5 additions & 5 deletions torchbiggraph/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# Optional[foo] is an alias for Union[foo, NoneType], but Unions are weird.
def unpack_optional(type_):
try:
candidate_arg, = set(type_.__args__) - {type(None)}
(candidate_arg,) = set(type_.__args__) - {type(None)}
except (AttributeError, LookupError, ValueError):
raise TypeError("Not an optional type")
if type_ != Optional[candidate_arg]:
Expand Down Expand Up @@ -100,7 +100,7 @@ def map_enum(self, data: Any, type_: Type[Enum]) -> Any:
def map_list(self, data: Any, type_) -> List:
if not isinstance(data, list):
raise DeepTypeError("Not a list")
element_type, = type_.__args__
(element_type,) = type_.__args__
result = []
for idx, element in enumerate(data):
try:
Expand Down Expand Up @@ -274,7 +274,7 @@ def represent_type(cls, type_):
if isclass(type_) and issubclass(type_, Enum):
return "(%s)" % "|".join(member.name.lower() for member in type_)
if has_origin(type_, list):
element_type, = type_.__args__
(element_type,) = type_.__args__
return "[%s]" % cls.represent_type(element_type)
if has_origin(type_, dict):
key_type, value_type = type_.__args__
Expand Down Expand Up @@ -309,7 +309,7 @@ def append_if_subschema(s):
pass
append_if_subschema(type_)
if has_origin(type_, list):
element_type, = type_.__args__
(element_type,) = type_.__args__
append_if_subschema(element_type)
elif has_origin(type_, dict):
_, value_type = type_.__args__
Expand Down Expand Up @@ -350,7 +350,7 @@ def extract_nested_type(type_: Any, path: List[str]) -> Any:
if len(path) == 0:
return type_
if has_origin(type_, list):
element_type, = type_.__args__
(element_type,) = type_.__args__
return extract_nested_type(element_type, path[1:])
if has_origin(type_, dict):
_, value_type = type_.__args__
Expand Down

0 comments on commit 18dc81a

Please sign in to comment.