Skip to content

Commit

Permalink
set seed
Browse files Browse the repository at this point in the history
  • Loading branch information
nagadomi committed Dec 9, 2022
1 parent 427a4ea commit 876f4e3
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
5 changes: 3 additions & 2 deletions nunif/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from . import models
from . import modules
from . import transforms
from . import utils
from . import training
from . import initializer

__version__ = "0.0.1"

__all__ = ["models", "modules", "transforms", "utils"]
__all__ = ["models", "modules", "transforms", "training", "initializer"]
14 changes: 8 additions & 6 deletions nunif/initialize.py → nunif/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
import numpy as np


def global_initialize():
def disable_image_lib_threads():
# Disable OpenMP
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OMP_THREAD_LIMIT'] = '1'
#os.environ['OMP_NUM_THREADS'] = '1'
#os.environ['OMP_THREAD_LIMIT'] = '1'

# Disable ImageMagick's Threading
os.environ['MAGICK_THREAD_LIMIT'] = '1'
try:
from wand.resource import limits
limits["thread"] = 1
except ModuleNotFoundError:
pass

# Disable OpenCV's Threading/OpenCL
try:
Expand All @@ -27,6 +32,3 @@ def set_seed(seed):
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)


global_initialize()
2 changes: 2 additions & 0 deletions nunif/training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .env import BaseEnv, SoftMaxEnv, I2IEnv, RGBPSNREnv, LuminancePSNREnv
from .confusion_matrix import SoftMaxConfusionMatrix
6 changes: 6 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
from nunif.addon import load_addons
from multiprocessing import cpu_count
from nunif.initializer import set_seed, disable_image_lib_threads


def add_default_options(parser):
Expand Down Expand Up @@ -30,6 +31,7 @@ def add_default_options(parser):
parser.add_argument("--amp", action="store_true", help="with AMP")
parser.add_argument("--resume", action="store_true", help="resume training from the latest checkpoint")
parser.add_argument("--reset-state", action="store_true", help="reset optimizer,scheduler states for --resume")
parser.add_argument("--seed", type=int, default=71, help="random seed")


def main():
Expand All @@ -40,6 +42,10 @@ def main():
if subparser is not None:
add_default_options(subparser)
args = parser.parse_args()

disable_image_lib_threads()
set_seed(args.seed)

assert(args.handler is not None)
print(vars(args))
args.handler(args)
Expand Down

0 comments on commit 876f4e3

Please sign in to comment.