Skip to content

Commit

Permalink
[feature] CLI commands for PyPI package along with user_dir option (f…
Browse files Browse the repository at this point in the history
…acebookresearch#104)

Summary:
- Allows users to specify their own directory as a one step forward
towards MMF as a lib
- mmf_run commmand to run mmf training from virtual anywhere
- Add functions for easy imports on user end
- Fixes test_results upload issue on circleci
- Address some comments from previous PR
- User dir can also be specified via MMF_USER_DIR
Pull Request resolved: https://github.com/fairinternal/mmf-internal/pull/104

Test Plan: Install with `python setup.py develop` and use mmf_train instead of `python -u tools/run.py` to run all of your commands now.

Reviewed By: vedanuj

Differential Revision: D21173056

Pulled By: apsdehal

fbshipit-source-id: de24b990e5c18e478f413a7c3f6b23b6abce6949
  • Loading branch information
apsdehal committed May 8, 2020
1 parent 95340d1 commit 4517e13
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 13 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ jobs:
- <<: *run_unittests

- store_test_results:
path: test-results
path: tests/test-results

gpu_tests:
<<: *gpu
Expand Down Expand Up @@ -166,7 +166,7 @@ jobs:
- <<: *run_unittests

- store_test_results:
path: test-results
path: tests/test-results


workflows:
Expand Down
2 changes: 2 additions & 0 deletions mmf/configs/datasets/coco/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ dataset_config:
fast_read: false
use_images: false
use_features: true
# annotation_style can be coco or textcaps which allows us to override
# the dataset class
annotation_style: coco
features:
train:
Expand Down
2 changes: 2 additions & 0 deletions mmf/configs/datasets/textcaps/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ dataset_config:
use_images: false
use_features: true
use_order_vectors: true
# annotation_style can be coco or textcaps which allows us to override
# the dataset class
annotation_style: textcaps
features:
train:
Expand Down
5 changes: 5 additions & 0 deletions mmf/configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ env:
# Use MMF_TENSORBOARD_LOGDIR or env.tensorboard_logdir to override
tensorboard_logdir: ${env:MMF_TENSORBOARD_LOGDIR,}

# User directory where user can keep their own models independent of MMF
# This allows users to create projects which only include MMF as dependency
# Use MMF_USER_DIR or env.user_dir to specify
user_dir: ${env:MMF_USER_DIR,}

# Configuration for the distributed setup
distributed:
# Typically tcp://hostname:port that will be used to establish initial connection
Expand Down
12 changes: 11 additions & 1 deletion mmf/datasets/databases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import mmf.datasets.databases.readers # noqa

from .annotation_database import AnnotationDatabase
from .features_database import FeaturesDatabase
from .image_database import ImageDatabase
from .scene_graph_database import SceneGraphDatabase

__all__ = ["AnnotationDatabase"]
__all__ = [
"AnnotationDatabase",
"FeaturesDatabase",
"ImageDatabase",
"SceneGraphDatabase",
]
Empty file.
9 changes: 7 additions & 2 deletions mmf/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ def load(self):
self.dataset_loader = DatasetLoader(self.config)
self._datasets = self.config.datasets

self.writer = Logger(self.config)
registry.register("writer", self.writer)
# Check if loader is already defined, else init it
writer = registry.get("writer", no_warning=True)
if writer:
self.writer = writer
else:
self.writer = Logger(self.config)
registry.register("writer", self.writer)

self.configuration.pretty_print()

Expand Down
28 changes: 25 additions & 3 deletions mmf/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from omegaconf import OmegaConf

from mmf.common.registry import registry
from mmf.utils.env import import_user_module
from mmf.utils.file_io import PathManager
from mmf.utils.general import get_mmf_root

Expand Down Expand Up @@ -190,15 +191,14 @@ def __init__(self, args=None, default_only=False):
self.args = args
self._register_resolvers()

default_config = self._build_default_config()
self._default_config = self._build_default_config()

if default_only:
other_configs = {}
else:
other_configs = self._build_other_configs()

self._default_config = default_config
self.config = OmegaConf.merge(default_config, other_configs)
self.config = OmegaConf.merge(self._default_config, other_configs)

self.config = self._merge_with_dotlist(self.config, args.opts)
self._update_specific(self.config)
Expand All @@ -213,8 +213,11 @@ def _build_other_configs(self):
opts_config = self._build_opt_list(self.args.opts)
user_config = self._build_user_config(opts_config)

self._opts_config = opts_config
self._user_config = user_config

self.import_user_dir()

model_config = self._build_model_config(opts_config)
dataset_config = self._build_dataset_config(opts_config)
args_overrides = self._build_demjson_config(self.args.config_override)
Expand All @@ -238,6 +241,25 @@ def _build_user_config(self, opts):

return user_config

def import_user_dir(self):
# Try user_dir options in order of MMF configuration hierarchy
# First try the default one, which can be set via environment as well
user_dir = self._default_config.env.user_dir

# Now, check user's config
user_config_user_dir = self._user_config.get("env", {}).get("user_dir", None)

if user_config_user_dir:
user_dir = user_config_user_dir

# Finally, check opts
opts_user_dir = self._opts_config.get("env", {}).get("user_dir", None)
if opts_user_dir:
user_dir = opts_user_dir

if user_dir:
import_user_module(user_dir)

def _build_model_config(self, config):
model = config.model
if model is None:
Expand Down
61 changes: 61 additions & 0 deletions mmf/utils/env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import importlib
import os
import random
import sys
from datetime import datetime

import numpy as np
import torch

from mmf.utils.general import get_absolute_path


def set_seed(seed):
if seed:
Expand All @@ -20,3 +24,60 @@ def set_seed(seed):
random.seed(seed)

return seed


def import_user_module(user_dir: str, no_print: bool = False):
"""Given a user dir, this function imports it as a module.
This user_module is expected to have an __init__.py at its root.
You can use import_files to import your python files easily in
__init__.py
Args:
user_dir (str): directory which has to be imported
no_print (bool): This function won't print anything if set to true
"""
if user_dir:
user_dir = get_absolute_path(user_dir)
module_parent, module_name = os.split(user_dir)

if module_name not in sys.modules:
sys.path.insert(0, module_parent)
if not no_print:
print(f"Importing user_dir from {user_dir}")
importlib.import_module(module_name)
sys.path.pop(0)


def import_files(file_path: str, module_name: str = None):
"""The function imports all of the files present in file_path's directory.
This is useful for end user in case they want to easily import files without
mentioning each of them in their __init__.py. module_name if specified
is the full path to module under which all modules will be imported.
my_project/
my_models/
my_model.py
__init__.py
Contents of __init__.py
```
from mmf.utils.env import import_files
import_files(__file__, "my_project.my_models")
```
This will then allow you to import `my_project.my_models.my_model` anywhere.
Args:
file_path (str): Path to file in whose directory everything will be imported
module_name (str): Module name if this file under some specified structure
"""
for file in os.listdir(os.path.dirname(file_path)):
if file.endswith(".py") and not file.startswith("_"):
import_name = file[: file.find(".py")]
if module_name:
importlib.import_module(f"{module_name}.{import_name}")
else:
importlib.import_module(f"{import_name}")
2 changes: 2 additions & 0 deletions mmf/utils/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def get_parser(self):

def add_core_args(self):
self.parser.add_argument_group("Core Arguments")
# TODO: Add Help flag here describing MMF Configuration
# and point to configuration documentation
self.parser.add_argument(
"-co",
"--config_override",
Expand Down
8 changes: 5 additions & 3 deletions mmf/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class Logger:
def __init__(self, config):
def __init__(self, config, name=None):
self.logger = None
self._is_master = is_master()

Expand Down Expand Up @@ -45,8 +45,10 @@ def __init__(self, config):

logging.captureWarnings(True)

self.logger = logging.getLogger(__name__)
self._file_only_logger = logging.getLogger(__name__)
if not name:
name = __name__
self.logger = logging.getLogger(name)
self._file_only_logger = logging.getLogger(name)
warnings_logger = logging.getLogger("py.warnings")

# Set level
Expand Down
2 changes: 1 addition & 1 deletion mmf/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

__version__ = "0.9.alpha2"
__version__ = "0.9.alpha4"

msg = "MMF is only compatible with Python 3.6 and newer."

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import re
import shutil
import sys
from glob import glob

import setuptools
Expand Down Expand Up @@ -113,4 +112,5 @@ def run(self):
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Operating System :: OS Independent",
],
entry_points={"console_scripts": ["mmf_run = tools.run:run"]},
)
Empty file added tools/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions tools/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
import random

Expand All @@ -10,10 +11,13 @@
from mmf.utils.env import set_seed
from mmf.utils.flags import flags
from mmf.utils.general import setup_imports
from mmf.utils.logger import Logger


def main(configuration, init_distributed=False):
# A reload might be needed for imports
setup_imports()
configuration.import_user_dir()
config = configuration.get_config()

if torch.cuda.is_available():
Expand All @@ -27,6 +31,8 @@ def main(configuration, init_distributed=False):
registry.register("seed", config.training.seed)
print("Using seed {}".format(config.training.seed))

registry.register("writer", Logger(config, name="mmf.train"))

trainer = build_trainer(configuration)
trainer.load()
trainer.train()
Expand Down

0 comments on commit 4517e13

Please sign in to comment.