Skip to content

Commit

Permalink
added evaluation during training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
GuyTevet committed Oct 9, 2022
1 parent 474d005 commit 3c0ce16
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 20 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,22 @@ python -m train.train_mdm --save_dir save/my_kit_trans_enc_512 --dataset kit
* Use `--device` to define GPU id.
* Use `--arch` to choose one of the architectures reported in the paper `{trans_enc, trans_dec, gru}` (`trans_enc` is default).
* Add `--train_platform_type {ClearmlPlatform, TensorboardPlatform}` to track results with either [ClearML](https://clear.ml/) or [Tensorboard](https://www.tensorflow.org/tensorboard).
* Add `--eval_during_training` to run a short (90 minutes) evaluation for each saved checkpoint.
This will slow down training but will give you better monitoring.

## Evaluate
* Takes about 20 hours (on a single GPU)
* The output of this script is provided in the checkpoints zip file.

ETA - Nov 22
**HumanML3D**
```shell
python -m eval.eval_humanml --model_path ./save/humanml_trans_enc_512/model000475000.pt
```

**KIT**
```shell
python -m eval.eval_humanml --model_path ./save/kit_trans_enc_512/model000400000.pt
```


## Acknowledgments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __getitem__(self, item):

class CompMDMGeneratedDataset(Dataset):

def __init__(self, model, diffusion, dataloader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale=None):
def __init__(self, model, diffusion, dataloader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale=1.):
self.dataloader = dataloader
self.dataset = dataloader.dataset
assert mm_num_samples < len(dataloader.dataset)
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ dependencies:
- pip:
- blis==0.7.8
- chumpy==0.70
- clearml==1.7.1
- click==8.1.3
- confection==0.0.2
- filelock==3.8.0
Expand Down
69 changes: 58 additions & 11 deletions train/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from diffusion.resample import LossAwareSampler, UniformSampler
from tqdm import tqdm
from diffusion.resample import create_named_schedule_sampler
from data_loaders.humanml.networks.evaluator_wrapper import EvaluatorMDMWrapper
from eval import eval_humanml
from data_loaders.get_data import get_dataset_loader


# For ImageNet experiments, this was a good default value.
Expand All @@ -25,6 +28,7 @@

class TrainLoop:
def __init__(self, args, train_platform, model, diffusion, data):
self.args = args
self.dataset = args.dataset
self.train_platform = train_platform
self.model = model
Expand Down Expand Up @@ -72,15 +76,27 @@ def __init__(self, args, train_platform, model, diffusion, data):
if torch.cuda.is_available() and dist_util.dev() != 'cpu':
self.device = torch.device(dist_util.dev())

###### CODE FROM TRAIN - START
self.schedule_sampler_type = 'uniform'
self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, diffusion)
self.eval_wrapper, self.eval_data, self.eval_gt_data = None, None, None # TODO - implement





self.eval_wrapper, self.eval_data, self.eval_gt_data = None, None, None
if args.dataset in ['kit', 'humanml'] and args.eval_during_training:
mm_num_samples = 0 # mm is super slow hence we won't run it during training
mm_num_repeats = 0 # mm is super slow hence we won't run it during training
gen_loader = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None,
split=args.eval_split,
hml_mode='eval')

self.eval_gt_data = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None,
split=args.eval_split,
hml_mode='gt')
self.eval_wrapper = EvaluatorMDMWrapper(args.dataset, dist_util.dev())
self.eval_data = {
'test': lambda: eval_humanml.get_mdm_loader(
model, diffusion, args.eval_batch_size,
gen_loader, mm_num_samples, mm_num_repeats, gen_loader.dataset.opt.max_motion_length,
args.eval_num_samples, scale=1.,
)
}
self.use_ddp = False
self.ddp_model = self.model

Expand Down Expand Up @@ -148,10 +164,41 @@ def run_loop(self):
self.evaluate()

