Skip to content

Commit

Permalink
filtrate by samples number, print extra info
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Sep 19, 2024
1 parent a263939 commit 0446554
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions scripts/preprocess-remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,32 @@ def dropPadding(idx, padding):
print('Frames before: {}. Frames after: {}'.format(len(idx), len(res)))
return res

def processFolder(folder, timeDelta, testRatio, framesPerChunk, testPadding, skippedFrames):
def processFolder(folder, timeDelta, testRatio, framesPerChunk, testPadding, skippedFrames, minimumFrames):
print('Processing', folder)
dataset = loadNpz(folder)
for k, v in dataset.items():
print(k, v.shape)

if len(dataset['time']) < minimumFrames:
print('Dataset is too short. Skipping...')
return 0, 0
# split dataset into sessions
sessions = Utils.extractSessions(dataset, float(timeDelta))
# print sessions and their durations for debugging
print('Found {} sessions'.format(len(sessions)))
for i, (start, end) in enumerate(sessions):
duration = dataset['time'][end - 1] - dataset['time'][start]
idx = np.arange(start, end)
session_time = dataset['time'][idx]
delta = np.diff(session_time)
duration = session_time[-1] - session_time[0]
print('Session {}: {} - {} ({}, {})'.format(i, start, end, end - start, duration))
# print also min, max, and mean time deltas
print('Time deltas in session {}: min={}, max={}, mean={}'.format(i, np.min(delta), np.max(delta), np.mean(delta)))
continue
# print total deltas statistics
deltas = np.diff(dataset['time'])
print('Total time deltas: min={}, max={}, mean={}'.format(np.min(deltas), np.max(deltas), np.mean(deltas)))
deltas = None
######################################################
# split each session into training and testing sets
training, testing = splitDataset(
Expand Down Expand Up @@ -233,6 +245,7 @@ def main(args):
'--skipped-frames', type=str, default='train', choices=['train', 'test', 'drop'],
help='What to do with skipped frames ("train", "test", or "drop")'
)
parser.add_argument('--minimum-frames', type=int, default=0, help='Minimum number of frames in a dataset')
args = parser.parse_args()
main(args)
pass

0 comments on commit 0446554

Please sign in to comment.