diff --git a/scripts/preprocess-remote.py b/scripts/preprocess-remote.py index 44f097c..4bf3f62 100644 --- a/scripts/preprocess-remote.py +++ b/scripts/preprocess-remote.py @@ -153,15 +153,6 @@ def processFolder( if 0 < testPadding: testing = dropPadding(testing, testPadding) - def saveSubset(filename, idx): - print('%s: %d frames' % (filename, len(idx))) - subset = {k: v[idx] for k, v in dataset.items()} - time = subset['time'] - diff = np.diff(time) - assert np.all(diff >= 0), 'Time is not monotonically increasing!' - np.savez(os.path.join(folder, filename), **subset) - return - # remove the npz files files = os.listdir(folder) for fn in files: @@ -169,10 +160,19 @@ def saveSubset(filename, idx): print('Removed', len(files), 'files') totalFrames = len(testing) + len(training) - if minFrames < totalFrames: + if totalFrames < minFrames: print('Not enough frames: %d < %d' % (totalFrames, minFrames)) return 0, 0 # save training and testing sets + def saveSubset(filename, idx): + print('%s: %d frames' % (filename, len(idx))) + subset = {k: v[idx] for k, v in dataset.items()} + time = subset['time'] + diff = np.diff(time) + assert np.all(diff >= 0), 'Time is not monotonically increasing!' + np.savez(os.path.join(folder, filename), **subset) + return + saveSubset('train.npz', training) saveSubset('test.npz', testing)