def evaluate(self):
pass # TODO - implement
# start_eval = time.time()
# end_eval = time.time()
# print(f'Evaluation time: {round(end_eval-start_eval)/60}min')
start_eval = time.time()
if self.eval_wrapper is not None:
print('Running evaluation loop: [Should take about 90 min]')
log_file = os.path.join(self.save_dir, f'eval_humanml_{(self.step + self.resume_step):09d}.log')
diversity_times = 300
mm_num_times = 0 # mm is super slow hence we won't run it during training
eval_dict = eval_humanml.evaluation(
self.eval_wrapper, self.eval_gt_data, self.eval_data, log_file,
replication_times=self.args.eval_rep_times, diversity_times=diversity_times, mm_num_times=mm_num_times, run_mm=False)
print(eval_dict)
for k, v in eval_dict.items():
if k.startswith('R_precision'):
for i in range(len(v)):
self.train_platform.report_scalar(name=f'top{i + 1}_' + k, value=v[i],
iteration=self.step + self.resume_step,
group_name='Eval')
else:
self.train_platform.report_scalar(name=k, value=v, iteration=self.step + self.resume_step,
group_name='Eval')

elif self.dataset in ['humanact12', 'uestc']:
from scripts import eval_humanact12_uestc
num_seeds = 2 if self.eval_debug else self.eval_rep_times
num_samples = 64 if self.eval_debug else self.eval_num_samples
args = SimpleNamespace(num_seeds=num_seeds, num_samples=num_samples, use_ddim=False,
batch_size=self.batch_size, device=self.device, guidance_scale=False,
dataset=self.dataset, cond_mode=self.cond_mode, model_path=os.path.join(self.save_dir, self.ckpt_file_name()))
eval_dict = eval_humanact12_uestc.evaluate(args, model=self.model, diffusion=self.diffusion, data=self.data.dataset)
print(f'Evaluation results on {self.dataset}: {sorted(eval_dict["feats"].items())}')
for k, v in eval_dict["feats"].items():
self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval')


end_eval = time.time()
print(f'Evaluation time: {round(end_eval-start_eval)/60}min')


def run_step(self, batch, cond):
Expand Down
14 changes: 7 additions & 7 deletions utils/parser_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,17 @@ def add_training_options(parser):
"Precision calculation is based on fixed batch size 32.")
group.add_argument("--eval_split", default='test', choices=['val', 'test'], type=str,
help="Which split to evaluate on during training.")
group.add_argument("--eval_debug", action='store_true',
help="If True, will make the eval loop run much faster (and output invalid results).")
group.add_argument("--eval_during_training", action='store_true',
help="If True, will run evaluation during training.")
group.add_argument("--eval_rep_times", default=3, type=int,
help="Number of repetitions for evaluation loop during training.")
group.add_argument("--eval_num_samples", default=-1, type=int,
group.add_argument("--eval_num_samples", default=1_000, type=int,
help="If -1, will use all samples in the specified split.")
group.add_argument("--log_interval", default=1000, type=int,
group.add_argument("--log_interval", default=1_000, type=int,
help="Log losses each N steps")
group.add_argument("--save_interval", default=100000, type=int,
group.add_argument("--save_interval", default=50_000, type=int,
help="Save checkpoints and run evaluation each N steps")
group.add_argument("--num_steps", default=600000, type=int,
group.add_argument("--num_steps", default=600_000, type=int,
help="Training will stop after the specified number of steps.")
group.add_argument("--num_frames", default=60, type=int,
help="Limit for the maximal number of frames. In HumanML3D and KIT this field is ignored.")
Expand Down Expand Up @@ -149,7 +149,7 @@ def add_sampling_options(parser):
help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")

def add_evaluation_options(parser):
group = parser.add_argument_group('dataset')
group = parser.add_argument_group('eval')
group.add_argument("--model_path", required=True, type=str,
help="Path to model####.pt file to be sampled.")
group.add_argument("--eval_mode", default='wo_mm', choices=['wo_mm', 'mm_short', 'debug'], type=str,
Expand Down

0 comments on commit 3c0ce16

Please sign in to comment.