-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1fba02e
commit 9b4ccff
Showing
5 changed files
with
209 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
train: | ||
epochs: 10 | ||
model: | ||
conv_units: 16 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
tensorflow>=2.5,<2.6 | ||
ruamel.yaml>=0.17,<0.18 | ||
imageio>=2.9,<3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
from util import load_params, read_labeled_images | ||
import os | ||
import json | ||
|
||
INPUT_DIR = "data/images" | ||
RESUME_PREVIOUS_MODEL = False | ||
OUTPUT_DIR = "models" | ||
|
||
METRICS_FILE = "metrics.json" | ||
SEED = 20210715 | ||
|
||
BATCH_SIZE = 128 | ||
|
||
|
||
def get_model(dense_units=128, | ||
conv_kernel=(3, 3), | ||
conv_units=32, | ||
dropout=0.5, | ||
activation="relu"): | ||
model = tf.keras.models.Sequential([ | ||
tf.keras.layers.Reshape(input_shape=(28, 28), | ||
target_shape=(28, 28, 1)), | ||
tf.keras.layers.Conv2D(conv_units, | ||
kernel_size=conv_kernel, | ||
activation=activation), | ||
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), | ||
tf.keras.layers.Dropout(dropout), | ||
tf.keras.layers.Flatten(), | ||
tf.keras.layers.Dense(dense_units, activation=activation), | ||
tf.keras.layers.Dense(10, activation="softmax")]) | ||
|
||
loss = tf.keras.losses.CategoricalCrossentropy() | ||
metrics = [tf.keras.metrics.CategoricalAccuracy(name="acc")] | ||
optimizer = "Adam" | ||
model.compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=metrics, | ||
) | ||
|
||
return model | ||
|
||
|
||
def normalize(images_array): | ||
return images_array / 255 | ||
|
||
|
||
def history_to_csv(history): | ||
keys = list(history.history.keys()) | ||
csv_string = ", ".join(["epoch"] + keys) + "\n" | ||
list_len = len(history.history[keys[0]]) | ||
for i in range(list_len): | ||
row = ( | ||
str(i + 1) | ||
+ ", " | ||
+ ", ".join([str(history.history[k][i]) for k in keys]) | ||
+ "\n" | ||
) | ||
csv_string += row | ||
|
||
return csv_string | ||
|
||
|
||
def main(): | ||
params = load_params() | ||
m = get_model() | ||
m.summary() | ||
|
||
training_images, training_labels = read_labeled_images( | ||
os.path.join(INPUT_DIR, 'train/')) | ||
testing_images, testing_labels = read_labeled_images( | ||
os.path.join(INPUT_DIR, 'test/') | ||
) | ||
|
||
assert training_images.shape[0] + testing_images.shape[0] == 70000 | ||
assert training_labels.shape[0] + testing_labels.shape[0] == 70000 | ||
|
||
print(f"Training Dataset Shape: {training_images.shape}") | ||
print(f"Testing Dataset Shape: {testing_images.shape}") | ||
print(f"Training Labels: {training_labels}") | ||
print(f"Testing Labels: {testing_labels}") | ||
|
||
training_images = normalize(training_images) | ||
testing_images = normalize(testing_images) | ||
|
||
training_labels = tf.keras.utils.to_categorical( | ||
training_labels, num_classes=10, dtype="float32") | ||
testing_labels = tf.keras.utils.to_categorical( | ||
testing_labels, num_classes=10, dtype="float32") | ||
|
||
# We use the test set as validation for simplicity | ||
x_train = training_images | ||
x_valid = testing_images | ||
y_train = training_labels | ||
y_valid = testing_labels | ||
|
||
history = m.fit( | ||
x_train, | ||
y_train, | ||
batch_size=BATCH_SIZE, | ||
epochs=params["train"]["epochs"], | ||
verbose=1, | ||
validation_data=(x_valid, y_valid), | ||
) | ||
|
||
with open("logs.csv", "w") as f: | ||
f.write(history_to_csv(history)) | ||
|
||
model_file = os.path.join(OUTPUT_DIR, "model.h5") | ||
m.save(model_file) | ||
|
||
metrics_dict = m.evaluate( | ||
testing_images, | ||
testing_labels, | ||
batch_size=BATCH_SIZE, | ||
return_dict=True, | ||
) | ||
|
||
with open(METRICS_FILE, "w") as f: | ||
f.write(json.dumps(metrics_dict)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from ruamel.yaml import YAML | ||
import numpy as np | ||
import os | ||
from imageio import imread | ||
|
||
|
||
def get_images_from_directory(directory): | ||
image_file_extensions = ['.png', '.jpg', '.bmp'] | ||
images = [] | ||
# we assume the images are 28x28 grayscale | ||
shape_0, shape_1 = 28, 28 | ||
for f in os.listdir(directory): | ||
if os.path.splitext(f)[1] in image_file_extensions: | ||
current_img = imread(os.path.join(directory, f)) | ||
if (len(current_img.shape) != 2 or current_img.shape[0] != shape_0 or current_img.shape[1] != shape_1): | ||
raise Exception("Works with 28x28 grayscale images") | ||
images.append(current_img) | ||
image_array = np.ndarray( | ||
shape=(len(images), shape_0, shape_1), dtype='uint8') | ||
for i, img in enumerate(images): | ||
image_array[i] = img | ||
print(image_array.shape) | ||
return image_array | ||
|
||
|
||
def read_labeled_images(directory): | ||
"""The structure of the directory should be like: | ||
. | ||
├── 0 | ||
├── 1 | ||
├── 2 | ||
├── 3 | ||
├── 4 | ||
├── 5 | ||
├── 6 | ||
├── 7 | ||
├── 8 | ||
└── 9 | ||
and contain PNG images in each directory. | ||
""" | ||
shape_0, shape_1 = 28, 28 | ||
label_array = np.ndarray(shape=0, dtype='uint8') | ||
image_array = np.ndarray(shape=(0, shape_0, shape_1), dtype='uint8') | ||
for label in range(0, 10): | ||
images_dir = f"{directory}/{label}" | ||
images = get_images_from_directory(images_dir) | ||
labels = np.ones(shape=(images.shape[0]), dtype='uint8') * label | ||
image_array = np.concatenate((image_array, images), axis=0) | ||
label_array = np.concatenate((label_array, labels), axis=0) | ||
|
||
return image_array, label_array | ||
|
||
|
||
def load_params(): | ||
"Updates FULL_PARAMS with the values in params.yaml and returns all as a dictionary" | ||
yaml = YAML(typ="safe") | ||
with open("params.yaml") as f: | ||
params = yaml.load(f) | ||
return params | ||
|
||
|
||
def load_npz_data(filename): | ||
npzfile = np.load(filename) | ||
return (npzfile['images'], npzfile['labels']) | ||
|
||
|
||
def shuffle_in_parallel(seed, array1, array2): | ||
np.random.seed(seed) | ||
np.random.shuffle(array1) | ||
np.random.seed(seed) | ||
np.random.shuffle(array2) | ||
|
||
return array1, array2 |