This package contains the model implementation and training infrastructure of our AI Choreographer.
git clone https://github.com/liruilong940607/mint --recursive
Note here --recursive
is important as it will automatically clone the submodule (orbit) as well.
conda create -n mint python=3.7
conda activate mint
conda install protobuf numpy
pip install tensorflow absl-py tensorflow-datasets librosa
sudo apt-get install libopenexr-dev
pip install --upgrade OpenEXR
pip install tensorflow-graphics tensorflow-graphics-gpu
git clone https://github.com/arogozhnikov/einops /tmp/einops
cd /tmp/einops/ && pip install . -U
git clone https://github.com/google/aistplusplus_api /tmp/aistplusplus_api
cd /tmp/aistplusplus_api && pip install -r requirements.txt && pip install . -U
Note if you meet environment conflicts about numpy, you can try with pip install numpy==1.20
.
See the website
Download from google drive here, and put them to the folder ./checkpoints/
- complie protocols
protoc ./mint/protos/*.proto
- preprocess dataset into tfrecord
python tools/preprocessing.py \
--anno_dir="/mnt/data/aist_plusplus_final/" \
--audio_dir="/mnt/data/AIST/music/" \
--split=train
python tools/preprocessing.py \
--anno_dir="/mnt/data/aist_plusplus_final/" \
--audio_dir="/mnt/data/AIST/music/" \
--split=testval
- run training
python trainer.py --config_path ./configs/fact_v5_deeper_t10_cm12.config --model_dir ./checkpoints
Note you might want to change the batch_size
in the config file if you meet OUT-OF-MEMORY issue.
- run testing and evaluation
# caching the generated motions (seed included) to `./outputs`
python evaluator.py --config_path ./configs/fact_v5_deeper_t10_cm12.config --model_dir ./checkpoints
# calculate FIDs
python tools/extract_aist_features.py
python tools/calculate_fid_scores.py
@inproceedings{li2021dance,
title={AI Choreographer: Music Conditioned 3D Dance Generation with AIST++},
author={Ruilong Li and Shan Yang and David A. Ross and Angjoo Kanazawa},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
year = {2021}
}