Skip to content

Commit

Permalink
added different sampling methods for the dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Jul 2, 2024
1 parent ddd0813 commit d2f1c89
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
29 changes: 24 additions & 5 deletions Core/CDatasetLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
from Core.CDataSampler import CDataSampler
import numpy as np
import tensorflow as tf
from enum import Enum

class ESampling(Enum):
AS_IS = 'as_is'
UNIFORM = 'uniform'

class CDatasetLoader:
def __init__(self, folder, samplerArgs, stats):
def __init__(self, folder, samplerArgs, sampling, stats):
# recursively find all 'train.npz' files
trainFiles = glob.glob(os.path.join(folder, '**', 'train.npz'), recursive=True)
if 0 == len(trainFiles):
Expand Down Expand Up @@ -40,10 +45,24 @@ def __init__(self, folder, samplerArgs, stats):
}
dtype = np.uint8 if len(self._datasets) < 256 else np.uint32
# create an array of dataset indices to sample from
self._indices = np.concatenate([
np.full((v, ), k, dtype=dtype) # id of the dataset
for k, v in validSamples.items()
])
sampling = ESampling(sampling)
if ESampling.AS_IS == sampling: # just concatenate the indices
self._indices = np.concatenate([
np.full((v, ), k, dtype=dtype) # id of the dataset
for k, v in validSamples.items()
])
if ESampling.UNIFORM == sampling:
maxSize = max(validSamples.values())
chunks = []
for k, size in validSamples.items():
# all datasets should have the same number of samples represented in the indices
# so that the sampling is uniform
chunk = np.full((maxSize, ), k, dtype=np.uint32) % size
chunk = chunk.astype(dtype)
chunks.append(chunk)
continue
self._indices = np.concatenate(chunks)

self._currentId = 0

self._batchSize = samplerArgs.get('batch_size', 16)
Expand Down
5 changes: 5 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def main(args):
trainDataset = CDatasetLoader(
os.path.join(folder, 'remote'),
stats=stats,
sampling=args.sampling,
samplerArgs=dict(
batch_size=args.batch_size,
minFrames=timesteps,
Expand Down Expand Up @@ -344,6 +345,10 @@ def performRandomSearch(epoch=0):
parser.add_argument(
'--with-enconders', default=False, action='store_true',
)
parser.add_argument(
'--sampling', type=str, default='uniform',
choices=['uniform', 'as_is'],
)

main(parser.parse_args())
pass

0 comments on commit d2f1c89

Please sign in to comment.