Skip to content

Commit

Permalink
Black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippThoelke committed Jul 27, 2021
1 parent 0332657 commit cdeba74
Show file tree
Hide file tree
Showing 20 changed files with 766 additions and 423 deletions.
31 changes: 17 additions & 14 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def get_args():
args = parser.parse_args()

if args.redirect:
sys.stdout = open(os.path.join(args.log_dir, 'log'), 'w')
sys.stdout = open(os.path.join(args.log_dir, "log"), "w")
sys.stderr = sys.stdout

if args.inference_batch_size is None:
args.inference_batch_size = args.batch_size

save_argparse(args, os.path.join(args.log_dir, 'input.yaml'), exclude=['conf'])
save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"])

return args

Expand All @@ -111,12 +111,14 @@ def main():
# initialize data module
data = DataModule(args)
data.prepare_data()
data.setup('fit')
data.setup("fit")

prior = None
if args.prior_model:
assert hasattr(priors, args.prior_model), (f'Unknown prior model {args["prior_model"]}. '
f'Available models are {", ".join(priors.__all__)}')
assert hasattr(priors, args.prior_model), (
f'Unknown prior model {args["prior_model"]}. '
f'Available models are {", ".join(priors.__all__)}'
)
# initialize the prior model
prior = getattr(priors, args.prior_model)(dataset=data.dataset)
args.prior_args = prior.get_init_args()
Expand All @@ -126,19 +128,20 @@ def main():

checkpoint_callback = ModelCheckpoint(
dirpath=args.log_dir,
monitor='val_loss',
save_top_k=10, # -1 to save all
monitor="val_loss",
save_top_k=10, # -1 to save all
period=args.save_interval,
filename='{epoch}-{val_loss:.4f}-{test_loss:.4f}'
filename="{epoch}-{val_loss:.4f}-{test_loss:.4f}",
)
early_stopping = EarlyStopping('val_loss', patience=args.early_stopping_patience)
early_stopping = EarlyStopping("val_loss", patience=args.early_stopping_patience)

tb_logger = pl.loggers.TensorBoardLogger(args.log_dir, name='tensorbord', version='',
default_hp_metric=False)
csv_logger = CSVLogger(args.log_dir, name='', version='')
tb_logger = pl.loggers.TensorBoardLogger(
args.log_dir, name="tensorbord", version="", default_hp_metric=False
)
csv_logger = CSVLogger(args.log_dir, name="", version="")

ddp_plugin = None
if 'ddp' in args.distributed_backend:
if "ddp" in args.distributed_backend:
ddp_plugin = DDPPlugin(find_unused_parameters=False, num_nodes=args.num_nodes)

trainer = pl.Trainer(
Expand All @@ -153,7 +156,7 @@ def main():
logger=[tb_logger, csv_logger],
reload_dataloaders_every_epoch=False,
precision=args.precision,
plugins=[ddp_plugin]
plugins=[ddp_plugin],
)

trainer.fit(model, data)
Expand Down
16 changes: 10 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
from setuptools import setup, find_packages

try:
version = subprocess.check_output(['git', 'describe', '--abbrev=0', '--tags']).strip().decode('utf-8')
version = (
subprocess.check_output(["git", "describe", "--abbrev=0", "--tags"])
.strip()
.decode("utf-8")
)
except:
print('Failed to retrieve the current version, defaulting to 0')
version = '0'
print("Failed to retrieve the current version, defaulting to 0")
version = "0"

with open('requirements.txt') as f:
with open("requirements.txt") as f:
requirements = f.read().splitlines()

setup(
name='torchmd-net',
name="torchmd-net",
version=version,
packages=find_packages(),
install_requires=requirements
install_requires=requirements,
)
6 changes: 4 additions & 2 deletions torchmdnet/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@


