Skip to content

Commit

Permalink
fix preprocess data local error
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Jun 17, 2024
1 parent 7b5ef19 commit 219f749
Showing 1 changed file with 54 additions and 45 deletions.
99 changes: 54 additions & 45 deletions mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
import ast
from functools import partial
import json
import logging
import multiprocessing as mp
Expand Down Expand Up @@ -92,6 +93,27 @@ def get_prime_factors(n: int):
return factors


# Define Task for Multiprocessiing
def multi_train_hdf5(process, args, split_train, drop_last):
with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train[process], process, f)


def multi_valid_hdf5(process, args, split_valid, drop_last):
with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_valid[process], process, f)


def multi_test_hdf5(process, name, args, split_test, drop_last):
with h5py.File(
args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w"
) as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f)


def main() -> None:
"""
This script loads an xyz dataset and prepares
Expand Down Expand Up @@ -172,47 +194,42 @@ def run(args: argparse.Namespace):
if len(collections.train) % 2 == 1:
drop_last = True

# Define Task for Multiprocessiing
def multi_train_hdf5(process):
with h5py.File(args.h5_prefix + "train/train_" + str(process)+".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train[process], process, f)

multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_train_hdf5, args=[i])
p = mp.Process(target=multi_train_hdf5_, args=[i])
p.start()
processes.append(p)

for i in processes:
i.join()


logging.info("Computing statistics")
if len(atomic_energies_dict) == 0:
atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table)
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
logging.info(f"Mean: {mean}")
logging.info(f"Standard deviation: {std}")

# save the statistics as a json
statistics = {
"atomic_energies": str(atomic_energies_dict),
"avg_num_neighbors": avg_num_neighbors,
"mean": mean,
"std": std,
"atomic_numbers": str(z_table.zs),
"r_max": args.r_max,
}

with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514
json.dump(statistics, f)
if args.compute_statistics:
logging.info("Computing statistics")
if len(atomic_energies_dict) == 0:
atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table)
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
logging.info(f"Mean: {mean}")
logging.info(f"Standard deviation: {std}")

# save the statistics as a json
statistics = {
"atomic_energies": str(atomic_energies_dict),
"avg_num_neighbors": avg_num_neighbors,
"mean": mean,
"std": std,
"atomic_numbers": str(z_table.zs),
"r_max": args.r_max,
}

with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514
json.dump(statistics, f)

logging.info("Preparing validation set")
if args.shuffle:
Expand All @@ -222,26 +239,18 @@ def multi_train_hdf5(process):
if len(collections.valid) % 2 == 1:
drop_last = True

def multi_valid_hdf5(process):
with h5py.File(args.h5_prefix + "val/val_" + str(process)+".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_valid[process], process, f)

multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_valid_hdf5, args=[i])
p = mp.Process(target=multi_valid_hdf5_, args=[i])
p.start()
processes.append(p)

for i in processes:
i.join()

if args.test_file is not None:
def multi_test_hdf5(process, name):
with h5py.File(args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f)

multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last)
logging.info("Preparing test sets")
for name, subset in collections.tests:
drop_last = False
Expand All @@ -251,7 +260,7 @@ def multi_test_hdf5(process, name):

processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_test_hdf5, args=[i, name])
p = mp.Process(target=multi_test_hdf5_, args=[i, name])
p.start()
processes.append(p)

Expand Down

0 comments on commit 219f749

Please sign in to comment.