Skip to content

Commit

Permalink
Use absolute imports and format them in Black-style
Browse files Browse the repository at this point in the history
Summary:
This was achieved automatically by
- search-and-replacing `^from \.` with `from torchbiggraph.`
- fixed manually that one late import
- running `isort -p torchbiggraph -o attr -o torch -o numpy -o torch_extensions -o tqdm -o h5py -m 3`

Black-style is "longer" for imports that don't fit on one line but has the advantage of making future diffs that remove or add imports more compact (in terms of number of lines touched) and easier to understand.

Reviewed By: adamlerer

Differential Revision: D15149757

fbshipit-source-id: 542b50e1c79bf500b4b4e0abd5ff4bb4aecd412e
  • Loading branch information
lw authored and facebook-github-bot committed May 1, 2019
1 parent f65a963 commit 93577cc
Show file tree
Hide file tree
Showing 29 changed files with 213 additions and 109 deletions.
4 changes: 1 addition & 3 deletions examples/fb15k.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
import attr

import torchbiggraph.converters.utils as utils
from filtered_eval import FilteredRankingEvaluator
from torchbiggraph.config import parse_config
from torchbiggraph.converters.import_from_tsv import convert_input_data
from torchbiggraph.eval import do_eval
from torchbiggraph.train import train

from filtered_eval import FilteredRankingEvaluator


