-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' of github.com:nitish-kulkarni/Question-Relevanc…
…e-in-VQA
- Loading branch information
Showing
2 changed files
with
383 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
Initialized word embeddings | ||
Obtained embeddings | ||
_________________________________________________________________ | ||
Layer (type) Output Shape Param # | ||
================================================================= | ||
merge_1 (Merge) (None, 10, 600) 0 | ||
_________________________________________________________________ | ||
lstm_2 (LSTM) (None, 1000) 6404000 | ||
_________________________________________________________________ | ||
dense_3 (Dense) (None, 100) 100100 | ||
_________________________________________________________________ | ||
dense_4 (Dense) (None, 50) 5050 | ||
_________________________________________________________________ | ||
dense_5 (Dense) (None, 1) 51 | ||
================================================================= | ||
Total params: 13,184,569 | ||
Trainable params: 13,184,569 | ||
Non-trainable params: 0 | ||
_________________________________________________________________ | ||
None | ||
Built Model | ||
Training now... | ||
Epoch 1/20 | ||
14105/14105 [==============================] - 1441s 102ms/step - loss: 0.2184 - acc: 0.9310 - val_loss: 0.1882 - val_acc: 0.9419 | ||
Epoch 2/20 | ||
14105/14105 [==============================] - 1438s 102ms/step - loss: 0.1755 - acc: 0.9467 - val_loss: 0.1710 - val_acc: 0.9482 | ||
Epoch 3/20 | ||
14105/14105 [==============================] - 1438s 102ms/step - loss: 0.1631 - acc: 0.9505 - val_loss: 0.1636 - val_acc: 0.9504 | ||
Epoch 4/20 | ||
14105/14105 [==============================] - 1437s 102ms/step - loss: 0.1547 - acc: 0.9531 - val_loss: 0.1605 - val_acc: 0.9511 | ||
Epoch 5/20 | ||
14105/14105 [==============================] - 1438s 102ms/step - loss: 0.1498 - acc: 0.9547 - val_loss: 0.1549 - val_acc: 0.9539 | ||
Epoch 6/20 | ||
14105/14105 [==============================] - 1436s 102ms/step - loss: 0.1466 - acc: 0.9557 - val_loss: 0.1550 - val_acc: 0.9533 | ||
Epoch 7/20 | ||
14105/14105 [==============================] - 1436s 102ms/step - loss: 0.1440 - acc: 0.9565 - val_loss: 0.1525 - val_acc: 0.9547 | ||
Epoch 8/20 | ||
14105/14105 [==============================] - 1436s 102ms/step - loss: 0.1417 - acc: 0.9572 - val_loss: 0.1507 - val_acc: 0.9555 | ||
Epoch 9/20 | ||
14105/14105 [==============================] - 1435s 102ms/step - loss: 0.1399 - acc: 0.9579 - val_loss: 0.1507 - val_acc: 0.9555 | ||
Epoch 10/20 | ||
14105/14105 [==============================] - 1436s 102ms/step - loss: 0.1387 - acc: 0.9582 - val_loss: 0.1520 - val_acc: 0.9555 | ||
Epoch 11/20 | ||
14105/14105 [==============================] - 1438s 102ms/step - loss: 0.1378 - acc: 0.9583 - val_loss: 0.1501 - val_acc: 0.9557 | ||
Epoch 12/20 | ||
14105/14105 [==============================] - 1437s 102ms/step - loss: 0.1368 - acc: 0.9586 - val_loss: 0.1478 - val_acc: 0.9567 | ||
Epoch 13/20 | ||
14105/14105 [==============================] - 1436s 102ms/step - loss: 0.1367 - acc: 0.9586 - val_loss: 0.1506 - val_acc: 0.9553 | ||
Epoch 14/20 | ||
14105/14105 [==============================] - 1437s 102ms/step - loss: 0.1370 - acc: 0.9585 - val_loss: 0.1517 - val_acc: 0.9551 | ||
Epoch 15/20 | ||
14105/14105 [==============================] - 1437s 102ms/step - loss: 0.1360 - acc: 0.9590 - val_loss: 0.1521 - val_acc: 0.9552 | ||
Epoch 16/20 | ||
14105/14105 [==============================] - 1435s 102ms/step - loss: 0.1358 - acc: 0.9589 - val_loss: 0.1490 - val_acc: 0.9561 | ||
Epoch 17/20 | ||
14105/14105 [==============================] - 1436s 102ms/step - loss: 0.1364 - acc: 0.9587 - val_loss: 0.1500 - val_acc: 0.9557 | ||
Epoch 18/20 | ||
14105/14105 [==============================] - 1436s 102ms/step - loss: 0.1359 - acc: 0.9588 - val_loss: 0.1543 - val_acc: 0.9552 | ||
Epoch 19/20 | ||
14105/14105 [==============================] - 1435s 102ms/step - loss: 0.1358 - acc: 0.9589 - val_loss: 0.1542 - val_acc: 0.9544 | ||
Epoch 20/20 | ||
14105/14105 [==============================] - 1436s 102ms/step - loss: 0.1363 - acc: 0.9587 - val_loss: 0.1514 - val_acc: 0.9554 | ||
|
||
Metrics on first order train dataset | ||
precision: [ 0.9534671 0.98037139] | ||
recall: [ 0.99424028 0.85567199] | ||
fscore: [ 0.97342692 0.91378706] | ||
support: [1055607 354893] | ||
[[1049527 6080] | ||
[ 51221 303672]] | ||
Accuracy = (1049527.0 + 303672.0) / (1049527.0 + 303672.0 + 6080.0 + 51221.0) = 0.9593753987947536 | ||
|
||
Metrics on first order val dataset | ||
precision: [ 0.9495827 0.97592667] | ||
recall: [ 0.99298987 0.84351368] | ||
fscore: [ 0.97080131 0.90490188] | ||
support: [263761 88864] | ||
[[261912 1849] | ||
[ 13906 74958]] | ||
Accuracy = (261912.0 + 74958.0) / (74958.0 + 261912.0 + 1849.0 + 13906.0) = 0.955320808224034 | ||
|
||
Metrics on first order test dataset | ||
precision: [ 0.94509108 0.71661531] | ||
recall: [ 0.89842234 0.83111115] | ||
fscore: [ 0.921166 0.76962822] | ||
support: [693558 214354] | ||
[[623108 70450] | ||
[ 36202 178152]] | ||
Accuracy = (623108.0 + 178152.0) / (623108.0 + 178152.0 + 70450.0 + 36202.0) = 0.8825304655076703 | ||
|
||
Metrics on qrpe train dataset | ||
precision: [ 0.70698281 0.89866632] | ||
recall: [ 0.93430719 0.6007167 ] | ||
fscore: [ 0.80490239 0.72008801] | ||
support: [35392 34324] | ||
[[33067 2325] | ||
[13705 20619]] | ||
Accuracy = (33067.0 + 20619.0) / (33067.0 + 20619.0 + 2325.0 + 13705.0) = 0.7700671294968157 | ||
|
||
Metrics on qrpe test dataset | ||
precision: [ 0.63524626 0.70323154] | ||
recall: [ 0.78205966 0.53489683] | ||
fscore: [ 0.70104904 0.60762088] | ||
support: [18372 17738] | ||
[[14368 4004] | ||
[ 8250 9488]] | ||
Accuracy = (14368.0 + 9488.0) / (14368.0 + 9488.0 + 4004.0 + 8250.0) = 0.6606480199390751 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
import argparse | ||
import numpy as np | ||
import pandas as pd | ||
from keras.preprocessing import sequence, text | ||
from keras.models import Sequential, Model | ||
from keras.layers import Dense, LSTM, Dropout, Merge, Input, RepeatVector, TimeDistributed | ||
from keras.layers.embeddings import Embedding | ||
from keras.regularizers import l2 | ||
from keras.callbacks import ModelCheckpoint | ||
from sklearn.metrics import precision_recall_fscore_support as score | ||
import json | ||
from sklearn.metrics import confusion_matrix | ||
import torchfile | ||
|
||
EMBEDDING_LEN = 300 | ||
MAX_LEN_SENTENCE = 10 | ||
BATCH_SIZE = 128 | ||
|
||
DATA_PATH = "./Data/" | ||
QUESTIONS_MAP = json.load(open(DATA_PATH + 'vqa2_questions.json', 'r')) | ||
QUESTIONS_MAP_QRPE = json.load(open(DATA_PATH + 'vqa1_questions.json', 'r')) | ||
COCO_TRAIN_IMG_MAP = json.load(open("%s/coco_train_image_map.json" % DATA_PATH, "r")) | ||
COCO_VAL_IMG_MAP = json.load(open("%s/coco_val_image_map.json" % DATA_PATH, "r")) | ||
VG_IMG_MAP = json.load(open("%s/vg_image_map.json" % DATA_PATH, "r")) | ||
### Collecting fc7 features for images | ||
COCO_TRAIN_FEAT = torchfile.load("%s/train_fc7.t7" % DATA_PATH) | ||
COCO_VAL_FEAT = torchfile.load("%s/val_fc7.t7" % DATA_PATH) | ||
VG_FEAT = torchfile.load("%s/vg_fc7.t7" % DATA_PATH) | ||
TRAIN_PATH = DATA_PATH + "train_firstorder_split_data.txt" | ||
VAL_PATH = DATA_PATH + "val_firstorder_split_data.txt" | ||
TEST_PATH = DATA_PATH + "val_firstorder_data.txt" | ||
TRAIN_QRPE_PATH = DATA_PATH + "train_data.txt" | ||
TEST_QRPE_PATH = DATA_PATH + "val_data.txt" | ||
def get_image_feature(image_id, coco=True): | ||
if coco: | ||
return COCO_TRAIN_FEAT[COCO_TRAIN_IMG_MAP[image_id]] if image_id in COCO_TRAIN_IMG_MAP else COCO_VAL_FEAT[COCO_VAL_IMG_MAP[image_id]] | ||
else: | ||
return VG_FEAT[VG_IMG_MAP[image_id]] | ||
def get_question(question_id, qrpe_flag): | ||
if qrpe_flag: | ||
return QUESTIONS_MAP_QRPE[question_id] | ||
return QUESTIONS_MAP[question_id] | ||
def get_image_features(image_ids, coco): | ||
return np.array([get_image_feature(image_id, c) for image_id, c in zip(image_ids, coco)]) | ||
def get_questions(question_ids, qrpe_flag): | ||
return [get_question(qid, qrpe_flag) for qid in question_ids] | ||
|
||
class Dataset(): | ||
tokenizer = None | ||
data_train = None | ||
data_val = None | ||
data_test = None | ||
word_to_idx = None | ||
total_words = None | ||
qrpe_train = None | ||
qrpe_test = None | ||
|
||
def __init__(self, datapath): | ||
self.tokenizer = text.Tokenizer() | ||
self.data_train = pd.read_csv(TRAIN_PATH, sep="\t", header=None) | ||
self.data_train.columns = ["imgid","qid","rel","src"] | ||
self.data_val = pd.read_csv(VAL_PATH, sep="\t", header=None) | ||
self.data_val.columns = ["imgid","qid","rel","src"] | ||
self.data_test = pd.read_csv(TEST_PATH, sep="\t", header=None) | ||
self.data_test.columns = ["imgid","qid","rel","src"] | ||
self.qrpe_train = pd.read_csv(TRAIN_QRPE_PATH, sep="\t", header=None) | ||
self.qrpe_train.columns = ["imgid","qid","rel","src"] | ||
self.qrpe_test = pd.read_csv(TEST_QRPE_PATH, sep="\t", header=None) | ||
self.qrpe_test.columns = ["imgid","qid","rel","src"] | ||
# imgid for train, val, test | ||
self.data_train.imgid = self.data_train.imgid.astype(str) | ||
self.data_val.imgid = self.data_val.imgid.astype(str) | ||
self.data_test.imgid = self.data_test.imgid.astype(str) | ||
#imgid for qrpe train and test | ||
self.qrpe_train.imgid = self.qrpe_train.imgid.astype(str) | ||
self.qrpe_test.imgid = self.qrpe_test.imgid.astype(str) | ||
# qid for train, val, test | ||
self.data_train.qid = get_questions(self.data_train.qid.astype(str), False) | ||
self.data_val.qid = get_questions(self.data_val.qid.astype(str), False) | ||
self.data_test.qid = get_questions(self.data_test.qid.astype(str), False) | ||
#qid for qrpe train and test | ||
self.qrpe_train.qid = get_questions(self.qrpe_train.qid.astype(str), True) | ||
self.qrpe_test.qid = get_questions(self.qrpe_test.qid.astype(str), True) | ||
# src for train, val, test | ||
self.data_train.src = self.data_train.src.astype(int) | ||
self.data_val.src = self.data_val.src.astype(int) | ||
self.data_test.src = self.data_test.src.astype(int) | ||
# src for qrpe | ||
self.qrpe_train.src = self.qrpe_train.src.astype(int) | ||
self.qrpe_test.src = self.qrpe_test.src.astype(int) | ||
self.tokenizer.fit_on_texts(list(self.data_train.qid.astype(str)) + list(self.data_val.qid.astype(str)) + list(self.data_test.qid.astype(str)) + list(self.qrpe_train.qid.astype(str)) + list(self.qrpe_test.qid.astype(str))) | ||
self.word_to_idx = self.tokenizer.word_index | ||
self.total_words = len(self.word_to_idx) | ||
|
||
def create_embedding_matrix(self, embeddings_path): | ||
embeddings = {} | ||
with open(embeddings_path) as f: | ||
for line in f: | ||
values = line.split() | ||
embedding = np.asarray(values[1:], dtype='float32') | ||
embeddings[values[0]] = embedding | ||
sz_embedding_mat = self.total_words + 1 | ||
embedding_matrix = np.zeros((sz_embedding_mat, EMBEDDING_LEN)) | ||
for key in self.word_to_idx: | ||
if key in embeddings: | ||
embedding_matrix[self.word_to_idx[key]] = embeddings[key] | ||
print "Initialized word embeddings" | ||
return embedding_matrix | ||
|
||
class DataGenerator(): | ||
def __init__(self, batch_size = 128, shuffle = True): | ||
self.batch_size = batch_size | ||
self.shuffle = shuffle | ||
|
||
def __get_exploration_order(self, list_IDs): | ||
indexes = np.arange(len(list_IDs)) | ||
if self.shuffle == True: | ||
np.random.shuffle(indexes) | ||
return indexes | ||
|
||
def __data_generation(self, list_IDs_temp, list_QIDs_temp, list_src_temp, tokenizer): | ||
X_img = get_image_features(list_IDs_temp, list_src_temp) | ||
X = tokenizer.texts_to_sequences(list_QIDs_temp) | ||
X_lang = sequence.pad_sequences(X, maxlen=MAX_LEN_SENTENCE) | ||
return X_lang, X_img | ||
|
||
def generate(self, labels, list_IDs, list_QIDs, list_src, tokenizer): | ||
while 1: | ||
indexes = self.__get_exploration_order(list_IDs) | ||
# Generate batches | ||
imax = int(len(indexes)/self.batch_size) | ||
for i in range(imax): | ||
indexes_temp = indexes[i*self.batch_size:(i+1)*self.batch_size] | ||
# Find list of IDs | ||
labels_temp = [] | ||
list_IDs_temp = [] | ||
list_QIDs_tmp = [] | ||
list_src_temp = [] | ||
for k in indexes_temp: | ||
list_IDs_temp.append(list_IDs[k]) | ||
list_QIDs_tmp.append(list_QIDs[k]) | ||
list_src_temp.append(int(list_src[k])) | ||
labels_temp.append(np.array(int(labels[k]))) | ||
# Generate data | ||
X_lang, X_img = self.__data_generation(list_IDs_temp, list_QIDs_tmp, list_src_temp, tokenizer) | ||
yield [X_lang, X_img], np.array(labels_temp) | ||
|
||
class LSTMModel(): | ||
def build_model(self, num_vocab, embedding_matrix, max_len): | ||
lang_model = Sequential() | ||
lang_model.add(Embedding(input_dim=num_vocab, output_dim=EMBEDDING_LEN, weights=[embedding_matrix], input_length=max_len)) | ||
lang_model.add(LSTM(256,return_sequences=True)) | ||
lang_model.add(TimeDistributed(Dense(EMBEDDING_LEN))) | ||
|
||
image_model = Sequential() | ||
image_model.add(Dense(EMBEDDING_LEN, input_dim = 4096, activation='relu')) | ||
image_model.add(RepeatVector(max_len)) | ||
|
||
model = Sequential() | ||
model.add(Merge([lang_model, image_model], mode='concat')) | ||
model.add(LSTM(1000,return_sequences=False, input_shape=())) | ||
model.add(Dense(100, activation='relu', W_regularizer=l2(0.0001), b_regularizer=l2(0.0001))) | ||
model.add(Dense(50, activation='relu', W_regularizer=l2(0.0001), b_regularizer=l2(0.0001))) | ||
model.add(Dense(1, activation='sigmoid', W_regularizer=l2(0.0001), b_regularizer=l2(0.0001))) | ||
|
||
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) | ||
print(model.summary()) | ||
return model | ||
|
||
### Main function which trains the model, tests the model and report metrics | ||
def main(params): | ||
datapath = params["datapath"] | ||
train_data_split = params["train_data_split"] | ||
embeddings_path = params["embeddings_path"] | ||
model_path = params["model_path"] | ||
|
||
Ds = Dataset(datapath) | ||
embedding_matrix = Ds.create_embedding_matrix(embeddings_path) | ||
print "Obtained embeddings" | ||
num_vocab = Ds.total_words + 1 | ||
|
||
lm = LSTMModel() | ||
model = lm.build_model(num_vocab, embedding_matrix, MAX_LEN_SENTENCE) | ||
print "Built Model" | ||
print "Training now..." | ||
training_generator = DataGenerator(batch_size=100).generate(Ds.data_train.rel, Ds.data_train.imgid, Ds.data_train.qid, Ds.data_train.src, Ds.tokenizer) | ||
validation_generator = DataGenerator(batch_size=125).generate(Ds.data_val.rel, Ds.data_val.imgid, Ds.data_val.qid, Ds.data_val.src, Ds.tokenizer) | ||
#model.fit_generator(epochs=1, generator=training_generator, validation_data=validation_generator, steps_per_epoch=(len(Ds.data_train.imgid)/100), validation_steps = len(Ds.data_val.imgid)/125, verbose=1) | ||
model.fit_generator(epochs=20, generator=training_generator, validation_data=validation_generator, steps_per_epoch=(len(Ds.data_train.imgid)/100), validation_steps = len(Ds.data_val.imgid)/125, verbose=1) | ||
|
||
### Defining generators for prediction | ||
training_generator = DataGenerator(batch_size=100, shuffle=False).generate(Ds.data_train.rel, Ds.data_train.imgid, Ds.data_train.qid, Ds.data_train.src, Ds.tokenizer) | ||
validation_generator = DataGenerator(batch_size=125, shuffle=False).generate(Ds.data_val.rel, Ds.data_val.imgid, Ds.data_val.qid, Ds.data_val.src, Ds.tokenizer) | ||
testing_generator = DataGenerator(batch_size=8, shuffle=False).generate(Ds.data_test.rel, Ds.data_test.imgid, Ds.data_test.qid, Ds.data_test.src, Ds.tokenizer) | ||
qrpe_training_generator = DataGenerator(batch_size=116, shuffle=False).generate(Ds.qrpe_train.rel, Ds.qrpe_train.imgid, Ds.qrpe_train.qid, Ds.qrpe_train.src, Ds.tokenizer) | ||
qrpe_testing_generator = DataGenerator(batch_size=115, shuffle=False).generate(Ds.qrpe_test.rel, Ds.qrpe_test.imgid, Ds.qrpe_test.qid, Ds.qrpe_test.src, Ds.tokenizer) | ||
|
||
### Testing on first order train dataset | ||
pred = model.predict_generator(generator=training_generator, steps=(len(Ds.data_train.imgid)/100), verbose=1) | ||
fid = open(datapath + "firstorder_train_pred.txt", "w") | ||
for p in pred: | ||
fid.write(str(p[0])+"\n") | ||
fid.close() | ||
precision, recall, fscore, support = score(Ds.data_train.rel, pred.round(), labels=[0, 1]) | ||
print "Metrics on first order train dataset" | ||
print('precision: {}'.format(precision)) | ||
print('recall: {}'.format(recall)) | ||
print('fscore: {}'.format(fscore)) | ||
print('support: {}'.format(support)) | ||
print confusion_matrix(Ds.data_train.rel, pred.round()) | ||
|
||
### Testing on first order val dataset | ||
pred = model.predict_generator(generator=validation_generator, steps=(len(Ds.data_val.imgid)/125), verbose=1) | ||
fid = open(datapath + "firstorder_val_pred.txt", "w") | ||
for p in pred: | ||
fid.write(str(p[0])+"\n") | ||
fid.close() | ||
precision, recall, fscore, support = score(Ds.data_val.rel, pred.round(), labels=[0, 1]) | ||
print "Metrics on first order val dataset" | ||
print('precision: {}'.format(precision)) | ||
print('recall: {}'.format(recall)) | ||
print('fscore: {}'.format(fscore)) | ||
print('support: {}'.format(support)) | ||
print confusion_matrix(Ds.data_val.rel, pred.round()) | ||
|
||
### Testing on first order test dataset | ||
pred = model.predict_generator(generator=testing_generator, steps=(len(Ds.data_test.imgid)/8), verbose=1) | ||
fid = open(datapath + "firstorder_test_pred.txt", "w") | ||
for p in pred: | ||
fid.write(str(p[0])+"\n") | ||
fid.close() | ||
precision, recall, fscore, support = score(Ds.data_test.rel, pred.round(), labels=[0, 1]) | ||
print "Metrics on first order test dataset" | ||
print('precision: {}'.format(precision)) | ||
print('recall: {}'.format(recall)) | ||
print('fscore: {}'.format(fscore)) | ||
print('support: {}'.format(support)) | ||
print confusion_matrix(Ds.data_test.rel, pred.round()) | ||
|
||
### Testing on qrpe train dataset | ||
pred = model.predict_generator(generator=qrpe_training_generator, steps=(len(Ds.qrpe_train.imgid)/116), verbose=1) | ||
fid = open(datapath + "qrpe_train_pred.txt", "w") | ||
for p in pred: | ||
fid.write(str(p[0])+"\n") | ||
fid.close() | ||
precision, recall, fscore, support = score(Ds.qrpe_train.rel, pred.round(), labels=[0, 1]) | ||
print "Metrics on qrpe train dataset" | ||
print('precision: {}'.format(precision)) | ||
print('recall: {}'.format(recall)) | ||
print('fscore: {}'.format(fscore)) | ||
print('support: {}'.format(support)) | ||
print confusion_matrix(Ds.qrpe_train.rel, pred.round()) | ||
|
||
### Testing on qrpe test dataset | ||
pred = model.predict_generator(generator=qrpe_testing_generator, steps=(len(Ds.qrpe_test.imgid)/115), verbose=1) | ||
fid = open(datapath + "qrpe_test_pred.txt", "w") | ||
for p in pred: | ||
fid.write(str(p[0])+"\n") | ||
fid.close() | ||
precision, recall, fscore, support = score(Ds.qrpe_test.rel, pred.round(), labels=[0, 1]) | ||
print "Metrics on qrpe test dataset" | ||
print('precision: {}'.format(precision)) | ||
print('recall: {}'.format(recall)) | ||
print('fscore: {}'.format(fscore)) | ||
print('support: {}'.format(support)) | ||
print confusion_matrix(Ds.qrpe_test.rel, pred.round()) | ||
|
||
if __name__=='__main__': | ||
### Read user inputs | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--datapath", dest="datapath", type=str, default="./") | ||
parser.add_argument("--train_data_split", dest="train_data_split", type=float, default=0.8) | ||
parser.add_argument("--embeddings_path", dest="embeddings_path", type=str, default="./glove.840B.300d.txt") | ||
parser.add_argument("--model_path", dest="model_path", type=str, default="./models/") | ||
params = vars(parser.parse_args()) | ||
main(params) |