In this repository, we provide the official pytorch code for CLTP-MAN.
- Tested OS: Linux / RTX 2080
- Python == 3.7.15
- PyTorch == 1.3.1
- torchvision == 0.4.2
Install the dependencies from the requirements.txt
pip install -r requirements.txt
A pretrained model (autoencoder ) can be found in pretrained_models/model_AE/xx
A pretrained model ( writing controller) can be found in pretrained_models/model_controller/xx
In order to carry out the experiment for continual learning, we divide ETH/UCY dataset into three tasks, as follows:
----datasets\ # ETH/UCY datasets
| |----train\
| |----biwi_eth_train
| |----biwi_hotel_train
| |----val\
| |----biwi_eth_val
| |----biwi_hotel_val
| |----train\
| |----students001_train
| |----students003_train
| |----uni_examples_train
| |----val\
| |----students001_val
| |----students003_val
| |----uni_examples_val
| |----train\
| |----crowds_zara01_train
| |----crowds_zara02_train
| |----crowds_zara03_train
| |----val\
| |----crowds_zara01_val
| |----crowds_zara02_val
| |----crowds_zara03_val
To train CLTP-MAN, first it is necessary to train the autoencoder, then to train the controller and finally to train the Trajectory Prediction Module (TP). In the pretrained_model folder there are pretrained models of the different components (autoencoder, controller).
The autoencoder can be trained with the script. calls When the model is training,then it will be saved into the folder traing/trainig_ae/. A pretrained model can be found in pretrained_models/model_AE/
python --model pretrained_autoencoder_model_path
The writing controller for the memory with autoencoder can be trained with calls When the model is training,then it will be saved into the folder traing/trainig_controller/. A pretrained model (autoencoder + writing controller) can be found in pretrained_models/model_controller/
python --model pretrained_autoencoder+controller_model_path calls The script trains the TP module that generates the final prediction. When the model is training,then it model will be saved into the folder traing/trainig_TP/.
--cuda Enable/Disable GPU device (default=True).
--batch_size Number of samples that will be fed to CLTP-MAN in one iteration (default=32).
--past_len Past length (default=8).
--future_len Future length (default=12).
--best_k Number of predictions generated by CLTP-MAN model (default=20)
--model Path of pretrained model for the evaluation (default='pretrained_models/XX/XXX')
--CL_flag Whether to conduct continual learning, default=True, where CL_flag=True is conducting continual learning by using sparse experience replay,
CL_flag=False is without using sparse experience replay when model is cotinual learning.
--task_order The order of tasks,e.g. ['ETH','STU','ZARA'] represents ETH->STU->ZARA
--saved_memory If True, new memories will be generated. pairs of past-future will be decided by writing controller of model.
We borrow the framework and interface from the code MANTRA code .Thanks for the framework provided by Marchetz/MANTRA-CVPR20
which is source code of the published work MANTRA in CVPR-2020. The github repo is MANTRA code.
If you use our code or find it useful in your research, please cite our paper:
title={Continual learning-based trajectory prediction with memory augmented networks},
author={Yang, Biao and Fan, Fucheng and Ni, Rongrong and Li, Jie and Kiong, Loochu and Liu, Xiaofeng},
journal={Knowledge-Based Systems},