diff --git a/fit_stage.py b/fit_stage.py index 5011a8b..9e5f51d 100644 --- a/fit_stage.py +++ b/fit_stage.py @@ -2,8 +2,6 @@ import numpy as np def train(model, memory, params): - if len(memory) < params['batchSize']: return np.Inf - modelClone = tf.keras.models.clone_model(model) modelClone.set_weights(model.get_weights()) # use clone model for stability diff --git a/learn_environment.py b/learn_environment.py index 669816b..9cc5aa3 100644 --- a/learn_environment.py +++ b/learn_environment.py @@ -87,17 +87,18 @@ def testModel(EXPLORE_RATE): ) print('Avg. train loss: %.4f' % trainLoss) - trainLoss = fit_stage.train( - model, doomMemory, - { - 'gamma': GAMMA, - 'batchSize': BATCH_SIZE, - 'steps': BOOTSTRAPPED_STEPS, - 'episodes': params['train doom episodes'](epoch), - 'alpha': params.get('doom alpha', lambda _: alpha)(epoch) - } - ) - print('Avg. train doom loss: %.4f' % trainLoss) + if params['batchSize'] < len(doomMemory): + trainLoss = fit_stage.train( + model, doomMemory, + { + 'gamma': GAMMA, + 'batchSize': BATCH_SIZE, + 'steps': BOOTSTRAPPED_STEPS, + 'episodes': params['train doom episodes'](epoch), + 'alpha': params.get('doom alpha', lambda _: alpha)(epoch) + } + ) + print('Avg. train doom loss: %.4f' % trainLoss) ################## # test print('Testing...')