Skip to content

Commit

Permalink
Merge pull request #4179 from RosettaCommons/klimaj/init_distributed
Browse files Browse the repository at this point in the history
Adding a pyrosetta.distributed.init() function

Old repository SHA1: e7209884267213089de8e6409f3e9554dfae0b3e
  • Loading branch information
klimaj authored Aug 26, 2019
1 parent be71ccc commit 104b8e8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# :noTabs=true:
#
# (c) Copyright Rosetta Commons Member Institutions.
# (c) This file is part of the Rosetta software suite and is made available under license.
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
# (c) For more information, see http://www.rosettacommons.org.
# (c) Questions about this can be addressed to University of Washington CoMotion, email: [email protected].


import functools
import logging
import threading
Expand All @@ -9,13 +18,20 @@

_logger = logging.getLogger("pyrosetta.distributed")

__all__ = ["maybe_init", "requires_init", "with_lock"]
__all__ = ["init", "maybe_init", "requires_init", "with_lock"]

# Access lock for any non-threadsafe calls into rosetta internals.
# Intended to provide a threadsafe api surface area to `distributed`.
_access_lock = threading.RLock()


def _normflags(flags):
"""Normalize tuple/list/str of flags into str."""
if not isinstance(flags, str):
flags = " ".join(flags)
return " ".join(" ".join([line.split("#")[0] for line in flags.split("\n")]).split())


def with_lock(func):
"""Function decorator that protects access to rosetta internals."""
@functools.wraps(func)
Expand Down Expand Up @@ -71,3 +87,10 @@ def fwrap(*args, **kwargs):
return func(*args, **kwargs)

return fwrap


def init(options=None, **kwargs):
"""Initialize PyRosetta with command line options."""
if options and ("extra_options" not in kwargs):
kwargs["extra_options"] = _normflags(options)
maybe_init(**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
import pyrosetta.distributed


def _normflags(flags):
"""Normalize tuple/list/str of flags into str."""
if not isinstance(flags, str):
flags = " ".join(flags)
return " ".join(" ".join([line.split("#")[0] for line in flags.split("\n")]).split())


def worker_extra(init_flags=None, local_directory=None):
"""Format flags and local directory for dask worker preload.
Expand All @@ -39,7 +32,6 @@ def worker_extra(init_flags=None, local_directory=None):
extra=pyrosetta.distributed.dask.worker_extra(init_flags, local_directory)
)
"""

extras = []
if local_directory:
extras.extend(["--local-directory", local_directory])
Expand All @@ -51,7 +43,7 @@ def worker_extra(init_flags=None, local_directory=None):
extras.extend(
[
"--preload pyrosetta.distributed.dask.worker ' {0}'".format(
_normflags(init_flags)
pyrosetta.distributed._normflags(init_flags)
)
]
)
Expand All @@ -67,11 +59,4 @@ def init_notebook(init_flags=None):
import pyrosetta.distributed.dask
pyrosetta.distributed.dask.init_notebook(init_flags)
"""
kwargs = {}

if init_flags:
kwargs["extra_options"] = _normflags(init_flags)
else:
kwargs["extra_options"] = ""

pyrosetta.distributed.maybe_init(**kwargs)
pyrosetta.distributed.init(options=init_flags)
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def test_score_smoke_test(self):
Which is to say, turn on the power and look for magic smoke.
"""
flags = """
-ignore_unrecognized_res 1
-ex4 # Test comment 1
-out:level 300 ### Test comment 2
"""
pyrosetta.distributed.init(flags)

score_task = score.ScorePoseTask()
rs_task = rosetta_scripts.SingleoutputRosettaScriptsTask(self.min_rs)
Expand Down

0 comments on commit 104b8e8

Please sign in to comment.