Skip to content

Commit

Permalink
Avoid code duplication in plugin registration
Browse files Browse the repository at this point in the history
Summary: Make some "framework" for the plugin registration system in order to avoid the copy-pasting that had been going on until now.

Reviewed By: adamlerer

Differential Revision: D17153537

fbshipit-source-id: 700cccd4488933aa962174805000d7df43449484
  • Loading branch information
lw authored and facebook-github-bot committed Sep 3, 2019
1 parent b3bc415 commit f4a819b
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 187 deletions.
6 changes: 3 additions & 3 deletions docs/source/scoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ All the operators' parameters are learned during training.

To define an additional operator, one must subclass the :class:`torchbiggraph.model.AbstractOperator` class
(or the :class:`torchbiggraph.model.AbstractDynamicOperator` one when using :ref:`dynamic relations <dynamic-relations>`;
their docstrings explain what must be implemented) and decorate it with the :func:`torchbiggraph.model.register_operator_as`
decorator (respectively the :func:`torchbiggraph.model.register_dynamic_operator_as` one), specifying a new
their docstrings explain what must be implemented) and decorate it with the :func:`torchbiggraph.model.OPERATORS.register_as`
decorator (respectively the :func:`torchbiggraph.model.DYNAMIC_OPERATORS.register_as` one), specifying a new
name that can then be used in the config to select that comparator.
All of the above can be done inside the config file itself.

Expand All @@ -113,7 +113,7 @@ The available comparators are:
* ``squared_l2``, the *negative* squared L2 distance.

Custom comparators need to extend the :class:`torchbiggraph.model.AbstractComparator` class
(its docstring explains how) and decorate it with the :func:`torchbiggraph.model.register_comparator_as`
(its docstring explains how) and decorate it with the :func:`torchbiggraph.model.COMPARATORS.register_as`
decorator, specifying a new name that can then be used in the config to select that comparator.
All of the above can be done inside the config file itself.

Expand Down
8 changes: 1 addition & 7 deletions torchbiggraph/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse

import numpy as np
import torch
Expand Down Expand Up @@ -218,12 +217,7 @@ def __init__(
subprocess_name: Optional[str] = None,
subprocess_init: Optional[Callable[[], None]] = None,
) -> None:
scheme = urlparse(url).scheme
try:
self.storage: AbstractCheckpointStorage = CHECKPOINT_STORAGES[scheme](url)
except LookupError:
raise RuntimeError(f"Couldn't find any checkpoint storage "
f"for scheme {scheme} used by {url}")
self.storage: AbstractCheckpointStorage = CHECKPOINT_STORAGES.make_instance(url)
self.dirty: Set[Tuple[EntityName, Partition]] = set()
self.rank: Rank = rank
self.num_machines: int = num_machines
Expand Down
20 changes: 4 additions & 16 deletions torchbiggraph/checkpoint_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import torch

