Skip to content

Latest commit

 

History

History
101 lines (74 loc) · 3.55 KB

hub.rst.txt

File metadata and controls

101 lines (74 loc) · 3.55 KB

torch.hub

Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.

Publishing models

Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a github repository by adding a simple hubconf.py file;

hubconf.py can have multiple entrypoints. Each entrypoint is defined as a python function with the following signature.

def entrypoint_name(pretrained=False, *args, **kwargs):
    ...

How to implement an entrypoint?

Here is a code snipet from pytorch/vision repository, which specifies an entrypoint for resnet18 model. You can see a full script in pytorch/vision repo

dependencies = ['torch', 'math']

def resnet18(pretrained=False, *args, **kwargs):
    """
    Resnet18 model
    pretrained (bool): a recommended kwargs for all entrypoints
    args & kwargs are arguments for the function
    """
    ######## Call the model in the repo ###############
    from torchvision.models.resnet import resnet18 as _resnet18
    model = _resnet18(*args, **kwargs)
    ######## End of call ##############################
    # The following logic is REQUIRED
    if pretrained:
        # For weights saved in local repo
                    # model.load_state_dict(<path_to_saved_file>)

                    # For weights saved elsewhere
                    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
        model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
    return model
  • dependencies variable is a list of package names required to to run the model.
  • Pretrained weights can either be stored local in the github repo, or loadable by model_zoo.load().
  • pretrained controls whether to load the pre-trained weights provided by repo owners.
  • args and kwargs are passed along to the real callable function.
  • Docstring of the function works as a help message, explaining what does the model do and what are the allowed arguments.
  • Entrypoint function should ALWAYS return a model(nn.module).

Important Notice

  • The published models should be at least in a branch/tag. It can't be a random commit.

Loading models from Hub

Users can load the pre-trained models using torch.hub.load() API.

.. automodule:: torch.hub
.. autofunction:: load

Here's an example loading resnet18 entrypoint from pytorch/vision repo.

hub_model = hub.load(
    'pytorch/vision:master', # repo_owner/repo_name:branch
    'resnet18', # entrypoint
    1234, # args for callable [not applicable to resnet]
    pretrained=True) # kwargs for callable

Where are my downloaded model & weights saved?

The locations are used in the order of

  • hub_dir: user specified path. It can be set in the following ways: - Setting the environment variable TORCH_HUB_DIR - Calling hub.set_dir(<PATH_TO_HUB_DIR>)
  • ~/.torch/hub
.. autofunction:: set_dir

Caching logic

By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in hub_dir.

Users can force a reload by calling hub.load(..., force_reload=True). This will delete the existing github folder and downloaded weights, reinitialize a fresh download. This is useful when updates are published to the same branch, users can keep up with the latest release.