Skip to content

Commit

Permalink
Simplify and format the code
Browse files Browse the repository at this point in the history
  • Loading branch information
salu133445 committed Jan 11, 2020
1 parent 0f5588a commit 0290b12
Showing 1 changed file with 30 additions and 45 deletions.
75 changes: 30 additions & 45 deletions src/process_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""This script loads and saves an array to shared memory."""
import os.path
"""Load and save an array to shared memory."""
import argparse
import os.path
import sys

import numpy as np
import SharedArray as sa

Expand All @@ -11,24 +13,38 @@ def parse_arguments():
parser.add_argument("filepath", help="Path to the data file.")
parser.add_argument(
"--name",
help="File name to save in SharedArray. "
"Default to use the original file name.",
help="File name to save in SharedArray. Defaults to the original file name.",
)
parser.add_argument(
"--prefix",
help="Prefix to the file name to save in "
"SharedArray. Only effective when "
help="Prefix to the file name to save in SharedArray. Only effective when "
"`name` is not given.",
)
parser.add_argument(
"--dtype", default="bool", help="Datatype of the array. Default to bool."
"--dtype", default="bool", help="Datatype of the array. Defaults to bool."
)
args = parser.parse_args()
return args.filepath, args.name, args.prefix, args.dtype


def create_shared_array(name, shape, dtype):
"""Create shared array. Prompt if a file with the same name existed."""
try:
return sa.create(name, shape, dtype)
except FileExistsError:
response = ""
while response.lower() not in ["y", "n", "yes", "no"]:
response = input(
"Existing array (also named " + name + ") was found. Replace it? (y/n) "
)
if response.lower() in ("n", "no"):
sys.exit(0)
sa.delete(name)
return sa.create(name, shape, dtype)


def main():
"""Main function"""
"""Load and save an array to shared memory."""
filepath, name, prefix, dtype = parse_arguments()

if name is None:
Expand All @@ -40,45 +56,14 @@ def main():
if filepath.endswith(".npy"):
data = np.load(filepath)
data = data.astype(dtype)
try:
sa_array = sa.create(name, data.shape, data.dtype)
except FileExistsError:
response = ""
while response not in ["yes", "no"]:
response = input(
"Existing array (also named "
+ name
+ ") was found. Replace it? (Respond with 'yes' or 'no') "
)
if response == "yes":
sa.delete(name)
sa_array = sa.create(name, data.shape, data.dtype)
else:
raise e
finally:
print("Saving data to shared memory...")
np.copyto(sa_array, data)
sa_array = create_shared_array(name, data.shape, data.dtype)
print("Saving data to shared memory...")
np.copyto(sa_array, data)
else:
with np.load(filepath) as loaded:
try:
sa_array = sa.create(name, loaded["shape"], dtype)
except FileExistsError as e:
response = ""
while response not in ["yes", "no"]:
response = input(
"Existing array (also named "
+ name
+ ") was found. Replace it? (Respond with 'yes' or "
+ "'no') "
)
if response == "yes":
sa.delete(name)
sa_array = sa.create(name, loaded["shape"], dtype)
else:
raise e
finally:
print("Saving data to shared memory...")
sa_array[[x for x in loaded["nonzero"]]] = True
sa_array = create_shared_array(name, loaded["shape"], dtype)
print("Saving data to shared memory...")
sa_array[[x for x in loaded["nonzero"]]] = 1

print(
"Successfully saved: (name='{}', shape={}, dtype={})".format(
Expand Down

0 comments on commit 0290b12

Please sign in to comment.