from torchbiggraph.plugin import URLPluginRegistry
from torchbiggraph.types import (
EntityName,
FloatTensorType,
Expand Down Expand Up @@ -141,20 +142,7 @@ def copy_version_to_snapshot(self, version: int, epoch_idx: int) -> None:
pass


CHECKPOINT_STORAGES: Dict[str, Type[AbstractCheckpointStorage]] = {}


def register_checkpoint_storage_for_scheme(
scheme: str,
) -> Callable[[Type[AbstractCheckpointStorage]], Type[AbstractCheckpointStorage]]:
def decorator(class_: Type[AbstractCheckpointStorage]) -> Type[AbstractCheckpointStorage]:
reg_class = CHECKPOINT_STORAGES.setdefault(scheme, class_)
if reg_class is not class_:
raise RuntimeError(
f"Attempting to re-register a checkpoint storage for scheme "
f"{scheme} which was already set to {reg_class!r}")
return class_
return decorator
CHECKPOINT_STORAGES = URLPluginRegistry[AbstractCheckpointStorage]()


NP_VOID_DTYPE = np.dtype("V1")
Expand Down Expand Up @@ -227,8 +215,8 @@ def process_dataset(public_name, dataset) -> None:
return state_dict


@register_checkpoint_storage_for_scheme("") # No scheme
@register_checkpoint_storage_for_scheme("file")
@CHECKPOINT_STORAGES.register_as("") # No scheme
@CHECKPOINT_STORAGES.register_as("file")
class FileCheckpointStorage(AbstractCheckpointStorage):

"""Reads and writes checkpoint data to/from disk.
Expand Down
33 changes: 4 additions & 29 deletions torchbiggraph/edgelist_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Type
from urllib.parse import urlparse

import h5py
import numpy as np
Expand All @@ -19,6 +17,7 @@

from torchbiggraph.edgelist import EdgeList
from torchbiggraph.entitylist import EntityList
from torchbiggraph.plugin import URLPluginRegistry
from torchbiggraph.types import Partition


Expand All @@ -42,40 +41,16 @@ def read_edgelist(
pass


EDGELIST_READERS: Dict[str, Type[AbstractEdgelistReader]] = {}


def register_edgelist_reader_for_scheme(
scheme: str,
) -> Callable[[Type[AbstractEdgelistReader]], Type[AbstractEdgelistReader]]:
def decorator(class_: Type[AbstractEdgelistReader]) -> Type[AbstractEdgelistReader]:
reg_class = EDGELIST_READERS.setdefault(scheme, class_)
if reg_class is not class_:
raise RuntimeError(
f"Attempting to re-register an edgelist reader for scheme "
f"{scheme} which was already set to {reg_class!r}")
return class_
return decorator


def get_edgelist_reader_for_url(url: str) -> AbstractEdgelistReader:
scheme = urlparse(url).scheme
try:
class_: Type[AbstractEdgelistReader] = EDGELIST_READERS[scheme]
except LookupError:
raise RuntimeError(f"Couldn't find any edgelist reader "
f"for scheme {scheme} used by {url}")
reader = class_(url)
return reader
EDGELIST_READERS = URLPluginRegistry[AbstractEdgelistReader]()


# Names and values of metadata attributes for the HDF5 files.
FORMAT_VERSION_ATTR = "format_version"
FORMAT_VERSION = 1


@register_edgelist_reader_for_scheme("") # No scheme
@register_edgelist_reader_for_scheme("file")
@EDGELIST_READERS.register_as("") # No scheme
@EDGELIST_READERS.register_as("file")
class FileEdgelistReader(AbstractEdgelistReader):
"""Reads partitioned edgelists from disk, in the format
created by edge_downloader.py.
Expand Down
35 changes: 4 additions & 31 deletions torchbiggraph/entity_count_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Type
from urllib.parse import urlparse

from torchbiggraph.plugin import URLPluginRegistry
from torchbiggraph.types import EntityName, Partition


Expand All @@ -33,42 +32,16 @@ def read_entity_count(
pass


ENTITY_COUNT_READERS: Dict[str, Type[AbstractEntityCountReader]] = {}


def register_entity_count_reader_for_scheme(
scheme: str,
) -> Callable[[Type[AbstractEntityCountReader]], Type[AbstractEntityCountReader]]:
def decorator(
class_: Type[AbstractEntityCountReader],
) -> Type[AbstractEntityCountReader]:
reg_class = ENTITY_COUNT_READERS.setdefault(scheme, class_)
if reg_class is not class_:
raise RuntimeError(
f"Attempting to re-register an entity count reader for scheme "
f"{scheme} which was already set to {reg_class!r}")
return class_
return decorator


def get_entity_count_reader_for_url(url: str) -> AbstractEntityCountReader:
scheme = urlparse(url).scheme
try:
class_: Type[AbstractEntityCountReader] = ENTITY_COUNT_READERS[scheme]
except LookupError:
raise RuntimeError(f"Couldn't find any edgelist reader "
f"for scheme {scheme} used by {url}")
reader = class_(url)
return reader
ENTITY_COUNT_READERS = URLPluginRegistry[AbstractEntityCountReader]()


# Names and values of metadata attributes for the HDF5 files.
FORMAT_VERSION_ATTR = "format_version"
FORMAT_VERSION = 1


@register_entity_count_reader_for_scheme("") # No scheme
@register_entity_count_reader_for_scheme("file")
@ENTITY_COUNT_READERS.register_as("") # No scheme
@ENTITY_COUNT_READERS.register_as("file")
class FileEntityCountReader(AbstractEntityCountReader):

def __init__(self, path: str) -> None:
Expand Down
4 changes: 2 additions & 2 deletions torchbiggraph/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torchbiggraph.checkpoint_manager import CheckpointManager
from torchbiggraph.config import add_to_sys_path, ConfigFileLoader, ConfigSchema
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.edgelist_reader import get_edgelist_reader_for_url
from torchbiggraph.edgelist_reader import EDGELIST_READERS
from torchbiggraph.model import MultiRelationEmbedder, Scores, make_model
from torchbiggraph.stats import Stats, average_of_sums
from torchbiggraph.types import Bucket, EntityName, Partition, Side
Expand Down Expand Up @@ -139,7 +139,7 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
logger.info(
f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} "
f"({edge_path})")
edgelist_reader = get_edgelist_reader_for_url(edge_path)
edgelist_reader = EDGELIST_READERS.make_instance(edge_path)

all_edge_path_stats = []
last_lhs, last_rhs = None, None
Expand Down
4 changes: 2 additions & 2 deletions torchbiggraph/filtered_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torchbiggraph.config import ConfigSchema
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.edgelist_reader import get_edgelist_reader_for_url
from torchbiggraph.edgelist_reader import EDGELIST_READERS
from torchbiggraph.eval import RankingEvaluator
from torchbiggraph.model import Scores
from torchbiggraph.stats import Stats
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, config: ConfigSchema, filter_paths: List[str]):
self.rhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list)
for path in filter_paths:
logger.info(f"Building links map from path {path}")
e_reader = get_edgelist_reader_for_url(path)
e_reader = EDGELIST_READERS.make_instance(path)
# Assume unpartitioned.
edges = e_reader.read_edgelist(Partition(0), Partition(0))
for idx in range(len(edges)):
Expand Down
22 changes: 5 additions & 17 deletions torchbiggraph/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn import functional as F

