forked from RosettaCommons/rosetta
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit tests from pyrosetta.tests to PyRosetta testing server: Pull…
… request #3808 Add unit tests from pyrosetta.tests to PyRosetta testing server ---- An additional 14 commit messages were squashed into this commit: -- Removing pyrosetta.tests.protocols.indexed_structure_store.test_search unit test due to missing imports -- PEP8 updates to T900_distributed.py -- Add pip-installable 'scipy' to temporary testing virtual environment; PEP8 -- Resort to declaring explicit list of unit tests available in pyrosetta.tests for testing on the PyRosetta testing server -- Update T900_distributed.py to print list of automatically discovered unit tests to PyRosetta testing server tracer -- Implementing automatic unit test discovery in pyrosetta.tests, skipping test_gil.py unit test -- Testing updated test_gil.py unit test on PyRosetta testing server. -- Assertion bug fix and updating docstrings. -- Update test_dask.py unit test to run in temporary directory, and update unittest framework in other unit tests. -- Add pip-installable 'dask' and 'distributed' to temporary testing virtual environment -- Temporarily remove pyrosetta.tests.distributed.test_gil unit tests from PyRosetta testing server -- Add pip-installable 'blosc' to temporary testing virtual environment -- Removing dependency on unittest.TestLoader().discover() and declaring pyrosetta.tests.distributed unit tests explicitly -- Updating T900 tests covering pyrosetta.tests.distributed test suites with unittest framework -- Old repository SHA1: 1e0f021a01a3fe9a27617789f896bdd30ec32798
- Loading branch information
Showing
8 changed files
with
338 additions
and
97 deletions.
There are no files selected for viewing
175 changes: 175 additions & 0 deletions
175
source/src/python/PyRosetta/src/pyrosetta/tests/bindings/core/test_pose.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# :noTabs=true: | ||
# | ||
# (c) Copyright Rosetta Commons Member Institutions. | ||
# (c) This file is part of the Rosetta software suite and is made available under license. | ||
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons. | ||
# (c) For more information, see http://www.rosettacommons.org. | ||
# (c) Questions about this can be addressed to University of Washington CoMotion, email: [email protected]. | ||
|
||
import pyrosetta | ||
import pyrosetta.rosetta.core.pose as pose | ||
import unittest | ||
|
||
|
||
pyrosetta.init(extra_options="-constant_seed", set_logging_handler="logging") | ||
|
||
class TestPoseResidueAccessor(unittest.TestCase): | ||
|
||
def test_residues(self): | ||
|
||
pose1 = pyrosetta.pose_from_sequence('ACDEFGHI') | ||
|
||
# Test __len__ | ||
self.assertEqual(0, len(pose.Pose().residues)) | ||
self.assertEqual(8, len(pose1.residues)) | ||
self.assertEqual(8, len(pose1)) # Deprecated | ||
|
||
# Test __iter__ | ||
self.assertEqual(0, len(list(pose.Pose().residues))) | ||
self.assertEqual(8, len(list(pose1.residues))) | ||
self.assertEqual(8, len(list(pose1))) # Deprecated | ||
|
||
# Test __getitem__ | ||
# assert(pose1.residues[0] == ValueError) | ||
# assert(pose1.residues[0:] == ValueError) | ||
self.assertEqual( | ||
pose1.residues[1].annotated_name(), 'A[ALA:NtermProteinFull]') | ||
|
||
self.assertEqual(pose1.residues[6].annotated_name(), 'G') | ||
self.assertEqual(pose1.residues[8].annotated_name(), 'I[ILE:CtermProteinFull]') | ||
self.assertEqual(pose1.residues[-1].annotated_name(), 'I[ILE:CtermProteinFull]') | ||
self.assertEqual(pose1.residues[-3].annotated_name(), 'G') | ||
self.assertEqual(pose1.residues[-8].annotated_name(), 'A[ALA:NtermProteinFull]') | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose.Pose().residues[:]]), '') | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose1.residues[:]]), | ||
'A[ALA:NtermProteinFull]CDEFGHI[ILE:CtermProteinFull]') | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose1.residues[1:9]]), | ||
'A[ALA:NtermProteinFull]CDEFGHI[ILE:CtermProteinFull]') | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose1.residues[:-3]]), 'A[ALA:NtermProteinFull]CDEF') | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose1.residues[3:]]), 'DEFGHI[ILE:CtermProteinFull]') | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose1.residues[-6:8]]), 'DEFGH') | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose1.residues[-6:8:2]]), 'DFH') | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose1.residues[-6:8:3]]), 'DG') | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose1[1:9]]), 'A[ALA:NtermProteinFull]CDEFGHI[ILE:CtermProteinFull]') # Deprecated | ||
self.assertEqual( | ||
''.join([res.annotated_name() for res in pose1[-6:8]]), 'DEFGH') # Deprecated | ||
|
||
# #Test __iadd__ | ||
# gly_residue = pose1.residues[6] | ||
# pose1.residues += gly_residue | ||
# self.assertEqual( | ||
# ''.join([res.annotated_name() for res in pose1.residues]), 'A[ALA:NtermProteinFull]CDEFGHIG[GLY:CtermProteinFull]') | ||
|
||
# pose2 = Pose() | ||
# pose2.residues += pose1.residues[1] | ||
# for _ in range(3): | ||
# pose2.residues += gly_residue | ||
# pose2.residues += pose1.residues[-1] | ||
# pose2.residues += gly_residue | ||
# self.assertEqual( | ||
# ''.join([res.annotated_name() for res in pose2.residues]), 'A[ALA:NtermProteinFull]GGGGG[GLY:CtermProteinFull]') | ||
|
||
# #Test __imul__ | ||
# pose3 = Pose() | ||
# pose3.residues *= pose1.residues[5] | ||
# self.assertEqual(''.join([res.annotated_name() for res in pose3.residues]), 'F') | ||
# pose3.residues *= pose1.residues[5] | ||
# self.assertEqual(''.join([res.annotated_name() for res in pose3.residues]), 'FF') | ||
# pose3.residues *= pose1.residues[5] | ||
# self.assertEqual(''.join([res.annotated_name() for res in pose3.residues]), 'FFF') | ||
|
||
|
||
class TestPoseScoresAccessor(unittest.TestCase): | ||
|
||
def test_scores(self): | ||
|
||
test_pose = pyrosetta.pose_from_sequence("TESTTESTTEST") | ||
|
||
self.assertDictEqual(dict(test_pose.scores), {}) | ||
|
||
# Test proper overwrite of extra scores of varying types. | ||
test_pose.scores["foo"] = "bar" | ||
self.assertDictEqual(dict(test_pose.scores), {"foo" : "bar"}) | ||
test_pose.scores["foo"] = 1 | ||
self.assertDictEqual(dict(test_pose.scores), {"foo" : 1.0}) | ||
|
||
test_pose.scores["bar"] = 2 | ||
self.assertDictEqual(dict(test_pose.scores), {"foo" : 1.0, "bar" : 2.0}) | ||
|
||
# Test score deletion | ||
del test_pose.scores["foo"] | ||
self.assertDictEqual(dict(test_pose.scores), {"bar" : 2.0}) | ||
|
||
# Test exception when setting reserved names | ||
with self.assertRaises(ValueError): | ||
test_pose.scores["fa_atr"] = "invalid" | ||
|
||
# Test score update after scoring | ||
self.assertNotIn("fa_atr", test_pose.scores) | ||
pyrosetta.get_score_function()(test_pose) | ||
self.assertIn("fa_atr", test_pose.scores) | ||
|
||
# Test clear w/ energies | ||
with self.assertRaises(ValueError): | ||
del test_pose.scores["fa_atr"] | ||
|
||
test_pose.energies().clear() | ||
self.assertNotIn("fa_atr", test_pose.scores) | ||
self.assertDictEqual(dict(test_pose.scores), {"bar" : 2.0}) | ||
|
||
test_pose.scores.clear() | ||
self.assertDictEqual(dict(test_pose.scores), dict()) | ||
|
||
|
||
class TestPoseResidueLabelAccessor(unittest.TestCase): | ||
|
||
def test_labels(self): | ||
|
||
test_pose = pyrosetta.pose_from_sequence("TESTTESTTEST") | ||
|
||
self.assertSequenceEqual( | ||
list(test_pose.reslabels), [set()] * len(test_pose.residues)) | ||
|
||
test_pose.reslabels[1].add("foo") | ||
test_pose.reslabels[-1].add("bar") | ||
test_pose.reslabels[-1].add("blah") | ||
self.assertSequenceEqual( | ||
list(test_pose.reslabels), | ||
[{"foo"}] + [set()] * (len(test_pose.residues) - 2) + [{"bar", "blah"}]) | ||
|
||
self.assertSequenceEqual( | ||
test_pose.reslabels.mask["foo"], | ||
[True] + [False] * (len(test_pose.residues) - 1)) | ||
self.assertSequenceEqual( | ||
test_pose.reslabels.mask["bar"], | ||
[False] * (len(test_pose.residues) - 1) + [True]) | ||
self.assertSequenceEqual( | ||
test_pose.reslabels.mask["blah"], | ||
[False] * (len(test_pose.residues) - 1) + [True]) | ||
|
||
self.assertSetEqual(set(test_pose.reslabels.mask), {"foo", "bar", "blah"}) | ||
|
||
test_pose.reslabels[-1].clear() | ||
self.assertSequenceEqual( | ||
list(test_pose.reslabels), | ||
[{"foo"}] + [set()] * (len(test_pose.residues) - 1)) | ||
|
||
test_pose.reslabels[-1].add("bar") | ||
test_pose.reslabels[-1].add("blah") | ||
test_pose.reslabels[-1].discard("blah") | ||
self.assertSequenceEqual( | ||
list(test_pose.reslabels), | ||
[{"foo"}] + [set()] * (len(test_pose.residues) - 2) + [{"bar"}]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,33 @@ | ||
import unittest | ||
import threading | ||
import time | ||
# :noTabs=true: | ||
# | ||
# (c) Copyright Rosetta Commons Member Institutions. | ||
# (c) This file is part of the Rosetta software suite and is made available under license. | ||
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons. | ||
# (c) For more information, see http://www.rosettacommons.org. | ||
# (c) Questions about this can be addressed to University of Washington CoMotion, email: [email protected]. | ||
|
||
import numpy | ||
import time | ||
import threading | ||
import unittest | ||
|
||
import pyrosetta.distributed.io as io | ||
import pyrosetta.distributed.tasks.rosetta_scripts as rosetta_scripts | ||
import pyrosetta.distributed.tasks.score as score | ||
|
||
|
||
class TestConcurrentScripts(unittest.TestCase): | ||
|
||
def test_concurrent_on_task(self): | ||
|
||
protocol = rosetta_scripts.SingleoutputRosettaScriptsTask(""" | ||
<ROSETTASCRIPTS> | ||
<MOVERS> | ||
<FastRelax name="score" repeats="1" /> | ||
</MOVERS> | ||
<PROTOCOLS> | ||
<Add mover_name="score"/> | ||
</PROTOCOLS> | ||
<MOVERS> | ||
<FastRelax name="score" repeats="1"/> | ||
</MOVERS> | ||
<PROTOCOLS> | ||
<Add mover_name="score"/> | ||
</PROTOCOLS> | ||
</ROSETTASCRIPTS> | ||
""") | ||
|
||
|
@@ -37,21 +43,17 @@ def run_task(seq): | |
test_pose = io.pose_from_sequence(seq) | ||
protocol = rosetta_scripts.SingleoutputRosettaScriptsTask(""" | ||
<ROSETTASCRIPTS> | ||
<MOVERS> | ||
<FastRelax name="score" repeats="1" /> | ||
</MOVERS> | ||
<PROTOCOLS> | ||
<Add mover_name="score"/> | ||
</PROTOCOLS> | ||
<MOVERS> | ||
<FastRelax name="score" repeats="1" /> | ||
</MOVERS> | ||
<PROTOCOLS> | ||
<Add mover_name="score"/> | ||
</PROTOCOLS> | ||
</ROSETTASCRIPTS> | ||
""") | ||
|
||
return protocol(test_pose) | ||
|
||
|
||
import concurrent.futures | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as p: | ||
result = list(p.map(run_task, ["TEST"] * 3)) | ||
|
100 changes: 63 additions & 37 deletions
100
source/src/python/PyRosetta/src/pyrosetta/tests/distributed/test_dask.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,90 @@ | ||
import logging | ||
# :noTabs=true: | ||
# | ||
# (c) Copyright Rosetta Commons Member Institutions. | ||
# (c) This file is part of the Rosetta software suite and is made available under license. | ||
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons. | ||
# (c) For more information, see http://www.rosettacommons.org. | ||
# (c) Questions about this can be addressed to University of Washington CoMotion, email: [email protected]. | ||
|
||
import argparse | ||
import unittest | ||
|
||
import dask.distributed | ||
import logging | ||
import tempfile | ||
import unittest | ||
|
||
import pyrosetta.distributed.io as io | ||
import pyrosetta.distributed.packed_pose as packed_pose | ||
import pyrosetta.distributed.tasks.rosetta_scripts as rosetta_scripts | ||
|
||
|
||
class TestDaskDistribution(unittest.TestCase): | ||
|
||
_dask_scheduler = None | ||
|
||
with tempfile.TemporaryDirectory() as workdir: | ||
|
||
def setUp(self, local_dir=workdir): | ||
if not self._dask_scheduler: | ||
self.local_cluster = dask.distributed.LocalCluster( | ||
n_workers=2, threads_per_worker=2, diagnostics_port=None, local_dir=local_dir | ||
) | ||
cluster = self.local_cluster | ||
else: | ||
self.local_cluster = None | ||
cluster = self._dask_scheduler | ||
|
||
def setUp(self): | ||
if not self._dask_scheduler: | ||
self.local_cluster = dask.distributed.LocalCluster(n_workers=2, threads_per_worker=2) | ||
cluster = self.local_cluster | ||
else: | ||
self.local_cluster = None | ||
cluster = self._dask_scheduler | ||
self.client = dask.distributed.Client(cluster) | ||
|
||
self.client = dask.distributed.Client(cluster) | ||
def tearDown(self): | ||
self.client.close() | ||
def tearDown(self): | ||
self.client.close() | ||
|
||
if self.local_cluster: | ||
self.local_cluster.close() | ||
|
||
if self.local_cluster: | ||
self.local_cluster.close() | ||
def test_rosetta_scripts(self): | ||
test_protocol = """ | ||
<ROSETTASCRIPTS> | ||
<TASKOPERATIONS> | ||
<RestrictToRepacking name="repack"/> | ||
</TASKOPERATIONS> | ||
<MOVERS> | ||
<PackRotamersMover name="pack" task_operations="repack"/> | ||
</MOVERS> | ||
<PROTOCOLS> | ||
<Add mover="pack"/> | ||
</PROTOCOLS> | ||
</ROSETTASCRIPTS> | ||
""" | ||
|
||
def test_rosetta_scripts(self): | ||
test_protocol = """ | ||
<ROSETTASCRIPTS> | ||
<MOVERS> | ||
<PackRotamersMover name="pack"/> | ||
</MOVERS> | ||
<PROTOCOLS> | ||
<Add mover="pack" /> | ||
</PROTOCOLS> | ||
</ROSETTASCRIPTS> | ||
""" | ||
test_pose = io.pose_from_sequence("TEST") | ||
test_task = rosetta_scripts.SingleoutputRosettaScriptsTask(test_protocol) | ||
|
||
test_pose = io.pose_from_sequence("TESTTESTTEST") | ||
test_task = rosetta_scripts.SingleoutputRosettaScriptsTask(test_protocol) | ||
logging.info("dask client: %s", self.client) | ||
task = self.client.submit(test_task, test_pose) | ||
result = task.result() | ||
self.assertEqual( | ||
packed_pose.to_pose(result).sequence(), | ||
packed_pose.to_pose(test_pose).sequence() | ||
) | ||
|
||
logging.info("dask client: %s", self.client) | ||
task = self.client.submit(test_task, test_pose) | ||
result = task.result() | ||
def test_basic(self): | ||
logging.info("dask client: %s", self.client) | ||
task = self.client.submit(lambda a: a + 1, 1) | ||
self.assertEqual(task.result(), 2) | ||
|
||
def test_basic(self): | ||
logging.info("dask client: %s", self.client) | ||
task = self.client.submit(lambda a: a + 1, 1) | ||
self.assertEqual(task.result(), 2) | ||
|
||
if __name__ == "__main__": | ||
|
||
logging.basicConfig( | ||
level=logging.INFO, | ||
format="%(asctime)s.%(msecs).03d %(name)s %(message)s", | ||
datefmt='%Y-%m-%dT%H:%M:%S' | ||
) | ||
|
||
parser = argparse.ArgumentParser(description='Run initial pyrosetta.distributed smoke test over given scheduler.') | ||
parser.add_argument('scheduler', type=str, nargs='?', help='Target scheduler endpoint for test.') | ||
parser = argparse.ArgumentParser( | ||
description="Run initial pyrosetta.distributed smoke test over given scheduler." | ||
) | ||
parser.add_argument("scheduler", type=str, nargs="?", help="Target scheduler endpoint for test.") | ||
args = parser.parse_args() | ||
|
||
TestDaskDistribution._dask_scheduler = args.scheduler | ||
|
Oops, something went wrong.