FB15K_URL = 'https://dl.fbaipublicfiles.com/starspace/fb15k.tgz'
FILENAMES = {
'train': 'FB15k/freebase_mtr100_mte100-train.txt',
Expand Down
2 changes: 1 addition & 1 deletion examples/filtered_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from torchbiggraph.eval import RankingEvaluator
from torchbiggraph.fileio import EdgeReader
from torchbiggraph.model import Scores
from torchbiggraph.util import log
from torchbiggraph.stats import Stats
from torchbiggraph.types import Partition
from torchbiggraph.util import log


class FilteredRankingEvaluator(RankingEvaluator):
Expand Down
1 change: 0 additions & 1 deletion examples/livejournal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torchbiggraph.eval import do_eval
from torchbiggraph.train import train


URL = 'https://snap.stanford.edu/data/soc-LiveJournal1.txt.gz'
FILENAMES = {
'train': 'train.txt',
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from setuptools import setup, find_packages
from setuptools import find_packages, setup

with open("README.md", "rt") as f:
long_description = f.read()
Expand Down
8 changes: 4 additions & 4 deletions tests/batching_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@

import torch

from torchbiggraph.edgelist import EdgeList
from torchbiggraph.entitylist import EntityList
from torchbiggraph.batching import (
group_by_relation_type,
batch_edges_mix_relation_types,
batch_edges_group_by_relation_type,
batch_edges_mix_relation_types,
group_by_relation_type,
)
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.entitylist import EntityList


class TestGroupByRelationType(TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/bucket_scheduling_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from itertools import product
from unittest import TestCase, main

from torchbiggraph.config import BucketOrder
from torchbiggraph.bucket_scheduling import create_ordered_buckets
from torchbiggraph.config import BucketOrder


class TestCreateOrderedBuckets(TestCase):
Expand Down
4 changes: 2 additions & 2 deletions tests/fileio_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import numpy as np
import torch

from torchbiggraph.config import EntitySchema, RelationSchema, ConfigSchema
from torchbiggraph.fileio import DatasetIO, Mapping, ConfigMetadataProvider
from torchbiggraph.config import ConfigSchema, EntitySchema, RelationSchema
from torchbiggraph.fileio import ConfigMetadataProvider, DatasetIO, Mapping


class TestDatasetIO(TestCase):
Expand Down
9 changes: 7 additions & 2 deletions tests/functional_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@
import h5py
import numpy as np

from torchbiggraph.config import Operator, EntitySchema, RelationSchema, ConfigSchema
from torchbiggraph.config import (
ConfigSchema,
EntitySchema,
Operator,
RelationSchema,
)
from torchbiggraph.eval import do_eval
from torchbiggraph.partitionserver import run_partition_server
from torchbiggraph.train import train
from torchbiggraph.eval import do_eval


class Dataset(NamedTuple):
Expand Down
27 changes: 17 additions & 10 deletions tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,24 @@

from torchbiggraph.entitylist import EntityList
from torchbiggraph.model import (
AffineDynamicOperator,
AffineOperator,
BiasedComparator,
ComplexDiagonalDynamicOperator,
ComplexDiagonalOperator,
CosComparator,
DiagonalDynamicOperator,
DiagonalOperator,
DotComparator,
FeaturizedEmbedding,
IdentityDynamicOperator,
IdentityOperator,
LinearDynamicOperator,
LinearOperator,
SimpleEmbedding,
TranslationDynamicOperator,
TranslationOperator,
match_shape,
# Embeddings
SimpleEmbedding, FeaturizedEmbedding,
# Operators
IdentityOperator, DiagonalOperator, TranslationOperator, LinearOperator,
AffineOperator, ComplexDiagonalOperator,
# Dynamic operators
IdentityDynamicOperator, DiagonalDynamicOperator, TranslationDynamicOperator,
LinearDynamicOperator, AffineDynamicOperator, ComplexDiagonalDynamicOperator,
# Comparator
DotComparator, CosComparator, BiasedComparator,
)


Expand Down
11 changes: 9 additions & 2 deletions tests/schema_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@

import attr

from torchbiggraph.schema import unpack_optional, Loader, Dumper, schema, Schema, \
extract_nested_type, inject_nested_value
from torchbiggraph.schema import (
Dumper,
Loader,
Schema,
extract_nested_type,
inject_nested_value,
schema,
unpack_optional,
)


class TestUnpackOptional(TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/util_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from unittest import TestCase, main

from torchbiggraph.util import (
split_almost_equally,
round_up_to_nearest_multiple,
split_almost_equally,
)


Expand Down
8 changes: 4 additions & 4 deletions torchbiggraph/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

import torch

from .edgelist import EdgeList
from .model import MultiRelationEmbedder
from .stats import Stats
from .types import LongTensorType
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.model import MultiRelationEmbedder
from torchbiggraph.stats import Stats
from torchbiggraph.types import LongTensorType


def group_by_relation_type(edges: EdgeList) -> List[EdgeList]:
Expand Down
9 changes: 4 additions & 5 deletions torchbiggraph/bucket_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@

from torch_extensions.rpc.rpc import Client, Server

from .config import BucketOrder
from .distributed import Startable
from .types import Bucket, EntityName, Partition, Rank, Side
from .util import log, vlog

from torchbiggraph.config import BucketOrder
from torchbiggraph.distributed import Startable
from torchbiggraph.types import Bucket, EntityName, Partition, Rank, Side
from torchbiggraph.util import log, vlog

###
### Bucket scheduling interface.
Expand Down
15 changes: 12 additions & 3 deletions torchbiggraph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,17 @@
import attr
from attr.validators import optional

from .schema import has_origin, DeepTypeError, Schema, schema, non_negative, \
positive, non_empty, extract_nested_type, inject_nested_value
from torchbiggraph.schema import (
DeepTypeError,
Schema,
extract_nested_type,
has_origin,
inject_nested_value,
non_empty,
non_negative,
positive,
schema,
)


class Operator(Enum):
Expand Down Expand Up @@ -433,7 +442,7 @@ def parse_config(config_filename: str, overrides: Optional[List[str]] = None) ->
print(str(err), file=sys.stderr)
exit(1)
# Late import to avoid circular dependency.
from . import util
from torchbiggraph import util
util._verbosity_level = config.verbose
return config

Expand Down
8 changes: 6 additions & 2 deletions torchbiggraph/converters/import_from_tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
import h5py
import numpy as np

from torchbiggraph.config import \
ConfigSchema, EntitySchema, RelationSchema, get_config_dict_from_module
from torchbiggraph.config import (
ConfigSchema,
EntitySchema,
RelationSchema,
get_config_dict_from_module,
)
from torchbiggraph.converters.dictionary import Dictionary


Expand Down
2 changes: 1 addition & 1 deletion torchbiggraph/converters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import gzip
import os
import tarfile
import shutil
import tarfile
import urllib.request
from typing import Callable, Optional

Expand Down
4 changes: 2 additions & 2 deletions torchbiggraph/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import torch.distributed as td
import torch.multiprocessing as mp

from .types import Rank
from .util import log
from torchbiggraph.types import Rank
from torchbiggraph.util import log


class ProcessRanks(NamedTuple):
Expand Down
4 changes: 2 additions & 2 deletions torchbiggraph/edgelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import torch

from .entitylist import EntityList
from .types import LongTensorType
from torchbiggraph.entitylist import EntityList
from torchbiggraph.types import LongTensorType


class EdgeList:
Expand Down
2 changes: 1 addition & 1 deletion torchbiggraph/entitylist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch_extensions.tensorlist.tensorlist import TensorList

from .types import LongTensorType
from torchbiggraph.types import LongTensorType


class EntityList:
Expand Down
32 changes: 22 additions & 10 deletions torchbiggraph/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,28 @@

import torch

from .batching import call, process_in_batches, AbstractBatchProcessor
from .bucket_scheduling import create_buckets_ordered_lexicographically
from .config import parse_config, ConfigSchema
from .edgelist import EdgeList
from .fileio import CheckpointManager, EdgeReader
from .model import Scores, MultiRelationEmbedder, make_model
from .stats import average_of_sums, Stats
from .types import Side, Bucket, EntityName, Partition
from .util import log, get_partitioned_types, create_pool, compute_randomized_auc,\
split_almost_equally, get_num_workers
from torchbiggraph.batching import (
AbstractBatchProcessor,
call,
process_in_batches,
)
from torchbiggraph.bucket_scheduling import (
create_buckets_ordered_lexicographically
)
from torchbiggraph.config import ConfigSchema, parse_config
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.fileio import CheckpointManager, EdgeReader
from torchbiggraph.model import MultiRelationEmbedder, Scores, make_model
from torchbiggraph.stats import Stats, average_of_sums
from torchbiggraph.types import Bucket, EntityName, Partition, Side
from torchbiggraph.util import (
compute_randomized_auc,
create_pool,
get_num_workers,
get_partitioned_types,
log,
split_almost_equally,
)


class RankingEvaluator(AbstractBatchProcessor):
Expand Down
26 changes: 15 additions & 11 deletions torchbiggraph/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,23 @@
import numpy as np
import torch
import torch.multiprocessing as mp
from torch_extensions.rpc.rpc import _deserialize as torch_rpc_deserialize
from torch_extensions.rpc.rpc import _serialize as torch_rpc_serialize
from torch_extensions.tensorlist.tensorlist import TensorList
from torch_extensions.rpc.rpc import (
_serialize as torch_rpc_serialize,
_deserialize as torch_rpc_deserialize,
)

from .config import ConfigSchema
from .edgelist import EdgeList
from .entitylist import EntityList
from .parameter_sharing import ParameterClient
from .types import EntityName, Partition, Rank, OptimizerStateDict, ModuleStateDict, \
FloatTensorType
from .util import log, vlog, create_pool
from torchbiggraph.config import ConfigSchema
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.entitylist import EntityList
from torchbiggraph.parameter_sharing import ParameterClient
from torchbiggraph.types import (
EntityName,
FloatTensorType,
ModuleStateDict,
OptimizerStateDict,
Partition,
Rank,
)
from torchbiggraph.util import create_pool, log, vlog


def maybe_old_entity_path(path: str) -> bool:
Expand Down
29 changes: 22 additions & 7 deletions torchbiggraph/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,34 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from enum import Enum
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Type, Union
from typing import (
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Type,
Union,
)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_extensions.tensorlist.tensorlist import TensorList

from .config import Operator, Comparator, EntitySchema, RelationSchema, ConfigSchema
from .edgelist import EdgeList
from .entitylist import EntityList
from .fileio import maybe_old_entity_path
from .types import Side, FloatTensorType, LongTensorType
from .util import log
from torchbiggraph.config import (
Comparator,
ConfigSchema,
EntitySchema,
Operator,
RelationSchema,
)
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.entitylist import EntityList
from torchbiggraph.fileio import maybe_old_entity_path
from torchbiggraph.types import FloatTensorType, LongTensorType, Side
from torchbiggraph.util import log


def match_shape(tensor, *expected_shape):
Expand Down
Loading

0 comments on commit 93577cc

Please sign in to comment.