from torchbiggraph.model import match_shape
from torchbiggraph.plugin import PluginRegistry
from torchbiggraph.types import FloatTensorType


Expand All @@ -40,23 +41,10 @@ def forward(
pass


LOSS_FUNCTIONS: Dict[str, Type[AbstractLossFunction]] = {}
LOSS_FUNCTIONS = PluginRegistry[AbstractLossFunction]()


def register_loss_function_as(
name: str,
) -> Callable[[Type[AbstractLossFunction]], Type[AbstractLossFunction]]:
def decorator(class_: Type[AbstractLossFunction]) -> Type[AbstractLossFunction]:
reg_class = LOSS_FUNCTIONS.setdefault(name, class_)
if reg_class is not class_:
raise RuntimeError(
f"Attempting to re-register loss function {name} which was "
f"already set to {reg_class!r}")
return class_
return decorator


@register_loss_function_as("logistic")
@LOSS_FUNCTIONS.register_as("logistic")
class LogisticLossFunction(AbstractLossFunction):

def forward(
Expand All @@ -83,7 +71,7 @@ def forward(
return loss


@register_loss_function_as("ranking")
@LOSS_FUNCTIONS.register_as("ranking")
class RankingLossFunction(AbstractLossFunction):

def __init__(self, margin):
Expand Down Expand Up @@ -113,7 +101,7 @@ def forward(
return loss


@register_loss_function_as("softmax")
@LOSS_FUNCTIONS.register_as("softmax")
class SoftmaxLossFunction(AbstractLossFunction):

def forward(
Expand Down
Loading

0 comments on commit f4a819b

Please sign in to comment.