diff --git a/train.py b/train.py index 3b0e4d6b93..a470606813 100644 --- a/train.py +++ b/train.py @@ -54,6 +54,7 @@ def train_net(net, for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) + net.train() # reset the generators train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)