class External:
def __init__(self, netfile, embeddings, device='cpu'):
def __init__(self, netfile, embeddings, device="cpu"):
self.model = load_model(netfile, device=device)
self.device = device
self.n_atoms = embeddings.size(1)
self.embeddings = embeddings.reshape(-1).to(device)
self.batch = torch.arange(embeddings.size(0), device=device).repeat_interleave(embeddings.size(1))
self.batch = torch.arange(embeddings.size(0), device=device).repeat_interleave(
embeddings.size(1)
)
self.model.eval()

def calculate(self, pos, box):
Expand Down
42 changes: 24 additions & 18 deletions torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@ def __init__(self, hparams):
self._saved_dataloaders = dict()

def setup(self, stage):
if self.hparams.dataset == 'Custom':
if self.hparams.dataset == "Custom":
self.dataset = datasets.Custom(
self.hparams.coord_files,
self.hparams.embed_files,
self.hparams.energy_files,
self.hparams.force_files
self.hparams.force_files,
)
else:
self.dataset = getattr(datasets, self.hparams.dataset)(
self.hparams.dataset_root,
dataset_arg=self.hparams.dataset_arg
self.hparams.dataset_root, dataset_arg=self.hparams.dataset_arg
)

idx_train, idx_val, idx_test = make_splits(
Expand All @@ -36,10 +35,10 @@ def setup(self, stage):
self.hparams.val_size,
self.hparams.test_size,
self.hparams.seed,
join(self.hparams.log_dir, 'splits.npz'),
join(self.hparams.log_dir, "splits.npz"),
self.hparams.splits,
)
print(f'train {len(idx_train)}, val {len(idx_val)}, test {len(idx_test)}')
print(f"train {len(idx_train)}, val {len(idx_val)}, test {len(idx_test)}")

self.train_dataset = Subset(self.dataset, idx_train)
self.val_dataset = Subset(self.dataset, idx_val)
Expand All @@ -49,20 +48,23 @@ def setup(self, stage):
self._standardize()

def train_dataloader(self):
return self._get_dataloader(self.train_dataset, 'train')
return self._get_dataloader(self.train_dataset, "train")

def val_dataloader(self):
loaders = [self._get_dataloader(self.val_dataset, 'val')]
if len(self.test_dataset) > 0 and self.trainer.current_epoch % self.hparams.test_interval == 0:
loaders.append(self._get_dataloader(self.test_dataset, 'test'))
loaders = [self._get_dataloader(self.val_dataset, "val")]
if (
len(self.test_dataset) > 0
and self.trainer.current_epoch % self.hparams.test_interval == 0
):
loaders.append(self._get_dataloader(self.test_dataset, "test"))
return loaders

def test_dataloader(self):
return self._get_dataloader(self.test_dataset, 'test')
return self._get_dataloader(self.test_dataset, "test")

@property
def atomref(self):
if hasattr(self.dataset, 'get_atomref'):
if hasattr(self.dataset, "get_atomref"):
return self.dataset.get_atomref()
return None

Expand All @@ -75,16 +77,18 @@ def std(self):
return self._std

def _get_dataloader(self, dataset, stage, store_dataloader=True):
store_dataloader = store_dataloader and not self.trainer.reload_dataloaders_every_epoch
store_dataloader = (
store_dataloader and not self.trainer.reload_dataloaders_every_epoch
)
if stage in self._saved_dataloaders and store_dataloader:
# storing the dataloaders like this breaks calls to trainer.reload_train_val_dataloaders
# but makes it possible that the dataloaders are not recreated on every testing epoch
return self._saved_dataloaders[stage]

if stage == 'train':
if stage == "train":
batch_size = self.hparams.batch_size
shuffle = True
elif stage in ['val', 'test']:
elif stage in ["val", "test"]:
batch_size = self.hparams.inference_batch_size
shuffle = False

Expand All @@ -93,16 +97,18 @@ def _get_dataloader(self, dataset, stage, store_dataloader=True):
batch_size=batch_size,
shuffle=shuffle,
num_workers=self.hparams.num_workers,
pin_memory=True
pin_memory=True,
)

if store_dataloader:
self._saved_dataloaders[stage] = dl
return dl

