In this repository, we provide the official pytorch code for CLTP-MAN.
- Tested OS: Linux / RTX 2080
- Python == 3.7.15
- PyTorch == 1.3.1
- torchvision == 0.4.2
Install the dependencies from the requirements.txt
:
pip install -r requirements.txt
A pretrained model (autoencoder ) can be found in pretrained_models/model_AE/xx
A pretrained model ( writing controller) can be found in pretrained_models/model_controller/xx
In order to carry out the experiment for continual learning, we divide ETH/UCY dataset into three tasks, as follows:
----datasets\ # ETH/UCY datasets
|----ETH\
| |----train\
| |----biwi_eth_train
| |----biwi_hotel_train
| |----val\
| |----biwi_eth_val
| |----biwi_hotel_val
|----STU\
| |----train\
| |----students001_train
| |----students003_train
| |----uni_examples_train
| |----val\
| |----students001_val
| |----students003_val
| |----uni_examples_val
|----ZARA\
| |----train\
| |----crowds_zara01_train
| |----crowds_zara02_train
| |----crowds_zara03_train
| |----val\
| |----crowds_zara01_val
| |----crowds_zara02_val
| |----crowds_zara03_val
To train CLTP-MAN, first it is necessary to train the autoencoder, then to train the controller and finally to train the Trajectory Prediction Module (TP). In the pretrained_model folder there are pretrained models of the different components (autoencoder, controller).
python train_ae.py
The autoencoder can be trained with the train_ae.py script. train_ae.py calls trainer_ae.py When the model is training,then it will be saved into the folder traing/trainig_ae/. A pretrained model can be found in pretrained_models/model_AE/
python train_controllerMem.py --model pretrained_autoencoder_model_path
The writing controller for the memory with autoencoder can be trained with train_controllerMem.py. train_controllerMem.py calls trainer_controllerMem.py. When the model is training,then it will be saved into the folder traing/trainig_controller/. A pretrained model (autoencoder + writing controller) can be found in pretrained_models/model_controller/
python train_TP.py --model pretrained_autoencoder+controller_model_path
train_TP.py calls trainer_TP.py The script trains the TP module that generates the final prediction. When the model is training,then it model will be saved into the folder traing/trainig_TP/.
--cuda Enable/Disable GPU device (default=True).
--batch_size Number of samples that will be fed to CLTP-MAN in one iteration (default=32).
--past_len Past length (default=8).
--future_len Future length (default=12).
--best_k Number of predictions generated by CLTP-MAN model (default=20)
--model Path of pretrained model for the evaluation (default='pretrained_models/XX/XXX')
--CL_flag Whether to conduct continual learning, default=True, where CL_flag=True is conducting continual learning by using sparse experience replay,
CL_flag=False is without using sparse experience replay when model is cotinual learning.
--task_order The order of tasks,e.g. ['ETH','STU','ZARA'] represents ETH->STU->ZARA
--saved_memory If True, new memories will be generated. pairs of past-future will be decided by writing controller of model.
We borrow the framework and interface from the code MANTRA code .Thanks for the framework provided by Marchetz/MANTRA-CVPR20
,
which is source code of the published work MANTRA in CVPR-2020. The github repo is MANTRA code.
If you use our code or find it useful in your research, please cite our paper:
@article{yang2022continual,
title={Continual learning-based trajectory prediction with memory augmented networks},
author={Yang, Biao and Fan, Fucheng and Ni, Rongrong and Li, Jie and Kiong, Loochu and Liu, Xiaofeng},
journal={Knowledge-Based Systems},
volume={258},
pages={110022},
year={2022},
publisher={Elsevier}
}