Skip to content

Commit

Permalink
mv
Browse files Browse the repository at this point in the history
  • Loading branch information
ShinYwings committed Dec 30, 2021
1 parent 90836dc commit dc8bf4c
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions class10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime as dt
import optimizer_alexnet
import threading
import progressbar
# import progressbar
import math

# Hyper parameters
Expand All @@ -16,13 +16,13 @@
MOMENTUM = 0.9 # SGD + MOMENTUM
BATCH_SIZE = 64 # 128 batches occurs OOM in my computer

DATASET_DIR = r"D:\ILSVRC2012"
DATASET_DIR = "/media/shin/2nd_m.2/ILSVRC2012"

# Input으로 넣을 데이터 선택
indexsub = 440
# indexsub = 441 # q95로 하면 1 더 빼줌
TRAIN_DATASET = r"D:\ILSVRC2012\class10_tfrecord_train"
TEST_DATASET = r"D:\ILSVRC2012\class10_tfrecord_val"
TRAIN_DATASET = "/media/shin/2nd_m.2/ILSVRC2012/class10_tfrecord_train"
TEST_DATASET = "/media/shin/2nd_m.2/ILSVRC2012/class10_tfrecord_val"

RUN_TRAIN_DATASET = TRAIN_DATASET
RUN_TEST_DATASET = TEST_DATASET
Expand All @@ -38,12 +38,12 @@
ENCODING_STYLE = "utf-8"
AUTO = tf.data.experimental.AUTOTUNE

widgets = [' [',
progressbar.Timer(format= 'elapsed time: %(elapsed)s'),
'] ',
progressbar.Bar('/'),' (',
progressbar.ETA(), ') ',
]
# widgets = [' [',
# progressbar.Timer(format= 'elapsed time: %(elapsed)s'),
# '] ',
# progressbar.Bar('/'),' (',
# progressbar.ETA(), ') ',
# ]
with tf.device('/CPU:0'):
def img_preprocessing(q, images, labels, train = None):
test_images = list()
Expand Down Expand Up @@ -149,10 +149,10 @@ def _parse_function(example_proto):
if not os.path.isdir(filewriter_path):
os.mkdir(filewriter_path)

root_logdir = os.path.join(filewriter_path, "logs\\fit\\")
root_logdir = os.path.join(filewriter_path, "logs/fit")
logdir = get_logdir(root_logdir)
train_logdir = os.path.join(logdir, "train\\")
val_logdir = os.path.join(logdir, "val\\")
train_logdir = os.path.join(logdir, "train")
val_logdir = os.path.join(logdir, "val")

train_tfrecord_list = list()
test_tfrecord_list = list()
Expand Down Expand Up @@ -276,10 +276,10 @@ def termination_lr_scheduling():
top5_train_accuracy.reset_states()
top5_test_accuracy.reset_states()

bar = progressbar.ProgressBar(max_value= math.ceil(train_buf_size/BATCH_SIZE), widgets=widgets)
test_bar = progressbar.ProgressBar(max_value= math.ceil(test_buf_size/BATCH_SIZE), widgets=widgets)
bar.start()
test_bar.start()
# bar = progressbar.ProgressBar(max_value= math.ceil(train_buf_size/BATCH_SIZE), widgets=widgets)
# test_bar = progressbar.ProgressBar(max_value= math.ceil(test_buf_size/BATCH_SIZE), widgets=widgets)
# bar.start()
# test_bar.start()

q = list()
isFirst = True
Expand All @@ -305,7 +305,7 @@ def termination_lr_scheduling():

if (epoch == (NUM_EPOCHS -1)) and step == 126:
termination_lr_scheduling()
bar.update(step)
# bar.update(step)

# Last step
train_images, train_labels = q.pop()
Expand Down Expand Up @@ -345,7 +345,7 @@ def termination_lr_scheduling():
test_step(batch_test_images, batch_test_labels)
t.join()

test_bar.update(step)
# test_bar.update(step)

# Last step
test_images, test_labels = q2.pop()
Expand Down

0 comments on commit dc8bf4c

Please sign in to comment.