def _standardize(self):
data = tqdm(self._get_dataloader(self.train_dataset, 'val', store_dataloader=False),
desc='computing mean and std')
data = tqdm(
self._get_dataloader(self.train_dataset, "val", store_dataloader=False),
desc="computing mean and std",
)
ys = torch.cat([batch.y.clone() for batch in data])

self._mean = ys.mean()
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .custom import Custom
from .hdf import HDF5

__all__ = ['QM9', 'MD17', 'ANI1', 'Custom', 'HDF5']
__all__ = ["QM9", "MD17", "ANI1", "Custom", "HDF5"]
50 changes: 28 additions & 22 deletions torchmdnet/datasets/ani1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,17 @@

class ANI1(InMemoryDataset):

raw_url = 'https://ndownloader.figshare.com/files/9057631'
raw_url = "https://ndownloader.figshare.com/files/9057631"

element_numbers = {
'H': 1,
'C': 6,
'N': 7,
'O': 8
}
element_numbers = {"H": 1, "C": 6, "N": 7, "O": 8}

HAR2EV = 27.211386246

self_energies = {
'H': -0.500607632585 * HAR2EV,
'C': -37.8302333826 * HAR2EV,
'N': -54.5680045287 * HAR2EV,
'O': -75.0362229210 * HAR2EV
"H": -0.500607632585 * HAR2EV,
"C": -37.8302333826 * HAR2EV,
"N": -54.5680045287 * HAR2EV,
"O": -75.0362229210 * HAR2EV,
}

def __init__(self, root, transform=None, pre_transform=None, **kwargs):
Expand All @@ -33,29 +28,38 @@ def __init__(self, root, transform=None, pre_transform=None, **kwargs):

@property
def raw_file_names(self):
return [f'ANI-1_release/ani_gdb_s{i + 1:02d}.h5' for i in range(8)]
return [f"ANI-1_release/ani_gdb_s{i + 1:02d}.h5" for i in range(8)]

@property
def processed_file_names(self):
return ['ani1.pt']
return ["ani1.pt"]

def download(self):
raw_archive = join(self.raw_dir, 'ANI1_release.tar.gz')
print(f'Downloading {self.raw_url}')
raw_archive = join(self.raw_dir, "ANI1_release.tar.gz")
print(f"Downloading {self.raw_url}")
request.urlretrieve(self.raw_url, raw_archive)
extract_tar(raw_archive, self.raw_dir)
os.remove(raw_archive)

def process(self):
data_list = []
for path in tqdm(self.raw_paths, desc='raw h5 files'):
data = h5py.File(path, 'r')
for path in tqdm(self.raw_paths, desc="raw h5 files"):
data = h5py.File(path, "r")
for file_name in data:
for molecule_name in tqdm(data[file_name], desc='molecules', leave=False):
for molecule_name in tqdm(
data[file_name], desc="molecules", leave=False
):
group = data[file_name][molecule_name]
elements = torch.tensor([self.element_numbers[str(elem)[-2]] for elem in group['species']])
positions = torch.from_numpy(group['coordinates'][:])
energies = torch.from_numpy(group['energies'][:] * self.HAR2EV).float()
elements = torch.tensor(
[
self.element_numbers[str(elem)[-2]]
for elem in group["species"]
]
)
positions = torch.from_numpy(group["coordinates"][:])
energies = torch.from_numpy(
group["energies"][:] * self.HAR2EV
).float()

elements = elements.expand(positions.size(0), -1)
for z, pos, energy in zip(elements, positions, energies):
Expand All @@ -72,5 +76,7 @@ def process(self):

def get_atomref(self, max_z=100):
out = torch.zeros(max_z)
out[list(self.element_numbers.values())] = torch.tensor(list(self.self_energies.values()))
out[list(self.element_numbers.values())] = torch.tensor(
list(self.self_energies.values())
)
return out.view(-1, 1)
Loading

0 comments on commit cdeba74

Please sign in to comment.