Skip to content

Continual learning-based trajectory prediction with memory augmented networks

Notifications You must be signed in to change notification settings

fanf21/CLTP-MAN

Repository files navigation

Continual learning-based trajectory prediction with memory augmented networks

In this repository, we provide the official pytorch code for CLTP-MAN.

Installation

Environment

  • Tested OS: Linux / RTX 2080
  • Python == 3.7.15
  • PyTorch == 1.3.1
  • torchvision == 0.4.2

Dependencies

Install the dependencies from the requirements.txt:

pip install -r requirements.txt

Pretrained Models

A pretrained model (autoencoder ) can be found in pretrained_models/model_AE/xx

A pretrained model ( writing controller) can be found in pretrained_models/model_controller/xx

Dataset

In order to carry out the experiment for continual learning, we divide ETH/UCY dataset into three tasks, as follows:

----datasets\ # ETH/UCY datasets 
    |----ETH\                      
    |    |----train\
    |         |----biwi_eth_train
    |         |----biwi_hotel_train
    |    |----val\
    |         |----biwi_eth_val
    |         |----biwi_hotel_val
    |----STU\                    
    |    |----train\
    |         |----students001_train
    |         |----students003_train
    |         |----uni_examples_train
    |    |----val\
    |         |----students001_val
    |         |----students003_val
    |         |----uni_examples_val
    |----ZARA\                     
    |    |----train\
    |         |----crowds_zara01_train
    |         |----crowds_zara02_train
    |         |----crowds_zara03_train
    |    |----val\
    |         |----crowds_zara01_val
    |         |----crowds_zara02_val
    |         |----crowds_zara03_val

Training

To train CLTP-MAN, first it is necessary to train the autoencoder, then to train the controller and finally to train the Trajectory Prediction Module (TP). In the pretrained_model folder there are pretrained models of the different components (autoencoder, controller).

Training encoder-decoder model (autoencoder)

python train_ae.py

The autoencoder can be trained with the train_ae.py script. train_ae.py calls trainer_ae.py When the model is training,then it will be saved into the folder traing/trainig_ae/. A pretrained model can be found in pretrained_models/model_AE/

Training writing controller

python train_controllerMem.py --model pretrained_autoencoder_model_path

The writing controller for the memory with autoencoder can be trained with train_controllerMem.py. train_controllerMem.py calls trainer_controllerMem.py. When the model is training,then it will be saved into the folder traing/trainig_controller/. A pretrained model (autoencoder + writing controller) can be found in pretrained_models/model_controller/

Training Trajectory Prediction Module (TP)

python train_TP.py --model pretrained_autoencoder+controller_model_path

train_TP.py calls trainer_TP.py The script trains the TP module that generates the final prediction. When the model is training,then it model will be saved into the folder traing/trainig_TP/.

Command line arguments

    --cuda                               Enable/Disable GPU device (default=True).
    --batch_size                   Number of samples that will be fed to CLTP-MAN in one iteration (default=32).
    --past_len                       Past length (default=8).
    --future_len                   Future length (default=12).
    --best_k                           Number of predictions generated by CLTP-MAN model (default=20)
    --model                            Path of pretrained model for the evaluation (default='pretrained_models/XX/XXX')
    --CL_flag                         Whether to conduct continual learning, default=True, where CL_flag=True is  conducting continual learning by using sparse experience replay,
                                                  CL_flag=False is without using sparse experience replay when model is cotinual learning.
    --task_order                  The order of tasks,e.g. ['ETH','STU','ZARA']  represents ETH->STU->ZARA
    --saved_memory         If True, new memories will be generated. pairs of past-future will be decided by writing controller of model.
    

Acknowledgement

We borrow the framework and interface from the code MANTRA code .Thanks for the framework provided by Marchetz/MANTRA-CVPR20, which is source code of the published work MANTRA in CVPR-2020. The github repo is MANTRA code.

Citation

If you use our code or find it useful in your research, please cite our paper:

@article{yang2022continual,
  title={Continual learning-based trajectory prediction with memory augmented networks},
  author={Yang, Biao and Fan, Fucheng and Ni, Rongrong and Li, Jie and Kiong, Loochu and Liu, Xiaofeng},
  journal={Knowledge-Based Systems},
  volume={258},
  pages={110022},
  year={2022},
  publisher={Elsevier}
}

About

Continual learning-based trajectory prediction with memory augmented networks

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages