diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..e3507e9 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,21 @@ +name: test +on: + push: + branches: + - '**' + pull_request: + branches: + - '**' + +jobs: + test: + runs-on: ubuntu-latest + container: python:3.8-slim + steps: + - uses: actions/checkout@v2 + - name: Install prerequisites (for OpenCV) + run: apt-get update && apt-get install ffmpeg libsm6 libxext6 -y + - name: Install trajdata base version + run: python -m pip install . + - name: Run tests + run: python -m unittest tests/test_state.py diff --git a/.github/workflows/pypi_publish.yml b/.github/workflows/pypi_publish.yml new file mode 100644 index 0000000..7e0316f --- /dev/null +++ b/.github/workflows/pypi_publish.yml @@ -0,0 +1,40 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + workflow_dispatch: + +permissions: + contents: read + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore index cda38db..1e55dd9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ public/ .vscode/ +*.html +*.mp4 +*.avi # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/CITATION.cff b/CITATION.cff index dbf28a0..e793eb6 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,7 +5,23 @@ authors: given-names: "Boris" orcid: "https://orcid.org/0000-0002-8698-202X" title: "trajdata: A unified interface to many trajectory forecasting datasets" -version: 1.0.3 +version: 1.3.3 doi: 10.5281/zenodo.6671548 -date-released: 2022-06-20 -url: "https://github.com/nvr-avg/trajdata" \ No newline at end of file +date-released: 2023-08-22 +url: "https://github.com/nvr-avg/trajdata" +preferred-citation: + type: conference-paper + authors: + - family-names: "Ivanovic" + given-names: "Boris" + orcid: "https://orcid.org/0000-0002-8698-202X" + - family-names: "Song" + given-names: "Guanyu" + - family-names: "Gilitschenski" + given-names: "Igor" + - family-names: "Pavone" + given-names: "Marco" + journal: "Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks" + month: 12 + title: "trajdata: A Unified Interface to Multiple Human Trajectory Datasets" + year: 2023 \ No newline at end of file diff --git a/DATASETS.md b/DATASETS.md index e1169e3..528bae2 100644 --- a/DATASETS.md +++ b/DATASETS.md @@ -1,7 +1,18 @@ # Supported Datasets and Required Formats +## View-of-Delft +Nothing special needs to be done for the View-of-Delft Prediction dataset, simply download it as per [the instructions in the devkit README](https://github.com/tudelft-iv/view-of-delft-prediction-devkit?tab=readme-ov-file#vod-p-setup). + +It should look like this after downloading: +``` +/path/to/VoD/ + ├── maps/ + ├── v1.0-test/ + └── v1.0-trainval/ +``` + ## nuScenes -Nothing special needs to be done for the nuScenes dataset, simply install it as per [the instructions in the devkit README](https://github.com/nutonomy/nuscenes-devkit#nuscenes-setup). +Nothing special needs to be done for the nuScenes dataset, simply download it as per [the instructions in the devkit README](https://github.com/nutonomy/nuscenes-devkit#nuscenes-setup). It should look like this after downloading: ``` @@ -16,8 +27,68 @@ It should look like this after downloading: **Note**: At a minimum, only the annotations need to be downloaded (not the raw radar/camera/lidar/etc data). +## nuPlan +Nothing special needs to be done for the nuPlan dataset, simply download v1.1 as per [the instructions in the devkit documentation](https://nuplan-devkit.readthedocs.io/en/latest/dataset_setup.html). + +It should look like this after downloading: +``` +/path/to/nuPlan/ + └── dataset + ├── maps + │ ├── nuplan-maps-v1.0.json + │ ├── sg-one-north + │ │ └── 9.17.1964 + │ │ └── map.gpkg + │ ├── us-ma-boston + │ │ └── 9.12.1817 + │ │ └── map.gpkg + │ ├── us-nv-las-vegas-strip + │ │ └── 9.15.1915 + │ │ ├── drivable_area.npy.npz + │ │ ├── Intensity.npy.npz + │ │ └── map.gpkg + │ └── us-pa-pittsburgh-hazelwood + │ └── 9.17.1937 + │ └── map.gpkg + └── nuplan-v1.1 + ├── mini + │ ├── 2021.05.12.22.00.38_veh-35_01008_01518.db + │ ├── 2021.06.09.17.23.18_veh-38_00773_01140.db + │ ├── ... + │ └── 2021.10.11.08.31.07_veh-50_01750_01948.db + └── trainval + ├── 2021.05.12.22.00.38_veh-35_01008_01518.db + ├── 2021.06.09.17.23.18_veh-38_00773_01140.db + ├── ... + └── 2021.10.11.08.31.07_veh-50_01750_01948.db +``` + +**Note**: Not all dataset splits need to be downloaded. For example, you can download only the nuPlan Mini Split in case you only need a small sample dataset. + +## Waymo Open Motion Dataset +Nothing special needs to be done for the Waymo Open Motion Dataset, simply download v1.1 as per [the instructions on the dataset website](https://waymo.com/intl/en_us/open/download/). + +It should look like this after downloading: +``` +/path/to/waymo/ + ├── training/ + | ├── training.tfrecord-00000-of-01000 + | ├── training.tfrecord-00001-of-01000 + | └── ... + ├── validation/ + │   ├── validation.tfrecord-00000-of-00150 + | ├── validation.tfrecord-00001-of-00150 + | └── ... + └── testing/ + ├── testing.tfrecord-00000-of-00150 + ├── testing.tfrecord-00001-of-00150 + └── ... +``` + +**Note**: Not all the dataset parts need to be downloaded, only the necessary directories in [the Google Cloud Bucket](https://console.cloud.google.com/storage/browser/waymo_open_dataset_motion_v_1_1_0/uncompressed/scenario) need to be downloaded (e.g., `validation` for the validation dataset). + ## Lyft Level 5 -Nothing special needs to be done for the Lyft Level 5 dataset, simply install it as per [the instructions on the dataset website](https://woven-planet.github.io/l5kit/dataset.html). +Nothing special needs to be done for the Lyft Level 5 dataset, simply download it as per [the instructions on the dataset website](https://woven-planet.github.io/l5kit/dataset.html). It should look like this after downloading: ``` @@ -36,6 +107,56 @@ It should look like this after downloading: **Note**: Not all the dataset parts need to be downloaded, only the necessary `.zarr` files need to be downloaded (e.g., `sample.zarr` for the small sample dataset). +## INTERACTION Dataset +Nothing special needs to be done for the INTERACTION Dataset, simply download it as per [the instructions on the dataset website](http://interaction-dataset.com/). + +It should look like this after downloading: +``` +/path/to/interaction_single/ + ├── maps/ + │   ├── DR_CHN_Merging_ZS0.osm + | ├── DR_CHN_Merging_ZS0.osm_xy + | └── ... + ├── test_conditional-single-agent/ + │   ├── DR_CHN_Merging_ZS0_obs.csv + | ├── DR_CHN_Merging_ZS2_obs.csv + | └── ... + └── test_single-agent/ + │   ├── DR_CHN_Merging_ZS0_obs.csv + | ├── DR_CHN_Merging_ZS2_obs.csv + | └── ... + └── train/ + │   ├── DR_CHN_Merging_ZS0_train.csv + | ├── DR_CHN_Merging_ZS2_train.csv + | └── ... + └── val/ +    ├── DR_CHN_Merging_ZS0_val.csv + ├── DR_CHN_Merging_ZS2_val.csv + └── ... + +/path/to/interaction_multi/ + ├── maps/ + │   ├── DR_CHN_Merging_ZS0.osm + | ├── DR_CHN_Merging_ZS0.osm_xy + | └── ... + ├── test_conditional-multi-agent/ + │   ├── DR_CHN_Merging_ZS0_obs.csv + | ├── DR_CHN_Merging_ZS2_obs.csv + | └── ... + └── test_multi-agent/ + │   ├── DR_CHN_Merging_ZS0_obs.csv + | ├── DR_CHN_Merging_ZS2_obs.csv + | └── ... + └── train/ + │   ├── DR_CHN_Merging_ZS0_train.csv + | ├── DR_CHN_Merging_ZS2_train.csv + | └── ... + └── val/ +    ├── DR_CHN_Merging_ZS0_val.csv + ├── DR_CHN_Merging_ZS2_val.csv + └── ... +``` + ## ETH/UCY Pedestrians The raw data can be found in many places online, ranging from [research projects' data download scripts](https://github.com/agrimgupta92/sgan/blob/master/scripts/download_data.sh) to [copies of the original data itself](https://github.com/StanfordASL/Trajectron-plus-plus/tree/master/experiments/pedestrians/raw/raw/all_data) on GitHub. In this data loader, we assume the data was sourced from the latter. @@ -51,3 +172,57 @@ It should look like this after downloading: ├── students003.txt └── uni_examples.txt ``` + +## Stanford Drone Dataset +The raw data can be found in many places online, the easiest is probably [this space-optimized version](https://www.kaggle.com/datasets/aryashah2k/stanford-drone-dataset) on Kaggle. + +It should look like this after downloading: +``` +/path/to/sdd/ + ├── bookstore/ + | ├── video0 + | ├── annotations.txt + | └── reference.jpg + | ├── video1 + | ├── annotations.txt + | └── reference.jpg + | └── ... + ├── coupa/ + | ├── video0 + | ├── annotations.txt + | └── reference.jpg + | ├── video1 + | ├── annotations.txt + | └── reference.jpg + | └── ... + └── ... +``` + +**Note**: Only the annotations need to be downloaded (not the videos). + + +## Argoverse 2 Motion Forecasting +The dataset can be downloaded from [here](https://www.argoverse.org/av2.html#download-link). + +It should look like this after downloading: +``` +/path/to/av2mf/ + ├── train/ + | ├── 0000b0f9-99f9-4a1f-a231-5be9e4c523f7/ + | | ├── log_map_archive_0000b0f9-99f9-4a1f-a231-5be9e4c523f7.json + | | └── scenario_0000b0f9-99f9-4a1f-a231-5be9e4c523f7.parquet + | ├── 0000b6ab-e100-4f6b-aee8-b520b57c0530/ + | | ├── log_map_archive_0000b6ab-e100-4f6b-aee8-b520b57c0530.json + | | └── scenario_0000b6ab-e100-4f6b-aee8-b520b57c0530.parquet + | └── ... + ├── val/ + | ├── 00010486-9a07-48ae-b493-cf4545855937/ + | | ├── log_map_archive_00010486-9a07-48ae-b493-cf4545855937.json + | | └── scenario_00010486-9a07-48ae-b493-cf4545855937.parquet + | └── ... + └── test/ + ├── 0000b329-f890-4c2b-93f2-7e2413d4ca5b/ + | ├── log_map_archive_0000b329-f890-4c2b-93f2-7e2413d4ca5b.json + | └── scenario_0000b329-f890-4c2b-93f2-7e2413d4ca5b.parquet + └── ... +``` \ No newline at end of file diff --git a/README.md b/README.md index 50d615c..db5678d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Unified Trajectory Data Loader +# trajdata: A Unified Interface to Multiple Human Trajectory Datasets [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) @@ -6,6 +6,10 @@ [![DOI](https://zenodo.org/badge/488789438.svg)](https://zenodo.org/badge/latestdoi/488789438) [![PyPI version](https://badge.fury.io/py/trajdata.svg)](https://badge.fury.io/py/trajdata) +### Announcements + +**Sept 2023**: [Our paper about trajdata](https://arxiv.org/abs/2307.13924) has been accepted to the NeurIPS 2023 Datasets and Benchmarks Track! + ## Installation The easiest way to install trajdata is through PyPI with @@ -13,7 +17,7 @@ The easiest way to install trajdata is through PyPI with pip install trajdata ``` -In case you would also like to use datasets such as nuScenes and Lyft Level 5 (which require their own devkits to access raw data), the following will also install the respective devkits. +In case you would also like to use datasets such as nuScenes, Lyft Level 5, View-of-Delft, or Waymo Open Motion Dataset (which require their own devkits to access raw data or additional package dependencies), the following will also install the respective devkits and/or package dependencies. ```sh # For nuScenes pip install "trajdata[nusc]" @@ -21,10 +25,19 @@ pip install "trajdata[nusc]" # For Lyft pip install "trajdata[lyft]" -# Both -pip install "trajdata[nusc,lyft]" +# For Waymo +pip install "trajdata[waymo]" + +# For INTERACTION +pip install "trajdata[interaction]" + +# For View-of-Delft +pip install "trajdata[vod]" + +# All +pip install "trajdata[nusc,lyft,waymo,interaction,vod]" ``` -Then, download the raw datasets (nuScenes, Lyft Level 5, ETH/UCY, etc) in case you do not already have them. For more information about how to structure dataset folders/files, please see [`DATASETS.md`](./DATASETS.md). +Then, download the raw datasets (nuScenes, Lyft Level 5, View-of-Delft, ETH/UCY, etc.) in case you do not already have them. For more information about how to structure dataset folders/files, please see [`DATASETS.md`](./DATASETS.md). ### Package Developer Installation @@ -81,23 +94,49 @@ For more information on all of the possible `UnifiedDataset` constructor argumen ## Supported Datasets Currently, the dataloader supports interfacing with the following datasets: -| Dataset | ID | Splits | Add'l Tags | Description | dt | Maps | +| Dataset | ID | Splits | Locations | Description | dt | Maps | |---------|----|--------|------------|-------------|----|------| | nuScenes Train/TrainVal/Val | `nusc_trainval` | `train`, `train_val`, `val` | `boston`, `singapore` | nuScenes prediction challenge training/validation/test splits (500/200/150 scenes) | 0.5s (2Hz) | :white_check_mark: | -| nuScenes Test | `nusc_test` | `test` | `boston`, `singapore` | nuScenes' test split, no annotations (150 scenes) | 0.5s (2Hz) | :white_check_mark: | +| nuScenes Test | `nusc_test` | `test` | `boston`, `singapore` | nuScenes test split, no annotations (150 scenes) | 0.5s (2Hz) | :white_check_mark: | | nuScenes Mini | `nusc_mini` | `mini_train`, `mini_val` | `boston`, `singapore` | nuScenes mini training/validation splits (8/2 scenes) | 0.5s (2Hz) | :white_check_mark: | +| nuPlan Train | `nuplan_train` | N/A | `boston`, `singapore`, `pittsburgh`, `las_vegas` | nuPlan training split (947.42 GB) | 0.05s (20Hz) | :white_check_mark: | +| nuPlan Validation | `nuplan_val` | N/A | `boston`, `singapore`, `pittsburgh`, `las_vegas` | nuPlan validation split (90.30 GB) | 0.05s (20Hz) | :white_check_mark: | +| nuPlan Test | `nuplan_test` | N/A | `boston`, `singapore`, `pittsburgh`, `las_vegas` | nuPlan testing split (89.33 GB) | 0.05s (20Hz) | :white_check_mark: | +| nuPlan Mini | `nuplan_mini` | `mini_train`, `mini_val`, `mini_test` | `boston`, `singapore`, `pittsburgh`, `las_vegas` | nuPlan mini training/validation/test splits (942/197/224 scenes, 7.96 GB) | 0.05s (20Hz) | :white_check_mark: | +| View-of-Delft Train/TrainVal/Val | `vod_trainval` | `train`, `train_val`, `val` | `delft` | View-of-Delft Prediction training and validation splits | 0.1s (10Hz) | :white_check_mark: | +| View-of-Delft Test | `vod_test` | `test` | `delft` | View-of-Delft Prediction test split | 0.1s (10Hz) | :white_check_mark: | +| Waymo Open Motion Training | `waymo_train` | `train` | N/A | Waymo Open Motion Dataset `training` split | 0.1s (10Hz) | :white_check_mark: | +| Waymo Open Motion Validation | `waymo_val` | `val` | N/A | Waymo Open Motion Dataset `validation` split | 0.1s (10Hz) | :white_check_mark: | +| Waymo Open Motion Testing | `waymo_test` | `test` | N/A | Waymo Open Motion Dataset `testing` split | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Train | `lyft_train` | `train` | `palo_alto` | Lyft Level 5 training data - part 1/2 (8.4 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Train Full | `lyft_train_full` | `train` | `palo_alto` | Lyft Level 5 training data - part 2/2 (70 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Validation | `lyft_val` | `val` | `palo_alto` | Lyft Level 5 validation data (8.2 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Sample | `lyft_sample` | `mini_train`, `mini_val` | `palo_alto` | Lyft Level 5 sample data (100 scenes, randomly split 80/20 for training/validation) | 0.1s (10Hz) | :white_check_mark: | +| Argoverse 2 Motion Forecasting | `av2_motion_forecasting` | `train`, `val`, `test` | N/A | 250,000 motion forecasting scenarios of 11s each | 0.1s (10Hz) | :white_check_mark: | +| INTERACTION Dataset Single-Agent | `interaction_single` | `train`, `val`, `test`, `test_conditional` | `usa`, `china`, `germany`, `bulgaria` | Single-agent split of the INTERACTION Dataset (where the goal is to predict one target agents' future motion) | 0.1s (10Hz) | :white_check_mark: | +| INTERACTION Dataset Multi-Agent | `interaction_multi` | `train`, `val`, `test`, `test_conditional` | `usa`, `china`, `germany`, `bulgaria` | Multi-agent split of the INTERACTION Dataset (where the goal is to jointly predict multiple agents' future motion) | 0.1s (10Hz) | :white_check_mark: | | ETH - Univ | `eupeds_eth` | `train`, `val`, `train_loo`, `val_loo`, `test_loo` | `zurich` | The ETH (University) scene from the ETH BIWI Walking Pedestrians dataset | 0.4s (2.5Hz) | | | ETH - Hotel | `eupeds_hotel` | `train`, `val`, `train_loo`, `val_loo`, `test_loo` | `zurich` | The Hotel scene from the ETH BIWI Walking Pedestrians dataset | 0.4s (2.5Hz) | | | UCY - Univ | `eupeds_univ` | `train`, `val`, `train_loo`, `val_loo`, `test_loo` | `cyprus` | The University scene from the UCY Pedestrians dataset | 0.4s (2.5Hz) | | | UCY - Zara1 | `eupeds_zara1` | `train`, `val`, `train_loo`, `val_loo`, `test_loo` | `cyprus` | The Zara1 scene from the UCY Pedestrians dataset | 0.4s (2.5Hz) | | | UCY - Zara2 | `eupeds_zara2` | `train`, `val`, `train_loo`, `val_loo`, `test_loo` | `cyprus` | The Zara2 scene from the UCY Pedestrians dataset | 0.4s (2.5Hz) | | +| Stanford Drone Dataset | `sdd` | `train`, `val`, `test` | `stanford` | Stanford Drone Dataset (60 scenes, randomly split 42/9/9 (70%/15%/15%) for training/validation/test) | 0.0333...s (30Hz) | | + +### Adding New Datasets +The code that interfaces the original datasets (dealing with their unique formats) can be found in `src/trajdata/dataset_specific`. + +To add a new dataset, one needs to: +- Create a new folder under `src/trajdata/dataset_specific` which will contain all the code specific to a particular dataset (e.g., for extracting data into our canonical format). In particular, there must be: + - An `__init__.py` file. + - A file that defines a subclass of `RawDataset` and implements some of its functions. Reference implementations can be found in the `nusc/nusc_dataset.py`, `lyft/lyft_dataset.py`, and `eth_ucy_peds/eupeds_dataset.py` files. +- Add a subclass of `NamedTuple` to `src/trajdata/dataset_specific/scene_records.py` which contains the minimal set of information sufficient to describe a scene. This "scene record" will be used in conjunction with the raw dataset class above and relates to how scenes are stored and efficiently accessed later. +- Add a section to the `DATASETS.md` file which outlines how users should store the raw dataset locally. +- Add a section to `src/trajdata/utils/env_utils.py` which allows users to get the raw dataset via its name, and specify if the dataset is a good candidate for parallel processing (e.g., does its native dataset object have a large memory footprint which might not allow it to be loaded in multiple processes, such as nuScenes?) and if it has maps. ## Examples +Please see the `examples/` folder for more examples, below are just a few demonstrations of core capabilities. + ### Multiple Datasets The following will load data from both the nuScenes mini dataset as well as the ETH - University scene from the ETH BIWI Walking Pedestrians dataset. @@ -114,10 +153,18 @@ dataset = UnifiedDataset( **Note**: Be careful about loading multiple datasets without an associated `desired_dt` argument; many datasets do not share the same underlying data annotation frequency. To address this, we've implemented timestep interpolation to a common frequency which will ensure that all batched data shares the same dt. Interpolation can only be performed to integer multiples of the original data annotation frequency. For example, nuScenes' `dt=0.5` and the ETH BIWI dataset's `dt=0.4` can be interpolated to a common `desired_dt=0.1`. -## Adding New Datasets -The code that interfaces raw datasets can be found in `src/trajdata/dataset_specific`. +## Map API +`trajdata` also provides an API to access the raw vector map information from datasets that provide it. -To add a new dataset, ... +```py +from pathlib import Path +from trajdata import MapAPI, VectorMap + +cache_path = Path("~/.unified_data_cache").expanduser() +map_api = MapAPI(cache_path) + +vector_map: VectorMap = map_api.get_map("nusc_mini:boston-seaport") +``` ## Simulation Interface One additional feature of trajdata is that it can be used to initialize simulations from real data and track resulting agent motion, metrics, etc. @@ -151,7 +198,7 @@ sim_scene = SimulationScene( ) obs: AgentBatch = sim_scene.reset() -for t in range(1, sim_scene.scene_info.length_timesteps): +for t in range(1, sim_scene.scene.length_timesteps): new_xyh_dict: Dict[str, np.ndarray] = dict() # Everything inside the forloop just sets @@ -170,8 +217,20 @@ for t in range(1, sim_scene.scene_info.length_timesteps): `examples/sim_example.py` contains a more comprehensive example which initializes a simulation from a scene in the nuScenes mini dataset, steps through it by replaying agents' GT motions, and computes metrics based on scene statistics (e.g., displacement error from the original GT data, velocity/acceleration/jerk histograms). +## Citation + +If you use this software, please cite it as follows: +``` +@Inproceedings{ivanovic2023trajdata, + author = {Ivanovic, Boris and Song, Guanyu and Gilitschenski, Igor and Pavone, Marco}, + title = {{trajdata}: A Unified Interface to Multiple Human Trajectory Datasets}, + booktitle = {{Proceedings of the Neural Information Processing Systems (NeurIPS) Track on Datasets and Benchmarks}}, + month = dec, + year = {2023}, + address = {New Orleans, USA}, + url = {https://arxiv.org/abs/2307.13924} +} +``` + ## TODO - Create a method like finalize() which writes all the batch information to a TFRecord/WebDataset/some other format which is (very) fast to read from for higher epoch training. -- Add more examples to the README. -- Finish README section about how to add a new dataset. - diff --git a/examples/batch_example.py b/examples/batch_example.py index fa8f88b..385861b 100644 --- a/examples/batch_example.py +++ b/examples/batch_example.py @@ -20,8 +20,12 @@ def main(): only_predict=[AgentType.VEHICLE], agent_interaction_distances=defaultdict(lambda: 30.0), incl_robot_future=False, - incl_map=True, - map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, augmentations=[noise_hists], num_workers=0, verbose=True, diff --git a/examples/cache_and_filter_example.py b/examples/cache_and_filter_example.py new file mode 100644 index 0000000..64c13d1 --- /dev/null +++ b/examples/cache_and_filter_example.py @@ -0,0 +1,97 @@ +import os +from collections import defaultdict + +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.augmentation import NoiseHistories +from trajdata.data_structures.batch_element import AgentBatchElement +from trajdata.visualization.vis import plot_agent_batch + + +def main(): + noise_hists = NoiseHistories() + + create_dataset = lambda: UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="agent", + desired_dt=0.5, + history_sec=(2.0, 2.0), + future_sec=(4.0, 4.0), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=False, + incl_raster_map=False, + # map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, + augmentations=[noise_hists], + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataset = create_dataset() + + print(f"# Data Samples: {len(dataset):,}") + + print( + "To demonstrate how to use caching we will first save the " + "entire dataset (all BatchElements) to a cache file and then load from " + "the cache file. Note that for large datasets and/or high time resolution " + "this will create a large file and will use a lot of RAM." + ) + cache_path = "./temp_cache_file.dill" + + print( + "We also use a custom filter function that only keeps elements with more " + "than 5 neighbors" + ) + + def my_filter(el: AgentBatchElement) -> bool: + return el.num_neighbors > 5 + + print( + f"In the first run we will iterate through the entire dataset and save all " + f"BatchElements to the cache file {cache_path}" + ) + print("This may take several minutes.") + dataset.load_or_create_cache( + cache_path=cache_path, num_workers=0, filter_fn=my_filter + ) + assert os.path.isfile(cache_path) + + print( + "To demonstrate a consecuitve run we create a new dataset and load elements " + "from the cache file." + ) + del dataset + dataset = create_dataset() + + dataset.load_or_create_cache( + cache_path=cache_path, num_workers=0, filter_fn=my_filter + ) + + # Remove the temp cache file, we dont need it anymore. + os.remove(cache_path) + + print( + "We can iterate through the dataset the same way as normally, but this " + "time it will be much faster because all BatchElements are in memory." + ) + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + + batch: AgentBatch + for batch in tqdm(dataloader): + plot_agent_batch(batch, batch_idx=0) + + +if __name__ == "__main__": + main() diff --git a/examples/custom_batch_data.py b/examples/custom_batch_data.py index 24828eb..3b98c74 100644 --- a/examples/custom_batch_data.py +++ b/examples/custom_batch_data.py @@ -11,9 +11,7 @@ from tqdm import tqdm from trajdata import AgentBatch, AgentType, UnifiedDataset -from trajdata.augmentation import NoiseHistories from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement -from trajdata.visualization.vis import plot_agent_batch def custom_random_data( @@ -27,7 +25,7 @@ def custom_goal_location( batch_elem: Union[AgentBatchElement, SceneBatchElement] ) -> np.ndarray: # simply access existing element attributes - return batch_elem.agent_future_np[:, :2] + return batch_elem.agent_future_np.position def custom_min_distance_from_others( @@ -74,8 +72,12 @@ def main(): only_types=[AgentType.VEHICLE], agent_interaction_distances=defaultdict(lambda: 30.0), incl_robot_future=False, - incl_map=True, - map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, num_workers=0, verbose=True, data_dirs={ # Remember to change this to match your filesystem! diff --git a/examples/lane_query_example.py b/examples/lane_query_example.py new file mode 100644 index 0000000..86a9004 --- /dev/null +++ b/examples/lane_query_example.py @@ -0,0 +1,137 @@ +""" +This is an example of how to extend a batch with lane information +""" + +import random +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.data_structures.batch_element import AgentBatchElement +from trajdata.maps import VectorMap +from trajdata.utils.arr_utils import transform_angles_np, transform_coords_np +from trajdata.utils.state_utils import transform_state_np_2d +from trajdata.visualization.vis import plot_agent_batch + + +def get_closest_lane_point(element: AgentBatchElement) -> np.ndarray: + """Closest lane for predicted agent.""" + + # Transform from agent coordinate frame to world coordinate frame. + vector_map: VectorMap = element.vec_map + world_from_agent_tf = np.linalg.inv(element.agent_from_world_tf) + agent_future_xyzh_world = transform_state_np_2d( + element.agent_future_np, world_from_agent_tf + ).as_format("x,y,z,h") + + # Use cached kdtree to find closest lane point + lane_points = [] + for point_xyzh in agent_future_xyzh_world: + possible_lanes = vector_map.get_current_lane(point_xyzh) + xyzh_on_lane = np.full((1, 4), np.nan) + if len(possible_lanes) > 0: + xyzh_on_lane = possible_lanes[0].center.project_onto(point_xyzh[None, :3]) + xyzh_on_lane[:, :2] = transform_coords_np( + xyzh_on_lane[:, :2], element.agent_from_world_tf + ) + xyzh_on_lane[:, -1] = transform_angles_np( + xyzh_on_lane[:, -1], element.agent_from_world_tf + ) + + lane_points.append(xyzh_on_lane) + + lane_points = np.concatenate(lane_points, axis=0) + return lane_points + + +def main(): + dataset = UnifiedDataset( + desired_data=[ + # "nusc_mini-mini_train", + "lyft_sample-mini_val", + ], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_types=[AgentType.VEHICLE], + state_format="x,y,z,xd,yd,xdd,ydd,h", + obs_format="x,y,z,xd,yd,xdd,ydd,s,c", + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=False, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + incl_vector_map=True, + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + # "nusc_mini": "~/datasets/nuScenes", + "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", + }, + # A dictionary that contains functions that generate our custom data. + # Can be any function and has access to the batch element. + extras={ + "closest_lane_point": get_closest_lane_point, + }, + ) + + print(f"# Data Samples: {len(dataset):,}") + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=False, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + + # Visualize selected examples + num_plots = 3 + # batch_idxs = [10876, 10227, 1284] + batch_idxs = random.sample(range(len(dataset)), num_plots) + batch: AgentBatch = dataset.get_collate_fn(pad_format="right")( + [dataset[i] for i in batch_idxs] + ) + assert "closest_lane_point" in batch.extras + + for batch_i in range(num_plots): + ax = plot_agent_batch( + batch, batch_idx=batch_i, legend=False, show=False, close=False + ) + lane_points = batch.extras["closest_lane_point"][batch_i] + lane_points = lane_points[ + torch.logical_not(torch.any(torch.isnan(lane_points), dim=1)), : + ].numpy() + + ax.plot( + lane_points[:, 0], + lane_points[:, 1], + "o-", + markersize=3, + label="Lane points", + ) + + ax.legend(loc="best", frameon=True) + + plt.show() + plt.close("all") + + # Scan through dataset + batch: AgentBatch + for idx, batch in enumerate(tqdm(dataloader)): + assert "closest_lane_point" in batch.extras + if idx > 50: + break + + +if __name__ == "__main__": + main() diff --git a/examples/map_api_example.py b/examples/map_api_example.py new file mode 100644 index 0000000..759b2e0 --- /dev/null +++ b/examples/map_api_example.py @@ -0,0 +1,268 @@ +import time +from pathlib import Path +from typing import Dict, List, Optional + +import matplotlib.pyplot as plt +import numpy as np + +from trajdata import MapAPI, VectorMap +from trajdata.caching.df_cache import DataFrameCache +from trajdata.caching.env_cache import EnvCache +from trajdata.caching.scene_cache import SceneCache +from trajdata.data_structures.scene_metadata import Scene +from trajdata.maps.vec_map import Polyline, RoadLane +from trajdata.utils import map_utils + + +def load_random_scene(cache_path: Path, env_name: str, scene_dt: float) -> Scene: + env_cache = EnvCache(cache_path) + scenes_list = env_cache.load_env_scenes_list(env_name) + random_scene_name = scenes_list[np.random.randint(0, len(scenes_list))].name + + return env_cache.load_scene(env_name, random_scene_name, scene_dt) + + +def main(): + cache_path = Path("~/.unified_data_cache").expanduser() + map_api = MapAPI(cache_path) + + ### Loading random scene and initializing VectorMap. + env_name: str = np.random.choice(["nusc_mini", "lyft_sample", "nuplan_mini"]) + scene_cache: Optional[SceneCache] = None + if env_name == "nuplan_mini": + # Hardcoding scene_dt = 0.05s for now + # (using nuPlan as our traffic light data example). + random_scene: Scene = load_random_scene(cache_path, env_name, scene_dt=0.05) + scene_cache = DataFrameCache(cache_path, random_scene) + + vec_map: VectorMap = map_api.get_map( + f"{env_name}:{random_scene.location}", scene_cache=scene_cache + ) + else: + random_location: Dict[str, str] = { + "nusc_mini": np.random.choice(["boston-seaport", "singapore-onenorth"]), + "lyft_sample": "palo_alto", + } + + vec_map: VectorMap = map_api.get_map( + f"{env_name}:{random_location[env_name]}", scene_cache=scene_cache + ) + + print(f"Randomly chose {vec_map.env_name}, {vec_map.map_name} map.") + + ### Loading Lane used for the next few figures. + lane: RoadLane = vec_map.lanes[np.random.randint(0, len(vec_map.lanes))] + + ### Lane Interpolation (max_dist) + start = time.perf_counter() + interpolated: Polyline = lane.center.interpolate(max_dist=0.01) + end = time.perf_counter() + print(f"interpolate (max_dist) took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + ax.scatter( + lane.center.points[:, 0], lane.center.points[:, 1], label="Original", s=80 + ) + ax.quiver( + lane.center.points[:, 0], + lane.center.points[:, 1], + np.cos(lane.center.points[:, -1]), + np.sin(lane.center.points[:, -1]), + ) + + ax.scatter( + interpolated.points[:, 0], interpolated.points[:, 1], label="Interpolated" + ) + ax.quiver( + interpolated.points[:, 0], + interpolated.points[:, 1], + np.cos(interpolated.points[:, -1]), + np.sin(interpolated.points[:, -1]), + ) + + ax.legend(loc="best") + ax.axis("equal") + + ### Lane Interpolation (num_pts) + start = time.perf_counter() + interpolated: Polyline = lane.center.interpolate(num_pts=10) + end = time.perf_counter() + print(f"interpolate (num_pts) took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + ax.scatter( + lane.center.points[:, 0], lane.center.points[:, 1], label="Original", s=80 + ) + ax.quiver( + lane.center.points[:, 0], + lane.center.points[:, 1], + np.cos(lane.center.points[:, -1]), + np.sin(lane.center.points[:, -1]), + ) + + ax.scatter( + interpolated.points[:, 0], interpolated.points[:, 1], label="Interpolated" + ) + ax.quiver( + interpolated.points[:, 0], + interpolated.points[:, 1], + np.cos(interpolated.points[:, -1]), + np.sin(interpolated.points[:, -1]), + ) + + ax.legend(loc="best") + ax.axis("equal") + + ### Projection onto Lane + num_pts = 15 + orig_pts = lane.center.midpoint + np.concatenate( + [ + np.random.uniform(-3, 3, size=(num_pts, 2)), # x,y offsets + np.zeros(shape=(num_pts, 1)), # no z offsets + np.random.uniform(-np.pi, np.pi, size=(num_pts, 1)), # headings + ], + axis=-1, + ) + start = time.perf_counter() + proj_pts = lane.center.project_onto(orig_pts) + end = time.perf_counter() + print(f"project_onto ({num_pts} points) took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + ax.plot(lane.center.points[:, 0], lane.center.points[:, 1], label="Lane") + ax.scatter(orig_pts[:, 0], orig_pts[:, 1], label="Original") + ax.quiver( + orig_pts[:, 0], + orig_pts[:, 1], + np.cos(orig_pts[:, -1]), + np.sin(orig_pts[:, -1]), + ) + + ax.scatter(proj_pts[:, 0], proj_pts[:, 1], label="Projected") + ax.quiver( + proj_pts[:, 0], + proj_pts[:, 1], + np.cos(proj_pts[:, -1]), + np.sin(proj_pts[:, -1]), + ) + + ax.legend(loc="best") + ax.axis("equal") + + ### Lane Graph Visualization (with rasterized map in background) + fig, ax = plt.subplots() + map_img, raster_from_world = vec_map.rasterize( + resolution=2, + return_tf_mat=True, + incl_centerlines=False, + area_color=(255, 255, 255), + edge_color=(0, 0, 0), + scene_ts=100, + ) + ax.imshow(map_img, alpha=0.5, origin="lower") + vec_map.visualize_lane_graph( + origin_lane=np.random.randint(0, len(vec_map.lanes)), + num_hops=5, + raster_from_world=raster_from_world, + ax=ax, + ) + ax.axis("equal") + ax.grid(None) + + ### Closest Lane Query (with rasterized map in background) + # vec_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + min_x, min_y, _, max_x, max_y, _ = vec_map.extent + + mean_pt: np.ndarray = np.array( + [ + np.random.uniform(min_x, max_x), + np.random.uniform(min_y, max_y), + 0, + ] + ) + + start = time.perf_counter() + lane: RoadLane = vec_map.get_closest_lane(mean_pt) + end = time.perf_counter() + print(f"get_closest_lane took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + map_img, raster_from_world = vec_map.rasterize( + resolution=2, + return_tf_mat=True, + incl_centerlines=False, + area_color=(255, 255, 255), + edge_color=(0, 0, 0), + ) + ax.imshow(map_img, alpha=0.5, origin="lower") + query_pt_map: np.ndarray = map_utils.transform_points( + mean_pt[None, :2], raster_from_world + )[0] + ax.scatter(*query_pt_map, label="Query Point") + vec_map.visualize_lane_graph( + origin_lane=lane, num_hops=0, raster_from_world=raster_from_world, ax=ax + ) + ax.axis("equal") + ax.grid(None) + + ### Lanes Within Range Query (with rasterized map in background) + radius: float = 20.0 + + # vec_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + min_x, min_y, _, max_x, max_y, _ = vec_map.extent + + mean_pt: np.ndarray = np.array( + [ + np.random.uniform(min_x, max_x), + np.random.uniform(min_y, max_y), + 0, + ] + ) + + start = time.perf_counter() + lanes: List[RoadLane] = vec_map.get_lanes_within(mean_pt, radius) + end = time.perf_counter() + print(f"get_lanes_within took {(end - start)*1000:.2f} ms") + + fig, ax = plt.subplots() + img_resolution: float = 2 + map_img, raster_from_world = vec_map.rasterize( + resolution=img_resolution, + return_tf_mat=True, + incl_centerlines=False, + area_color=(255, 255, 255), + edge_color=(0, 0, 0), + ) + ax.imshow(map_img, alpha=0.5, origin="lower") + + query_pt_map: np.ndarray = map_utils.transform_points( + mean_pt[None, :2], raster_from_world + )[0] + ax.scatter(*query_pt_map, label="Query Point") + circle2 = plt.Circle( + (query_pt_map[0], query_pt_map[1]), + radius * img_resolution, + color="b", + fill=False, + ) + ax.add_patch(circle2) + + for l in lanes: + vec_map.visualize_lane_graph( + origin_lane=l, + num_hops=0, + raster_from_world=raster_from_world, + ax=ax, + legend=False, + ) + + ax.axis("equal") + ax.grid(None) + ax.legend(loc="best", frameon=True) + + plt.show() + plt.close("all") + + +if __name__ == "__main__": + main() diff --git a/examples/preprocess_data.py b/examples/preprocess_data.py index fc03ef5..e6faf03 100644 --- a/examples/preprocess_data.py +++ b/examples/preprocess_data.py @@ -5,7 +5,7 @@ def main(): dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "lyft_sample", "nuplan_mini"], rebuild_cache=True, rebuild_maps=True, num_workers=os.cpu_count(), @@ -13,6 +13,7 @@ def main(): data_dirs={ # Remember to change this to match your filesystem! "nusc_mini": "~/datasets/nuScenes", "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", }, ) print(f"Total Data Samples: {len(dataset):,}") diff --git a/examples/preprocess_maps.py b/examples/preprocess_maps.py index 21042a2..ec10d8a 100644 --- a/examples/preprocess_maps.py +++ b/examples/preprocess_maps.py @@ -1,17 +1,18 @@ -import os - from trajdata import UnifiedDataset # @profile def main(): dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + # TODO(bivanovic@nvidia.com) Remove lyft from default examples + desired_data=["nusc_mini", "lyft_sample", "nuplan_mini"], rebuild_maps=True, data_dirs={ # Remember to change this to match your filesystem! "nusc_mini": "~/datasets/nuScenes", "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", }, + verbose=True, ) print(f"Finished Caching Maps!") diff --git a/examples/scene_batch_example.py b/examples/scene_batch_example.py index 7320b00..9340e07 100644 --- a/examples/scene_batch_example.py +++ b/examples/scene_batch_example.py @@ -20,8 +20,12 @@ def main(): only_types=[AgentType.VEHICLE], agent_interaction_distances=defaultdict(lambda: 30.0), incl_robot_future=True, - incl_map=True, - map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, augmentations=[noise_hists], max_agent_num=20, num_workers=4, diff --git a/examples/scenetimebatcher_example.py b/examples/scenetimebatcher_example.py new file mode 100644 index 0000000..d919a61 --- /dev/null +++ b/examples/scenetimebatcher_example.py @@ -0,0 +1,53 @@ +from collections import defaultdict + +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.utils.batch_utils import SceneTimeBatcher +from trajdata.visualization.vis import plot_agent_batch_all + + +def main(): + """ + Here, we use SceneTimeBatcher to loop through an + Agent-centric dataset with batches grouped by scene and timestep + """ + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_train"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=False, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + print(f"# Data Samples: {len(dataset):,}") + + dataloader = DataLoader( + dataset, + batch_sampler=SceneTimeBatcher(dataset), + collate_fn=dataset.get_collate_fn(), + num_workers=4, + ) + + batch: AgentBatch + for batch in tqdm(dataloader): + plot_agent_batch_all(batch) + + +if __name__ == "__main__": + main() diff --git a/examples/sim_example.py b/examples/sim_example.py index 2a14974..c58837a 100644 --- a/examples/sim_example.py +++ b/examples/sim_example.py @@ -6,6 +6,7 @@ from trajdata import AgentBatch, AgentType, UnifiedDataset from trajdata.data_structures.scene_metadata import Scene +from trajdata.data_structures.state import StateArray from trajdata.simulation import SimulationScene, sim_metrics, sim_stats, sim_vis from trajdata.visualization.vis import plot_agent_batch @@ -20,7 +21,6 @@ def main(): # "px_per_m": 2, # "map_size_px": 224, # "offset_frac_xy": (0.0, 0.0), - # "return_rgb": True, # }, verbose=True, # desired_dt=0.1, @@ -55,32 +55,31 @@ def main(): obs: AgentBatch = sim_scene.reset() for t in trange(1, sim_scene.scene.length_timesteps): - new_xyh_dict: Dict[str, np.ndarray] = dict() + new_xyzh_dict: Dict[str, StateArray] = dict() for idx, agent_name in enumerate(obs.agent_name): - curr_yaw = obs.curr_agent_state[idx, -1] - curr_pos = obs.curr_agent_state[idx, :2] + curr_yaw = obs.curr_agent_state[idx].heading.item() + curr_pos = obs.curr_agent_state[idx].position.numpy() world_from_agent = np.array( [ [np.cos(curr_yaw), np.sin(curr_yaw)], [-np.sin(curr_yaw), np.cos(curr_yaw)], ] ) - next_state = np.zeros((3,)) + next_state = np.zeros((4,)) if obs.agent_fut_len[idx] < 1: next_state[:2] = curr_pos yaw_ac = 0 else: next_state[:2] = ( - obs.agent_fut[idx, 0, :2] @ world_from_agent + curr_pos - ) - yaw_ac = np.arctan2( - obs.agent_fut[idx, 0, -2], obs.agent_fut[idx, 0, -1] + obs.agent_fut[idx, 0].position.numpy() @ world_from_agent + + curr_pos ) + yaw_ac = obs.agent_fut[idx, 0].heading.item() - next_state[2] = curr_yaw + yaw_ac - new_xyh_dict[agent_name] = next_state + next_state[-1] = curr_yaw + yaw_ac + new_xyzh_dict[agent_name] = StateArray.from_array(next_state, "x,y,z,h") - obs = sim_scene.step(new_xyh_dict) + obs = sim_scene.step(new_xyzh_dict) metrics: Dict[str, Dict[str, float]] = sim_scene.get_metrics([ade, fde]) print(metrics) diff --git a/examples/simple_map_api_example.py b/examples/simple_map_api_example.py new file mode 100644 index 0000000..5aa6400 --- /dev/null +++ b/examples/simple_map_api_example.py @@ -0,0 +1,110 @@ +import time +from pathlib import Path +from typing import Dict + +import matplotlib.pyplot as plt +import numpy as np + +from trajdata import MapAPI, VectorMap +from trajdata.maps.vec_map_elements import MapElementType +from trajdata.utils import map_utils + + +def main(): + cache_path = Path("~/.unified_data_cache").expanduser() + map_api = MapAPI(cache_path) + + ### Loading random scene and initializing VectorMap. + env_name: str = np.random.choice(["nusc_mini", "lyft_sample", "nuplan_mini"]) + random_location_dict: Dict[str, str] = { + "nuplan_mini": np.random.choice( + ["boston", "singapore", "pittsburgh", "las_vegas"] + ), + "nusc_mini": np.random.choice(["boston-seaport", "singapore-onenorth"]), + "lyft_sample": "palo_alto", + } + + start = time.perf_counter() + vec_map: VectorMap = map_api.get_map( + f"{env_name}:{random_location_dict[env_name]}", incl_road_areas=True + ) + end = time.perf_counter() + print(f"Map loading took {(end - start)*1000:.2f} ms") + + start = time.perf_counter() + vec_map: VectorMap = map_api.get_map( + f"{env_name}:{random_location_dict[env_name]}", incl_road_areas=True + ) + end = time.perf_counter() + print(f"Repeated (cached in memory) map loading took {(end - start)*1000:.2f} ms") + + print(f"Randomly chose {vec_map.env_name}, {vec_map.map_name} map.") + + ### Lane Graph Visualization (with rasterized map in background) + fig, ax = plt.subplots() + + print(f"Rasterizing Map...") + start = time.perf_counter() + map_img, raster_from_world = vec_map.rasterize( + resolution=2, + return_tf_mat=True, + incl_centerlines=False, + area_color=(255, 255, 255), + edge_color=(0, 0, 0), + scene_ts=100, + ) + end = time.perf_counter() + print(f"Map rasterization took {(end - start)*1000:.2f} ms") + + ax.imshow(map_img, alpha=0.5, origin="lower") + + lane_idx = np.random.randint(0, len(vec_map.lanes)) + print(f"Visualizing random lane index {lane_idx}...") + start = time.perf_counter() + vec_map.visualize_lane_graph( + origin_lane=lane_idx, + num_hops=10, + raster_from_world=raster_from_world, + ax=ax, + ) + end = time.perf_counter() + print(f"Lane visualization took {(end - start)*1000:.2f} ms") + + point = vec_map.lanes[lane_idx].center.xyz[0, :] + point_raster = map_utils.transform_points( + point[None, :], transf_mat=raster_from_world + ) + ax.scatter(point_raster[:, 0], point_raster[:, 1]) + + print("Getting nearest road area...") + start = time.perf_counter() + area = vec_map.get_closest_area(point, elem_type=MapElementType.ROAD_AREA) + end = time.perf_counter() + print(f"Getting nearest area took {(end-start)*1000:.2f} ms") + + raster_pts = map_utils.transform_points(area.exterior_polygon.xy, raster_from_world) + ax.fill(raster_pts[:, 0], raster_pts[:, 1], alpha=1.0, color="C0") + + print("Getting road areas within 100m...") + start = time.perf_counter() + areas = vec_map.get_areas_within( + point, elem_type=MapElementType.ROAD_AREA, dist=100.0 + ) + end = time.perf_counter() + print(f"Getting areas within took {(end-start)*1000:.2f} ms") + + for area in areas: + raster_pts = map_utils.transform_points( + area.exterior_polygon.xy, raster_from_world + ) + ax.fill(raster_pts[:, 0], raster_pts[:, 1], alpha=0.2, color="C1") + + ax.axis("equal") + ax.grid(None) + + plt.show() + plt.close("all") + + +if __name__ == "__main__": + main() diff --git a/examples/simple_sim_example.py b/examples/simple_sim_example.py index 6ce2623..ec20778 100644 --- a/examples/simple_sim_example.py +++ b/examples/simple_sim_example.py @@ -1,9 +1,11 @@ from typing import Dict # Just for type annotations import numpy as np +from tqdm import trange from trajdata import AgentBatch, UnifiedDataset -from trajdata.data_structures.scene_metadata import Scene # Just for type annotations +from trajdata.data_structures.scene_metadata import Scene +from trajdata.data_structures.state import StateArray # Just for type annotations from trajdata.simulation import SimulationScene dataset = UnifiedDataset( @@ -24,18 +26,18 @@ ) obs: AgentBatch = sim_scene.reset() -for t in range(1, sim_scene.scene.length_timesteps): - new_xyh_dict: Dict[str, np.ndarray] = dict() +for t in trange(1, sim_scene.scene.length_timesteps): + new_xyzh_dict: Dict[str, StateArray] = dict() # Everything inside the forloop just sets # agents' next states to their current ones. for idx, agent_name in enumerate(obs.agent_name): - curr_yaw = obs.curr_agent_state[idx, -1] - curr_pos = obs.curr_agent_state[idx, :2] + curr_yaw = obs.curr_agent_state[idx].heading.item() + curr_pos = obs.curr_agent_state[idx].position.numpy() - next_state = np.zeros((3,)) + next_state = np.zeros((4,)) next_state[:2] = curr_pos - next_state[2] = curr_yaw - new_xyh_dict[agent_name] = next_state + next_state[-1] = curr_yaw + new_xyzh_dict[agent_name] = StateArray.from_array(next_state, "x,y,z,h") - obs = sim_scene.step(new_xyh_dict) + obs = sim_scene.step(new_xyzh_dict) diff --git a/examples/speed_example.py b/examples/speed_example.py new file mode 100644 index 0000000..4710020 --- /dev/null +++ b/examples/speed_example.py @@ -0,0 +1,54 @@ +import os +from collections import defaultdict + +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.augmentation import NoiseHistories + + +def main(): + noise_hists = NoiseHistories() + + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_train"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + incl_vector_map=True, + augmentations=[noise_hists], + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + print(f"# Data Samples: {len(dataset):,}") + + dataloader = DataLoader( + dataset, + batch_size=64, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=os.cpu_count() // 2, + ) + + batch: AgentBatch + for batch in tqdm(dataloader): + pass + + +if __name__ == "__main__": + main() diff --git a/examples/state_example.py b/examples/state_example.py new file mode 100644 index 0000000..21b3735 --- /dev/null +++ b/examples/state_example.py @@ -0,0 +1,96 @@ +from collections import defaultdict + +import numpy as np +from torch.utils.data import DataLoader + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.data_structures.state import StateArray, StateTensor + + +def main(): + dataset = UnifiedDataset( + desired_data=["lyft_sample-mini_val"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + state_format="x,y,z,xd,yd,xdd,ydd,h", + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=False, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "lyft_sample": "~/datasets/lyft_sample/scenes/sample.zarr", + }, + ) + + print(f"# Data Samples: {len(dataset):,}") + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=4, + ) + + # batchElement has properties that correspond to agent states + ego_state = dataset[0].curr_agent_state_np.copy() + print(ego_state) + + # StateArray types offer easy conversion to whatever format you want your state + # e.g. we want x,y position and cos/sin heading: + print(ego_state.as_format("x,y,c,s")) + + # We can also access elements via properties + print(ego_state.position3d) + print(ego_state.velocity) + + # We can set elements of states via properties. E.g., let's reset the heading to 0 + ego_state.heading = 0 + print(ego_state) + + # We can request elements that aren't directly stored in the state, e.g. cos/sin heading + print(ego_state.heading_vector) + + # However, we can't set properties that aren't directly stored in the state tensor + try: + ego_state.heading_vector = 0.0 + except AttributeError as e: + print(e) + + # Finally, StateArrays are just np.ndarrays under the hood, and any normal np operation + # should convert them to a normal array + print(ego_state**2) + + # To convert an np.array into a StateArray, we just need to specify what format it is + # Note that StateArrays can have an arbitrary number of batch elems + print(StateArray.from_array(np.random.randn(1, 2, 3), "x,y,z")) + + # Analagous to StateArray wrapping np.arrays, the StateTensor class gives the same + # functionality to torch.Tensors + batch: AgentBatch = next(iter(dataloader)) + ego_state_t: StateTensor = batch.curr_agent_state + + print(ego_state_t.as_format("x,y,c,s")) + print(ego_state_t.position3d) + print(ego_state_t.velocity) + ego_state_t.heading = 0 + print(ego_state_t) + print(ego_state_t.heading_vector) + + # Furthermore, we can use the from_numpy() and numpy() methods to convert to and from + # StateTensors with the same format + print(ego_state_t.numpy()) + print(StateTensor.from_numpy(ego_state)) + + +if __name__ == "__main__": + main() diff --git a/examples/visualization_example.py b/examples/visualization_example.py new file mode 100644 index 0000000..ba65f53 --- /dev/null +++ b/examples/visualization_example.py @@ -0,0 +1,68 @@ +from collections import defaultdict + +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.visualization.interactive_animation import ( + InteractiveAnimation, + animate_agent_batch_interactive, +) +from trajdata.visualization.interactive_vis import plot_agent_batch_interactive +from trajdata.visualization.vis import plot_agent_batch + + +def main(): + dataset = UnifiedDataset( + desired_data=["nusc_mini"], + centric="agent", + desired_dt=0.1, + # history_sec=(3.2, 3.2), + # future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + state_format="x,y,z,xd,yd,h", + obs_format="x,y,z,xd,yd,s,c", + # agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=False, + incl_raster_map=True, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, + ) + + print(f"# Data Samples: {len(dataset):,}") + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + + batch: AgentBatch + for batch in tqdm(dataloader): + plot_agent_batch_interactive(batch, batch_idx=0, cache_path=dataset.cache_path) + plot_agent_batch(batch, batch_idx=0) + + animation = InteractiveAnimation( + animate_agent_batch_interactive, + batch=batch, + batch_idx=0, + cache_path=dataset.cache_path, + ) + animation.show() + # break + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index def1233..2413b5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,8 +2,50 @@ profile = "black" [build-system] -requires = [ - "setuptools>=58", - "wheel" +requires = ["setuptools>=58", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3.8", + "License :: OSI Approved :: Apache Software License", +] +name = "trajdata" +version = "1.4.0" +authors = [{ name = "Boris Ivanovic", email = "bivanovic@nvidia.com" }] +description = "A unified interface to many trajectory forecasting datasets." +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "numpy>=1.19", + "tqdm>=4.62", + "matplotlib>=3.5", + "dill>=0.3.4", + "pandas>=1.4.1", + "pyarrow>=7.0.0", + "torch>=1.10.2", + "zarr>=2.11.0", + "kornia>=0.6.4", + "seaborn>=0.12", + "bokeh>=3.0.3", + "geopandas>=0.13.2", + "protobuf==3.19.4", + "scipy>=1.9.0", + "opencv-python>=4.5.0", + "shapely>=2.0.0", ] -build-backend = "setuptools.build_meta" \ No newline at end of file + +[project.optional-dependencies] +av2 = ["av2==0.2.1"] +dev = ["black", "isort", "pytest", "pytest-xdist", "twine", "build"] +interaction = ["lanelet2==1.2.1"] +lyft = ["l5kit==1.5.0"] +nusc = ["nuscenes-devkit==1.1.9"] +waymo = ["tensorflow==2.11.0", "waymo-open-dataset-tf-2-11-0", "intervaltree"] +vod = ["vod-devkit==1.1.1"] + +[project.urls] +"Homepage" = "https://github.com/nvr-avg/trajdata" +"Bug Tracker" = "https://github.com/nvr-avg/trajdata/issues" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 17d7c79..0000000 --- a/requirements.txt +++ /dev/null @@ -1,24 +0,0 @@ -numpy -tqdm -matplotlib -dill -pandas -pyarrow -torch -zarr -kornia - -# nuScenes devkit -nuscenes-devkit==1.1.9 - -# Lyft Level 5 devkit -protobuf==3.19.4 -l5kit==1.5.0 - -# Development -black -isort -pytest -pytest-xdist -twine -build diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 8ecbf21..0000000 --- a/setup.cfg +++ /dev/null @@ -1,48 +0,0 @@ -[metadata] -name = trajdata -version = 1.0.8 -author = Boris Ivanovic -author_email = bivanovic@nvidia.com -description = A unified interface to many trajectory forecasting datasets. -long_description = file: README.md -long_description_content_type = text/markdown -license = Apache License 2.0 -url = https://github.com/nvr-avg/trajdata -classifiers = - Development Status :: 3 - Alpha - Intended Audience :: Developers - Programming Language :: Python :: 3.8 - License :: OSI Approved :: Apache Software License - -[options] -package_dir = - = src -packages = find: -python_requires = >=3.8 -install_requires = - numpy>=1.19 - tqdm>=4.62 - matplotlib>=3.5 - dill>=0.3.4 - pandas>=1.4.1 - pyarrow>=7.0.0 - torch>=1.10.2 - zarr>=2.11.0 - kornia>=0.6.4 - -[options.packages.find] -where = src - -[options.extras_require] -dev = - black - isort - pytest - pytest-xdist - twine - build -nusc = - nuscenes-devkit==1.1.9 -lyft = - protobuf==3.19.4 - l5kit==1.5.0 diff --git a/src/trajdata/__init__.py b/src/trajdata/__init__.py index 2da9e8e..b639ce9 100644 --- a/src/trajdata/__init__.py +++ b/src/trajdata/__init__.py @@ -1,2 +1,3 @@ from .data_structures import AgentBatch, AgentType, SceneBatch from .dataset import UnifiedDataset +from .maps import MapAPI, VectorMap diff --git a/src/trajdata/caching/__init__.py b/src/trajdata/caching/__init__.py index 64ed55b..17e837a 100644 --- a/src/trajdata/caching/__init__.py +++ b/src/trajdata/caching/__init__.py @@ -1,3 +1,2 @@ -from .df_cache import DataFrameCache from .env_cache import EnvCache from .scene_cache import SceneCache diff --git a/src/trajdata/caching/df_cache.py b/src/trajdata/caching/df_cache.py index 8f9c9f9..329bb3d 100644 --- a/src/trajdata/caching/df_cache.py +++ b/src/trajdata/caching/df_cache.py @@ -1,7 +1,22 @@ +from __future__ import annotations + +import warnings +from decimal import Decimal +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trajdata.maps import ( + RasterizedMap, + RasterizedMapMetadata, + VectorMap, + ) + from trajdata.maps.map_kdtree import MapElementKDTree + from trajdata.maps.map_strtree import MapElementSTRTree + import pickle from math import ceil, floor from pathlib import Path -from typing import Callable, Dict, Final, List, Optional, Tuple +from typing import Any, Dict, Final, List, Optional, Tuple import dill import kornia @@ -14,20 +29,23 @@ from trajdata.caching.scene_cache import SceneCache from trajdata.data_structures.agent import AgentMetadata, FixedExtent from trajdata.data_structures.scene_metadata import Scene -from trajdata.maps import RasterizedMap, RasterizedMapMetadata -from trajdata.proto.vectorized_map_pb2 import VectorizedMap -from trajdata.utils import arr_utils +from trajdata.data_structures.state import NP_STATE_TYPES, StateArray +from trajdata.maps.traffic_light_status import TrafficLightStatus +from trajdata.utils import arr_utils, df_utils, raster_utils, state_utils -STATE_COLS: Final[List[str]] = ["x", "y", "vx", "vy", "ax", "ay"] +STATE_COLS: Final[List[str]] = ["x", "y", "z", "vx", "vy", "ax", "ay"] EXTENT_COLS: Final[List[str]] = ["length", "width", "height"] +# TODO(apoorva) this is kind of serving the same purpose as STATE_COLS above +STATE_FORMAT_STR: Final[str] = "x,y,z,xd,yd,xdd,ydd,h" +RawStateArray = NP_STATE_TYPES[STATE_FORMAT_STR] + class DataFrameCache(SceneCache): def __init__( self, cache_path: Path, scene: Scene, - scene_ts: Optional[int] = 0, augmentations: Optional[List[Augmentation]] = None, ) -> None: """ @@ -36,7 +54,7 @@ def __init__( and pickle for miscellaneous supporting objects. Maps are pre-rasterized and stored as Zarr arrays. """ - super().__init__(cache_path, scene, scene_ts, augmentations) + super().__init__(cache_path, scene, augmentations) agent_data_path: Path = self.scene_dir / DataFrameCache._agent_data_file( scene.dt @@ -51,7 +69,10 @@ def __init__( self._load_agent_data(scene.dt) # Setting default data transformation parameters. - self.reset_transforms() + self.reset_obs_format() + self.reset_obs_frame() + + self._kdtrees = None if augmentations: dataset_augments: List[DatasetAugmentation] = [ @@ -89,6 +110,7 @@ def _get_and_reorder_col_idxs(self) -> None: } self.pos_cols = [self.column_dict["x"], self.column_dict["y"]] + self.z_cols = [self.column_dict["z"]] self.vel_cols = [self.column_dict["vx"], self.column_dict["vy"]] self.acc_cols = [self.column_dict["ax"], self.column_dict["ay"]] if "sin_heading" in self.column_dict: @@ -100,20 +122,26 @@ def _get_and_reorder_col_idxs(self) -> None: self.heading_cols = [self.column_dict["heading"]] self.state_cols = ( - self.pos_cols + self.vel_cols + self.acc_cols + self.heading_cols + self.pos_cols + + self.z_cols + + self.vel_cols + + self.acc_cols + + self.heading_cols ) self._state_dim = len(self.state_cols) - # While it may seem that obs_dim == _state_dim, obs_dim is meant to - # track the output dimension of states (which could differ via - # returing sin and cos of heading rather than just heading itself). - self.obs_dim = self._state_dim - self.extent_cols: List[int] = list() for extent_name in ["length", "width", "height"]: if extent_name in self.column_dict: self.extent_cols.append(self.column_dict[extent_name]) + @property + def obs_dim(self): + """ + obs dim is increased by 1 if we transform heading to sin/cos repr + """ + return self.obs_type.state_dim + def _load_agent_data(self, scene_dt: float) -> None: self.scene_data_df: pd.DataFrame = pd.read_feather( self.scene_dir / DataFrameCache._agent_data_file(scene_dt), @@ -143,7 +171,9 @@ def save_agent_data( cache_path: Path, scene: Scene, ) -> None: - scene_cache_dir: Path = cache_path / scene.env_name / scene.name + scene_cache_dir: Path = DataFrameCache.scene_cache_dir( + cache_path, scene.env_name, scene.name + ) scene_cache_dir.mkdir(parents=True, exist_ok=True) index_dict: Dict[Tuple[str, int], int] = { @@ -189,72 +219,86 @@ def get_value(self, agent_id: str, scene_ts: int, attribute: str) -> float: ) return transformed_pair[0, 0].item() - def get_state(self, agent_id: str, scene_ts: int) -> np.ndarray: - state = self.scene_data_df.iloc[ - self.index_dict[(agent_id, scene_ts)], : self._state_dim - ].to_numpy() + def get_raw_state(self, agent_id: str, scene_ts: int) -> RawStateArray: + return StateArray.from_array( + self.scene_data_df.iloc[ + self.index_dict[(agent_id, scene_ts)], : self._state_dim + ] + .to_numpy() + .copy(), + STATE_FORMAT_STR, + ) + + def get_state(self, agent_id: str, scene_ts: int) -> StateArray: + state = self.get_raw_state(agent_id, scene_ts) - return self._transform_state(state) + return self._observation(state) - def get_states(self, agent_ids: List[str], scene_ts: int) -> np.ndarray: + def get_states(self, agent_ids: List[str], scene_ts: int) -> StateArray: row_idxs: List[int] = [ self.index_dict[(agent_id, scene_ts)] for agent_id in agent_ids ] states = self.scene_data_df.iloc[row_idxs, : self._state_dim].to_numpy() - return self._transform_state(states) - - def transform_data(self, **kwargs) -> None: - if "shift_mean_to" in kwargs: - # This standardizes the scene to be relative to the agent being predicted. - self._transf_mean = kwargs["shift_mean_to"] + return self._observation(states) - if "rotate_by" in kwargs: - # This rotates the scene so that the predicted agent's current - # heading aligns with the x-axis. - agent_heading: float = kwargs["rotate_by"] - self._transf_rotmat: np.ndarray = np.array( - [ - [np.cos(agent_heading), -np.sin(agent_heading)], - [np.sin(agent_heading), np.cos(agent_heading)], - ] - ) - - if "sincos_heading" in kwargs: - self._sincos_heading = True - self.obs_dim += 1 - - def reset_transforms(self) -> None: - self._transf_mean: Optional[np.ndarray] = None - self._transf_rotmat: Optional[np.ndarray] = None - self._sincos_heading: bool = False + def set_obs_frame(self, obs_frame: RawStateArray) -> None: + """ + Sets frame in which to return state observations + """ + self._obs_frame = obs_frame - self.obs_dim = self._state_dim + # compute rotation matrix for convenience + heading = -obs_frame.heading[0] + self._obs_rot_mat = np.array( + [ + [np.cos(heading), -np.sin(heading)], + [np.sin(heading), np.cos(heading)], + ] + ) - def _transform_state(self, data: np.ndarray) -> np.ndarray: - state: np.ndarray = data.copy() # Don't want to alter the original data. + def reset_obs_frame(self) -> None: + """ + Resets observation frame to world frame + """ + self._obs_frame = None + self._obs_rot_mat = None - if len(state.shape) == 1: - state = state[np.newaxis, :] + def set_obs_format(self, format_str: str): + self._obs_format = format_str + self.obs_type = NP_STATE_TYPES[format_str] - if self._transf_mean is not None: - # Shift to zero mean, but leave heading - # standardization for if rotation is also requested (below). - state[..., :-1] -= self._transf_mean[:-1] + def reset_obs_format(self) -> None: + self._obs_format = None + self.obs_type = RawStateArray - if self._transf_rotmat is not None: - state[..., self.pos_cols] = state[..., self.pos_cols] @ self._transf_rotmat - state[..., self.vel_cols] = state[..., self.vel_cols] @ self._transf_rotmat - state[..., self.acc_cols] = state[..., self.acc_cols] @ self._transf_rotmat - state[..., -1] -= self._transf_mean[-1] + def _observation(self, raw_state: np.ndarray) -> StateArray: + """ + Turns raw state into state observation, transforming to + the frame specified by self.set_observation_frame() - if self._sincos_heading: - sin_heading: np.ndarray = np.sin(state[..., [-1]]) - cos_heading: np.ndarray = np.cos(state[..., [-1]]) - state = np.concatenate([state[..., :-1], sin_heading, cos_heading], axis=-1) + Args: + raw_state (np.ndarray): assumes this can be safely viewed as RawStateArray - return state[0] if len(data.shape) == 1 else state + Returns: + StateArray: _description_ + """ + obs = raw_state.copy() + batch_dims = raw_state.shape[:-1] + # apply transformations + if self._obs_frame is not None: + # we know obs is "x,y,z,xd,yd,xdd,ydd,h" + obs = obs - self._obs_frame + # batch rotate vectors + obs[..., (0, 1, 3, 4, 5, 6)] = ( + obs[..., (0, 1, 3, 4, 5, 6)].reshape(-1, 3, 2) + @ self._obs_rot_mat.T[None, :, :] + ).reshape(*batch_dims, 6) + obs = obs.view(RawStateArray) + if self._obs_format is not None: + obs = obs.as_format(self._obs_format) + return obs def _transform_pair( self, data: np.ndarray, col_idxs: Tuple[int, int] @@ -364,13 +408,18 @@ def get_agent_history( agent_info: AgentMetadata, scene_ts: int, history_sec: Tuple[Optional[float], Optional[float]], - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[StateArray, np.ndarray]: # We don't have to check the mins here because our data_index filtering in dataset.py already # took care of it. first_index_incl: int last_index_incl: int = self.index_dict[(agent_info.name, scene_ts)] if history_sec[1] is not None: - max_history: int = floor(history_sec[1] / self.dt) + # Wrapping the input floats with Decimal for exact division + # (avoiding float roundoff errors). + max_history: int = floor( + Decimal(str(history_sec[1])) / Decimal(str(self.dt)) + ) + first_index_incl = self.index_dict[ ( agent_info.name, @@ -389,15 +438,13 @@ def get_agent_history( agent_extent_np: np.ndarray if isinstance(agent_info.extent, FixedExtent): agent_extent_np = agent_info.extent.get_extents( - self.scene_ts - agent_history_df.shape[0] + 1, self.scene_ts + scene_ts - agent_history_df.shape[0] + 1, scene_ts ) else: agent_extent_np = agent_history_df.iloc[:, self.extent_cols].to_numpy() return ( - self._transform_state( - agent_history_df.iloc[:, : self._state_dim].to_numpy() - ), + self._observation(agent_history_df.iloc[:, : self._state_dim].to_numpy()), agent_extent_np, ) @@ -406,9 +453,9 @@ def get_agent_future( agent_info: AgentMetadata, scene_ts: int, future_sec: Tuple[Optional[float], Optional[float]], - ) -> Tuple[np.ndarray, np.ndarray]: - # We don't have to check the mins here because our data_index filtering in dataset.py already - # took care of it. + ) -> Tuple[StateArray, np.ndarray]: + # We don't have to check the mins here because our data_index filtering in + # dataset.py already took care of it. if scene_ts >= agent_info.last_timestep: # Extent shape = 3 return np.zeros((0, self.obs_dim)), np.zeros((0, 3)) @@ -416,7 +463,9 @@ def get_agent_future( first_index_incl: int = self.index_dict[(agent_info.name, scene_ts + 1)] last_index_incl: int if future_sec[1] is not None: - max_future = floor(future_sec[1] / self.dt) + # Wrapping the input floats with Decimal for exact division + # (avoiding float roundoff errors). + max_future = floor(Decimal(str(future_sec[1])) / Decimal(str(self.dt))) last_index_incl = self.index_dict[ (agent_info.name, min(scene_ts + max_future, agent_info.last_timestep)) ] @@ -432,15 +481,13 @@ def get_agent_future( agent_extent_np: np.ndarray if isinstance(agent_info.extent, FixedExtent): agent_extent_np: np.ndarray = agent_info.extent.get_extents( - self.scene_ts + 1, self.scene_ts + agent_future_df.shape[0] + scene_ts + 1, scene_ts + agent_future_df.shape[0] ) else: agent_extent_np = agent_future_df.iloc[:, self.extent_cols].to_numpy() return ( - self._transform_state( - agent_future_df.iloc[:, : self._state_dim].to_numpy() - ), + self._observation(agent_future_df.iloc[:, : self._state_dim].to_numpy()), agent_extent_np, ) @@ -449,12 +496,16 @@ def get_agents_history( scene_ts: int, agents: List[AgentMetadata], history_sec: Tuple[Optional[float], Optional[float]], - ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: + ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: first_timesteps = np.array( - [agent.first_timestep for agent in agents], dtype=np.long + [agent.first_timestep for agent in agents], dtype=int ) if history_sec[1] is not None: - max_history: int = floor(history_sec[1] / self.dt) + # Wrapping the input floats with Decimal for exact division + # (avoiding float roundoff errors). + max_history: int = floor( + Decimal(str(history_sec[1])) / Decimal(str(self.dt)) + ) first_timesteps = np.maximum(scene_ts - max_history, first_timesteps) first_index_incl: np.ndarray = np.array( @@ -462,10 +513,10 @@ def get_agents_history( self.index_dict[(agent.name, first_timesteps[idx])] for idx, agent in enumerate(agents) ], - dtype=np.long, + dtype=int, ) last_index_incl: np.ndarray = np.array( - [self.index_dict[(agent.name, scene_ts)] for agent in agents], dtype=np.long + [self.index_dict[(agent.name, scene_ts)] for agent in agents], dtype=int ) concat_idxs = arr_utils.vrange(first_index_incl, last_index_incl + 1) @@ -473,11 +524,11 @@ def get_agents_history( neighbor_history_lens_np = last_index_incl - first_index_incl + 1 - neighbor_histories_np = self._transform_state( + neighbor_histories_np = self._observation( neighbor_data_df.iloc[:, : self._state_dim].to_numpy() ) # The last one will always be empty because of what cumsum returns. - neighbor_histories: List[np.ndarray] = np.vsplit( + neighbor_histories: List[StateArray] = np.vsplit( neighbor_histories_np, neighbor_history_lens_np.cumsum() )[:-1] @@ -491,8 +542,8 @@ def get_agents_history( else: neighbor_extents = [ agent.extent.get_extents( - self.scene_ts - neighbor_history_lens_np[idx].item() + 1, - self.scene_ts, + scene_ts - neighbor_history_lens_np[idx].item() + 1, + scene_ts, ) for idx, agent in enumerate(agents) ] @@ -508,15 +559,15 @@ def get_agents_future( scene_ts: int, agents: List[AgentMetadata], future_sec: Tuple[Optional[float], Optional[float]], - ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: - last_timesteps = np.array( - [agent.last_timestep for agent in agents], dtype=np.long - ) + ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: + last_timesteps = np.array([agent.last_timestep for agent in agents], dtype=int) first_timesteps = np.minimum(scene_ts + 1, last_timesteps) if future_sec[1] is not None: - max_future: int = floor(future_sec[1] / self.dt) + # Wrapping the input floats with Decimal for exact division + # (avoiding float roundoff errors). + max_future: int = floor(Decimal(str(future_sec[1])) / Decimal(str(self.dt))) last_timesteps = np.minimum(scene_ts + max_future, last_timesteps) first_index_incl: np.ndarray = np.array( @@ -524,14 +575,14 @@ def get_agents_future( self.index_dict[(agent.name, first_timesteps[idx])] for idx, agent in enumerate(agents) ], - dtype=np.long, + dtype=int, ) last_index_incl: np.ndarray = np.array( [ self.index_dict[(agent.name, last_timesteps[idx])] for idx, agent in enumerate(agents) ], - dtype=np.long, + dtype=int, ) concat_idxs = arr_utils.vrange(first_index_incl, last_index_incl + 1) @@ -539,11 +590,11 @@ def get_agents_future( neighbor_future_lens_np = last_index_incl - first_index_incl + 1 - neighbor_futures_np = self._transform_state( + neighbor_futures_np = self._observation( neighbor_data_df.iloc[:, : self._state_dim].to_numpy() ) # The last one will always be empty because of what cumsum returns. - neighbor_futures: List[np.ndarray] = np.vsplit( + neighbor_futures: List[StateArray] = np.vsplit( neighbor_futures_np, neighbor_future_lens_np.cumsum() )[:-1] @@ -557,8 +608,8 @@ def get_agents_future( else: neighbor_extents = [ agent.extent.get_extents( - self.scene_ts - neighbor_future_lens_np[idx].item() + 1, - self.scene_ts, + scene_ts - neighbor_future_lens_np[idx].item() + 1, + scene_ts, ) for idx, agent in enumerate(agents) ] @@ -569,6 +620,80 @@ def get_agents_future( neighbor_future_lens_np, ) + # TRAFFIC LIGHT INFO + @staticmethod + def _tls_data_file(scene_dt: float) -> str: + return f"tls_data_dt{scene_dt:.2f}.feather" + + @staticmethod + def save_traffic_light_data( + traffic_light_status_data: pd.DataFrame, + cache_path: Path, + scene: Scene, + dt: Optional[float] = None, + ) -> None: + """ + Assumes traffic_light_status_data is a MultiIndex dataframe with + lane_id and scene_ts as the indices, and has a column "status" with integer + values for traffic status given by the TrafficLightStatus enum + """ + scene_cache_dir: Path = DataFrameCache.scene_cache_dir( + cache_path, scene.env_name, scene.name + ) + scene_cache_dir.mkdir(parents=True, exist_ok=True) + + if dt is None: + dt = scene.dt + + traffic_light_status_data.reset_index().to_feather( + scene_cache_dir / DataFrameCache._tls_data_file(dt) + ) + + def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bool: + desired_dt = self.dt if desired_dt is None else desired_dt + tls_data_path: Path = self.scene_dir / DataFrameCache._tls_data_file(desired_dt) + return tls_data_path.exists() + + def get_traffic_light_status_dict( + self, desired_dt: Optional[float] = None + ) -> Dict[Tuple[str, int], TrafficLightStatus]: + """ + Returns dict mapping Lane Id, scene_ts to traffic light status for the + particular scene. If data doesn't exist for the current dt, interpolates and + saves the interpolated data to disk for loading later. + """ + desired_dt = self.dt if desired_dt is None else desired_dt + + tls_data_path: Path = self.scene_dir / DataFrameCache._tls_data_file(desired_dt) + if not tls_data_path.exists(): + # Load the original dt traffic light data + tls_orig_dt_df: pd.DataFrame = pd.read_feather( + self.scene_dir + / DataFrameCache._tls_data_file(self.scene.env_metadata.dt), + use_threads=False, + ).set_index(["lane_id", "scene_ts"]) + + # Interpolate it to the desired dt. + tls_data_df = df_utils.interpolate_multi_index_df( + tls_orig_dt_df, self.scene.env_metadata.dt, desired_dt, method="nearest" + ) + + # Save it for the future + DataFrameCache.save_traffic_light_data( + tls_data_df, self.path, self.scene, desired_dt + ) + else: + # Load the data with the desired dt. + tls_data_df: pd.DataFrame = pd.read_feather( + tls_data_path, + use_threads=False, + ).set_index(["lane_id", "scene_ts"]) + + # Return data as dict + return { + idx: TrafficLightStatus(v["status"]) for idx, v in tls_data_df.iterrows() + } + # MAPS @staticmethod def get_maps_path(cache_path: Path, env_name: str) -> Path: @@ -581,14 +706,23 @@ def are_maps_cached(cache_path: Path, env_name: str) -> bool: @staticmethod def get_map_paths( cache_path: Path, env_name: str, map_name: str, resolution: float - ) -> Tuple[Path, Path, Path, Path]: + ) -> Tuple[Path, Path, Path, Path, Path, Path]: maps_path: Path = DataFrameCache.get_maps_path(cache_path, env_name) vector_map_path: Path = maps_path / f"{map_name}.pb" + kdtrees_path: Path = maps_path / f"{map_name}_kdtrees.dill" + rtrees_path: Path = maps_path / f"{map_name}_rtrees.dill" raster_map_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.zarr" raster_metadata_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.dill" - return maps_path, vector_map_path, raster_map_path, raster_metadata_path + return ( + maps_path, + vector_map_path, + kdtrees_path, + rtrees_path, + raster_map_path, + raster_metadata_path, + ) @staticmethod def is_map_cached( @@ -597,75 +731,71 @@ def is_map_cached( ( maps_path, vector_map_path, + kdtrees_path, + rtrees_path, raster_map_path, raster_metadata_path, ) = DataFrameCache.get_map_paths(cache_path, env_name, map_name, resolution) + + # TODO(bivanovic): For now, rtrees are optional to have in the cache. + # In the future, they may be required (likely after we develop an + # incremental caching scheme or similar to handle additions like these). return ( maps_path.exists() and vector_map_path.exists() + and kdtrees_path.exists() + # and rtrees_path.exists() and raster_metadata_path.exists() and raster_map_path.exists() ) @staticmethod - def cache_map( - cache_path: Path, vec_map: VectorizedMap, map_obj: RasterizedMap, env_name: str + def finalize_and_cache_map( + cache_path: Path, + vector_map: VectorMap, + map_params: Dict[str, Any], ) -> None: + raster_resolution: float = map_params["px_per_m"] + ( maps_path, vector_map_path, + kdtrees_path, + rtrees_path, raster_map_path, raster_metadata_path, ) = DataFrameCache.get_map_paths( - cache_path, env_name, map_obj.metadata.name, map_obj.metadata.resolution + cache_path, vector_map.env_name, vector_map.map_name, raster_resolution ) - # Ensuring the maps directory exists. - maps_path.mkdir(parents=True, exist_ok=True) - - # Saving the vectorized map data. - with open(vector_map_path, "wb") as f: - f.write(vec_map.SerializeToString()) - - # Saving the rasterized map data. - zarr.save(raster_map_path, map_obj.data) - - # Saving the rasterized map metadata. - with open(raster_metadata_path, "wb") as f: - dill.dump(map_obj.metadata, f) - - @staticmethod - def cache_map_layers( - cache_path: Path, - vec_map: VectorizedMap, - map_info: RasterizedMapMetadata, - layer_fn: Callable[[str], np.ndarray], - env_name: str, - ) -> None: - ( - maps_path, - vector_map_path, - raster_map_path, - raster_metadata_path, - ) = DataFrameCache.get_map_paths( - cache_path, env_name, map_info.name, map_info.resolution + pbar_kwargs = {"position": 2, "leave": False, "disable": True} + rasterized_map: RasterizedMap = raster_utils.rasterize_map( + vector_map, raster_resolution, **pbar_kwargs ) + vector_map.compute_search_indices() + # Ensuring the maps directory exists. maps_path.mkdir(parents=True, exist_ok=True) # Saving the vectorized map data. with open(vector_map_path, "wb") as f: - f.write(vec_map.SerializeToString()) + f.write(vector_map.to_proto().SerializeToString()) + + # Saving precomputed map element kdtrees. + with open(kdtrees_path, "wb") as f: + dill.dump(vector_map.search_kdtrees, f) + + # Saving precomputed map element rtrees. + with open(rtrees_path, "wb") as f: + dill.dump(vector_map.search_rtrees, f) # Saving the rasterized map data. - disk_data = zarr.open_array(raster_map_path, mode="w", shape=map_info.shape) - for idx, layer_name in enumerate(map_info.layers): - disk_data[idx] = layer_fn(layer_name) + zarr.save(raster_map_path, rasterized_map.data) # Saving the rasterized map metadata. with open(raster_metadata_path, "wb") as f: - dill.dump(map_info, f) + dill.dump(rasterized_map.metadata, f) def pad_map_patch( self, @@ -698,6 +828,74 @@ def pad_map_patch( return np.pad(patch, [(0, 0), (pad_top, pad_bot), (pad_left, pad_right)]) + def load_kdtrees(self) -> Dict[str, MapElementKDTree]: + _, _, kdtrees_path, _, _, _ = DataFrameCache.get_map_paths( + self.path, self.scene.env_name, self.scene.location, 0.0 + ) + + with open(kdtrees_path, "rb") as f: + kdtrees: Dict[str, MapElementKDTree] = dill.load(f) + + return kdtrees + + def get_kdtrees(self, load_only_once: bool = True): + """Loads and returns the kdtrees dictionary from the cache file. + + Args: + load_only_once (bool): store the kdtree dictionary in self so that we + dont have to load it from the cache file more than once. + """ + if self._kdtrees is None: + kdtrees = self.load_kdtrees() + if load_only_once: + self._kdtrees = kdtrees + + return kdtrees + + else: + return self._kdtrees + + def load_rtrees(self) -> MapElementSTRTree: + _, _, _, rtrees_path, _, _ = DataFrameCache.get_map_paths( + self.path, self.scene.env_name, self.scene.location, 0.0 + ) + + if not rtrees_path.exists(): + warnings.warn( + ( + "Trying to load cached RTree encoding 2D Map elements, " + f"but {rtrees_path} does not exist. Earlier versions of " + "trajdata did not build and cache this RTree. If area queries " + "are needed, please rebuild the map cache (see " + "examples/preprocess_maps.py for an example of how to do this). " + "Otherwise, please ignore this warning." + ), + UserWarning, + ) + return None + + with open(rtrees_path, "rb") as f: + rtrees: MapElementSTRTree = dill.load(f) + + return rtrees + + def get_rtrees(self, load_only_once: bool = True): + """Loads and returns the rtrees object from the cache file. + + Args: + load_only_once (bool): store the kdtree dictionary in self so that we + dont have to load it from the cache file more than once. + """ + if self._rtrees is None: + rtrees = self.load_rtrees() + if load_only_once: + self._rtrees = rtrees + + return rtrees + + else: + return self._rtrees + def load_map_patch( self, world_x: float, @@ -713,6 +911,8 @@ def load_map_patch( ( maps_path, _, + _, + _, raster_map_path, raster_metadata_path, ) = DataFrameCache.get_map_paths( diff --git a/src/trajdata/caching/env_cache.py b/src/trajdata/caching/env_cache.py index 80336f1..5827e94 100644 --- a/src/trajdata/caching/env_cache.py +++ b/src/trajdata/caching/env_cache.py @@ -48,6 +48,20 @@ def save_scene(self, scene: Scene) -> Path: return scene_file + @staticmethod + def save_scene_with_path(base_path: Path, scene: Scene) -> Path: + scene_file: Path = EnvCache.scene_metadata_path( + base_path, scene.env_name, scene.name, scene.dt + ) + + scene_cache_dir: Path = scene_file.parent + scene_cache_dir.mkdir(parents=True, exist_ok=True) + + with open(scene_file, "wb") as f: + dill.dump(scene, f) + + return scene_file + def load_env_scenes_list(self, env_name: str) -> List[NamedTuple]: env_cache_dir: Path = self.path / env_name with open(env_cache_dir / "scenes_list.dill", "rb") as f: diff --git a/src/trajdata/caching/scene_cache.py b/src/trajdata/caching/scene_cache.py index 8f8b1fc..138cb2a 100644 --- a/src/trajdata/caching/scene_cache.py +++ b/src/trajdata/caching/scene_cache.py @@ -1,13 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Type + +if TYPE_CHECKING: + from trajdata.maps import TrafficLightStatus, VectorMap + from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from trajdata.augmentation.augmentation import Augmentation from trajdata.data_structures.agent import AgentMetadata from trajdata.data_structures.scene_metadata import Scene -from trajdata.maps import RasterizedMap, RasterizedMapMetadata -from trajdata.proto.vectorized_map_pb2 import VectorizedMap +from trajdata.data_structures.state import StateArray class SceneCache: @@ -15,7 +21,6 @@ def __init__( self, cache_path: Path, scene: Scene, - scene_ts: Optional[int] = 0, augmentations: Optional[List[Augmentation]] = None, ) -> None: """ @@ -24,13 +29,21 @@ def __init__( self.path = cache_path self.scene = scene self.dt = scene.dt - self.scene_ts = scene_ts self.augmentations = augmentations # Ensuring the scene cache folder exists - self.scene_dir: Path = self.path / self.scene.env_name / self.scene.name + self.scene_dir: Path = SceneCache.scene_cache_dir( + self.path, self.scene.env_name, self.scene.name + ) self.scene_dir.mkdir(parents=True, exist_ok=True) + self.obs_type: Type[StateArray] = None + + @staticmethod + def scene_cache_dir(cache_path: Path, env_name: str, scene_name: str) -> Path: + """Standardized convention to compute scene cache folder path""" + return cache_path / env_name / scene_name + def write_cache_to_disk(self) -> None: """Saves agent data to disk for fast loading later (just like save_agent_data), but using the class attributes for the sources of data and file paths. @@ -53,29 +66,45 @@ def get_value(self, agent_id: str, scene_ts: int, attribute: str) -> float: """ raise NotImplementedError() - def get_state(self, agent_id: str, scene_ts: int) -> np.ndarray: + def get_raw_state(self, agent_id: str, scene_ts: int) -> StateArray: + """ + Get an agent's raw state (without transformations applied) + """ + raise NotImplementedError() + + def get_state(self, agent_id: str, scene_ts: int) -> StateArray: """ Get an agent's state at a specific timestep. """ raise NotImplementedError() - def get_states(self, agent_ids: List[str], scene_ts: int) -> np.ndarray: + def get_states(self, agent_ids: List[str], scene_ts: int) -> StateArray: """ Get multiple agents' states at a specific timestep. """ raise NotImplementedError() - def transform_data(self, **kwargs) -> None: + def set_obs_frame(self, obs_frame: StateArray) -> None: + """ + Set frame in which to return observations + """ + raise NotImplementedError() + + def reset_obs_frame(self) -> None: + """ + Reset observation frame to be same as world frame + """ + raise NotImplementedError() + + def set_obs_format(self, format_str: str) -> None: """ - Transform the data before accessing it later, e.g., to make the mean zero or rotate the scene around an agent. - This can either be done in this function call or just stored for later lazy application. + Sets observation format (which elements to include and their order) """ raise NotImplementedError() - def reset_transforms(self) -> None: + def reset_obs_format(self) -> None: """ - Transform the data back to its original coordinate system. - This can either be done in this function call or just stored for later lazy application. + Resets observation format to default (set by subclass) """ raise NotImplementedError() @@ -93,7 +122,10 @@ def get_agent_history( agent_info: AgentMetadata, scene_ts: int, history_sec: Tuple[Optional[float], Optional[float]], - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[StateArray, np.ndarray]: + """ + Returns (agent_history_state, agent_extent) + """ raise NotImplementedError() def get_agent_future( @@ -101,7 +133,10 @@ def get_agent_future( agent_info: AgentMetadata, scene_ts: int, future_sec: Tuple[Optional[float], Optional[float]], - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[StateArray, np.ndarray]: + """ + Returns (agent_future_state, agent_extent) + """ raise NotImplementedError() def get_agents_history( @@ -109,7 +144,7 @@ def get_agents_history( scene_ts: int, agents: List[AgentMetadata], history_sec: Tuple[Optional[float], Optional[float]], - ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: + ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: raise NotImplementedError() def get_agents_future( @@ -117,7 +152,25 @@ def get_agents_future( scene_ts: int, agents: List[AgentMetadata], future_sec: Tuple[Optional[float], Optional[float]], - ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: + ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: + raise NotImplementedError() + + # TRAFFIC LIGHT INFO + @staticmethod + def save_traffic_light_data( + traffic_light_status_data: Any, cache_path: Path, scene: Scene + ) -> None: + """Saves traffic light status to disk for easy access later""" + raise NotImplementedError() + + def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bool: + raise NotImplementedError() + + def get_traffic_light_status_dict( + self, desired_dt: Optional[float] = None + ) -> Dict[Tuple[str, int], TrafficLightStatus]: + """Returns lookup table for traffic light status in the current scene + lane_id, scene_ts -> TrafficLightStatus""" raise NotImplementedError() # MAPS @@ -132,18 +185,10 @@ def is_map_cached( raise NotImplementedError() @staticmethod - def cache_map( - cache_path: Path, vec_map: VectorizedMap, map_obj: RasterizedMap, env_name: str - ) -> None: - raise NotImplementedError() - - @staticmethod - def cache_map_layers( + def finalize_and_cache_map( cache_path: Path, - vec_map: VectorizedMap, - map_info: RasterizedMapMetadata, - layer_fn: Callable[[str], np.ndarray], - env_name: str, + vector_map: VectorMap, + map_params: Dict[str, Any], ) -> None: raise NotImplementedError() diff --git a/src/trajdata/data_structures/__init__.py b/src/trajdata/data_structures/__init__.py index 78539fb..fd2c569 100644 --- a/src/trajdata/data_structures/__init__.py +++ b/src/trajdata/data_structures/__init__.py @@ -7,3 +7,4 @@ from .scene import SceneTime, SceneTimeAgent from .scene_metadata import Scene, SceneMetadata from .scene_tag import SceneTag +from .state import NP_STATE_TYPES, TORCH_STATE_TYPES, StateArray, StateTensor diff --git a/src/trajdata/data_structures/batch.py b/src/trajdata/data_structures/batch.py index 0844ef1..c93fb8b 100644 --- a/src/trajdata/data_structures/batch.py +++ b/src/trajdata/data_structures/batch.py @@ -1,40 +1,45 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import torch from torch import Tensor from trajdata.data_structures.agent import AgentType +from trajdata.data_structures.state import StateTensor +from trajdata.maps import VectorMap from trajdata.utils.arr_utils import PadDirection @dataclass class AgentBatch: data_idx: Tensor + scene_ts: Tensor dt: Tensor agent_name: List[str] agent_type: Tensor - curr_agent_state: Tensor - agent_hist: Tensor + curr_agent_state: StateTensor + agent_hist: StateTensor agent_hist_extent: Tensor agent_hist_len: Tensor - agent_fut: Tensor + agent_fut: StateTensor agent_fut_extent: Tensor agent_fut_len: Tensor num_neigh: Tensor neigh_types: Tensor - neigh_hist: Tensor + neigh_hist: StateTensor neigh_hist_extents: Tensor neigh_hist_len: Tensor - neigh_fut: Tensor + neigh_fut: StateTensor neigh_fut_extents: Tensor neigh_fut_len: Tensor - robot_fut: Optional[Tensor] + robot_fut: Optional[StateTensor] robot_fut_len: Optional[Tensor] + map_names: Optional[List[str]] maps: Optional[Tensor] maps_resolution: Optional[Tensor] + vector_maps: Optional[List[VectorMap]] rasters_from_world_tf: Optional[Tensor] agents_from_world_tf: Tensor scene_ids: Optional[List] @@ -53,6 +58,8 @@ def to(self, device) -> None: "neigh_types", "num_neigh", "robot_fut_len", + "map_names", + "vector_maps", "scene_ids", "history_pad_dir", "extras", @@ -60,11 +67,15 @@ def to(self, device) -> None: for val in vars(self).keys(): tensor_val = getattr(self, val) if val not in excl_vals and tensor_val is not None: - tensor_val: Tensor + tensor_val: Union[Tensor, StateTensor] setattr(self, val, tensor_val.to(device, non_blocking=True)) for key, val in self.extras.items(): - self.extras[key] = val.to(device, non_blocking=True) + # Allow for custom .to() method for objects that define a __to__ function. + if hasattr(val, "__to__"): + self.extras[key] = val.__to__(device, non_blocking=True) + else: + self.extras[key] = val.to(device, non_blocking=True) def agent_types(self) -> List[AgentType]: unique_types: Tensor = torch.unique(self.agent_type) @@ -72,69 +83,100 @@ def agent_types(self) -> List[AgentType]: def for_agent_type(self, agent_type: AgentType) -> AgentBatch: match_type = self.agent_type == agent_type + return self.filter_batch(match_type) + + def filter_batch(self, filter_mask: torch.Tensor) -> AgentBatch: + """Build a new batch with elements for which filter_mask[i] == True.""" + + # Some of the tensors might be on different devices, so we define some convenience functions + # to make sure the filter_mask is always on the same device as the tensor we are indexing. + filter_mask_dict = {} + filter_mask_dict["cpu"] = filter_mask.to("cpu") + filter_mask_dict[str(self.agent_hist.device)] = filter_mask.to( + self.agent_hist.device + ) + + _filter = lambda tensor: tensor[filter_mask_dict[str(tensor.device)]] + _filter_tensor_or_list = lambda tensor_or_list: ( + _filter(tensor_or_list) + if isinstance(tensor_or_list, torch.Tensor) + else type(tensor_or_list)( + [ + el + for idx, el in enumerate(tensor_or_list) + if filter_mask_dict["cpu"][idx] + ] + ) + ) + return AgentBatch( - data_idx=self.data_idx[match_type], - dt=self.dt[match_type], - agent_name=[ - name for idx, name in enumerate(self.agent_name) if match_type[idx] - ], - agent_type=agent_type.value, - curr_agent_state=self.curr_agent_state[match_type], - agent_hist=self.agent_hist[match_type], - agent_hist_extent=self.agent_hist_extent[match_type], - agent_hist_len=self.agent_hist_len[match_type], - agent_fut=self.agent_fut[match_type], - agent_fut_extent=self.agent_fut_extent[match_type], - agent_fut_len=self.agent_fut_len[match_type], - num_neigh=self.num_neigh[match_type], - neigh_types=self.neigh_types[match_type], - neigh_hist=self.neigh_hist[match_type], - neigh_hist_extents=self.neigh_hist_extents[match_type], - neigh_hist_len=self.neigh_hist_len[match_type], - neigh_fut=self.neigh_fut[match_type], - neigh_fut_extents=self.neigh_fut_extents[match_type], - neigh_fut_len=self.neigh_fut_len[match_type], - robot_fut=self.robot_fut[match_type] - if self.robot_fut is not None - else None, - robot_fut_len=self.robot_fut_len[match_type] + data_idx=_filter(self.data_idx), + scene_ts=_filter(self.scene_ts), + dt=_filter(self.dt), + agent_name=_filter_tensor_or_list(self.agent_name), + agent_type=_filter(self.agent_type), + curr_agent_state=_filter(self.curr_agent_state), + agent_hist=_filter(self.agent_hist), + agent_hist_extent=_filter(self.agent_hist_extent), + agent_hist_len=_filter(self.agent_hist_len), + agent_fut=_filter(self.agent_fut), + agent_fut_extent=_filter(self.agent_fut_extent), + agent_fut_len=_filter(self.agent_fut_len), + num_neigh=_filter(self.num_neigh), + neigh_types=_filter(self.neigh_types), + neigh_hist=_filter(self.neigh_hist), + neigh_hist_extents=_filter(self.neigh_hist_extents), + neigh_hist_len=_filter(self.neigh_hist_len), + neigh_fut=_filter(self.neigh_fut), + neigh_fut_extents=_filter(self.neigh_fut_extents), + neigh_fut_len=_filter(self.neigh_fut_len), + robot_fut=_filter(self.robot_fut) if self.robot_fut is not None else None, + robot_fut_len=_filter(self.robot_fut_len) if self.robot_fut_len is not None else None, - maps=self.maps[match_type] if self.maps is not None else None, - maps_resolution=self.maps_resolution[match_type] + map_names=_filter_tensor_or_list(self.map_names) + if self.map_names is not None + else None, + maps=_filter(self.maps) if self.maps is not None else None, + maps_resolution=_filter(self.maps_resolution) if self.maps_resolution is not None else None, - rasters_from_world_tf=self.rasters_from_world_tf[match_type] + vector_maps=_filter(self.vector_maps) + if self.vector_maps is not None + else None, + rasters_from_world_tf=_filter(self.rasters_from_world_tf) if self.rasters_from_world_tf is not None else None, - agents_from_world_tf=self.agents_from_world_tf[match_type], - scene_ids=[ - scene_id - for idx, scene_id in enumerate(self.scene_ids) - if match_type[idx] - ], + agents_from_world_tf=_filter(self.agents_from_world_tf), + scene_ids=_filter_tensor_or_list(self.scene_ids), history_pad_dir=self.history_pad_dir, - extras={key: val[match_type] for key, val in self.extras}, + extras={ + key: _filter_tensor_or_list(val) for key, val in self.extras.items() + }, ) @dataclass class SceneBatch: data_idx: Tensor + scene_ts: Tensor dt: Tensor num_agents: Tensor agent_type: Tensor - centered_agent_state: Tensor - agent_hist: Tensor + centered_agent_state: StateTensor + agent_names: List[str] + agent_hist: StateTensor agent_hist_extent: Tensor agent_hist_len: Tensor - agent_fut: Tensor + agent_fut: StateTensor agent_fut_extent: Tensor agent_fut_len: Tensor - robot_fut: Optional[Tensor] + robot_fut: Optional[StateTensor] robot_fut_len: Optional[Tensor] + map_names: Optional[Tensor] maps: Optional[Tensor] maps_resolution: Optional[Tensor] + vector_maps: Optional[List[VectorMap]] rasters_from_world_tf: Optional[Tensor] centered_agent_from_world_tf: Tensor centered_world_from_agent_tf: Tensor @@ -144,7 +186,11 @@ class SceneBatch: def to(self, device) -> None: excl_vals = { + "agent_names", + "map_names", + "vector_maps", "history_pad_dir", + "scene_ids", "extras", } @@ -154,46 +200,152 @@ def to(self, device) -> None: setattr(self, val, tensor_val.to(device)) for key, val in self.extras.items(): - self.extras[key] = val.to(device) + # Allow for custom .to() method for objects that define a __to__ function. + if hasattr(val, "__to__"): + self.extras[key] = val.__to__(device, non_blocking=True) + else: + self.extras[key] = val.to(device, non_blocking=True) def agent_types(self) -> List[AgentType]: unique_types: Tensor = torch.unique(self.agent_type) - return [AgentType(unique_type.item()) for unique_type in unique_types] + return [ + AgentType(unique_type.item()) + for unique_type in unique_types + if unique_type >= 0 + ] def for_agent_type(self, agent_type: AgentType) -> SceneBatch: match_type = self.agent_type == agent_type + return self.filter_batch(match_type) + + def filter_batch(self, filter_mask: torch.tensor) -> SceneBatch: + """Build a new batch with elements for which filter_mask[i] == True.""" + + # Some of the tensors might be on different devices, so we define some convenience functions + # to make sure the filter_mask is always on the same device as the tensor we are indexing. + filter_mask_dict = {} + filter_mask_dict["cpu"] = filter_mask.to("cpu") + filter_mask_dict[str(self.agent_hist.device)] = filter_mask.to( + self.agent_hist.device + ) + + _filter = lambda tensor: tensor[filter_mask_dict[str(tensor.device)]] + _filter_tensor_or_list = lambda tensor_or_list: ( + _filter(tensor_or_list) + if isinstance(tensor_or_list, torch.Tensor) + else type(tensor_or_list)( + [ + el + for idx, el in enumerate(tensor_or_list) + if filter_mask_dict["cpu"][idx] + ] + ) + ) + return SceneBatch( - data_idx=self.data_idx[match_type], - dt=self.dt[match_type], - num_agents=self.num_agents[match_type], - agent_type=self.agent_type[match_type], - centered_agent_state=self.centered_agent_state[match_type], - agent_hist=self.agent_hist[match_type], - agent_hist_extent=self.agent_hist_extent[match_type], - agent_hist_len=self.agent_hist_len[match_type], - agent_fut=self.agent_fut[match_type], - agent_fut_extent=self.agent_fut_extent[match_type], - agent_fut_len=self.agent_fut_len[match_type], - robot_fut=self.robot_fut[match_type] - if self.robot_fut is not None - else None, - robot_fut_len=self.robot_fut_len[match_type] + data_idx=_filter(self.data_idx), + scene_ts=_filter(self.scene_ts), + dt=_filter(self.dt), + num_agents=_filter(self.num_agents), + agent_type=_filter(self.agent_type), + centered_agent_state=_filter(self.centered_agent_state), + agent_hist=_filter(self.agent_hist), + agent_hist_extent=_filter(self.agent_hist_extent), + agent_hist_len=_filter(self.agent_hist_len), + agent_fut=_filter(self.agent_fut), + agent_fut_extent=_filter(self.agent_fut_extent), + agent_fut_len=_filter(self.agent_fut_len), + robot_fut=_filter(self.robot_fut) if self.robot_fut is not None else None, + robot_fut_len=_filter(self.robot_fut_len) if self.robot_fut_len is not None else None, - maps=self.maps[match_type] if self.maps is not None else None, - maps_resolution=self.maps_resolution[match_type] + map_names=_filter_tensor_or_list(self.map_names) + if self.map_names is not None + else None, + maps=_filter(self.maps) if self.maps is not None else None, + maps_resolution=_filter(self.maps_resolution) if self.maps_resolution is not None else None, - rasters_from_world_tf=self.rasters_from_world_tf[match_type] + vector_maps=_filter(self.vector_maps) + if self.vector_maps is not None + else None, + rasters_from_world_tf=_filter(self.rasters_from_world_tf) if self.rasters_from_world_tf is not None else None, - centered_agent_from_world_tf=self.centered_agent_from_world_tf[match_type], - centered_world_from_agent_tf=self.centered_world_from_agent_tf[match_type], - scene_ids=[ - scene_id - for idx, scene_id in enumerate(self.scene_ids) - if match_type[idx] - ], + centered_agent_from_world_tf=_filter(self.centered_agent_from_world_tf), + centered_world_from_agent_tf=_filter(self.centered_world_from_agent_tf), + scene_ids=_filter_tensor_or_list(self.scene_ids), + history_pad_dir=self.history_pad_dir, + extras={ + key: _filter_tensor_or_list(val, filter_mask) + for key, val in self.extras.items() + }, + ) + + def to_agent_batch(self, agent_inds: torch.Tensor) -> AgentBatch: + """ + Converts SeceneBatch to AgentBatch for agents defined by `agent_inds`. + + self.extras will be simply copied over, any custom conversion must be + implemented externally. + """ + + batch_size = self.agent_hist.shape[0] + num_agents = self.agent_hist.shape[1] + + if agent_inds.ndim != 1 or agent_inds.shape[0] != batch_size: + raise ValueError("Wrong shape for agent_inds, expected [batch_size].") + + if (agent_inds < 0).any() or (agent_inds >= num_agents).any(): + raise ValueError("Invalid agent index") + + batch_inds = torch.arange(batch_size) + others_mask = torch.ones((batch_size, num_agents), dtype=torch.bool) + others_mask[batch_inds, agent_inds] = False + index_agent = lambda x: x[batch_inds, agent_inds] if x is not None else None + index_agent_list = ( + lambda xlist: [x[ind] for x, ind in zip(xlist, agent_inds)] + if xlist is not None + else None + ) + index_neighbors = lambda x: x[others_mask].reshape( + [ + batch_size, + num_agents - 1, + ] + + list(x.shape[2:]) + ) + + return AgentBatch( + data_idx=self.data_idx, + scene_ts=self.scene_ts, + dt=self.dt, + agent_name=index_agent_list(self.agent_names), + agent_type=index_agent(self.agent_type), + curr_agent_state=self.centered_agent_state, # TODO this is not actually the agent but the `global` coordinate frame + agent_hist=index_agent(self.agent_hist), + agent_hist_extent=index_agent(self.agent_hist_extent), + agent_hist_len=index_agent(self.agent_hist_len), + agent_fut=index_agent(self.agent_fut), + agent_fut_extent=index_agent(self.agent_fut_extent), + agent_fut_len=index_agent(self.agent_fut_len), + num_neigh=self.num_agents - 1, + neigh_types=index_neighbors(self.agent_type), + neigh_hist=index_neighbors(self.agent_hist), + neigh_hist_extents=index_neighbors(self.agent_hist_extent), + neigh_hist_len=index_neighbors(self.agent_hist_len), + neigh_fut=index_neighbors(self.agent_fut), + neigh_fut_extents=index_neighbors(self.agent_fut_extent), + neigh_fut_len=index_neighbors(self.agent_fut_len), + robot_fut=self.robot_fut, + robot_fut_len=self.robot_fut_len, + map_names=index_agent_list(self.map_names), + maps=index_agent(self.maps), + vector_maps=index_agent(self.vector_maps), + maps_resolution=index_agent(self.maps_resolution), + rasters_from_world_tf=index_agent(self.rasters_from_world_tf), + agents_from_world_tf=self.centered_agent_from_world_tf, + scene_ids=self.scene_ids, history_pad_dir=self.history_pad_dir, - extras={key: val[match_type] for key, val in self.extras}, + extras=self.extras, ) diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py index 2b177f9..cf61764 100644 --- a/src/trajdata/data_structures/batch_element.py +++ b/src/trajdata/data_structures/batch_element.py @@ -7,7 +7,9 @@ from trajdata.caching import SceneCache from trajdata.data_structures.agent import AgentMetadata, AgentType from trajdata.data_structures.scene import SceneTime, SceneTimeAgent -from trajdata.maps import RasterizedMapPatch +from trajdata.data_structures.state import StateArray +from trajdata.maps import MapAPI, RasterizedMapPatch, VectorMap +from trajdata.utils.state_utils import convert_to_frame_state, transform_from_frame class AgentBatchElement: @@ -25,8 +27,11 @@ def __init__( Tuple[AgentType, AgentType], float ] = defaultdict(lambda: np.inf), incl_robot_future: bool = False, - incl_map: bool = False, - map_params: Optional[Dict[str, Any]] = None, + incl_raster_map: bool = False, + raster_map_params: Optional[Dict[str, Any]] = None, + map_api: Optional[MapAPI] = None, + vector_map_params: Optional[Dict[str, Any]] = None, + state_format: Optional[str] = None, standardize_data: bool = False, standardize_derivatives: bool = False, max_neighbor_num: Optional[int] = None, @@ -43,16 +48,26 @@ def __init__( self.agent_type: AgentType = agent_info.type self.max_neighbor_num = max_neighbor_num - self.curr_agent_state_np: np.ndarray = cache.get_state( - agent_info.name, self.scene_ts - ) + raw_state: StateArray = cache.get_raw_state(agent_info.name, self.scene_ts) + if state_format is not None: + self.curr_agent_state_np = raw_state.as_format(state_format) + else: + self.curr_agent_state_np = raw_state self.standardize_data = standardize_data if self.standardize_data: - agent_pos: np.ndarray = self.curr_agent_state_np[:2] - agent_heading: float = self.curr_agent_state_np[-1] + # Request cache to return observations relative to current agent + obs_frame: StateArray = convert_to_frame_state( + raw_state, + stationary=not standardize_derivatives, + grounded=True, + ) + cache.set_obs_frame(obs_frame) - cos_agent, sin_agent = np.cos(agent_heading), np.sin(agent_heading) + # Create and store 2d tranformation matrix to agent from world + agent_pos = self.curr_agent_state_np.position + agent_heading_vector = self.curr_agent_state_np.heading_vector + cos_agent, sin_agent = agent_heading_vector[0], agent_heading_vector[1] world_from_agent_tf: np.ndarray = np.array( [ [cos_agent, -sin_agent, agent_pos[0]], @@ -61,17 +76,6 @@ def __init__( ] ) self.agent_from_world_tf: np.ndarray = np.linalg.inv(world_from_agent_tf) - - offset = self.curr_agent_state_np - if not standardize_derivatives: - offset[2:6] = 0.0 - - cache.transform_data( - shift_mean_to=offset, - rotate_by=agent_heading, - sincos_heading=True, - ) - else: self.agent_from_world_tf: np.ndarray = np.eye(3) @@ -85,6 +89,7 @@ def __init__( agent_info, future_sec ) self.agent_future_len: int = self.agent_future_np.shape[0] + self.agent_meta_dict: Dict = get_agent_meta_dict(self.cache, agent_info) ### NEIGHBOR-SPECIFIC DATA ### def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: @@ -95,31 +100,32 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: ] ) + nearby_agents, self.neighbor_types_np = self.get_nearby_agents( + scene_time_agent, agent_info, distance_limit + ) + + self.num_neighbors = len(nearby_agents) ( - self.num_neighbors, - self.neighbor_types_np, self.neighbor_histories, self.neighbor_history_extents, self.neighbor_history_lens_np, - ) = self.get_neighbor_history( - scene_time_agent, agent_info, history_sec, distance_limit - ) + ) = self.get_neighbor_history(history_sec, nearby_agents) ( - _, - _, self.neighbor_futures, self.neighbor_future_extents, self.neighbor_future_lens_np, - ) = self.get_neighbor_future( - scene_time_agent, agent_info, future_sec, distance_limit - ) + ) = self.get_neighbor_future(future_sec, nearby_agents) + + self.neighbor_meta_dicts: Dict = [ + get_agent_meta_dict(self.cache, agent) for agent in nearby_agents + ] ### ROBOT DATA ### - self.robot_future_np: Optional[np.ndarray] = None + self.robot_future_np: Optional[StateArray] = None if incl_robot_future: - self.robot_future_np: np.ndarray = self.get_robot_current_and_future( + self.robot_future_np: StateArray = self.get_robot_current_and_future( scene_time_agent.robot, future_sec ) @@ -129,9 +135,29 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: self.robot_future_len: int = self.robot_future_np.shape[0] - 1 ### MAP ### + self.map_name: Optional[str] = None self.map_patch: Optional[RasterizedMapPatch] = None - if incl_map: - self.map_patch = self.get_agent_map_patch(map_params) + + map_name: str = ( + f"{scene_time_agent.scene.env_name}:{scene_time_agent.scene.location}" + ) + if incl_raster_map: + self.map_name = map_name + self.map_patch = self.get_agent_map_patch(raster_map_params) + + self.vec_map: Optional[VectorMap] = None + if map_api is not None: + self.vec_map = map_api.get_map( + map_name, + self.cache + if self.cache.is_traffic_light_data_cached( + # Is the original dt cached? If so, we can continue by + # interpolating time to get whatever the user desires. + self.cache.scene.env_metadata.dt + ) + else None, + **vector_map_params if vector_map_params is not None else None, + ) self.scene_id = scene_time_agent.scene.name @@ -142,7 +168,7 @@ def get_agent_history( self, agent_info: AgentMetadata, history_sec: Tuple[Optional[float], Optional[float]], - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[StateArray, np.ndarray]: agent_history_np, agent_extent_history_np = self.cache.get_agent_history( agent_info, self.scene_ts, history_sec ) @@ -152,21 +178,22 @@ def get_agent_future( self, agent_info: AgentMetadata, future_sec: Tuple[Optional[float], Optional[float]], - ) -> np.ndarray: + ) -> Tuple[StateArray, np.ndarray]: agent_future_np, agent_extent_future_np = self.cache.get_agent_future( agent_info, self.scene_ts, future_sec ) return agent_future_np, agent_extent_future_np # @profile - def get_neighbor_data( + def get_nearby_agents( self, scene_time: SceneTimeAgent, agent_info: AgentMetadata, - length_sec: Tuple[Optional[float], Optional[float]], distance_limit: Callable[[np.ndarray, int], np.ndarray], - mode: str, - ) -> Tuple[int, np.ndarray, List[np.ndarray], List[np.ndarray], np.ndarray]: + ) -> Tuple[List[AgentMetadata], np.ndarray]: + """ + Returns Agent Metadata and Agent types of nearby agents + """ # The indices of the returned ndarray match the scene_time agents list # (including the index of the central agent, which would have a distance # of 0 to itself). @@ -183,82 +210,68 @@ def get_neighbor_data( nearby_agents: List[AgentMetadata] = [ scene_time.agents[idx] for idx in nb_idx if nearby_mask[idx] ] - neighbor_types_np: np.ndarray = neighbor_types[nearby_mask] if self.max_neighbor_num is not None: # Pruning nearby_agents and re-creating # neighbor_types_np with the remaining agents. nearby_agents = nearby_agents[: self.max_neighbor_num] - neighbor_types_np: np.ndarray = np.array( - [a.type.value for a in nearby_agents] - ) - num_neighbors: int = len(nearby_agents) - - if mode == "history": - ( - neighbor_data, - neighbor_extents_data, - neighbor_data_lens_np, - ) = self.cache.get_agents_history(self.scene_ts, nearby_agents, length_sec) - elif mode == "future": - ( - neighbor_data, - neighbor_extents_data, - neighbor_data_lens_np, - ) = self.cache.get_agents_future(self.scene_ts, nearby_agents, length_sec) - else: - raise ValueError(f"Unknown mode {mode} passed in!") + # Doing this here because the argsort above changes the order of agents. + neighbor_types_np: np.ndarray = np.array([a.type.value for a in nearby_agents]) - return ( - num_neighbors, - neighbor_types_np, - neighbor_data, - neighbor_extents_data, - neighbor_data_lens_np, - ) + return nearby_agents, neighbor_types_np def get_neighbor_history( self, - scene_time: SceneTimeAgent, - agent_info: AgentMetadata, history_sec: Tuple[Optional[float], Optional[float]], - distance_limit: Callable[[np.ndarray, int], np.ndarray], - ) -> Tuple[int, np.ndarray, List[np.ndarray], List[np.ndarray], np.ndarray]: - return self.get_neighbor_data( - scene_time, agent_info, history_sec, distance_limit, mode="history" + nearby_agents: List[AgentMetadata], + ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: + ( + neighbor_data, + neighbor_extents_data, + neighbor_data_lens_np, + ) = self.cache.get_agents_history(self.scene_ts, nearby_agents, history_sec) + return ( + neighbor_data, + neighbor_extents_data, + neighbor_data_lens_np, ) def get_neighbor_future( self, - scene_time: SceneTimeAgent, - agent_info: AgentMetadata, future_sec: Tuple[Optional[float], Optional[float]], - distance_limit: Callable[[np.ndarray, int], np.ndarray], - ) -> Tuple[int, np.ndarray, List[np.ndarray], List[np.ndarray], np.ndarray]: - return self.get_neighbor_data( - scene_time, agent_info, future_sec, distance_limit, mode="future" + nearby_agents: List[AgentMetadata], + ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: + ( + neighbor_data, + neighbor_extents_data, + neighbor_data_lens_np, + ) = self.cache.get_agents_future(self.scene_ts, nearby_agents, future_sec) + return ( + neighbor_data, + neighbor_extents_data, + neighbor_data_lens_np, ) def get_robot_current_and_future( self, robot_info: AgentMetadata, future_sec: Tuple[Optional[float], Optional[float]], - ) -> np.ndarray: - robot_curr_np: np.ndarray = self.cache.get_state(robot_info.name, self.scene_ts) + ) -> StateArray: + robot_curr_np: StateArray = self.cache.get_state(robot_info.name, self.scene_ts) # robot_fut_extents_np, ( robot_fut_np, _, ) = self.cache.get_agent_future(robot_info, self.scene_ts, future_sec) - robot_curr_and_fut_np: np.ndarray = np.concatenate( + robot_curr_and_fut_np: StateArray = np.concatenate( (robot_curr_np[np.newaxis, :], robot_fut_np), axis=0 - ) + ).view(self.cache.obs_type) return robot_curr_and_fut_np def get_agent_map_patch(self, patch_params: Dict[str, int]) -> RasterizedMapPatch: - world_x, world_y = self.curr_agent_state_np[:2] + world_x, world_y = self.curr_agent_state_np.position desired_patch_size: int = patch_params["map_size_px"] resolution: float = patch_params["px_per_m"] offset_xy: Tuple[float, float] = patch_params.get("offset_frac_xy", (0.0, 0.0)) @@ -266,7 +279,7 @@ def get_agent_map_patch(self, patch_params: Dict[str, int]) -> RasterizedMapPatc no_map_fill_val: float = patch_params.get("no_map_fill_value", 0.0) if self.standardize_data: - heading = self.curr_agent_state_np[-1] + heading = self.curr_agent_state_np.heading[0] patch_data, raster_from_world_tf, has_data = self.cache.load_map_patch( world_x, world_y, @@ -315,8 +328,11 @@ def __init__( Tuple[AgentType, AgentType], float ] = defaultdict(lambda: np.inf), incl_robot_future: bool = False, - incl_map: bool = False, - map_params: Optional[Dict[str, Any]] = None, + incl_raster_map: bool = False, + raster_map_params: Optional[Dict[str, Any]] = None, + map_api: Optional[MapAPI] = None, + vector_map_params: Optional[Dict[str, Any]] = None, + state_format: Optional[str] = None, standardize_data: bool = False, standardize_derivatives: bool = False, max_agent_num: Optional[int] = None, @@ -337,14 +353,29 @@ def __init__( else: self.centered_agent = self.agents[0] - self.centered_agent_state_np: np.ndarray = cache.get_state( + raw_state: StateArray = cache.get_raw_state( self.centered_agent.name, self.scene_ts ) + + if state_format is not None: + self.centered_agent_state_np = raw_state.as_format(state_format) + else: + self.centered_agent_state_np = raw_state + self.standardize_data = standardize_data if self.standardize_data: - agent_pos: np.ndarray = self.centered_agent_state_np[:2] - agent_heading: float = self.centered_agent_state_np[-1] + # Request cache to return observations relative to centered agent + obs_frame: StateArray = convert_to_frame_state( + raw_state, + stationary=not standardize_derivatives, + grounded=True, + ) + cache.set_obs_frame(obs_frame) + + # Create 2d transformation matrix to and from agent and world + agent_pos: np.ndarray = self.centered_agent_state_np.position + agent_heading: float = self.centered_agent_state_np.heading[0] cos_agent, sin_agent = np.cos(agent_heading), np.sin(agent_heading) self.centered_world_from_agent_tf: np.ndarray = np.array( @@ -357,16 +388,6 @@ def __init__( self.centered_agent_from_world_tf: np.ndarray = np.linalg.inv( self.centered_world_from_agent_tf ) - - offset = self.centered_agent_state_np - if not standardize_derivatives: - offset[2:6] = 0.0 - - cache.transform_data( - shift_mean_to=offset, - rotate_by=agent_heading, - sincos_heading=True, - ) else: self.centered_agent_from_world_tf: np.ndarray = np.eye(3) self.centered_world_from_agent_tf: np.ndarray = np.eye(3) @@ -385,6 +406,7 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: ) self.num_agents = len(nearby_agents) + self.agent_names = [agent.name for agent in nearby_agents] ( self.agent_histories, self.agent_history_extents, @@ -396,18 +418,34 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: self.agent_future_lens_np, ) = self.get_agents_future(future_sec, nearby_agents) + self.agent_meta_dicts = [ + get_agent_meta_dict(self.cache, agent) for agent in nearby_agents + ] + ### MAP ### + self.map_name: Optional[str] = None self.map_patches: Optional[RasterizedMapPatch] = None - if incl_map: + + map_name: str = f"{scene_time.scene.env_name}:{scene_time.scene.location}" + if incl_raster_map: + self.map_name = map_name self.map_patches = self.get_agents_map_patch( - map_params, self.agent_histories + raster_map_params, self.agent_histories ) - self.scene_id = scene_time.scene.name + + self.vec_map: Optional[VectorMap] = None + if map_api is not None: + self.vec_map = map_api.get_map( + map_name, + self.cache if self.cache.is_traffic_light_data_cached() else None, + **vector_map_params if vector_map_params is not None else None, + ) + ### ROBOT DATA ### - self.robot_future_np: Optional[np.ndarray] = None + self.robot_future_np: Optional[StateArray] = None if incl_robot_future: - self.robot_future_np: np.ndarray = self.get_robot_current_and_future( + self.robot_future_np: StateArray = self.get_robot_current_and_future( self.centered_agent, future_sec ) @@ -446,7 +484,7 @@ def get_agents_history( self, history_sec: Tuple[Optional[float], Optional[float]], nearby_agents: List[AgentMetadata], - ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: + ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: # The indices of the returned ndarray match the scene_time agents list (including the index of the central agent, # which would have a distance of 0 to itself). ( @@ -465,8 +503,7 @@ def get_agents_future( self, future_sec: Tuple[Optional[float], Optional[float]], nearby_agents: List[AgentMetadata], - ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: - + ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: ( agent_futures, agent_future_extents, @@ -482,95 +519,49 @@ def get_agents_future( def get_agents_map_patch( self, patch_params: Dict[str, int], agent_histories: List[np.ndarray] ) -> List[RasterizedMapPatch]: - world_x, world_y = self.centered_agent_state_np[:2] - heading = self.centered_agent_state_np[-1] desired_patch_size: int = patch_params["map_size_px"] resolution: float = patch_params["px_per_m"] offset_xy: Tuple[float, float] = patch_params.get("offset_frac_xy", (0.0, 0.0)) return_rgb: bool = patch_params.get("return_rgb", True) no_map_fill_val: float = patch_params.get("no_map_fill_value", 0.0) - if self.cache._sincos_heading: - if len(self.cache.heading_cols) == 2: - heading_sin_idx, heading_cos_idx = self.cache.heading_cols - else: - heading_sin_idx, heading_cos_idx = ( - self.cache.heading_cols[0], - self.cache.heading_cols[0] + 1, - ) - sincos = True - - else: - heading_idx = self.cache.heading_cols[0] - sincos = False - - x_idx, y_idx = self.cache.pos_cols - map_patches = list() curr_state = [state[-1] for state in agent_histories] - curr_state = np.stack(curr_state) + curr_state = np.stack(curr_state).view(self.cache.obs_type) + if self.standardize_data: - Rot = np.array( - [ - [np.cos(heading), -np.sin(heading)], - [np.sin(heading), np.cos(heading)], - ] + # need to transform back into world frame + obs_frame: StateArray = convert_to_frame_state( + self.centered_agent_state_np, stationary=True, grounded=True ) - if sincos: - agent_heading = ( - np.arctan2( - curr_state[:, heading_sin_idx], curr_state[:, heading_cos_idx] - ) - + heading - ) - else: - agent_heading = curr_state[:, heading_idx] + heading - world_dxy = curr_state[:, [x_idx, y_idx]] @ (Rot.T) - for i in range(curr_state.shape[0]): - patch_data, raster_from_world_tf, has_data = self.cache.load_map_patch( - world_x + world_dxy[i, 0], - world_y + world_dxy[i, 1], - desired_patch_size, - resolution, - offset_xy, - agent_heading[i], - return_rgb, - rot_pad_factor=sqrt(2), - no_map_val=no_map_fill_val, - ) - map_patches.append( - RasterizedMapPatch( - data=patch_data, - rot_angle=agent_heading[i], - crop_size=desired_patch_size, - resolution=resolution, - raster_from_world_tf=raster_from_world_tf, - has_data=has_data, - ) - ) + curr_state = transform_from_frame(curr_state, obs_frame) + heading = curr_state.heading[:, 0] else: - for i in range(curr_state.shape[0]): - patch_data, raster_from_world_tf, has_data = self.cache.load_map_patch( - curr_state[i, x_idx], - curr_state[i, y_idx], - desired_patch_size, - resolution, - offset_xy, - 0, - return_rgb, - no_map_val=no_map_fill_val, - ) - map_patches.append( - RasterizedMapPatch( - data=patch_data, - rot_angle=0, - crop_size=desired_patch_size, - resolution=resolution, - raster_from_world_tf=raster_from_world_tf, - has_data=has_data, - ) + heading = 0.0 * curr_state.heading[:, 0] + + for i in range(curr_state.shape[0]): + patch_data, raster_from_world_tf, has_data = self.cache.load_map_patch( + curr_state.get_attr("x")[i], + curr_state.get_attr("y")[i], + desired_patch_size, + resolution, + offset_xy, + heading[i], + return_rgb, + rot_pad_factor=sqrt(2), + no_map_val=no_map_fill_val, + ) + map_patches.append( + RasterizedMapPatch( + data=patch_data, + rot_angle=heading[i], + crop_size=desired_patch_size, + resolution=resolution, + raster_from_world_tf=raster_from_world_tf, + has_data=has_data, ) + ) return map_patches @@ -578,15 +569,33 @@ def get_robot_current_and_future( self, robot_info: AgentMetadata, future_sec: Tuple[Optional[float], Optional[float]], - ) -> np.ndarray: - robot_curr_np: np.ndarray = self.cache.get_state(robot_info.name, self.scene_ts) + ) -> StateArray: + robot_curr_np: StateArray = self.cache.get_state(robot_info.name, self.scene_ts) # robot_fut_extents_np, ( robot_fut_np, _, ) = self.cache.get_agent_future(robot_info, self.scene_ts, future_sec) - robot_curr_and_fut_np: np.ndarray = np.concatenate( + robot_curr_and_fut_np: StateArray = np.concatenate( (robot_curr_np[np.newaxis, :], robot_fut_np), axis=0 - ) + ).view(self.cache.obs_type) return robot_curr_and_fut_np + + +def is_agent_stationary(cache: SceneCache, agent_info: AgentMetadata) -> bool: + # Agent is considered stationary if it moves less than 1m between the first and last valid timestep. + first_state: StateArray = cache.get_state( + agent_info.name, agent_info.first_timestep + ) + last_state: StateArray = cache.get_state(agent_info.name, agent_info.last_timestep) + is_stationary = np.square(last_state.position - first_state.position).sum(0) < 1.0 + return is_stationary + + +def get_agent_meta_dict( + cache: SceneCache, agent_info: AgentMetadata +) -> Dict[str, np.ndarray]: + return { + "is_stationary": is_agent_stationary(cache, agent_info), + } diff --git a/src/trajdata/data_structures/collation.py b/src/trajdata/data_structures/collation.py index dfea7a1..f08b3b6 100644 --- a/src/trajdata/data_structures/collation.py +++ b/src/trajdata/data_structures/collation.py @@ -1,5 +1,4 @@ from dataclasses import asdict -from enum import IntEnum from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -12,26 +11,49 @@ from trajdata.augmentation import BatchAugmentation from trajdata.data_structures.batch import AgentBatch, SceneBatch from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement +from trajdata.data_structures.state import TORCH_STATE_TYPES +from trajdata.maps import VectorMap from trajdata.utils import arr_utils -def map_collate_fn_agent( +class CustomCollateData: + @staticmethod + def __collate__(elements: list) -> any: + raise NotImplementedError + + def __to__(self, device, non_blocking=False): + # Example for moving all elements of a list to a device: + # return LanesList([[pts.to(device, non_blocking=non_blocking) + # for pts in lanelist] for lanelist in self]) + raise NotImplementedError + + +def _collate_data(elems): + if hasattr(elems[0], "__collate__"): + return elems[0].__collate__(elems) + else: + return torch.as_tensor(np.stack(elems)) + + +def raster_map_collate_fn_agent( batch_elems: List[AgentBatchElement], ): if batch_elems[0].map_patch is None: - return None, None, None + return None, None, None, None + + map_names = [batch_elem.map_name for batch_elem in batch_elems] # Ensuring that any empty map patches have the correct number of channels # prior to collation. has_data: np.ndarray = np.array( [batch_elem.map_patch.has_data for batch_elem in batch_elems], - dtype=np.bool, + dtype=bool, ) no_data: np.ndarray = ~has_data patch_channels: np.ndarray = np.array( [batch_elem.map_patch.data.shape[0] for batch_elem in batch_elems], - dtype=np.int, + dtype=int, ) desired_num_channels: int @@ -148,26 +170,27 @@ def map_collate_fn_agent( ) return ( + map_names, rot_crop_patches, resolution, rasters_from_world_tf, ) -def map_collate_fn_scene( +def raster_map_collate_fn_scene( batch_elems: List[SceneBatchElement], max_agent_num: Optional[int] = None, pad_value: Any = np.nan, ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: - if batch_elems[0].map_patches is None: - return None, None, None + return None, None, None, None patch_size: int = batch_elems[0].map_patches[0].crop_size assert all( batch_elem.map_patches[0].crop_size == patch_size for batch_elem in batch_elems ) + map_names: List[str] = list() num_agents: List[int] = list() agents_rasters_from_world_tfs: List[np.ndarray] = list() agents_patches: List[np.ndarray] = list() @@ -175,6 +198,7 @@ def map_collate_fn_scene( agents_res_list: List[float] = list() for elem in batch_elems: + map_names.append(elem.map_name) num_agents.append(min(elem.num_agents, max_agent_num)) agents_rasters_from_world_tfs += [ x.raster_from_world_tf for x in elem.map_patches[:max_agent_num] @@ -253,7 +277,7 @@ def map_collate_fn_scene( agents_resolution, num_agents, pad_value=0, desired_size=max_agent_num ) - return rot_crop_patches, agents_resolution, agents_rasters_from_world_tf + return map_names, rot_crop_patches, agents_resolution, agents_rasters_from_world_tf def agent_collate_fn( @@ -270,17 +294,24 @@ def agent_collate_fn( ) data_index_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) + scene_ts_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) dt_t: Tensor = torch.zeros((batch_size,), dtype=torch.float) agent_type_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) agent_names: List[str] = list() - curr_agent_state: List[Tensor] = list() + # get agent state and obs format from first item in list + state_format = batch_elems[0].curr_agent_state_np._format + obs_format = batch_elems[0].cache.obs_type._format + AgentStateTensor = TORCH_STATE_TYPES[state_format] + AgentObsTensor = TORCH_STATE_TYPES[obs_format] + + curr_agent_state: List[AgentStateTensor] = list() - agent_history: List[Tensor] = list() + agent_history: List[AgentObsTensor] = list() agent_history_extent: List[Tensor] = list() agent_history_len: Tensor = torch.zeros((batch_size,), dtype=torch.long) - agent_future: List[Tensor] = list() + agent_future: List[AgentObsTensor] = list() agent_future_extent: List[Tensor] = list() agent_future_len: Tensor = torch.zeros((batch_size,), dtype=torch.long) @@ -290,9 +321,9 @@ def agent_collate_fn( max_num_neighbors: int = num_neighbors_t.max().item() neighbor_types: List[Tensor] = list() - neighbor_histories: List[Tensor] = list() + neighbor_histories: List[AgentObsTensor] = list() neighbor_history_extents: List[Tensor] = list() - neighbor_futures: List[Tensor] = list() + neighbor_futures: List[AgentObsTensor] = list() neighbor_future_extents: List[Tensor] = list() # Doing this one up here so that I can use it later in the loop. @@ -323,12 +354,13 @@ def agent_collate_fn( neighbor_future_lens_t: Tensor = torch.full((batch_size, 0), np.nan) max_neigh_future_len: int = 0 - robot_future: List[Tensor] = list() + robot_future: List[AgentObsTensor] = list() robot_future_len: Tensor = torch.zeros((batch_size,), dtype=torch.long) elem: AgentBatchElement for idx, elem in enumerate(batch_elems): data_index_t[idx] = elem.data_index + scene_ts_t[idx] = elem.scene_ts dt_t[idx] = elem.dt agent_names.append(elem.agent_name) agent_type_t[idx] = elem.agent_type.value @@ -458,14 +490,18 @@ def agent_collate_fn( # agent history state dimension (presumably they'll be the same # since they're obtained from the same cached data source). neighbor_histories.append( - torch.full((0, elem.agent_history_np.shape[-1]), np.nan) + torch.full( + (0, elem.agent_history_np.shape[-1]), np.nan, dtype=torch.float + ) ) neighbor_history_extents.append( torch.full((0, elem.agent_history_extent_np.shape[-1]), np.nan) ) neighbor_futures.append( - torch.full((0, elem.agent_future_np.shape[-1]), np.nan) + torch.full( + (0, elem.agent_future_np.shape[-1]), np.nan, dtype=torch.float + ) ) neighbor_future_extents.append( torch.full((0, elem.agent_future_extent_np.shape[-1]), np.nan) @@ -477,15 +513,17 @@ def agent_collate_fn( ) robot_future_len[idx] = elem.robot_future_len - curr_agent_state_t: Tensor = torch.stack(curr_agent_state) + curr_agent_state_t: AgentStateTensor = torch.stack(curr_agent_state).as_subclass( + AgentStateTensor + ) - agent_history_t: Tensor = arr_utils.pad_with_dir( + agent_history_t: AgentObsTensor = arr_utils.pad_with_dir( agent_history, time_dim=-2, pad_dir=history_pad_dir, batch_first=True, padding_value=np.nan, - ) + ).as_subclass(AgentObsTensor) agent_history_extent_t: Tensor = arr_utils.pad_with_dir( agent_history_extent, time_dim=-2, @@ -494,9 +532,9 @@ def agent_collate_fn( padding_value=np.nan, ) - agent_future_t: Tensor = pad_sequence( + agent_future_t: AgentObsTensor = pad_sequence( agent_future, batch_first=True, padding_value=np.nan - ) + ).as_subclass(AgentObsTensor) agent_future_extent_t: Tensor = pad_sequence( agent_future_extent, batch_first=True, padding_value=np.nan ) @@ -513,7 +551,7 @@ def agent_collate_fn( if history_pad_dir == arr_utils.PadDirection.BEFORE else (0, 0, 0, to_add), value=np.nan, - ) + ).as_subclass(AgentObsTensor) if agent_history_extent_t.shape[-2] < hist_len: to_add: int = hist_len - agent_history_extent_t.shape[-2] @@ -532,7 +570,7 @@ def agent_collate_fn( agent_future_t, (0, 0, 0, fut_len - agent_future_t.shape[-2]), value=np.nan, - ) + ).as_subclass(AgentObsTensor) if agent_future_extent_t.shape[-2] < fut_len: agent_future_extent_t = F.pad( @@ -548,15 +586,17 @@ def agent_collate_fn( neighbor_types, batch_first=True, padding_value=-1 ) - neighbor_histories_t: Tensor = pad_sequence( - neighbor_histories, batch_first=True, padding_value=np.nan - ).reshape( - ( - batch_size, - max_num_neighbors, - max_neigh_history_len, - agent_history_t.shape[-1], + neighbor_histories_t: AgentObsTensor = ( + pad_sequence(neighbor_histories, batch_first=True, padding_value=np.nan) + .reshape( + ( + batch_size, + max_num_neighbors, + max_neigh_history_len, + agent_history_t.shape[-1], + ) ) + .as_subclass(AgentObsTensor) ) neighbor_history_extents_t: Tensor = pad_sequence( neighbor_history_extents, batch_first=True, padding_value=np.nan @@ -569,15 +609,17 @@ def agent_collate_fn( ) ) - neighbor_futures_t: Tensor = pad_sequence( - neighbor_futures, batch_first=True, padding_value=np.nan - ).reshape( - ( - batch_size, - max_num_neighbors, - max_neigh_future_len, - agent_future_t.shape[-1], + neighbor_futures_t: AgentObsTensor = ( + pad_sequence(neighbor_futures, batch_first=True, padding_value=np.nan) + .reshape( + ( + batch_size, + max_num_neighbors, + max_neigh_future_len, + agent_future_t.shape[-1], + ) ) + .as_subclass(AgentObsTensor) ) neighbor_future_extents_t: Tensor = pad_sequence( neighbor_future_extents, batch_first=True, padding_value=np.nan @@ -592,33 +634,44 @@ def agent_collate_fn( else: neighbor_types_t: Tensor = torch.full((batch_size, 0), np.nan) - neighbor_histories_t: Tensor = torch.full( - (batch_size, 0, max_neigh_history_len, agent_history_t.shape[-1]), np.nan - ) + neighbor_histories_t: AgentObsTensor = torch.full( + (batch_size, 0, max_neigh_history_len, agent_history_t.shape[-1]), + np.nan, + dtype=torch.float, + ).as_subclass(AgentObsTensor) neighbor_history_extents_t: Tensor = torch.full( (batch_size, 0, max_neigh_history_len, agent_history_extent_t.shape[-1]), np.nan, ) - neighbor_futures_t: Tensor = torch.full( - (batch_size, 0, max_neigh_future_len, agent_future_t.shape[-1]), np.nan - ) + neighbor_futures_t: AgentObsTensor = torch.full( + (batch_size, 0, max_neigh_future_len, agent_future_t.shape[-1]), + np.nan, + dtype=torch.float, + ).as_subclass(AgentObsTensor) neighbor_future_extents_t: Tensor = torch.full( (batch_size, 0, max_neigh_future_len, agent_future_extent_t.shape[-1]), np.nan, ) - robot_future_t: Optional[Tensor] = ( - pad_sequence(robot_future, batch_first=True, padding_value=np.nan) + robot_future_t: Optional[AgentObsTensor] = ( + pad_sequence(robot_future, batch_first=True, padding_value=np.nan).as_subclass( + AgentObsTensor + ) if robot_future else None ) ( + map_names, map_patches, maps_resolution, rasters_from_world_tf, - ) = map_collate_fn_agent(batch_elems) + ) = raster_map_collate_fn_agent(batch_elems) + + vector_maps: Optional[List[VectorMap]] = None + if batch_elems[0].vec_map is not None: + vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] agents_from_world_tf = torch.as_tensor( np.stack([batch_elem.agent_from_world_tf for batch_elem in batch_elems]), @@ -629,12 +682,13 @@ def agent_collate_fn( extras: Dict[str, Tensor] = {} for key in batch_elems[0].extras.keys(): - extras[key] = torch.as_tensor( - np.stack([batch_elem.extras[key] for batch_elem in batch_elems]) + extras[key] = _collate_data( + [batch_elem.extras[key] for batch_elem in batch_elems] ) batch = AgentBatch( data_idx=data_index_t, + scene_ts=scene_ts_t, dt=dt_t, agent_name=agent_names, agent_type=agent_type_t, @@ -655,8 +709,10 @@ def agent_collate_fn( neigh_fut_len=neighbor_future_lens_t, robot_fut=robot_future_t, robot_fut_len=robot_future_len, + map_names=map_names, maps=map_patches, maps_resolution=maps_resolution, + vector_maps=vector_maps, rasters_from_world_tf=rasters_from_world_tf, agents_from_world_tf=agents_from_world_tf, scene_ids=scene_ids, @@ -733,20 +789,27 @@ def scene_collate_fn( ) data_index_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) + scene_ts_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) dt_t: Tensor = torch.zeros((batch_size,), dtype=torch.float) + # get agent state and obs format from first item in list + state_format = batch_elems[0].centered_agent_state_np._format + obs_format = batch_elems[0].cache.obs_type._format + AgentStateTensor = TORCH_STATE_TYPES[state_format] + AgentObsTensor = TORCH_STATE_TYPES[obs_format] + max_agent_num: int = max(elem.num_agents for elem in batch_elems) - centered_agent_state: List[Tensor] = list() + centered_agent_state: List[AgentStateTensor] = list() agents_types: List[Tensor] = list() - agents_histories: List[Tensor] = list() + agents_histories: List[AgentObsTensor] = list() agents_history_extents: List[Tensor] = list() agents_history_len: Tensor = torch.zeros( (batch_size, max_agent_num), dtype=torch.long ) agents_futures: List[Tensor] = list() - agents_future_extents: List[Tensor] = list() + agents_future_extents: List[AgentObsTensor] = list() agents_future_len: Tensor = torch.zeros( (batch_size, max_agent_num), dtype=torch.long ) @@ -757,11 +820,12 @@ def scene_collate_fn( max_history_len: int = max(elem.agent_history_lens_np.max() for elem in batch_elems) max_future_len: int = max(elem.agent_future_lens_np.max() for elem in batch_elems) - robot_future: List[Tensor] = list() + robot_future: List[AgentObsTensor] = list() robot_future_len: Tensor = torch.zeros((batch_size,), dtype=torch.long) for idx, elem in enumerate(batch_elems): data_index_t[idx] = elem.data_index + scene_ts_t[idx] = elem.scene_ts dt_t[idx] = elem.dt centered_agent_state.append(elem.centered_agent_state_np) agents_types.append(elem.agent_types_np) @@ -856,24 +920,36 @@ def scene_collate_fn( agents_histories_t = split_pad_crop( agents_histories, num_agents, np.nan, max_agent_num - ) + ).as_subclass(AgentObsTensor) agents_history_extents_t = split_pad_crop( agents_history_extents, num_agents, np.nan, max_agent_num ) - agents_futures_t = split_pad_crop(agents_futures, num_agents, np.nan, max_agent_num) + agents_futures_t = split_pad_crop( + agents_futures, num_agents, np.nan, max_agent_num + ).as_subclass(AgentObsTensor) agents_future_extents_t = split_pad_crop( agents_future_extents, num_agents, np.nan, max_agent_num ) - centered_agent_state_t = torch.tensor(np.stack(centered_agent_state)) + centered_agent_state_t = torch.as_tensor( + np.stack(centered_agent_state), dtype=torch.float + ).as_subclass(AgentStateTensor) agents_types_t = torch.as_tensor(np.concatenate(agents_types)) agents_types_t = split_pad_crop( agents_types_t, num_agents, pad_value=-1, desired_size=max_agent_num ) - map_patches, maps_resolution, rasters_from_world_tf = map_collate_fn_scene( - batch_elems, max_agent_num - ) + ( + map_names, + map_patches, + maps_resolution, + rasters_from_world_tf, + ) = raster_map_collate_fn_scene(batch_elems, max_agent_num) + + vector_maps: Optional[List[VectorMap]] = None + if batch_elems[0].vec_map is not None: + vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] + centered_agent_from_world_tf = torch.as_tensor( np.stack( [batch_elem.centered_agent_from_world_tf for batch_elem in batch_elems] @@ -888,25 +964,31 @@ def scene_collate_fn( ) robot_future_t: Optional[Tensor] = ( - pad_sequence(robot_future, batch_first=True, padding_value=np.nan) + pad_sequence(robot_future, batch_first=True, padding_value=np.nan).as_subclass( + AgentObsTensor + ) if robot_future else None ) + agent_names = [batch_elem.agent_names for batch_elem in batch_elems] + scene_ids = [batch_elem.scene_id for batch_elem in batch_elems] extras: Dict[str, Tensor] = {} for key in batch_elems[0].extras.keys(): - extras[key] = torch.as_tensor( - np.stack([batch_elem.extras[key] for batch_elem in batch_elems]) + extras[key] = _collate_data( + [batch_elem.extras[key] for batch_elem in batch_elems] ) batch = SceneBatch( data_idx=data_index_t, + scene_ts=scene_ts_t, dt=dt_t, num_agents=num_agents_t, agent_type=agents_types_t, centered_agent_state=centered_agent_state_t, + agent_names=agent_names, agent_hist=agents_histories_t, agent_hist_extent=agents_history_extents_t, agent_hist_len=agents_history_len, @@ -915,8 +997,10 @@ def scene_collate_fn( agent_fut_len=agents_future_len, robot_fut=robot_future_t, robot_fut_len=robot_future_len, + map_names=map_names, maps=map_patches, maps_resolution=maps_resolution, + vector_maps=vector_maps, rasters_from_world_tf=rasters_from_world_tf, centered_agent_from_world_tf=centered_agent_from_world_tf, centered_world_from_agent_tf=centered_world_from_agent_tf, diff --git a/src/trajdata/data_structures/environment.py b/src/trajdata/data_structures/environment.py index e167865..a33ef13 100644 --- a/src/trajdata/data_structures/environment.py +++ b/src/trajdata/data_structures/environment.py @@ -1,6 +1,6 @@ import itertools from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from trajdata.data_structures.scene_tag import SceneTag @@ -13,10 +13,12 @@ def __init__( dt: float, parts: List[Tuple[str]], scene_split_map: Dict[str, str], + map_locations: Optional[Tuple[str]] = None, ) -> None: self.name = name self.data_dir = Path(data_dir).expanduser().resolve() self.dt = dt + self.map_locations = map_locations self.parts = parts self.scene_tags: List[SceneTag] = [ SceneTag(tag_tuple) diff --git a/src/trajdata/data_structures/map.py b/src/trajdata/data_structures/map.py deleted file mode 100644 index 7ab1172..0000000 --- a/src/trajdata/data_structures/map.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy as np -import torch -from torch import Tensor - - -class MapMetadata: - def __init__( - self, - name: str, - shape: Tuple[int, int], - layers: List[str], - layer_rgb_groups: Tuple[List[int], List[int], List[int]], - resolution: float, # px/m - map_from_world: np.ndarray, # Transformation from world coordinates [m] to map coordinates [px] - ) -> None: - self.name: str = name - self.shape: Tuple[int, int] = shape - self.layers: List[str] = layers - self.layer_rgb_groups: Tuple[List[int], List[int], List[int]] = layer_rgb_groups - self.resolution: float = resolution - self.map_from_world: np.ndarray = map_from_world - - -class Map: - def __init__( - self, - metadata: MapMetadata, - data: np.ndarray, - ) -> None: - assert data.shape == metadata.shape - self.metadata: MapMetadata = metadata - self.data: np.ndarray = data - - @property - def shape(self) -> Tuple[int, ...]: - return self.data.shape - - @staticmethod - def to_img( - map_arr: Tensor, - idx_groups: Optional[Tuple[List[int], List[int], List[int]]] = None, - ) -> Tensor: - if idx_groups is None: - return map_arr.permute(1, 2, 0).numpy() - - return torch.stack( - [ - torch.amax(map_arr[idx_groups[0]], dim=0), - torch.amax(map_arr[idx_groups[1]], dim=0), - torch.amax(map_arr[idx_groups[2]], dim=0), - ], - dim=-1, - ).numpy() diff --git a/src/trajdata/data_structures/map_patch.py b/src/trajdata/data_structures/map_patch.py deleted file mode 100644 index 371ae1c..0000000 --- a/src/trajdata/data_structures/map_patch.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np - - -class MapPatch: - def __init__( - self, - data: np.ndarray, - rot_angle: float, - crop_size: int, - resolution: float, - raster_from_world_tf: np.ndarray, - ) -> None: - self.data = data - self.rot_angle = rot_angle - self.crop_size = crop_size - self.resolution = resolution - self.raster_from_world_tf = raster_from_world_tf diff --git a/src/trajdata/data_structures/scene.py b/src/trajdata/data_structures/scene.py index f2251e9..2011a4e 100644 --- a/src/trajdata/data_structures/scene.py +++ b/src/trajdata/data_structures/scene.py @@ -7,6 +7,7 @@ from trajdata.caching import SceneCache from trajdata.data_structures.agent import Agent, AgentMetadata, AgentType from trajdata.data_structures.scene_metadata import Scene +from trajdata.data_structures.state import StateArray class SceneTime: @@ -41,9 +42,9 @@ def from_cache( return cls(scene, scene_ts, filtered_agents, cache) def get_agent_distances_to(self, agent: Agent) -> np.ndarray: - agent_pos: np.ndarray = self.cache.get_state(agent.name, self.ts)[:2] + agent_pos: StateArray = self.cache.get_state(agent.name, self.ts).position nb_pos: np.ndarray = np.stack( - [self.cache.get_state(nb.name, self.ts)[:2] for nb in self.agents] + [self.cache.get_state(nb.name, self.ts).position for nb in self.agents] ) return np.linalg.norm(nb_pos - agent_pos, axis=1) @@ -108,9 +109,9 @@ def from_cache( # @profile def get_agent_distances_to(self, agent_info: AgentMetadata) -> np.ndarray: - agent_pos: np.ndarray = self.cache.get_state(agent_info.name, self.ts)[:2] + agent_pos: StateArray = self.cache.get_state(agent_info.name, self.ts).position - curr_poses: np.ndarray = self.cache.get_states( + curr_poses: StateArray = self.cache.get_states( [a.name for a in self.agents], self.ts - )[:, :2] + ).position return np.linalg.norm(curr_poses - agent_pos, axis=1) diff --git a/src/trajdata/data_structures/scene_tag.py b/src/trajdata/data_structures/scene_tag.py index a046455..6d57a53 100644 --- a/src/trajdata/data_structures/scene_tag.py +++ b/src/trajdata/data_structures/scene_tag.py @@ -1,3 +1,4 @@ +import re from typing import Set, Tuple @@ -8,6 +9,9 @@ def __init__(self, tag_tuple: Tuple[str, ...]) -> None: def contains(self, query: Set[str]) -> bool: return query.issubset(self._tag_tuple) + def matches_any(self, regex: re.Pattern) -> bool: + return any(regex.search(x) is not None for x in self._tag_tuple) + def __contains__(self, item) -> bool: return item in self._tag_tuple diff --git a/src/trajdata/data_structures/state.py b/src/trajdata/data_structures/state.py new file mode 100644 index 0000000..9dcc09b --- /dev/null +++ b/src/trajdata/data_structures/state.py @@ -0,0 +1,466 @@ +""" +Defines subclasses of np.array and torch.Tensor which give +property access to different state elements and allow for easy conversion +between types to help make code that works with state elements more readable +and more robust to future changes in state format (e.g. adding additional dimensions) + +Currently, these subclasses are designed to be lightweight and ephemeral: +any np/torch operation on a State subclass will drop the format metadata. +TODO: we could make this more robust by making exceptions for operations which +preserve the semantic meanings of the elements. +TODO: implement setters for all properties +""" +from abc import abstractclassmethod +from collections import defaultdict +from typing import Callable, ClassVar, Dict, List, Set, Type, TypeVar + +import numpy as np +import torch +from torch import Tensor + +STATE_ELEMS_REQUIREMENTS = { + "x": None, # x position in world frame (m) + "y": None, # y position in world frame (m) + "z": None, # z position in world frame (m) + "xd": ("x_component", "v_lon", "v_lat", "c", "s"), # x vel in world frame (m/s) + "yd": ("y_component", "v_lon", "v_lat", "c", "s"), # y vel in world frame (m/s) + "zd": None, # z velocity in world frame (m/s) + "xdd": None, # x acceleration in world frame (m/s^2) + "ydd": None, # y acceleration in world frame (m/s^2) + "zdd": None, # z acceleration in world frame (m/s^2) + "h": ("arctan", "s", "c"), # heading (rad) + "dh": None, # heading rate (rad) + "c": ("cos", "h"), # cos(h) + "s": ("sin", "h"), # sin(h) + "v_lon": ("lon_component", "xd", "yd", "c", "s"), # longitudinal velocity + "v_lat": ("lat_component", "xd", "yd", "c", "s"), # latitudinal velocity +} + +# How many levels deep we'll try to check if requirements for certain attributes +# themselves need to be computed and are not directly available +MAX_RECURSION_LEVELS = 2 + + +Array = TypeVar("Array", np.ndarray, torch.Tensor) + + +def lon_component(x, y, c, s): + """ + Returns magnitude of x,y that is parallel + to unit vector c,s + """ + return x * c + y * s + + +def lat_component(x, y, c, s): + """ + Returns magnitude of x,y that is orthogonal to + unit vector c,s (i.e., parallel to -s,c) + """ + return -x * s + y * c + + +def x_component(long, lat, c, s): + """ + Returns x component given long and lat components + and cos and sin of heading + """ + return long * c - lat * s + + +def y_component(long, lat, c, s): + """ + Returns y component given long and lat components + and cos and sin of heading + """ + return long * s + lat * c + + +class State: + """ + Base class implementing property access to state elements + Needs to be subclassed for concrete underlying datatypes, e.g. + torch.Tensor vs np.ndarray, to equip self object with + indexing support + """ + + _format: str = "" + + # set upon subclass init + state_dim: int = 0 + + # needs to be defined in subclass + _FUNCS: ClassVar[Dict[str, Callable]] = {} + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + # use subclass _format string to initialize class specific _format_dict + cls._format_dict: Dict[str, int] = {} + for j, attr in enumerate(cls._format.split(",")): + cls._format_dict[attr] = j + + # intialize properties + cls.position = cls._init_property("x,y") + cls.position3d = cls._init_property("x,y,z") + cls.velocity = cls._init_property("xd,yd") + cls.acceleration = cls._init_property("xdd,ydd") + cls.heading = cls._init_property("h") + cls.heading_vector = cls._init_property("c,s") + + # initialize state_dim + cls.state_dim = len(cls._format_dict) + + @abstractclassmethod + def from_array(cls, array: Array, format: str) -> "State": + """ + Returns State instance given Array with correct format. + + Args: + array (Array): Array + format (str): format string + """ + raise NotImplementedError + + @abstractclassmethod + def _combine(cls, arrays: List[Array]): + """ + Concatenates arrays along last dimension, and returns result + according to new format string + + Args: + arrays (List[Array]): _description_ + format (str): _description_ + """ + raise NotImplementedError + + def as_format(self, new_format: str, create_type=True): + """ + Returns a new StateTensor with the specified format, + constructed using data in the current format + """ + requested_attrs = new_format.split(",") + components = [] # contains either indicies into self, or attrs + index_list = None + for j, attr in enumerate(requested_attrs): + if attr in self._format_dict: + if index_list is None: + # start a new block of indices + index_list = [] + components.append(index_list) + index_list.append(self._format_dict[attr]) + else: + if index_list is not None: + # if we had been pulling indices, stop + index_list = None + components.append(attr) + # assemble + arrays = [] + for component in components: + if isinstance(component, list): + arrays.append(self[..., component]) + elif isinstance(component, str): + arrays.append(self._compute_attr(component)[..., None]) + else: + raise ValueError + + result = self._combine(arrays) + if create_type: + return self.from_array(result, new_format) + else: + return result + + def _compute_attr(self, attr: str, _depth: int = MAX_RECURSION_LEVELS): + """ + Tries to compute attr that isn't directly part of the tensor + given the information available. + + if a requirement for the attr isn't directly part of the tensor + either, then we recurse to compute that attribute. + _depth controls the depth of recursion + + If impossible raises ValueError + """ + if _depth == 0: + raise RecursionError + try: + formula = STATE_ELEMS_REQUIREMENTS[attr] + if formula is None: + raise KeyError(f"No formula for {attr}") + func_name, *requirements = formula + func = self._FUNCS[func_name] + args = [self.get_attr(req, _depth=_depth - 1) for req in requirements] + except KeyError as ke: + raise ValueError( + f"{attr} cannot be computed from available data at the current timestep." + ) + except RecursionError as re: + raise ValueError( + f"{attr} cannot be computed: Recursion depth exceeded when trying to computerequirements" + ) + return func(*args) + + def get_attr(self, attr: str, _depth: int = MAX_RECURSION_LEVELS): + """ + Returns slice of tensor corresponding to attr + + """ + if attr in self._format_dict: + return self[..., self._format_dict[attr]] + else: + return self._compute_attr(attr, _depth=_depth) + + def set_attr(self, attr: str, val: Tensor): + if attr in self._format_dict: + self[..., self._format_dict[attr]] = val + else: + raise ValueError(f"{attr} not part of State") + + @classmethod + def _init_property(cls, format: str) -> property: + split_format = format.split(",") + try: + index_list = tuple(cls._format_dict[attr] for attr in split_format) + + def getter(self: State) -> Array: + return self[..., index_list] + + def setter(self: State, val: Array) -> None: + self[..., index_list] = val + + except KeyError: + # getter is nontrivial, let as_format handle the logic + def getter(self: State) -> Array: + return self.as_format(format, create_type=False) + + # can't set this property since not all elements are part of format + setter = None + + return property( + getter, + setter, + doc=f""" + Returns: + Array: shape [..., {len(split_format)}] corresponding to {split_format}. + """, + ) + + +class StateArray(State, np.ndarray): + _FUNCS = { + "cos": np.cos, + "sin": np.sin, + "arctan": np.arctan2, + "lon_component": lon_component, + "lat_component": lat_component, + "x_component": x_component, + "y_component": y_component, + } + + def __str__(self) -> str: + return f"{self.__class__.__name__}({super().__str__()})" + + def __array_ufunc__(self, function, method, *inputs, **kwargs): + args = [] + for i, input_ in enumerate(inputs): + if isinstance(input_, type(self)): + args.append(input_.view(np.ndarray)) + else: + args.append(input_) + + outputs = kwargs.get("out", None) + if outputs: + out_args = [] + for j, output in enumerate(outputs): + if isinstance(output, type(self)): + out_args.append(output.view(np.ndarray)) + else: + out_args.append(output) + kwargs["out"] = tuple(out_args) + else: + outputs = (None,) * function.nout + + # call original function + results = super().__array_ufunc__(function, method, *args, **kwargs) + + return results + + def __getitem__(self, key) -> np.ndarray: + """ + StateArray[key] always returns an np.ndarray, as we can't + be sure that key isn't indexing into the state elements + without adding logic in python which adds significant overhead + to the base numpy implementation which is in C. + + In cases where we just want to index batch dimensions, use + StateArray.at(key). We add logic for slice indexing + """ + return_type = np.ndarray + if isinstance(key, (int, slice)) and self.ndim > 1: + return_type = type(self) + return super().__getitem__(key).view(return_type) + + def at(self, key) -> "StateArray": + """ + Equivalent to self[key], but assumes (without checking!) + that key selects only batch dimensions, so return type + is the same as type(self) + """ + return super().__getitem__(key) + + def as_ndarray(self) -> np.ndarray: + """Convenience function to convert to default ndarray type + Applying np operations to StateArrays can silently convert them + to basic np.ndarrays, so making this conversion explicit + can improve code readability. + + Returns: + np.ndarray: pointing to same data as self + """ + return self.view(np.ndarray) + + @classmethod + def from_array(cls, array: Array, format: str): + return array.view(NP_STATE_TYPES[format]) + + @classmethod + def _combine(cls, arrays: List[Array]): + """ + Concatenates arrays along last dimension, and returns result + according to new format string + """ + return np.concatenate(arrays, axis=-1) + + +class StateTensor(State, Tensor): + """ + Convenience class which offers property access to state elements + Standardizes order of state dimensions + """ + + _FUNCS = { + "cos": torch.cos, + "sin": torch.sin, + "arctan": torch.atan2, + "lon_component": lon_component, + "lat_component": lat_component, + "x_component": x_component, + "y_component": y_component, + } + + CAPTURED_FUNCS: Set[Callable] = { + Tensor.cpu, + Tensor.cuda, + Tensor.add, + Tensor.add_, + Tensor.__deepcopy__, + } + + @classmethod + def new_empty(cls, *args, **kwargs): + return torch.empty(*args, **kwargs).as_subclass(cls) + + def clone(self, *args, **kwargs): + return super().clone(*args, **kwargs).as_subclass(type(self)) + + def to(self, *args, **kwargs): + new_obj = self.__class__() + tempTensor = super().to(*args, **kwargs) + new_obj.data = tempTensor.data + new_obj.requires_grad = tempTensor.requires_grad + return new_obj + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + # overriding this to ensure operations yield base Tensors + if kwargs is None: + kwargs = {} + + new_class = Tensor + result = super().__torch_function__(func, types, args, kwargs) + + if func in StateTensor.CAPTURED_FUNCS: + new_class = cls + + if func == Tensor.__getitem__: + self = args[0] + indices = args[1] + if isinstance(indices, int): + if self.ndim > 1: + new_class = cls + elif isinstance(indices, slice): + if self.ndim > 1: + new_class = cls + elif indices == slice(None): + new_class = cls + elif isinstance(indices, tuple): + if len(indices) < self.ndim: + new_class = cls + elif len(indices) == self.ndim and indices[-1] == slice(None): + new_class = cls + + if isinstance(result, Tensor) and new_class != cls: + result = result.as_subclass(new_class) + + if func == Tensor.numpy: + result: np.ndarray + result = result.view(NP_STATE_TYPES[cls._format]) + + return result + + def as_tensor(self) -> Tensor: + """Convenience function to convert to default tensor type + Applying torch operations to StateTensors can silently convert them + to basic torch.Tensors, so making this conversion explicit + can improve code readability. + + Returns: + Tensor: pointing to same data as self + """ + return self.as_subclass(Tensor) + + @classmethod + def from_numpy(cls, state: StateArray, **kwargs): + return torch.from_numpy(state, **kwargs).as_subclass( + TORCH_STATE_TYPES[state._format] + ) + + @classmethod + def from_array(cls, array: Array, format: str): + return array.as_subclass(TORCH_STATE_TYPES[format]) + + @classmethod + def _combine(cls, arrays: List[Array]): + """ + Concatenates arrays along last dimension, and returns result + according to new format string + """ + return torch.cat(arrays, dim=-1) + + +def createStateType(format: str, base: Type[State]) -> Type[State]: + name = base.__name__ + "".join(map(str.capitalize, format.split(","))) + cls = type( + name, + (base,), + { + "_format": format, + }, + ) + # This is needed so that these dynamically created classes are understood + # by pickle, which is used in multiprocessing, e.g. in a dataset + globals()[name] = cls + return cls + + +class StateTypeFactory(defaultdict): + def __init__(self, base: Type[State]): + self.base_type = base + + def __missing__(self, format: str) -> Type[State]: + self[format] = createStateType(format, self.base_type) + return self[format] + + +# DEFINE STATE TYPES +TORCH_STATE_TYPES: Dict[str, Type[StateTensor]] = StateTypeFactory(StateTensor) +NP_STATE_TYPES: Dict[str, Type[StateArray]] = StateTypeFactory(StateArray) diff --git a/src/trajdata/dataset.py b/src/trajdata/dataset.py index 5979b0b..c99cedb 100644 --- a/src/trajdata/dataset.py +++ b/src/trajdata/dataset.py @@ -1,18 +1,28 @@ import gc +import json +import random +import re +import time +import warnings from collections import defaultdict from functools import partial from itertools import chain +from os.path import isfile from pathlib import Path -from typing import Any, Callable, Dict, Final, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +import dill import numpy as np +from torch import distributed from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from trajdata import filtering from trajdata.augmentation.augmentation import Augmentation, BatchAugmentation -from trajdata.caching import DataFrameCache, EnvCache, SceneCache +from trajdata.caching import EnvCache, SceneCache, df_cache from trajdata.data_structures import ( + NP_STATE_TYPES, + TORCH_STATE_TYPES, AgentBatchElement, AgentDataIndex, AgentMetadata, @@ -29,15 +39,17 @@ scene_collate_fn, ) from trajdata.dataset_specific import RawDataset -from trajdata.parallel import ( - ParallelDatasetPreprocessor, - parallel_iapply, - scene_paths_collate_fn, +from trajdata.maps.map_api import MapAPI +from trajdata.parallel import ParallelDatasetPreprocessor, scene_paths_collate_fn +from trajdata.utils import ( + agent_utils, + env_utils, + py_utils, + raster_utils, + scene_utils, + string_utils, ) -from trajdata.utils import agent_utils, env_utils, scene_utils, string_utils - -# TODO(bivanovic): Move this to a better place in the codebase. -DEFAULT_PX_PER_M: Final[float] = 2.0 +from trajdata.utils.parallel_utils import parallel_iapply class UnifiedDataset(Dataset): @@ -60,11 +72,16 @@ def __init__( Tuple[AgentType, AgentType], float ] = defaultdict(lambda: np.inf), incl_robot_future: bool = False, - incl_map: bool = False, - map_params: Optional[Dict[str, Any]] = None, + incl_raster_map: bool = False, + raster_map_params: Optional[Dict[str, Any]] = None, + incl_vector_map: bool = False, + vector_map_params: Optional[Dict[str, Any]] = None, + require_map_cache: bool = True, only_types: Optional[List[AgentType]] = None, only_predict: Optional[List[AgentType]] = None, no_types: Optional[List[AgentType]] = None, + state_format: str = "x,y,xd,yd,xdd,ydd,h", + obs_format: str = "x,y,xd,yd,xdd,ydd,s,c", standardize_data: bool = True, standardize_derivatives: bool = False, augmentations: Optional[List[Augmentation]] = None, @@ -72,26 +89,32 @@ def __init__( max_neighbor_num: Optional[int] = None, ego_only: Optional[bool] = False, data_dirs: Dict[str, str] = { - # "nusc_trainval": "~/datasets/nuScenes", - # "nusc_test": "~/datasets/nuScenes", "eupeds_eth": "~/datasets/eth_ucy_peds", "eupeds_hotel": "~/datasets/eth_ucy_peds", "eupeds_univ": "~/datasets/eth_ucy_peds", "eupeds_zara1": "~/datasets/eth_ucy_peds", "eupeds_zara2": "~/datasets/eth_ucy_peds", "nusc_mini": "~/datasets/nuScenes", + # "nusc_trainval": "~/datasets/nuScenes", + # "nusc_test": "~/datasets/nuScenes", "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", # "lyft_train": "~/datasets/lyft/scenes/train.zarr", # "lyft_train_full": "~/datasets/lyft/scenes/train_full.zarr", # "lyft_val": "~/datasets/lyft/scenes/validate.zarr", + # "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", }, cache_type: str = "dataframe", cache_location: str = "~/.unified_data_cache", rebuild_cache: bool = False, rebuild_maps: bool = False, + save_index: bool = False, num_workers: int = 0, verbose: bool = False, extras: Dict[str, Callable[..., np.ndarray]] = dict(), + transforms: Iterable[ + Callable[..., Union[AgentBatchElement, SceneBatchElement]] + ] = (), + rank: int = 0, ) -> None: """Instantiates a PyTorch Dataset object which aggregates data from multiple trajectory forecasting datasets. @@ -105,11 +128,16 @@ def __init__( future_sec (Tuple[Optional[float], Optional[float]], optional): A tuple containing (the minimum seconds of future data each batch element must contain, the maximum seconds of future data to return). Both inclusive. Defaults to ( None, None, ). agent_interaction_distances: (Dict[Tuple[AgentType, AgentType], float]): A dictionary mapping agent-agent interaction distances in meters (determines which agents are included as neighbors to the predicted agent). Defaults to infinity for all types. incl_robot_future (bool, optional): Include the ego agent's future trajectory in batches (accordingly, never predict the ego's future). Defaults to False. - incl_map (bool, optional): Include a local cropping of the rasterized map (if the dataset provides a map) per agent. Defaults to False. - map_params (Optional[Dict[str, Any]], optional): Local map cropping parameters, must be specified if incl_map is True. Must contain keys {"px_per_m", "map_size_px"} and can optionally contain {"offset_frac_xy"}. Defaults to None. + incl_raster_map (bool, optional): Include a local cropping of the rasterized map (if the dataset provides a map) per agent. Defaults to False. + raster_map_params (Optional[Dict[str, Any]], optional): Local map cropping parameters, must be specified if incl_map is True. Must contain keys {"px_per_m", "map_size_px"} and can optionally contain {"offset_frac_xy"}. Defaults to None. + incl_vector_map (bool, optional): Include information about the scene's vector map (e.g., for use in nearest lane queries as an `extras` batch element function), + vector_map_params (Optional[Dict[str, Any]], optional): Vector map loading parameters. Defaults to None (by default only road lanes will be loaded as part of the map). + require_map_cache (bool, optional): Cache map objects (if the dataset provides a map) regardless of the value of incl_map. Defaults to True. only_types (Optional[List[AgentType]], optional): Filter out all agents EXCEPT for those of the specified types. Defaults to None. only_predict (Optional[List[AgentType]], optional): Only predict the specified types of agents. Importantly, this keeps other agent types in the scene, e.g., as neighbors of the agent to be predicted. Defaults to None. no_types (Optional[List[AgentType]], optional): Filter out all agents with the specified types. Defaults to None. + state_format (str, optional): Ordered comma separated list of elements to return for current/centered agent state. Defaults to "x,y,xd,yd,xdd,ydd,h". + obs_format (str, optional): Ordered comma separated list of elements to return for history and future agent state arrays. Defaults to "x,y,xd,yd,xdd,ydd,s,c". standardize_data (bool, optional): Standardize all data such that (1) the predicted agent's orientation at the current timestep is 0, (2) all data is made relative to the predicted agent's current position, and (3) the agent's heading value is replaced with its sin, cos values. Defaults to True. standardize_derivatives (bool, optional): Make agent velocities and accelerations relative to the agent being predicted. Defaults to False. augmentations (Optional[List[Augmentation]], optional): Perform the specified augmentations to the batch or dataset. Defaults to None. @@ -121,57 +149,106 @@ def __init__( cache_location (str, optional): Where to store and load preprocessed, cached data. Defaults to "~/.unified_data_cache". rebuild_cache (bool, optional): If True, process and cache trajectory data even if it is already cached. Defaults to False. rebuild_maps (bool, optional): If True, process and cache maps even if they are already cached. Defaults to False. + save_index (bool, optional): If True, save the resulting agent (or scene) data index after it is computed (speeding up subsequent initializations with the same argument values). num_workers (int, optional): Number of parallel workers to use for dataset preprocessing and loading. Defaults to 0. verbose (bool, optional): If True, print internal data loading information. Defaults to False. extras (Dict[str, Callable[..., np.ndarray]], optional): Adds extra data to each batch element. Each Callable must take as input a filled {Agent,Scene}BatchElement and return an ndarray which will subsequently be added to the batch element's `extra` dict. + transforms (Iterable[Callable], optional): Allows for custom modifications of batch elements. Each Callable must take in a filled {Agent,Scene}BatchElement and return a {Agent,Scene}BatchElement. + rank (int, optional): Proccess rank when using torch DistributedDataParallel for multi-GPU training. Only the rank 0 process will be used for caching. """ + self.desired_data: List[str] = desired_data + self.scene_description_contains: Optional[ + List[str] + ] = scene_description_contains self.centric: str = centric self.desired_dt: float = desired_dt if cache_type == "dataframe": - self.cache_class = DataFrameCache + self.cache_class = df_cache.DataFrameCache self.rebuild_cache: bool = rebuild_cache self.cache_path: Path = Path(cache_location).expanduser().resolve() self.cache_path.mkdir(parents=True, exist_ok=True) self.env_cache: EnvCache = EnvCache(self.cache_path) - if incl_map: + if incl_raster_map: assert ( - map_params is not None + raster_map_params is not None ), r"Path size information, i.e., {'px_per_m': ..., 'map_size_px': ...}, must be provided if incl_map=True" assert ( - map_params["map_size_px"] % 2 == 0 + raster_map_params["map_size_px"] % 2 == 0 ), "Patch parameter 'map_size_px' must be divisible by 2" + require_map_cache = require_map_cache or incl_raster_map + self.history_sec = history_sec self.future_sec = future_sec self.agent_interaction_distances = agent_interaction_distances self.incl_robot_future = incl_robot_future - self.incl_map = incl_map - self.map_params = ( - map_params if map_params is not None else {"px_per_m": DEFAULT_PX_PER_M} + + self.incl_raster_map = incl_raster_map + self.raster_map_params = ( + raster_map_params + if raster_map_params is not None + # Allowing for parallel map processing in case the user specifies num_workers. + else {"px_per_m": raster_utils.DEFAULT_PX_PER_M, "num_workers": num_workers} + ) + + self.incl_vector_map = incl_vector_map + self.vector_map_params = ( + vector_map_params + if vector_map_params is not None + else { + "incl_road_lanes": True, + "incl_road_areas": False, + "incl_ped_crosswalks": False, + "incl_ped_walkways": False, + # Collation can be quite slow if vector maps are included, + # so we do not unless the user requests it. + "collate": False, + # Whether loaded maps should be stored in memory (memoized) for later re-use. + # For datasets which provide full maps ahead-of-time (i.e., all except Waymo), + # this should be True. However, for Waymo it should be False because maps + # are already partitioned geographically and keeping them around significantly grows memory. + "keep_in_memory": True, + } ) + if self.desired_dt is not None: + self.vector_map_params["desired_dt"] = desired_dt + self.only_types = None if only_types is None else set(only_types) self.only_predict = None if only_predict is None else set(only_predict) self.no_types = None if no_types is None else set(no_types) + self.state_format = state_format + self.obs_format = obs_format self.standardize_data = standardize_data self.standardize_derivatives = standardize_derivatives self.augmentations = augmentations self.extras = extras + self.transforms = transforms self.verbose = verbose self.max_agent_num = max_agent_num + self.rank = rank self.max_neighbor_num = max_neighbor_num self.ego_only = ego_only + # Create requested state types now so pickling works + # (Needed for multiprocess dataloading) + self.np_state_type = NP_STATE_TYPES[state_format] + self.np_obs_type = NP_STATE_TYPES[obs_format] + self.torch_state_type = TORCH_STATE_TYPES[state_format] + self.torch_obs_type = TORCH_STATE_TYPES[obs_format] + # Ensuring scene description queries are all lowercase - if scene_description_contains is not None: - scene_description_contains = [s.lower() for s in scene_description_contains] + if self.scene_description_contains is not None: + self.scene_description_contains = [ + s.lower() for s in self.scene_description_contains + ] self.envs: List[RawDataset] = env_utils.get_raw_datasets(data_dirs) self.envs_dict: Dict[str, RawDataset] = {env.name: env for env in self.envs} - matching_datasets: List[SceneTag] = self.get_matching_scene_tags(desired_data) + matching_datasets: List[SceneTag] = self._get_matching_scene_tags(desired_data) if self.verbose: print( "Loading data for matched scene tags:", @@ -179,15 +256,23 @@ def __init__( flush=True, ) + self.check_args_combinations(matching_datasets) + + self._map_api: Optional[MapAPI] = None + if self.incl_vector_map: + self._map_api = MapAPI( + self.cache_path, + keep_in_memory=self.vector_map_params.get("keep_in_memory", True), + ) + all_scenes_list: Union[List[SceneMetadata], List[Scene]] = list() for env in self.envs: if any(env.name in dataset_tuple for dataset_tuple in matching_datasets): all_data_cached: bool = False - all_maps_cached: bool = not env.has_maps or not self.incl_map - + all_maps_cached: bool = not env.has_maps or not require_map_cache if self.env_cache.env_is_cached(env.name) and not self.rebuild_cache: - scenes_list: List[Scene] = self.get_desired_scenes_from_env( - matching_datasets, scene_description_contains, env + scenes_list: List[Scene] = self._get_desired_scenes_from_env( + matching_datasets, env ) all_data_cached: bool = all( @@ -199,13 +284,13 @@ def __init__( all_maps_cached: bool = ( not env.has_maps - or not self.incl_map + or not require_map_cache or all( self.cache_class.is_map_cached( self.cache_path, env.name, scene.location, - self.map_params["px_per_m"], + self.raster_map_params["px_per_m"], ) for scene in scenes_list ) @@ -228,20 +313,38 @@ def __init__( self.cache_path, env.name ) ): - env.cache_maps( - self.cache_path, - self.cache_class, - self.map_params, - ) + # Use only rank 0 process for caching when using multi-GPU torch training. + if rank == 0: + env.cache_maps( + self.cache_path, + self.cache_class, + self.raster_map_params, + ) - scenes_list: List[SceneMetadata] = self.get_desired_scenes_from_env( - matching_datasets, scene_description_contains, env - ) + # Wait for rank 0 process to be done with caching. + if ( + distributed.is_initialized() + and distributed.get_world_size() > 1 + ): + distributed.barrier() + + scenes_list: List[ + SceneMetadata + ] = self._get_desired_scenes_from_env(matching_datasets, env) + + if self.incl_vector_map and env.metadata.map_locations is not None: + # env.metadata.map_locations can be none for map-containing + # datasets if they have a huge number of maps + # (or map crops, like Waymo). + for map_name in env.metadata.map_locations: + self._map_api.get_map( + f"{env.name}:{map_name}", **self.vector_map_params + ) all_scenes_list += scenes_list # List of cached scene paths. - scene_paths: List[Path] = self.preprocess_scene_data( + scene_paths: List[Path] = self._preprocess_scene_data( all_scenes_list, num_workers ) if self.verbose: @@ -254,7 +357,11 @@ def __init__( data_index: Union[ List[Tuple[str, int, np.ndarray]], List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], - ] = self.get_data_index(num_workers, scene_paths) + ] + if self._index_cache_path().exists(): + data_index = self._load_data_index() + else: + data_index = self._get_data_index(num_workers, scene_paths) # Done with this list. Cutting memory usage because # of multiprocessing later on. @@ -271,7 +378,282 @@ def __init__( ) self._data_len: int = len(self._data_index) - def get_data_index( + # Use only rank 0 process for caching when using multi-GPU torch training. + if save_index and rank == 0: + if self._index_cache_path().exists(): + print( + "WARNING: Overwriting already-cached data index (since save_index is True).", + flush=True, + ) + + self._cache_data_index(data_index) + + # Wait for rank 0 process to be done with caching. + if distributed.is_initialized() and distributed.get_world_size() > 1: + distributed.barrier() + + self._cached_batch_elements = None + + def check_args_combinations(self, chosen_datasets: List[SceneTag]) -> None: + """Warn users about potential "gotcha" combinations of arguments, + usually involving fundamental limits in datasets. + """ + waymo_warning_given: bool = False + waymo_pattern: re.Pattern = re.compile("waymo") + + nuplan_warning_given: bool = False + nuplan_pattern: re.Pattern = re.compile("nuplan") + + dataset: SceneTag + for dataset in chosen_datasets: + if ( + not waymo_warning_given + and self.vector_map_params["incl_road_areas"] + and dataset.matches_any(waymo_pattern) + ): + warnings.warn( + ( + "\n\n############ WARNING! ############\n" + "Waymo has many gaps in the associations between " + "lane centerlines and boundaries,\nmaking it difficult " + "to construct lane edge polylines or road area polygons.\n" + "The ones currently provided by trajdata should be considered " + "low quality!" + "\n#################################\n" + ) + ) + waymo_warning_given = True + + elif ( + not nuplan_warning_given + and self.incl_vector_map + and dataset.matches_any(nuplan_pattern) + ): + warnings.warn( + ( + "\n\n############ WARNING! ############\n" + "nuPlan uses Shapely to represent its map. " + "Shapely only supports 2D coordinates.\nHowever, " + "nuPlan's agent trajectories are provided in 3D " + "(with a non-zero z-coordinate).\nThus, any " + "spatial queries (e.g., nearest lane or road area) " + "should be executed with the z-coordinate " + "set to 0.0 manually." + "\n#################################\n" + ) + ) + nuplan_warning_given = True + + def _index_cache_path( + self, ret_args: bool = False + ) -> Union[Path, Tuple[Path, Dict[str, Any]]]: + # Whichever UnifiedDataset arguments affect data indexing are captured + # and hashed together here. + impactful_args: Dict[str, Any] = { + "desired_data": tuple(self.desired_data), + "scene_description_contains": tuple(self.scene_description_contains) + if self.scene_description_contains is not None + else None, + "centric": self.centric, + "desired_dt": self.desired_dt, + "history_sec": self.history_sec, + "future_sec": self.future_sec, + "incl_robot_future": self.incl_robot_future, + "only_types": tuple(t.name for t in self.only_types) + if self.only_types is not None + else None, + "only_predict": tuple(t.name for t in self.only_predict) + if self.only_predict is not None + else None, + "no_types": tuple(t.name for t in self.no_types) + if self.no_types is not None + else None, + "ego_only": self.ego_only, + } + index_hash: str = py_utils.hash_dict(impactful_args) + index_cache_path: Path = self.cache_path / "data_indexes" / index_hash + + if ret_args: + return index_cache_path, impactful_args + else: + return index_cache_path + + def _cache_data_index( + self, + data_index: Union[ + List[Tuple[str, int, np.ndarray]], + List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], + ], + ) -> None: + index_cache_dir, index_args = self._index_cache_path(ret_args=True) + + # Create it if it doesn't exist yet. + index_cache_dir.mkdir(parents=True, exist_ok=True) + + index_cache_file: Path = index_cache_dir / "data_index.dill" + with open(index_cache_file, "wb") as f: + dill.dump(data_index, f) + + args_file: Path = index_cache_dir / "index_args.json" + with open(args_file, "w") as f: + json.dump(index_args, f, indent=4) + + print( + f"Cached data index to {str(index_cache_file)}", + flush=True, + ) + + def _load_data_index( + self, + ) -> Union[ + List[Tuple[str, int, np.ndarray]], + List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], + ]: + index_cache_file: Path = self._index_cache_path() / "data_index.dill" + with open(index_cache_file, "rb") as f: + data_index = dill.load(f) + + if self.verbose: + print( + f"Loaded data index from {str(index_cache_file)}", + flush=True, + ) + + return data_index + + def load_or_create_cache( + self, cache_path: str, num_workers=0, filter_fn=None + ) -> None: + if isfile(cache_path): + print(f"Loading cache from {cache_path} ...", end="") + t = time.time() + with open(cache_path, "rb") as f: + self._cached_batch_elements, keep_ids = dill.load(f, encoding="latin1") + print(f" done in {time.time() - t:.1f}s.") + + else: + # Build cache + cached_batch_elements = [] + keep_ids = [] + + if num_workers <= 0: + cache_data_iterator = self + else: + # Use DataLoader as a generic multiprocessing framework. + # We set batchsize=1 and a custom collate function. + # In effect this will just call self.__getitem__ in parallel. + cache_data_iterator = DataLoader( + self, + batch_size=1, + num_workers=num_workers, + shuffle=False, + collate_fn=lambda xlist: xlist[0], + ) + + for element in tqdm( + cache_data_iterator, + desc=f"Caching batch elements ({num_workers} CPUs): ", + disable=False, + ): + if filter_fn is None or filter_fn(element): + cached_batch_elements.append(element) + keep_ids.append(element.data_index) + + # Just deletes the variable cache_data_iterator, + # not self (in case it is set to that)! + del cache_data_iterator + + print(f"Saving cache to {cache_path} ....", end="") + t = time.time() + with open(cache_path, "wb") as f: + dill.dump((cached_batch_elements, keep_ids), f) + print(f" done in {time.time() - t:.1f}s.") + + self._cached_batch_elements = cached_batch_elements + + # Remove unwanted elements + self.remove_elements(keep_ids=keep_ids) + + # Verify + if len(self._cached_batch_elements) != self._data_len: + raise ValueError("Current data and cached data lengths do not match!") + + def apply_filter( + self, + filter_fn: Callable[[Union[AgentBatchElement, SceneBatchElement]], bool], + num_workers: int = 0, + max_count: Optional[int] = None, + all_gather: Optional[Callable] = None, + ) -> None: + keep_ids = [] + keep_count = 0 + + if filter_fn is None: + return + + # Only do filtering with rank=0 + if self.rank == 0: + if num_workers <= 0: + cache_data_iterator = self + else: + # Multiply num_workers by the number of torch processes, because + # we will only be using rank 0 process, whereas + # num_workers is typically defined per torch process. + if distributed.is_initialized(): + num_workers = num_workers * distributed.get_world_size() + + # Use DataLoader as a generic multiprocessing framework. + # We set batchsize=1 and a custom collate function. + # In effect this will just call self.__getitem__ in parallel. + cache_data_iterator = DataLoader( + self, + batch_size=1, + num_workers=num_workers, + shuffle=False, + collate_fn=lambda xlist: xlist[0], + ) + + # Iterate over data + for element in tqdm( + cache_data_iterator, + desc=f"Filtering dataset ({num_workers} CPUs): ", + disable=False, + ): + if filter_fn(element): + keep_ids.append(element.data_index) + keep_count += 1 + if max_count is not None and keep_count >= max_count: + # Add False for remaining samples and break loop + print( + f"Reached maximum number of {max_count} elements, terminating early." + ) + break + + del cache_data_iterator + + # Remove unwanted elements + self.remove_elements(keep_ids=keep_ids) + + # Wait for rank 0 process to be done with caching. + # Note that the default timeout is 30 minutes. If filtering is expected to exceed this, the timeout can be + # increased when initializing the process group, i.e., torch.distributed.init_process_group(timeout=...) + if distributed.is_initialized() and distributed.get_world_size() > 1: + gathered_values = all_gather(self._data_index) + # All proceses use the indices from rank 0 + self._data_index = gathered_values[0] + self._data_len = len(self._data_index) + print(f"Rank {self.rank} has {self._data_len} elements.") + + def remove_elements(self, keep_ids: Union[np.ndarray, List[int]]): + old_len = self._data_len + self._data_index = [self._data_index[i] for i in keep_ids] + self._data_len = len(self._data_index) + + print( + f"Kept {self._data_len}/{old_len} elements, {self._data_len/old_len*100.0:.2f}%." + ) + + def _get_data_index( self, num_workers: int, scene_paths: List[Path] ) -> Union[ List[Tuple[str, int, np.ndarray]], @@ -325,7 +707,7 @@ def get_data_index( if len(index_elems) > 0: data_index.append((str(orig_path), index_elems_len, index_elems)) else: - for (_, orig_path, index_elems_len, index_elems) in parallel_iapply( + for _, orig_path, index_elems_len, index_elems in parallel_iapply( data_index_fn, scene_paths, num_workers=num_workers, @@ -377,7 +759,7 @@ def _get_data_index_scene( (scene if ret_scene_info else None), scene_info_path, len(index_elems), - np.array(index_elems, dtype=np.int), + np.array(index_elems, dtype=int), ) @staticmethod @@ -422,7 +804,7 @@ def _get_data_index_agent( num_agent_ts: int = valid_ts[1] - valid_ts[0] + 1 if num_agent_ts > 0: index_elems_len += num_agent_ts - index_elems.append((agent_info.name, np.array(valid_ts, dtype=np.int))) + index_elems.append((agent_info.name, np.array(valid_ts, dtype=int))) return ( (scene if ret_scene_info else None), @@ -464,7 +846,7 @@ def get_collate_fn( return collate_fn - def get_matching_scene_tags(self, queries: List[str]) -> List[SceneTag]: + def _get_matching_scene_tags(self, queries: List[str]) -> List[SceneTag]: # if queries is None: # return list(chain.from_iterable(env.components for env in self.envs)) @@ -485,25 +867,26 @@ def get_matching_scene_tags(self, queries: List[str]) -> List[SceneTag]: return matching_scene_tags - def get_desired_scenes_from_env( + def _get_desired_scenes_from_env( self, scene_tags: List[SceneTag], - scene_description_contains: Optional[List[str]], env: RawDataset, ) -> Union[List[Scene], List[SceneMetadata]]: scenes_list: Union[List[Scene], List[SceneMetadata]] = list() - for scene_tag in scene_tags: + for scene_tag in tqdm( + scene_tags, desc=f"Getting Scenes from {env.name}", disable=not self.verbose + ): if env.name in scene_tag: scenes_list += env.get_matching_scenes( scene_tag, - scene_description_contains, + self.scene_description_contains, self.env_cache, self.rebuild_cache, ) return scenes_list - def preprocess_scene_data( + def _preprocess_scene_data( self, scenes_list: Union[List[SceneMetadata], List[Scene]], num_workers: int, @@ -530,6 +913,10 @@ def preprocess_scene_data( for scene_info in scenes_list if self.envs_dict[scene_info.env_name].parallelizable ] + + # Fixing the seed for random suffling (for debugging and reproducibility). + shuffle_rng = random.Random(123) + shuffle_rng.shuffle(parallel_scenes) else: serial_scenes = scenes_list parallel_scenes = list() @@ -548,7 +935,7 @@ def preprocess_scene_data( scene_dt: float = ( self.desired_dt if self.desired_dt is not None else scene_info.dt ) - if self.env_cache.scene_is_cached( + if not self.rebuild_cache and self.env_cache.scene_is_cached( scene_info.env_name, scene_info.name, scene_dt ): # This is a fast path in case we don't need to @@ -639,7 +1026,11 @@ def preprocess_scene_data( desc=f"Calculating Agent Data ({num_workers} CPUs)", disable=not self.verbose, ): - scene_paths += [Path(path_str) for path_str in processed_scene_paths] + scene_paths += [ + Path(path_str) + for path_str in processed_scene_paths + if path_str is not None + ] return scene_paths @@ -655,10 +1046,17 @@ def scenes(self) -> Scene: for scene_idx in range(self.num_scenes()): yield self.get_scene(scene_idx) + def __iter__(self): + for i in range(len(self)): + yield self[i] + def __len__(self) -> int: return self._data_len def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: + if self._cached_batch_elements is not None: + return self._cached_batch_elements[idx] + if self.centric == "scene": scene_path, ts = self._data_index[idx] elif self.centric == "agent": @@ -667,8 +1065,9 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: scene: Scene = EnvCache.load(scene_path) scene_utils.enforce_desired_dt(scene, self.desired_dt) scene_cache: SceneCache = self.cache_class( - self.cache_path, scene, ts, self.augmentations + self.cache_path, scene, self.augmentations ) + scene_cache.set_obs_format(self.obs_format) if self.centric == "scene": scene_time: SceneTime = SceneTime.from_cache( @@ -687,8 +1086,11 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: self.future_sec, self.agent_interaction_distances, self.incl_robot_future, - self.incl_map, - self.map_params, + self.incl_raster_map, + self.raster_map_params, + self._map_api, + self.vector_map_params, + self.state_format, self.standardize_data, self.standardize_derivatives, self.max_agent_num, @@ -712,8 +1114,11 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: self.future_sec, self.agent_interaction_distances, self.incl_robot_future, - self.incl_map, - self.map_params, + self.incl_raster_map, + self.raster_map_params, + self._map_api, + self.vector_map_params, + self.state_format, self.standardize_data, self.standardize_derivatives, self.max_neighbor_num, @@ -722,4 +1127,10 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: for key, extra_fn in self.extras.items(): batch_element.extras[key] = extra_fn(batch_element) + for transform_fn in self.transforms: + batch_element = transform_fn(batch_element) + + if not self.vector_map_params.get("collate", False): + batch_element.vec_map = None + return batch_element diff --git a/src/trajdata/dataset_specific/__init__.py b/src/trajdata/dataset_specific/__init__.py index 85a62ef..5293df9 100644 --- a/src/trajdata/dataset_specific/__init__.py +++ b/src/trajdata/dataset_specific/__init__.py @@ -1,2 +1 @@ from .raw_dataset import RawDataset -from .scene_records import EUPedsRecord, LyftSceneRecord, NuscSceneRecord diff --git a/src/trajdata/dataset_specific/argoverse2/__init__.py b/src/trajdata/dataset_specific/argoverse2/__init__.py new file mode 100644 index 0000000..5a94b22 --- /dev/null +++ b/src/trajdata/dataset_specific/argoverse2/__init__.py @@ -0,0 +1 @@ +from .av2_dataset import Av2Dataset diff --git a/src/trajdata/dataset_specific/argoverse2/av2_dataset.py b/src/trajdata/dataset_specific/argoverse2/av2_dataset.py new file mode 100644 index 0000000..24ada32 --- /dev/null +++ b/src/trajdata/dataset_specific/argoverse2/av2_dataset.py @@ -0,0 +1,191 @@ +from pathlib import Path +from typing import Any, Dict, List, Tuple, Type + +import pandas as pd +import tqdm +from av2.datasets.motion_forecasting.constants import ( + AV2_SCENARIO_OBS_TIMESTEPS, + AV2_SCENARIO_STEP_HZ, + AV2_SCENARIO_TOTAL_TIMESTEPS, +) + +from trajdata.caching.env_cache import EnvCache +from trajdata.caching.scene_cache import SceneCache +from trajdata.data_structures import AgentMetadata, EnvMetadata, Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.argoverse2.av2_utils import ( + AV2_SPLITS, + Av2Object, + Av2ScenarioIds, + av2_map_to_vector_map, + get_track_metadata, + scenario_name_to_split, +) +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import Argoverse2Record +from trajdata.utils import arr_utils + +AV2_MOTION_FORECASTING = "av2_motion_forecasting" +AV2_DT = 1 / AV2_SCENARIO_STEP_HZ + + +class Av2Dataset(RawDataset): + + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + if env_name != AV2_MOTION_FORECASTING: + raise ValueError(f"Unknown Argoverse 2 env name: {env_name}") + + scenario_ids = Av2ScenarioIds.create(Path(data_dir)) + + return EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=AV2_DT, + parts=[AV2_SPLITS], + scene_split_map=scenario_ids.scene_split_map, + map_locations=None, + ) + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + print(f"Loading {self.name} dataset...", flush=True) + self.dataset_obj = Av2Object(self.metadata.data_dir) + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: List[str] | None, + env_cache: EnvCache, + ) -> List[SceneMetadata]: + """Compute SceneMetadata for all samples from self.dataset_obj. + + Also saves records to env_cache for later reuse. + """ + if scene_desc_contains: + raise ValueError("Argoverse dataset does not support scene descriptions.") + + record_list = [] + metadata_list = [] + + for idx, scenario_name in enumerate(self.dataset_obj.scenario_names): + record_list.append(Argoverse2Record(scenario_name, idx)) + metadata_list.append( + SceneMetadata( + env_name=self.metadata.name, + name=scenario_name, + dt=AV2_DT, + raw_data_idx=idx, + ) + ) + + self.cache_all_scenes_list(env_cache, record_list) + return metadata_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: List[str] | None, + env_cache: EnvCache, + ) -> List[Scene]: + """Computes Scene data for all samples by reading data from env_cache.""" + if scene_desc_contains: + raise ValueError("Argoverse dataset does not support scene descriptions.") + + record_list: List[Argoverse2Record] = env_cache.load_env_scenes_list(self.name) + return [ + self._create_scene(record.name, record.data_idx) for record in record_list + ] + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + return self._create_scene(scene_info.name, scene_info.raw_data_idx) + + def _create_scene(self, scenario_name: str, data_idx: int) -> Scene: + data_split = scenario_name_to_split(scenario_name) + return Scene( + env_metadata=self.metadata, + name=scenario_name, + location=scenario_name, + data_split=data_split, + length_timesteps=( + AV2_SCENARIO_OBS_TIMESTEPS + if data_split == "test" + else AV2_SCENARIO_TOTAL_TIMESTEPS + ), + raw_data_idx=data_idx, + data_access_info=None, + ) + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + """ + Get frame-level information from source dataset, caching it + to cache_path. + + Always called after cache_maps, can load map if needed + to associate map information to positions. + """ + scenario = self.dataset_obj.load_scenario(scene.name) + + agent_list: List[AgentMetadata] = [] + agent_presence: List[List[AgentMetadata]] = [[] for _ in scenario.timestamps_ns] + + df_records = [] + + for track in scenario.tracks: + track_metadata = get_track_metadata(track) + if track_metadata is None: + continue + + agent_list.append(track_metadata) + + for object_state in track.object_states: + agent_presence[int(object_state.timestep)].append(track_metadata) + + df_records.append( + { + "agent_id": track_metadata.name, + "scene_ts": object_state.timestep, + "x": object_state.position[0], + "y": object_state.position[1], + "z": 0.0, + "vx": object_state.velocity[0], + "vy": object_state.velocity[1], + "heading": object_state.heading, + } + ) + + df = pd.DataFrame.from_records(df_records) + df.set_index(["agent_id", "scene_ts"], inplace=True) + df.sort_index(inplace=True) + + df[["ax", "ay"]] = ( + arr_utils.agent_aware_diff( + df[["vx", "vy"]].to_numpy(), df.index.get_level_values(0) + ) + / AV2_DT + ) + cache_class.save_agent_data(df, cache_path, scene) + + return agent_list, agent_presence + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + """ + Get static, scene-level info from the source dataset, caching it + to cache_path. (Primarily this is info needed to construct VectorMap) + + Resolution is in pixels per meter. + """ + for scenario_name in tqdm.tqdm( + self.dataset_obj.scenario_names, + desc=f"{self.name} cache maps", + dynamic_ncols=True, + ): + av2_map = self.dataset_obj.load_map(scenario_name) + vector_map = av2_map_to_vector_map(f"{self.name}:{scenario_name}", av2_map) + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) diff --git a/src/trajdata/dataset_specific/argoverse2/av2_utils.py b/src/trajdata/dataset_specific/argoverse2/av2_utils.py new file mode 100644 index 0000000..7d060c2 --- /dev/null +++ b/src/trajdata/dataset_specific/argoverse2/av2_utils.py @@ -0,0 +1,255 @@ +import dataclasses +import os +from pathlib import Path +from typing import Dict, Literal, Optional, Tuple + +import numpy as np +from av2.datasets.motion_forecasting.data_schema import ( + ArgoverseScenario, + ObjectType, + Track, +) +from av2.datasets.motion_forecasting.scenario_serialization import ( + load_argoverse_scenario_parquet, +) +from av2.datasets.motion_forecasting.viz.scenario_visualization import ( + _ESTIMATED_CYCLIST_LENGTH_M, + _ESTIMATED_CYCLIST_WIDTH_M, + _ESTIMATED_VEHICLE_LENGTH_M, + _ESTIMATED_VEHICLE_WIDTH_M, +) +from av2.geometry.interpolate import compute_midpoint_line +from av2.map.map_api import ArgoverseStaticMap + +from trajdata.data_structures.agent import AgentMetadata, AgentType, FixedExtent +from trajdata.maps.vec_map import VectorMap +from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadArea, RoadLane + +AV2_SPLITS = ("train", "val", "test") +DELIM = "_" +T_Split = Literal["train", "val", "test"] + + +# {ObjectType: (AgentType, length, width, height)}. +# Uses av2 constants where possible. +OBJECT_TYPE_DATA: Dict[str, Tuple[AgentType, float, float, float]] = { + ObjectType.VEHICLE: ( + AgentType.VEHICLE, + _ESTIMATED_VEHICLE_LENGTH_M, + _ESTIMATED_VEHICLE_WIDTH_M, + 2, + ), + ObjectType.PEDESTRIAN: (AgentType.PEDESTRIAN, 0.7, 0.7, 2), + ObjectType.MOTORCYCLIST: ( + AgentType.MOTORCYCLE, + _ESTIMATED_CYCLIST_LENGTH_M, + _ESTIMATED_CYCLIST_WIDTH_M, + 2, + ), + ObjectType.CYCLIST: ( + AgentType.BICYCLE, + _ESTIMATED_CYCLIST_LENGTH_M, + _ESTIMATED_CYCLIST_WIDTH_M, + 2, + ), + ObjectType.BUS: (AgentType.VEHICLE, 9, 3, 4), +} + + +@dataclasses.dataclass +class Av2ScenarioIds: + train: list[str] + val: list[str] + test: list[str] + + @staticmethod + def create(dataset_path: Path) -> "Av2ScenarioIds": + train = os.listdir(dataset_path / "train") + val = os.listdir(dataset_path / "val") + test = os.listdir(dataset_path / "test") + return Av2ScenarioIds(train=train, val=val, test=test) + + @property + def scene_split_map(self) -> Dict[str, T_Split]: + """Compute a map of {scenario_name: split}.""" + return { + _pack_av2_scenario_name(split, scenario_id): split + for split, scenario_ids in dataclasses.asdict(self).items() + for scenario_id in scenario_ids + } + + +def scenario_name_to_split(scenario_name: str) -> T_Split: + split, _ = _unpack_av2_scenario_name(scenario_name) + return split + + +def _pack_av2_scenario_name(split: T_Split, scenario_id: str) -> str: + return split + DELIM + scenario_id + + +def _unpack_av2_scenario_name(scenario_name: str) -> Tuple[T_Split, str]: + return tuple(scenario_name.split(DELIM, maxsplit=1)) + + +def _scenario_df_filename(scenario_id: str) -> str: + return f"scenario_{scenario_id}.parquet" + + +def _scenario_map_filename(scenario_id: str) -> str: + return f"log_map_archive_{scenario_id}.json" + + +class Av2Object: + """Object for interfacing with Av2 data on disk.""" + + def __init__(self, dataset_path: Path) -> None: + self.dataset_path = dataset_path + self.scenario_ids = Av2ScenarioIds.create(dataset_path) + + @property + def scenario_names(self) -> list[str]: + return list(self.scenario_ids.scene_split_map) + + def _parse_scenario_name(self, scenario_name: str) -> Tuple[Path, str]: + split, scenario_id = _unpack_av2_scenario_name(scenario_name) + del scenario_name + + scenario_dir = self.dataset_path / split / scenario_id + if not scenario_dir.exists(): + raise FileNotFoundError(f"Scenario path {scenario_dir} not found") + return scenario_dir, scenario_id + + def load_scenario(self, scenario_name: str) -> ArgoverseScenario: + scenario_dir, scenario_id = self._parse_scenario_name(scenario_name) + return load_argoverse_scenario_parquet( + scenario_dir / _scenario_df_filename(scenario_id) + ) + + def load_map(self, scenario_name: str) -> ArgoverseStaticMap: + scenario_dir, scenario_id = self._parse_scenario_name(scenario_name) + return ArgoverseStaticMap.from_json( + scenario_dir / _scenario_map_filename(scenario_id) + ) + + +def av2_map_to_vector_map(map_id: str, av2_map: ArgoverseStaticMap) -> VectorMap: + vector_map = VectorMap(map_id) + + extents: Optional[Tuple[np.ndarray, np.ndarray]] = None + + for lane_segment in av2_map.vector_lane_segments.values(): + lane_max = np.maximum( + lane_segment.left_lane_boundary.xyz.max(0), + lane_segment.right_lane_boundary.xyz.max(0), + ) + lane_min = np.minimum( + lane_segment.left_lane_boundary.xyz.min(0), + lane_segment.right_lane_boundary.xyz.min(0), + ) + + if extents is None: + extents = (lane_min, lane_max) + else: + extents = ( + np.minimum(lane_min, extents[0]), + np.maximum(lane_max, extents[1]), + ) + + center, _ = compute_midpoint_line( + lane_segment.left_lane_boundary.xyz, lane_segment.right_lane_boundary.xyz + ) + vector_map.add_map_element( + RoadLane( + id=_road_lane_id(lane_segment.id), + center=Polyline(center), + left_edge=Polyline(lane_segment.left_lane_boundary.xyz), + right_edge=Polyline(lane_segment.right_lane_boundary.xyz), + adj_lanes_left=_adj_lanes_set(lane_segment.left_neighbor_id), + adj_lanes_right=_adj_lanes_set(lane_segment.right_neighbor_id), + next_lanes={_road_lane_id(i) for i in lane_segment.successors}, + prev_lanes={_road_lane_id(i) for i in lane_segment.predecessors}, + ) + ) + + for drivavble_area in av2_map.vector_drivable_areas.values(): + assert extents is not None + extents = ( + np.minimum(drivavble_area.xyz.min(0), extents[0]), + np.maximum(drivavble_area.xyz.max(0), extents[1]), + ) + + vector_map.add_map_element( + RoadArea( + id=_road_area_id(drivavble_area.id), + exterior_polygon=Polyline(drivavble_area.xyz), + ) + ) + + for ped_crossing in av2_map.vector_pedestrian_crossings.values(): + assert extents is not None + extents = ( + np.minimum(ped_crossing.polygon.min(0), extents[0]), + np.maximum(ped_crossing.polygon.max(0), extents[1]), + ) + vector_map.add_map_element( + PedCrosswalk( + id=_ped_crosswalk_id(ped_crossing.id), + polygon=Polyline(ped_crossing.polygon), + ) + ) + + # extent is [min_x, min_y, min_z, max_x, max_y, max_z] + vector_map.extent = np.concatenate(extents) + + return vector_map + + +def _adj_lanes_set(neighbor_id: Optional[int]) -> set[str]: + if neighbor_id is None: + return set() + return {_road_lane_id(neighbor_id)} + + +def _road_lane_id(lane_segment_id: int) -> str: + return f"RoadLane{lane_segment_id}" + + +def _road_area_id(drivable_area_id: int) -> str: + return f"RoadArea{drivable_area_id}" + + +def _ped_crosswalk_id(ped_crossing_id: int) -> str: + return f"PedCrosswalk{ped_crossing_id}" + + +def get_track_metadata(track: Track) -> Optional[AgentMetadata]: + agent_data = OBJECT_TYPE_DATA.get(track.object_type) + if agent_data is None: + return None + + agent_type, length, width, height = agent_data + + timesteps = [_to_int(state.timestep) for state in track.object_states] + if not timesteps: + return None + + # Av2 uses the name "AV" for the robot. Trajdata expects the name "ego" for the robot. + name = track.track_id + if name == "AV": + name = "ego" + + return AgentMetadata( + name=name, + agent_type=agent_type, + first_timestep=min(timesteps), + last_timestep=max(timesteps), + extent=FixedExtent(length=length, width=width, height=height), + ) + + +def _to_int(x: float) -> int: + """Safe convert floats like 42.0 to 42.""" + y = int(x) + assert x == y + return y diff --git a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py index 2bca501..4234df4 100644 --- a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py +++ b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py @@ -247,6 +247,9 @@ def get_agent_info( agent_ids: np.ndarray = scene_data.index.get_level_values(0).to_numpy() + # Add in zero for z value + scene_data["z"] = np.zeros_like(scene_data["x"]) + ### Calculating agent velocities scene_data[["vx", "vy"]] = ( arr_utils.agent_aware_diff(scene_data[["x", "y"]].to_numpy(), agent_ids) @@ -278,8 +281,6 @@ def get_agent_info( last_frame: int = frames.iat[-1].item() if frames.shape[0] < last_frame - start_frame + 1: - # Fun fact: this is never hit which means Lyft has no missing - # timesteps (which could be caused by, e.g., occlusion). raise ValueError("ETH/UCY indeed can have missing frames :(") agent_metadata = AgentMetadata( diff --git a/src/trajdata/dataset_specific/interaction/__init__.py b/src/trajdata/dataset_specific/interaction/__init__.py new file mode 100644 index 0000000..c5fe2e4 --- /dev/null +++ b/src/trajdata/dataset_specific/interaction/__init__.py @@ -0,0 +1 @@ +from .interaction_dataset import InteractionDataset diff --git a/src/trajdata/dataset_specific/interaction/interaction_dataset.py b/src/trajdata/dataset_specific/interaction/interaction_dataset.py new file mode 100644 index 0000000..89cc29a --- /dev/null +++ b/src/trajdata/dataset_specific/interaction/interaction_dataset.py @@ -0,0 +1,530 @@ +import os +import time +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, Final, List, Optional, Tuple, Type + +import lanelet2 +import numpy as np +import pandas as pd +from tqdm import tqdm + +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures.agent import AgentMetadata, AgentType, FixedExtent +from trajdata.data_structures.environment import EnvMetadata +from trajdata.data_structures.scene_metadata import Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import InteractionRecord +from trajdata.maps import VectorMap +from trajdata.maps.vec_map_elements import Polyline, RoadLane +from trajdata.utils import arr_utils + +# SDD was captured at 10 frames per second. +INTERACTION_DT: Final[float] = 0.1 + +# For training, 1 second of history is used to predict 3 seconds into the future. +# For testing, only 1 second of observations are provided. +INTERACTION_TRAIN_LENGTH: Final[int] = 40 +INTERACTION_TEST_LENGTH: Final[int] = 10 +INTERACTION_NUM_FILES: Final[int] = 56 +INTERACTION_LOCATIONS: Final[Tuple[str, str, str, str]] = ( + "usa", + "china", + "germany", + "bulgaria", +) + +INTERACTION_DETAILED_LOCATIONS: Final[Tuple[str, ...]] = ( + "CHN_Merging_ZS0", + "CHN_Merging_ZS2", + "CHN_Roundabout_LN", + "DEU_Merging_MT", + "DEU_Roundabout_OF", + "Intersection_CM", + "LaneChange_ET0", + "LaneChange_ET1", + "Merging_TR0", + "Merging_TR1", + "Roundabout_RW", + "USA_Intersection_EP0", + "USA_Intersection_EP1", + "USA_Intersection_GL", + "USA_Intersection_MA", + "USA_Roundabout_EP", + "USA_Roundabout_FT", + "USA_Roundabout_SR", +) + + +def get_split(scene_name: str, no_case: bool = False) -> str: + if no_case: + case_id_str = "" + else: + case_id_str = f"_{int(scene_name.split('_')[-1])}" + if scene_name.endswith(f"test_condition{case_id_str}"): + return "test_condition" + else: + if no_case: + return scene_name.split("_")[-1] + else: + return scene_name.split("_")[-2] + + +def get_location(scene_name: str) -> Tuple[str, str]: + if scene_name.startswith("DR_DEU"): + country = "germany" + elif scene_name.startswith("DR_CHN"): + country = "china" + elif scene_name.startswith("DR_USA"): + country = "usa" + else: + country = "bulgaria" + + if country != "bulgaria": + detailed_loc = "_".join(scene_name.split("_")[1:4]) + else: + detailed_loc = "_".join(scene_name.split("_")[1:3]) + + return country, detailed_loc + + +def interaction_type_to_unified_type(label: str) -> AgentType: + if label == "car": + return AgentType.VEHICLE + elif label == "pedestrian/bicycle": + return AgentType.PEDESTRIAN + raise + + +def get_last_line(file_path: Path) -> str: + with open(file_path, "rb") as file: + # Go to the end of the file before the last break-line + file.seek(-2, os.SEEK_END) + + # Keep reading backward until you find the next break-line + while file.read(1) != b"\n": + file.seek(-2, os.SEEK_CUR) + + return file.readline().decode() + + +class InteractionDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + # INTERACTION dataset possibilities are the Cartesian product of these. + dataset_parts = [ + ("train", "val", "test", "test_conditional"), + INTERACTION_LOCATIONS, + ] + + if env_name not in {"interaction_multi", "interaction_single"}: + raise ValueError( + f"{env_name} not found in INTERACTION dataset. Options are {'interaction_multi', 'interaction_single'}" + ) + + return EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=INTERACTION_DT, + parts=dataset_parts, + # No need since we'll have it in the scene name (and the scene names + # are not unique between the two test types). + scene_split_map=None, + # The location names should match the map names used in + # the unified data cache. + map_locations=INTERACTION_DETAILED_LOCATIONS, + ) + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + print(f"Loading {self.name} dataset...", flush=True) + + data_dir_path = Path(self.metadata.data_dir) + + # Just storing the filepath and scene length (number of frames). + # One could load the entire dataset here, but there's no need since some + # of these are large in size and we can parallel process it later easily. + self.dataset_obj: Dict[str, Tuple[Path, int, np.ndarray]] = dict() + for scene_path in tqdm( + data_dir_path.glob("**/*.csv"), + disable=not verbose, + total=INTERACTION_NUM_FILES, + ): + scene_name = scene_path.stem + + scene_split: str = "" + if scene_name.endswith("obs"): + scene_split = f"_{scene_path.parent.stem[:-len('-multi-agent')-1]}" + + num_scenarios = int(float(get_last_line(scene_path).split(",")[0])) + + self.dataset_obj[f"{scene_name}{scene_split}"] = ( + scene_path, + INTERACTION_TRAIN_LENGTH + if len(scene_split) == 0 + else INTERACTION_TEST_LENGTH, + num_scenarios, + ) + + if verbose: + print( + f"The first ~60 iterations might be slow, don't worry the following ones will be fast.", + flush=True, + ) + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[InteractionRecord] = list() + + scenes_list: List[SceneMetadata] = list() + idx: int = 0 + for scene_name, (_, scene_length, num_scenarios) in self.dataset_obj.items(): + scene_split: str = get_split(scene_name, no_case=True) + country, _ = get_location(scene_name) + + for scenario_num in range(num_scenarios): + scene_name_with_num = f"{scene_name}_{scenario_num}" + + # Saving all scene records for later caching. + all_scenes_list.append( + InteractionRecord(scene_name_with_num, scene_length, idx) + ) + + if ( + country in scene_tag + and scene_split in scene_tag + and scene_desc_contains is None + ): + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=scene_name_with_num, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + + idx += 1 + + self.cache_all_scenes_list(env_cache, all_scenes_list) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list: List[InteractionRecord] = env_cache.load_env_scenes_list( + self.name + ) + + scenes_list: List[Scene] = list() + for scene_record in all_scenes_list: + scene_name, scene_length, data_idx = scene_record + + scene_split: str = get_split(scene_name) + country, scene_location = get_location(scene_name) + + if ( + country in scene_tag + and scene_split in scene_tag + and scene_desc_contains is None + ): + scene_metadata = Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + None, # This isn't used if everything is already cached. + ) + scenes_list.append(scene_metadata) + + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + _, scene_name, _, data_idx = scene_info + + scene_length = ( + INTERACTION_TRAIN_LENGTH + if scene_name.split("_")[-2] in {"train", "val"} + else INTERACTION_TEST_LENGTH + ) + scene_split: str = get_split(scene_name) + _, scene_location = get_location(scene_name) + + return Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + None, # No data access info necessary for the INTERACTION dataset. + ) + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + scene_name_parts: List[str] = scene.name.split("_") + base_scene_name: str = "_".join(scene_name_parts[:-1]) + orig_scenario_num = int(scene_name_parts[-1]) + + scene_metadata_path = EnvCache.scene_metadata_path( + cache_path, scene.env_name, scene.name, scene.dt + ) + if scene_metadata_path.exists(): + # Try repeatedly to open the file because it might still be + # being created in another process. + while True: + try: + already_done_scene = EnvCache.load(scene_metadata_path) + break + except: + time.sleep(1) + + # Already processed, so we can immediately return our cached results. + return ( + already_done_scene.agents, + already_done_scene.agent_presence, + ) + + scene_filepath, _, num_scenarios = self.dataset_obj[base_scene_name] + + data_df: pd.DataFrame = pd.read_csv( + scene_filepath, index_col=False, dtype={"case_id": int} + ) + + # The first frame and case IDs of INTERACTION data is always "1". + data_df["frame_id"] -= 1 + data_df["case_id"] -= 1 + + # Ensuring case_ids are kept within track_ids. + data_df["track_id"] = ( + data_df["case_id"].astype(str) + "_" + data_df["track_id"].astype(str) + ) + + # Don't need these columns anymore. + data_df.drop( + columns=["timestamp_ms"], + inplace=True, + ) + + # Add in zero for z value + data_df["z"] = np.zeros_like(data_df["x"]) + + # Renaming columns to match our usual names. + data_df.rename( + columns={ + "frame_id": "scene_ts", + "psi_rad": "heading", + "track_id": "agent_id", + }, + inplace=True, + ) + + # Ensuring data is sorted by agent ID and scene timestep. + data_df.set_index(["agent_id", "scene_ts"], inplace=True) + data_df.sort_index(inplace=True) + data_df.reset_index(level=1, inplace=True) + + agent_ids: np.ndarray = data_df.index.get_level_values(0).to_numpy() + + ### Calculating agent classes + agent_class: Dict[int, str] = ( + data_df.groupby("agent_id")["agent_type"].first().to_dict() + ) + + ### Calculating agent extents + agent_length: Dict[int, float] = ( + data_df.groupby("agent_id")["length"].first().to_dict() + ) + + agent_width: Dict[int, float] = ( + data_df.groupby("agent_id")["width"].first().to_dict() + ) + + # This is likely to be very noisy... Unfortunately, ETH/UCY only + # provide center of mass data. + non_car_mask = data_df["agent_type"] != "car" + data_df.loc[non_car_mask, "heading"] = np.arctan2( + data_df.loc[non_car_mask, "vy"], data_df.loc[non_car_mask, "vx"] + ) + + del data_df["agent_type"] + + ### Calculating agent accelerations + data_df[["ax", "ay"]] = ( + arr_utils.agent_aware_diff(data_df[["vx", "vy"]].to_numpy(), agent_ids) + / INTERACTION_DT + ) + + agent_list: Dict[int, List[AgentMetadata]] = defaultdict(list) + agent_presence: Dict[int, List[List[AgentMetadata]]] = dict() + for agent_id, frames in data_df.groupby("agent_id")["scene_ts"]: + case_id = int(agent_id.split("_")[0]) + start_frame: int = frames.iat[0].item() + last_frame: int = frames.iat[-1].item() + + agent_type: AgentType = interaction_type_to_unified_type( + agent_class[agent_id] + ) + + agent_metadata = AgentMetadata( + name=str(agent_id), + agent_type=agent_type, + first_timestep=start_frame, + last_timestep=last_frame, + # These values are as ballpark as it gets... + # The vehicle height here is just taking 6 feet. + extent=FixedExtent(0.75, 0.75, 1.5) + if agent_type != AgentType.VEHICLE + else FixedExtent(agent_length[agent_id], agent_width[agent_id], 1.83), + ) + + if case_id not in agent_presence: + agent_presence[case_id] = [[] for _ in range(scene.length_timesteps)] + + agent_list[case_id].append(agent_metadata) + for frame in frames: + agent_presence[case_id][frame].append(agent_metadata) + + # Changing the agent_id dtype to str + data_df.reset_index(inplace=True) + data_df["agent_id"] = data_df["agent_id"].astype(str) + data_df.set_index(["agent_id", "scene_ts"], inplace=True) + + for case_id, case_df in data_df.groupby("case_id"): + case_scene = Scene( + env_metadata=scene.env_metadata, + name=base_scene_name + f"_{case_id}", + location=scene.location, + data_split=scene.data_split, + length_timesteps=scene.length_timesteps, + raw_data_idx=scene.raw_data_idx, + data_access_info=scene.data_access_info, + description=scene.description, + agents=agent_list[case_id], + agent_presence=agent_presence[case_id], + ) + cache_class.save_agent_data( + case_df.loc[ + :, + [ + "x", + "y", + "z", + "vx", + "vy", + "ax", + "ay", + "heading", + ], + ], + cache_path, + case_scene, + ) + EnvCache.save_scene_with_path(cache_path, case_scene) + + return agent_list[orig_scenario_num], agent_presence[orig_scenario_num] + + def cache_map( + self, + map_path: Path, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + vector_map = VectorMap( + map_id=f"{self.name}:{'_'.join(map_path.stem.split('_')[1:])}" + ) + + map_projector = lanelet2.projection.UtmProjector(lanelet2.io.Origin(0.0, 0.0)) + + laneletmap = lanelet2.io.load(str(map_path), map_projector) + traffic_rules = lanelet2.traffic_rules.create( + # TODO(bivanovic): lanelet2 has only implemented Germany so far... + # Thankfully all countries here drive on the right-hand side like + # Germany, so maybe we can get away with it. + lanelet2.traffic_rules.Locations.Germany, + lanelet2.traffic_rules.Participants.Vehicle, + ) + lane_graph = lanelet2.routing.RoutingGraph(laneletmap, traffic_rules) + + maximum_bound: np.ndarray = np.full((3,), np.nan) + minimum_bound: np.ndarray = np.full((3,), np.nan) + + for lanelet in tqdm( + laneletmap.laneletLayer, desc="Creating Vectorized Map", leave=False + ): + left_pts: np.ndarray = np.array( + [(p.x, p.y, p.z) for p in lanelet.leftBound] + ) + right_pts: np.ndarray = np.array( + [(p.x, p.y, p.z) for p in lanelet.rightBound] + ) + center_pts: np.ndarray = np.array( + [(p.x, p.y, p.z) for p in lanelet.centerline] + ) + + # Adding the element to the map. + new_lane = RoadLane( + id=str(lanelet.id), + center=Polyline(center_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), + ) + + new_lane.next_lanes.update( + [str(l.id) for l in lane_graph.following(lanelet)] + ) + + new_lane.prev_lanes.update( + [str(l.id) for l in lane_graph.previous(lanelet)] + ) + + left_lane_change = lane_graph.left(lanelet) + if left_lane_change: + new_lane.adj_lanes_left.add(str(left_lane_change.id)) + + right_lane_change = lane_graph.right(lanelet) + if right_lane_change: + new_lane.adj_lanes_right.add(str(right_lane_change.id)) + + vector_map.add_map_element(new_lane) + + # Computing the maximum and minimum map coordinates. + maximum_bound = np.fmax(maximum_bound, left_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, left_pts.min(axis=0)) + + maximum_bound = np.fmax(maximum_bound, right_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, right_pts.min(axis=0)) + + maximum_bound = np.fmax(maximum_bound, center_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, center_pts.min(axis=0)) + + # Setting the map bounds. + # vector_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + vector_map.extent = np.concatenate((minimum_bound, maximum_bound)) + + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + data_dir_path = Path(self.metadata.data_dir) + file_paths = list(data_dir_path.glob("**/*.osm")) + for map_path in tqdm( + file_paths, + desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", + position=0, + ): + self.cache_map(map_path, cache_path, map_cache_class, map_params) diff --git a/src/trajdata/dataset_specific/lyft/lyft_dataset.py b/src/trajdata/dataset_specific/lyft/lyft_dataset.py index aeb8634..61ad7d7 100644 --- a/src/trajdata/dataset_specific/lyft/lyft_dataset.py +++ b/src/trajdata/dataset_specific/lyft/lyft_dataset.py @@ -1,19 +1,14 @@ -import warnings from collections import defaultdict from functools import partial -from math import ceil from pathlib import Path from random import Random from typing import Any, Dict, List, Optional, Tuple, Type -import l5kit.data.proto.road_network_pb2 as l5_pb2 import numpy as np import pandas as pd from l5kit.configs.config import load_metadata from l5kit.data import ChunkedDataset, LocalDataManager -from l5kit.data.map_api import InterpolationMethod, MapAPI -from l5kit.rasterization import RenderContext -from tqdm import tqdm +from l5kit.data.map_api import MapAPI from trajdata.caching import EnvCache, SceneCache from trajdata.data_structures import ( @@ -25,16 +20,9 @@ ) from trajdata.data_structures.agent import Agent, AgentType, VariableExtent from trajdata.dataset_specific.lyft import lyft_utils -from trajdata.dataset_specific.lyft.rasterizer import MapSemanticRasterizer from trajdata.dataset_specific.raw_dataset import RawDataset from trajdata.dataset_specific.scene_records import LyftSceneRecord -from trajdata.maps import RasterizedMap, RasterizedMapMetadata, map_utils -from trajdata.proto.vectorized_map_pb2 import ( - MapElement, - PedCrosswalk, - RoadLane, - VectorizedMap, -) +from trajdata.maps import VectorMap from trajdata.utils import arr_utils @@ -82,6 +70,8 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: ] scene_split_map = defaultdict(partial(const_lambda, const_val="val")) + else: + raise ValueError(f"Unknown Lyft environment name: {env_name}") return EnvMetadata( name=env_name, @@ -89,6 +79,9 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: dt=lyft_utils.LYFT_DT, parts=dataset_parts, scene_split_map=scene_split_map, + # The location names should match the map names used in + # the unified data cache. + map_locations=("palo_alto",), ) def load_dataset_obj(self, verbose: bool = False) -> None: @@ -256,6 +249,11 @@ def get_agent_info( all_agent_data_df.sort_index(inplace=True) all_agent_data_df.reset_index(level=1, inplace=True) + # Add in z data based on ego data + all_agent_data_df = all_agent_data_df.join( + ego_agent.data.xs("ego", level=0)["z"], on="scene_ts" + ) + ### Calculating agent classes agent_class: Dict[int, float] = ( all_agent_data_df.groupby("agent_id")["class_id"] @@ -306,6 +304,7 @@ def get_agent_info( final_cols = [ "x", "y", + "z", "vx", "vy", "ax", @@ -329,80 +328,6 @@ def get_agent_info( return agent_list, agent_presence - def extract_vectorized(self, mapAPI: MapAPI) -> VectorizedMap: - vec_map = VectorizedMap() - maximum_bound: np.ndarray = np.full((3,), np.nan) - minimum_bound: np.ndarray = np.full((3,), np.nan) - for l5_element in tqdm(mapAPI.elements, desc="Creating Vectorized Map"): - if mapAPI.is_lane(l5_element): - l5_element_id: str = mapAPI.id_as_str(l5_element.id) - l5_lane: l5_pb2.Lane = l5_element.element.lane - - lane_dict = mapAPI.get_lane_coords(l5_element_id) - left_pts = lane_dict["xyz_left"] - right_pts = lane_dict["xyz_right"] - - # Ensuring the left and right bounds have the same numbers of points. - if len(left_pts) < len(right_pts): - left_pts = mapAPI.interpolate( - left_pts, len(right_pts), InterpolationMethod.INTER_ENSURE_LEN - ) - elif len(right_pts) < len(left_pts): - right_pts = mapAPI.interpolate( - right_pts, len(left_pts), InterpolationMethod.INTER_ENSURE_LEN - ) - - midlane_pts: np.ndarray = (left_pts + right_pts) / 2 - - # Computing the maximum and minimum map coordinates. - maximum_bound = np.fmax(maximum_bound, left_pts.max(axis=0)) - minimum_bound = np.fmin(minimum_bound, left_pts.min(axis=0)) - - maximum_bound = np.fmax(maximum_bound, right_pts.max(axis=0)) - minimum_bound = np.fmin(minimum_bound, right_pts.min(axis=0)) - - maximum_bound = np.fmax(maximum_bound, midlane_pts.max(axis=0)) - minimum_bound = np.fmin(minimum_bound, midlane_pts.min(axis=0)) - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = l5_element.id.id - - new_lane: RoadLane = new_element.road_lane - map_utils.populate_lane_polylines( - new_lane, midlane_pts, left_pts, right_pts - ) - - new_lane.exit_lanes.extend([gid.id for gid in l5_lane.lanes_ahead]) - new_lane.adjacent_lanes_left.append( - l5_lane.adjacent_lane_change_left.id - ) - new_lane.adjacent_lanes_right.append( - l5_lane.adjacent_lane_change_right.id - ) - - if mapAPI.is_crosswalk(l5_element): - l5_element_id: str = mapAPI.id_as_str(l5_element.id) - crosswalk_pts: np.ndarray = mapAPI.get_crosswalk_coords(l5_element_id)[ - "xyz" - ] - - # Computing the maximum and minimum map coordinates. - maximum_bound = np.fmax(maximum_bound, crosswalk_pts.max(axis=0)) - minimum_bound = np.fmin(minimum_bound, crosswalk_pts.min(axis=0)) - - new_element: MapElement = vec_map.elements.add() - new_element.id = l5_element.id.id - - new_crosswalk: PedCrosswalk = new_element.ped_crosswalk - map_utils.populate_polygon(new_crosswalk.polygon, crosswalk_pts) - - # Setting the map bounds. - vec_map.max_pt.x, vec_map.max_pt.y, vec_map.max_pt.z = maximum_bound - vec_map.min_pt.x, vec_map.min_pt.y, vec_map.min_pt.z = minimum_bound - - return vec_map - def cache_maps( self, cache_path: Path, @@ -421,75 +346,7 @@ def cache_maps( world_to_ecef = np.array(dataset_meta["world_to_ecef"], dtype=np.float64) mapAPI = MapAPI(semantic_map_filepath, world_to_ecef) - if map_params.get("original_format", False): - warnings.warn( - "Using a dataset's original map format is deprecated, and will be removed in the next version of trajdata!", - FutureWarning, - ) - - mins = np.stack( - [ - map_elem["bounds"][:, 0].min(axis=0) - for map_elem in mapAPI.bounds_info.values() - ] - ).min(axis=0) - maxs = np.stack( - [ - map_elem["bounds"][:, 1].max(axis=0) - for map_elem in mapAPI.bounds_info.values() - ] - ).max(axis=0) - - world_right, world_top = maxs - world_left, world_bottom = mins - - world_center: np.ndarray = np.array( - [(world_left + world_right) / 2, (world_bottom + world_top) / 2] - ) - raster_size_px: np.ndarray = np.array( - [ - ceil((world_right - world_left) * resolution), - ceil((world_top - world_bottom) * resolution), - ] - ) - - render_context = RenderContext( - raster_size_px=raster_size_px, - pixel_size_m=np.array([1 / resolution, 1 / resolution]), - center_in_raster_ratio=np.array([0.5, 0.5]), - set_origin_to_bottom=False, - ) - - map_from_world: np.ndarray = render_context.raster_from_world( - world_center, 0.0 - ) - - rasterizer = MapSemanticRasterizer( - render_context, semantic_map_filepath, world_to_ecef - ) - - print("Rendering palo_alto Map...", flush=True, end=" ") - map_data: np.ndarray = rasterizer.render_semantic_map( - world_center, map_from_world - ) - print("done!", flush=True) - - vectorized_map = VectorizedMap() - else: - vectorized_map: VectorizedMap = self.extract_vectorized(mapAPI) - map_data, map_from_world = map_utils.rasterize_map( - vectorized_map, resolution - ) + vector_map = VectorMap(map_id=f"{self.name}:{map_name}") + lyft_utils.populate_vector_map(vector_map, mapAPI) - rasterized_map_info: RasterizedMapMetadata = RasterizedMapMetadata( - name=map_name, - shape=map_data.shape, - layers=["drivable_area", "lane_divider", "ped_area"], - layer_rgb_groups=([0], [1], [2]), - resolution=resolution, - map_from_world=map_from_world, - ) - rasterized_map_obj: RasterizedMap = RasterizedMap(rasterized_map_info, map_data) - map_cache_class.cache_map( - cache_path, vectorized_map, rasterized_map_obj, self.name - ) + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) diff --git a/src/trajdata/dataset_specific/lyft/lyft_utils.py b/src/trajdata/dataset_specific/lyft/lyft_utils.py index 0fc6528..92e01f0 100644 --- a/src/trajdata/dataset_specific/lyft/lyft_utils.py +++ b/src/trajdata/dataset_specific/lyft/lyft_utils.py @@ -1,9 +1,12 @@ -from typing import Final, List +from typing import Dict, Final, List +import l5kit.data.proto.road_network_pb2 as l5_pb2 import numpy as np import pandas as pd from l5kit.data import ChunkedDataset +from l5kit.data.map_api import InterpolationMethod, MapAPI from l5kit.geometry import rotation33_as_yaw +from tqdm import tqdm from trajdata.data_structures import ( Agent, @@ -13,6 +16,9 @@ Scene, VariableExtent, ) +from trajdata.maps.vec_map import VectorMap +from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadLane +from trajdata.utils import map_utils LYFT_DT: Final[float] = 0.1 @@ -23,12 +29,16 @@ def agg_ego_data(lyft_obj: ChunkedDataset, scene: Scene) -> Agent: ego_translations = lyft_obj.frames[scene_frame_start:scene_frame_end][ "ego_translation" - ][:, :2] + ][:, :3] # Doing this prepending so that the first velocity isn't zero (rather it's just the first actual velocity duplicated) - prepend_pos = ego_translations[0] - (ego_translations[1] - ego_translations[0]) + prepend_pos = ego_translations[0, :2] - ( + ego_translations[1, :2] - ego_translations[0, :2] + ) ego_velocities = ( - np.diff(ego_translations, axis=0, prepend=np.expand_dims(prepend_pos, axis=0)) + np.diff( + ego_translations[:, :2], axis=0, prepend=np.expand_dims(prepend_pos, axis=0) + ) / LYFT_DT ) @@ -61,7 +71,7 @@ def agg_ego_data(lyft_obj: ChunkedDataset, scene: Scene) -> Agent: ) ego_data_df = pd.DataFrame( ego_data_np, - columns=["x", "y", "vx", "vy", "ax", "ay", "heading"] + extent_cols, + columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"] + extent_cols, index=pd.MultiIndex.from_tuples( [("ego", idx) for idx in range(ego_data_np.shape[0])], names=["agent_id", "scene_ts"], @@ -93,3 +103,82 @@ def lyft_type_to_unified_type(lyft_type: int) -> AgentType: return AgentType.MOTORCYCLE elif lyft_type == 14: return AgentType.PEDESTRIAN + + +def populate_vector_map(vector_map: VectorMap, mapAPI: MapAPI) -> None: + maximum_bound: np.ndarray = np.full((3,), np.nan) + minimum_bound: np.ndarray = np.full((3,), np.nan) + for l5_element in tqdm(mapAPI.elements, desc="Creating Vectorized Map"): + if mapAPI.is_lane(l5_element): + l5_element_id: str = mapAPI.id_as_str(l5_element.id) + l5_lane: l5_pb2.Lane = l5_element.element.lane + + lane_dict = mapAPI.get_lane_coords(l5_element_id) + left_pts = lane_dict["xyz_left"] + right_pts = lane_dict["xyz_right"] + + # Ensuring the left and right bounds have the same numbers of points. + if len(left_pts) < len(right_pts): + left_pts = mapAPI.interpolate( + left_pts, len(right_pts), InterpolationMethod.INTER_ENSURE_LEN + ) + elif len(right_pts) < len(left_pts): + right_pts = mapAPI.interpolate( + right_pts, len(left_pts), InterpolationMethod.INTER_ENSURE_LEN + ) + + midlane_pts: np.ndarray = (left_pts + right_pts) / 2 + + # Computing the maximum and minimum map coordinates. + maximum_bound = np.fmax(maximum_bound, left_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, left_pts.min(axis=0)) + + maximum_bound = np.fmax(maximum_bound, right_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, right_pts.min(axis=0)) + + maximum_bound = np.fmax(maximum_bound, midlane_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, midlane_pts.min(axis=0)) + + # Adding the element to the map. + new_lane = RoadLane( + id=l5_element_id, + center=Polyline(midlane_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), + ) + + new_lane.next_lanes.update( + [mapAPI.id_as_str(gid) for gid in l5_lane.lanes_ahead] + ) + + left_lane_change_id: str = mapAPI.id_as_str( + l5_lane.adjacent_lane_change_left + ) + if left_lane_change_id: + new_lane.adj_lanes_left.add(left_lane_change_id) + + right_lane_change_id: str = mapAPI.id_as_str( + l5_lane.adjacent_lane_change_right + ) + if right_lane_change_id: + new_lane.adj_lanes_right.add(right_lane_change_id) + + vector_map.add_map_element(new_lane) + + if mapAPI.is_crosswalk(l5_element): + l5_element_id: str = mapAPI.id_as_str(l5_element.id) + crosswalk_pts: np.ndarray = mapAPI.get_crosswalk_coords(l5_element_id)[ + "xyz" + ] + + # Computing the maximum and minimum map coordinates. + maximum_bound = np.fmax(maximum_bound, crosswalk_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, crosswalk_pts.min(axis=0)) + + vector_map.add_map_element( + PedCrosswalk(id=l5_element_id, polygon=Polyline(crosswalk_pts)) + ) + + # Setting the map bounds. + # vector_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + vector_map.extent = np.concatenate((minimum_bound, maximum_bound)) diff --git a/src/trajdata/dataset_specific/lyft/rasterizer.py b/src/trajdata/dataset_specific/lyft/rasterizer.py deleted file mode 100644 index c3bbe53..0000000 --- a/src/trajdata/dataset_specific/lyft/rasterizer.py +++ /dev/null @@ -1,124 +0,0 @@ -from collections import defaultdict -from typing import Dict - -import cv2 -import numpy as np -from l5kit.data.map_api import InterpolationMethod -from l5kit.geometry import transform_points -from l5kit.rasterization.semantic_rasterizer import ( - CV2_SUB_VALUES, - INTERPOLATION_POINTS, - RasterEls, - SemanticRasterizer, - cv2_subpixel, -) - - -def indices_in_bounds( - center: np.ndarray, bounds: np.ndarray, half_extent: float -) -> np.ndarray: - """ - Get indices of elements for which the bounding box described by bounds intersects the one defined around - center (square with side 2*half_side) - - Args: - center (float): XY of the center - bounds (np.ndarray): array of shape Nx2x2 [[x_min,y_min],[x_max, y_max]] - half_extent (float): half the side of the bounding box centered around center - - Returns: - np.ndarray: indices of elements inside radius from center - """ - return np.arange(bounds.shape[0], dtype=np.long) - - -class MapSemanticRasterizer(SemanticRasterizer): - def render_semantic_map( - self, center_in_world: np.ndarray, raster_from_world: np.ndarray - ) -> np.ndarray: - """Renders the semantic map at given x,y coordinates. - - Args: - center_in_world (np.ndarray): XY of the image center in world ref system - raster_from_world (np.ndarray): - Returns: - np.ndarray: RGB raster - - """ - lane_area_img: np.ndarray = np.zeros( - shape=(self.raster_size[1], self.raster_size[0], 3), dtype=np.uint8 - ) - lane_line_img: np.ndarray = np.zeros( - shape=(self.raster_size[1], self.raster_size[0], 3), dtype=np.uint8 - ) - ped_area_img: np.ndarray = np.zeros( - shape=(self.raster_size[1], self.raster_size[0], 3), dtype=np.uint8 - ) - - # filter using half a radius from the center - raster_radius = float(np.linalg.norm(self.raster_size * self.pixel_size)) / 2 - - # get all lanes as interpolation so that we can transform them all together - lane_indices = indices_in_bounds( - center_in_world, self.mapAPI.bounds_info["lanes"]["bounds"], raster_radius - ) - lanes_mask: Dict[str, np.ndarray] = defaultdict( - lambda: np.zeros(len(lane_indices) * 2, dtype=np.bool) - ) - lanes_area = np.zeros((len(lane_indices) * 2, INTERPOLATION_POINTS, 2)) - - for idx, lane_idx in enumerate(lane_indices): - lane_idx = self.mapAPI.bounds_info["lanes"]["ids"][lane_idx] - - # interpolate over polyline to always have the same number of points - lane_coords = self.mapAPI.get_lane_as_interpolation( - lane_idx, INTERPOLATION_POINTS, InterpolationMethod.INTER_ENSURE_LEN - ) - lanes_area[idx * 2] = lane_coords["xyz_left"][:, :2] - lanes_area[idx * 2 + 1] = lane_coords["xyz_right"][::-1, :2] - - lanes_mask[RasterEls.LANE_NOTL.name][idx * 2 : idx * 2 + 2] = True - - if len(lanes_area): - lanes_area = cv2_subpixel( - transform_points(lanes_area.reshape((-1, 2)), raster_from_world) - ) - - for lane_area in lanes_area.reshape((-1, INTERPOLATION_POINTS * 2, 2)): - # need to for-loop otherwise some of them are empty - cv2.fillPoly(lane_area_img, [lane_area], (255, 0, 0), **CV2_SUB_VALUES) - - lanes_area = lanes_area.reshape((-1, INTERPOLATION_POINTS, 2)) - for ( - name, - mask, - ) in lanes_mask.items(): # draw each type of lane with its own color - cv2.polylines( - lane_line_img, - lanes_area[mask], - False, - (0, 255, 0), - **CV2_SUB_VALUES - ) - - # plot crosswalks - crosswalks = [] - for idx in indices_in_bounds( - center_in_world, - self.mapAPI.bounds_info["crosswalks"]["bounds"], - raster_radius, - ): - crosswalk = self.mapAPI.get_crosswalk_coords( - self.mapAPI.bounds_info["crosswalks"]["ids"][idx] - ) - xy_cross = cv2_subpixel( - transform_points(crosswalk["xyz"][:, :2], raster_from_world) - ) - crosswalks.append(xy_cross) - - cv2.fillPoly(ped_area_img, crosswalks, (0, 0, 255), **CV2_SUB_VALUES) - - map_img: np.ndarray = (lane_area_img + lane_line_img + ped_area_img).astype( - np.float32 - ) / 255 - return map_img.transpose(2, 0, 1) diff --git a/src/trajdata/dataset_specific/nuplan/__init__.py b/src/trajdata/dataset_specific/nuplan/__init__.py new file mode 100644 index 0000000..022a1a0 --- /dev/null +++ b/src/trajdata/dataset_specific/nuplan/__init__.py @@ -0,0 +1 @@ +from .nuplan_dataset import NuplanDataset diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py new file mode 100644 index 0000000..1c4df3f --- /dev/null +++ b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py @@ -0,0 +1,409 @@ +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type + +import numpy as np +import pandas as pd +from nuplan.common.maps.nuplan_map import map_factory +from nuplan.common.maps.nuplan_map.nuplan_map import NuPlanMap +from tqdm import tqdm + +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures.agent import ( + Agent, + AgentMetadata, + AgentType, + FixedExtent, + VariableExtent, +) +from trajdata.data_structures.environment import EnvMetadata +from trajdata.data_structures.scene_metadata import Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.nuplan import nuplan_utils +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import NuPlanSceneRecord +from trajdata.maps.vec_map import VectorMap +from trajdata.utils import arr_utils + + +class NuplanDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + all_log_splits: Dict[str, List[str]] = nuplan_utils.create_splits_logs() + + nup_log_splits: Dict[str, List[str]] + if env_name == "nuplan_mini": + nup_log_splits = { + k: all_log_splits[k[5:]] + for k in ["mini_train", "mini_val", "mini_test"] + } + + # nuScenes possibilities are the Cartesian product of these + dataset_parts: List[Tuple[str, ...]] = [ + ("mini_train", "mini_val", "mini_test"), + nuplan_utils.NUPLAN_LOCATIONS, + ] + elif env_name.startswith("nuplan"): + split_str = env_name.split("_")[-1] + nup_log_splits = {split_str: all_log_splits[split_str]} + + # nuScenes possibilities are the Cartesian product of these + dataset_parts: List[Tuple[str, ...]] = [ + (split_str,), + nuplan_utils.NUPLAN_LOCATIONS, + ] + else: + raise ValueError(f"Unknown nuPlan environment name: {env_name}") + + # Inverting the dict from above, associating every log with its data split. + nup_log_split_map: Dict[str, str] = { + v_elem: k for k, v in nup_log_splits.items() for v_elem in v + } + + return EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=nuplan_utils.NUPLAN_DT, + parts=dataset_parts, + scene_split_map=nup_log_split_map, + # The location names should match the map names used in + # the unified data cache. + map_locations=nuplan_utils.NUPLAN_LOCATIONS, + ) + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + print(f"Loading {self.name} dataset...", flush=True) + + if self.name == "nuplan_mini": + subfolder = "mini" + elif self.name.startswith("nuplan"): + subfolder = "trainval" + + self.dataset_obj = nuplan_utils.NuPlanObject(self.metadata.data_dir, subfolder) + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[NuPlanSceneRecord] = list() + + default_split = "mini_train" if "mini" in self.metadata.name else "train" + + scenes_list: List[SceneMetadata] = list() + for idx, scene_record in enumerate(self.dataset_obj.scenes): + scene_name: str = scene_record["name"] + originating_log: str = scene_name.split("=")[0] + # scene_desc: str = scene_record["description"].lower() + scene_location: str = scene_record["location"] + scene_split: str = self.metadata.scene_split_map.get( + originating_log, default_split + ) + scene_length: int = scene_record["num_timesteps"] + + if scene_length == 1: + # nuPlan has scenes with only a single frame of data which we + # can't do much with in terms of prediction/planning/etc. As a + # result, we skip it. + # As an example, nuplan_mini scene e958b276c7a65197 + # from log 2021.06.14.19.22.11_veh-38_01480_01860. + continue + + # Saving all scene records for later caching. + all_scenes_list.append( + NuPlanSceneRecord( + scene_name, + scene_location, + scene_length, + scene_split, + # scene_desc, + idx, + ) + ) + + if ( + scene_location in scene_tag + and scene_split in scene_tag + and scene_desc_contains is None + ): + # if scene_desc_contains is not None and not any( + # desc_query in scene_desc for desc_query in scene_desc_contains + # ): + # continue + + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=scene_name, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + + self.cache_all_scenes_list(env_cache, all_scenes_list) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list: List[NuPlanSceneRecord] = env_cache.load_env_scenes_list( + self.name + ) + + scenes_list: List[SceneMetadata] = list() + for scene_record in all_scenes_list: + ( + scene_name, + scene_location, + scene_length, + scene_split, + # scene_desc, + data_idx, + ) = scene_record + + if ( + scene_location in scene_tag + and scene_split in scene_tag + and scene_desc_contains is None + ): + # if scene_desc_contains is not None and not any( + # desc_query in scene_desc for desc_query in scene_desc_contains + # ): + # continue + + scene_metadata = Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + None, # This isn't used if everything is already cached. + # scene_desc, + ) + scenes_list.append(scene_metadata) + + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + _, _, _, data_idx = scene_info + default_split = "mini_train" if "mini" in self.metadata.name else "train" + + scene_record: Dict[str, str] = self.dataset_obj.scenes[data_idx] + + scene_name: str = scene_record["name"] + originating_log: str = scene_name.split("=")[0] + # scene_desc: str = scene_record["description"].lower() + scene_location: str = scene_record["location"] + scene_split: str = self.metadata.scene_split_map.get( + originating_log, default_split + ) + scene_length: int = scene_record["num_timesteps"] + + return Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + scene_record, + # scene_desc, + ) + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + # instantiate VectorMap from map_api if necessary + self.dataset_obj.open_db(scene.name.split("=")[0] + ".db") + + ego_agent_info: AgentMetadata = AgentMetadata( + name="ego", + agent_type=AgentType.VEHICLE, + first_timestep=0, + last_timestep=scene.length_timesteps - 1, + # From https://github.com/motional/nuplan-devkit/blob/761cdbd52d699560629c79ba1b10b29c18ebc068/nuplan/common/actor_state/vehicle_parameters.py#L125 + extent=FixedExtent(length=4.049 + 1.127, width=1.1485 * 2.0, height=1.777), + ) + + agent_list: List[AgentMetadata] = [ego_agent_info] + agent_presence: List[List[AgentMetadata]] = [ + [ego_agent_info] for _ in range(scene.length_timesteps) + ] + + all_frames: pd.DataFrame = self.dataset_obj.get_scene_frames(scene) + + ego_df = ( + all_frames[ + ["ego_x", "ego_y", "ego_z", "ego_vx", "ego_vy", "ego_ax", "ego_ay"] + ] + .rename(columns=lambda name: name[4:]) + .reset_index(drop=True) + ) + ego_df["heading"] = arr_utils.quaternion_to_yaw( + all_frames[["ego_qw", "ego_qx", "ego_qy", "ego_qz"]].values + ) + ego_df["scene_ts"] = np.arange(len(ego_df)) + ego_df["agent_id"] = "ego" + + lpc_tokens: List[bytearray] = all_frames.index.tolist() + agents_df: pd.DataFrame = self.dataset_obj.get_detected_agents(lpc_tokens) + tls_df: pd.DataFrame = self.dataset_obj.get_traffic_light_status(lpc_tokens) + + self.dataset_obj.close_db() + + agents_df["scene_ts"] = agents_df["lidar_pc_token"].map( + {lpc_token: scene_ts for scene_ts, lpc_token in enumerate(lpc_tokens)} + ) + agents_df["agent_id"] = agents_df["track_token"].apply(lambda x: x.hex()) + + # Recording agent metadata for later. + agent_metadata_dict: Dict[str, Dict[str, Any]] = dict() + for agent_id, agent_data in agents_df.groupby("agent_id").first().iterrows(): + if agent_id not in agent_metadata_dict: + agent_metadata_dict[agent_id] = { + "type": nuplan_utils.nuplan_type_to_unified_type( + agent_data["category_name"] + ), + "width": agent_data["width"], + "length": agent_data["length"], + "height": agent_data["height"], + } + + agents_df = agents_df.drop( + columns=[ + "lidar_pc_token", + "track_token", + "category_name", + "width", + "length", + "height", + ], + ).rename(columns={"yaw": "heading"}) + + # Sorting the agents' combined DataFrame here. + agents_df.set_index(["agent_id", "scene_ts"], inplace=True) + agents_df.sort_index(inplace=True) + agents_df.reset_index(level=1, inplace=True) + + one_detection_agents: List[str] = list() + for agent_id in agent_metadata_dict: + agent_metadata_entry = agent_metadata_dict[agent_id] + + agent_specific_df = agents_df.loc[agent_id] + if len(agent_specific_df.shape) <= 1 or agent_specific_df.shape[0] <= 1: + # Removing agents that are only observed once. + one_detection_agents.append(agent_id) + continue + + first_timestep: int = agent_specific_df.iat[0, 0].item() + last_timestep: int = agent_specific_df.iat[-1, 0].item() + agent_info: AgentMetadata = AgentMetadata( + name=agent_id, + agent_type=agent_metadata_entry["type"], + first_timestep=first_timestep, + last_timestep=last_timestep, + extent=FixedExtent( + length=agent_metadata_entry["length"], + width=agent_metadata_entry["width"], + height=agent_metadata_entry["height"], + ), + ) + + agent_list.append(agent_info) + for timestep in range( + agent_info.first_timestep, agent_info.last_timestep + 1 + ): + agent_presence[timestep].append(agent_info) + + # Removing agents with only one detection. + agents_df.drop(index=one_detection_agents, inplace=True) + + ### Calculating agent accelerations + agent_ids: np.ndarray = agents_df.index.get_level_values(0).to_numpy() + if len(agent_ids) > 0: + agents_df[["ax", "ay"]] = ( + arr_utils.agent_aware_diff( + agents_df[["vx", "vy"]].to_numpy(), agent_ids + ) + / nuplan_utils.NUPLAN_DT + ) + else: + agents_df[["ax", "ay"]] = agents_df[["vx", "vy"]] + + # for agent_id, frames in agents_df.groupby("agent_id")["scene_ts"]: + # if frames.shape[0] <= 1: + # raise ValueError("nuPlan can have one-detection agents :(") + + # start_frame: int = frames.iat[0].item() + # last_frame: int = frames.iat[-1].item() + + # if frames.shape[0] < last_frame - start_frame + 1: + # raise ValueError("nuPlan indeed can have missing frames :(") + + overall_agents_df = pd.concat([ego_df, agents_df.reset_index()]).set_index( + ["agent_id", "scene_ts"] + ) + cache_class.save_agent_data(overall_agents_df, cache_path, scene) + + # similar process to clean up and traffic light data + tls_df["scene_ts"] = tls_df["lidar_pc_token"].map( + {lpc_token: scene_ts for scene_ts, lpc_token in enumerate(lpc_tokens)} + ) + tls_df = tls_df.drop(columns=["lidar_pc_token"]).set_index( + ["lane_id", "scene_ts"] + ) + + cache_class.save_traffic_light_data(tls_df, cache_path, scene) + + return agent_list, agent_presence + + def cache_map( + self, + map_name: str, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + nuplan_map: NuPlanMap = map_factory.get_maps_api( + map_root=str(self.metadata.data_dir.parent / "maps"), + map_version=nuplan_utils.NUPLAN_MAP_VERSION, + map_name=nuplan_utils.NUPLAN_FULL_MAP_NAME_DICT[map_name], + ) + + # Loading all layer geometries. + nuplan_map.initialize_all_layers() + + # This df has the normal lane_connectors with additional boundary information, + # which we want to use, however the default index is not the lane_connector_fid, + # although it is a 1:1 mapping so we instead create another index with the + # lane_connector_fids as the key and the resulting integer indices as the value. + lane_connector_fids: pd.Series = nuplan_map._vector_map[ + "gen_lane_connectors_scaled_width_polygons" + ]["lane_connector_fid"] + lane_connector_idxs: pd.Series = pd.Series( + index=lane_connector_fids, data=range(len(lane_connector_fids)) + ) + + vector_map = VectorMap(map_id=f"{self.name}:{map_name}") + nuplan_utils.populate_vector_map(vector_map, nuplan_map, lane_connector_idxs) + + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + """ + Stores rasterized maps to disk for later retrieval. + """ + for map_name in tqdm( + nuplan_utils.NUPLAN_LOCATIONS, + desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", + position=0, + ): + self.cache_map(map_name, cache_path, map_cache_class, map_params) diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_utils.py b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py new file mode 100644 index 0000000..c98c748 --- /dev/null +++ b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py @@ -0,0 +1,410 @@ +import glob +import sqlite3 +from collections import defaultdict +from pathlib import Path +from typing import Dict, Final, Generator, Iterable, List, Optional, Tuple + +import numpy as np +import nuplan.planning.script.config.common as common_cfg +import pandas as pd +import yaml +from nuplan.common.maps.nuplan_map.nuplan_map import NuPlanMap +from tqdm import tqdm + +from trajdata.data_structures.agent import AgentType +from trajdata.data_structures.scene_metadata import Scene +from trajdata.maps import TrafficLightStatus, VectorMap +from trajdata.maps.vec_map_elements import ( + MapElementType, + PedCrosswalk, + PedWalkway, + Polyline, + RoadArea, + RoadLane, +) +from trajdata.utils import map_utils + +NUPLAN_DT: Final[float] = 0.05 +NUPLAN_FULL_MAP_NAME_DICT: Final[Dict[str, str]] = { + "boston": "us-ma-boston", + "singapore": "sg-one-north", + "las_vegas": "us-nv-las-vegas-strip", + "pittsburgh": "us-pa-pittsburgh-hazelwood", +} +_NUPLAN_SQL_MAP_FRIENDLY_NAMES_DICT: Final[Dict[str, str]] = { + "us-ma-boston": "boston", + "sg-one-north": "singapore", + "las_vegas": "las_vegas", + "us-pa-pittsburgh-hazelwood": "pittsburgh", +} +NUPLAN_LOCATIONS: Final[Tuple[str, str, str, str]] = tuple( + NUPLAN_FULL_MAP_NAME_DICT.keys() +) +NUPLAN_MAP_VERSION: Final[str] = "nuplan-maps-v1.0" + +NUPLAN_TRAFFIC_STATUS_DICT: Final[Dict[str, TrafficLightStatus]] = { + "green": TrafficLightStatus.GREEN, + "red": TrafficLightStatus.RED, + "unknown": TrafficLightStatus.UNKNOWN, +} + + +class NuPlanObject: + def __init__(self, dataset_path: Path, subfolder: str) -> None: + self.base_path: Path = dataset_path / subfolder + + self.connection: sqlite3.Connection = None + self.cursor: sqlite3.Cursor = None + + self.scenes: List[Dict[str, str]] = self._load_scenes() + + def open_db(self, db_filename: str) -> None: + self.connection = sqlite3.connect(str(self.base_path / db_filename)) + self.connection.row_factory = sqlite3.Row + self.cursor = self.connection.cursor() + + def execute_query_one( + self, query_text: str, query_params: Optional[Iterable] = None + ) -> sqlite3.Row: + self.cursor.execute( + query_text, query_params if query_params is not None else [] + ) + return self.cursor.fetchone() + + def execute_query_all( + self, query_text: str, query_params: Optional[Iterable] = None + ) -> List[sqlite3.Row]: + self.cursor.execute( + query_text, query_params if query_params is not None else [] + ) + return self.cursor.fetchall() + + def execute_query_iter( + self, query_text: str, query_params: Optional[Iterable] = None + ) -> Generator[sqlite3.Row, None, None]: + self.cursor.execute( + query_text, query_params if query_params is not None else [] + ) + + for row in self.cursor: + yield row + + def _load_scenes(self) -> List[Dict[str, str]]: + scene_info_query = """ + SELECT sc.token AS scene_token, + log.location, + log.logfile, + ( + SELECT count(*) + FROM lidar_pc AS lpc + WHERE lpc.scene_token = sc.token + ) AS num_timesteps + FROM scene AS sc + LEFT JOIN log ON sc.log_token = log.token + """ + scenes: List[Dict[str, str]] = [] + + for log_filename in glob.glob(str(self.base_path / "*.db")): + self.open_db(log_filename) + + for row in self.execute_query_iter(scene_info_query): + scenes.append( + { + "name": f"{row['logfile']}={row['scene_token'].hex()}", + "location": _NUPLAN_SQL_MAP_FRIENDLY_NAMES_DICT[ + row["location"] + ], + "num_timesteps": row["num_timesteps"], + } + ) + + self.close_db() + + return scenes + + def get_scene_frames(self, scene: Scene) -> pd.DataFrame: + query = """ + SELECT lpc.token AS lpc_token, + ep.x AS ego_x, + ep.y AS ego_y, + ep.z AS ego_z, + ep.qw AS ego_qw, + ep.qx AS ego_qx, + ep.qy AS ego_qy, + ep.qz AS ego_qz, + ep.vx AS ego_vx, + ep.vy AS ego_vy, + ep.acceleration_x AS ego_ax, + ep.acceleration_y AS ego_ay + FROM lidar_pc AS lpc + LEFT JOIN ego_pose AS ep ON lpc.ego_pose_token = ep.token + WHERE scene_token = ? + ORDER BY lpc.timestamp ASC; + """ + log_filename, scene_token_str = scene.name.split("=") + scene_token = bytearray.fromhex(scene_token_str) + + return pd.read_sql_query( + query, self.connection, index_col="lpc_token", params=(scene_token,) + ) + + def get_detected_agents(self, binary_lpc_tokens: List[bytearray]) -> pd.DataFrame: + query = f""" + SELECT lb.lidar_pc_token, + lb.track_token, + (SELECT category.name FROM category WHERE category.token = tr.category_token) AS category_name, + tr.width, + tr.length, + tr.height, + lb.x, + lb.y, + lb.z, + lb.vx, + lb.vy, + lb.yaw + FROM lidar_box AS lb + LEFT JOIN track AS tr ON lb.track_token = tr.token + + WHERE lidar_pc_token IN ({('?,'*len(binary_lpc_tokens))[:-1]}) AND category_name IN ('vehicle', 'bicycle', 'pedestrian') + """ + return pd.read_sql_query(query, self.connection, params=binary_lpc_tokens) + + def get_traffic_light_status( + self, binary_lpc_tokens: List[bytearray] + ) -> pd.DataFrame: + query = f""" + SELECT tls.lidar_pc_token AS lidar_pc_token, + tls.lane_connector_id AS lane_id, + tls.status AS raw_status + FROM traffic_light_status AS tls + WHERE lidar_pc_token IN ({('?,'*len(binary_lpc_tokens))[:-1]}); + """ + df = pd.read_sql_query(query, self.connection, params=binary_lpc_tokens) + df["status"] = df["raw_status"].map(NUPLAN_TRAFFIC_STATUS_DICT) + df["lane_id"] = df["lane_id"].astype(str) + return df.drop(columns=["raw_status"]) + + def close_db(self) -> None: + self.cursor.close() + self.connection.close() + + +def nuplan_type_to_unified_type(nuplan_type: str) -> AgentType: + if nuplan_type == "pedestrian": + return AgentType.PEDESTRIAN + elif nuplan_type == "bicycle": + return AgentType.BICYCLE + elif nuplan_type == "vehicle": + return AgentType.VEHICLE + else: + return AgentType.UNKNOWN + + +def create_splits_logs() -> Dict[str, List[str]]: + yaml_filepath = Path(common_cfg.__path__[0]) / "splitter" / "nuplan.yaml" + with open(yaml_filepath, "r") as stream: + splits = yaml.safe_load(stream) + + return splits["log_splits"] + + +def extract_lane_and_edges( + nuplan_map: NuPlanMap, lane_record, lane_connector_idxs: pd.Series +) -> Tuple[str, np.ndarray, np.ndarray, np.ndarray, Tuple[str, str]]: + lane_midline = np.stack(lane_record["geometry"].xy, axis=-1) + + # Getting the bounding polygon vertices. + boundary_df = nuplan_map._vector_map["boundaries"] + if np.isfinite(lane_record["lane_fid"]): + fid = str(int(lane_record["lane_fid"])) + lane_info = nuplan_map._vector_map["lanes_polygons"].loc[fid] + elif np.isfinite(lane_record["lane_connector_fid"]): + fid = int(lane_record["lane_connector_fid"]) + lane_info = nuplan_map._vector_map[ + "gen_lane_connectors_scaled_width_polygons" + ].iloc[lane_connector_idxs[fid]] + else: + raise ValueError("Both lane_fid and lane_connector_fid are NaN!") + + lane_fid = str(fid) + boundary_info = ( + str(lane_info["left_boundary_fid"]), + str(lane_info["right_boundary_fid"]), + ) + + left_pts = np.stack(boundary_df.loc[boundary_info[0]]["geometry"].xy, axis=-1) + right_pts = np.stack(boundary_df.loc[boundary_info[1]]["geometry"].xy, axis=-1) + + # Final ordering check, ensuring that left_pts and right_pts can be combined + # into a polygon without the endpoints intersecting. + # Reversing the one lane edge that does not match the ordering of the midline. + if map_utils.endpoints_intersect(left_pts, right_pts): + if not map_utils.order_matches(left_pts, lane_midline): + left_pts = left_pts[::-1] + else: + right_pts = right_pts[::-1] + + # Ensuring that left and right have the same number of points. + # This is necessary, not for data storage but for later rasterization. + if left_pts.shape[0] < right_pts.shape[0]: + left_pts = map_utils.interpolate(left_pts, num_pts=right_pts.shape[0]) + elif right_pts.shape[0] < left_pts.shape[0]: + right_pts = map_utils.interpolate(right_pts, num_pts=left_pts.shape[0]) + + return (lane_fid, lane_midline, left_pts, right_pts, boundary_info) + + +def extract_area(nuplan_map: NuPlanMap, area_record) -> np.ndarray: + return np.stack(area_record["geometry"].exterior.xy, axis=-1) + + +def populate_vector_map( + vector_map: VectorMap, nuplan_map: NuPlanMap, lane_connector_idxs: pd.Series +) -> None: + # Setting the map bounds. + # NOTE: min_pt is especially important here since the world coordinates of nuPlan + # are quite large in magnitude. We make them relative to the bottom-left by + # subtracting all positions by min_pt and registering that offset as part of + # the map_from_world (and related) transforms later. + min_pt = np.min( + [ + layer_df["geometry"].total_bounds[:2] + for layer_df in nuplan_map._vector_map.values() + ], + axis=0, + ) + max_pt = np.max( + [ + layer_df["geometry"].total_bounds[2:] + for layer_df in nuplan_map._vector_map.values() + ], + axis=0, + ) + + # vector_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + vector_map.extent = np.array( + [ + min_pt[0], + min_pt[1], + 0.0, + max_pt[0], + max_pt[1], + 0.0, + ] + ) + + overall_pbar = tqdm( + total=len(nuplan_map._vector_map["baseline_paths"]) + + len(nuplan_map._vector_map["drivable_area"]) + + len(nuplan_map._vector_map["crosswalks"]) + + len(nuplan_map._vector_map["walkways"]), + desc=f"Getting {nuplan_map.map_name} Elements", + position=1, + leave=False, + ) + + # This dict stores boundary IDs and which lanes are to the left and right of them. + boundary_connectivity_dict: Dict[str, Dict[str, List[str]]] = defaultdict( + lambda: defaultdict(list) + ) + + # This dict stores lanes' boundary IDs. + lane_boundary_dict: Dict[str, Tuple[str, str]] = dict() + for _, lane_info in nuplan_map._vector_map["baseline_paths"].iterrows(): + ( + lane_id, + center_pts, + left_pts, + right_pts, + boundary_info, + ) = extract_lane_and_edges(nuplan_map, lane_info, lane_connector_idxs) + + lane_boundary_dict[lane_id] = boundary_info + left_boundary_id, right_boundary_id = boundary_info + + # The left boundary of Lane A has Lane A to its right. + boundary_connectivity_dict[left_boundary_id]["right"].append(lane_id) + + # The right boundary of Lane A has Lane A to its left. + boundary_connectivity_dict[right_boundary_id]["left"].append(lane_id) + + # "partial" because we aren't adding lane connectivity until later. + partial_new_lane = RoadLane( + id=lane_id, + center=Polyline(center_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), + ) + vector_map.add_map_element(partial_new_lane) + overall_pbar.update() + + for fid, polygon_info in nuplan_map._vector_map["drivable_area"].iterrows(): + polygon_pts = extract_area(nuplan_map, polygon_info) + + new_road_area = RoadArea(id=fid, exterior_polygon=Polyline(polygon_pts)) + for hole in polygon_info["geometry"].interiors: + hole_pts = extract_area(nuplan_map, hole) + new_road_area.interior_holes.append(Polyline(hole_pts)) + + vector_map.add_map_element(new_road_area) + overall_pbar.update() + + for fid, ped_area_record in nuplan_map._vector_map["crosswalks"].iterrows(): + polygon_pts = extract_area(nuplan_map, ped_area_record) + + new_ped_crosswalk = PedCrosswalk(id=fid, polygon=Polyline(polygon_pts)) + vector_map.add_map_element(new_ped_crosswalk) + overall_pbar.update() + + for fid, ped_area_record in nuplan_map._vector_map["walkways"].iterrows(): + polygon_pts = extract_area(nuplan_map, ped_area_record) + + new_ped_walkway = PedWalkway(id=fid, polygon=Polyline(polygon_pts)) + vector_map.add_map_element(new_ped_walkway) + overall_pbar.update() + + overall_pbar.close() + + # Lane connectivity + lane_connectivity_exit_dict = defaultdict(list) + lane_connectivity_entry_dict = defaultdict(list) + for lane_connector_fid, lane_connector in tqdm( + nuplan_map._vector_map["lane_connectors"].iterrows(), + desc="Getting Lane Connectivity", + total=len(nuplan_map._vector_map["lane_connectors"]), + position=1, + leave=False, + ): + lane_connectivity_exit_dict[str(lane_connector["exit_lane_fid"])].append( + lane_connector_fid + ) + lane_connectivity_entry_dict[lane_connector_fid].append( + str(lane_connector["exit_lane_fid"]) + ) + + lane_connectivity_exit_dict[lane_connector_fid].append( + str(lane_connector["entry_lane_fid"]) + ) + lane_connectivity_entry_dict[str(lane_connector["entry_lane_fid"])].append( + lane_connector_fid + ) + + map_elem: RoadLane + for map_elem in tqdm( + vector_map.elements[MapElementType.ROAD_LANE].values(), + desc="Storing Lane Connectivity", + position=1, + leave=False, + ): + map_elem.prev_lanes.update(lane_connectivity_entry_dict[map_elem.id]) + map_elem.next_lanes.update(lane_connectivity_exit_dict[map_elem.id]) + + lane_id: str = map_elem.id + left_boundary_id, right_boundary_id = lane_boundary_dict[lane_id] + + map_elem.adj_lanes_left.update( + boundary_connectivity_dict[left_boundary_id]["left"] + ) + map_elem.adj_lanes_right.update( + boundary_connectivity_dict[right_boundary_id]["right"] + ) diff --git a/src/trajdata/dataset_specific/nusc/nusc_dataset.py b/src/trajdata/dataset_specific/nusc/nusc_dataset.py index 401658a..209824e 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_dataset.py +++ b/src/trajdata/dataset_specific/nusc/nusc_dataset.py @@ -3,14 +3,11 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Type, Union -import numpy as np import pandas as pd from nuscenes.eval.prediction.splits import NUM_IN_TRAIN_VAL -from nuscenes.map_expansion import arcline_path_utils from nuscenes.map_expansion.map_api import NuScenesMap, locations from nuscenes.nuscenes import NuScenes from nuscenes.utils.splits import create_splits_scenes -from scipy.spatial.distance import cdist from tqdm import tqdm from trajdata.caching import EnvCache, SceneCache @@ -27,16 +24,7 @@ from trajdata.dataset_specific.nusc import nusc_utils from trajdata.dataset_specific.raw_dataset import RawDataset from trajdata.dataset_specific.scene_records import NuscSceneRecord -from trajdata.maps import RasterizedMap, RasterizedMapMetadata, map_utils -from trajdata.proto.vectorized_map_pb2 import ( - MapElement, - PedCrosswalk, - PedWalkway, - Polyline, - RoadArea, - RoadLane, - VectorizedMap, -) +from trajdata.maps import VectorMap class NuscDataset(RawDataset): @@ -82,6 +70,8 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: ("mini_train", "mini_val"), ("boston", "singapore"), ] + else: + raise ValueError(f"Unknown nuScenes environment name: {env_name}") # Inverting the dict from above, associating every scene with its data split. nusc_scene_split_map: Dict[str, str] = { @@ -94,6 +84,9 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: dt=nusc_utils.NUSC_DT, parts=dataset_parts, scene_split_map=nusc_scene_split_map, + # The location names should match the map names used in + # the unified data cache. + map_locations=tuple(locations), ) def load_dataset_obj(self, verbose: bool = False) -> None: @@ -274,223 +267,6 @@ def get_agent_info( return agent_list, agent_presence - def extract_lane_and_edges( - self, nusc_map: NuScenesMap, lane_record - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - # Getting the bounding polygon vertices. - lane_polygon_obj = nusc_map.get("polygon", lane_record["polygon_token"]) - polygon_nodes = [ - nusc_map.get("node", node_token) - for node_token in lane_polygon_obj["exterior_node_tokens"] - ] - polygon_pts: np.ndarray = np.array( - [(node["x"], node["y"]) for node in polygon_nodes] - ) - - # Getting the lane center's points. - curr_lane = nusc_map.arcline_path_3.get(lane_record["token"], []) - lane_midline: np.ndarray = np.array( - arcline_path_utils.discretize_lane(curr_lane, resolution_meters=0.5) - )[:, :2] - - # For some reason, nuScenes duplicates a few entries - # (likely how they're building their arcline representation). - # We delete those duplicate entries here. - duplicate_check: np.ndarray = np.where( - np.linalg.norm(np.diff(lane_midline, axis=0, prepend=0), axis=1) < 1e-10 - )[0] - if duplicate_check.size > 0: - lane_midline = np.delete(lane_midline, duplicate_check, axis=0) - - # Computing the closest lane center point to each bounding polygon vertex. - closest_midlane_pt: np.ndarray = np.argmin( - cdist(polygon_pts, lane_midline), axis=1 - ) - # Computing the local direction of the lane at each lane center point. - direction_vectors: np.ndarray = np.diff( - lane_midline, - axis=0, - prepend=lane_midline[[0]] - (lane_midline[[1]] - lane_midline[[0]]), - ) - - # Selecting the direction vectors at the closest lane center point per polygon vertex. - local_dir_vecs: np.ndarray = direction_vectors[closest_midlane_pt] - # Calculating the vectors from the the closest lane center point per polygon vertex to the polygon vertex. - origin_to_polygon_vecs: np.ndarray = ( - polygon_pts - lane_midline[closest_midlane_pt] - ) - - # Computing the perpendicular dot product. - # See https://www.xarg.org/book/linear-algebra/2d-perp-product/ - # If perp_dot_product < 0, then the associated polygon vertex is - # on the right edge of the lane. - perp_dot_product: np.ndarray = ( - local_dir_vecs[:, 0] * origin_to_polygon_vecs[:, 1] - - local_dir_vecs[:, 1] * origin_to_polygon_vecs[:, 0] - ) - - # Determining which indices are on the right of the lane center. - on_right: np.ndarray = perp_dot_product < 0 - # Determining the boundary between the left/right polygon vertices - # (they will be together in blocks due to the ordering of the polygon vertices). - idx_changes: int = np.where(np.roll(on_right, 1) < on_right)[0].item() - - if idx_changes > 0: - # If the block of left/right points spreads across the bounds of the array, - # roll it until the boundary between left/right points is at index 0. - # This is important so that the following index selection orders points - # without jumps. - polygon_pts = np.roll(polygon_pts, shift=-idx_changes, axis=0) - on_right = np.roll(on_right, shift=-idx_changes) - - left_pts: np.ndarray = polygon_pts[~on_right] - right_pts: np.ndarray = polygon_pts[on_right] - - # Final ordering check, ensuring that the beginning of left_pts/right_pts - # matches the beginning of the lane. - left_order_correct: bool = np.linalg.norm( - left_pts[0] - lane_midline[0] - ) < np.linalg.norm(left_pts[0] - lane_midline[-1]) - right_order_correct: bool = np.linalg.norm( - right_pts[0] - lane_midline[0] - ) < np.linalg.norm(right_pts[0] - lane_midline[-1]) - - # Reversing left_pts/right_pts in case their first index is - # at the end of the lane. - if not left_order_correct: - left_pts = left_pts[::-1] - if not right_order_correct: - right_pts = right_pts[::-1] - - # Ensuring that left and right have the same number of points. - # This is necessary, not for data storage but for later rasterization. - if left_pts.shape[0] < right_pts.shape[0]: - left_pts = map_utils.interpolate(left_pts, right_pts.shape[0]) - elif right_pts.shape[0] < left_pts.shape[0]: - right_pts = map_utils.interpolate(right_pts, left_pts.shape[0]) - - return ( - lane_midline, - left_pts, - right_pts, - ) - - def extract_area(self, nusc_map: NuScenesMap, area_record) -> np.ndarray: - token_key: str - if "exterior_node_tokens" in area_record: - token_key = "exterior_node_tokens" - elif "node_tokens" in area_record: - token_key = "node_tokens" - - polygon_nodes = [ - nusc_map.get("node", node_token) for node_token in area_record[token_key] - ] - - return np.array([(node["x"], node["y"]) for node in polygon_nodes]) - - def extract_vectorized(self, nusc_map: NuScenesMap) -> VectorizedMap: - vec_map = VectorizedMap() - - # Setting the map bounds. - vec_map.max_pt.x, vec_map.max_pt.y, vec_map.max_pt.z = ( - nusc_map.explorer.canvas_max_x, - nusc_map.explorer.canvas_max_y, - 0.0, - ) - vec_map.min_pt.x, vec_map.min_pt.y, vec_map.min_pt.z = ( - nusc_map.explorer.canvas_min_x, - nusc_map.explorer.canvas_min_y, - 0.0, - ) - - overall_pbar = tqdm( - total=len(nusc_map.lane) - + len(nusc_map.drivable_area[0]["polygon_tokens"]) - + len(nusc_map.ped_crossing) - + len(nusc_map.walkway), - desc=f"Getting {nusc_map.map_name} Elements", - position=1, - leave=False, - ) - - for lane_record in nusc_map.lane: - center_pts, left_pts, right_pts = self.extract_lane_and_edges( - nusc_map, lane_record - ) - - lane_record_token: str = lane_record["token"] - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = lane_record_token.encode() - - new_lane: RoadLane = new_element.road_lane - map_utils.populate_lane_polylines(new_lane, center_pts, left_pts, right_pts) - - new_lane.entry_lanes.extend( - lane_id.encode() - for lane_id in nusc_map.get_incoming_lane_ids(lane_record_token) - ) - new_lane.exit_lanes.extend( - lane_id.encode() - for lane_id in nusc_map.get_outgoing_lane_ids(lane_record_token) - ) - - # new_lane.adjacent_lanes_left.append( - # l5_lane.adjacent_lane_change_left.id - # ) - # new_lane.adjacent_lanes_right.append( - # l5_lane.adjacent_lane_change_right.id - # ) - - overall_pbar.update() - - for polygon_token in nusc_map.drivable_area[0]["polygon_tokens"]: - polygon_record = nusc_map.get("polygon", polygon_token) - polygon_pts = self.extract_area(nusc_map, polygon_record) - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = lane_record["token"].encode() - - new_area: RoadArea = new_element.road_area - map_utils.populate_polygon(new_area.exterior_polygon, polygon_pts) - - for hole in polygon_record["holes"]: - polygon_pts = self.extract_area(nusc_map, hole) - new_hole: Polyline = new_area.interior_holes.add() - map_utils.populate_polygon(new_hole, polygon_pts) - - overall_pbar.update() - - for ped_area_record in nusc_map.ped_crossing: - polygon_pts = self.extract_area(nusc_map, ped_area_record) - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = ped_area_record["token"].encode() - - new_crosswalk: PedCrosswalk = new_element.ped_crosswalk - map_utils.populate_polygon(new_crosswalk.polygon, polygon_pts) - - overall_pbar.update() - - for ped_area_record in nusc_map.walkway: - polygon_pts = self.extract_area(nusc_map, ped_area_record) - - # Adding the element to the map. - new_element: MapElement = vec_map.elements.add() - new_element.id = ped_area_record["token"].encode() - - new_walkway: PedWalkway = new_element.ped_walkway - map_utils.populate_polygon(new_walkway.polygon, polygon_pts) - - overall_pbar.update() - - overall_pbar.close() - - return vec_map - def cache_map( self, map_name: str, @@ -498,81 +274,14 @@ def cache_map( map_cache_class: Type[SceneCache], map_params: Dict[str, Any], ) -> None: - resolution: float = map_params["px_per_m"] - nusc_map: NuScenesMap = NuScenesMap( dataroot=self.metadata.data_dir, map_name=map_name ) - if map_params.get("original_format", False): - warnings.warn( - "Using a dataset's original map format is deprecated, and will be removed in the next version of trajdata!", - FutureWarning, - ) - - width_m, height_m = nusc_map.canvas_edge - height_px, width_px = round(height_m * resolution), round( - width_m * resolution - ) + vector_map = VectorMap(map_id=f"{self.name}:{map_name}") + nusc_utils.populate_vector_map(vector_map, nusc_map) - def layer_fn(layer_name: str) -> np.ndarray: - # Getting rid of the channels dim by accessing index [0] - return nusc_map.get_map_mask( - patch_box=None, - patch_angle=0, - layer_names=[layer_name], - canvas_size=(height_px, width_px), - )[0].astype(np.bool) - - map_from_world: np.ndarray = np.array( - [[resolution, 0.0, 0.0], [0.0, resolution, 0.0], [0.0, 0.0, 1.0]] - ) - - layer_names: List[str] = [ - "lane", - "road_segment", - "drivable_area", - "road_divider", - "lane_divider", - "ped_crossing", - "walkway", - ] - map_info: RasterizedMapMetadata = RasterizedMapMetadata( - name=map_name, - shape=(len(layer_names), height_px, width_px), - layers=layer_names, - layer_rgb_groups=([0, 1, 2], [3, 4], [5, 6]), - resolution=resolution, - map_from_world=map_from_world, - ) - - map_cache_class.cache_map_layers( - cache_path, VectorizedMap(), map_info, layer_fn, self.name - ) - else: - vectorized_map: VectorizedMap = self.extract_vectorized(nusc_map) - - pbar_kwargs = {"position": 2, "leave": False} - map_data, map_from_world = map_utils.rasterize_map( - vectorized_map, resolution, **pbar_kwargs - ) - - rasterized_map_info: RasterizedMapMetadata = RasterizedMapMetadata( - name=map_name, - shape=map_data.shape, - layers=["drivable_area", "lane_divider", "ped_area"], - layer_rgb_groups=([0], [1], [2]), - resolution=resolution, - map_from_world=map_from_world, - ) - - rasterized_map_obj: RasterizedMap = RasterizedMap( - rasterized_map_info, map_data - ) - - map_cache_class.cache_map( - cache_path, vectorized_map, rasterized_map_obj, self.name - ) + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) def cache_maps( self, diff --git a/src/trajdata/dataset_specific/nusc/nusc_utils.py b/src/trajdata/dataset_specific/nusc/nusc_utils.py index 8cd4cb6..356debd 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_utils.py +++ b/src/trajdata/dataset_specific/nusc/nusc_utils.py @@ -1,12 +1,25 @@ -from typing import Any, Dict, Final, List, Union +from typing import Any, Dict, Final, List, Tuple, Union import numpy as np import pandas as pd +from nuscenes.map_expansion import arcline_path_utils +from nuscenes.map_expansion.map_api import NuScenesMap from nuscenes.nuscenes import NuScenes from pyquaternion import Quaternion +from scipy.spatial.distance import cdist +from tqdm import tqdm from trajdata.data_structures import Agent, AgentMetadata, AgentType, FixedExtent, Scene -from trajdata.utils import arr_utils +from trajdata.maps import VectorMap +from trajdata.maps.vec_map_elements import ( + MapElementType, + PedCrosswalk, + PedWalkway, + Polyline, + RoadArea, + RoadLane, +) +from trajdata.utils import arr_utils, map_utils NUSC_DT: Final[float] = 0.5 @@ -49,7 +62,7 @@ def agg_agent_data( if agent_data["prev"]: print("WARN: This is not the first frame of this agent!") - translation_list = [np.array(agent_data["translation"][:2])[np.newaxis]] + translation_list = [np.array(agent_data["translation"][:3])[np.newaxis]] agent_size = agent_data["size"] yaw_list = [Quaternion(agent_data["rotation"]).yaw_pitch_roll[0]] @@ -58,7 +71,7 @@ def agg_agent_data( while curr_sample_ann_token: agent_data = nusc_obj.get("sample_annotation", curr_sample_ann_token) - translation = np.array(agent_data["translation"][:2]) + translation = np.array(agent_data["translation"][:3]) heading = Quaternion(agent_data["rotation"]).yaw_pitch_roll[0] curr_idx: int = frame_idx_dict[agent_data["sample_token"]] if curr_idx > prev_idx + 1: @@ -73,6 +86,11 @@ def agg_agent_data( xp=[prev_idx, curr_idx], fp=[translation_list[-1][0, 1], translation[1]], ) + zs = np.interp( + x=fill_time, + xp=[prev_idx, curr_idx], + fp=[translation_list[-1][0, 2], translation[2]], + ) headings: np.ndarray = arr_utils.angle_wrap( np.interp( x=fill_time, @@ -80,7 +98,7 @@ def agg_agent_data( fp=np.unwrap([yaw_list[-1], heading]), ) ) - translation_list.append(np.stack([xs, ys], axis=1)) + translation_list.append(np.stack([xs, ys, zs], axis=1)) yaw_list.extend(headings.tolist()) translation_list.append(translation[np.newaxis]) @@ -93,9 +111,13 @@ def agg_agent_data( translations_np = np.concatenate(translation_list, axis=0) # Doing this prepending so that the first velocity isn't zero (rather it's just the first actual velocity duplicated) - prepend_pos = translations_np[0] - (translations_np[1] - translations_np[0]) + prepend_pos = translations_np[0, :2] - ( + translations_np[1, :2] - translations_np[0, :2] + ) velocities_np = ( - np.diff(translations_np, axis=0, prepend=np.expand_dims(prepend_pos, axis=0)) + np.diff( + translations_np[:, :2], axis=0, prepend=np.expand_dims(prepend_pos, axis=0) + ) / NUSC_DT ) @@ -142,7 +164,7 @@ def agg_agent_data( last_timestep = curr_scene_index + agent_data_np.shape[0] - 1 agent_data_df = pd.DataFrame( agent_data_np, - columns=["x", "y", "vx", "vy", "ax", "ay", "heading"], + columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], index=pd.MultiIndex.from_tuples( [ (agent_data["instance_token"], idx) @@ -187,16 +209,18 @@ def agg_ego_data(nusc_obj: NuScenes, scene: Scene) -> Agent: for frame_info in frame_iterator(nusc_obj, scene): ego_pose = get_ego_pose(nusc_obj, frame_info) yaw_list.append(Quaternion(ego_pose["rotation"]).yaw_pitch_roll[0]) - translation_list.append(ego_pose["translation"][:2]) + translation_list.append(ego_pose["translation"]) translations_np: np.ndarray = np.stack(translation_list, axis=0) # Doing this prepending so that the first velocity isn't zero (rather it's just the first actual velocity duplicated) - prepend_pos: np.ndarray = translations_np[0] - ( - translations_np[1] - translations_np[0] + prepend_pos: np.ndarray = translations_np[0, :2] - ( + translations_np[1, :2] - translations_np[0, :2] ) velocities_np: np.ndarray = ( - np.diff(translations_np, axis=0, prepend=np.expand_dims(prepend_pos, axis=0)) + np.diff( + translations_np[:, :2], axis=0, prepend=np.expand_dims(prepend_pos, axis=0) + ) / NUSC_DT ) @@ -215,7 +239,7 @@ def agg_ego_data(nusc_obj: NuScenes, scene: Scene) -> Agent: ) ego_data_df = pd.DataFrame( ego_data_np, - columns=["x", "y", "vx", "vy", "ax", "ay", "heading"], + columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], index=pd.MultiIndex.from_tuples( [("ego", idx) for idx in range(ego_data_np.shape[0])], names=["agent_id", "scene_ts"], @@ -233,3 +257,251 @@ def agg_ego_data(nusc_obj: NuScenes, scene: Scene) -> Agent: metadata=ego_metadata, data=ego_data_df, ) + + +def extract_lane_center(nusc_map: NuScenesMap, lane_record) -> np.ndarray: + # Getting the lane center's points. + curr_lane = nusc_map.arcline_path_3.get(lane_record["token"], []) + lane_midline: np.ndarray = np.array( + arcline_path_utils.discretize_lane(curr_lane, resolution_meters=0.5) + )[:, :2] + + # For some reason, nuScenes duplicates a few entries + # (likely how they're building their arcline representation). + # We delete those duplicate entries here. + duplicate_check: np.ndarray = np.where( + np.linalg.norm(np.diff(lane_midline, axis=0, prepend=0), axis=1) < 1e-10 + )[0] + if duplicate_check.size > 0: + lane_midline = np.delete(lane_midline, duplicate_check, axis=0) + + return lane_midline + + +def extract_lane_and_edges( + nusc_map: NuScenesMap, lane_record +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + # Getting the bounding polygon vertices. + lane_polygon_obj = nusc_map.get("polygon", lane_record["polygon_token"]) + polygon_nodes = [ + nusc_map.get("node", node_token) + for node_token in lane_polygon_obj["exterior_node_tokens"] + ] + polygon_pts: np.ndarray = np.array( + [(node["x"], node["y"]) for node in polygon_nodes] + ) + + # Getting the lane center's points. + lane_midline: np.ndarray = extract_lane_center(nusc_map, lane_record) + + # Computing the closest lane center point to each bounding polygon vertex. + closest_midlane_pt: np.ndarray = np.argmin(cdist(polygon_pts, lane_midline), axis=1) + # Computing the local direction of the lane at each lane center point. + direction_vectors: np.ndarray = np.diff( + lane_midline, + axis=0, + prepend=lane_midline[[0]] - (lane_midline[[1]] - lane_midline[[0]]), + ) + + # Selecting the direction vectors at the closest lane center point per polygon vertex. + local_dir_vecs: np.ndarray = direction_vectors[closest_midlane_pt] + # Calculating the vectors from the the closest lane center point per polygon vertex to the polygon vertex. + origin_to_polygon_vecs: np.ndarray = polygon_pts - lane_midline[closest_midlane_pt] + + # Computing the perpendicular dot product. + # See https://www.xarg.org/book/linear-algebra/2d-perp-product/ + # If perp_dot_product < 0, then the associated polygon vertex is + # on the right edge of the lane. + perp_dot_product: np.ndarray = ( + local_dir_vecs[:, 0] * origin_to_polygon_vecs[:, 1] + - local_dir_vecs[:, 1] * origin_to_polygon_vecs[:, 0] + ) + + # Determining which indices are on the right of the lane center. + on_right: np.ndarray = perp_dot_product < 0 + # Determining the boundary between the left/right polygon vertices + # (they will be together in blocks due to the ordering of the polygon vertices). + idx_changes: int = np.where(np.roll(on_right, 1) < on_right)[0].item() + + if idx_changes > 0: + # If the block of left/right points spreads across the bounds of the array, + # roll it until the boundary between left/right points is at index 0. + # This is important so that the following index selection orders points + # without jumps. + polygon_pts = np.roll(polygon_pts, shift=-idx_changes, axis=0) + on_right = np.roll(on_right, shift=-idx_changes) + + left_pts: np.ndarray = polygon_pts[~on_right] + right_pts: np.ndarray = polygon_pts[on_right] + + # Final ordering check, ensuring that left_pts and right_pts can be combined + # into a polygon without the endpoints intersecting. + # Reversing the one lane edge that does not match the ordering of the midline. + if map_utils.endpoints_intersect(left_pts, right_pts): + if not map_utils.order_matches(left_pts, lane_midline): + left_pts = left_pts[::-1] + else: + right_pts = right_pts[::-1] + + # Ensuring that left and right have the same number of points. + # This is necessary, not for data storage but for later rasterization. + if left_pts.shape[0] < right_pts.shape[0]: + left_pts = map_utils.interpolate(left_pts, num_pts=right_pts.shape[0]) + elif right_pts.shape[0] < left_pts.shape[0]: + right_pts = map_utils.interpolate(right_pts, num_pts=left_pts.shape[0]) + + return ( + lane_midline, + left_pts, + right_pts, + ) + + +def extract_area(nusc_map: NuScenesMap, area_record) -> np.ndarray: + token_key: str + if "exterior_node_tokens" in area_record: + token_key = "exterior_node_tokens" + elif "node_tokens" in area_record: + token_key = "node_tokens" + + polygon_nodes = [ + nusc_map.get("node", node_token) for node_token in area_record[token_key] + ] + + return np.array([(node["x"], node["y"]) for node in polygon_nodes]) + + +def populate_vector_map(vector_map: VectorMap, nusc_map: NuScenesMap) -> None: + # Setting the map bounds. + vector_map.extent = np.array( + [ + nusc_map.explorer.canvas_min_x, + nusc_map.explorer.canvas_min_y, + 0.0, + nusc_map.explorer.canvas_max_x, + nusc_map.explorer.canvas_max_y, + 0.0, + ] + ) + + overall_pbar = tqdm( + total=len(nusc_map.lane) + + len(nusc_map.lane_connector) + + len(nusc_map.drivable_area) + + len(nusc_map.ped_crossing) + + len(nusc_map.walkway), + desc=f"Getting {nusc_map.map_name} Elements", + position=1, + leave=False, + ) + + for lane_record in nusc_map.lane: + center_pts, left_pts, right_pts = extract_lane_and_edges(nusc_map, lane_record) + + lane_record_token: str = lane_record["token"] + + new_lane = RoadLane( + id=lane_record_token, + center=Polyline(center_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), + ) + + for lane_id in nusc_map.get_incoming_lane_ids(lane_record_token): + # Need to do this because some incoming/outgoing lane_connector IDs + # do not exist as lane_connectors... + if lane_id in nusc_map._token2ind["lane_connector"]: + new_lane.prev_lanes.add(lane_id) + + for lane_id in nusc_map.get_outgoing_lane_ids(lane_record_token): + # Need to do this because some incoming/outgoing lane_connector IDs + # do not exist as lane_connectors... + if lane_id in nusc_map._token2ind["lane_connector"]: + new_lane.next_lanes.add(lane_id) + + # new_lane.adjacent_lanes_left.append( + # l5_lane.adjacent_lane_change_left.id + # ) + # new_lane.adjacent_lanes_right.append( + # l5_lane.adjacent_lane_change_right.id + # ) + + # Adding the element to the map. + vector_map.add_map_element(new_lane) + overall_pbar.update() + + for lane_record in nusc_map.lane_connector: + # Unfortunately lane connectors in nuScenes have very simple exterior + # polygons which make extracting their edges quite difficult, so we + # only extract the centerline. + center_pts = extract_lane_center(nusc_map, lane_record) + + lane_record_token: str = lane_record["token"] + + # Adding the element to the map. + new_lane = RoadLane( + id=lane_record_token, + center=Polyline(center_pts), + ) + + new_lane.prev_lanes.update(nusc_map.get_incoming_lane_ids(lane_record_token)) + new_lane.next_lanes.update(nusc_map.get_outgoing_lane_ids(lane_record_token)) + + # new_lane.adjacent_lanes_left.append( + # l5_lane.adjacent_lane_change_left.id + # ) + # new_lane.adjacent_lanes_right.append( + # l5_lane.adjacent_lane_change_right.id + # ) + + # Adding the element to the map. + vector_map.add_map_element(new_lane) + overall_pbar.update() + + for drivable_area in nusc_map.drivable_area: + for polygon_token in drivable_area["polygon_tokens"]: + if ( + polygon_token is None + and str(None) in vector_map.elements[MapElementType.ROAD_AREA] + ): + # See below, but essentially nuScenes has two None polygon_tokens + # back-to-back, so we don't need the second one. + continue + + polygon_record = nusc_map.get("polygon", polygon_token) + polygon_pts = extract_area(nusc_map, polygon_record) + + # NOTE: nuScenes has some polygon_tokens that are None, although that + # doesn't stop the above get(...) function call so it's fine, + # just have to be mindful of this when creating the id. + new_road_area = RoadArea( + id=str(polygon_token), exterior_polygon=Polyline(polygon_pts) + ) + + for hole in polygon_record["holes"]: + polygon_pts = extract_area(nusc_map, hole) + new_road_area.interior_holes.append(Polyline(polygon_pts)) + + # Adding the element to the map. + vector_map.add_map_element(new_road_area) + overall_pbar.update() + + for ped_area_record in nusc_map.ped_crossing: + polygon_pts = extract_area(nusc_map, ped_area_record) + + # Adding the element to the map. + vector_map.add_map_element( + PedCrosswalk(id=ped_area_record["token"], polygon=Polyline(polygon_pts)) + ) + overall_pbar.update() + + for ped_area_record in nusc_map.walkway: + polygon_pts = extract_area(nusc_map, ped_area_record) + + # Adding the element to the map. + vector_map.add_map_element( + PedWalkway(id=ped_area_record["token"], polygon=Polyline(polygon_pts)) + ) + overall_pbar.update() + + overall_pbar.close() diff --git a/src/trajdata/dataset_specific/raw_dataset.py b/src/trajdata/dataset_specific/raw_dataset.py index b26f392..57188d3 100644 --- a/src/trajdata/dataset_specific/raw_dataset.py +++ b/src/trajdata/dataset_specific/raw_dataset.py @@ -81,6 +81,13 @@ def get_scene(self, scene_info: SceneMetadata) -> Scene: def get_agent_info( self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + """ + Get frame-level information from source dataset, caching it + to cache_path. + + Always called after cache_maps, can load map if needed + to associate map information to positions. + """ raise NotImplementedError() def cache_maps( @@ -90,6 +97,9 @@ def cache_maps( map_params: Dict[str, Any], ) -> None: """ - resolution is in pixels per meter. + Get static, scene-level info from the source dataset, caching it + to cache_path. (Primarily this is info needed to construct VectorMap) + + Resolution is in pixels per meter. """ raise NotImplementedError() diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py index 60d3abb..b785415 100644 --- a/src/trajdata/dataset_specific/scene_records.py +++ b/src/trajdata/dataset_specific/scene_records.py @@ -1,6 +1,11 @@ from typing import NamedTuple +class Argoverse2Record(NamedTuple): + name: str + data_idx: int + + class EUPedsRecord(NamedTuple): name: str location: str @@ -9,6 +14,18 @@ class EUPedsRecord(NamedTuple): data_idx: int +class SDDPedsRecord(NamedTuple): + name: str + length: str + data_idx: int + + +class InteractionRecord(NamedTuple): + name: str + length: str + data_idx: int + + class NuscSceneRecord(NamedTuple): name: str location: str @@ -17,7 +34,31 @@ class NuscSceneRecord(NamedTuple): data_idx: int +class VODSceneRecord(NamedTuple): + token: str + name: str + location: str + length: str + desc: str + data_idx: int + + class LyftSceneRecord(NamedTuple): name: str length: str data_idx: int + + +class WaymoSceneRecord(NamedTuple): + name: str + length: str + data_idx: int + + +class NuPlanSceneRecord(NamedTuple): + name: str + location: str + length: str + split: str + # desc: str + data_idx: int diff --git a/src/trajdata/dataset_specific/sdd_peds/__init__.py b/src/trajdata/dataset_specific/sdd_peds/__init__.py new file mode 100644 index 0000000..46268b0 --- /dev/null +++ b/src/trajdata/dataset_specific/sdd_peds/__init__.py @@ -0,0 +1 @@ +from .sddpeds_dataset import SDDPedsDataset diff --git a/src/trajdata/dataset_specific/sdd_peds/estimated_homography.py b/src/trajdata/dataset_specific/sdd_peds/estimated_homography.py new file mode 100644 index 0000000..6c20523 --- /dev/null +++ b/src/trajdata/dataset_specific/sdd_peds/estimated_homography.py @@ -0,0 +1,68 @@ +from typing import Dict, Final + +# Please see https://github.com/crowdbotp/OpenTraj/tree/master/datasets/SDD for more information. +# These homographies (transformations from pixel values to world coordinates) were estimated, +# albeit most of them with high certainty. The certainty values indicate how reliable the +# estimate is (or is not). Some of these scales were estimated using google maps, others are a pure guess. +SDD_HOMOGRAPHY_SCALES: Final[Dict[str, Dict[str, float]]] = { + "bookstore_0": {"certainty": 1.0, "scale": 0.038392063}, + "bookstore_1": {"certainty": 1.0, "scale": 0.039892913}, + "bookstore_2": {"certainty": 1.0, "scale": 0.04062433}, + "bookstore_3": {"certainty": 1.0, "scale": 0.039098596}, + "bookstore_4": {"certainty": 1.0, "scale": 0.0396}, + "bookstore_5": {"certainty": 0.9, "scale": 0.0396}, + "bookstore_6": {"certainty": 0.9, "scale": 0.0413}, + "coupa_0": {"certainty": 1.0, "scale": 0.027995674}, + "coupa_1": {"certainty": 1.0, "scale": 0.023224545}, + "coupa_2": {"certainty": 1.0, "scale": 0.024}, + "coupa_3": {"certainty": 1.0, "scale": 0.025524906}, + "deathCircle_0": {"certainty": 1.0, "scale": 0.04064}, + "deathCircle_1": {"certainty": 1.0, "scale": 0.039076923}, + "deathCircle_2": {"certainty": 1.0, "scale": 0.03948382}, + "deathCircle_3": {"certainty": 1.0, "scale": 0.028478209}, + "deathCircle_4": {"certainty": 1.0, "scale": 0.038980137}, + "gates_0": {"certainty": 1.0, "scale": 0.03976968}, + "gates_1": {"certainty": 1.0, "scale": 0.03770837}, + "gates_2": {"certainty": 1.0, "scale": 0.037272793}, + "gates_3": {"certainty": 1.0, "scale": 0.034515323}, + "gates_4": {"certainty": 1.0, "scale": 0.04412268}, + "gates_5": {"certainty": 1.0, "scale": 0.0342392}, + "gates_6": {"certainty": 1.0, "scale": 0.0342392}, + "gates_7": {"certainty": 1.0, "scale": 0.04540353}, + "gates_8": {"certainty": 1.0, "scale": 0.045191525}, + "hyang_0": {"certainty": 1.0, "scale": 0.034749693}, + "hyang_1": {"certainty": 1.0, "scale": 0.0453136}, + "hyang_10": {"certainty": 1.0, "scale": 0.054460944}, + "hyang_11": {"certainty": 1.0, "scale": 0.054992233}, + "hyang_12": {"certainty": 1.0, "scale": 0.054104065}, + "hyang_13": {"certainty": 0.0, "scale": 0.0541}, + "hyang_14": {"certainty": 0.0, "scale": 0.0541}, + "hyang_2": {"certainty": 1.0, "scale": 0.054992233}, + "hyang_3": {"certainty": 1.0, "scale": 0.056642}, + "hyang_4": {"certainty": 1.0, "scale": 0.034265612}, + "hyang_5": {"certainty": 1.0, "scale": 0.029655497}, + "hyang_6": {"certainty": 1.0, "scale": 0.052936449}, + "hyang_7": {"certainty": 1.0, "scale": 0.03540125}, + "hyang_8": {"certainty": 1.0, "scale": 0.034592381}, + "hyang_9": {"certainty": 1.0, "scale": 0.038031423}, + "little_0": {"certainty": 1.0, "scale": 0.028930169}, + "little_1": {"certainty": 1.0, "scale": 0.028543144}, + "little_2": {"certainty": 1.0, "scale": 0.028543144}, + "little_3": {"certainty": 1.0, "scale": 0.028638926}, + "nexus_0": {"certainty": 1.0, "scale": 0.043986494}, + "nexus_1": {"certainty": 1.0, "scale": 0.043316805}, + "nexus_10": {"certainty": 1.0, "scale": 0.043991753}, + "nexus_11": {"certainty": 1.0, "scale": 0.043766154}, + "nexus_2": {"certainty": 1.0, "scale": 0.042247434}, + "nexus_3": {"certainty": 1.0, "scale": 0.045883871}, + "nexus_4": {"certainty": 1.0, "scale": 0.045883871}, + "nexus_5": {"certainty": 1.0, "scale": 0.045395745}, + "nexus_6": {"certainty": 1.0, "scale": 0.037929168}, + "nexus_7": {"certainty": 1.0, "scale": 0.037106087}, + "nexus_8": {"certainty": 1.0, "scale": 0.037106087}, + "nexus_9": {"certainty": 1.0, "scale": 0.044917895}, + "quad_0": {"certainty": 1.0, "scale": 0.043606807}, + "quad_1": {"certainty": 1.0, "scale": 0.042530206}, + "quad_2": {"certainty": 1.0, "scale": 0.043338169}, + "quad_3": {"certainty": 1.0, "scale": 0.044396842}, +} diff --git a/src/trajdata/dataset_specific/sdd_peds/sddpeds_dataset.py b/src/trajdata/dataset_specific/sdd_peds/sddpeds_dataset.py new file mode 100644 index 0000000..a94bf88 --- /dev/null +++ b/src/trajdata/dataset_specific/sdd_peds/sddpeds_dataset.py @@ -0,0 +1,361 @@ +from pathlib import Path +from random import Random +from typing import Any, Dict, Final, List, Optional, Tuple, Type + +import numpy as np +import pandas as pd + +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures.agent import AgentMetadata, AgentType, FixedExtent +from trajdata.data_structures.environment import EnvMetadata +from trajdata.data_structures.scene_metadata import Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import SDDPedsRecord +from trajdata.utils import arr_utils + +from .estimated_homography import SDD_HOMOGRAPHY_SCALES + +# SDD was captured at 30 frames per second. +SDDPEDS_DT: Final[float] = 1.0 / 30.0 + + +# There are 60 scenes in total. +SDDPEDS_SCENE_COUNTS: Final[Dict[str, int]] = { + "bookstore": 7, + "coupa": 4, + "deathCircle": 5, + "gates": 9, + "hyang": 15, + "little": 4, + "nexus": 12, + "quad": 4, +} + + +def sdd_type_to_unified_type(label: str) -> AgentType: + if label == "Pedestrian": + return AgentType.PEDESTRIAN + elif label == "Biker": + return AgentType.BICYCLE + elif label in {"Cart", "Car", "Bus"}: + return AgentType.VEHICLE + elif label == "Skater": + return AgentType.UNKNOWN + + +class SDDPedsDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + # Using seeded randomness to assign 42 scenes (70% of all scenes) to "train", + # 9 (15%) to "val", and 9 (15%) to "test". + rng = Random(0) + scene_split = ["train"] * 42 + ["val"] * 9 + ["test"] * 9 + rng.shuffle(scene_split) + + scene_list: List[str] = [] + for scene_name, video_count in SDDPEDS_SCENE_COUNTS.items(): + scene_list += [f"{scene_name}_{idx}" for idx in range(video_count)] + + scene_split_map: Dict[str, str] = { + scene_list[idx]: scene_split[idx] for idx in range(len(scene_split)) + } + + # SDD possibilities are the Cartesian product of these, + dataset_parts: List[Tuple[str, ...]] = [ + ("train", "val", "test"), + ("stanford",), + ] + + env_metadata = EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=SDDPEDS_DT, + parts=dataset_parts, + scene_split_map=scene_split_map, + ) + return env_metadata + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + print(f"Loading {self.name} dataset...", flush=True) + + # Just storing the filepath and scene length (number of frames). + # One could load the entire dataset here, but there's no need + # since it's ~500 MB in size and we can parallel process it later easily. + self.dataset_obj: Dict[str, Tuple[Path, int]] = dict() + for scene_name, video_count in SDDPEDS_SCENE_COUNTS.items(): + for video_num in range(video_count): + data_filepath: Path = ( + Path(self.metadata.data_dir) + / scene_name + / f"video{video_num}" + / "annotations.txt" + ) + + csv_columns = [ + "agent_id", + "x_min", + "y_min", + "x_max", + "y_max", + "frame_id", + "lost", + "occluded", + "generated", + "label", + ] + data = pd.read_csv( + data_filepath, + sep=" ", + index_col=False, + header=None, + names=csv_columns, + usecols=["frame_id", "generated"], + dtype={"frame_id": int, "generated": bool}, + ) + # Ignoring generated frames in the count here since + # we will remove them later (we'll do our own interpolation). + data = data[~data["generated"]] + data["frame_id"] -= data["frame_id"].min() + + self.dataset_obj[f"{scene_name}_{video_num}"] = ( + data_filepath, + data["frame_id"].max().item() + 1, + ) + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[SDDPedsRecord] = list() + + scenes_list: List[SceneMetadata] = list() + for idx, (scene_name, (scene_filepath, scene_length)) in enumerate( + self.dataset_obj.items() + ): + if scene_name not in self.metadata.scene_split_map: + raise ValueError() + + scene_split: str = self.metadata.scene_split_map[scene_name] + + # Saving all scene records for later caching. + all_scenes_list.append(SDDPedsRecord(scene_name, scene_length, idx)) + + if ( + "stanford" in scene_tag + and scene_split in scene_tag + and scene_desc_contains is None + ): + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=scene_name, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + + self.cache_all_scenes_list(env_cache, all_scenes_list) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list: List[SDDPedsRecord] = env_cache.load_env_scenes_list(self.name) + + scenes_list: List[Scene] = list() + for scene_record in all_scenes_list: + scene_name, scene_length, data_idx = scene_record + scene_split: str = self.metadata.scene_split_map[scene_name] + + if ( + "stanford" in scene_tag + and scene_split in scene_tag + and scene_desc_contains is None + ): + scene_metadata = Scene( + self.metadata, + scene_name, + "stanford", + scene_split, + scene_length, + data_idx, + None, # This isn't used if everything is already cached. + ) + scenes_list.append(scene_metadata) + + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + _, scene_name, _, data_idx = scene_info + + _, scene_length = self.dataset_obj[scene_name] + scene_split: str = self.metadata.scene_split_map[scene_name] + + return Scene( + self.metadata, + scene_name, + "stanford", + scene_split, + scene_length, + data_idx, + None, # No data access info necessary for the ETH/UCY datasets. + ) + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + scene_filepath, _ = self.dataset_obj[scene.name] + + csv_columns = [ + "agent_id", + "x_min", + "y_min", + "x_max", + "y_max", + "frame_id", + "lost", + "occluded", + "generated", + "label", + ] + data_df: pd.DataFrame = pd.read_csv( + scene_filepath, + sep=" ", + index_col=False, + header=None, + names=csv_columns, + dtype={"generated": bool}, + ) + + # Setting generated frames to NaN, we'll do our own interpolation later. + data_df.loc[data_df["generated"], ["x_min", "y_min"]] = np.nan + data_df["frame_id"] -= data_df["frame_id"].min() + + scale: float = SDD_HOMOGRAPHY_SCALES[scene.name]["scale"] + data_df["x"] = scale * (data_df["x_min"] + data_df["x_max"]) / 2.0 + data_df["y"] = scale * (data_df["y_min"] + data_df["y_max"]) / 2.0 + + # Don't need these columns anymore. + data_df.drop( + columns=[ + "x_min", + "y_min", + "x_max", + "y_max", + "lost", + "occluded", + "generated", + ], + inplace=True, + ) + + # Renaming columns to match our usual names. + data_df.rename( + columns={"frame_id": "scene_ts", "label": "agent_type"}, + inplace=True, + ) + + # Ensuring data is sorted by agent ID and scene timestep. + data_df.set_index(["agent_id", "scene_ts"], inplace=True) + data_df.sort_index(inplace=True) + + # Re-interpolating because the original SDD interpolation yielded discrete position steps, + # which is not very natural. Also, the data is already sorted by agent and time so + # we can safely do this without worrying about contaminating position data across agents. + data_df.interpolate( + method="linear", axis="index", inplace=True, limit_area="inside" + ) + + data_df.reset_index(level=1, inplace=True) + + agent_ids: np.ndarray = data_df.index.get_level_values(0).to_numpy() + + # Add in zero for z value + data_df["z"] = np.zeros_like(data_df["x"]) + + ### Calculating agent classes + agent_class: Dict[int, str] = ( + data_df.groupby("agent_id")["agent_type"].first().to_dict() + ) + + ### Calculating agent velocities + data_df[["vx", "vy"]] = ( + arr_utils.agent_aware_diff(data_df[["x", "y"]].to_numpy(), agent_ids) + / SDDPEDS_DT + ) + + ### Calculating agent accelerations + data_df[["ax", "ay"]] = ( + arr_utils.agent_aware_diff(data_df[["vx", "vy"]].to_numpy(), agent_ids) + / SDDPEDS_DT + ) + + # This is likely to be very noisy... Unfortunately, SDD only + # provides center of mass data. + data_df["heading"] = np.arctan2(data_df["vy"], data_df["vx"]) + + agent_list: List[AgentMetadata] = list() + agent_presence: List[List[AgentMetadata]] = [ + [] for _ in range(scene.length_timesteps) + ] + for agent_id, frames in data_df.groupby("agent_id")["scene_ts"]: + start_frame: int = frames.iat[0].item() + last_frame: int = frames.iat[-1].item() + + agent_type: AgentType = sdd_type_to_unified_type(agent_class[agent_id]) + + agent_metadata = AgentMetadata( + name=str(agent_id), + agent_type=agent_type, + first_timestep=start_frame, + last_timestep=last_frame, + # These values are as ballpark as it gets... It's not super reliable to use + # the pixel extents in the annotations since they are all always axis-aligned. + extent=FixedExtent(0.75, 0.75, 1.5), + ) + + agent_list.append(agent_metadata) + for frame in frames: + agent_presence[frame].append(agent_metadata) + + # Changing the agent_id dtype to str + data_df.reset_index(inplace=True) + data_df["agent_id"] = data_df["agent_id"].astype(str) + data_df.set_index(["agent_id", "scene_ts"], inplace=True) + + cache_class.save_agent_data( + data_df, + cache_path, + scene, + ) + + return agent_list, agent_presence + + def cache_map( + self, + map_name: str, + layer_names: List[str], + cache_path: Path, + map_cache_class: Type[SceneCache], + resolution: float, + ) -> None: + """ + No maps in this dataset! + """ + pass + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + """ + No maps in this dataset! + """ + pass diff --git a/src/trajdata/dataset_specific/vod/__init__.py b/src/trajdata/dataset_specific/vod/__init__.py new file mode 100644 index 0000000..ab3e325 --- /dev/null +++ b/src/trajdata/dataset_specific/vod/__init__.py @@ -0,0 +1 @@ +from .vod_dataset import VODDataset diff --git a/src/trajdata/dataset_specific/vod/vod_dataset.py b/src/trajdata/dataset_specific/vod/vod_dataset.py new file mode 100644 index 0000000..be0dc9b --- /dev/null +++ b/src/trajdata/dataset_specific/vod/vod_dataset.py @@ -0,0 +1,315 @@ +import warnings +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import pandas as pd +from tqdm import tqdm +from vod.map_expansion.map_api import VODMap, locations +from vod.utils.splits import create_splits_scenes +from vod.vod import VOD + +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures.agent import ( + Agent, + AgentMetadata, + AgentType, + FixedExtent, + VariableExtent, +) +from trajdata.data_structures.environment import EnvMetadata +from trajdata.data_structures.scene_metadata import Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import VODSceneRecord +from trajdata.dataset_specific.vod import vod_utils +from trajdata.maps import VectorMap + + +class VODDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + # See https://github.com/tudelft-iv/view-of-delft-prediction-devkit/blob/main/src/vod/utils/splits.py + # for full details on how the splits are obtained below. + all_scene_splits: Dict[str, List[str]] = create_splits_scenes() + + train_scenes: List[str] = deepcopy(all_scene_splits["train"]) + + if env_name == "vod_trainval": + vod_scene_splits: Dict[str, List[str]] = { + k: all_scene_splits[k] for k in ["train", "train_val", "val"] + } + + # VoD possibilities are the Cartesian product of these + dataset_parts: List[Tuple[str, ...]] = [ + ("train", "train_val", "val"), + ("delft",), + ] + elif env_name == "vod_test": + vod_scene_splits: Dict[str, List[str]] = { + k: all_scene_splits[k] for k in ["test"] + } + + # VoD possibilities are the Cartesian product of these + dataset_parts: List[Tuple[str, ...]] = [ + ("test",), + ("delft",), + ] + + # elif env_name == "vod_mini": + # vod_scene_splits: Dict[str, List[str]] = { + # k: all_scene_splits[k] for k in ["mini_train", "mini_val"] + # } + + # # VoD possibilities are the Cartesian product of these + # dataset_parts: List[Tuple[str, ...]] = [ + # ("mini_train", "mini_val"), + # ("delft",), + # ] + else: + raise ValueError(f"Unknown VoD environment name: {env_name}") + + # Inverting the dict from above, associating every scene with its data split. + vod_scene_split_map: Dict[str, str] = { + v_elem: k for k, v in vod_scene_splits.items() for v_elem in v + } + + return EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=vod_utils.VOD_DT, + parts=dataset_parts, + scene_split_map=vod_scene_split_map, + # The location names should match the map names used in + # the unified data cache. + map_locations=tuple(locations), + ) + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + print(f"Loading {self.name} dataset...", flush=True) + + if self.name == "vod_trainval": + version_str = "v1.0-trainval" + elif self.name == "vod_test": + version_str = "v1.0-test" + # elif self.name == "vod_mini": + # version_str = "v1.0-mini" + + self.dataset_obj = VOD(version=version_str, dataroot=self.metadata.data_dir) + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[VODSceneRecord] = list() + + scenes_list: List[SceneMetadata] = list() + for idx, scene_record in enumerate(self.dataset_obj.scene): + scene_token: str = scene_record["token"] + scene_name: str = scene_record["name"] + scene_desc: str = scene_record["description"].lower() + scene_location: str = self.dataset_obj.get( + "log", scene_record["log_token"] + )["location"] + scene_split: str = self.metadata.scene_split_map[scene_token] + scene_length: int = scene_record["nbr_samples"] + + # Saving all scene records for later caching. + all_scenes_list.append( + VODSceneRecord( + scene_token, + scene_name, + scene_location, + scene_length, + scene_desc, + idx, + ) + ) + + if scene_location in scene_tag and scene_split in scene_tag: + if scene_desc_contains is not None and not any( + desc_query in scene_desc for desc_query in scene_desc_contains + ): + continue + + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=scene_name, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + + self.cache_all_scenes_list(env_cache, all_scenes_list) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list: List[VODSceneRecord] = env_cache.load_env_scenes_list( + self.name + ) + + scenes_list: List[SceneMetadata] = list() + for scene_record in all_scenes_list: + ( + scene_token, + scene_name, + scene_location, + scene_length, + scene_desc, + data_idx, + ) = scene_record + scene_split: str = self.metadata.scene_split_map[scene_token] + + if scene_location.split("-")[0] in scene_tag and scene_split in scene_tag: + if scene_desc_contains is not None and not any( + desc_query in scene_desc for desc_query in scene_desc_contains + ): + continue + + scene_metadata = Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + None, # This isn't used if everything is already cached. + scene_desc, + ) + scenes_list.append(scene_metadata) + + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + _, _, _, data_idx = scene_info + + scene_record = self.dataset_obj.scene[data_idx] + scene_token: str = scene_record["token"] + scene_name: str = scene_record["name"] + scene_desc: str = scene_record["description"].lower() + scene_location: str = self.dataset_obj.get("log", scene_record["log_token"])[ + "location" + ] + scene_split: str = self.metadata.scene_split_map[scene_token] + scene_length: int = scene_record["nbr_samples"] + + return Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + scene_record, + scene_desc, + ) + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + ego_agent_info: AgentMetadata = AgentMetadata( + name="ego", + agent_type=AgentType.VEHICLE, + first_timestep=0, + last_timestep=scene.length_timesteps - 1, + extent=FixedExtent(length=4.084, width=1.730, height=1.562), + ) + + agent_presence: List[List[AgentMetadata]] = [ + [ego_agent_info] for _ in range(scene.length_timesteps) + ] + + agent_data_list: List[pd.DataFrame] = list() + existing_agents: Dict[str, AgentMetadata] = dict() + + all_frames: List[Dict[str, Union[str, int]]] = list( + vod_utils.frame_iterator(self.dataset_obj, scene) + ) + frame_idx_dict: Dict[str, int] = { + frame_dict["token"]: idx for idx, frame_dict in enumerate(all_frames) + } + for frame_idx, frame_info in enumerate(all_frames): + for agent_info in vod_utils.agent_iterator(self.dataset_obj, frame_info): + if agent_info["instance_token"] in existing_agents: + continue + + if agent_info["category_name"] == "vehicle.ego": + # Do not double-count the ego vehicle + continue + + if not agent_info["next"]: + # There are some agents with only a single detection to them, we don't care about these. + continue + + agent: Agent = vod_utils.agg_agent_data( + self.dataset_obj, agent_info, frame_idx, frame_idx_dict + ) + + for scene_ts in range( + agent.metadata.first_timestep, agent.metadata.last_timestep + 1 + ): + agent_presence[scene_ts].append(agent.metadata) + + existing_agents[agent.name] = agent.metadata + + agent_data_list.append(agent.data) + + ego_agent: Agent = vod_utils.agg_ego_data(self.dataset_obj, scene) + agent_data_list.append(ego_agent.data) + + agent_list: List[AgentMetadata] = [ego_agent_info] + list( + existing_agents.values() + ) + + cache_class.save_agent_data(pd.concat(agent_data_list), cache_path, scene) + + return agent_list, agent_presence + + def cache_map( + self, + map_name: str, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + vod_map: VODMap = VODMap(dataroot=self.metadata.data_dir, map_name=map_name) + + vector_map = VectorMap(map_id=f"{self.name}:{map_name}") + vod_utils.populate_vector_map(vector_map, vod_map) + + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + """ + Stores rasterized maps to disk for later retrieval. + + Below are the map origins (south western corner, in [lat, lon]) for each of + the 4 maps in VoD: + delft: [] + + The dimensions of the maps are as follows ([width, height] in meters). They + can also be found in vod_utils.py + delft: [] + The rasterized semantic maps published with VoD v1.0 have a scale of 10px/m, + hence the above numbers are the image dimensions divided by 10. + + VoD uses the same WGS 84 Web Mercator (EPSG:3857) projection as Google Maps/Earth. + """ + for map_name in tqdm( + locations, + desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", + position=0, + ): + self.cache_map(map_name, cache_path, map_cache_class, map_params) diff --git a/src/trajdata/dataset_specific/vod/vod_utils.py b/src/trajdata/dataset_specific/vod/vod_utils.py new file mode 100644 index 0000000..a4050e8 --- /dev/null +++ b/src/trajdata/dataset_specific/vod/vod_utils.py @@ -0,0 +1,508 @@ +from typing import Any, Dict, Final, List, Tuple, Union + +import numpy as np +import pandas as pd +from pyquaternion import Quaternion +from scipy.spatial.distance import cdist +from tqdm import tqdm +from vod.map_expansion import arcline_path_utils +from vod.map_expansion.map_api import VODMap +from vod.vod import VOD + +from trajdata.data_structures import Agent, AgentMetadata, AgentType, FixedExtent, Scene +from trajdata.maps import VectorMap +from trajdata.maps.vec_map_elements import ( + MapElementType, + PedCrosswalk, + PedWalkway, + Polyline, + RoadArea, + RoadLane, +) +from trajdata.utils import arr_utils, map_utils + +VOD_DT: Final[float] = 0.1 + + +def frame_iterator(vod_obj: VOD, scene: Scene) -> Dict[str, Union[str, int]]: + """Loops through all frames in a scene and yields them for the caller to deal with the information.""" + curr_scene_token: str = scene.data_access_info["first_sample_token"] + while curr_scene_token: + frame = vod_obj.get("sample", curr_scene_token) + + yield frame + + curr_scene_token = frame["next"] + + +def agent_iterator(vod_obj: VOD, frame_info: Dict[str, Any]) -> Dict[str, Any]: + """Loops through all annotations (agents) in a frame and yields them for the caller to deal with the information.""" + ann_token: str + for ann_token in frame_info["anns"]: + ann_record = vod_obj.get("sample_annotation", ann_token) + + agent_category: str = ann_record["category_name"] + if agent_category.startswith("vehicle") or agent_category.startswith("human"): + yield ann_record + + +def get_ego_pose(vod_obj: VOD, frame_info: Dict[str, Any]) -> Dict[str, Any]: + cam_front_data = vod_obj.get("sample_data", frame_info["sample_data_token"]) + ego_pose = vod_obj.get("ego_pose", cam_front_data["ego_pose_token"]) + return ego_pose + + +def agg_agent_data( + vod_obj: VOD, + agent_data: Dict[str, Any], + curr_scene_index: int, + frame_idx_dict: Dict[str, int], +) -> Agent: + """Loops through all annotations of a specific agent in a scene and aggregates their data into an Agent object.""" + if agent_data["prev"]: + print("WARN: This is not the first frame of this agent!") + + translation_list = [np.array(agent_data["translation"][:3])[np.newaxis]] + agent_size = agent_data["size"] + yaw_list = [Quaternion(agent_data["rotation"]).yaw_pitch_roll[0]] + + prev_idx: int = curr_scene_index + curr_sample_ann_token: str = agent_data["next"] + while curr_sample_ann_token: + agent_data = vod_obj.get("sample_annotation", curr_sample_ann_token) + + translation = np.array(agent_data["translation"][:3]) + heading = Quaternion(agent_data["rotation"]).yaw_pitch_roll[0] + curr_idx: int = frame_idx_dict[agent_data["sample_token"]] + if curr_idx > prev_idx + 1: + fill_time = np.arange(prev_idx + 1, curr_idx) + xs = np.interp( + x=fill_time, + xp=[prev_idx, curr_idx], + fp=[translation_list[-1][0, 0], translation[0]], + ) + ys = np.interp( + x=fill_time, + xp=[prev_idx, curr_idx], + fp=[translation_list[-1][0, 1], translation[1]], + ) + zs = np.interp( + x=fill_time, + xp=[prev_idx, curr_idx], + fp=[translation_list[-1][0, 2], translation[2]], + ) + headings: np.ndarray = arr_utils.angle_wrap( + np.interp( + x=fill_time, + xp=[prev_idx, curr_idx], + fp=np.unwrap([yaw_list[-1], heading]), + ) + ) + translation_list.append(np.stack([xs, ys, zs], axis=1)) + yaw_list.extend(headings.tolist()) + + translation_list.append(translation[np.newaxis]) + # size_list.append(agent_data['size']) + yaw_list.append(heading) + + prev_idx = curr_idx + curr_sample_ann_token = agent_data["next"] + + translations_np = np.concatenate(translation_list, axis=0) + + # Doing this prepending so that the first velocity isn't zero (rather it's just the first actual velocity duplicated) + prepend_pos = translations_np[0, :2] - ( + translations_np[1, :2] - translations_np[0, :2] + ) + velocities_np = ( + np.diff( + translations_np[:, :2], axis=0, prepend=np.expand_dims(prepend_pos, axis=0) + ) + / VOD_DT + ) + + # Doing this prepending so that the first acceleration isn't zero (rather it's just the first actual acceleration duplicated) + prepend_vel = velocities_np[0] - (velocities_np[1] - velocities_np[0]) + accelerations_np = ( + np.diff(velocities_np, axis=0, prepend=np.expand_dims(prepend_vel, axis=0)) + / VOD_DT + ) + + anno_yaws_np = np.expand_dims(np.stack(yaw_list, axis=0), axis=1) + # yaws_np = np.expand_dims( + # np.arctan2(velocities_np[:, 1], velocities_np[:, 0]), axis=1 + # ) + # sizes_np = np.stack(size_list, axis=0) + + # import matplotlib.pyplot as plt + + # fig, ax = plt.subplots() + # ax.plot(translations_np[:, 0], translations_np[:, 1], color="blue") + # ax.quiver( + # translations_np[:, 0], + # translations_np[:, 1], + # np.cos(anno_yaws_np), + # np.sin(anno_yaws_np), + # color="green", + # label="annotated heading" + # ) + # ax.quiver( + # translations_np[:, 0], + # translations_np[:, 1], + # np.cos(yaws_np), + # np.sin(yaws_np), + # color="orange", + # label="velocity heading" + # ) + # ax.scatter([translations_np[0, 0]], [translations_np[0, 1]], color="red", label="Start", zorder=20) + # ax.legend(loc='best') + # plt.show() + + agent_data_np = np.concatenate( + [translations_np, velocities_np, accelerations_np, anno_yaws_np], axis=1 + ) + last_timestep = curr_scene_index + agent_data_np.shape[0] - 1 + agent_data_df = pd.DataFrame( + agent_data_np, + columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], + index=pd.MultiIndex.from_tuples( + [ + (agent_data["instance_token"], idx) + for idx in range(curr_scene_index, last_timestep + 1) + ], + names=["agent_id", "scene_ts"], + ), + ) + + agent_type = vod_type_to_unified_type(agent_data["category_name"]) + agent_metadata = AgentMetadata( + name=agent_data["instance_token"], + agent_type=agent_type, + first_timestep=curr_scene_index, + last_timestep=last_timestep, + extent=FixedExtent( + length=agent_size[1], width=agent_size[0], height=agent_size[2] + ), + ) + return Agent( + metadata=agent_metadata, + data=agent_data_df, + ) + + +def vod_type_to_unified_type(vod_type: str) -> AgentType: + if vod_type.startswith("human"): + return AgentType.PEDESTRIAN + elif vod_type == "vehicle.bicycle": + return AgentType.BICYCLE + elif vod_type == "vehicle.motorcycle": + return AgentType.MOTORCYCLE + elif vod_type.startswith("vehicle"): + return AgentType.VEHICLE + else: + return AgentType.UNKNOWN + + +def agg_ego_data(vod_obj: VOD, scene: Scene) -> Agent: + translation_list: List[np.ndarray] = list() + yaw_list: List[float] = list() + for frame_info in frame_iterator(vod_obj, scene): + ego_pose = get_ego_pose(vod_obj, frame_info) + yaw_list.append(Quaternion(ego_pose["rotation"]).yaw_pitch_roll[0]) + translation_list.append(ego_pose["translation"]) + + translations_np: np.ndarray = np.stack(translation_list, axis=0) + + # Doing this prepending so that the first velocity isn't zero (rather it's just the first actual velocity duplicated) + prepend_pos: np.ndarray = translations_np[0, :2] - ( + translations_np[1, :2] - translations_np[0, :2] + ) + velocities_np: np.ndarray = ( + np.diff( + translations_np[:, :2], axis=0, prepend=np.expand_dims(prepend_pos, axis=0) + ) + / VOD_DT + ) + + # Doing this prepending so that the first acceleration isn't zero (rather it's just the first actual acceleration duplicated) + prepend_vel: np.ndarray = velocities_np[0] - (velocities_np[1] - velocities_np[0]) + accelerations_np: np.ndarray = ( + np.diff(velocities_np, axis=0, prepend=np.expand_dims(prepend_vel, axis=0)) + / VOD_DT + ) + + yaws_np: np.ndarray = np.expand_dims(np.stack(yaw_list, axis=0), axis=1) + # yaws_np = np.expand_dims(np.arctan2(velocities_np[:, 1], velocities_np[:, 0]), axis=1) + + ego_data_np: np.ndarray = np.concatenate( + [translations_np, velocities_np, accelerations_np, yaws_np], axis=1 + ) + ego_data_df = pd.DataFrame( + ego_data_np, + columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], + index=pd.MultiIndex.from_tuples( + [("ego", idx) for idx in range(ego_data_np.shape[0])], + names=["agent_id", "scene_ts"], + ), + ) + + ego_metadata = AgentMetadata( + name="ego", + agent_type=AgentType.VEHICLE, + first_timestep=0, + last_timestep=ego_data_np.shape[0] - 1, + extent=FixedExtent(length=4.084, width=1.730, height=1.562), + ) + return Agent( + metadata=ego_metadata, + data=ego_data_df, + ) + + +def extract_lane_center(vod_map: VODMap, lane_record) -> np.ndarray: + # Getting the lane center's points. + curr_lane = vod_map.arcline_path_3.get(lane_record["token"], []) + # TBD: temporarily change resolution_meters from 0.5 to 1.0. Need to add this as a parameter into vector_map_params. + lane_midline: np.ndarray = np.array( + arcline_path_utils.discretize_lane(curr_lane, resolution_meters=1.0) + )[:, :2] + + # For some reason, VoD duplicates a few entries + # (likely how they're building their arcline representation). + # We delete those duplicate entries here. + duplicate_check: np.ndarray = np.where( + np.linalg.norm(np.diff(lane_midline, axis=0, prepend=0), axis=1) < 1e-10 + )[0] + if duplicate_check.size > 0: + lane_midline = np.delete(lane_midline, duplicate_check, axis=0) + + return lane_midline + + +def extract_lane_and_edges( + vod_map: VODMap, lane_record +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + # Getting the bounding polygon vertices. + lane_polygon_obj = vod_map.get("polygon", lane_record["polygon_token"]) + polygon_nodes = [ + vod_map.get("node", node_token) + for node_token in lane_polygon_obj["exterior_node_tokens"] + ] + polygon_pts: np.ndarray = np.array( + [(node["x"], node["y"]) for node in polygon_nodes] + ) + + # Getting the lane center's points. + lane_midline: np.ndarray = extract_lane_center(vod_map, lane_record) + + # Computing the closest lane center point to each bounding polygon vertex. + closest_midlane_pt: np.ndarray = np.argmin(cdist(polygon_pts, lane_midline), axis=1) + # Computing the local direction of the lane at each lane center point. + direction_vectors: np.ndarray = np.diff( + lane_midline, + axis=0, + prepend=lane_midline[[0]] - (lane_midline[[1]] - lane_midline[[0]]), + ) + + # Selecting the direction vectors at the closest lane center point per polygon vertex. + local_dir_vecs: np.ndarray = direction_vectors[closest_midlane_pt] + # Calculating the vectors from the the closest lane center point per polygon vertex to the polygon vertex. + origin_to_polygon_vecs: np.ndarray = polygon_pts - lane_midline[closest_midlane_pt] + + # Computing the perpendicular dot product. + # See https://www.xarg.org/book/linear-algebra/2d-perp-product/ + # If perp_dot_product < 0, then the associated polygon vertex is + # on the right edge of the lane. + perp_dot_product: np.ndarray = ( + local_dir_vecs[:, 0] * origin_to_polygon_vecs[:, 1] + - local_dir_vecs[:, 1] * origin_to_polygon_vecs[:, 0] + ) + + # Determining which indices are on the right of the lane center. + on_right: np.ndarray = perp_dot_product < 0 + # Determining the boundary between the left/right polygon vertices + # (they will be together in blocks due to the ordering of the polygon vertices). + idx_changes: int = np.where(np.roll(on_right, 1) < on_right)[0].item() + + if idx_changes > 0: + # If the block of left/right points spreads across the bounds of the array, + # roll it until the boundary between left/right points is at index 0. + # This is important so that the following index selection orders points + # without jumps. + polygon_pts = np.roll(polygon_pts, shift=-idx_changes, axis=0) + on_right = np.roll(on_right, shift=-idx_changes) + + left_pts: np.ndarray = polygon_pts[~on_right] + right_pts: np.ndarray = polygon_pts[on_right] + + # Final ordering check, ensuring that left_pts and right_pts can be combined + # into a polygon without the endpoints intersecting. + # Reversing the one lane edge that does not match the ordering of the midline. + if map_utils.endpoints_intersect(left_pts, right_pts): + if not map_utils.order_matches(left_pts, lane_midline): + left_pts = left_pts[::-1] + else: + right_pts = right_pts[::-1] + + # Ensuring that left and right have the same number of points. + # This is necessary, not for data storage but for later rasterization. + if left_pts.shape[0] < right_pts.shape[0]: + left_pts = map_utils.interpolate(left_pts, num_pts=right_pts.shape[0]) + elif right_pts.shape[0] < left_pts.shape[0]: + right_pts = map_utils.interpolate(right_pts, num_pts=left_pts.shape[0]) + + return ( + lane_midline, + left_pts, + right_pts, + ) + + +def extract_area(vod_map: VODMap, area_record) -> np.ndarray: + token_key: str + if "exterior_node_tokens" in area_record: + token_key = "exterior_node_tokens" + elif "node_tokens" in area_record: + token_key = "node_tokens" + + polygon_nodes = [ + vod_map.get("node", node_token) for node_token in area_record[token_key] + ] + + return np.array([(node["x"], node["y"]) for node in polygon_nodes]) + + +def populate_vector_map(vector_map: VectorMap, vod_map: VODMap) -> None: + # Setting the map bounds. + vector_map.extent = np.array( + [ + vod_map.explorer.canvas_min_x, + vod_map.explorer.canvas_min_y, + 0.0, + vod_map.explorer.canvas_max_x, + vod_map.explorer.canvas_max_y, + 0.0, + ] + ) + + overall_pbar = tqdm( + total=len(vod_map.lane) + + len(vod_map.lane_connector) + + len(vod_map.drivable_area) + + len(vod_map.ped_crossing) + + len(vod_map.walkway), + desc=f"Getting {vod_map.map_name} Elements", + position=1, + leave=False, + ) + + for lane_record in vod_map.lane: + center_pts, left_pts, right_pts = extract_lane_and_edges(vod_map, lane_record) + + lane_record_token: str = lane_record["token"] + + new_lane = RoadLane( + id=lane_record_token, + center=Polyline(center_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), + ) + + for lane_id in vod_map.get_incoming_lane_ids(lane_record_token): + # Need to do this because some incoming/outgoing lane_connector IDs + # do not exist as lane_connectors... + if lane_id in vod_map._token2ind["lane_connector"]: + new_lane.prev_lanes.add(lane_id) + + for lane_id in vod_map.get_outgoing_lane_ids(lane_record_token): + # Need to do this because some incoming/outgoing lane_connector IDs + # do not exist as lane_connectors... + if lane_id in vod_map._token2ind["lane_connector"]: + new_lane.next_lanes.add(lane_id) + + # new_lane.adjacent_lanes_left.append( + # l5_lane.adjacent_lane_change_left.id + # ) + # new_lane.adjacent_lanes_right.append( + # l5_lane.adjacent_lane_change_right.id + # ) + + # Adding the element to the map. + vector_map.add_map_element(new_lane) + overall_pbar.update() + + for lane_record in vod_map.lane_connector: + # Unfortunately lane connectors in VoD have very simple exterior + # polygons which make extracting their edges quite difficult, so we + # only extract the centerline. + center_pts = extract_lane_center(vod_map, lane_record) + + lane_record_token: str = lane_record["token"] + + # Adding the element to the map. + new_lane = RoadLane( + id=lane_record_token, + center=Polyline(center_pts), + ) + + new_lane.prev_lanes.update(vod_map.get_incoming_lane_ids(lane_record_token)) + new_lane.next_lanes.update(vod_map.get_outgoing_lane_ids(lane_record_token)) + + # new_lane.adjacent_lanes_left.append( + # l5_lane.adjacent_lane_change_left.id + # ) + # new_lane.adjacent_lanes_right.append( + # l5_lane.adjacent_lane_change_right.id + # ) + + # Adding the element to the map. + vector_map.add_map_element(new_lane) + overall_pbar.update() + + for drivable_area in vod_map.drivable_area: + for polygon_token in drivable_area["polygon_tokens"]: + if ( + polygon_token is None + and str(None) in vector_map.elements[MapElementType.ROAD_AREA] + ): + # See below, but essentially VoD has two None polygon_tokens + # back-to-back, so we don't need the second one. + continue + + polygon_record = vod_map.get("polygon", polygon_token) + polygon_pts = extract_area(vod_map, polygon_record) + + # NOTE: VoD has some polygon_tokens that are None, although that + # doesn't stop the above get(...) function call so it's fine, + # just have to be mindful of this when creating the id. + new_road_area = RoadArea( + id=str(polygon_token), exterior_polygon=Polyline(polygon_pts) + ) + + for hole in polygon_record["holes"]: + polygon_pts = extract_area(vod_map, hole) + new_road_area.interior_holes.append(Polyline(polygon_pts)) + + # Adding the element to the map. + vector_map.add_map_element(new_road_area) + overall_pbar.update() + + for ped_area_record in vod_map.ped_crossing: + polygon_pts = extract_area(vod_map, ped_area_record) + + # Adding the element to the map. + vector_map.add_map_element( + PedCrosswalk(id=ped_area_record["token"], polygon=Polyline(polygon_pts)) + ) + overall_pbar.update() + + for ped_area_record in vod_map.walkway: + polygon_pts = extract_area(vod_map, ped_area_record) + + # Adding the element to the map. + vector_map.add_map_element( + PedWalkway(id=ped_area_record["token"], polygon=Polyline(polygon_pts)) + ) + overall_pbar.update() + + overall_pbar.close() diff --git a/src/trajdata/dataset_specific/waymo/__init__.py b/src/trajdata/dataset_specific/waymo/__init__.py new file mode 100644 index 0000000..4a43e59 --- /dev/null +++ b/src/trajdata/dataset_specific/waymo/__init__.py @@ -0,0 +1 @@ +from .waymo_dataset import WaymoDataset diff --git a/src/trajdata/dataset_specific/waymo/waymo_dataset.py b/src/trajdata/dataset_specific/waymo/waymo_dataset.py new file mode 100644 index 0000000..4497a79 --- /dev/null +++ b/src/trajdata/dataset_specific/waymo/waymo_dataset.py @@ -0,0 +1,374 @@ +import os +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type + +import numpy as np +import pandas as pd +import tensorflow as tf +import tqdm +from waymo_open_dataset.protos.scenario_pb2 import Scenario + +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures import ( + AgentMetadata, + EnvMetadata, + Scene, + SceneMetadata, + SceneTag, +) +from trajdata.data_structures.agent import ( + Agent, + AgentMetadata, + AgentType, + FixedExtent, + VariableExtent, +) +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import WaymoSceneRecord +from trajdata.dataset_specific.waymo import waymo_utils +from trajdata.dataset_specific.waymo.waymo_utils import ( + WaymoScenarios, + interpolate_array, + translate_agent_type, +) +from trajdata.maps import VectorMap +from trajdata.proto.vectorized_map_pb2 import ( + MapElement, + PedCrosswalk, + RoadLane, + VectorizedMap, +) +from trajdata.utils import arr_utils +from trajdata.utils.parallel_utils import parallel_apply + + +def const_lambda(const_val: Any) -> Any: + return const_val + + +class WaymoDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + if env_name == "waymo_train": + # Waymo possibilities are the Cartesian product of these + dataset_parts = [("train",)] + scene_split_map = defaultdict(partial(const_lambda, const_val="train")) + + elif env_name == "waymo_val": + # Waymo possibilities are the Cartesian product of these + dataset_parts = [("val",)] + scene_split_map = defaultdict(partial(const_lambda, const_val="val")) + + elif env_name == "waymo_test": + # Waymo possibilities are the Cartesian product of these + dataset_parts = [("test",)] + scene_split_map = defaultdict(partial(const_lambda, const_val="test")) + + return EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=waymo_utils.WAYMO_DT, + parts=dataset_parts, + scene_split_map=scene_split_map, + ) + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + print(f"Loading {self.name} dataset...", flush=True) + dataset_name: str = "" + if self.name == "waymo_train": + dataset_name = "training" + elif self.name == "waymo_val": + dataset_name = "validation" + elif self.name == "waymo_test": + dataset_name = "testing" + self.dataset_obj = WaymoScenarios( + dataset_name=dataset_name, source_dir=self.metadata.data_dir + ) + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[WaymoSceneRecord] = list() + + scenes_list: List[SceneMetadata] = list() + for idx in range(self.dataset_obj.num_scenarios): + scene_name: str = "scene_" + str(idx) + scene_split: str = self.metadata.scene_split_map[scene_name] + scene_length: int = self.dataset_obj.scene_length + + # Saving all scene records for later caching. + all_scenes_list.append(WaymoSceneRecord(scene_name, str(scene_length), idx)) + + if scene_split in scene_tag and scene_desc_contains is None: + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=scene_name, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + + self.cache_all_scenes_list(env_cache, all_scenes_list) + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + _, name, _, data_idx = scene_info + scene_name: str = name + scene_split: str = self.metadata.scene_split_map[scene_name] + scene_length: int = self.dataset_obj.scene_length + + return Scene( + self.metadata, + scene_name, + f"{self.name}_{data_idx}", + scene_split, + scene_length, + data_idx, + None, + ) + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list: List[WaymoSceneRecord] = env_cache.load_env_scenes_list( + self.name + ) + + scenes_list: List[SceneMetadata] = list() + for scene_record in all_scenes_list: + scene_name, scene_length, data_idx = scene_record + scene_split: str = self.metadata.scene_split_map[scene_name] + + if scene_split in scene_tag and scene_desc_contains is None: + scene_metadata = Scene( + self.metadata, + scene_name, + # Unfortunately necessary as Waymo does not + # associate each scenario with a location. + f"{self.name}_{data_idx}", + scene_split, + scene_length, + data_idx, + None, # This isn't used if everything is already cached. + ) + scenes_list.append(scene_metadata) + + return scenes_list + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + agent_list: List[AgentMetadata] = [] + agent_presence: List[List[AgentMetadata]] = [ + [] for _ in range(scene.length_timesteps) + ] + + dataset = tf.data.TFRecordDataset( + [str(self.dataset_obj.get_filename(scene.raw_data_idx))], + compression_type="", + ) + scenario: Scenario = Scenario() + for data in dataset: + scenario.ParseFromString(bytearray(data.numpy())) + break + + agent_ids = [] + # agent_ml_class = [] + all_agent_data = [] + agents_to_remove = [] + ego_id = None + for index, track in enumerate(scenario.tracks): + agent_type: AgentType = translate_agent_type(track.object_type) + if agent_type == -1: + continue + + agent_id: int = track.id + agent_ids.append(agent_id) + + # agent_ml_class.append(agent_type) + states = track.states + translations = [] + velocities = [] + sizes = [] + yaws = [] + for state in states: + if state.valid: + translations.append( + (state.center_x, state.center_y, state.center_z) + ) + velocities.append((state.velocity_x, state.velocity_y)) + sizes.append((state.length, state.width, state.height)) + yaws.append(state.heading) + else: + translations.append((np.nan, np.nan, np.nan)) + velocities.append((np.nan, np.nan)) + sizes.append((np.nan, np.nan, np.nan)) + yaws.append(np.nan) + + curr_agent_data = np.concatenate( + ( + translations, + velocities, + np.expand_dims(yaws, axis=1), + sizes, + ), + axis=1, + ) + + curr_agent_data = interpolate_array(curr_agent_data) + all_agent_data.append(curr_agent_data) + + first_timestep = pd.Series(curr_agent_data[:, 0]).first_valid_index() + last_timestep = pd.Series(curr_agent_data[:, 0]).last_valid_index() + if first_timestep is None or last_timestep is None: + first_timestep = 0 + last_timestep = 0 + + agent_name = str(agent_id) + if index == scenario.sdc_track_index: + ego_id = agent_id + agent_name = "ego" + + agent_info = AgentMetadata( + name=agent_name, + agent_type=agent_type, + first_timestep=first_timestep, + last_timestep=last_timestep, + extent=VariableExtent(), + ) + if last_timestep - first_timestep > 0: + agent_list.append(agent_info) + for timestep in range(first_timestep, last_timestep + 1): + agent_presence[timestep].append(agent_info) + else: + agents_to_remove.append(agent_id) + + # agent_ml_class = np.repeat(agent_ml_class, scene.length_timesteps) + # all_agent_data = np.insert(all_agent_data, 6, agent_ml_class, axis=1) + agent_ids = np.repeat(agent_ids, scene.length_timesteps) + traj_cols = ["x", "y", "z", "vx", "vy", "heading"] + # class_cols = ["class_id"] + extent_cols = ["length", "width", "height"] + agent_frame_ids = np.resize( + np.arange(scene.length_timesteps), + len(scenario.tracks) * scene.length_timesteps, + ) + + all_agent_data_df = pd.DataFrame( + np.concatenate(all_agent_data), + columns=traj_cols + extent_cols, + index=[agent_ids, agent_frame_ids], + ) + + all_agent_data_df.index.names = ["agent_id", "scene_ts"] + + # This does exactly the same as dropna(...), but we're keeping the mask around + # for later use with agent_ids. + mask = pd.notna(all_agent_data_df).all(axis=1, bool_only=False) + all_agent_data_df = all_agent_data_df.loc[mask] + + all_agent_data_df.sort_index(inplace=True) + all_agent_data_df.reset_index(level=1, inplace=True) + + all_agent_data_df[["ax", "ay"]] = ( + arr_utils.agent_aware_diff( + all_agent_data_df[["vx", "vy"]].to_numpy(), agent_ids[mask] + ) + / waymo_utils.WAYMO_DT + ) + final_cols = [ + "x", + "y", + "z", + "vx", + "vy", + "ax", + "ay", + "heading", + ] + extent_cols + + # Removing agents with only one detection. + all_agent_data_df.drop(index=agents_to_remove, inplace=True) + + # Changing the agent_id dtype to str + all_agent_data_df.reset_index(inplace=True) + all_agent_data_df["agent_id"] = all_agent_data_df["agent_id"].astype(str) + all_agent_data_df.set_index(["agent_id", "scene_ts"], inplace=True) + all_agent_data_df.rename( + index={str(ego_id): "ego"}, inplace=True, level="agent_id" + ) + + cache_class.save_agent_data( + all_agent_data_df.loc[:, final_cols], + cache_path, + scene, + ) + + tls_dict = waymo_utils.extract_traffic_lights( + dynamic_states=scenario.dynamic_map_states + ) + tls_df = pd.DataFrame( + tls_dict.values(), + index=pd.MultiIndex.from_tuples( + tls_dict.keys(), names=["lane_id", "scene_ts"] + ), + columns=["status"], + ) + cache_class.save_traffic_light_data(tls_df, cache_path, scene) + + return agent_list, agent_presence + + def cache_map( + self, + data_idx: int, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ): + dataset = tf.data.TFRecordDataset( + [str(self.dataset_obj.get_filename(data_idx))], compression_type="" + ) + + scenario: Scenario = Scenario() + for data in dataset: + scenario.ParseFromString(bytearray(data.numpy())) + break + + vector_map: VectorMap = waymo_utils.extract_vectorized( + map_features=scenario.map_features, + map_name=f"{self.name}:{self.name}_{data_idx}", + ) + + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + num_workers: int = map_params.get("num_workers", 0) + if num_workers > 1: + parallel_apply( + partial( + self.cache_map, + cache_path=cache_path, + map_cache_class=map_cache_class, + map_params=map_params, + ), + range(self.dataset_obj.num_scenarios), + num_workers=num_workers, + ) + + else: + for i in tqdm.trange(self.dataset_obj.num_scenarios): + self.cache_map(i, cache_path, map_cache_class, map_params) diff --git a/src/trajdata/dataset_specific/waymo/waymo_utils.py b/src/trajdata/dataset_specific/waymo/waymo_utils.py new file mode 100644 index 0000000..83cc210 --- /dev/null +++ b/src/trajdata/dataset_specific/waymo/waymo_utils.py @@ -0,0 +1,566 @@ +import os +from pathlib import Path +from subprocess import check_call, check_output +from typing import Dict, Final, List, Optional, Tuple + +import numpy as np +import pandas as pd +import tensorflow as tf +from intervaltree import Interval, IntervalTree +from tqdm import tqdm +from waymo_open_dataset.protos import map_pb2 as waymo_map_pb2 +from waymo_open_dataset.protos import scenario_pb2 + +from trajdata.maps import TrafficLightStatus, VectorMap +from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadLane + +WAYMO_DT: Final[float] = 0.1 +WAYMO_DATASET_NAMES = [ + "testing", + "testing_interactive", + "training", + "training_20s", + "validation", + "validation_interactive", +] + +TRAIN_SCENE_LENGTH = 91 +VAL_SCENE_LENGTH = 91 +TEST_SCENE_LENGTH = 11 +TRAIN_20S_SCENE_LENGTH = 201 + +GREEN = [ + waymo_map_pb2.TrafficSignalLaneState.State.LANE_STATE_ARROW_GO, + waymo_map_pb2.TrafficSignalLaneState.State.LANE_STATE_GO, +] +RED = [ + waymo_map_pb2.TrafficSignalLaneState.State.LANE_STATE_ARROW_CAUTION, + waymo_map_pb2.TrafficSignalLaneState.State.LANE_STATE_ARROW_STOP, + waymo_map_pb2.TrafficSignalLaneState.State.LANE_STATE_STOP, + waymo_map_pb2.TrafficSignalLaneState.State.LANE_STATE_CAUTION, + waymo_map_pb2.TrafficSignalLaneState.State.LANE_STATE_FLASHING_STOP, + waymo_map_pb2.TrafficSignalLaneState.State.LANE_STATE_FLASHING_CAUTION, +] + +from trajdata.data_structures.agent import ( + Agent, + AgentMetadata, + AgentType, + FixedExtent, + VariableExtent, +) + + +class WaymoScenarios: + def __init__( + self, + dataset_name: str, + source_dir: Path, + download: bool = False, + split: bool = False, + ): + if dataset_name not in WAYMO_DATASET_NAMES: + raise RuntimeError( + "Wrong dataset name. Please choose name from " + + str(WAYMO_DATASET_NAMES) + ) + + self.name = dataset_name + self.source_dir = source_dir + if dataset_name in ["training"]: + self.scene_length = TRAIN_SCENE_LENGTH + elif dataset_name in ["validation", "validation_interactive"]: + self.scene_length = VAL_SCENE_LENGTH + elif dataset_name in ["testing", "testing_interactive"]: + self.scene_length = TEST_SCENE_LENGTH + elif dataset_name in ["training_20s"]: + self.scene_length = TRAIN_20S_SCENE_LENGTH + + if download: + self.download_dataset() + + split_path = self.source_dir / (self.name + "_splitted") + if split or not split_path.is_dir(): + self.split_scenarios() + else: + self.num_scenarios = len(os.listdir(split_path)) + + def download_dataset(self) -> None: + # check_call("snap install google-cloud-sdk --classic".split()) + gsutil = check_output(["which", "gsutil"]) + download_cmd = ( + str(gsutil.decode("utf-8")) + + "-m cp -r gs://waymo_open_dataset_motion_v_1_1_0/uncompressed/scenario/" + + str(self.name) + + " " + + str(self.source_dir) + ).split() + check_call(download_cmd) + + def split_scenarios( + self, num_parallel_reads: int = 20, verbose: bool = True + ) -> None: + source_it: Path = (self.source_dir / self.name).glob("*") + file_names: List[str] = [str(file_name) for file_name in source_it] + if verbose: + print("Loading tfrecord files...") + dataset = tf.data.TFRecordDataset( + file_names, compression_type="", num_parallel_reads=num_parallel_reads + ) + + if verbose: + print("Splitting tfrecords...") + + splitted_dir: Path = self.source_dir / f"{self.name}_splitted" + if not splitted_dir.exists(): + splitted_dir.mkdir(parents=True) + + scenario_num: int = 0 + for data in tqdm(dataset): + file_name: Path = ( + splitted_dir / f"{self.name}_splitted_{scenario_num}.tfrecords" + ) + with tf.io.TFRecordWriter(str(file_name)) as file_writer: + file_writer.write(data.numpy()) + + scenario_num += 1 + + self.num_scenarios = scenario_num + if verbose: + print( + str(self.num_scenarios) + + " scenarios from " + + str(len(file_names)) + + " file(s) have been split into " + + str(self.num_scenarios) + + " files." + ) + + def get_filename(self, data_idx): + return ( + self.source_dir + / f"{self.name}_splitted" + / f"{self.name}_splitted_{data_idx}.tfrecords" + ) + + +def extract_vectorized( + map_features: List[waymo_map_pb2.MapFeature], map_name: str, verbose: bool = False +) -> VectorMap: + vec_map = VectorMap(map_id=map_name) + + max_pt = np.array([np.nan, np.nan, np.nan]) + min_pt = np.array([np.nan, np.nan, np.nan]) + + boundaries: Dict[int, Polyline] = {} + for map_feature in tqdm( + map_features, desc="Extracting road boundaries", disable=not verbose + ): + if map_feature.WhichOneof("feature_data") == "road_line": + boundaries[map_feature.id] = Polyline( + np.array([(pt.x, pt.y, pt.z) for pt in map_feature.road_line.polyline]) + ) + elif map_feature.WhichOneof("feature_data") == "road_edge": + boundaries[map_feature.id] = Polyline( + np.array([(pt.x, pt.y, pt.z) for pt in map_feature.road_edge.polyline]) + ) + + lane_id_remap_dict = {} + for map_feature in tqdm( + map_features, desc="Extracting map elements", disable=not verbose + ): + if map_feature.WhichOneof("feature_data") == "lane": + if len(map_feature.lane.polyline) == 1: + # TODO: Why does Waymo have single-point polylines that + # aren't interpolating between others?? + continue + + road_lanes, modified_lane_ids = translate_lane(map_feature, boundaries) + if modified_lane_ids: + lane_id_remap_dict.update(modified_lane_ids) + + for road_lane in road_lanes: + vec_map.add_map_element(road_lane) + + max_pt = np.fmax(max_pt, road_lane.center.xyz.max(axis=0)) + min_pt = np.fmin(min_pt, road_lane.center.xyz.min(axis=0)) + + if road_lane.left_edge: + max_pt = np.fmax(max_pt, road_lane.left_edge.xyz.max(axis=0)) + min_pt = np.fmin(min_pt, road_lane.left_edge.xyz.min(axis=0)) + + if road_lane.right_edge: + max_pt = np.fmax(max_pt, road_lane.right_edge.xyz.max(axis=0)) + min_pt = np.fmin(min_pt, road_lane.right_edge.xyz.min(axis=0)) + + elif map_feature.WhichOneof("feature_data") == "crosswalk": + crosswalk = PedCrosswalk( + id=str(map_feature.id), + polygon=Polyline( + np.array( + [(pt.x, pt.y, pt.z) for pt in map_feature.crosswalk.polygon] + ) + ), + ) + vec_map.add_map_element(crosswalk) + + max_pt = np.fmax(max_pt, crosswalk.polygon.xyz.max(axis=0)) + min_pt = np.fmin(min_pt, crosswalk.polygon.xyz.min(axis=0)) + + else: + continue + + for elem in vec_map.iter_elems(): + if not isinstance(elem, RoadLane): + continue + + to_remove = set() + to_add = set() + for l_id in elem.adj_lanes_left: + if l_id in lane_id_remap_dict: + # Remove the original lanes, replace them with our chunked versions. + to_remove.add(l_id) + to_add.update(lane_id_remap_dict[l_id]) + + elem.adj_lanes_left -= to_remove + elem.adj_lanes_left |= to_add + + to_remove = set() + to_add = set() + for l_id in elem.adj_lanes_right: + if l_id in lane_id_remap_dict: + # Remove the original lanes, replace them with our chunked versions. + to_remove.add(l_id) + to_add.update(lane_id_remap_dict[l_id]) + + elem.adj_lanes_right -= to_remove + elem.adj_lanes_right |= to_add + + to_remove = set() + to_add = set() + for l_id in elem.prev_lanes: + if l_id in lane_id_remap_dict: + # Remove the original prev lanes, replace them with + # the tail of our equivalent chunked version. + to_remove.add(l_id) + to_add.add(lane_id_remap_dict[l_id][-1]) + + elem.prev_lanes -= to_remove + elem.prev_lanes |= to_add + + to_remove = set() + to_add = set() + for l_id in elem.next_lanes: + if l_id in lane_id_remap_dict: + # Remove the original prev lanes, replace them with + # the first of our equivalent chunked version. + to_remove.add(l_id) + to_add.add(lane_id_remap_dict[l_id][0]) + + elem.next_lanes -= to_remove + elem.next_lanes |= to_add + + # Setting the map bounds. + # vec_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + vec_map.extent = np.concatenate((min_pt, max_pt)) + + return vec_map + + +def translate_agent_type(agent_type): + if agent_type == scenario_pb2.Track.ObjectType.TYPE_VEHICLE: + return AgentType.VEHICLE + elif agent_type == scenario_pb2.Track.ObjectType.TYPE_PEDESTRIAN: + return AgentType.PEDESTRIAN + elif agent_type == scenario_pb2.Track.ObjectType.TYPE_CYCLIST: + return AgentType.BICYCLE + elif agent_type == scenario_pb2.Track.ObjectType.OTHER: + return AgentType.UNKNOWN + return -1 + + +def is_full_boundary(lane_boundaries, num_lane_indices: int) -> bool: + """Returns True if a given boundary is connected (there are no gaps) + and every lane center index has a corresponding boundary point. + + Returns: + bool + """ + covers_all: bool = lane_boundaries[0].lane_start_index == 0 and lane_boundaries[ + 0 + ].lane_end_index == (num_lane_indices - 1) + for idx in range(1, len(lane_boundaries)): + if ( + lane_boundaries[idx].lane_start_index + != lane_boundaries[idx - 1].lane_end_index + 1 + ): + covers_all = False + break + + return covers_all + + +def _merge_interval_data(data1: str, data2: str) -> str: + if data1 == data2 == "none": + return "none" + + if data1 == "none" and data2 != "none": + return data2 + + if data1 != "none" and data2 == "none": + return data1 + + if data1 != "none" and data2 != "none": + return "both" + + +def split_lane_into_chunks( + lane: waymo_map_pb2.LaneCenter, boundaries: Dict[int, Polyline] +) -> List[Tuple[Polyline, Optional[Polyline], Optional[Polyline]]]: + boundary_intervals = IntervalTree.from_tuples( + [ + (b.lane_start_index, b.lane_end_index + 1, "left") + for b in lane.left_boundaries + ] + + [ + (b.lane_start_index, b.lane_end_index + 1, "right") + for b in lane.right_boundaries + ] + + [(0, len(lane.polyline), "none")] + ) + + boundary_intervals.split_overlaps() + + boundary_intervals.merge_equals(data_reducer=_merge_interval_data) + intervals: List[Interval] = sorted(boundary_intervals) + + if len(intervals) > 1: + merged_intervals: List[Interval] = [intervals.pop(0)] + while intervals: + last_interval: Interval = merged_intervals[-1] + curr_interval: Interval = intervals.pop(0) + + if last_interval.end != curr_interval.begin: + raise ValueError("Non-consecutive intervals in merging!") + + if last_interval.data == curr_interval.data: + # Simple merging of same-data neighbors. + merged_intervals[-1] = Interval( + last_interval.begin, + curr_interval.end, + last_interval.data, + ) + elif ( + last_interval.end - last_interval.begin == 1 + or curr_interval.end - curr_interval.begin == 1 + ): + # Trying to remove 1-length chunks by merging them with neighbors. + data_to_keep: str = ( + curr_interval.data + if curr_interval.end - curr_interval.begin + > last_interval.end - last_interval.begin + else last_interval.data + ) + + merged_intervals[-1] = Interval( + last_interval.begin, + curr_interval.end, + data_to_keep, + ) + else: + merged_intervals.append(curr_interval) + + intervals = merged_intervals + + left_boundary_tree = IntervalTree.from_tuples( + [ + (b.lane_start_index, b.lane_end_index + 1, b.boundary_feature_id) + for b in lane.left_boundaries + ] + ) + right_boundary_tree = IntervalTree.from_tuples( + [ + (b.lane_start_index, b.lane_end_index + 1, b.boundary_feature_id) + for b in lane.right_boundaries + ] + ) + lane_chunk_data: List[Tuple[Polyline, Optional[Polyline], Optional[Polyline]]] = [] + for interval in intervals: + center_chunk = Polyline( + np.array( + [ + (point.x, point.y, point.z) + for point in lane.polyline[interval.begin : interval.end] + ] + ) + ) + if interval.data == "none": + lane_chunk_data.append((center_chunk, None, None)) + elif interval.data == "left": + left_chunk = subselect_boundary( + boundaries, center_chunk, interval, left_boundary_tree + ) + lane_chunk_data.append((center_chunk, left_chunk, None)) + elif interval.data == "right": + right_chunk = subselect_boundary( + boundaries, center_chunk, interval, right_boundary_tree + ) + lane_chunk_data.append((center_chunk, None, right_chunk)) + elif interval.data == "both": + left_chunk = subselect_boundary( + boundaries, center_chunk, interval, left_boundary_tree + ) + right_chunk = subselect_boundary( + boundaries, center_chunk, interval, right_boundary_tree + ) + lane_chunk_data.append((center_chunk, left_chunk, right_chunk)) + else: + raise ValueError() + + return lane_chunk_data + + +def subselect_boundary( + boundaries: Dict[int, Polyline], + lane_center: Polyline, + chunk_interval: Interval, + boundary_tree: IntervalTree, +) -> Polyline: + relevant_boundaries: List[Interval] = sorted( + boundary_tree[chunk_interval.begin : chunk_interval.end] + ) + + if ( + len(relevant_boundaries) == 1 + and relevant_boundaries[0].begin == chunk_interval.begin + and relevant_boundaries[0].end == chunk_interval.end + ): + # Return immediately for an exact match. + return boundaries[relevant_boundaries[0].data] + + polyline_pts: List[Polyline] = [] + for boundary_interval in relevant_boundaries: + # Below we are trying to find relevant boundary regions to the current lane chunk center + # by projecting the boundary onto the lane and seeing where it stops following the center. + # After that point, the projections will cease to change + # (they will typically all be the last point of the center line). + boundary = boundaries[boundary_interval.data] + + if boundary.points.shape[0] == 1: + polyline_pts.append(boundary.points) + continue + + proj = lane_center.project_onto(boundary.points) + local_diffs = np.diff(proj, axis=0, append=proj[[-1]] - proj[[-2]]) + + nonzero_mask = (local_diffs != 0.0).any(axis=1) + nonzero_idxs = np.nonzero(nonzero_mask)[0] + marker_idx = np.nonzero(np.ediff1d(nonzero_idxs, to_begin=[2]) > 1)[0] + + # TODO(bivanovic): Only taking the first group. Adding 1 to the + # first ends because it otherwise ignores the first element of + # the repeated value group. + start = np.minimum.reduceat(nonzero_idxs, marker_idx)[0] + end = np.maximum.reduceat(nonzero_idxs, marker_idx)[0] + 1 + + # TODO(bivanovic): This may or may not end up being a problem, but + # polyline_pts[0][-1] and polyline_pts[1][0] can be exactly identical. + polyline_pts.append(boundary.points[start : end + 1]) + + return Polyline(points=np.concatenate(polyline_pts, axis=0)) + + +def translate_lane( + map_feature: waymo_map_pb2.MapFeature, + boundaries: Dict[int, Polyline], +) -> Tuple[RoadLane, Optional[Dict[int, List[bytes]]]]: + lane: waymo_map_pb2.LaneCenter = map_feature.lane + + if lane.left_boundaries or lane.right_boundaries: + # Waymo lane boundaries are... complicated. See + # https://github.com/waymo-research/waymo-open-dataset/issues/389 + # for more information. For now, we split lanes into chunks which + # have consistent lane boundaries (either both left and right, + # one of them, or none). + lane_chunks = split_lane_into_chunks(lane, boundaries) + road_lanes: List[RoadLane] = [] + new_ids: List[bytes] = [] + for idx, (lane_center, left_edge, right_edge) in enumerate(lane_chunks): + road_lane = RoadLane( + id=f"{map_feature.id}_{idx}" + if len(lane_chunks) > 1 + else str(map_feature.id), + center=lane_center, + left_edge=left_edge, + right_edge=right_edge, + ) + new_ids.append(road_lane.id) + + if idx == 0: + road_lane.prev_lanes.update([str(eid) for eid in lane.entry_lanes]) + else: + road_lane.prev_lanes.add(f"{map_feature.id}_{idx-1}") + + if idx == len(lane_chunks) - 1: + road_lane.next_lanes.update([str(eid) for eid in lane.exit_lanes]) + else: + road_lane.next_lanes.add(f"{map_feature.id}_{idx+1}") + + # We'll take care of reassigning these IDs to the chunked versions later. + for neighbor in lane.left_neighbors: + road_lane.adj_lanes_left.add(str(neighbor.feature_id)) + + for neighbor in lane.right_neighbors: + road_lane.adj_lanes_right.add(str(neighbor.feature_id)) + + road_lanes.append(road_lane) + + if len(lane_chunks) > 1: + return road_lanes, {str(map_feature.id): new_ids} + else: + return road_lanes, None + + else: + road_lane = RoadLane( + id=str(map_feature.id), + center=Polyline(np.array([(pt.x, pt.y, pt.z) for pt in lane.polyline])), + ) + + road_lane.prev_lanes.update([str(eid) for eid in lane.entry_lanes]) + road_lane.next_lanes.update([str(eid) for eid in lane.exit_lanes]) + + for neighbor in lane.left_neighbors: + road_lane.adj_lanes_left.add(str(neighbor.feature_id)) + + for neighbor in lane.right_neighbors: + road_lane.adj_lanes_right.add(str(neighbor.feature_id)) + + return [road_lane], None + + +def extract_traffic_lights( + dynamic_states: List[scenario_pb2.DynamicMapState], +) -> Dict[Tuple[str, int], TrafficLightStatus]: + ret: Dict[Tuple[str, int], TrafficLightStatus] = {} + for i, dynamic_state in enumerate(dynamic_states): + for lane_state in dynamic_state.lane_states: + ret[(str(lane_state.lane), i)] = translate_traffic_state(lane_state.state) + + return ret + + +def translate_traffic_state( + state: waymo_map_pb2.TrafficSignalLaneState.State, +) -> TrafficLightStatus: + # TODO(bivanovic): The traffic light type doesn't align between waymo and trajdata, + # since trajdata's TrafficLightStatus does not include a yellow light yet. + # For now, we set caution = red. + if state in GREEN: + return TrafficLightStatus.GREEN + if state in RED: + return TrafficLightStatus.RED + return TrafficLightStatus.UNKNOWN + + +def interpolate_array(data: List) -> np.array: + return pd.DataFrame(data).interpolate(limit_area="inside").to_numpy() diff --git a/src/trajdata/filtering/filters.py b/src/trajdata/filtering/filters.py index de198da..156ac48 100644 --- a/src/trajdata/filtering/filters.py +++ b/src/trajdata/filtering/filters.py @@ -1,3 +1,4 @@ +from decimal import Decimal from math import ceil from typing import List, Optional, Set, Tuple @@ -53,12 +54,12 @@ def get_valid_ts( """ first_valid_ts = agent_info.first_timestep if history_sec[0] is not None: - min_history = ceil(history_sec[0] / dt) + min_history = ceil(Decimal(str(history_sec[0])) / Decimal(str(dt))) first_valid_ts += min_history last_valid_ts = agent_info.last_timestep if future_sec[0] is not None: - min_future = ceil(future_sec[0] / dt) + min_future = ceil(Decimal(str(future_sec[0])) / Decimal(str(dt))) last_valid_ts -= min_future return first_valid_ts, last_valid_ts @@ -71,7 +72,7 @@ def satisfies_history( history_sec: Tuple[Optional[float], Optional[float]], ) -> bool: if history_sec[0] is not None: - min_history = ceil(history_sec[0] / dt) + min_history = ceil(Decimal(str(history_sec[0])) / Decimal(str(dt))) agent_history_satisfies = ts - agent_info.first_timestep >= min_history else: agent_history_satisfies = True @@ -86,7 +87,7 @@ def satisfies_future( future_sec: Tuple[Optional[float], Optional[float]], ) -> bool: if future_sec[0] is not None: - min_future = ceil(future_sec[0] / dt) + min_future = ceil(Decimal(str(future_sec[0])) / Decimal(str(dt))) agent_future_satisfies = agent_info.last_timestep - ts >= min_future else: agent_future_satisfies = True diff --git a/src/trajdata/maps/__init__.py b/src/trajdata/maps/__init__.py index ea04c8d..2ef0bf6 100644 --- a/src/trajdata/maps/__init__.py +++ b/src/trajdata/maps/__init__.py @@ -1,2 +1,4 @@ -from .map import RasterizedMap, RasterizedMapMetadata -from .map_patch import RasterizedMapPatch +from .map_api import MapAPI +from .raster_map import RasterizedMap, RasterizedMapMetadata, RasterizedMapPatch +from .traffic_light_status import TrafficLightStatus +from .vec_map import VectorMap diff --git a/src/trajdata/maps/lane_route.py b/src/trajdata/maps/lane_route.py new file mode 100644 index 0000000..fe70388 --- /dev/null +++ b/src/trajdata/maps/lane_route.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass +from typing import Set + + +@dataclass +class LaneRoute: + lane_idxs: Set[int] diff --git a/src/trajdata/maps/map_api.py b/src/trajdata/maps/map_api.py new file mode 100644 index 0000000..6e58aab --- /dev/null +++ b/src/trajdata/maps/map_api.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from trajdata.maps.map_kdtree import MapElementKDTree + from trajdata.caching.scene_cache import SceneCache + +from pathlib import Path +from typing import Dict + +from trajdata.maps.vec_map import VectorMap +from trajdata.proto.vectorized_map_pb2 import VectorizedMap +from trajdata.utils import map_utils + + +class MapAPI: + def __init__(self, unified_cache_path: Path, keep_in_memory: bool = False) -> None: + """A simple interface for loading trajdata's vector maps which does not require + instantiation of a `UnifiedDataset` object. + + Args: + unified_cache_path (Path): Path to trajdata's local cache on disk. + keep_in_memory (bool): Whether loaded maps should be stored + in memory (memoized) for later re-use. For most cases (e.g., batched dataloading), + this is a good idea. However, this can cause rapid memory usage growth for some + datasets (e.g., Waymo) and it can be better to disable this. Defaults to False. + """ + self.unified_cache_path: Path = unified_cache_path + self.maps: Dict[str, VectorMap] = dict() + self._keep_in_memory = keep_in_memory + + def get_map( + self, map_id: str, scene_cache: Optional[SceneCache] = None, **kwargs + ) -> VectorMap: + if map_id not in self.maps: + env_name, map_name = map_id.split(":") + env_maps_path: Path = self.unified_cache_path / env_name / "maps" + stored_vec_map: VectorizedMap = map_utils.load_vector_map( + env_maps_path / f"{map_name}.pb" + ) + + vec_map: VectorMap = VectorMap.from_proto(stored_vec_map, **kwargs) + vec_map.search_kdtrees = map_utils.load_kdtrees( + env_maps_path / f"{map_name}_kdtrees.dill" + ) + vec_map.search_rtrees = map_utils.load_rtrees( + env_maps_path / f"{map_name}_rtrees.dill" + ) + + if self._keep_in_memory: + self.maps[map_id] = vec_map + else: + vec_map = self.maps[map_id] + + if scene_cache is not None: + vec_map.associate_scene_data( + scene_cache.get_traffic_light_status_dict( + kwargs.get("desired_dt", None) + ) + ) + + return vec_map diff --git a/src/trajdata/maps/map_kdtree.py b/src/trajdata/maps/map_kdtree.py new file mode 100644 index 0000000..c9f36d0 --- /dev/null +++ b/src/trajdata/maps/map_kdtree.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING, Dict + +if TYPE_CHECKING: + from trajdata.maps.vec_map import VectorMap + +from typing import Optional, Tuple + +import numpy as np +from scipy.spatial import KDTree +from tqdm import tqdm + +from trajdata.maps.vec_map_elements import MapElement, MapElementType, Polyline +from trajdata.utils.arr_utils import angle_wrap + + +class MapElementKDTree: + """ + Constructs a KDTree of MapElements and exposes fast lookup functions. + + Inheriting classes need to implement the _extract_points function that defines for a MapElement + the coordinates we want to store in the KDTree. + """ + + def __init__(self, vector_map: VectorMap, verbose: bool = False) -> None: + # Build kd-tree + self.kdtree, self.polyline_inds, self.metadata = self._build_kdtree( + vector_map, verbose + ) + + def _build_kdtree(self, vector_map: VectorMap, verbose: bool = False): + polylines = [] + polyline_inds = [] + metadata = defaultdict(list) + + map_elem: MapElement + for map_elem in tqdm( + vector_map.iter_elems(), + desc=f"Building K-D Trees", + leave=False, + total=len(vector_map), + disable=not verbose, + ): + result = self._extract_points_and_metadata(map_elem) + if result is not None: + points, extras = result + polyline_inds.extend([len(polylines)] * points.shape[0]) + + # Apply any map offsets to ensure we're in the same coordinate area as the + # original world map. + polylines.append(points) + + for k, v in extras.items(): + metadata[k].append(v) + + points = np.concatenate(polylines, axis=0) + polyline_inds = np.array(polyline_inds) + metadata = {k: np.concatenate(v) for k, v in metadata.items()} + + kdtree = KDTree(points) + return kdtree, polyline_inds, metadata + + def _extract_points_and_metadata( + self, map_element: MapElement + ) -> Optional[Tuple[np.ndarray, Dict[str, np.ndarray]]]: + """Defines the coordinates we want to store in the KDTree for a MapElement. + Args: + map_element (MapElement): the MapElement to store in the KDTree. + Returns: + Optional[np.ndarray]: coordinates based on which we can search the KDTree, or None. + If None, the MapElement will not be stored. + Else, tuple of + np.ndarray: [B,d] set of B d-dimensional points to add, + Dict[str, np.ndarray] mapping names to meta-information about the points + """ + raise NotImplementedError() + + def closest_point(self, query_points: np.ndarray) -> np.ndarray: + """Find the closest KDTree points to (a batch of) query points. + + Args: + query_points: np.ndarray of shape (..., data_dim). + + Return: + np.ndarray of shape (..., data_dim), the KDTree points closest to query_point. + """ + _, data_inds = self.kdtree.query(query_points, k=1) + pts = self.kdtree.data[data_inds] + return pts + + def closest_polyline_ind(self, query_points: np.ndarray) -> np.ndarray: + """Find the index of the closest polyline(s) in self.polylines.""" + _, data_ind = self.kdtree.query(query_points, k=1) + return self.polyline_inds[data_ind] + + def polyline_inds_in_range(self, point: np.ndarray, range: float) -> np.ndarray: + """Find the index of polylines in self.polylines within 'range' distance to 'point'.""" + data_inds = self.kdtree.query_ball_point(point, range) + return np.unique(self.polyline_inds[data_inds], axis=0) + + +class LaneCenterKDTree(MapElementKDTree): + """KDTree for lane center polylines.""" + + def __init__( + self, vector_map: VectorMap, max_segment_len: Optional[float] = None + ) -> None: + """ + Args: + vec_map: the VectorizedMap object to build the KDTree for + max_segment_len (float, optional): if specified, we will insert extra points into the KDTree + such that all polyline segments are shorter then max_segment_len. + """ + self.max_segment_len = max_segment_len + super().__init__(vector_map) + + def _extract_points_and_metadata( + self, map_element: MapElement + ) -> Optional[Tuple[np.ndarray, Dict[str, np.ndarray]]]: + if map_element.elem_type == MapElementType.ROAD_LANE: + pts: Polyline = map_element.center + if self.max_segment_len is not None: + pts = pts.interpolate(max_dist=self.max_segment_len) + + # We only want to store xyz in the kdtree, not heading. + return pts.xyz, {"heading": pts.h} + else: + return None + + def current_lane_inds( + self, + xyzh: np.ndarray, + distance_threshold: float, + heading_threshold: float, + sorted: bool = True, + dist_weight: float = 1.0, + heading_weight: float = 0.1, + ) -> np.ndarray: + """ + Args: + xyzh (np.ndarray): [...,d]: (batch of) position and heading in world frame + distance_threshold (Optional[float], optional). Defaults to None. + heading_threshold (float, optional). Defaults to np.pi/8. + + Returns: + np.ndarray: List of polyline inds that could be considered the current lane + for the provided position and heading, ordered by heading similarity + """ + query_point = xyzh[..., :3] # query on xyz + heading = xyzh[..., 3] + data_inds = np.array( + self.kdtree.query_ball_point(query_point, distance_threshold) + ) + + if len(data_inds) == 0: + return [] + possible_points = self.kdtree.data[data_inds] + possible_headings = self.metadata["heading"][data_inds] + + heading_errs = np.abs(angle_wrap(heading - possible_headings)) + dist_errs = np.linalg.norm( + query_point[None, :] - possible_points, ord=2, axis=-1 + ) + + under_thresh = heading_errs < heading_threshold + lane_inds = self.polyline_inds[data_inds[under_thresh]] + + # we don't want to return duplicates of lanes + unique_lane_inds = np.unique(lane_inds) + + if not sorted: + return unique_lane_inds + + # if we are sorting results, evaluate cost: + costs = ( + dist_weight * dist_errs[under_thresh] + + heading_weight * heading_errs[under_thresh] + ) + + # cost for a lane is minimum over all possible points for that lane + min_costs = [np.min(costs[lane_inds == ind]) for ind in unique_lane_inds] + + return unique_lane_inds[np.argsort(min_costs)] diff --git a/src/trajdata/maps/map_patch.py b/src/trajdata/maps/map_patch.py deleted file mode 100644 index 4416ae8..0000000 --- a/src/trajdata/maps/map_patch.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np - - -class RasterizedMapPatch: - def __init__( - self, - data: np.ndarray, - rot_angle: float, - crop_size: int, - resolution: float, - raster_from_world_tf: np.ndarray, - has_data: bool, - ) -> None: - self.data = data - self.rot_angle = rot_angle - self.crop_size = crop_size - self.resolution = resolution - self.raster_from_world_tf = raster_from_world_tf - self.has_data = has_data diff --git a/src/trajdata/maps/map_strtree.py b/src/trajdata/maps/map_strtree.py new file mode 100644 index 0000000..e083fff --- /dev/null +++ b/src/trajdata/maps/map_strtree.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple + +if TYPE_CHECKING: + from trajdata.maps.vec_map import VectorMap + +import numpy as np +from shapely import LinearRing, Polygon, STRtree, linearrings, points, polygons +from tqdm import tqdm + +from trajdata.maps.vec_map_elements import ( + MapElement, + MapElementType, + PedCrosswalk, + PedWalkway, + RoadArea, +) + + +def polygon_with_holes_geometry(map_element: MapElement) -> Polygon: + assert isinstance(map_element, RoadArea) + points = linearrings(map_element.exterior_polygon.xy) + holes: Optional[List[LinearRing]] = None + if len(map_element.interior_holes) > 0: + holes = [linearrings(hole.xy) for hole in map_element.interior_holes] + + return polygons(points, holes=holes) + + +def polygon_geometry(map_element: MapElement) -> Polygon: + assert isinstance(map_element, (PedWalkway, PedCrosswalk)) + return polygons(map_element.polygon.xy) + + +# Dictionary mapping map_elem_type to function returning +# shapely polygon for that map element +MAP_ELEM_TO_GEOMETRY: Dict[MapElementType, Callable[[MapElement], Polygon]] = { + MapElementType.ROAD_AREA: polygon_with_holes_geometry, + MapElementType.PED_CROSSWALK: polygon_geometry, + MapElementType.PED_WALKWAY: polygon_geometry, +} + + +class MapElementSTRTree: + """ + Constructs an Rtree of Polygonal MapElements and exposes fast lookup functions. + + Inheriting classes need to implement the _extract_geometry function which for a MapElement + returns the geometry we want to store + """ + + def __init__( + self, + vector_map: VectorMap, + elem_type: MapElementType, + verbose: bool = False, + ) -> None: + # Build R-tree + self.strtree, self.elem_ids = self._build_strtree( + vector_map, elem_type, verbose + ) + + def _build_strtree( + self, + vector_map: VectorMap, + elem_type: MapElementType, + verbose: bool = False, + ) -> Tuple[STRtree, np.ndarray]: + geometries: List[Polygon] = [] + ids: List[str] = [] + geometry_fn = MAP_ELEM_TO_GEOMETRY[elem_type] + + map_elem: MapElement + for id, map_elem in tqdm( + vector_map.elements[elem_type].items(), + desc=f"Building STR Tree for {elem_type.name} elements", + leave=False, + disable=not verbose, + ): + ids.append(id) + geometries.append(geometry_fn(map_elem)) + + return STRtree(geometries), np.array(ids) + + def query_point( + self, + point: np.ndarray, + **kwargs, + ) -> np.ndarray: + """ + Returns ID of all elements of type elem_type + that intersect with query point + + Args: + point (np.ndarray): point to query + elem_type (MapElementType): type of elem to query + kwargs: passed on to STRtree.query(), see + https://pygeos.readthedocs.io/en/latest/strtree.html + Can be used for predicate based queries, e.g. + predicate="dwithin", distance=100. + returns all elements which are within 100m of query point + + Returns: + np.ndarray[str]: 1d array of ids of all elements matching query + """ + indices = self.strtree.query(points(point), **kwargs) + return self.elem_ids[indices] + + def nearest_area( + self, + point: np.ndarray, + **kwargs, + ) -> str: + """ + Returns ID of the elements of type elem_type + that are closest to point. + + Args: + point (np.ndarray): point to query + elem_type (MapElementType): type of elem to query + kwargs: passed on to STRtree.nearest(), see + https://pygeos.readthedocs.io/en/latest/strtree.html + + Returns: + str: element_id of all elements matching query + """ + idx = self.strtree.nearest(points(point), **kwargs) + return self.elem_ids[idx] diff --git a/src/trajdata/maps/map_utils.py b/src/trajdata/maps/map_utils.py deleted file mode 100644 index b40e14d..0000000 --- a/src/trajdata/maps/map_utils.py +++ /dev/null @@ -1,291 +0,0 @@ -from math import ceil -from typing import Any, Dict, Final, List, Optional, Tuple - -import cv2 -import numpy as np -from tqdm import tqdm - -from trajdata.proto.vectorized_map_pb2 import ( - MapElement, - Polyline, - RoadLane, - VectorizedMap, -) - -# Sub-pixel drawing precision constants. -# See https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/rasterization/semantic_rasterizer.py#L16 -CV2_SUB_VALUES = {"shift": 9, "lineType": cv2.LINE_AA} -CV2_SHIFT_VALUE = 2 ** CV2_SUB_VALUES["shift"] - -MM_PER_M: Final[float] = 1000 - - -def cv2_subpixel(coords: np.ndarray) -> np.ndarray: - """ - Cast coordinates to numpy.int but keep fractional part by previously multiplying by 2**CV2_SHIFT - cv2 calls will use shift to restore original values with higher precision - - Args: - coords (np.ndarray): XY coords as float - - Returns: - np.ndarray: XY coords as int for cv2 shift draw - """ - return (coords * CV2_SHIFT_VALUE).astype(np.int) - - -def decompress_values(data: np.ndarray) -> np.ndarray: - # From https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 - # The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from - # the origin. For subsequent points, this field stores the difference between the point's - # coordinates and the previous point's coordinates. This is for representation efficiency. - return np.cumsum(data, axis=0, dtype=np.float) / MM_PER_M - - -def compress_values(data: np.ndarray) -> np.ndarray: - return (np.diff(data, axis=0, prepend=0.0) * MM_PER_M).astype(np.int32) - - -def populate_lane_polylines( - new_lane: RoadLane, - midlane_pts: np.ndarray, - left_pts: np.ndarray, - right_pts: np.ndarray, -) -> None: - """Fill a Lane object's polyline attributes. - All points should be in world coordinates. - - Args: - new_lane (Lane): _description_ - midlane_pts (np.ndarray): _description_ - left_pts (np.ndarray): _description_ - right_pts (np.ndarray): _description_ - """ - compressed_mid_pts: np.ndarray = compress_values(midlane_pts) - compressed_left_pts: np.ndarray = compress_values(left_pts) - compressed_right_pts: np.ndarray = compress_values(right_pts) - - new_lane.center.dx_mm.extend(compressed_mid_pts[:, 0].tolist()) - new_lane.center.dy_mm.extend(compressed_mid_pts[:, 1].tolist()) - - new_lane.left_boundary.dx_mm.extend(compressed_left_pts[:, 0].tolist()) - new_lane.left_boundary.dy_mm.extend(compressed_left_pts[:, 1].tolist()) - - new_lane.right_boundary.dx_mm.extend(compressed_right_pts[:, 0].tolist()) - new_lane.right_boundary.dy_mm.extend(compressed_right_pts[:, 1].tolist()) - - if compressed_mid_pts.shape[-1] == 3: - new_lane.center.dz_mm.extend(compressed_mid_pts[:, 2].tolist()) - new_lane.left_boundary.dz_mm.extend(compressed_left_pts[:, 2].tolist()) - new_lane.right_boundary.dz_mm.extend(compressed_right_pts[:, 2].tolist()) - - -def populate_polygon( - polygon: Polyline, - polygon_pts: np.ndarray, -) -> None: - """Fill a Crosswalk object's polygon attribute. - All points should be in world coordinates. - - Args: - new_crosswalk (Lane): _description_ - polygon_pts (np.ndarray): _description_ - """ - - compressed_pts: np.ndarray = compress_values(polygon_pts) - - polygon.dx_mm.extend(compressed_pts[:, 0].tolist()) - polygon.dy_mm.extend(compressed_pts[:, 1].tolist()) - - if compressed_pts.shape[-1] == 3: - polygon.dz_mm.extend(compressed_pts[:, 2].tolist()) - - -def proto_to_np(polyline: Polyline) -> np.ndarray: - dx: np.ndarray = np.asarray(polyline.dx_mm) - dy: np.ndarray = np.asarray(polyline.dy_mm) - - if len(polyline.dz_mm) > 0: - dz: np.ndarray = np.asarray(polyline.dz_mm) - pts: np.ndarray = np.stack([dx, dy, dz], axis=1) - else: - pts: np.ndarray = np.stack([dx, dy], axis=1) - - return decompress_values(pts) - - -def transform_points(points: np.ndarray, transf_mat: np.ndarray): - n_dim = points.shape[-1] - return points @ transf_mat[:n_dim, :n_dim] + transf_mat[:n_dim, -1] - - -def interpolate(pts: np.ndarray, num_pts: int) -> np.ndarray: - """ - Interpolate points based on cumulative distances from the first one. In particular, - interpolate using a variable step such that we always get step values. - - Args: - xyz (np.ndarray): XYZ coords. - num_pts (int): How many points to interpolate to. - - Returns: - np.ndarray: The new interpolated coordinates. - """ - cum_dist = np.cumsum(np.linalg.norm(np.diff(pts, axis=0), axis=-1)) - cum_dist = np.insert(cum_dist, 0, 0) - - assert num_pts > 1, f"num_pts must be at least 2, but got {num_pts}" - steps = np.linspace(cum_dist[0], cum_dist[-1], num_pts) - - xyz_inter = np.empty((len(steps), pts.shape[-1]), dtype=pts.dtype) - xyz_inter[:, 0] = np.interp(steps, xp=cum_dist, fp=pts[:, 0]) - xyz_inter[:, 1] = np.interp(steps, xp=cum_dist, fp=pts[:, 1]) - if pts.shape[-1] == 3: - xyz_inter[:, 2] = np.interp(steps, xp=cum_dist, fp=pts[:, 2]) - - return xyz_inter - - -def rasterize_map( - vec_map: VectorizedMap, resolution: float, **pbar_kwargs -) -> np.ndarray: - """Renders the semantic map at the given resolution. - - Args: - vec_map (VectorizedMap): _description_ - resolution (float): The rasterized image's resolution in pixels per meter. - - Returns: - np.ndarray: The rasterized RGB image. - """ - world_center_m: Tuple[float, float] = ( - (vec_map.max_pt.x + vec_map.min_pt.x) / 2, - (vec_map.max_pt.y + vec_map.min_pt.y) / 2, - ) - - raster_size_x: int = ceil((vec_map.max_pt.x - vec_map.min_pt.x) * resolution) - raster_size_y: int = ceil((vec_map.max_pt.y - vec_map.min_pt.y) * resolution) - - raster_from_local: np.ndarray = np.array( - [ - [resolution, 0, raster_size_x / 2], - [0, resolution, raster_size_y / 2], - [0, 0, 1], - ] - ) - - # Compute pose from its position and rotation - pose_from_world: np.ndarray = np.array( - [ - [1, 0, -world_center_m[0]], - [0, 1, -world_center_m[1]], - [0, 0, 1], - ] - ) - - raster_from_world: np.ndarray = raster_from_local @ pose_from_world - - lane_area_img: np.ndarray = np.zeros( - shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 - ) - lane_line_img: np.ndarray = np.zeros( - shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 - ) - ped_area_img: np.ndarray = np.zeros( - shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 - ) - - map_elem: MapElement - for map_elem in tqdm( - vec_map.elements, - desc=f"Rasterizing Map at {resolution:.2f} px/m", - **pbar_kwargs, - ): - if map_elem.HasField("road_lane"): - left_pts: np.ndarray = proto_to_np(map_elem.road_lane.left_boundary) - right_pts: np.ndarray = proto_to_np(map_elem.road_lane.right_boundary) - - lane_area: np.ndarray = cv2_subpixel( - transform_points( - np.concatenate([left_pts[:, :2], right_pts[::-1, :2]], axis=0), - raster_from_world, - ) - ) - - # Need to for-loop because doing it all at once can make holes. - cv2.fillPoly( - img=lane_area_img, - pts=[lane_area], - color=(255, 0, 0), - **CV2_SUB_VALUES, - ) - - # Drawing lane lines. - cv2.polylines( - img=lane_line_img, - pts=lane_area.reshape((2, -1, 2)), - isClosed=False, - color=(0, 255, 0), - **CV2_SUB_VALUES, - ) - - elif map_elem.HasField("road_area"): - xyz_pts: np.ndarray = proto_to_np(map_elem.road_area.exterior_polygon) - road_area: np.ndarray = cv2_subpixel( - transform_points(xyz_pts[:, :2], raster_from_world) - ) - - # Drawing general road areas. - cv2.fillPoly( - img=lane_area_img, - pts=[road_area], - color=(255, 0, 0), - **CV2_SUB_VALUES, - ) - - for interior_hole in map_elem.road_area.interior_holes: - xyz_pts: np.ndarray = proto_to_np(interior_hole) - road_area: np.ndarray = cv2_subpixel( - transform_points(xyz_pts[:, :2], raster_from_world) - ) - - # Removing holes. - cv2.fillPoly( - img=lane_area_img, - pts=[road_area], - color=(0, 0, 0), - **CV2_SUB_VALUES, - ) - - elif map_elem.HasField("ped_crosswalk"): - xyz_pts: np.ndarray = proto_to_np(map_elem.ped_crosswalk.polygon) - crosswalk_area: np.ndarray = cv2_subpixel( - transform_points(xyz_pts[:, :2], raster_from_world) - ) - - # Drawing crosswalks. - cv2.fillPoly( - img=ped_area_img, - pts=[crosswalk_area], - color=(0, 0, 255), - **CV2_SUB_VALUES, - ) - - elif map_elem.HasField("ped_walkway"): - xyz_pts: np.ndarray = proto_to_np(map_elem.ped_walkway.polygon) - walkway_area: np.ndarray = cv2_subpixel( - transform_points(xyz_pts[:, :2], raster_from_world) - ) - - # Drawing walkways. - cv2.fillPoly( - img=ped_area_img, - pts=[walkway_area], - color=(0, 0, 255), - **CV2_SUB_VALUES, - ) - - map_img: np.ndarray = (lane_area_img + lane_line_img + ped_area_img).astype( - np.float32 - ) / 255 - return map_img.transpose(2, 0, 1), raster_from_world diff --git a/src/trajdata/maps/map.py b/src/trajdata/maps/raster_map.py similarity index 71% rename from src/trajdata/maps/map.py rename to src/trajdata/maps/raster_map.py index ad0ee07..4309896 100644 --- a/src/trajdata/maps/map.py +++ b/src/trajdata/maps/raster_map.py @@ -9,14 +9,14 @@ class RasterizedMapMetadata: def __init__( self, name: str, - shape: Tuple[int, int], + shape: Tuple[int, int, int], layers: List[str], layer_rgb_groups: Tuple[List[int], List[int], List[int]], resolution: float, # px/m map_from_world: np.ndarray, # Transformation from world coordinates [m] to map coordinates [px] ) -> None: self.name: str = name - self.shape: Tuple[int, int] = shape + self.shape: Tuple[int, int, int] = shape self.layers: List[str] = layers self.layer_rgb_groups: Tuple[List[int], List[int], List[int]] = layer_rgb_groups self.resolution: float = resolution @@ -34,7 +34,7 @@ def __init__( self.data: np.ndarray = data @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> Tuple[int, int, int]: return self.data.shape @staticmethod @@ -53,3 +53,21 @@ def to_img( ], dim=-1, ).numpy() + + +class RasterizedMapPatch: + def __init__( + self, + data: np.ndarray, + rot_angle: float, + crop_size: int, + resolution: float, + raster_from_world_tf: np.ndarray, + has_data: bool, + ) -> None: + self.data = data + self.rot_angle = rot_angle + self.crop_size = crop_size + self.resolution = resolution + self.raster_from_world_tf = raster_from_world_tf + self.has_data = has_data diff --git a/src/trajdata/maps/traffic_light_status.py b/src/trajdata/maps/traffic_light_status.py new file mode 100644 index 0000000..260fe46 --- /dev/null +++ b/src/trajdata/maps/traffic_light_status.py @@ -0,0 +1,8 @@ +from enum import IntEnum + + +class TrafficLightStatus(IntEnum): + NO_DATA = -1 + UNKNOWN = 0 + GREEN = 1 + RED = 2 diff --git a/src/trajdata/maps/vec_map.py b/src/trajdata/maps/vec_map.py new file mode 100644 index 0000000..6b720cc --- /dev/null +++ b/src/trajdata/maps/vec_map.py @@ -0,0 +1,648 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trajdata.maps.map_kdtree import MapElementKDTree, LaneCenterKDTree + from trajdata.maps.map_strtree import MapElementSTRTree + +from collections import defaultdict +from dataclasses import dataclass, field +from math import ceil +from typing import ( + DefaultDict, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Union, + overload, +) + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from tqdm import tqdm + +import trajdata.proto.vectorized_map_pb2 as map_proto +from trajdata.maps.map_kdtree import LaneCenterKDTree +from trajdata.maps.map_strtree import MapElementSTRTree +from trajdata.maps.traffic_light_status import TrafficLightStatus +from trajdata.maps.vec_map_elements import ( + MapElement, + MapElementType, + PedCrosswalk, + PedWalkway, + Polyline, + RoadArea, + RoadLane, +) +from trajdata.utils import map_utils, raster_utils + + +@dataclass(repr=False) +class VectorMap: + map_id: str + extent: Optional[ + np.ndarray + ] = None # extent is [min_x, min_y, min_z, max_x, max_y, max_z] + elements: DefaultDict[MapElementType, Dict[str, MapElement]] = field( + default_factory=lambda: defaultdict(dict) + ) + search_kdtrees: Optional[Dict[MapElementType, MapElementKDTree]] = None + search_rtrees: Optional[Dict[MapElementType, MapElementSTRTree]] = None + traffic_light_status: Optional[Dict[Tuple[str, int], TrafficLightStatus]] = None + + def __post_init__(self) -> None: + self.env_name, self.map_name = self.map_id.split(":") + + self.lanes: Optional[List[RoadLane]] = None + if MapElementType.ROAD_LANE in self.elements: + self.lanes = list(self.elements[MapElementType.ROAD_LANE].values()) + + def add_map_element(self, map_elem: MapElement) -> None: + self.elements[map_elem.elem_type][map_elem.id] = map_elem + + def compute_search_indices(self) -> None: + # TODO(bivanovic@nvidia.com): merge tree dicts? + self.search_kdtrees = {MapElementType.ROAD_LANE: LaneCenterKDTree(self)} + self.search_rtrees = { + elem_type: MapElementSTRTree(self, elem_type) + for elem_type in [ + MapElementType.ROAD_AREA, + MapElementType.PED_CROSSWALK, + MapElementType.PED_WALKWAY, + ] + } + + def iter_elems(self) -> Iterator[MapElement]: + for elems_dict in self.elements.values(): + for elem in elems_dict.values(): + yield elem + + def get_road_lane(self, lane_id: str) -> RoadLane: + return self.elements[MapElementType.ROAD_LANE][lane_id] + + def __len__(self) -> int: + return sum(len(elems_dict) for elems_dict in self.elements.values()) + + def _write_road_lanes( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + road_lane: RoadLane + for elem_id, road_lane in self.elements[MapElementType.ROAD_LANE].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_lane: map_proto.RoadLane = new_element.road_lane + map_utils.populate_lane_polylines(new_lane, road_lane, shifted_origin) + + new_lane.entry_lanes.extend( + [lane_id.encode() for lane_id in road_lane.prev_lanes] + ) + new_lane.exit_lanes.extend( + [lane_id.encode() for lane_id in road_lane.next_lanes] + ) + + new_lane.adjacent_lanes_left.extend( + [lane_id.encode() for lane_id in road_lane.adj_lanes_left] + ) + new_lane.adjacent_lanes_right.extend( + [lane_id.encode() for lane_id in road_lane.adj_lanes_right] + ) + + def _write_road_areas( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + road_area: RoadArea + for elem_id, road_area in self.elements[MapElementType.ROAD_AREA].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_area: map_proto.RoadArea = new_element.road_area + map_utils.populate_polygon( + new_area.exterior_polygon, + road_area.exterior_polygon.xyz, + shifted_origin, + ) + + hole: Polyline + for hole in road_area.interior_holes: + new_hole: map_proto.Polyline = new_area.interior_holes.add() + map_utils.populate_polygon( + new_hole, + hole.xyz, + shifted_origin, + ) + + def _write_ped_crosswalks( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + ped_crosswalk: PedCrosswalk + for elem_id, ped_crosswalk in self.elements[ + MapElementType.PED_CROSSWALK + ].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_crosswalk: map_proto.PedCrosswalk = new_element.ped_crosswalk + map_utils.populate_polygon( + new_crosswalk.polygon, + ped_crosswalk.polygon.xyz, + shifted_origin, + ) + + def _write_ped_walkways( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + ped_walkway: PedWalkway + for elem_id, ped_walkway in self.elements[MapElementType.PED_WALKWAY].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_walkway: map_proto.PedWalkway = new_element.ped_walkway + map_utils.populate_polygon( + new_walkway.polygon, + ped_walkway.polygon.xyz, + shifted_origin, + ) + + def to_proto(self) -> map_proto.VectorizedMap: + output_map = map_proto.VectorizedMap() + output_map.name = self.map_id + + ( + output_map.min_pt.x, + output_map.min_pt.y, + output_map.min_pt.z, + output_map.max_pt.x, + output_map.max_pt.y, + output_map.max_pt.z, + ) = self.extent + + shifted_origin: np.ndarray = self.extent[:3] + ( + output_map.shifted_origin.x, + output_map.shifted_origin.y, + output_map.shifted_origin.z, + ) = shifted_origin + + # Populating the elements in the vectorized map protobuf. + self._write_road_lanes(output_map, shifted_origin) + self._write_road_areas(output_map, shifted_origin) + self._write_ped_crosswalks(output_map, shifted_origin) + self._write_ped_walkways(output_map, shifted_origin) + + return output_map + + @classmethod + def from_proto(cls, vec_map: map_proto.VectorizedMap, **kwargs): + # Options for which map elements to include. + incl_road_lanes: bool = kwargs.get("incl_road_lanes", True) + incl_road_areas: bool = kwargs.get("incl_road_areas", False) + incl_ped_crosswalks: bool = kwargs.get("incl_ped_crosswalks", False) + incl_ped_walkways: bool = kwargs.get("incl_ped_walkways", False) + + # Add any map offset in case the map origin was shifted for storage efficiency. + shifted_origin: np.ndarray = np.array( + [ + vec_map.shifted_origin.x, + vec_map.shifted_origin.y, + vec_map.shifted_origin.z, + 0.0, # Some polylines also have heading so we're adding + # this (zero) coordinate to account for that. + ] + ) + + map_elem_dict: Dict[str, Dict[str, MapElement]] = defaultdict(dict) + + map_elem: MapElement + for map_elem in vec_map.elements: + elem_id: str = map_elem.id.decode() + if incl_road_lanes and map_elem.HasField("road_lane"): + road_lane_obj: map_proto.RoadLane = map_elem.road_lane + + center_pl: Polyline = Polyline( + map_utils.proto_to_np(road_lane_obj.center) + shifted_origin + ) + + # We do not care for the heading of the left and right edges + # (only the center matters). + left_pl: Optional[Polyline] = None + if road_lane_obj.HasField("left_boundary"): + left_pl = Polyline( + map_utils.proto_to_np( + road_lane_obj.left_boundary, incl_heading=False + ) + + shifted_origin[:3] + ) + + right_pl: Optional[Polyline] = None + if road_lane_obj.HasField("right_boundary"): + right_pl = Polyline( + map_utils.proto_to_np( + road_lane_obj.right_boundary, incl_heading=False + ) + + shifted_origin[:3] + ) + + adj_lanes_left: Set[str] = set( + [iden.decode() for iden in road_lane_obj.adjacent_lanes_left] + ) + adj_lanes_right: Set[str] = set( + [iden.decode() for iden in road_lane_obj.adjacent_lanes_right] + ) + + next_lanes: Set[str] = set( + [iden.decode() for iden in road_lane_obj.exit_lanes] + ) + prev_lanes: Set[str] = set( + [iden.decode() for iden in road_lane_obj.entry_lanes] + ) + + # Double-using the connectivity attributes for lane IDs now (will + # replace them with Lane objects after all Lane objects have been created). + curr_lane = RoadLane( + elem_id, + center_pl, + left_pl, + right_pl, + adj_lanes_left, + adj_lanes_right, + next_lanes, + prev_lanes, + ) + map_elem_dict[MapElementType.ROAD_LANE][elem_id] = curr_lane + + elif incl_road_areas and map_elem.HasField("road_area"): + road_area_obj: map_proto.RoadArea = map_elem.road_area + + exterior: Polyline = Polyline( + map_utils.proto_to_np( + road_area_obj.exterior_polygon, incl_heading=False + ) + + shifted_origin[:3] + ) + + interior_holes: List[Polyline] = list() + interior_hole: map_proto.Polyline + for interior_hole in road_area_obj.interior_holes: + interior_holes.append( + Polyline( + map_utils.proto_to_np(interior_hole, incl_heading=False) + + shifted_origin[:3] + ) + ) + + curr_area = RoadArea(elem_id, exterior, interior_holes) + map_elem_dict[MapElementType.ROAD_AREA][elem_id] = curr_area + + elif incl_ped_crosswalks and map_elem.HasField("ped_crosswalk"): + ped_crosswalk_obj: map_proto.PedCrosswalk = map_elem.ped_crosswalk + + polygon_vertices: Polyline = Polyline( + map_utils.proto_to_np(ped_crosswalk_obj.polygon, incl_heading=False) + + shifted_origin[:3] + ) + + curr_area = PedCrosswalk(elem_id, polygon_vertices) + map_elem_dict[MapElementType.PED_CROSSWALK][elem_id] = curr_area + + elif incl_ped_walkways and map_elem.HasField("ped_walkway"): + ped_walkway_obj: map_proto.PedCrosswalk = map_elem.ped_walkway + + polygon_vertices: Polyline = Polyline( + map_utils.proto_to_np(ped_walkway_obj.polygon, incl_heading=False) + + shifted_origin[:3] + ) + + curr_area = PedWalkway(elem_id, polygon_vertices) + map_elem_dict[MapElementType.PED_WALKWAY][elem_id] = curr_area + + return cls( + map_id=vec_map.name, + extent=np.array( + [ + vec_map.min_pt.x, + vec_map.min_pt.y, + vec_map.min_pt.z, + vec_map.max_pt.x, + vec_map.max_pt.y, + vec_map.max_pt.z, + ] + ), + elements=map_elem_dict, + search_kdtrees=None, + search_rtrees=None, + traffic_light_status=None, + ) + + def associate_scene_data( + self, traffic_light_status_dict: Dict[Tuple[str, int], TrafficLightStatus] + ) -> None: + """Associates vector map with scene-specific data like traffic light information""" + self.traffic_light_status = traffic_light_status_dict + + def get_current_lane( + self, + xyzh: np.ndarray, + max_dist: float = 2.0, + max_heading_error: float = np.pi / 8, + ) -> List[RoadLane]: + """ + Args: + xyzh (np.ndarray): 3d position and heading of agent in world coordinates + + Returns: + List[RoadLane]: List of possible road lanes that agent could be on + """ + assert ( + self.search_kdtrees is not None + ), "Search KDTree not found, please rebuild cache." + lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] + return [ + self.lanes[idx] + for idx in lane_kdtree.current_lane_inds(xyzh, max_dist, max_heading_error) + ] + + def get_closest_lane(self, xyz: np.ndarray) -> RoadLane: + assert ( + self.search_kdtrees is not None + ), "Search KDTree not found, please rebuild cache." + lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] + return self.lanes[lane_kdtree.closest_polyline_ind(xyz)] + + def get_closest_unique_lanes(self, xyz_vec: np.ndarray) -> List[RoadLane]: + assert ( + self.search_kdtrees is not None + ), "Search KDTree not found, please rebuild cache." + assert xyz_vec.ndim == 2 # xyz_vec is assumed to be (*, 3) + lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] + closest_inds = lane_kdtree.closest_polyline_ind(xyz_vec) + unique_inds = np.unique(closest_inds) + return [self.lanes[ind] for ind in unique_inds] + + def get_lanes_within(self, xyz: np.ndarray, dist: float) -> List[RoadLane]: + assert ( + self.search_kdtrees is not None + ), "Search KDTree not found, please rebuild cache." + lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] + return [ + self.lanes[idx] for idx in lane_kdtree.polyline_inds_in_range(xyz, dist) + ] + + def get_closest_area( + self, xy: np.ndarray, elem_type: MapElementType + ) -> Union[RoadArea, PedCrosswalk, PedWalkway]: + """ + Returns 2D MapElement closest to query point + + Args: + xy (np.ndarray): query point + elem_type (MapElementType): type of map element desired + + Returns: + Union[RoadArea, PedCrosswalk, PedWalkway]: closest map elem of desired type to xy point + """ + assert ( + self.search_rtrees is not None + ), "Search RTree not found, please rebuild cache." + elem_id = self.search_rtrees[elem_type].nearest_area(xy) + return self.elements[elem_type][elem_id] + + def get_areas_within( + self, xy: np.ndarray, elem_type: MapElementType, dist: float + ) -> List[Union[RoadArea, PedCrosswalk, PedWalkway]]: + """ + Returns all 2D MapElements within dist of query xy point + + Args: + xy (np.ndarray): query point + elem_type (MapElementType): type of map element desired + dist (float): distance threshold + + Returns: + List[Union[RoadArea, PedCrosswalk, PedWalkway]]: List of areas matching query + """ + assert ( + self.search_rtrees is not None + ), "Search RTree not found, please rebuild cache." + ids = self.search_rtrees[elem_type].query_point( + xy, predicate="dwithin", distance=dist + ) + return [self.elements[elem_type][id] for id in ids] + + def get_traffic_light_status( + self, lane_id: str, scene_ts: int + ) -> TrafficLightStatus: + return ( + self.traffic_light_status.get( + (lane_id, scene_ts), TrafficLightStatus.NO_DATA + ) + if self.traffic_light_status is not None + else TrafficLightStatus.NO_DATA + ) + + def rasterize( + self, resolution: float = 2, **kwargs + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """Renders this vector map at the specified resolution. + + Args: + resolution (float): The rasterized image's resolution in pixels per meter. + + Returns: + np.ndarray: The rasterized RGB image. + """ + return_tf_mat: bool = kwargs.get("return_tf_mat", False) + incl_centerlines: bool = kwargs.get("incl_centerlines", True) + incl_lane_edges: bool = kwargs.get("incl_lane_edges", True) + incl_lane_area: bool = kwargs.get("incl_lane_area", True) + + scene_ts: Optional[int] = kwargs.get("scene_ts", None) + + # (255, 102, 99) also looks nice. + center_color: Tuple[int, int, int] = kwargs.get("center_color", (129, 51, 255)) + # (86, 203, 249) also looks nice. + edge_color: Tuple[int, int, int] = kwargs.get("edge_color", (118, 185, 0)) + # (191, 215, 234) also looks nice. + area_color: Tuple[int, int, int] = kwargs.get("area_color", (214, 232, 181)) + + min_x, min_y, _, max_x, max_y, _ = self.extent + + world_center_m: Tuple[float, float] = ( + (max_x + min_x) / 2, + (max_y + min_y) / 2, + ) + + raster_size_x: int = ceil((max_x - min_x) * resolution) + raster_size_y: int = ceil((max_y - min_y) * resolution) + + raster_from_local: np.ndarray = np.array( + [ + [resolution, 0, raster_size_x / 2], + [0, resolution, raster_size_y / 2], + [0, 0, 1], + ] + ) + + # Compute pose from its position and rotation. + pose_from_world: np.ndarray = np.array( + [ + [1, 0, -world_center_m[0]], + [0, 1, -world_center_m[1]], + [0, 0, 1], + ] + ) + + raster_from_world: np.ndarray = raster_from_local @ pose_from_world + + map_img: np.ndarray = np.zeros( + shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 + ) + + lane_edges: List[np.ndarray] = list() + centerlines: List[np.ndarray] = list() + lane: RoadLane + for lane in tqdm( + self.elements[MapElementType.ROAD_LANE].values(), + desc=f"Rasterizing Map at {resolution:.2f} px/m", + leave=False, + ): + centerlines.append( + raster_utils.world_to_subpixel( + lane.center.points[:, :2], raster_from_world + ) + ) + if lane.left_edge is not None and lane.right_edge is not None: + left_pts: np.ndarray = lane.left_edge.points[:, :2] + right_pts: np.ndarray = lane.right_edge.points[:, :2] + + lane_edges += [ + raster_utils.world_to_subpixel(left_pts, raster_from_world), + raster_utils.world_to_subpixel(right_pts, raster_from_world), + ] + + lane_color = area_color + status = self.get_traffic_light_status(lane.id, scene_ts) + if status == TrafficLightStatus.GREEN: + lane_color = [0, 200, 0] + elif status == TrafficLightStatus.RED: + lane_color = [200, 0, 0] + elif status == TrafficLightStatus.UNKNOWN: + lane_color = [150, 150, 0] + + # Drawing lane areas. Need to do per loop because doing it all at once can + # create lots of wonky holes in the image. + # See https://stackoverflow.com/questions/69768620/cv2-fillpoly-failing-for-intersecting-polygons + if incl_lane_area: + lane_area: np.ndarray = np.concatenate( + [left_pts, right_pts[::-1]], axis=0 + ) + raster_utils.rasterize_world_polygon( + lane_area, + map_img, + raster_from_world, + color=lane_color, + ) + + # Drawing all lane edge lines at the same time. + if incl_lane_edges: + raster_utils.cv2_draw_polylines(lane_edges, map_img, color=edge_color) + + # Drawing centerlines last (on top of everything else). + if incl_centerlines: + raster_utils.cv2_draw_polylines(centerlines, map_img, color=center_color) + + if return_tf_mat: + return map_img.astype(float) / 255, raster_from_world + else: + return map_img.astype(float) / 255 + + @overload + def visualize_lane_graph( + self, + origin_lane: RoadLane, + num_hops: int, + **kwargs, + ) -> Axes: + ... + + @overload + def visualize_lane_graph(self, origin_lane: str, num_hops: int, **kwargs) -> Axes: + ... + + @overload + def visualize_lane_graph(self, origin_lane: int, num_hops: int, **kwargs) -> Axes: + ... + + def visualize_lane_graph( + self, origin_lane: Union[RoadLane, str, int], num_hops: int, **kwargs + ) -> Axes: + ax = kwargs.get("ax", None) + if ax is None: + fig, ax = plt.subplots() + + origin: str + if isinstance(origin_lane, RoadLane): + origin = origin_lane.id + elif isinstance(origin_lane, str): + origin = origin_lane + elif isinstance(origin_lane, int): + origin = self.lanes[origin_lane].id + + viridis = mpl.colormaps[kwargs.get("cmap", "rainbow")].resampled(num_hops + 1) + + already_seen: Set[str] = set() + lanes_to_plot: List[Tuple[str, int]] = [(origin, 0)] + + if kwargs.get("legend", True): + ax.scatter([], [], label=f"Lane Endpoints", color="k") + ax.plot([], [], label=f"Origin Lane ({origin})", color=viridis(0)) + for h in range(1, num_hops + 1): + ax.plot( + [], + [], + label=f"{h} Lane{'s' if h > 1 else ''} Away", + color=viridis(h), + ) + + raster_from_world = kwargs.get("raster_from_world", None) + while len(lanes_to_plot) > 0: + lane_id, curr_hops = lanes_to_plot.pop(0) + already_seen.add(lane_id) + lane: RoadLane = self.get_road_lane(lane_id) + + center: np.ndarray = lane.center.points[..., :2] + first_pt_heading: float = lane.center.points[0, -1] + mdpt: np.ndarray = lane.center.midpoint[..., :2] + + if raster_from_world is not None: + center = map_utils.transform_points(center, raster_from_world) + mdpt = map_utils.transform_points(mdpt[None, :], raster_from_world)[0] + + ax.plot(center[:, 0], center[:, 1], color=viridis(curr_hops)) + ax.scatter(center[[0, -1], 0], center[[0, -1], 1], color=viridis(curr_hops)) + ax.quiver( + [center[0, 0]], + [center[0, 1]], + [np.cos(first_pt_heading)], + [np.sin(first_pt_heading)], + color=viridis(curr_hops), + ) + ax.text(mdpt[0], mdpt[1], s=lane_id) + + if curr_hops < num_hops: + lanes_to_plot += [ + (l, curr_hops + 1) + for l in lane.reachable_lanes + if l not in already_seen + ] + + if kwargs.get("legend", True): + ax.legend(loc="best", frameon=True) + + return ax diff --git a/src/trajdata/maps/vec_map_elements.py b/src/trajdata/maps/vec_map_elements.py new file mode 100644 index 0000000..fd7a9f1 --- /dev/null +++ b/src/trajdata/maps/vec_map_elements.py @@ -0,0 +1,171 @@ +from dataclasses import dataclass, field +from enum import IntEnum +from typing import List, Optional, Set + +import numpy as np + +from trajdata.utils import map_utils + + +class MapElementType(IntEnum): + ROAD_LANE = 1 + ROAD_AREA = 2 + PED_CROSSWALK = 3 + PED_WALKWAY = 4 + + +@dataclass +class Polyline: + points: np.ndarray + + def __post_init__(self) -> None: + if self.points.shape[-1] < 2: + raise ValueError( + f"Polylines are expected to have 2 (xy), 3 (xyz), or 4 (xyzh) dimensions, but received {self.points.shape[-1]}." + ) + + if self.points.shape[-1] == 2: + # If only xy are passed in, then append zero to the end for z. + self.points = np.append( + self.points, np.zeros_like(self.points[:, [0]]), axis=-1 + ) + + @property + def midpoint(self) -> np.ndarray: + num_pts: int = self.points.shape[0] + return self.points[num_pts // 2] + + @property + def has_heading(self) -> bool: + return self.points.shape[-1] == 4 + + @property + def xy(self) -> np.ndarray: + return self.points[..., :2] + + @property + def xyz(self) -> np.ndarray: + return self.points[..., :3] + + @property + def xyzh(self) -> np.ndarray: + if self.has_heading: + return self.points[..., :4] + else: + raise ValueError( + f"This Polyline only has {self.points.shape[-1]} coordinates, expected 4." + ) + + @property + def h(self) -> np.ndarray: + return self.points[..., 3] + + def interpolate( + self, num_pts: Optional[int] = None, max_dist: Optional[float] = None + ) -> "Polyline": + return Polyline( + map_utils.interpolate(self.points, num_pts=num_pts, max_dist=max_dist) + ) + + def project_onto(self, xyz_or_xyzh: np.ndarray) -> np.ndarray: + """Project the given points onto this Polyline. + + Args: + xyzh (np.ndarray): Points to project, of shape (M, D) + + Returns: + np.ndarray: The projected points, of shape (M, D) + + Note: + D = 4 if this Polyline has headings, otherwise D = 3 + """ + # xyzh is now (M, 1, 3), we do not use heading for projection. + xyz = xyz_or_xyzh[:, np.newaxis, :3] + + # p0, p1 are (1, N, 3) + p0: np.ndarray = self.points[np.newaxis, :-1, :3] + p1: np.ndarray = self.points[np.newaxis, 1:, :3] + + # 1. Compute projections of each point to each line segment in a + # batched manner. + line_seg_diffs: np.ndarray = p1 - p0 + point_seg_diffs: np.ndarray = xyz - p0 + + dot_products: np.ndarray = (point_seg_diffs * line_seg_diffs).sum( + axis=-1, keepdims=True + ) + norms: np.ndarray = np.linalg.norm(line_seg_diffs, axis=-1, keepdims=True) ** 2 + + # Clip ensures that the projected point stays within the line segment boundaries. + projs: np.ndarray = ( + p0 + np.clip(dot_products / norms, a_min=0, a_max=1) * line_seg_diffs + ) + + # 2. Find the nearest projections to the original points. + closest_proj_idxs: int = np.linalg.norm(xyz - projs, axis=-1).argmin(axis=-1) + + if self.has_heading: + # Adding in the heading of the corresponding p0 point (which makes + # sense as p0 to p1 is a line => same heading along it). + return np.concatenate( + [ + projs[range(xyz.shape[0]), closest_proj_idxs], + np.expand_dims(self.points[closest_proj_idxs, -1], axis=-1), + ], + axis=-1, + ) + else: + return projs[range(xyz.shape[0]), closest_proj_idxs] + + +@dataclass +class MapElement: + id: str + + +@dataclass +class RoadLane(MapElement): + center: Polyline + left_edge: Optional[Polyline] = None + right_edge: Optional[Polyline] = None + adj_lanes_left: Set[str] = field(default_factory=lambda: set()) + adj_lanes_right: Set[str] = field(default_factory=lambda: set()) + next_lanes: Set[str] = field(default_factory=lambda: set()) + prev_lanes: Set[str] = field(default_factory=lambda: set()) + elem_type: MapElementType = MapElementType.ROAD_LANE + + def __post_init__(self) -> None: + if not self.center.has_heading: + self.center = Polyline( + np.append( + self.center.xyz, + map_utils.get_polyline_headings(self.center.xyz), + axis=-1, + ) + ) + + def __hash__(self) -> int: + return hash(self.id) + + @property + def reachable_lanes(self) -> Set[str]: + return self.adj_lanes_left | self.adj_lanes_right | self.next_lanes + + +@dataclass +class RoadArea(MapElement): + exterior_polygon: Polyline + interior_holes: List[Polyline] = field(default_factory=lambda: list()) + elem_type: MapElementType = MapElementType.ROAD_AREA + + +@dataclass +class PedCrosswalk(MapElement): + polygon: Polyline + elem_type: MapElementType = MapElementType.PED_CROSSWALK + + +@dataclass +class PedWalkway(MapElement): + polygon: Polyline + elem_type: MapElementType = MapElementType.PED_WALKWAY diff --git a/src/trajdata/parallel/__init__.py b/src/trajdata/parallel/__init__.py index 8d98449..7f88dbd 100644 --- a/src/trajdata/parallel/__init__.py +++ b/src/trajdata/parallel/__init__.py @@ -1,2 +1 @@ from .data_preprocessor import ParallelDatasetPreprocessor, scene_paths_collate_fn -from .parallel_utils import parallel_apply, parallel_iapply diff --git a/src/trajdata/parallel/data_preprocessor.py b/src/trajdata/parallel/data_preprocessor.py index 2a0909f..6c61353 100644 --- a/src/trajdata/parallel/data_preprocessor.py +++ b/src/trajdata/parallel/data_preprocessor.py @@ -1,13 +1,12 @@ from pathlib import Path -from typing import Dict, List, Optional, Type +from typing import Dict, List, Optional, Tuple, Type import numpy as np from torch.utils.data import Dataset from trajdata.caching import EnvCache, SceneCache from trajdata.data_structures import Scene, SceneMetadata -from trajdata.utils import agent_utils -from trajdata.utils.env_utils import get_raw_dataset +from trajdata.utils import agent_utils, env_utils def scene_paths_collate_fn(filled_scenes: List) -> List: @@ -30,7 +29,10 @@ def __init__( self.rebuild_cache = rebuild_cache env_names: List[str] = list(envs_dir_dict.keys()) - scene_names: List[str] = [scene_info.name for scene_info in scene_info_list] + scene_idxs_names: List[Tuple[int, str]] = [ + (idx, scene_info.name) for idx, scene_info in enumerate(scene_info_list) + ] + scene_name_idxs, scene_names = zip(*scene_idxs_names) self.scene_idxs = np.array( [scene_info.raw_data_idx for scene_info in scene_info_list], dtype=int @@ -39,11 +41,8 @@ def __init__( [env_names.index(scene_info.env_name) for scene_info in scene_info_list], dtype=int, ) - self.scene_name_idxs = np.array( - [scene_names.index(scene_info.name) for scene_info in scene_info_list], - dtype=int, - ) + self.scene_name_idxs = np.array(scene_name_idxs, dtype=int) self.env_names_arr = np.array(env_names).astype(np.string_) self.scene_names_arr = np.array(scene_names).astype(np.string_) self.data_dir_arr = np.array(list(envs_dir_dict.values())).astype(np.string_) @@ -61,7 +60,7 @@ def __getitem__(self, idx: int) -> str: scene_idx: int = self.scene_name_idxs[idx] env_name: str = str(self.env_names_arr[env_idx], encoding="utf-8") - raw_dataset = get_raw_dataset( + raw_dataset = env_utils.get_raw_dataset( env_name, str(self.data_dir_arr[env_idx], encoding="utf-8") ) @@ -84,6 +83,13 @@ def __getitem__(self, idx: int) -> str: ) raw_dataset.del_dataset_obj() + if scene is None: + # This provides an escape hatch in case there's a reason we + # don't want to add a scene to the list of scenes. As an example, + # nuPlan has a scene with only a single frame of data which we + # can't do much with in terms of prediction/planning/etc. + return None + scene_path: Path = EnvCache.scene_metadata_path( env_cache.path, scene.env_name, scene.name, scene.dt ) diff --git a/src/trajdata/parallel/temp_cache.py b/src/trajdata/parallel/temp_cache.py deleted file mode 100644 index 9495e7d..0000000 --- a/src/trajdata/parallel/temp_cache.py +++ /dev/null @@ -1,45 +0,0 @@ -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import List, Optional, Union - -import dill - -from trajdata.data_structures.scene_metadata import Scene - - -class TemporaryCache: - def __init__(self, temp_dir: Optional[str] = None) -> None: - self.temp_dir: Optional[TemporaryDirectory] = None - if temp_dir is None: - self.temp_dir: TemporaryDirectory = TemporaryDirectory() - self.path: Path = Path(self.temp_dir.name) - else: - self.path: Path = Path(temp_dir) - - def cache(self, scene: Scene, ret_str: bool = False) -> Union[Path, str]: - tmp_file_path: Path = self.path / TemporaryCache.get_file_path(scene) - with open(tmp_file_path, "wb") as f: - dill.dump(scene, f) - - if ret_str: - return str(tmp_file_path) - else: - return tmp_file_path - - def cache_scenes(self, scenes: List[Scene]) -> List[str]: - paths: List[str] = list() - for scene in scenes: - tmp_file_path: Path = self.path / TemporaryCache.get_file_path(scene) - with open(tmp_file_path, "wb") as f: - dill.dump(scene, f) - - paths.append(str(tmp_file_path)) - - return paths - - def cleanup(self) -> None: - self.temp_dir.cleanup() - - @staticmethod - def get_file_path(scene_info: Scene) -> Path: - return f"{scene_info.env_name}_{scene_info.name}.dill" diff --git a/src/trajdata/proto/vectorized_map.proto b/src/trajdata/proto/vectorized_map.proto index 2a9ca7c..e1cd502 100644 --- a/src/trajdata/proto/vectorized_map.proto +++ b/src/trajdata/proto/vectorized_map.proto @@ -3,13 +3,20 @@ syntax = "proto3"; package trajdata; message VectorizedMap { + // The name of this map in the format environment_name:map_name + string name = 1; + // The full set of map elements. - repeated MapElement elements = 1; + repeated MapElement elements = 2; // The coordinates of the cuboid (in m) // containing all elements in this map. - optional Point max_pt = 2; - optional Point min_pt = 3; + Point max_pt = 3; + Point min_pt = 4; + + // The original world coordinates (in m) of the bottom-left of the map + // (account for a change in the origin for storage efficiency). + Point shifted_origin = 5; } message MapElement { @@ -26,9 +33,9 @@ message MapElement { } message Point { - optional double x = 1; - optional double y = 2; - optional double z = 3; + double x = 1; + double y = 2; + double z = 3; } message Polyline { @@ -40,6 +47,7 @@ message Polyline { repeated sint32 dx_mm = 1; repeated sint32 dy_mm = 2; repeated sint32 dz_mm = 3; + repeated double h_rad = 4; } message RoadLane { @@ -47,11 +55,11 @@ message RoadLane { // segments defined between consecutive points. Polyline center = 1; - // The polyline data for the left boundary of this lane. - Polyline left_boundary = 2; + // The polyline data for the (optional) left boundary of this lane. + optional Polyline left_boundary = 2; - // The polyline data for the right boundary of this lane. - Polyline right_boundary = 3; + // The polyline data for the (optional) right boundary of this lane. + optional Polyline right_boundary = 3; // A list of IDs for lanes that this lane may be entered from. repeated bytes entry_lanes = 4; diff --git a/src/trajdata/proto/vectorized_map_pb2.py b/src/trajdata/proto/vectorized_map_pb2.py index 96bb4f1..acdd9cc 100644 --- a/src/trajdata/proto/vectorized_map_pb2.py +++ b/src/trajdata/proto/vectorized_map_pb2.py @@ -14,7 +14,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x14vectorized_map.proto\x12\x08trajdata"\x99\x01\n\rVectorizedMap\x12&\n\x08\x65lements\x18\x01 \x03(\x0b\x32\x14.trajdata.MapElement\x12$\n\x06max_pt\x18\x02 \x01(\x0b\x32\x0f.trajdata.PointH\x00\x88\x01\x01\x12$\n\x06min_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.PointH\x01\x88\x01\x01\x42\t\n\x07_max_ptB\t\n\x07_min_pt"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data"I\n\x05Point\x12\x0e\n\x01x\x18\x01 \x01(\x01H\x00\x88\x01\x01\x12\x0e\n\x01y\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12\x0e\n\x01z\x18\x03 \x01(\x01H\x02\x88\x01\x01\x42\x04\n\x02_xB\x04\n\x02_yB\x04\n\x02_z"7\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x11\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x11\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x11"\xe9\x01\n\x08RoadLane\x12"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12)\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.Polyline\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' + b'\n\x14vectorized_map.proto\x12\x08trajdata"\xb0\x01\n\rVectorizedMap\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x08\x65lements\x18\x02 \x03(\x0b\x32\x14.trajdata.MapElement\x12\x1f\n\x06max_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.Point\x12\x1f\n\x06min_pt\x18\x04 \x01(\x0b\x32\x0f.trajdata.Point\x12\'\n\x0eshifted_origin\x18\x05 \x01(\x0b\x32\x0f.trajdata.Point"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data"(\n\x05Point\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01"F\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x11\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x11\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x11\x12\r\n\x05h_rad\x18\x04 \x03(\x01"\x98\x02\n\x08RoadLane\x12"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12.\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.PolylineH\x00\x88\x01\x01\x12/\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.PolylineH\x01\x88\x01\x01\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c\x42\x10\n\x0e_left_boundaryB\x11\n\x0f_right_boundary"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' ) @@ -115,22 +115,21 @@ _sym_db.RegisterMessage(PedWalkway) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None _VECTORIZEDMAP._serialized_start = 35 - _VECTORIZEDMAP._serialized_end = 188 - _MAPELEMENT._serialized_start = 191 - _MAPELEMENT._serialized_end = 407 - _POINT._serialized_start = 409 - _POINT._serialized_end = 482 - _POLYLINE._serialized_start = 484 - _POLYLINE._serialized_end = 539 - _ROADLANE._serialized_start = 542 - _ROADLANE._serialized_end = 775 - _ROADAREA._serialized_start = 777 - _ROADAREA._serialized_end = 877 - _PEDCROSSWALK._serialized_start = 879 - _PEDCROSSWALK._serialized_end = 930 - _PEDWALKWAY._serialized_start = 932 - _PEDWALKWAY._serialized_end = 981 + _VECTORIZEDMAP._serialized_end = 211 + _MAPELEMENT._serialized_start = 214 + _MAPELEMENT._serialized_end = 430 + _POINT._serialized_start = 432 + _POINT._serialized_end = 472 + _POLYLINE._serialized_start = 474 + _POLYLINE._serialized_end = 544 + _ROADLANE._serialized_start = 547 + _ROADLANE._serialized_end = 827 + _ROADAREA._serialized_start = 829 + _ROADAREA._serialized_end = 929 + _PEDCROSSWALK._serialized_start = 931 + _PEDCROSSWALK._serialized_end = 982 + _PEDWALKWAY._serialized_start = 984 + _PEDWALKWAY._serialized_end = 1033 # @@protoc_insertion_point(module_scope) diff --git a/src/trajdata/simulation/sim_cache.py b/src/trajdata/simulation/sim_cache.py index fae625f..5abff81 100644 --- a/src/trajdata/simulation/sim_cache.py +++ b/src/trajdata/simulation/sim_cache.py @@ -3,6 +3,7 @@ import numpy as np from trajdata.caching.scene_cache import SceneCache +from trajdata.data_structures.state import StateArray from trajdata.simulation.sim_metrics import SimMetric from trajdata.simulation.sim_stats import SimStatistic @@ -11,7 +12,7 @@ class SimulationCache(SceneCache): def reset(self) -> None: raise NotImplementedError() - def append_state(self, xyh_dict: Dict[str, np.ndarray]) -> None: + def append_state(self, xyzh_dict: Dict[str, StateArray]) -> None: raise NotImplementedError() def add_agents(self, agent_data: List[Tuple]) -> None: diff --git a/src/trajdata/simulation/sim_df_cache.py b/src/trajdata/simulation/sim_df_cache.py index d6e571e..07bdbd1 100644 --- a/src/trajdata/simulation/sim_df_cache.py +++ b/src/trajdata/simulation/sim_df_cache.py @@ -9,6 +9,7 @@ from trajdata.caching.df_cache import DataFrameCache from trajdata.data_structures.agent import AgentMetadata from trajdata.data_structures.scene_metadata import Scene +from trajdata.data_structures.state import StateArray from trajdata.simulation.sim_cache import SimulationCache from trajdata.simulation.sim_metrics import SimMetric from trajdata.simulation.sim_stats import SimStatistic @@ -22,7 +23,8 @@ def __init__( scene_ts: int, augmentations: Optional[List[Augmentation]] = None, ) -> None: - super().__init__(cache_path, scene, scene_ts, augmentations) + super().__init__(cache_path, scene, augmentations) + self.scene_ts = scene_ts agent_names: List[str] = [agent.name for agent in scene.agents] in_index: np.ndarray = self.scene_data_df.index.isin(agent_names, level=0) @@ -64,9 +66,7 @@ def get_agents_future( agents: List[AgentMetadata], future_sec: Tuple[Optional[float], Optional[float]], ) -> Tuple[np.ndarray, np.ndarray]: - last_timesteps = np.array( - [agent.last_timestep for agent in agents], dtype=np.long - ) + last_timesteps = np.array([agent.last_timestep for agent in agents], dtype=int) if np.all(np.greater(scene_ts, last_timesteps)): return ( @@ -77,33 +77,40 @@ def get_agents_future( return super().get_agents_future(scene_ts, agents, future_sec) - def append_state(self, xyh_dict: Dict[str, np.ndarray]) -> None: + def append_state(self, xyzh_dict: Dict[str, StateArray]) -> None: self.scene_ts += 1 sim_dict: Dict[str, List[Union[str, float, int]]] = defaultdict(list) - prev_states: np.ndarray = self.get_states( - list(xyh_dict.keys()), self.scene_ts - 1 + prev_states: StateArray = self.get_states( + list(xyzh_dict.keys()), self.scene_ts - 1 ) - for idx, (agent, new_xyh) in enumerate(xyh_dict.items()): - prev_state = prev_states[idx] + + new_xyzh: StateArray + for idx, (agent, new_xyzh) in enumerate(xyzh_dict.items()): + prev_state: StateArray = prev_states[idx] sim_dict["agent_id"].append(agent) sim_dict["scene_ts"].append(self.scene_ts) - sim_dict["x"].append(new_xyh[0]) - sim_dict["y"].append(new_xyh[1]) + old_x, old_y = prev_state.position + new_x, new_y = new_xyzh.position - vx: float = (new_xyh[0] - prev_state[0]) / self.scene.dt - vy: float = (new_xyh[1] - prev_state[1]) / self.scene.dt + sim_dict["x"].append(new_x) + sim_dict["y"].append(new_y) + + vx: float = (new_x - old_x) / self.scene.dt + vy: float = (new_y - old_y) / self.scene.dt sim_dict["vx"].append(vx) sim_dict["vy"].append(vy) - ax: float = (vx - prev_state[2]) / self.scene.dt - ay: float = (vy - prev_state[3]) / self.scene.dt + old_vx, old_vy = prev_state.velocity + + ax: float = (vx - old_vx) / self.scene.dt + ay: float = (vy - old_vy) / self.scene.dt sim_dict["ax"].append(ax) sim_dict["ay"].append(ay) - sim_dict["heading"].append(new_xyh[2]) + sim_dict["heading"].append(new_xyzh.heading.item()) if self.extent_cols: sim_dict["length"].append( diff --git a/src/trajdata/simulation/sim_scene.py b/src/trajdata/simulation/sim_scene.py index 56da477..7ad79dd 100644 --- a/src/trajdata/simulation/sim_scene.py +++ b/src/trajdata/simulation/sim_scene.py @@ -12,6 +12,7 @@ from trajdata.data_structures.collation import agent_collate_fn from trajdata.data_structures.scene import SceneTimeAgent from trajdata.data_structures.scene_metadata import Scene +from trajdata.data_structures.state import StateArray from trajdata.dataset import UnifiedDataset from trajdata.simulation.sim_cache import SimulationCache from trajdata.simulation.sim_df_cache import SimulationDataFrameCache @@ -94,12 +95,12 @@ def reset(self) -> Union[AgentBatch, Dict[str, Any]]: def step( self, - new_xyh_dict: Dict[str, np.ndarray], + new_xyzh_dict: Dict[str, StateArray], return_obs=True, ) -> Union[AgentBatch, Dict[str, Any]]: self.scene_ts += 1 - self.cache.append_state(new_xyh_dict) + self.cache.append_state(new_xyzh_dict) if not self.freeze_agents: agents_present: List[AgentMetadata] = self.scene.agent_presence[ @@ -120,30 +121,43 @@ def get_obs( self, collate: bool = True, get_map: bool = True ) -> Union[AgentBatch, Dict[str, Any]]: agent_data_list: List[AgentBatchElement] = list() + self.cache.set_obs_format(self.dataset.obs_format) + for agent in self.agents: scene_time_agent = SceneTimeAgent( self.scene, self.scene_ts, self.agents, agent, self.cache ) - agent_data_list.append( - AgentBatchElement( - self.cache, - -1, # Not used - scene_time_agent, - history_sec=self.dataset.history_sec, - future_sec=self.dataset.future_sec, - agent_interaction_distances=self.dataset.agent_interaction_distances, - incl_robot_future=False, - incl_map=get_map and self.dataset.incl_map, - map_params=self.dataset.map_params, - standardize_data=self.dataset.standardize_data, - standardize_derivatives=self.dataset.standardize_derivatives, - max_neighbor_num=self.dataset.max_neighbor_num, - ) + batch_element: AgentBatchElement = AgentBatchElement( + self.cache, + -1, # Not used + scene_time_agent, + history_sec=self.dataset.history_sec, + future_sec=self.dataset.future_sec, + agent_interaction_distances=self.dataset.agent_interaction_distances, + incl_robot_future=False, + incl_raster_map=get_map and self.dataset.incl_raster_map, + raster_map_params=self.dataset.raster_map_params, + map_api=self.dataset._map_api, + vector_map_params=self.dataset.vector_map_params, + state_format=self.dataset.state_format, + standardize_data=self.dataset.standardize_data, + standardize_derivatives=self.dataset.standardize_derivatives, + max_neighbor_num=self.dataset.max_neighbor_num, ) + agent_data_list.append(batch_element) + + for key, extra_fn in self.dataset.extras.items(): + batch_element.extras[key] = extra_fn(batch_element) + + for transform_fn in self.dataset.transforms: + batch_element = transform_fn(batch_element) + + if not self.dataset.vector_map_params.get("collate", False): + batch_element.vec_map = None # Need to reset transformations for each agent since each # AgentBatchElement transforms (standardizes) the cache. - self.cache.reset_transforms() + self.cache.reset_obs_frame() if collate: return agent_collate_fn( diff --git a/src/trajdata/simulation/sim_stats.py b/src/trajdata/simulation/sim_stats.py index 74f7fb5..499a900 100644 --- a/src/trajdata/simulation/sim_stats.py +++ b/src/trajdata/simulation/sim_stats.py @@ -83,12 +83,22 @@ def calc_stats( """ velocity: Tensor = ( - torch.diff(positions, dim=1, prepend=positions[:, [1]] - positions[:, [0]]) / dt + torch.diff( + positions, + dim=1, + prepend=positions[:, [0]] - (positions[:, [1]] - positions[:, [0]]), + ) + / dt ) velocity_norm: Tensor = torch.linalg.vector_norm(velocity, dim=-1) accel: Tensor = ( - torch.diff(positions, dim=1, prepend=velocity[:, [1]] - velocity[:, [0]]) / dt + torch.diff( + velocity, + dim=1, + prepend=velocity[:, [0]] - (velocity[:, [1]] - velocity[:, [0]]), + ) + / dt ) accel_norm: Tensor = torch.linalg.vector_norm(accel, dim=-1) @@ -96,7 +106,11 @@ def calc_stats( lat_acc: Tensor = accel_norm * torch.sin(heading.squeeze(-1)) jerk: Tensor = ( - torch.diff(accel_norm, dim=1, prepend=accel_norm[:, [1]] - accel_norm[:, [0]]) + torch.diff( + accel_norm, + dim=1, + prepend=accel_norm[:, [0]] - (accel_norm[:, [1]] - accel_norm[:, [0]]), + ) / dt ) diff --git a/src/trajdata/utils/agent_utils.py b/src/trajdata/utils/agent_utils.py index c457d90..1052a96 100644 --- a/src/trajdata/utils/agent_utils.py +++ b/src/trajdata/utils/agent_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Type, Union +from typing import Optional, Type from trajdata.caching import EnvCache, SceneCache from trajdata.data_structures import Scene, SceneMetadata @@ -47,6 +47,8 @@ def get_agent_data( agent_list, agent_presence = raw_dataset.get_agent_info( scene, env_cache.path, cache_class ) + if agent_list is None and agent_presence is None: + raise ValueError(f"Scene {scene_info.name} contains no agents!") scene.update_agent_info(agent_list, agent_presence) env_cache.save_scene(scene) diff --git a/src/trajdata/utils/arr_utils.py b/src/trajdata/utils/arr_utils.py index ef381e3..e76a678 100644 --- a/src/trajdata/utils/arr_utils.py +++ b/src/trajdata/utils/arr_utils.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import List, Optional +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -105,28 +105,36 @@ def angle_wrap(radians: np.ndarray) -> np.ndarray: return (radians + np.pi) % (2 * np.pi) - np.pi -def rotation_matrix(angle: float) -> np.ndarray: - """Creates a 2D rotation matrix. +def rotation_matrix(angle: Union[float, np.ndarray]) -> np.ndarray: + """Creates one or many 2D rotation matrices. Args: - angle (float): The angle to rotate points by. + angle (Union[float, np.ndarray]): The angle to rotate points by. + if float, returns 2x2 matrix + if np.ndarray, expects shape [...], and returns [...,2,2] array Returns: - np.ndarray: The 2x2 rotation matrix. + np.ndarray: The 2x2 rotation matri(x/ces). """ - return np.array( + batch_dims = 0 + if isinstance(angle, np.ndarray): + batch_dims = angle.ndim + angle = angle + + rotmat: np.ndarray = np.array( [ [np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)], ] ) + return rotmat.transpose(*np.arange(2, batch_dims + 2), 0, 1) def transform_matrices(angles: Tensor, translations: Tensor) -> Tensor: """Creates a 3x3 transformation matrix for each angle and translation in the input. Args: - angles (Tensor): The (N,)-shaped angles tensor to rotate points by. + angles (Tensor): The (N,)-shaped angles tensor to rotate points by (in radians). translations (Tensor): The (N,2)-shaped translations to shift points by. Returns: @@ -147,21 +155,93 @@ def transform_matrices(angles: Tensor, translations: Tensor) -> Tensor: ) -def batch_nd_transform_points_np(points, Mat): - ndim = Mat.shape[-1] - 1 - batch = list(range(Mat.ndim - 2)) + [Mat.ndim - 1] + [Mat.ndim - 2] - Mat = np.transpose(Mat, batch) - if points.ndim == Mat.ndim - 1: - return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ - ..., -1:, :ndim - ].squeeze(-2) - elif points.ndim == Mat.ndim: - return ( - (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) - + Mat[..., np.newaxis, -1:, :ndim] - ).squeeze(-2) +def transform_coords_2d_np( + coords: np.ndarray, + offset: Optional[np.ndarray] = None, + angle: Optional[np.ndarray] = None, + rot_mat: Optional[np.ndarray] = None, +) -> np.ndarray: + """ + Args: + coords (np.ndarray): [..., 2] coordinates + offset (Optional[np.ndarray], optional): [..., 2] offset to translate. Defaults to None. + angle (Optional[np.ndarray], optional): [...] angle to rotate by. Defaults to None. + rot_mat (Optional[np.ndarray], optional): [..., 2,2] rotation matrix to apply. Defaults to None. + If rot_mat is given, angle is ignored. + + Returns: + np.ndarray: transformed coords + """ + if rot_mat is None and angle is not None: + rot_mat = rotation_matrix(angle) + + if rot_mat is not None: + coords = np.einsum("...ij,...j->...i", rot_mat, coords) + + if offset is not None: + coords += offset + + return coords + + +def transform_coords_np( + coords: np.ndarray, tf_mat: np.ndarray, translate: bool = True +) -> np.ndarray: + """ + Returns coords after transforming them according to the transformation matrix tf_mat + + Args: + coords (np.ndarray): batch of points [..., d] + tf_mat (np.ndarray): nd affine transformation matrix [..., d+1, d+1] + or [d+1, d+1] if the same transformation should be applied to all points + + Returns: + np.ndarray: transformed points [..., d] + """ + if coords.ndim == (tf_mat.ndim - 1): + transformed = np.einsum("...jk,...k->...j", tf_mat[..., :-1, :-1], coords) + if translate: + transformed += tf_mat[..., :-1, -1] + elif tf_mat.ndim == 2: + transformed = np.einsum("jk,...k->...j", tf_mat[:-1, :-1], coords) + if translate: + transformed += tf_mat[None, :-1, -1] else: - raise Exception("wrong shape") + raise ValueError("Batch dims of tf_mat must match coords") + + return transformed + + +def transform_angles_np(angles: np.ndarray, tf_mat: np.ndarray) -> np.ndarray: + """ + Returns angles after transforming them according to the transformation matrix tf_mat + + Args: + angles (np.ndarray): batch of angles [...] + tf_mat (np.ndarray): nd affine transformation matrix [..., d+1, d+1] + or [d+1, d+1] if the same transformation should be applied to all points + + Returns: + np.ndarray: transformed angles [...] + """ + cos_vals, sin_vals = tf_mat[..., 0, 0], tf_mat[..., 1, 0] + rot_angle = np.arctan2(sin_vals, cos_vals) + transformed_angles = angles + rot_angle + transformed_angles = angle_wrap(transformed_angles) + return transformed_angles + + +def transform_xyh_np(xyh: np.ndarray, tf_mat: np.ndarray) -> np.ndarray: + """ + Returns transformed set of xyh points + + Args: + xyh (np.ndarray): shape [...,3] + tf_mat (np.ndarray): shape [...,3,3] + """ + transformed_xy = transform_coords_np(xyh[..., :2], tf_mat) + transformed_angles = transform_angles_np(xyh[..., 2], tf_mat) + return np.concatenate([transformed_xy, transformed_angles[..., None]], axis=-1) def agent_aware_diff(values: np.ndarray, agent_ids: np.ndarray) -> np.ndarray: @@ -214,6 +294,7 @@ def batch_proj(x, line): delta_y, torch.unsqueeze(delta_psi, dim=-1), ) + elif isinstance(x, np.ndarray): delta = line[..., 0:2] - np.repeat( x[..., np.newaxis, 0:2], line_length, axis=-2 @@ -236,3 +317,11 @@ def batch_proj(x, line): delta_y, np.expand_dims(delta_psi, axis=-1), ) + + +def quaternion_to_yaw(q: np.ndarray): + # From https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L1025 + return np.arctan2( + 2 * (q[..., 0] * q[..., 3] - q[..., 1] * q[..., 2]), + 1 - 2 * (q[..., 2] ** 2 + q[..., 3] ** 2), + ) diff --git a/src/trajdata/utils/batch_utils.py b/src/trajdata/utils/batch_utils.py new file mode 100644 index 0000000..cf63009 --- /dev/null +++ b/src/trajdata/utils/batch_utils.py @@ -0,0 +1,175 @@ +from collections import defaultdict +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import numpy as np +from torch.utils.data import Sampler + +from trajdata import UnifiedDataset +from trajdata.data_structures import ( + AgentBatch, + AgentBatchElement, + AgentDataIndex, + AgentType, + SceneBatchElement, + SceneTimeAgent, +) +from trajdata.data_structures.collation import agent_collate_fn + + +class SceneTimeBatcher(Sampler): + _agent_data_index: AgentDataIndex + _agent_idx: int + + def __init__( + self, agent_centric_dataset: UnifiedDataset, agent_idx_to_follow: int = 0 + ) -> None: + """ + Returns a sampler (to be used in a torch.utils.data.DataLoader) + which works with an agent-centric UnifiedDataset, yielding + batches consisting of whole scenes (AgentBatchElements for all agents + in a particular scene at a particular time) + + Args: + agent_centric_dataset (UnifiedDataset) + agent_idx_to_follow (int): index of agent to return batches for. Defaults to 0, + meaning we include all scene frames where the ego agent appears, which + usually covers the entire dataset. + """ + super().__init__(agent_centric_dataset) + self._agent_data_index = agent_centric_dataset._data_index + self._agent_idx = agent_idx_to_follow + self._cumulative_lengths = np.concatenate( + [ + [0], + np.cumsum( + [ + cumulative_scene_length[self._agent_idx + 1] + - cumulative_scene_length[self._agent_idx] + for cumulative_scene_length in self._agent_data_index._cumulative_scene_lengths + ] + ), + ] + ) + + def __len__(self): + return self._cumulative_lengths[-1] + + def __iter__(self) -> Iterator[int]: + for idx in range(len(self)): + # TODO(apoorvas) May not need to do this search, since we only support an iterable style access? + scene_idx: int = ( + np.searchsorted(self._cumulative_lengths, idx, side="right").item() - 1 + ) + + # offset into dataset index to reach current scene + scene_offset = self._agent_data_index._cumulative_lengths[scene_idx].item() + + # how far along we are in the current scene + scene_elem_index = idx - self._cumulative_lengths[scene_idx].item() + + # convert to scene-timestep for the tracked agent + scene_ts = ( + scene_elem_index + + self._agent_data_index._agent_times[scene_idx][self._agent_idx, 0] + ) + + # build a set of indices into the agent-centric dataset for all agents that exist at this scene and timestep + indices = [] + for agent_idx, agent_times in enumerate( + self._agent_data_index._agent_times[scene_idx] + ): + if scene_ts > agent_times[1]: + # we are past the last timestep for this agent (times are inclusive) + continue + agent_offset = scene_ts - agent_times[0] + if agent_offset < 0: + # this agent hasn't entered the scene yet + continue + + # compute index into original dataset, first into scene, then into this agent's part in scene, and then the offset + index_to_add = ( + scene_offset + + self._agent_data_index._cumulative_scene_lengths[scene_idx][ + agent_idx + ] + + agent_offset + ) + indices.append(index_to_add) + + yield indices + + +def convert_to_agent_batch( + scene_batch_element: SceneBatchElement, + only_types: Optional[List[AgentType]] = None, + no_types: Optional[List[AgentType]] = None, + agent_interaction_distances: Dict[Tuple[AgentType, AgentType], float] = defaultdict( + lambda: np.inf + ), + incl_map: bool = False, + map_params: Optional[Dict[str, Any]] = None, + max_neighbor_num: Optional[int] = None, + state_format: Optional[str] = None, + standardize_data: bool = True, + standardize_derivatives: bool = False, + pad_format: str = "outside", +) -> AgentBatch: + """ + Converts a SceneBatchElement into a AgentBatch consisting of + AgentBatchElements for all agents present at the given scene at the given + time step. + + Args: + scene_batch_element (SceneBatchElement): element to process + only_types (Optional[List[AgentType]], optional): AgentsTypes to consider. Defaults to None. + no_types (Optional[List[AgentType]], optional): AgentTypes to ignore. Defaults to None. + agent_interaction_distances (_type_, optional): Distance threshold for interaction. Defaults to defaultdict(lambda: np.inf). + incl_map (bool, optional): Whether to include map info. Defaults to False. + map_params (Optional[Dict[str, Any]], optional): Map params. Defaults to None. + max_neighbor_num (Optional[int], optional): Max number of neighbors to allow. Defaults to None. + standardize_data (bool): Whether to return data relative to current agent state. Defaults to True. + standardize_derivatives: Whether to transform relative velocities and accelerations as well. Defaults to False. + pad_format (str, optional): Pad format when collating agent trajectories. Defaults to "outside". + + Returns: + AgentBatch: batch of AgentBatchElements corresponding to all agents in the SceneBatchElement + """ + data_idx = scene_batch_element.data_index + cache = scene_batch_element.cache + scene = cache.scene + dt = scene_batch_element.dt + ts = scene_batch_element.scene_ts + state_format = scene_batch_element.centered_agent_state_np._format + + batch_elems: List[AgentBatchElement] = [] + for j, agent_name in enumerate(scene_batch_element.agent_names): + history_sec = dt * (scene_batch_element.agent_histories[j].shape[0] - 1) + future_sec = dt * (scene_batch_element.agent_futures[j].shape[0]) + cache.reset_obs_frame() + scene_time_agent: SceneTimeAgent = SceneTimeAgent.from_cache( + scene, + ts, + agent_name, + cache, + only_types=only_types, + no_types=no_types, + ) + + batch_elems.append( + AgentBatchElement( + cache=cache, + data_index=data_idx, + scene_time_agent=scene_time_agent, + history_sec=(history_sec, history_sec), + future_sec=(future_sec, future_sec), + agent_interaction_distances=agent_interaction_distances, + incl_raster_map=incl_map, + raster_map_params=map_params, + state_format=state_format, + standardize_data=standardize_data, + standardize_derivatives=standardize_derivatives, + max_neighbor_num=max_neighbor_num, + ) + ) + + return agent_collate_fn(batch_elems, return_dict=False, pad_format=pad_format) diff --git a/src/trajdata/utils/df_utils.py b/src/trajdata/utils/df_utils.py new file mode 100644 index 0000000..697369a --- /dev/null +++ b/src/trajdata/utils/df_utils.py @@ -0,0 +1,98 @@ +from typing import Callable, Optional + +import numpy as np +import pandas as pd + + +def downsample_multi_index_df( + df: pd.DataFrame, downsample_dt_factor: int +) -> pd.DataFrame: + """ + Downsamples MultiIndex dataframe, assuming level=1 of the index + corresponds to the scene timestep. + """ + subsampled_df = df.groupby(level=0).apply( + lambda g: g.reset_index(level=0, drop=True) + .iloc[::downsample_dt_factor] + .rename(index=lambda ts: ts // downsample_dt_factor) + ) + + return subsampled_df + + +def upsample_ts_index_df( + df: pd.DataFrame, + upsample_dt_factor: int, + method: str, + preprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, + postprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, +): + """ + Upsamples a time indexed dataframe, applying specified method. + Calls preprocess and postprocess before and after upsampling repsectively. + + If original data is at frames 2,3,4,5, and upsample_dt_factor is 3, then + the original data will live at frames 6,9,12,15, and new data will + be generated according to method for frames 7,8, 10,11, 13,14 (frames after the last frame are not generated) + """ + if preprocess: + df = preprocess(df) + + # first, we multiply ts index by upsample factor + df = df.rename(index=lambda ts: ts * upsample_dt_factor) + + # get the index by adding the number of frames needed per original index + new_index = pd.Index( + (df.index.to_numpy()[:, None] + np.arange(upsample_dt_factor)).flatten()[ + : -(upsample_dt_factor - 1) + ], + name=df.index.name, + ) + + # reindex and interpolate according to method + df = df.reindex(new_index).interpolate(method=method, limit_area="inside") + + if postprocess: + df = postprocess(df) + + return df + + +def upsample_multi_index_df( + df: pd.DataFrame, + upsample_dt_factor: int, + method: str, + preprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, + postprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, +) -> pd.DataFrame: + return df.groupby(level=[0]).apply( + lambda g: upsample_ts_index_df( + g.reset_index(level=[0], drop=True), + upsample_dt_factor, + method, + preprocess, + postprocess, + ) + ) + + +def interpolate_multi_index_df( + df: pd.DataFrame, data_dt: float, desired_dt: float, method: str = "linear" +) -> pd.DataFrame: + """ + Interpolates the given dataframe indexed with (elem_id, scene_ts) + where scene_ts corresponds to timesteps with increment data_dt to a new + desired_dt. + """ + upsample_dt_ratio: float = data_dt / desired_dt + downsample_dt_ratio: float = desired_dt / data_dt + if not upsample_dt_ratio.is_integer() and not downsample_dt_ratio.is_integer(): + raise ValueError( + f"Data's dt of {data_dt}s " + f"is not integer divisible by the desired dt {desired_dt}s." + ) + + if upsample_dt_ratio >= 1: + return upsample_multi_index_df(df, int(upsample_dt_ratio), method) + elif downsample_dt_ratio >= 1: + return downsample_multi_index_df(df, int(downsample_dt_ratio)) diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index 61980a0..49c22ec 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -1,35 +1,60 @@ from typing import Dict, List from trajdata.dataset_specific import RawDataset -from trajdata.dataset_specific.eth_ucy_peds import EUPedsDataset - -try: - from trajdata.dataset_specific.lyft import LyftDataset -except ModuleNotFoundError: - # This can happen if the user did not install trajdata - # with the "trajdata[lyft]" option. - pass - -try: - from trajdata.dataset_specific.nusc import NuscDataset -except ModuleNotFoundError: - # This can happen if the user did not install trajdata - # with the "trajdata[nusc]" option. - pass def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: if "nusc" in dataset_name: + from trajdata.dataset_specific.nusc import NuscDataset + return NuscDataset(dataset_name, data_dir, parallelizable=False, has_maps=True) + if "vod" in dataset_name: + from trajdata.dataset_specific.vod import VODDataset + + return VODDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + if "lyft" in dataset_name: + from trajdata.dataset_specific.lyft import LyftDataset + return LyftDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) if "eupeds" in dataset_name: + from trajdata.dataset_specific.eth_ucy_peds import EUPedsDataset + return EUPedsDataset( dataset_name, data_dir, parallelizable=True, has_maps=False ) + if "sdd" in dataset_name: + from trajdata.dataset_specific.sdd_peds import SDDPedsDataset + + return SDDPedsDataset( + dataset_name, data_dir, parallelizable=True, has_maps=False + ) + + if "nuplan" in dataset_name: + from trajdata.dataset_specific.nuplan import NuplanDataset + + return NuplanDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + + if "waymo" in dataset_name: + from trajdata.dataset_specific.waymo import WaymoDataset + + return WaymoDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + + if "interaction" in dataset_name: + from trajdata.dataset_specific.interaction import InteractionDataset + + return InteractionDataset( + dataset_name, data_dir, parallelizable=True, has_maps=True + ) + + if "av2" in dataset_name: + from trajdata.dataset_specific.argoverse2 import Av2Dataset + + return Av2Dataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + raise ValueError(f"Dataset with name '{dataset_name}' is not supported") diff --git a/src/trajdata/utils/map_utils.py b/src/trajdata/utils/map_utils.py new file mode 100644 index 0000000..9856972 --- /dev/null +++ b/src/trajdata/utils/map_utils.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trajdata.maps import map_kdtree, vec_map, map_strtree + from trajdata.maps.vec_map_elements import MapElementType + +from pathlib import Path +from typing import Dict, Final, Optional + +import dill +import numpy as np +from scipy.stats import circmean + +import trajdata.proto.vectorized_map_pb2 as map_proto +from trajdata.utils import arr_utils + +MM_PER_M: Final[float] = 1000 + + +def decompress_values(data: np.ndarray) -> np.ndarray: + # From https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 + # The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from + # the origin. For subsequent points, this field stores the difference between the point's + # coordinates and the previous point's coordinates. This is for representation efficiency. + return np.cumsum(data, axis=0, dtype=float) / MM_PER_M + + +def compress_values(data: np.ndarray) -> np.ndarray: + return (np.diff(data, axis=0, prepend=0.0) * MM_PER_M).astype(np.int32) + + +def get_polyline_headings(points: np.ndarray) -> np.ndarray: + """Get approximate heading angles for points in a polyline. + + Args: + points: XY points, np.ndarray of shape [N, 2] + + Returns: + np.ndarray: approximate heading angles in radians, shape [N, 1] + """ + if points.ndim < 2 and points.shape[-1] != 2 and points.shape[-2] <= 1: + raise ValueError("Unexpected shape") + + vectors = points[..., 1:, :] - points[..., :-1, :] + vec_headings = np.arctan2(vectors[..., 1], vectors[..., 0]) # -pi..pi + + # For internal points compute the mean heading of consecutive segments. + # Need to use circular mean to average directions. + # TODO(pkarkus) this would be more accurate if weighted with the distance to the neighbor + if vec_headings.shape[-1] <= 1: + # Handle special case because circmean unfortunately returns nan for such input. + mean_consec_headings = np.zeros( + list(vec_headings.shape[:-1]) + [0], dtype=vec_headings.dtype + ) + else: + mean_consec_headings = circmean( + np.stack([vec_headings[..., :-1], vec_headings[..., 1:]], axis=-1), + high=np.pi, + low=-np.pi, + axis=-1, + ) + + headings = np.concatenate( + [ + vec_headings[..., :1], # heading of first segment + mean_consec_headings, # mean heading of consecutive segments + vec_headings[..., -1:], # heading of last segment + ], + axis=-1, + ) + return headings[..., np.newaxis] + + +def populate_lane_polylines( + new_lane_proto: map_proto.RoadLane, + road_lane_py: vec_map.RoadLane, + origin: np.ndarray, +) -> None: + """Fill a Lane object's polyline attributes. + All points should be in world coordinates. + + Args: + new_lane (Lane): _description_ + midlane_pts (np.ndarray): _description_ + left_pts (np.ndarray): _description_ + right_pts (np.ndarray): _description_ + """ + compressed_mid_pts: np.ndarray = compress_values(road_lane_py.center.xyz - origin) + new_lane_proto.center.dx_mm.extend(compressed_mid_pts[:, 0].tolist()) + new_lane_proto.center.dy_mm.extend(compressed_mid_pts[:, 1].tolist()) + new_lane_proto.center.dz_mm.extend(compressed_mid_pts[:, 2].tolist()) + new_lane_proto.center.h_rad.extend(road_lane_py.center.h.tolist()) + + if road_lane_py.left_edge is not None: + compressed_left_pts: np.ndarray = compress_values( + road_lane_py.left_edge.xyz - origin + ) + new_lane_proto.left_boundary.dx_mm.extend(compressed_left_pts[:, 0].tolist()) + new_lane_proto.left_boundary.dy_mm.extend(compressed_left_pts[:, 1].tolist()) + new_lane_proto.left_boundary.dz_mm.extend(compressed_left_pts[:, 2].tolist()) + + if road_lane_py.right_edge is not None: + compressed_right_pts: np.ndarray = compress_values( + road_lane_py.right_edge.xyz - origin + ) + new_lane_proto.right_boundary.dx_mm.extend(compressed_right_pts[:, 0].tolist()) + new_lane_proto.right_boundary.dy_mm.extend(compressed_right_pts[:, 1].tolist()) + new_lane_proto.right_boundary.dz_mm.extend(compressed_right_pts[:, 2].tolist()) + + +def populate_polygon( + polygon_proto: map_proto.Polyline, + polygon_pts: np.ndarray, + origin: np.ndarray, +) -> None: + """Fill an object's polygon. + All points should be in world coordinates. + + Args: + polygon_proto (Polyline): _description_ + polygon_pts (np.ndarray): _description_ + """ + compressed_pts: np.ndarray = compress_values(polygon_pts - origin) + + polygon_proto.dx_mm.extend(compressed_pts[:, 0].tolist()) + polygon_proto.dy_mm.extend(compressed_pts[:, 1].tolist()) + polygon_proto.dz_mm.extend(compressed_pts[:, 2].tolist()) + + +def proto_to_np(polyline: map_proto.Polyline, incl_heading: bool = True) -> np.ndarray: + dx: np.ndarray = np.asarray(polyline.dx_mm) + dy: np.ndarray = np.asarray(polyline.dy_mm) + + if len(polyline.dz_mm) > 0: + dz: np.ndarray = np.asarray(polyline.dz_mm) + pts: np.ndarray = np.stack([dx, dy, dz], axis=1) + else: + # Default z is all zeros. + pts: np.ndarray = np.stack([dx, dy, np.zeros_like(dx)], axis=1) + + ret_pts: np.ndarray = decompress_values(pts) + + if incl_heading and len(polyline.h_rad) > 0: + headings: np.ndarray = np.asarray(polyline.h_rad) + ret_pts = np.concatenate((ret_pts, headings[:, np.newaxis]), axis=1) + elif incl_heading: + raise ValueError( + f"Polyline must have heading, but it does not (polyline.h_rad is empty)." + ) + + return ret_pts + + +def transform_points(points: np.ndarray, transf_mat: np.ndarray): + n_dim = points.shape[-1] + return points @ transf_mat[:n_dim, :n_dim] + transf_mat[:n_dim, -1] + + +def order_matches(pts: np.ndarray, ref: np.ndarray) -> bool: + """Evaluate whether `pts0` is ordered the same as `ref`, based on the distance from + `pts0`'s start and end points to `ref`'s start point. + + Args: + pts0 (np.ndarray): The first array of points, of shape (N, D). + pts1 (np.ndarray): The second array of points, of shape (M, D). + + Returns: + bool: True if `pts0`'s first point is closest to `ref`'s first point, + False if `pts0`'s endpoint is closer (e.g., they are flipped relative to each other). + """ + return np.linalg.norm(pts[0] - ref[0]) <= np.linalg.norm(pts[-1] - ref[0]) + + +def endpoints_intersect(left_edge: np.ndarray, right_edge: np.ndarray) -> bool: + def ccw(A, B, C): + return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0]) + + A, B = left_edge[-1], right_edge[-1] + C, D = right_edge[0], left_edge[0] + return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D) + + +def interpolate( + pts: np.ndarray, num_pts: Optional[int] = None, max_dist: Optional[float] = None +) -> np.ndarray: + """ + Interpolate points either based on cumulative distances from the first one (`num_pts`) + or by adding extra points until neighboring points are within `max_dist` of each other. + + In particular, `num_pts` will interpolate using a variable step such that we always get + the requested number of points. + + Args: + pts (np.ndarray): XYZ(H) coords. + num_pts (int, optional): Desired number of total points. + max_dist (float, optional): Maximum distance between points of the polyline. + + Note: + Only one of `num_pts` or `max_dist` can be specified. + + Returns: + np.ndarray: The new interpolated coordinates. + """ + if num_pts is not None and max_dist is not None: + raise ValueError("Only one of num_pts or max_dist can be used!") + + if pts.ndim != 2: + raise ValueError("pts is expected to be 2 dimensional") + + # 3 because XYZ (heading does not count as a positional distance). + pos_dim: int = min(pts.shape[-1], 3) + has_heading: bool = pts.shape[-1] == 4 + + if num_pts is not None: + assert num_pts > 1, f"num_pts must be at least 2, but got {num_pts}" + + if pts.shape[0] == num_pts: + return pts + + cum_dist: np.ndarray = np.cumsum( + np.linalg.norm(np.diff(pts[..., :pos_dim], axis=0), axis=-1) + ) + cum_dist = np.insert(cum_dist, 0, 0) + + steps: np.ndarray = np.linspace(cum_dist[0], cum_dist[-1], num_pts) + xyz_inter: np.ndarray = np.empty((num_pts, pts.shape[-1]), dtype=pts.dtype) + for i in range(pos_dim): + xyz_inter[:, i] = np.interp(steps, xp=cum_dist, fp=pts[:, i]) + + if has_heading: + # Heading, so make sure to unwrap, interpolate, and wrap it. + xyz_inter[:, 3] = arr_utils.angle_wrap( + np.interp(steps, xp=cum_dist, fp=np.unwrap(pts[:, 3])) + ) + + return xyz_inter + + elif max_dist is not None: + unwrapped_pts: np.ndarray = pts + if has_heading: + unwrapped_pts[..., 3] = np.unwrap(unwrapped_pts[..., 3]) + + segments = unwrapped_pts[..., 1:, :] - unwrapped_pts[..., :-1, :] + seg_lens = np.linalg.norm(segments[..., :pos_dim], axis=-1) + new_pts = [unwrapped_pts[..., 0:1, :]] + for i in range(segments.shape[-2]): + num_extra_points = seg_lens[..., i] // max_dist + if num_extra_points > 0: + step_vec = segments[..., i, :] / (num_extra_points + 1) + new_pts.append( + unwrapped_pts[..., i, np.newaxis, :] + + step_vec[..., np.newaxis, :] + * np.arange(1, num_extra_points + 1)[:, np.newaxis] + ) + + new_pts.append(unwrapped_pts[..., i + 1 : i + 2, :]) + + new_pts = np.concatenate(new_pts, axis=-2) + if has_heading: + new_pts[..., 3] = arr_utils.angle_wrap(new_pts[..., 3]) + + return new_pts + + +def load_vector_map(vector_map_path: Path) -> map_proto.VectorizedMap: + if not vector_map_path.exists(): + raise ValueError(f"{vector_map_path} does not exist!") + + vec_map = map_proto.VectorizedMap() + + # Saving the vectorized map data. + with open(vector_map_path, "rb") as f: + vec_map.ParseFromString(f.read()) + + return vec_map + + +def load_kdtrees( + kdtrees_path: Path, +) -> Dict[MapElementType, map_kdtree.MapElementKDTree]: + if not kdtrees_path.exists(): + raise ValueError(f"{kdtrees_path} does not exist!") + + with open(kdtrees_path, "rb") as f: + kdtrees: Dict[MapElementType, map_kdtree.MapElementKDTree] = dill.load(f) + + return kdtrees + + +def load_rtrees( + rtrees_path: Path, +) -> Optional[Dict[MapElementType, map_strtree.MapElementSTRTree]]: + if not rtrees_path.exists(): + warnings.warn( + ( + "Trying to load cached RTree encoding 2D Map elements, " + f"but {rtrees_path} does not exist. Earlier versions of " + "trajdata did not build and cache this RTree. If area queries " + "are needed, please rebuild the map cache (see " + "examples/preprocess_maps.py for an example of how to do this). " + "Otherwise, please ignore this warning." + ), + UserWarning, + ) + return None + + with open(rtrees_path, "rb") as f: + rtrees: Dict[MapElementType, map_strtree.MapElementSTRTree] = dill.load(f) + + return rtrees diff --git a/src/trajdata/parallel/parallel_utils.py b/src/trajdata/utils/parallel_utils.py similarity index 100% rename from src/trajdata/parallel/parallel_utils.py rename to src/trajdata/utils/parallel_utils.py diff --git a/src/trajdata/utils/py_utils.py b/src/trajdata/utils/py_utils.py new file mode 100644 index 0000000..c302aab --- /dev/null +++ b/src/trajdata/utils/py_utils.py @@ -0,0 +1,13 @@ +import hashlib +import json +from typing import Dict, List, Set, Tuple, Union + + +def hash_dict(o: Union[Dict, List, Tuple, Set]) -> str: + """ + Makes a hash from a dictionary, list, tuple or set to any level, that contains + only other hashable types (including any lists, tuples, sets, and + dictionaries). + """ + string_rep: str = json.dumps(o) + return hashlib.sha1(str.encode(string_rep)).hexdigest() diff --git a/src/trajdata/utils/raster_utils.py b/src/trajdata/utils/raster_utils.py new file mode 100644 index 0000000..d9e77dc --- /dev/null +++ b/src/trajdata/utils/raster_utils.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Final + +if TYPE_CHECKING: + from trajdata.maps import VectorMap + + +from math import ceil +from typing import List, Tuple + +import cv2 +import numpy as np +from tqdm import tqdm + +from trajdata.maps.raster_map import RasterizedMap, RasterizedMapMetadata +from trajdata.maps.vec_map import MapElement, MapElementType +from trajdata.utils import map_utils + +# Sub-pixel drawing precision constants. +# See https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/rasterization/semantic_rasterizer.py#L16 +CV2_SUB_VALUES: Final[Dict[str, Any]] = {"shift": 9, "lineType": cv2.LINE_AA} +CV2_SHIFT_VALUE: Final[int] = 2 ** CV2_SUB_VALUES["shift"] +DEFAULT_PX_PER_M: Final[float] = 2.0 + + +def cv2_subpixel(coords: np.ndarray) -> np.ndarray: + """ + Cast coordinates to numpy.int but keep fractional part by previously multiplying by 2**CV2_SHIFT + cv2 calls will use shift to restore original values with higher precision + + Args: + coords (np.ndarray): XY coords as float + + Returns: + np.ndarray: XY coords as int for cv2 shift draw + """ + return (coords * CV2_SHIFT_VALUE).astype(int) + + +def world_to_subpixel(pts: np.ndarray, raster_from_world: np.ndarray): + return cv2_subpixel(map_utils.transform_points(pts, raster_from_world)) + + +def cv2_draw_polygons( + polygon_pts: List[np.ndarray], + onto_img: np.ndarray, + color: Tuple[int, int, int], +) -> None: + cv2.fillPoly( + img=onto_img, + pts=polygon_pts, + color=color, + **CV2_SUB_VALUES, + ) + + +def cv2_draw_polylines( + polyline_pts: List[np.ndarray], + onto_img: np.ndarray, + color: Tuple[int, int, int], +) -> None: + cv2.polylines( + img=onto_img, + pts=polyline_pts, + isClosed=False, + color=color, + **CV2_SUB_VALUES, + ) + + +def rasterize_world_polygon( + polygon_pts: np.ndarray, + onto_img: np.ndarray, + raster_from_world: np.ndarray, + color: Tuple[int, int, int], +) -> None: + subpixel_area: np.ndarray = world_to_subpixel( + polygon_pts[..., :2], raster_from_world + ) + + # Drawing general road areas. + cv2_draw_polygons(polygon_pts=[subpixel_area], onto_img=onto_img, color=color) + + +def rasterize_world_polylines( + polyline_pts: List[np.ndarray], + onto_img: np.ndarray, + raster_from_world: np.ndarray, + color: Tuple[int, int, int], +) -> None: + subpixel_pts: List[np.ndarray] = [ + world_to_subpixel(pts[..., :2], raster_from_world) for pts in polyline_pts + ] + + # Drawing line. + cv2_draw_polylines( + polyline_pts=subpixel_pts, + onto_img=onto_img, + color=color, + ) + + +def rasterize_lane( + left_edge: np.ndarray, + right_edge: np.ndarray, + onto_img_area: np.ndarray, + onto_img_line: np.ndarray, + raster_from_world: np.ndarray, + area_color: Tuple[int, int, int], + line_color: Tuple[int, int, int], +) -> None: + lane_edges: List[np.ndarray] = [left_edge[:, :2], right_edge[::-1, :2]] + + # Drawing lane area. + rasterize_world_polygon( + np.concatenate(lane_edges, axis=0), + onto_img_area, + raster_from_world, + color=area_color, + ) + + # Drawing lane lines. + rasterize_world_polylines(lane_edges, onto_img_line, raster_from_world, line_color) + + +def rasterize_map( + vec_map: VectorMap, resolution: float, **pbar_kwargs +) -> RasterizedMap: + """Renders the semantic map at the given resolution. + + Args: + vec_map (VectorMap): _description_ + resolution (float): The rasterized image's resolution in pixels per meter. + + Returns: + np.ndarray: The rasterized RGB image. + """ + # extents is [min_x, min_y, min_z, max_x, max_y, max_z] + min_x, min_y, _, max_x, max_y, _ = vec_map.extent + world_center_m: Tuple[float, float] = ( + (min_x + max_x) / 2, + (min_y + max_y) / 2, + ) + + raster_size_x: int = ceil((max_x - min_x) * resolution) + raster_size_y: int = ceil((max_y - min_y) * resolution) + + raster_from_local: np.ndarray = np.array( + [ + [resolution, 0, raster_size_x / 2], + [0, resolution, raster_size_y / 2], + [0, 0, 1], + ] + ) + + # Compute pose from its position and rotation. + pose_from_world: np.ndarray = np.array( + [ + [1, 0, -world_center_m[0]], + [0, 1, -world_center_m[1]], + [0, 0, 1], + ] + ) + + raster_from_world: np.ndarray = raster_from_local @ pose_from_world + + lane_area_img: np.ndarray = np.zeros( + shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 + ) + lane_line_img: np.ndarray = np.zeros( + shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 + ) + ped_area_img: np.ndarray = np.zeros( + shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 + ) + + map_elem: MapElement + for map_elem in tqdm( + vec_map.iter_elems(), + desc=f"Rasterizing Map at {resolution:.2f} px/m", + total=len(vec_map), + **pbar_kwargs, + ): + if map_elem.elem_type == MapElementType.ROAD_LANE: + if map_elem.left_edge is not None and map_elem.right_edge is not None: + # Heading doesn't matter for rasterization. + left_pts: np.ndarray = map_elem.left_edge.xyz + right_pts: np.ndarray = map_elem.right_edge.xyz + + # Need to for-loop because doing it all at once can make holes. + # Drawing lane. + rasterize_lane( + left_pts, + right_pts, + lane_area_img, + lane_line_img, + raster_from_world, + area_color=(255, 0, 0), + line_color=(0, 255, 0), + ) + + # # This code helps visualize centerlines to check if the inferred headings are correct. + # center_pts = cv2_subpixel( + # transform_points( + # proto_to_np(map_elem.road_lane.center, incl_heading=False), + # raster_from_world, + # ) + # )[..., :2] + + # # Drawing lane centerlines. + # cv2.polylines( + # img=lane_line_img, + # pts=center_pts[None, :, :], + # isClosed=False, + # color=(255, 0, 0), + # **CV2_SUB_VALUES, + # ) + + # headings = np.asarray(map_elem.road_lane.center.h_rad) + # delta = cv2_subpixel(30*np.array([np.cos(headings[0]), np.sin(headings[0])])) + # cv2.arrowedLine(img=lane_line_img, pt1=tuple(center_pts[0]), pt2=tuple(center_pts[0] + 10*(center_pts[1] - center_pts[0])), color=(255, 0, 0), shift=9, line_type=cv2.LINE_AA) + # cv2.arrowedLine(img=lane_line_img, pt1=tuple(center_pts[0]), pt2=tuple(center_pts[0] + delta), color=(0, 255, 0), shift=9, line_type=cv2.LINE_AA) + + elif map_elem.elem_type == MapElementType.ROAD_AREA: + # Drawing general road areas. + rasterize_world_polygon( + map_elem.exterior_polygon.xy, + lane_area_img, + raster_from_world, + color=(255, 0, 0), + ) + + for interior_hole in map_elem.interior_holes: + # Removing holes. + rasterize_world_polygon( + interior_hole.xy, lane_area_img, raster_from_world, color=(0, 0, 0) + ) + + elif map_elem.elem_type in { + MapElementType.PED_CROSSWALK, + MapElementType.PED_WALKWAY, + }: + # Drawing crosswalks and walkways. + rasterize_world_polygon( + map_elem.polygon.xy, ped_area_img, raster_from_world, color=(0, 0, 255) + ) + + map_data: np.ndarray = (lane_area_img + lane_line_img + ped_area_img).astype( + np.float32 + ).transpose(2, 0, 1) / 255 + + rasterized_map_info = RasterizedMapMetadata( + name=vec_map.map_name, + shape=map_data.shape, + layers=["drivable_area", "lane_divider", "ped_area"], + layer_rgb_groups=([0], [1], [2]), + resolution=resolution, + map_from_world=raster_from_world, + ) + + return RasterizedMap(rasterized_map_info, map_data) diff --git a/src/trajdata/utils/state_utils.py b/src/trajdata/utils/state_utils.py new file mode 100644 index 0000000..cbf0f08 --- /dev/null +++ b/src/trajdata/utils/state_utils.py @@ -0,0 +1,164 @@ +from typing import Optional + +import numpy as np + +from trajdata.data_structures.state import StateArray, StateTensor +from trajdata.utils.arr_utils import ( + angle_wrap, + rotation_matrix, + transform_angles_np, + transform_coords_2d_np, + transform_coords_np, +) + + +def transform_state_np_2d(state: StateArray, tf_mat_2d: np.ndarray): + """ + Transforms a state into another coordinate frame + assumes center has dim 2 (xy shift) or shape 6 normalizes derivatives as well + """ + new_state = state.copy() + attributes = state._format_dict.keys() + if "x" in attributes and "y" in attributes: + # transform xy position with translation and rotation + new_state.position = transform_coords_np(state.position, tf_mat_2d) + if "xd" in attributes and "yd" in attributes: + # transform velocities + new_state.velocity = transform_coords_np( + state.velocity, tf_mat_2d, translate=False + ) + if "xdd" in attributes and "ydd" in attributes: + # transform acceleration + new_state.acceleration = transform_coords_np( + state.acceleration, tf_mat_2d, translate=False + ) + if "c" in attributes and "s" in attributes: + new_state.heading_vector = transform_coords_np( + state.heading_vector, tf_mat_2d, translate=False + ) + if "h" in attributes: + new_state.heading = transform_angles_np(state.heading, tf_mat_2d) + + return new_state + + +def convert_to_frame_state( + state: StateArray, + stationary: bool = True, + grounded: bool = True, +) -> StateArray: + """ + Returns a StateArray corresponding to a frame centered around the passed in State + """ + frame: StateArray = state.copy() + attributes = state._format_dict.keys() + if stationary: + if "xd" in attributes and "yd" in attributes: + frame.velocity = 0 + if "xdd" in attributes and "ydd" in attributes: + frame.acceleration = 0 + if grounded: + if "z" in attributes: + frame.set_attr("z", 0) + + return frame + + +def transform_to_frame( + state: StateArray, frame_state: StateArray, rot_mat: Optional[np.ndarray] = None +) -> StateArray: + """ + Returns state with coordinates relative to a frame with state frame_state. + Does not modify state in place. + + Args: + state (StateArray): state to transform in world coordinates + frame_state (StateArray): state of frame in world coordinates + rot_mat Optional[nd.array]: rotation matrix A such that c = A @ b returns coordinates in the new frame + if not given, it is computed frome frame_state + """ + new_state = state.copy() + attributes = state._format_dict.keys() + + frame_heading = frame_state.heading[..., 0] + if rot_mat is None: + rot_mat = rotation_matrix(-frame_heading) + + if "x" in attributes and "y" in attributes: + # transform xy position with translation and rotation + new_state.position = transform_coords_2d_np( + state.position, offset=-frame_state.position, rot_mat=rot_mat + ) + if "xd" in attributes and "yd" in attributes: + # transform velocities + new_state.velocity = transform_coords_2d_np( + state.velocity, offset=-frame_state.velocity, rot_mat=rot_mat + ) + if "xdd" in attributes and "ydd" in attributes: + # transform acceleration + new_state.acceleration = transform_coords_2d_np( + state.acceleration, offset=-frame_state.acceleration, rot_mat=rot_mat + ) + if "c" in attributes and "s" in attributes: + new_state.heading_vector = transform_coords_2d_np( + state.heading_vector, rot_mat=rot_mat + ) + if "h" in attributes: + new_state.heading = angle_wrap(state.heading - frame_heading) + + return new_state + + +def transform_from_frame( + state: StateArray, frame_state: StateArray, rot_mat: Optional[np.ndarray] = None +) -> StateArray: + """ + Returns state with coordinates in world frame + Does not modify state in place. + + Args: + state (StateArray): state to transform in world coordinates + frame_state (StateArray): state of frame in world coordinates + rot_mat Optional[nd.array]: rotation matrix A such that c = A @ b returns coordinates in the new frame + if not given, it is computed frome frame_state + """ + new_state = state.copy() + attributes = state._format_dict.keys() + + frame_heading = frame_state.heading[..., 0] + if rot_mat is None: + rot_mat = rotation_matrix(frame_heading) + + if "x" in attributes and "y" in attributes: + # transform xy position with translation and rotation + new_state.position = ( + transform_coords_2d_np(state.position, rot_mat=rot_mat) + + frame_state.position + ) + if "xd" in attributes and "yd" in attributes: + # transform velocities + new_state.velocity = ( + transform_coords_2d_np( + state.velocity, + angle=frame_heading, + ) + + frame_state.velocity + ) + if "xdd" in attributes and "ydd" in attributes: + # transform acceleration + new_state.acceleration = ( + transform_coords_2d_np( + state.acceleration, + angle=frame_heading, + ) + + frame_state.acceleration + ) + if "c" in attributes and "s" in attributes: + new_state.heading_vector = transform_coords_2d_np( + state.heading_vector, + angle=frame_heading, + ) + if "h" in attributes: + new_state.heading = angle_wrap(state.heading + frame_heading) + + return new_state diff --git a/src/trajdata/utils/vis_utils.py b/src/trajdata/utils/vis_utils.py new file mode 100644 index 0000000..cf8f146 --- /dev/null +++ b/src/trajdata/utils/vis_utils.py @@ -0,0 +1,543 @@ +from collections import defaultdict +from typing import List, Optional, Tuple + +import geopandas as gpd +import numpy as np +import pandas as pd +import seaborn as sns +from bokeh.models import ColumnDataSource, GlyphRenderer +from bokeh.plotting import figure +from shapely.geometry import LineString, Polygon + +from trajdata.data_structures.agent import AgentType +from trajdata.data_structures.batch import AgentBatch +from trajdata.data_structures.state import StateArray +from trajdata.maps.vec_map import VectorMap +from trajdata.maps.vec_map_elements import ( + MapElementType, + PedCrosswalk, + PedWalkway, + RoadArea, + RoadLane, +) +from trajdata.utils.arr_utils import transform_coords_2d_np + + +def apply_default_settings(fig: figure) -> None: + # Pixel dimensions match data dimensions, + # a 1x1 area in data space is a square in pixels. + fig.match_aspect = True + + # No gridlines. + fig.grid.visible = False + + # Setting the scroll wheel to active by default. + fig.toolbar.active_scroll = fig.tools[1] + + # Set autohide to true to only show the toolbar when mouse is over plot. + fig.toolbar.autohide = True + + # Setting the match_aspect property of bokeh's default BoxZoomTool. + fig.tools[2].match_aspect = True + + fig.xaxis.axis_label_text_font_size = "10pt" + fig.xaxis.major_label_text_font_size = "10pt" + + fig.yaxis.axis_label_text_font_size = "10pt" + fig.yaxis.major_label_text_font_size = "10pt" + + fig.title.text_font_size = "13pt" + + +def calculate_figure_sizes( + data_bbox: Tuple[float, float, float, float], + data_margin: float = 10, + aspect_ratio: float = 16 / 9, +) -> Tuple[float, float, float, float]: + """_summary_ + + Args: + data_bbox (Tuple[float, float, float, float]): x_min, x_max, y_min, y_max (in data units) + data_margin (float, optional): _description_. Defaults to 10. + aspect_ratio (float, optional): _description_. Defaults to 16/9. + + Returns: + Tuple[float, float, float, float]: Visualization x_min, x_max, y_min, y_max (in data units) matching the desired aspect ratio and clear margin around data points. + """ + x_min, x_max, y_min, y_max = data_bbox + + x_range = x_max - x_min + x_center = (x_min + x_max) / 2 + + y_range = y_max - y_min + y_center = (y_min + y_max) / 2 + + radius = (x_range / 2 if x_range > y_range else y_range / 2) + data_margin + return ( + x_center - radius, + x_center + radius, + y_center - radius / aspect_ratio, + y_center + radius / aspect_ratio, + ) + + +def pretty_print_agent_type(agent_type: AgentType): + return str(agent_type)[len("AgentType.") :].capitalize() + + +def agent_type_to_str(agent_type_int: int) -> str: + return pretty_print_agent_type(AgentType(agent_type_int)) + + +def get_agent_type_color(agent_type: AgentType) -> str: + palette = sns.color_palette("husl", 4).as_hex() + if agent_type == AgentType.VEHICLE: + return palette[0] + elif agent_type == AgentType.PEDESTRIAN: + return "darkorange" + elif agent_type == AgentType.BICYCLE: + return palette[2] + elif agent_type == AgentType.MOTORCYCLE: + return palette[3] + else: + return "#A9A9A9" + + +def get_map_patch_color(map_elem_type: MapElementType) -> str: + if map_elem_type == MapElementType.ROAD_AREA: + return "gray" + elif map_elem_type == MapElementType.ROAD_LANE: + return "red" + elif map_elem_type == MapElementType.PED_CROSSWALK: + return "gold" # "blue" + elif map_elem_type == MapElementType.PED_WALKWAY: + return "green" + else: + raise ValueError() + + +def get_multi_line_bbox( + lines_data: ColumnDataSource, +) -> Tuple[float, float, float, float]: + """_summary_ + + Args: + lines_data (ColumnDataSource): _description_ + + Returns: + Tuple[float, float, float, float]: x_min, x_max, y_min, y_max + """ + all_xs = np.concatenate(lines_data.data["xs"], axis=0) + all_ys = np.concatenate(lines_data.data["ys"], axis=0) + all_xy = np.stack((all_xs, all_ys), axis=1) + x_min, y_min = np.nanmin(all_xy, axis=0) + x_max, y_max = np.nanmax(all_xy, axis=0) + return ( + x_min.item(), + x_max.item(), + y_min.item(), + y_max.item(), + ) + + +def compute_agent_rect_coords( + agent_type: int, heading: float, length: float, width: float +) -> Tuple[np.ndarray, np.ndarray]: + agent_rect_coords = transform_coords_2d_np( + np.array( + [ + [-length / 2, -width / 2], + [-length / 2, width / 2], + [length / 2, width / 2], + [length / 2, -width / 2], + ] + ), + angle=heading, + ) + + size = 1.0 + if agent_type == AgentType.PEDESTRIAN or agent_type == AgentType.BICYCLE: + size = 0.25 + + dir_patch_coords = transform_coords_2d_np( + np.array( + [ + [0, np.sqrt(3) / 3], + [-1 / 2, -np.sqrt(3) / 6], + [1 / 2, -np.sqrt(3) / 6], + ] + ) + * size, + angle=heading - np.pi / 2, + ) + + return agent_rect_coords, dir_patch_coords + + +def compute_agent_rects_coords( + agent_type: int, hs: np.ndarray, lengths: np.ndarray, widths: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + raw_rect_coords = np.stack( + ( + np.stack((-lengths / 2, -widths / 2), axis=-1), + np.stack((-lengths / 2, widths / 2), axis=-1), + np.stack((lengths / 2, widths / 2), axis=-1), + np.stack((lengths / 2, -widths / 2), axis=-1), + ), + axis=-2, + ) + + agent_rect_coords = transform_coords_2d_np( + raw_rect_coords, + angle=hs[:, None].repeat(raw_rect_coords.shape[-2], axis=-1), + ) + + size = 1.0 + if agent_type == AgentType.PEDESTRIAN or agent_type == AgentType.BICYCLE: + size = 0.25 + + raw_tri_coords = size * np.array( + [ + [ + [0, np.sqrt(3) / 3], + [-1 / 2, -np.sqrt(3) / 6], + [1 / 2, -np.sqrt(3) / 6], + ] + ] + ).repeat(hs.shape[0], axis=0) + + dir_patch_coords = transform_coords_2d_np( + raw_tri_coords, + angle=hs[:, None].repeat(raw_tri_coords.shape[-2], axis=-1) - np.pi / 2, + ) + + return agent_rect_coords, dir_patch_coords + + +def extract_full_agent_data_df(batch: AgentBatch, batch_idx: int) -> pd.DataFrame: + main_data_dict = defaultdict(list) + + # Historical information + ## Agent + H = batch.agent_hist_len[batch_idx].item() + agent_type = batch.agent_type[batch_idx].item() + agent_extent: np.ndarray = batch.agent_hist_extent[batch_idx, -H:].cpu().numpy() + agent_hist_np: StateArray = batch.agent_hist[batch_idx, -H:].cpu().numpy() + + speed_mps = np.linalg.norm(agent_hist_np.velocity, axis=1) + + xs = agent_hist_np.get_attr("x") + ys = agent_hist_np.get_attr("y") + hs = agent_hist_np.get_attr("h") + + lengths = agent_extent[:, 0] + widths = agent_extent[:, 1] + + agent_rect_coords, dir_patch_coords = compute_agent_rects_coords( + agent_type, hs, lengths, widths + ) + + main_data_dict["id"].extend([0] * H) + main_data_dict["t"].extend(range(-H + 1, 1)) + main_data_dict["x"].extend(xs) + main_data_dict["y"].extend(ys) + main_data_dict["h"].extend(hs) + main_data_dict["rect_xs"].extend(agent_rect_coords[..., 0] + xs[:, None]) + main_data_dict["rect_ys"].extend(agent_rect_coords[..., 1] + ys[:, None]) + main_data_dict["dir_patch_xs"].extend(dir_patch_coords[..., 0] + xs[:, None]) + main_data_dict["dir_patch_ys"].extend(dir_patch_coords[..., 1] + ys[:, None]) + main_data_dict["speed_mps"].extend(speed_mps) + main_data_dict["speed_kph"].extend(speed_mps * 3.6) + main_data_dict["type"].extend([agent_type_to_str(agent_type)] * H) + main_data_dict["length"].extend(lengths) + main_data_dict["width"].extend(widths) + main_data_dict["pred_agent"].extend([True] * H) + main_data_dict["color"].extend([get_agent_type_color(agent_type)] * H) + + ## Neighbors + num_neighbors: int = batch.num_neigh[batch_idx].item() + + for n_neigh in range(num_neighbors): + H = batch.neigh_hist_len[batch_idx, n_neigh].item() + agent_type = batch.neigh_types[batch_idx, n_neigh].item() + agent_extent: np.ndarray = ( + batch.neigh_hist_extents[batch_idx, n_neigh, -H:].cpu().numpy() + ) + agent_hist_np: StateArray = ( + batch.neigh_hist[batch_idx, n_neigh, -H:].cpu().numpy() + ) + + speed_mps = np.linalg.norm(agent_hist_np.velocity, axis=1) + + xs = agent_hist_np.get_attr("x") + ys = agent_hist_np.get_attr("y") + hs = agent_hist_np.get_attr("h") + + lengths = agent_extent[:, 0] + widths = agent_extent[:, 1] + + agent_rect_coords, dir_patch_coords = compute_agent_rects_coords( + agent_type, hs, lengths, widths + ) + + main_data_dict["id"].extend([n_neigh + 1] * H) + main_data_dict["t"].extend(range(-H + 1, 1)) + main_data_dict["x"].extend(xs) + main_data_dict["y"].extend(ys) + main_data_dict["h"].extend(hs) + main_data_dict["rect_xs"].extend(agent_rect_coords[..., 0] + xs[:, None]) + main_data_dict["rect_ys"].extend(agent_rect_coords[..., 1] + ys[:, None]) + main_data_dict["dir_patch_xs"].extend(dir_patch_coords[..., 0] + xs[:, None]) + main_data_dict["dir_patch_ys"].extend(dir_patch_coords[..., 1] + ys[:, None]) + main_data_dict["speed_mps"].extend(speed_mps) + main_data_dict["speed_kph"].extend(speed_mps * 3.6) + main_data_dict["type"].extend([agent_type_to_str(agent_type)] * H) + main_data_dict["length"].extend(lengths) + main_data_dict["width"].extend(widths) + main_data_dict["pred_agent"].extend([False] * H) + main_data_dict["color"].extend([get_agent_type_color(agent_type)] * H) + + # Future information + ## Agent + T = batch.agent_fut_len[batch_idx].item() + agent_type = batch.agent_type[batch_idx].item() + agent_extent: np.ndarray = batch.agent_fut_extent[batch_idx, :T].cpu().numpy() + agent_fut_np: StateArray = batch.agent_fut[batch_idx, :T].cpu().numpy() + + speed_mps = np.linalg.norm(agent_fut_np.velocity, axis=1) + + xs = agent_fut_np.get_attr("x") + ys = agent_fut_np.get_attr("y") + hs = agent_fut_np.get_attr("h") + + lengths = agent_extent[:, 0] + widths = agent_extent[:, 1] + + agent_rect_coords, dir_patch_coords = compute_agent_rects_coords( + agent_type, hs, lengths, widths + ) + + main_data_dict["id"].extend([0] * T) + main_data_dict["t"].extend(range(1, T + 1)) + main_data_dict["x"].extend(xs) + main_data_dict["y"].extend(ys) + main_data_dict["h"].extend(hs) + main_data_dict["rect_xs"].extend(agent_rect_coords[..., 0] + xs[:, None]) + main_data_dict["rect_ys"].extend(agent_rect_coords[..., 1] + ys[:, None]) + main_data_dict["dir_patch_xs"].extend(dir_patch_coords[..., 0] + xs[:, None]) + main_data_dict["dir_patch_ys"].extend(dir_patch_coords[..., 1] + ys[:, None]) + main_data_dict["speed_mps"].extend(speed_mps) + main_data_dict["speed_kph"].extend(speed_mps * 3.6) + main_data_dict["type"].extend([agent_type_to_str(agent_type)] * T) + main_data_dict["length"].extend(lengths) + main_data_dict["width"].extend(widths) + main_data_dict["pred_agent"].extend([True] * T) + main_data_dict["color"].extend([get_agent_type_color(agent_type)] * T) + + ## Neighbors + num_neighbors: int = batch.num_neigh[batch_idx].item() + + for n_neigh in range(num_neighbors): + T = batch.neigh_fut_len[batch_idx, n_neigh].item() + agent_type = batch.neigh_types[batch_idx, n_neigh].item() + agent_extent: np.ndarray = ( + batch.neigh_fut_extents[batch_idx, n_neigh, :T].cpu().numpy() + ) + agent_fut_np: StateArray = batch.neigh_fut[batch_idx, n_neigh, :T].cpu().numpy() + + speed_mps = np.linalg.norm(agent_fut_np.velocity, axis=1) + + xs = agent_fut_np.get_attr("x") + ys = agent_fut_np.get_attr("y") + hs = agent_fut_np.get_attr("h") + + lengths = agent_extent[:, 0] + widths = agent_extent[:, 1] + + agent_rect_coords, dir_patch_coords = compute_agent_rects_coords( + agent_type, hs, lengths, widths + ) + + main_data_dict["id"].extend([n_neigh + 1] * T) + main_data_dict["t"].extend(range(1, T + 1)) + main_data_dict["x"].extend(xs) + main_data_dict["y"].extend(ys) + main_data_dict["h"].extend(hs) + main_data_dict["rect_xs"].extend(agent_rect_coords[..., 0] + xs[:, None]) + main_data_dict["rect_ys"].extend(agent_rect_coords[..., 1] + ys[:, None]) + main_data_dict["dir_patch_xs"].extend(dir_patch_coords[..., 0] + xs[:, None]) + main_data_dict["dir_patch_ys"].extend(dir_patch_coords[..., 1] + ys[:, None]) + main_data_dict["speed_mps"].extend(speed_mps) + main_data_dict["speed_kph"].extend(speed_mps * 3.6) + main_data_dict["type"].extend([agent_type_to_str(agent_type)] * T) + main_data_dict["length"].extend(lengths) + main_data_dict["width"].extend(widths) + main_data_dict["pred_agent"].extend([False] * T) + main_data_dict["color"].extend([get_agent_type_color(agent_type)] * T) + + return pd.DataFrame(main_data_dict) + + +def convert_to_gpd(vec_map: VectorMap) -> gpd.GeoDataFrame: + geo_data = defaultdict(list) + for elem in vec_map.iter_elems(): + geo_data["id"].append(elem.id) + geo_data["type"].append(elem.elem_type) + if isinstance(elem, RoadLane): + geo_data["geometry"].append(LineString(elem.center.xyz)) + elif isinstance(elem, PedCrosswalk) or isinstance(elem, PedWalkway): + geo_data["geometry"].append(Polygon(shell=elem.polygon.xyz)) + elif isinstance(elem, RoadArea): + geo_data["geometry"].append( + Polygon( + shell=elem.exterior_polygon.xyz, + holes=[hole.xyz for hole in elem.interior_holes], + ) + ) + + return gpd.GeoDataFrame(geo_data) + + +def get_map_cds( + map_from_world_tf: np.ndarray, + vec_map: VectorMap, + bbox: Optional[Tuple[float, float, float, float]] = None, +) -> Tuple[ + ColumnDataSource, + ColumnDataSource, + ColumnDataSource, + ColumnDataSource, + ColumnDataSource, +]: + road_lane_data = defaultdict(list) + lane_center_data = defaultdict(list) + ped_crosswalk_data = defaultdict(list) + ped_walkway_data = defaultdict(list) + road_area_data = defaultdict(list) + + map_gpd = convert_to_gpd(vec_map) + affine_tf_params = ( + map_from_world_tf[:2, :2].flatten().tolist() + + map_from_world_tf[:2, -1].flatten().tolist() + ) + map_gpd["geometry"] = map_gpd["geometry"].affine_transform(affine_tf_params) + + elems_gdf: gpd.GeoDataFrame + if bbox is not None: + elems_gdf = map_gpd.cx[bbox[0] : bbox[1], bbox[2] : bbox[3]] + else: + elems_gdf = map_gpd + + for row_idx, row in elems_gdf.iterrows(): + if row["type"] == MapElementType.PED_CROSSWALK: + xy = np.stack(row["geometry"].exterior.xy, axis=1) + ped_crosswalk_data["xs"].append(xy[..., 0]) + ped_crosswalk_data["ys"].append(xy[..., 1]) + if row["type"] == MapElementType.PED_WALKWAY: + xy = np.stack(row["geometry"].exterior.xy, axis=1) + ped_walkway_data["xs"].append(xy[..., 0]) + ped_walkway_data["ys"].append(xy[..., 1]) + elif row["type"] == MapElementType.ROAD_LANE: + xy = np.stack(row["geometry"].xy, axis=1) + lane_center_data["xs"].append(xy[..., 0]) + lane_center_data["ys"].append(xy[..., 1]) + lane_obj: RoadLane = vec_map.elements[MapElementType.ROAD_LANE][row["id"]] + if lane_obj.left_edge is not None and lane_obj.right_edge is not None: + left_xy = lane_obj.left_edge.xy + right_xy = lane_obj.right_edge.xy[::-1] + patch_xy = np.concatenate((left_xy, right_xy), axis=0) + + transformed_xy: np.ndarray = transform_coords_2d_np( + patch_xy, + offset=map_from_world_tf[:2, -1], + rot_mat=map_from_world_tf[:2, :2], + ) + + road_lane_data["xs"].append(transformed_xy[..., 0]) + road_lane_data["ys"].append(transformed_xy[..., 1]) + elif row["type"] == MapElementType.ROAD_AREA: + xy = np.stack(row["geometry"].exterior.xy, axis=1) + holes_xy: List[np.ndarray] = [ + np.stack(interior.xy, axis=1) for interior in row["geometry"].interiors + ] + + road_area_data["xs"].append( + [[xy[..., 0]] + [hole[..., 0] for hole in holes_xy]] + ) + road_area_data["ys"].append( + [[xy[..., 1]] + [hole[..., 1] for hole in holes_xy]] + ) + return ( + ColumnDataSource(data=lane_center_data), + ColumnDataSource(data=road_lane_data), + ColumnDataSource(data=ped_crosswalk_data), + ColumnDataSource(data=ped_walkway_data), + ColumnDataSource(data=road_area_data), + ) + + +def draw_map_elems( + fig: figure, + vec_map: VectorMap, + map_from_world_tf: np.ndarray, + bbox: Optional[Tuple[float, float, float, float]] = None, + **kwargs, +) -> Tuple[GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer]: + """_summary_ + + Args: + fig (Figure): _description_ + vec_map (VectorMap): _description_ + map_from_world_tf (np.ndarray): _description_ + bbox (Tuple[float, float, float, float]): x_min, x_max, y_min, y_max + + Returns: + Tuple[GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer]: _description_ + """ + ( + lane_center_cds, + road_lane_cds, + ped_crosswalk_cds, + ped_walkway_cds, + road_area_cds, + ) = get_map_cds(map_from_world_tf, vec_map, bbox) + + road_areas = fig.multi_polygons( + source=road_area_cds, + line_color="black", + line_width=0.3, + fill_alpha=0.1, + fill_color=get_map_patch_color(MapElementType.ROAD_AREA), + ) + + road_lanes = fig.patches( + source=road_lane_cds, + line_color="black", + line_width=0.3, + fill_alpha=0.1, + fill_color=get_map_patch_color(MapElementType.ROAD_LANE), + ) + + ped_crosswalks = fig.patches( + source=ped_crosswalk_cds, + line_color="black", + line_width=0.3, + fill_alpha=0.5, + fill_color=get_map_patch_color(MapElementType.PED_CROSSWALK), + ) + + ped_walkways = fig.patches( + source=ped_walkway_cds, + line_color="black", + line_width=0.3, + fill_alpha=0.3, + fill_color=get_map_patch_color(MapElementType.PED_WALKWAY), + ) + + lane_centers = fig.multi_line( + source=lane_center_cds, + line_color="gray", + line_alpha=0.5, + ) + + return road_areas, road_lanes, ped_crosswalks, ped_walkways, lane_centers diff --git a/src/trajdata/visualization/__init__.py b/src/trajdata/visualization/__init__.py index 5951e91..4b3301c 100644 --- a/src/trajdata/visualization/__init__.py +++ b/src/trajdata/visualization/__init__.py @@ -1 +1,3 @@ +from .interactive_figure import InteractiveFigure +from .interactive_vis import plot_agent_batch_interactive from .vis import plot_agent_batch, plot_scene_batch diff --git a/src/trajdata/visualization/interactive_animation.py b/src/trajdata/visualization/interactive_animation.py new file mode 100644 index 0000000..bf71b2c --- /dev/null +++ b/src/trajdata/visualization/interactive_animation.py @@ -0,0 +1,534 @@ +import logging +import socket +import threading +import time +import warnings +from collections import defaultdict +from contextlib import closing +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import cv2 +import numpy as np +import pandas as pd +from bokeh.application import Application +from bokeh.application.handlers import FunctionHandler +from bokeh.document import Document, without_document_lock +from bokeh.io.export import get_screenshot_as_png +from bokeh.layouts import column, row +from bokeh.models import ( + BooleanFilter, + Button, + CDSView, + ColumnDataSource, + HoverTool, + Legend, + LegendItem, + RangeSlider, + Select, + Slider, +) +from bokeh.plotting import figure +from bokeh.server.server import Server +from selenium import webdriver +from tornado import gen +from tornado.ioloop import IOLoop +from tqdm import trange + +from trajdata.data_structures.agent import AgentType +from trajdata.data_structures.batch import AgentBatch +from trajdata.data_structures.state import StateArray +from trajdata.maps.map_api import MapAPI +from trajdata.utils import vis_utils + + +class InteractiveAnimation: + def __init__( + self, + main_func: Callable[[Document, IOLoop], None], + port: Optional[int] = None, + **kwargs, + ) -> None: + self.main_func = main_func + self.port = port + self.kwargs = kwargs + + def get_open_port(self) -> int: + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + def show(self) -> None: + io_loop = IOLoop() + + if self.port is None: + self.port = self.get_open_port() + + def kill_on_tab_close(session_context): + io_loop.stop() + + def app_init(doc: Document): + doc.on_session_destroyed(kill_on_tab_close) + self.main_func(doc=doc, io_loop=io_loop, **self.kwargs) + return doc + + server = Server( + {"/": Application(FunctionHandler(app_init))}, + io_loop=io_loop, + port=self.port, + check_unused_sessions_milliseconds=500, + unused_session_lifetime_milliseconds=500, + ) + server.start() + + # print(f"Opening Bokeh application on http://localhost:{self.port}/") + server.io_loop.add_callback(server.show, "/") + server.io_loop.start() + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning) + server.io_loop.close() + + +def animate_agent_batch_interactive( + doc: Document, io_loop: IOLoop, batch: AgentBatch, batch_idx: int, cache_path: Path +) -> None: + agent_data_df = vis_utils.extract_full_agent_data_df(batch, batch_idx) + + # Figure creation and a few initial settings. + width: int = 1280 + aspect_ratio: float = 16 / 9 + data_vis_margin: float = 10.0 + + x_min = agent_data_df["x"].min() + x_max = agent_data_df["x"].max() + + y_min = agent_data_df["y"].min() + y_max = agent_data_df["y"].max() + + ( + x_range_min, + x_range_max, + y_range_min, + y_range_max, + ) = vis_utils.calculate_figure_sizes( + data_bbox=(x_min, x_max, y_min, y_max), + data_margin=data_vis_margin, + aspect_ratio=aspect_ratio, + ) + + kwargs = { + "x_range": (x_range_min, x_range_max), + "y_range": (y_range_min, y_range_max), + } + + fig = figure( + width=width, + height=int(width / aspect_ratio), + output_backend="canvas", + **kwargs, + ) + vis_utils.apply_default_settings(fig) + + agent_name: str = batch.agent_name[batch_idx] + agent_type: AgentType = AgentType(batch.agent_type[batch_idx].item()) + current_state: StateArray = batch.curr_agent_state[batch_idx].cpu().numpy() + map_id: str = batch.map_names[batch_idx] + env_name, map_name = map_id.split(":") + scene_id: str = batch.scene_ids[batch_idx] + fig.title = ( + f"Dataset: {env_name}, Location: {map_name}, Scene: {scene_id}" + + "\n" + + f"Agent ID: {agent_name} ({vis_utils.pretty_print_agent_type(agent_type)}) at x = {current_state[0]:.2f} m, y = {current_state[1]:.2f} m, heading = {current_state[-1]:.2f} rad ({np.rad2deg(current_state[-1]):.2f} deg)" + ) + + # Map plotting. + if batch.map_names is not None: + mapAPI = MapAPI(cache_path) + + vec_map = mapAPI.get_map( + batch.map_names[batch_idx], + incl_road_lanes=True, + incl_road_areas=True, + incl_ped_crosswalks=True, + incl_ped_walkways=True, + ) + + ( + road_areas, + road_lanes, + ped_crosswalks, + ped_walkways, + lane_centers, + ) = vis_utils.draw_map_elems( + fig, + vec_map, + batch.agents_from_world_tf[batch_idx].cpu().numpy(), + bbox=( + x_min - data_vis_margin, + x_max + data_vis_margin, + y_min - data_vis_margin, + y_max + data_vis_margin, + ), + ) + + # Preparing agent information for fast slicing with the time_slider. + agent_cds = ColumnDataSource(agent_data_df) + curr_time_view = CDSView(filter=BooleanFilter((agent_cds.data["t"] == 0).tolist())) + + # Some neighbors can have more history than the agent to be predicted + # (the data-collecting agent has observed the neighbors for longer). + full_H = max( + batch.agent_hist_len[batch_idx].item(), + *batch.neigh_hist_len[batch_idx].tolist(), + ) + full_T = max( + batch.agent_fut_len[batch_idx].item(), *batch.neigh_fut_len[batch_idx].tolist() + ) + + def create_multi_line_data(agents_df: pd.DataFrame) -> Dict[str, List]: + lines_data = defaultdict(list) + for agent_id, agent_df in agents_df.groupby(by="id"): + xs, ys, color = ( + agent_df.x.to_numpy(), + agent_df.y.to_numpy(), + agent_df.color.iat[0], + ) + + if agent_id == 0: + pad_before = full_H - batch.agent_hist_len[batch_idx].item() + pad_after = full_T - batch.agent_fut_len[batch_idx].item() + + else: + pad_before = ( + full_H - batch.neigh_hist_len[batch_idx, agent_id - 1].item() + ) + pad_after = full_T - batch.neigh_fut_len[batch_idx, agent_id - 1].item() + + xs = np.pad(xs, (pad_before, pad_after), constant_values=np.nan) + ys = np.pad(ys, (pad_before, pad_after), constant_values=np.nan) + + lines_data["xs"].append(xs) + lines_data["ys"].append(ys) + lines_data["color"].append(color) + + return lines_data + + def slice_multi_line_data( + multi_line_df: Dict[str, Any], slice_obj, check_idx: int + ) -> Dict[str, List]: + lines_data = defaultdict(list) + for i in range(len(multi_line_df["xs"])): + sliced_xs = multi_line_df["xs"][i][slice_obj] + sliced_ys = multi_line_df["ys"][i][slice_obj] + if ( + sliced_xs.shape[0] > 0 + and sliced_ys.shape[0] > 0 + and np.isfinite(sliced_xs[check_idx]) + and np.isfinite(sliced_ys[check_idx]) + ): + lines_data["xs"].append(sliced_xs) + lines_data["ys"].append(sliced_ys) + lines_data["color"].append(multi_line_df["color"][i]) + + return lines_data + + # Getting initial historical and future trajectory information ready for plotting. + history_line_data_df = create_multi_line_data(agent_data_df) + history_lines_cds = ColumnDataSource( + slice_multi_line_data(history_line_data_df, slice(None, full_H), check_idx=-1) + ) + future_line_data_df = history_line_data_df.copy() + future_lines_cds = ColumnDataSource( + slice_multi_line_data(future_line_data_df, slice(full_H, None), check_idx=0) + ) + + history_lines = fig.multi_line( + xs="xs", + ys="ys", + line_color="color", + line_dash="dashed", + line_width=2, + source=history_lines_cds, + ) + + future_lines = fig.multi_line( + xs="xs", + ys="ys", + line_color="color", + line_dash="solid", + line_width=2, + source=future_lines_cds, + ) + + # Agent rectangles/directional arrows at the current timestep. + agent_rects = fig.patches( + xs="rect_xs", + ys="rect_ys", + fill_color="color", + line_color="black", + # fill_alpha=0.7, + source=agent_cds, + view=curr_time_view, + ) + + agent_dir_patches = fig.patches( + xs="dir_patch_xs", + ys="dir_patch_ys", + fill_color="color", + line_color="black", + # fill_alpha=0.7, + source=agent_cds, + view=curr_time_view, + ) + + scene_ts: int = batch.scene_ts[batch_idx].item() + + # Controlling the timestep shown to users. + time_slider = Slider( + start=agent_cds.data["t"].min(), + end=agent_cds.data["t"].max(), + step=1, + value=0, + title=f"Current Timestep (scene timestep {scene_ts})", + ) + + dt: float = batch.dt[batch_idx].item() + + # Ensuring that information gets updated upon a cahnge in the slider value. + def time_callback(attr, old, new) -> None: + curr_time_view.filter = BooleanFilter((agent_cds.data["t"] == new).tolist()) + history_lines_cds.data = slice_multi_line_data( + history_line_data_df, slice(None, new + full_H), check_idx=-1 + ) + future_lines_cds.data = slice_multi_line_data( + future_line_data_df, slice(new + full_H, None), check_idx=0 + ) + + if new == 0: + time_slider.title = f"Current Timestep (scene timestep {scene_ts})" + else: + n_steps = abs(new) + time_slider.title = f"{n_steps} timesteps ({n_steps * dt:.2f} s) into the {'future' if new > 0 else 'past'}" + + time_slider.on_change("value", time_callback) + + # Adding tooltips on mouse hover. + fig.add_tools( + HoverTool( + tooltips=[ + ("Class", "@type"), + ("Position", "(@x, @y) m"), + ("Speed", "@speed_mps m/s (@speed_kph km/h)"), + ], + renderers=[agent_rects], + ) + ) + + def button_callback(): + # Stop the server. + io_loop.stop() + + exit_button = Button(label="Exit", button_type="danger", width=60) + exit_button.on_click(button_callback) + + # Writing animation callback functions so that the play/pause button animate the + # data according to its native dt. + def animate_update(): + t = time_slider.value + 1 + + if t > time_slider.end: + # If slider value + 1 is above max, reset to 0. + t = 0 + + time_slider.value = t + + play_cb_manager = [None] + + def animate(): + if play_button.label.startswith("►"): + play_button.label = "❚❚ Pause" + + play_cb_manager[0] = doc.add_periodic_callback( + animate_update, period_milliseconds=int(dt * 1000) + ) + else: + play_button.label = "► Play" + doc.remove_periodic_callback(play_cb_manager[0]) + + play_button = Button(label="► Play", width=100) + play_button.on_click(animate) + + # Creating the legend elements and connecting them to their original elements + # (allows us to hide them on click later!) + agent_legend_elems = [ + fig.rect( + fill_color=vis_utils.get_agent_type_color(x), + line_color="black", + name=vis_utils.agent_type_to_str(x), + ) + for x in AgentType + ] + + map_legend_elems = [LegendItem(label="Lane Center", renderers=[lane_centers])] + + map_area_legend_elems = [ + LegendItem(label="Road Area", renderers=[road_areas]), + LegendItem(label="Road Lanes", renderers=[road_lanes]), + LegendItem(label="Crosswalks", renderers=[ped_crosswalks]), + LegendItem(label="Sidewalks", renderers=[ped_walkways]), + ] + + hist_future_legend_elems = [ + LegendItem( + label="Past Motion", + renderers=[ + history_lines, + fig.multi_line( + line_color="black", line_dash="dashed", line_alpha=1.0, line_width=2 + ), + ], + ), + LegendItem( + label="Future Motion", + renderers=[ + future_lines, + fig.multi_line( + line_color="black", line_dash="solid", line_alpha=1.0, line_width=2 + ), + ], + ), + ] + + # Adding the legend to the figure. + legend = Legend( + items=[ + LegendItem(label=legend_item.name, renderers=[legend_item]) + for legend_item in agent_legend_elems + ] + + hist_future_legend_elems + + map_legend_elems + + map_area_legend_elems, + click_policy="hide", + label_text_font_size="15pt", + spacing=10, + ) + fig.add_layout(legend, "right") + + # Video rendering functions. + video_button = Button( + label="Render Video", + width=120, + ) + + render_range_slider = RangeSlider( + value=(0, time_slider.end), + start=time_slider.start, + end=time_slider.end, + title=f"Timesteps to Render", + ) + + filetype_select = Select( + title="Filetype:", value=".mp4", options=[".mp4", ".avi"], width=80 + ) + + def reset_buttons() -> None: + video_button.label = "Render Video" + video_button.disabled = False + time_slider.disabled = False + render_range_slider.disabled = False + filetype_select.disabled = False + play_button.disabled = False + fig.toolbar_location = "right" + + logging.basicConfig(level=logging.WARNING, force=True) + + def after_frame_save(label: str) -> None: + video_button.label = label + animate_update() + + def execute_save_animation(file_path: Path) -> None: + images = [] + + chrome_options = webdriver.ChromeOptions() + chrome_options.headless = True + driver = webdriver.Chrome(chrome_options=chrome_options) + + n_frames = render_range_slider.value[1] - render_range_slider.value[0] + 1 + for frame_index in trange(n_frames, desc="Rendering Video"): + # Giving the doc a chance to update the figure. + time.sleep(0.1) + + image = get_screenshot_as_png(fig, driver=driver) + shape = image.size + images.append(image) + + doc.add_next_tick_callback( + partial( + after_frame_save, + label=f"Rendering... ({100*(frame_index+1)/n_frames:.0f}%)", + ) + ) + + if file_path.suffix == ".mp4": + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + elif file_path.suffix == ".avi": + fourcc = cv2.VideoWriter_fourcc("M", "J", "P", "G") + + video_obj = cv2.VideoWriter( + filename=str(file_path), fourcc=fourcc, fps=1.0 / dt, frameSize=shape + ) + for image in images: + cv2_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + video_obj.write(cv2_image) + video_obj.release() + + doc.add_next_tick_callback(reset_buttons) + + @gen.coroutine + @without_document_lock + def save_animation(filename: str) -> None: + video_button.label = "Rendering..." + video_button.disabled = True + time_slider.disabled = True + render_range_slider.disabled = True + filetype_select.disabled = True + play_button.disabled = True + fig.toolbar_location = None + + # Bokeh logs a lot of warnings here related to some figure elements not having + # 'x', 'y', etc attributes set (most of these are legend items for which this + # is intentional). Accordingly, ignore WARNINGs and ERRORs now and re-enable + # them after. + logging.basicConfig(level=logging.CRITICAL, force=True) + + # Stop any ongoing animation. + if play_button.label.startswith("❚❚"): + animate() + + # Reset the current timestep to the left end of the range. + time_slider.value = render_range_slider.value[0] + + threading.Thread( + target=execute_save_animation, + args=(Path(filename + filetype_select.value),), + ).start() + + video_button.on_click( + partial( + save_animation, + filename=( + "_".join([env_name, map_name, scene_id, f"t{scene_ts}", agent_name]) + ), + ) + ) + + doc.add_root( + column( + fig, + row(play_button, time_slider, exit_button), + row(video_button, render_range_slider, filetype_select), + ) + ) diff --git a/src/trajdata/visualization/interactive_figure.py b/src/trajdata/visualization/interactive_figure.py new file mode 100644 index 0000000..b61ed59 --- /dev/null +++ b/src/trajdata/visualization/interactive_figure.py @@ -0,0 +1,169 @@ +from typing import Optional, Tuple + +import bokeh.plotting as plt +import numpy as np +import torch +from bokeh.models import ColumnDataSource, Range1d +from bokeh.models.renderers import GlyphRenderer +from bokeh.plotting import figure +from torch import Tensor + +from trajdata.data_structures.agent import AgentType +from trajdata.data_structures.state import StateTensor +from trajdata.maps import VectorMap +from trajdata.utils import vis_utils + + +class InteractiveFigure: + def __init__(self, **kwargs) -> None: + self.aspect_ratio: float = kwargs.get("aspect_ratio", 16 / 9) + self.width: int = kwargs.get("width", 1280) + self.height: int = kwargs.get("height", int(self.width / self.aspect_ratio)) + + # We'll be tracking the maxes and mins of data with these. + self.x_min = np.inf + self.x_max = -np.inf + self.y_min = np.inf + self.y_max = -np.inf + + self.raw_figure = figure(width=self.width, height=self.height, **kwargs) + vis_utils.apply_default_settings(self.raw_figure) + + def update_mins_maxs(self, x_min, x_max, y_min, y_max) -> None: + self.x_min = min(self.x_min, x_min) + self.x_max = max(self.x_max, x_max) + self.y_min = min(self.y_min, y_min) + self.y_max = max(self.y_max, y_max) + + def show(self) -> None: + if np.isfinite((self.x_min, self.x_max, self.y_min, self.y_max)).all(): + ( + x_range_min, + x_range_max, + y_range_min, + y_range_max, + ) = vis_utils.calculate_figure_sizes( + data_bbox=(self.x_min, self.x_max, self.y_min, self.y_max), + aspect_ratio=self.aspect_ratio, + ) + + self.raw_figure.x_range = Range1d(x_range_min, x_range_max) + self.raw_figure.y_range = Range1d(y_range_min, y_range_max) + + plt.show(self.raw_figure) + + def add_line(self, states: StateTensor, **kwargs) -> GlyphRenderer: + xy_pos = states.position.cpu().numpy() + + x_min, y_min = np.nanmin(xy_pos, axis=0) + x_max, y_max = np.nanmax(xy_pos, axis=0) + self.update_mins_maxs(x_min.item(), x_max.item(), y_min.item(), y_max.item()) + + return self.raw_figure.line(xy_pos[:, 0], xy_pos[:, 1], **kwargs) + + def add_lines(self, lines_data: ColumnDataSource, **kwargs) -> GlyphRenderer: + self.update_mins_maxs(*vis_utils.get_multi_line_bbox(lines_data)) + return self.raw_figure.multi_line( + source=lines_data, + # This is to ensure that the columns given in the + # ColumnDataSource are respected (e.g., "line_color"). + **{x: x for x in lines_data.column_names}, + **kwargs, + ) + + def add_map( + self, + map_from_world_tf: np.ndarray, + vec_map: VectorMap, + bbox: Optional[Tuple[float, float, float, float]] = None, + **kwargs, + ) -> Tuple[ + GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer + ]: + """_summary_ + + Args: + map_from_world_tf (np.ndarray): _description_ + vec_map (VectorMap): _description_ + bbox (Tuple[float, float, float, float]): x_min, x_max, y_min, y_max + + Returns: + Tuple[ GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer ]: _description_ + """ + return vis_utils.draw_map_elems( + self.raw_figure, vec_map, map_from_world_tf, bbox, **kwargs + ) + + def add_agent( + self, + agent_type: AgentType, + agent_state: StateTensor, + agent_extent: Tensor, + **kwargs, + ) -> Tuple[GlyphRenderer, GlyphRenderer]: + """Draws an agent at the given location, heading, and dimensions. + + Args: + agent_type (AgentType): _description_ + agent_state (Tensor): _description_ + agent_extent (Tensor): _description_ + """ + if torch.any(torch.isnan(agent_extent)): + raise ValueError("Agent extents cannot be NaN!") + + length = agent_extent[0].item() + width = agent_extent[1].item() + + x, y = agent_state.position.cpu().numpy() + heading = agent_state.heading.cpu().numpy() + + agent_rect_coords, dir_patch_coords = vis_utils.compute_agent_rect_coords( + agent_type, heading, length, width + ) + + source = { + "x": agent_rect_coords[:, 0] + x, + "y": agent_rect_coords[:, 1] + y, + "type": [vis_utils.pretty_print_agent_type(agent_type)], + "speed": [torch.linalg.norm(agent_state.velocity).item()], + } + + r = self.raw_figure.patch( + x="x", + y="y", + source=source, + **kwargs, + ) + p = self.raw_figure.patch( + x=dir_patch_coords[:, 0] + x, y=dir_patch_coords[:, 1] + y, **kwargs + ) + + return r, p + + def add_agents( + self, + agent_rects_data: ColumnDataSource, + dir_patches_data: ColumnDataSource, + **kwargs, + ) -> Tuple[GlyphRenderer, GlyphRenderer]: + r = self.raw_figure.patches( + source=agent_rects_data, + # This is to ensure that the columns given in the + # ColumnDataSource are respected (e.g., "line_color"). + xs="xs", + ys="ys", + fill_alpha="fill_alpha", + fill_color="fill_color", + line_color="line_color", + **kwargs, + ) + + p = self.raw_figure.patches( + source=dir_patches_data, + # This is to ensure that the columns given in the + # ColumnDataSource are respected (e.g., "line_color"). + **{x: x for x in dir_patches_data.column_names}, + **kwargs, + ) + + return r, p diff --git a/src/trajdata/visualization/interactive_vis.py b/src/trajdata/visualization/interactive_vis.py new file mode 100644 index 0000000..85a263a --- /dev/null +++ b/src/trajdata/visualization/interactive_vis.py @@ -0,0 +1,195 @@ +from pathlib import Path + +import numpy as np +from bokeh.models import ColumnDataSource + +from trajdata.data_structures.agent import AgentType +from trajdata.data_structures.batch import AgentBatch, SceneBatch +from trajdata.data_structures.state import StateArray, StateTensor +from trajdata.maps.map_api import MapAPI +from trajdata.utils import vis_utils +from trajdata.utils.arr_utils import transform_coords_2d_np +from trajdata.visualization.interactive_figure import InteractiveFigure + + +def plot_agent_batch_interactive(batch: AgentBatch, batch_idx: int, cache_path: Path): + fig = InteractiveFigure( + tooltips=[ + ("Class", "@type"), + ("Position", "(@x, @y) m"), + ("Speed", "@speed_mps m/s (@speed_kph km/h)"), + ] + ) + + agent_type: int = batch.agent_type[batch_idx].item() + num_neighbors: int = batch.num_neigh[batch_idx].item() + agent_hist_np: StateArray = batch.agent_hist[batch_idx].cpu().numpy() + neigh_hist_np: StateArray = batch.neigh_hist[batch_idx].cpu().numpy() + neigh_types = batch.neigh_types[batch_idx].cpu().numpy() + agent_histories = ColumnDataSource( + data={ + "xs": [agent_hist_np.get_attr("x")] + + [ + neigh_hist_np[n_neigh].get_attr("x") for n_neigh in range(num_neighbors) + ], + "ys": [agent_hist_np.get_attr("y")] + + [ + neigh_hist_np[n_neigh].get_attr("y") for n_neigh in range(num_neighbors) + ], + "line_dash": ["dashed"] * (num_neighbors + 1), + "line_color": [vis_utils.get_agent_type_color(agent_type)] + + [ + vis_utils.get_agent_type_color(neigh_types[n_neigh]) + for n_neigh in range(num_neighbors) + ], + } + ) + + agent_fut_np: StateArray = batch.agent_fut[batch_idx].cpu().numpy() + neigh_fut_np: StateArray = batch.neigh_fut[batch_idx].cpu().numpy() + agent_futures = ColumnDataSource( + data={ + "xs": [agent_fut_np.get_attr("x")] + + [neigh_fut_np[n_neigh].get_attr("x") for n_neigh in range(num_neighbors)], + "ys": [agent_fut_np.get_attr("y")] + + [neigh_fut_np[n_neigh].get_attr("y") for n_neigh in range(num_neighbors)], + "line_dash": ["solid"] * (num_neighbors + 1), + "line_color": [vis_utils.get_agent_type_color(agent_type)] + + [ + vis_utils.get_agent_type_color(neigh_types[n_neigh]) + for n_neigh in range(num_neighbors) + ], + } + ) + + agent_state: StateArray = batch.agent_hist[batch_idx, -1].cpu().numpy() + x, y = agent_state.position + + if batch.map_names is not None: + map_vis_radius: float = 50.0 + mapAPI = MapAPI(cache_path) + fig.add_map( + batch.agents_from_world_tf[batch_idx].cpu().numpy(), + mapAPI.get_map( + batch.map_names[batch_idx], + incl_road_lanes=True, + incl_road_areas=True, + incl_ped_crosswalks=True, + incl_ped_walkways=True, + ), + # x_min, x_max, y_min, y_max + bbox=( + x - map_vis_radius, + x + map_vis_radius, + y - map_vis_radius, + y + map_vis_radius, + ), + ) + + fig.add_lines(agent_histories) + fig.add_lines(agent_futures) + + agent_extent: np.ndarray = batch.agent_hist_extent[batch_idx, -1] + if agent_extent.isnan().any(): + raise ValueError("Agent extents cannot be NaN!") + + length = agent_extent[0].item() + width = agent_extent[1].item() + + heading: float = agent_state.heading.item() + speed_mps: float = np.linalg.norm(agent_state.velocity).item() + + agent_rect_coords = transform_coords_2d_np( + np.array( + [ + [-length / 2, -width / 2], + [-length / 2, width / 2], + [length / 2, width / 2], + [length / 2, -width / 2], + ] + ), + angle=heading, + ) + + agent_rects_data = { + "x": [x], + "y": [y], + "xs": [agent_rect_coords[:, 0] + x], + "ys": [agent_rect_coords[:, 1] + y], + "fill_color": [vis_utils.get_agent_type_color(agent_type)], + "line_color": ["black"], + "fill_alpha": [0.7], + "type": [str(AgentType(agent_type))[len("AgentType.") :]], + "speed_mps": [speed_mps], + "speed_kph": [speed_mps * 3.6], + } + + size = 1.0 + if agent_type == AgentType.PEDESTRIAN: + size = 0.25 + + dir_patch_coords = transform_coords_2d_np( + np.array( + [ + [0, np.sqrt(3) / 3], + [-1 / 2, -np.sqrt(3) / 6], + [1 / 2, -np.sqrt(3) / 6], + ] + ) + * size, + angle=heading - np.pi / 2, + ) + dir_patches_data = { + "xs": [dir_patch_coords[:, 0] + x], + "ys": [dir_patch_coords[:, 1] + y], + "fill_color": [vis_utils.get_agent_type_color(agent_type)], + "line_color": ["black"], + "alpha": [0.7], + } + + for n_neigh in range(num_neighbors): + agent_type: int = batch.neigh_types[batch_idx, n_neigh].item() + agent_state: StateArray = batch.neigh_hist[batch_idx, n_neigh, -1].cpu().numpy() + agent_extent: np.ndarray = batch.neigh_hist_extents[batch_idx, n_neigh, -1] + + if agent_extent.isnan().any(): + raise ValueError("Agent extents cannot be NaN!") + + length = agent_extent[0].item() + width = agent_extent[1].item() + + x, y = agent_state.position + heading: float = agent_state.heading.item() + speed_mps: float = np.linalg.norm(agent_state.velocity).item() + + agent_rect_coords, dir_patch_coords = vis_utils.compute_agent_rect_coords( + agent_type, heading, length, width + ) + + agent_rects_data["x"].append(x) + agent_rects_data["y"].append(y) + agent_rects_data["xs"].append(agent_rect_coords[:, 0] + x) + agent_rects_data["ys"].append(agent_rect_coords[:, 1] + y) + agent_rects_data["fill_color"].append( + vis_utils.get_agent_type_color(agent_type) + ) + agent_rects_data["line_color"].append("black") + agent_rects_data["fill_alpha"].append(0.7) + agent_rects_data["type"].append(str(AgentType(agent_type))[len("AgentType.") :]) + agent_rects_data["speed_mps"].append(speed_mps) + agent_rects_data["speed_kph"].append(speed_mps * 3.6) + + dir_patches_data["xs"].append(dir_patch_coords[:, 0] + x) + dir_patches_data["ys"].append(dir_patch_coords[:, 1] + y) + dir_patches_data["fill_color"].append( + vis_utils.get_agent_type_color(agent_type) + ) + dir_patches_data["line_color"].append("black") + dir_patches_data["alpha"].append(0.7) + + rects, _ = fig.add_agents( + ColumnDataSource(data=agent_rects_data), ColumnDataSource(data=dir_patches_data) + ) + + fig.raw_figure.hover.renderers = [rects] + fig.show() diff --git a/src/trajdata/visualization/vis.py b/src/trajdata/visualization/vis.py index 3976722..5fff19d 100644 --- a/src/trajdata/visualization/vis.py +++ b/src/trajdata/visualization/vis.py @@ -1,19 +1,245 @@ -from typing import Optional +from typing import List, Optional, Tuple +from warnings import warn import matplotlib.pyplot as plt +import matplotlib.transforms as mtransforms +import numpy as np +import seaborn as sns import torch from matplotlib.axes import Axes +from matplotlib.patches import FancyBboxPatch, Polygon from torch import Tensor from trajdata.data_structures.agent import AgentType from trajdata.data_structures.batch import AgentBatch, SceneBatch +from trajdata.data_structures.state import StateTensor from trajdata.maps import RasterizedMap +from trajdata.maps.vec_map_elements import RoadLane + + +def draw_agent( + ax: Axes, + agent_type: AgentType, + agent_state: StateTensor, + agent_extent: Tensor, + agent_to_world_tf: Tensor, + **kwargs, +) -> None: + """Draws a path with the correct location, heading, and dimensions onto the given axes + + Args: + ax (Axes): _description_ + agent_type (AgentType): _description_ + agent_state (Tensor): _description_ + agent_extent (Tensor): _description_ + agent_to_world_tf (Tensor): _description_ + """ + + if torch.any(torch.isnan(agent_extent)): + if agent_type == AgentType.VEHICLE: + length = 4.3 + width = 1.8 + elif agent_type == AgentType.PEDESTRIAN: + length = 0.5 + width = 0.5 + elif agent_type == AgentType.BICYCLE: + length = 1.9 + width = 0.5 + else: + length = 1.0 + width = 1.0 + else: + length = agent_extent[0].item() + width = agent_extent[1].item() + + xy = agent_state.position + heading = agent_state.heading + + patch = FancyBboxPatch([-length / 2, -width / 2], length, width, **kwargs) + transform = ( + mtransforms.Affine2D().rotate(heading[0].item()).translate(xy[0], xy[1]) + + mtransforms.Affine2D(matrix=agent_to_world_tf.cpu().numpy()) + + ax.transData + ) + patch.set_transform(transform) + + kwargs["label"] = None + size = 1.0 + angles = [0, 2 * np.pi / 3, np.pi, 4 * np.pi / 3] + pts = np.stack([size * np.cos(angles), size * np.sin(angles)], axis=-1) + center_patch = Polygon(pts, zorder=10.0, **kwargs) + center_patch.set_transform(transform) + + ax.add_patch(patch) + ax.add_patch(center_patch) + + +def draw_history( + ax: Axes, + agent_type: AgentType, + agent_history: StateTensor, + agent_extent: Tensor, + agent_to_world_tf: Tensor, + start_alpha: float = 0.2, + end_alpha: float = 0.5, + **kwargs, +): + T = agent_history.shape[0] + alphas = np.linspace(start_alpha, end_alpha, T) + for t in range(T): + draw_agent( + ax, + agent_type, + agent_history[t], + agent_extent, + agent_to_world_tf, + alpha=alphas[t], + **kwargs, + ) + + +def draw_map( + ax: Axes, map: Tensor, base_frame_from_map_tf: Tensor, alpha=1.0, **kwargs +): + patch_size: int = map.shape[-1] + map_array = RasterizedMap.to_img(map.cpu()) + brightened_map_array = map_array * 0.2 + 0.8 + + im = ax.imshow( + brightened_map_array, + extent=[0, patch_size, patch_size, 0], + clip_on=True, + **kwargs, + ) + transform = ( + mtransforms.Affine2D(matrix=base_frame_from_map_tf.cpu().numpy()) + ax.transData + ) + im.set_transform(transform) + + coords = np.array( + [[0, 0, 1], [patch_size, 0, 1], [patch_size, patch_size, 1], [0, patch_size, 1]] + ) + world_frame_corners = base_frame_from_map_tf.cpu().numpy() @ coords[:, :, None] + xmin = np.min(world_frame_corners[:, 0, 0]) + xmax = np.max(world_frame_corners[:, 0, 0]) + ymin = np.min(world_frame_corners[:, 1, 0]) + ymax = np.max(world_frame_corners[:, 1, 0]) + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) + + +def draw_lanes( + ax: Axes, + lanes: List[RoadLane], + centered_agent_from_world_tf: Tensor, + color: Tuple[float, float, float] = (0.5, 0.5, 0.5), +): + transform = ( + mtransforms.Affine2D(matrix=centered_agent_from_world_tf.cpu().numpy()) + + ax.transData + ) + for lane in lanes: + ax.plot( + lane.center.xy[:, 0], + lane.center.xy[:, 1], + linestyle="--", + color=color, + transform=transform, + ) + + +def plot_agent_batch_all( + batch: AgentBatch, + ax: Optional[Axes] = None, + show: bool = True, + close: bool = True, +) -> None: + if ax is None: + _, ax = plt.subplots() + + # Use first agent as common reference frame + base_frame_from_world_tf = batch.agents_from_world_tf[0].cpu() + + # plot maps over each other with proper transformations: + for i in range(len(batch.agent_name)): + base_frame_from_map_tf = base_frame_from_world_tf @ torch.linalg.inv( + batch.rasters_from_world_tf[i].cpu() + ) + draw_map(ax, batch.maps[i], base_frame_from_map_tf, alpha=1.0) + + for i in range(len(batch.agent_name)): + agent_type = batch.agent_type[i] + agent_name = batch.agent_name[i] + agent_hist = batch.agent_hist[i, :, :].cpu() + agent_fut = batch.agent_fut[i, :, :].cpu() + agent_extent = batch.agent_hist_extent[i, -1, :].cpu() + base_frame_from_agent_tf = base_frame_from_world_tf @ torch.linalg.inv( + batch.agents_from_world_tf[i].cpu() + ) + + palette = sns.color_palette("husl", 4) + if agent_type == AgentType.VEHICLE: + color = palette[0] + elif agent_type == AgentType.PEDESTRIAN: + color = palette[1] + elif agent_type == AgentType.BICYCLE: + color = palette[2] + else: + color = palette[3] + + transform = ( + mtransforms.Affine2D(matrix=base_frame_from_agent_tf.numpy()) + ax.transData + ) + draw_history( + ax, + agent_type, + agent_hist[:-1, :], + agent_extent, + base_frame_from_agent_tf, + facecolor="None", + edgecolor=color, + linewidth=0, + ) + ax.plot( + agent_hist[:, 0], + agent_hist[:, 1], + linestyle="--", + color=color, + transform=transform, + ) + draw_agent( + ax, + agent_type, + agent_hist[-1, :], + agent_extent, + base_frame_from_agent_tf, + facecolor=color, + edgecolor="k", + ) + ax.plot( + agent_fut[:, 0], + agent_fut[:, 1], + linestyle="-", + color=color, + transform=transform, + ) + + ax.set_ylim(-30, 40) + ax.set_xlim(-30, 40) + ax.grid(False) + + if show: + plt.show() + + if close: + plt.close() def plot_agent_batch( batch: AgentBatch, batch_idx: int, ax: Optional[Axes] = None, + legend: bool = True, show: bool = True, close: bool = True, ) -> None: @@ -22,98 +248,118 @@ def plot_agent_batch( agent_name: str = batch.agent_name[batch_idx] agent_type: AgentType = AgentType(batch.agent_type[batch_idx].item()) - ax.set_title(f"{str(agent_type)}/{agent_name}") + current_state = batch.curr_agent_state[batch_idx].cpu().numpy() + ax.set_title( + f"{str(agent_type)}/{agent_name}\nat x={current_state[0]:.2f},y={current_state[1]:.2f},h={current_state[-1]:.2f}" + ) - history_xy: Tensor = batch.agent_hist[batch_idx].cpu() - center_xy: Tensor = batch.agent_hist[batch_idx, -1, :2].cpu() - future_xy: Tensor = batch.agent_fut[batch_idx, :, :2].cpu() + agent_from_world_tf: Tensor = batch.agents_from_world_tf[batch_idx].cpu() if batch.maps is not None: - agent_from_world_tf: Tensor = batch.agents_from_world_tf[batch_idx].cpu() world_from_raster_tf: Tensor = torch.linalg.inv( batch.rasters_from_world_tf[batch_idx].cpu() ) agent_from_raster_tf: Tensor = agent_from_world_tf @ world_from_raster_tf - patch_size: int = batch.maps[batch_idx].shape[-1] - - left_extent: float = (agent_from_raster_tf @ torch.tensor([0.0, 0.0, 1.0]))[ - 0 - ].item() - right_extent: float = ( - agent_from_raster_tf @ torch.tensor([patch_size, 0.0, 1.0]) - )[0].item() - bottom_extent: float = ( - agent_from_raster_tf @ torch.tensor([0.0, patch_size, 1.0]) - )[1].item() - top_extent: float = (agent_from_raster_tf @ torch.tensor([0.0, 0.0, 1.0]))[ - 1 - ].item() - - ax.imshow( - RasterizedMap.to_img( - batch.maps[batch_idx].cpu(), - # [[0], [1], [2]] - # [[0, 1, 2], [3, 4], [5, 6]], - ), - extent=( - left_extent, - right_extent, - bottom_extent, - top_extent, - ), - alpha=0.3, - ) - + draw_map(ax, batch.maps[batch_idx], agent_from_raster_tf, alpha=1.0) + + agent_hist = batch.agent_hist[batch_idx].cpu() + agent_fut = batch.agent_fut[batch_idx].cpu() + agent_extent = batch.agent_hist_extent[batch_idx, -1, :].cpu() + base_frame_from_agent_tf = torch.eye(3) + + palette = sns.color_palette("husl", 4) + if agent_type == AgentType.VEHICLE: + color = palette[0] + elif agent_type == AgentType.PEDESTRIAN: + color = palette[1] + elif agent_type == AgentType.BICYCLE: + color = palette[2] + else: + color = palette[3] + + draw_history( + ax, + agent_type, + agent_hist[:-1], + agent_extent, + base_frame_from_agent_tf, + facecolor=color, + edgecolor=None, + linewidth=0, + ) ax.plot( - history_xy[..., 0], - history_xy[..., 1], - c="orange", - ls="--", + agent_hist.get_attr("x"), + agent_hist.get_attr("y"), + linestyle="--", + color=color, label="Agent History", ) - ax.quiver( - history_xy[..., 0], - history_xy[..., 1], - history_xy[..., -1], - history_xy[..., -2], - color="k", + draw_agent( + ax, + agent_type, + agent_hist[-1], + agent_extent, + base_frame_from_agent_tf, + facecolor=color, + edgecolor="k", + label="Agent Current", + ) + ax.plot( + agent_fut.get_attr("x"), + agent_fut.get_attr("y"), + linestyle="-", + color=color, + label="Agent Future", ) - ax.plot(future_xy[..., 0], future_xy[..., 1], c="violet", label="Agent Future") - ax.scatter(center_xy[0], center_xy[1], s=20, c="orangered", label="Agent Current") num_neigh = batch.num_neigh[batch_idx] if num_neigh > 0: - neighbor_hist = batch.neigh_hist[batch_idx] - neighbor_fut = batch.neigh_fut[batch_idx] + neighbor_hist = batch.neigh_hist[batch_idx].cpu() + neighbor_fut = batch.neigh_fut[batch_idx].cpu() + neighbor_extent = batch.neigh_hist_extents[batch_idx, :, -1, :].cpu() + neighbor_type = batch.neigh_types[batch_idx].cpu() ax.plot([], [], c="olive", ls="--", label="Neighbor History") - for n in range(num_neigh): - ax.plot(neighbor_hist[n, :, 0], neighbor_hist[n, :, 1], c="olive", ls="--") - ax.plot([], [], c="darkgreen", label="Neighbor Future") - for n in range(num_neigh): - ax.plot(neighbor_fut[n, :, 0], neighbor_fut[n, :, 1], c="darkgreen") - ax.scatter( - neighbor_hist[:num_neigh, -1, 0], - neighbor_hist[:num_neigh, -1, 1], - s=20, - c="gold", - label="Neighbor Current", - ) + for n in range(num_neigh): + if torch.isnan(neighbor_hist[n, -1, :]).any(): + # this neighbor does not exist at the current timestep + continue + ax.plot( + neighbor_hist.get_attr("x")[n, :], + neighbor_hist.get_attr("y")[n, :], + c="olive", + ls="--", + ) + draw_agent( + ax, + neighbor_type[n], + neighbor_hist[n, -1], + neighbor_extent[n, :], + base_frame_from_agent_tf, + facecolor="olive", + edgecolor="k", + alpha=0.7, + ) + ax.plot( + neighbor_fut.get_attr("x")[n, :], + neighbor_fut.get_attr("y")[n, :], + c="darkgreen", + ) if batch.robot_fut is not None and batch.robot_fut.shape[1] > 0: ax.plot( - batch.robot_fut[batch_idx, 1:, 0], - batch.robot_fut[batch_idx, 1:, 1], + batch.robot_fut.get_attr("x")[batch_idx, 1:], + batch.robot_fut.get_attr("y")[batch_idx, 1:], label="Ego Future", c="blue", ) ax.scatter( - batch.robot_fut[batch_idx, 0, 0], - batch.robot_fut[batch_idx, 0, 1], + batch.robot_fut.get_attr("x")[batch_idx, 0], + batch.robot_fut.get_attr("y")[batch_idx, 0], s=20, c="blue", label="Ego Current", @@ -123,8 +369,14 @@ def plot_agent_batch( ax.set_ylabel("y (m)") ax.grid(False) - ax.legend(loc="best", frameon=True) - ax.axis("equal") + ax.set_aspect("equal", adjustable="box") + + # Doing this because the imshow above makes the map origin at the top. + # TODO(pkarkus) we should just modify imshow not to change the origin instead. + ax.invert_yaxis() + + if legend: + ax.legend(loc="best", frameon=True) if show: plt.show() @@ -132,103 +384,96 @@ def plot_agent_batch( if close: plt.close() + return ax + def plot_scene_batch( batch: SceneBatch, batch_idx: int, ax: Optional[Axes] = None, + plot_vec_map: bool = False, + vec_map_search_radius: float = 100, show: bool = True, close: bool = True, -) -> None: +) -> Axes: if ax is None: _, ax = plt.subplots() num_agents: int = batch.num_agents[batch_idx].item() - history_xy: Tensor = batch.agent_hist[batch_idx].cpu() - center_xy: Tensor = batch.agent_hist[batch_idx, ..., -1, :2].cpu() - future_xy: Tensor = batch.agent_fut[batch_idx, ..., :2].cpu() - - if batch.maps is not None: - centered_agent_id: int = 0 - agent_from_world_tf: Tensor = batch.centered_agent_from_world_tf[ - batch_idx - ].cpu() + agent_from_world_tf: Tensor = batch.centered_agent_from_world_tf[batch_idx].cpu() + + if plot_vec_map and batch.vector_maps is not None: + try: + search_point = ( + batch.centered_agent_state.position3d[batch_idx].cpu().numpy() + ) + except ValueError: + warn( + "could not compute 3d position. try adding 'z' component to state format, " + "e.g. state_format='x,y,z,xd,yd,xdd,ydd,h'" + ) + raise + vec_map = batch.vector_maps[batch_idx] + lanes = vec_map.get_lanes_within(search_point, vec_map_search_radius) + draw_lanes(ax, lanes, agent_from_world_tf) + elif batch.maps is not None: + centered_agent_id = 0 world_from_raster_tf: Tensor = torch.linalg.inv( batch.rasters_from_world_tf[batch_idx, centered_agent_id].cpu() ) agent_from_raster_tf: Tensor = agent_from_world_tf @ world_from_raster_tf - patch_size: int = batch.maps[batch_idx, centered_agent_id].shape[-1] - - left_extent: float = (agent_from_raster_tf @ torch.tensor([0.0, 0.0, 1.0]))[ - 0 - ].item() - right_extent: float = ( - agent_from_raster_tf @ torch.tensor([patch_size, 0.0, 1.0]) - )[0].item() - bottom_extent: float = ( - agent_from_raster_tf @ torch.tensor([0.0, patch_size, 1.0]) - )[1].item() - top_extent: float = (agent_from_raster_tf @ torch.tensor([0.0, 0.0, 1.0]))[ - 1 - ].item() - - ax.imshow( - RasterizedMap.to_img( - batch.maps[batch_idx, centered_agent_id].cpu(), - # [[0], [1], [2]] - # [[0, 1, 2], [3, 4], [5, 6]], - ), - extent=( - left_extent, - right_extent, - bottom_extent, - top_extent, - ), - alpha=0.3, + draw_map( + ax, + batch.maps[batch_idx, centered_agent_id], + agent_from_raster_tf, + alpha=1.0, ) + base_frame_from_agent_tf = torch.eye(3) + agent_hist = batch.agent_hist[batch_idx] + agent_type = batch.agent_type[batch_idx] + agent_extent = batch.agent_hist_extent[batch_idx, :, -1] + agent_fut = batch.agent_fut[batch_idx] + for agent_id in range(num_agents): ax.plot( - history_xy[agent_id, ..., 0], - history_xy[agent_id, ..., 1], + agent_hist.get_attr("x")[agent_id], + agent_hist.get_attr("y")[agent_id], c="orange", ls="--", label="Agent History" if agent_id == 0 else None, ) - ax.quiver( - history_xy[agent_id, ..., 0], - history_xy[agent_id, ..., 1], - history_xy[agent_id, ..., -1], - history_xy[agent_id, ..., -2], - color="k", + draw_agent( + ax, + agent_type[agent_id], + agent_hist[agent_id, -1], + agent_extent[agent_id], + base_frame_from_agent_tf, + facecolor="olive", + edgecolor="k", + alpha=0.7, + label="Agent Current" if agent_id == 0 else None, ) ax.plot( - future_xy[agent_id, ..., 0], - future_xy[agent_id, ..., 1], + agent_fut.get_attr("x")[agent_id], + agent_fut.get_attr("y")[agent_id], c="violet", label="Agent Future" if agent_id == 0 else None, ) - ax.scatter( - center_xy[agent_id, 0], - center_xy[agent_id, 1], - s=20, - c="orangered", - label="Agent Current" if agent_id == 0 else None, - ) if batch.robot_fut is not None and batch.robot_fut.shape[1] > 0: ax.plot( - batch.robot_fut[batch_idx, 1:, 0], - batch.robot_fut[batch_idx, 1:, 1], + batch.robot_fut.get_attr("x")[batch_idx, 1:], + batch.robot_fut.get_attr("y")[batch_idx, 1:], label="Ego Future", c="blue", ) ax.scatter( - batch.robot_fut[batch_idx, 0, 0], - batch.robot_fut[batch_idx, 0, 1], + batch.robot_fut.get_attr("x")[batch_idx, 0], + batch.robot_fut.get_attr("y")[batch_idx, 0], s=20, c="blue", label="Ego Current", @@ -238,11 +483,16 @@ def plot_scene_batch( ax.set_ylabel("y (m)") ax.grid(False) + ax.set_aspect("equal", adjustable="box") ax.legend(loc="best", frameon=True) - ax.axis("equal") + + # Doing this because the imshow above makes the map origin at the top. + ax.invert_yaxis() if show: plt.show() if close: plt.close() + + return ax diff --git a/tests/test_batch_conversion.py b/tests/test_batch_conversion.py new file mode 100644 index 0000000..40d158d --- /dev/null +++ b/tests/test_batch_conversion.py @@ -0,0 +1,281 @@ +import unittest +from collections import defaultdict + +import torch + +from trajdata import AgentType, UnifiedDataset +from trajdata.caching.env_cache import EnvCache +from trajdata.data_structures import AgentBatch +from trajdata.utils.batch_utils import SceneTimeBatcher, convert_to_agent_batch + + +class TestSceneToAgentBatchConversion(unittest.TestCase): + def __init__(self, methodName: str = "batchConversion") -> None: + super().__init__(methodName) + + data_source = "nusc_mini" + history_sec = 2.0 + prediction_sec = 6.0 + + attention_radius = defaultdict( + lambda: 20.0 + ) # Default range is 20m unless otherwise specified. + attention_radius[(AgentType.PEDESTRIAN, AgentType.PEDESTRIAN)] = 10.0 + attention_radius[(AgentType.PEDESTRIAN, AgentType.VEHICLE)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.PEDESTRIAN)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.VEHICLE)] = 30.0 + + map_params = {"px_per_m": 2, "map_size_px": 100, "offset_frac_xy": (-0.75, 0.0)} + + self._scene_dataset = UnifiedDataset( + centric="scene", + desired_data=[data_source], + history_sec=(history_sec, history_sec), + future_sec=(prediction_sec, prediction_sec), + agent_interaction_distances=attention_radius, + incl_robot_future=False, + incl_raster_map=True, + raster_map_params=map_params, + only_predict=[AgentType.VEHICLE, AgentType.PEDESTRIAN], + no_types=[AgentType.UNKNOWN], + num_workers=0, + standardize_data=True, + data_dirs={ + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + self._agent_dataset = UnifiedDataset( + centric="agent", + desired_data=[data_source], + history_sec=(history_sec, history_sec), + future_sec=(prediction_sec, prediction_sec), + agent_interaction_distances=attention_radius, + incl_robot_future=False, + incl_raster_map=True, + raster_map_params=map_params, + only_predict=[AgentType.VEHICLE, AgentType.PEDESTRIAN], + no_types=[AgentType.UNKNOWN], + num_workers=0, + standardize_data=True, + data_dirs={ + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + def _assert_allclose_with_nans(self, tensor1, tensor2): + """ + asserts that the two tensors have nans in the same locations, and the non-nan + elements all are close. + """ + # Check nans are in the same place + self.assertFalse( + torch.any( # True if there's any mismatch + torch.logical_xor( # True where either tensor1 or tensor 2 has nans, but not both (mismatch) + torch.isnan(tensor1), # True where tensor1 has nans + torch.isnan(tensor2), # True where tensor2 has nans + ) + ), + msg="Nans occur in different places.", + ) + valid_mask = torch.logical_not(torch.isnan(tensor1)) + self.assertTrue( + torch.allclose(tensor1[valid_mask], tensor2[valid_mask]), + msg="Non-nan values don't match.", + ) + + def _test_agent_idx(self, agent_dataset_idx: int, verbose=False): + for offset in range(50): + agent_batch_element = self._agent_dataset[agent_dataset_idx] + agent_scene_path, _, _ = self._agent_dataset._data_index[agent_dataset_idx] + agent_batch = self._agent_dataset.get_collate_fn(pad_format="right")( + [agent_batch_element] + ) + scene_ts = agent_batch_element.scene_ts + scene_id = agent_batch_element.scene_id + agent_name = agent_batch_element.agent_name + if verbose: + print( + f"From the agent-centric dataset at index {agent_dataset_idx}, we're looking at:\nAgent {agent_name} in {scene_id} at timestep {scene_ts}" + ) + + # find same scene and ts in scene-centric dataset + scene_dataset_idx = 0 + for scene_dataset_idx in range(len(self._scene_dataset)): + scene_path, ts = self._scene_dataset._data_index[scene_dataset_idx] + if ts == scene_ts and scene_path == agent_scene_path: + # load scene to check scene name + scene = EnvCache.load(scene_path) + if scene.name == scene_id: + break + + if verbose: + print( + f"We found a matching scene in the scene-centric dataset at index {scene_dataset_idx}" + ) + + scene_batch_element = self._scene_dataset[scene_dataset_idx] + converted_agent_batch = convert_to_agent_batch( + scene_batch_element, + self._scene_dataset.only_types, + self._scene_dataset.no_types, + self._scene_dataset.agent_interaction_distances, + self._scene_dataset.incl_raster_map, + self._scene_dataset.raster_map_params, + self._scene_dataset.max_neighbor_num, + self._scene_dataset.state_format, + self._scene_dataset.standardize_data, + self._scene_dataset.standardize_derivatives, + pad_format="right", + ) + + agent_idx = -1 + for j, name in enumerate(converted_agent_batch.agent_name): + if name == agent_name: + agent_idx = j + + if agent_idx < 0: + if verbose: + print("no matching scene containing agent, checking next index") + agent_dataset_idx += 1 + else: + break + + self.assertTrue( + agent_idx >= 0, "Matching scene not found in scene-centric dataset!" + ) + + if verbose: + print( + f"Agent {converted_agent_batch.agent_name[agent_idx]} appears in {scene_batch_element.scene_id} at timestep {scene_batch_element.scene_ts}, as agent number {agent_idx}" + ) + + attrs_to_ignore = ["data_idx", "extras", "history_pad_dir"] + + variable_length_keys = { + "neigh_types": "num_neigh", + "neigh_hist": "num_neigh", + "neigh_hist_extents": "num_neigh", + "neigh_hist_len": "num_neigh", + "neigh_fut": "num_neigh", + "neigh_fut_extents": "num_neigh", + "neigh_fut_len": "num_neigh", + } + + for attr, val in converted_agent_batch.__dict__.items(): + if attr in attrs_to_ignore: + continue + if verbose: + print(f"Checking {attr}") + + if val is None: + self.assertTrue(agent_batch.__dict__[attr] is None) + elif isinstance(val[agent_idx], torch.Tensor): + if attr in variable_length_keys: + attr_len = converted_agent_batch.__dict__[ + variable_length_keys[attr] + ][agent_idx] + convertedTensor = val[agent_idx, :attr_len, ...] + targetTensor = agent_batch.__dict__[attr][0, :attr_len, ...] + else: + convertedTensor = val[agent_idx] + targetTensor = agent_batch.__dict__[attr][0] + try: + self._assert_allclose_with_nans(convertedTensor, targetTensor) + except RuntimeError as e: + print(f"Error at {attr=}") + raise e + else: + self.assertTrue( + val[agent_idx] == agent_batch.__dict__[attr][0], + f"Failed at {attr=}", + ) + + def test_index_1(self): + self._test_agent_idx(0, verbose=False) + + def test_index_2(self): + self._test_agent_idx(116, verbose=False) + + def test_index_3(self): + self._test_agent_idx(222, verbose=False) + + +class TestSceneSampler(unittest.TestCase): + def setUp(self) -> None: + self._scene_dataset = UnifiedDataset( + centric="scene", + desired_data=["nusc_mini-mini_val"], + data_dirs={ + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + self._agent_dataset = UnifiedDataset( + centric="agent", + desired_data=["nusc_mini-mini_val"], + data_dirs={ + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + self._scene_sampler = SceneTimeBatcher(self._agent_dataset) + + def _test_len(self, agent_idx=0): + """ + Len of dataset should be equal to number of timesteps + the agent appears in the dataset + """ + sampler = SceneTimeBatcher(self._agent_dataset, agent_idx) + + total_len = sum( + lengths[agent_idx + 1] - lengths[agent_idx] + for lengths in self._agent_dataset._data_index._cumulative_scene_lengths + ) + + dl = torch.utils.data.DataLoader(self._agent_dataset, batch_sampler=sampler) + + self.assertEqual(len(dl), total_len) + + def test_len_ego(self): + self.assertEqual(len(self._scene_sampler), len(self._scene_dataset)) + + def test_len_nonego_1(self): + return self._test_len(15) + + def test_len_nonego_2(self): + return self._test_len(30) + + def test_consistency(self): + dl = torch.utils.data.DataLoader( + self._agent_dataset, + batch_sampler=self._scene_sampler, + collate_fn=self._agent_dataset.get_collate_fn(pad_format="right"), + ) + scene_idx = 0 + agent_batch: AgentBatch + for agent_batch in dl: + scene_batch_elem = self._scene_dataset[scene_idx] + for scene_id in agent_batch.scene_ids: + self.assertEqual(scene_batch_elem.scene_id, scene_id) + + # ensure all elements have the same scene id + self.assertEqual( + torch.abs(agent_batch.scene_ts - scene_batch_elem.scene_ts) + .float() + .mean() + .item(), + 0, + ) + + for agent_name in agent_batch.agent_name: + self.assertIn(agent_name, scene_batch_elem.agent_names) + + scene_idx += 1 + + if scene_idx == 10: + break + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_collation.py b/tests/test_collation.py index 9da0499..5ae625a 100644 --- a/tests/test_collation.py +++ b/tests/test_collation.py @@ -1,8 +1,12 @@ import unittest +from collections import defaultdict import numpy as np import torch +from torch.utils.data import DataLoader +from trajdata.data_structures.agent import AgentType +from trajdata.dataset import UnifiedDataset from trajdata.utils import arr_utils @@ -129,3 +133,96 @@ def test_pad_sequences(self): equal_nan=True, ) ) + + def test_zero_neighbor_dict_collation(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 0.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(return_dict=True), + num_workers=0, + ) + + i = 0 + for batch in dataloader: + i += 1 + + self.assertIsInstance(batch["curr_agent_state"], dataset.torch_state_type) + self.assertIsInstance(batch["agent_hist"], dataset.torch_obs_type) + self.assertIsInstance(batch["agent_fut"], dataset.torch_obs_type) + self.assertIsInstance(batch["robot_fut"], dataset.torch_obs_type) + + if i == 5: + break + + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="scene", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 0.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(return_dict=True), + num_workers=0, + ) + + i = 0 + for batch in dataloader: + i += 1 + + self.assertIsInstance( + batch["centered_agent_state"], dataset.torch_state_type + ) + self.assertIsInstance(batch["agent_hist"], dataset.torch_obs_type) + self.assertIsInstance(batch["agent_fut"], dataset.torch_obs_type) + self.assertIsInstance(batch["robot_fut"], dataset.torch_obs_type) + + if i == 5: + break + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..9386429 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,212 @@ +import unittest +from collections import defaultdict + +from torch.utils.data import DataLoader + +from trajdata.data_structures.agent import AgentType +from trajdata.data_structures.batch import AgentBatch +from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement +from trajdata.data_structures.state import NP_STATE_TYPES, TORCH_STATE_TYPES +from trajdata.dataset import UnifiedDataset + + +class TestDataset(unittest.TestCase): + def test_dataloading(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + + i = 0 + batch: AgentBatch + for batch in dataloader: + i += 1 + + batch.to("cuda") + + self.assertIsInstance(batch.curr_agent_state, dataset.torch_state_type) + self.assertIsInstance(batch.agent_hist, dataset.torch_obs_type) + self.assertIsInstance(batch.agent_fut, dataset.torch_obs_type) + self.assertIsInstance(batch.robot_fut, dataset.torch_obs_type) + + if i == 5: + break + + def test_dict_dataloading(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(return_dict=True), + num_workers=0, + ) + + i = 0 + for batch in dataloader: + i += 1 + + self.assertIsInstance(batch["curr_agent_state"], dataset.torch_state_type) + self.assertIsInstance(batch["agent_hist"], dataset.torch_obs_type) + self.assertIsInstance(batch["agent_fut"], dataset.torch_obs_type) + self.assertIsInstance(batch["robot_fut"], dataset.torch_obs_type) + + if i == 5: + break + + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="scene", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(return_dict=True), + num_workers=0, + ) + + i = 0 + for batch in dataloader: + i += 1 + + self.assertIsInstance( + batch["centered_agent_state"], dataset.torch_state_type + ) + self.assertIsInstance(batch["agent_hist"], dataset.torch_obs_type) + self.assertIsInstance(batch["agent_fut"], dataset.torch_obs_type) + self.assertIsInstance(batch["robot_fut"], dataset.torch_obs_type) + + if i == 5: + break + + def test_default_datatypes_agent(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + elem: AgentBatchElement = dataset[0] + self.assertIsInstance(elem.curr_agent_state_np, dataset.np_state_type) + self.assertIsInstance(elem.agent_history_np, dataset.np_obs_type) + self.assertIsInstance(elem.agent_future_np, dataset.np_obs_type) + self.assertIsInstance(elem.robot_future_np, dataset.np_obs_type) + + def test_default_datatypes_scene(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="scene", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + elem: SceneBatchElement = dataset[0] + self.assertIsInstance(elem.centered_agent_state_np, dataset.np_state_type) + self.assertIsInstance(elem.agent_histories[0], dataset.np_obs_type) + self.assertIsInstance(elem.agent_futures[0], dataset.np_obs_type) + self.assertIsInstance(elem.robot_future_np, dataset.np_obs_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_datasizes.py b/tests/test_datasizes.py index 1283230..94e2222 100644 --- a/tests/test_datasizes.py +++ b/tests/test_datasizes.py @@ -7,10 +7,15 @@ class TestDatasetSizes(unittest.TestCase): def test_two_datasets(self): dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], centric="agent" + desired_data=["nusc_mini", "nuplan_mini"], + centric="agent", + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, ) - self.assertEqual(len(dataset), 1_924_196) + self.assertEqual(len(dataset), 27_054_719) def test_splits(self): dataset = UnifiedDataset(desired_data=["nusc_mini-mini_train"], centric="agent") @@ -30,57 +35,95 @@ def test_geography(self): self.assertEqual(len(dataset), 6_111) - dataset = UnifiedDataset(desired_data=["palo_alto"], centric="agent") + dataset = UnifiedDataset( + desired_data=["pittsburgh"], + centric="agent", + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, + ) - self.assertEqual(len(dataset), 1_909_120) + self.assertEqual(len(dataset), 846_605) - dataset = UnifiedDataset(desired_data=["boston", "palo_alto"], centric="agent") + dataset = UnifiedDataset( + desired_data=["boston", "pittsburgh"], + centric="agent", + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, + ) - self.assertEqual(len(dataset), 1_915_231) + self.assertEqual(len(dataset), 2_381_216) def test_exclusion(self): dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", - no_types=[AgentType.UNKNOWN], + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, + no_types=[AgentType.VEHICLE], ) - self.assertEqual(len(dataset), 610_074) + self.assertEqual(len(dataset), 13_099_040) dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", - no_types=[AgentType.UNKNOWN, AgentType.BICYCLE], + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, + no_types=[AgentType.VEHICLE, AgentType.BICYCLE], ) - self.assertEqual(len(dataset), 603_089) + self.assertEqual(len(dataset), 12_989_300) def test_inclusion(self): dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", - only_types=[AgentType.VEHICLE], + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, + only_types=[AgentType.PEDESTRIAN], ) - self.assertEqual(len(dataset), 554_880) + self.assertEqual(len(dataset), 12_988_830) dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", - only_types=[AgentType.VEHICLE, AgentType.UNKNOWN], + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, + only_types=[AgentType.VEHICLE, AgentType.PEDESTRIAN], ) - self.assertEqual(len(dataset), 1_869_002) + self.assertEqual(len(dataset), 26_944_509) def test_prediction_inclusion(self): unfiltered_dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, ) filtered_dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, only_predict=[AgentType.VEHICLE], ) @@ -91,8 +134,12 @@ def test_prediction_inclusion(self): self.assertEqual(filtered_dataset[sample_idx].agent_type, AgentType.VEHICLE) filtered_dataset2 = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, only_predict=[AgentType.VEHICLE, AgentType.PEDESTRIAN], ) @@ -106,31 +153,43 @@ def test_prediction_inclusion(self): def test_history_future(self): dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, history_sec=(0.1, 2.0), future_sec=(0.1, 2.0), ) - self.assertEqual(len(dataset), 1_685_896) + self.assertEqual(len(dataset), 26_283_199) dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, history_sec=(0.5, 2.0), future_sec=(0.5, 3.0), ) - self.assertEqual(len(dataset), 1_155_704) + self.assertEqual(len(dataset), 23_489_292) dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], centric="agent", + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, history_sec=(0.5, 1.0), future_sec=(0.5, 0.7), ) - self.assertEqual(len(dataset), 1_155_704) + self.assertEqual(len(dataset), 23_489_292) def test_interpolation(self): dataset = UnifiedDataset( @@ -141,8 +200,8 @@ def test_interpolation(self): future_sec=(4.8, 4.8), only_types=[AgentType.VEHICLE], incl_robot_future=False, - incl_map=False, - map_params={ + incl_raster_map=False, + raster_map_params={ "px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0), @@ -174,6 +233,64 @@ def test_simple_scene(self): self.assertEqual(len(dataset), 943) + def test_hist_fut_len(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_train"], + centric="agent", + desired_dt=0.1, + history_sec=(2.3, 2.3), + future_sec=(2.4, 2.4), + only_types=[AgentType.VEHICLE], + max_agent_num=20, + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + self.assertEqual(dataset[0].agent_history_np.shape[0], 24) + self.assertEqual(dataset[0].agent_future_np.shape[0], 24) + self.assertEqual(dataset[0].scene_ts, 23) + + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_train"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_types=[AgentType.VEHICLE], + max_agent_num=20, + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + self.assertEqual(dataset[0].agent_history_np.shape[0], 33) + self.assertEqual(dataset[0].agent_future_np.shape[0], 48) + self.assertEqual(dataset[0].scene_ts, 32) + + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_train"], + centric="scene", + desired_dt=0.1, + history_sec=(2.3, 2.3), + future_sec=(2.4, 2.4), + only_types=[AgentType.VEHICLE], + max_agent_num=20, + num_workers=0, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + self.assertEqual(dataset[0].agent_histories[0].shape[0], 24) + self.assertEqual(dataset[0].agent_futures[0].shape[0], 24) + self.assertEqual(dataset[0].scene_ts, 23) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_description_matching.py b/tests/test_description_matching.py index 0327a50..cf98c52 100644 --- a/tests/test_description_matching.py +++ b/tests/test_description_matching.py @@ -22,8 +22,12 @@ def test_intersection(self): def test_intersection_more_initial(self): dataset = UnifiedDataset( - desired_data=["nusc_mini", "lyft_sample"], + desired_data=["nusc_mini", "nuplan_mini"], scene_description_contains=["intersection"], + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, ) for scene_info in dataset.scenes(): diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 0000000..f0d0c12 --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,181 @@ +import unittest + +import numpy as np +import torch + +from trajdata.data_structures.state import NP_STATE_TYPES, TORCH_STATE_TYPES + +AgentStateArray = NP_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,h"] +AgentObsArray = NP_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,s,c"] +AgentStateTensor = TORCH_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,h"] +AgentObsTensor = TORCH_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,s,c"] + + +class TestStateTensor(unittest.TestCase): + def test_construction(self): + a = AgentStateTensor(torch.rand(2, 8)) + b = torch.rand(8).as_subclass(AgentStateTensor) + c = AgentObsTensor(torch.rand(5, 9)) + + def test_class_propagation(self): + a = AgentStateTensor(torch.rand(2, 8)) + self.assertTrue(isinstance(a.to("cpu"), AgentStateTensor)) + + a = AgentStateTensor(torch.rand(2, 8)) + self.assertTrue(isinstance(a.cpu(), AgentStateTensor)) + + b = AgentStateTensor(torch.rand(2, 8)) + self.assertTrue(isinstance(a + b, AgentStateTensor)) + + b = torch.rand(2, 8) + self.assertTrue(isinstance(a + b, AgentStateTensor)) + + a += 1 + self.assertTrue(isinstance(a, AgentStateTensor)) + + def test_property_access(self): + a = AgentStateTensor(torch.rand(2, 8)) + position = a[..., :3] + velocity = a[..., 3:5] + acc = a[..., 5:7] + h = a[..., 7:] + + self.assertTrue(torch.allclose(a.position3d, position)) + self.assertTrue(torch.allclose(a.velocity, velocity)) + self.assertTrue(torch.allclose(a.acceleration, acc)) + self.assertTrue(torch.allclose(a.heading, h)) + + def test_heading_conversion(self): + a = AgentStateTensor(torch.rand(2, 8)) + h = a[..., 7:] + hv = a.heading_vector + self.assertTrue(torch.allclose(torch.atan2(hv[..., 1], hv[..., 0])[:, None], h)) + + def test_long_lat_velocity(self): + a = AgentStateTensor(torch.rand(8)) + velocity = a[3:5] + h = a[7] + lonlat_v = a.as_format("v_lon,v_lat") + lonlat_v_correct = ( + torch.tensor([[np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)]])[None, ...] + @ velocity[..., None] + )[..., 0] + + self.assertTrue(torch.allclose(lonlat_v, lonlat_v_correct)) + + b = a.as_format("x,y,xd,yd,s,c") + s = b[-2] + c = b[-1] + lonlat_v = b.as_format("v_lon,v_lat") + lonlat_v_correct = ( + torch.tensor([[c, s], [-s, c]])[None, ...] @ velocity[..., None] + )[..., 0] + + self.assertTrue(torch.allclose(lonlat_v, lonlat_v_correct)) + + def test_long_lat_conversion(self): + a = AgentStateTensor(torch.rand(2, 8)) + b = a.as_format("xd,yd,h") + c = b.as_format("v_lon,v_lat,h") + d = c.as_format("xd,yd,h") + self.assertTrue(torch.allclose(b, d)) + + def test_as_format(self): + a = AgentStateTensor(torch.rand(2, 8)) + b = a.as_format("x,y,z,xd,yd,xdd,ydd,s,c") + self.assertTrue(isinstance(b, AgentObsTensor)) + self.assertTrue(torch.allclose(a, b.as_format(a._format))) + + def test_as_tensor(self): + a = AgentStateTensor(torch.rand(2, 8)) + b = a.as_tensor() + self.assertTrue(isinstance(b, torch.Tensor)) + self.assertFalse(isinstance(b, AgentStateTensor)) + + def test_tensor_ops(self): + a = AgentStateTensor(torch.rand(2, 8)) + b = a[0] + a[1] + c = torch.mean(b) + self.assertFalse(isinstance(c, AgentStateTensor)) + self.assertTrue(isinstance(c, torch.Tensor)) + + +class TestStateArray(unittest.TestCase): + def test_construction(self): + a = np.random.rand(2, 8).view(AgentStateArray) + c = np.random.rand(5, 9).view(AgentObsArray) + + def test_property_access(self): + a = np.random.rand(2, 8).view(AgentStateArray) + position = a[..., :3] + velocity = a[..., 3:5] + acc = a[..., 5:7] + h = a[..., 7:] + + self.assertTrue(np.allclose(a.position3d, position)) + self.assertTrue(np.allclose(a.velocity, velocity)) + self.assertTrue(np.allclose(a.acceleration, acc)) + self.assertTrue(np.allclose(a.heading, h)) + + def test_property_setting(self): + a = np.random.rand(2, 8).view(AgentStateArray) + a.heading = 0.0 + self.assertTrue(np.allclose(a[..., -1], np.zeros([2, 1]))) + + def test_heading_conversion(self): + a = np.random.rand(2, 8).view(AgentStateArray) + h = a[..., 7:] + hv = a.heading_vector + self.assertTrue(np.allclose(np.arctan2(hv[..., 1], hv[..., 0])[:, None], h)) + + def test_long_lat_velocity(self): + a = np.random.rand(8).view(AgentStateArray) + velocity = a[3:5] + h = a[7] + lonlat_v = a.as_format("v_lon,v_lat") + lonlat_v_correct = ( + np.array([[np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)]])[None, ...] + @ velocity[..., None] + )[..., 0] + + self.assertTrue(np.allclose(lonlat_v, lonlat_v_correct)) + + b = a.as_format("x,y,xd,yd,s,c") + s = b[-2] + c = b[-1] + lonlat_v = b.as_format("v_lon,v_lat") + lonlat_v_correct = ( + np.array([[c, s], [-s, c]])[None, ...] @ velocity[..., None] + )[..., 0] + + self.assertTrue(np.allclose(lonlat_v, lonlat_v_correct)) + + def test_long_lat_conversion(self): + a = np.random.rand(2, 8).view(AgentStateArray) + b = a.as_format("xd,yd,h") + c = b.as_format("v_lon,v_lat,h") + d = c.as_format("xd,yd,h") + self.assertTrue(np.allclose(b, d)) + + def test_as_format(self): + a = np.random.rand(2, 8).view(AgentStateArray) + b = a.as_format("x,y,z,xd,yd,xdd,ydd,s,c") + self.assertTrue(isinstance(b, AgentObsArray)) + self.assertTrue(np.allclose(a, b.as_format(a._format))) + + def test_as_ndarray(self): + a: AgentStateArray = np.random.rand(2, 8).view(AgentStateArray) + b = a.as_ndarray() + self.assertTrue(isinstance(b, np.ndarray)) + self.assertFalse(isinstance(b, AgentStateArray)) + + def test_tensor_ops(self): + a = np.random.rand(2, 8).view(AgentStateArray) + b = a[0] + a[1] + c = np.mean(b) + self.assertFalse(isinstance(c, AgentStateArray)) + self.assertTrue(isinstance(c, float)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_traffic_data.py b/tests/test_traffic_data.py new file mode 100644 index 0000000..e20c7a0 --- /dev/null +++ b/tests/test_traffic_data.py @@ -0,0 +1,158 @@ +import unittest + +from trajdata import UnifiedDataset +from trajdata.caching.df_cache import DataFrameCache + + +class TestTrafficLightData(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + kwargs = { + "desired_data": ["nuplan_mini-mini_val"], + "centric": "scene", + "history_sec": (3.2, 3.2), + "future_sec": (4.8, 4.8), + "incl_robot_future": False, + "incl_raster_map": True, + "cache_location": "~/.unified_data_cache", + "raster_map_params": { + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + "num_workers": 64, + "verbose": True, + "data_dirs": { # Remember to change this to match your filesystem! + "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", + }, + } + + cls.dataset = UnifiedDataset( + **kwargs, + desired_dt=0.05, + ) + + cls.downsampled_dataset = UnifiedDataset( + **kwargs, + desired_dt=0.1, + ) + + cls.upsampled_dataset = UnifiedDataset( + **kwargs, + desired_dt=0.025, + ) + + cls.scene_num: int = 100 + + def test_traffic_light_loading(self): + # get random scene + scene = self.dataset.get_scene(self.scene_num) + scene_cache = DataFrameCache(self.dataset.cache_path, scene) + traffic_light_status = scene_cache.get_traffic_light_status_dict() + + # just check if the loading works without errors + self.assertTrue(traffic_light_status is not None) + + def test_downsampling(self): + # get random scene from both datasets + scene = self.dataset.get_scene(self.scene_num) + downsampled_scene = self.downsampled_dataset.get_scene(self.scene_num) + + self.assertEqual(scene.name, downsampled_scene.name) + + scene_cache = DataFrameCache(self.dataset.cache_path, scene) + downsampled_scene_cache = DataFrameCache( + self.downsampled_dataset.cache_path, downsampled_scene + ) + traffic_light_status = scene_cache.get_traffic_light_status_dict() + downsampled_traffic_light_status = ( + downsampled_scene_cache.get_traffic_light_status_dict() + ) + + orig_lane_ids = set(key[0] for key in traffic_light_status.keys()) + downsampled_lane_ids = set( + key[0] for key in downsampled_traffic_light_status.keys() + ) + self.assertSetEqual(orig_lane_ids, downsampled_lane_ids) + + # check that matching indices match + for ( + lane_id, + scene_ts, + ), downsampled_status in downsampled_traffic_light_status.items(): + if scene_ts % 2 == 0: + try: + prev_status = traffic_light_status[lane_id, scene_ts * 2] + except KeyError: + prev_status = None + + try: + next_status = traffic_light_status[lane_id, scene_ts * 2 + 1] + except KeyError: + next_status = None + + self.assertTrue( + prev_status is not None or next_status is not None, + f"Lane {lane_id} at t={scene_ts} has status {downsampled_status} " + f"in the downsampled dataset, but neither t={2*scene_ts} nor " + f"t={2*scene_ts + 1} were found in the original dataset.", + ) + self.assertTrue( + downsampled_status == prev_status + or downsampled_status == next_status, + f"Lane {lane_id} at t={scene_ts*2, scene_ts*2 + 1} in the original dataset " + f"had status {prev_status, next_status}, but in the downsampled dataset, " + f"{lane_id} at t={scene_ts} had status {downsampled_status}", + ) + + def test_upsampling(self): + # get random scene from both datasets + scene = self.dataset.get_scene(self.scene_num) + upsampled_scene = self.upsampled_dataset.get_scene(self.scene_num) + scene_cache = DataFrameCache(self.dataset.cache_path, scene) + upsampled_scene_cache = DataFrameCache( + self.upsampled_dataset.cache_path, upsampled_scene + ) + traffic_light_status = scene_cache.get_traffic_light_status_dict() + upsampled_traffic_light_status = ( + upsampled_scene_cache.get_traffic_light_status_dict() + ) + + # check that matching indices match + for (lane_id, scene_ts), status in upsampled_traffic_light_status.items(): + if scene_ts % 2 == 0: + orig_status = traffic_light_status[lane_id, scene_ts // 2] + self.assertEqual( + status, + orig_status, + f"Lane {lane_id} at t={scene_ts // 2} in the original dataset " + f"had status {orig_status}, but in the upsampled dataset, " + f"{lane_id} at t={scene_ts} had status {status}", + ) + else: + try: + prev_status = traffic_light_status[lane_id, scene_ts // 2] + except KeyError: + prev_status = None + try: + next_status = traffic_light_status[lane_id, scene_ts // 2 + 1] + except KeyError as k: + next_status = None + + self.assertTrue( + prev_status is not None or next_status is not None, + f"Lane {lane_id} at t={scene_ts} has status {status} " + f"in the upsampled dataset, but neither t={scene_ts // 2} nor " + f"t={scene_ts // 2 + 1} were found in the original dataset.", + ) + + self.assertTrue( + status == prev_status or status == next_status, + f"Lane {lane_id} at t={scene_ts // 2, scene_ts // 2 + 1} in the original dataset " + f"had status {prev_status, next_status}, but in the upsampled dataset, " + f"{lane_id} at t={scene_ts} had status {status}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vec_map.py b/tests/test_vec_map.py new file mode 100644 index 0000000..714f70f --- /dev/null +++ b/tests/test_vec_map.py @@ -0,0 +1,116 @@ +import unittest +from pathlib import Path +from typing import Dict, List + +import numpy as np +from shapely import contains_xy, dwithin, linearrings, points, polygons + +from trajdata import MapAPI, VectorMap +from trajdata.maps.vec_map_elements import MapElementType + + +class TestVectorMap(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cache_path = Path("~/.unified_data_cache").expanduser() + cls.map_api = MapAPI(cache_path) + cls.proto_loading_kwargs = { + "incl_road_lanes": True, + "incl_road_areas": True, + "incl_ped_crosswalks": True, + "incl_ped_walkways": True, + } + + cls.location_dict: Dict[str, List[str]] = { + "nusc_mini": ["boston-seaport", "singapore-onenorth"], + } + + # TODO(pkarkus) this assumes we already have the maps cached. It would be better + # to attempt to cache them if the cache does not yet exists. + def test_map_existence(self): + for env_name, map_names in self.location_dict.items(): + for map_name in map_names: + vec_map: VectorMap = self.map_api.get_map( + f"{env_name}:{map_name}", **self.proto_loading_kwargs + ) + assert vec_map is not None + + def test_proto_equivalence(self): + for env_name, map_names in self.location_dict.items(): + for map_name in map_names: + vec_map: VectorMap = self.map_api.get_map( + f"{env_name}:{map_name}", **self.proto_loading_kwargs + ) + + assert maps_equal( + VectorMap.from_proto( + vec_map.to_proto(), **self.proto_loading_kwargs + ), + vec_map, + ) + + def test_road_area_queries(self): + env_name = next(self.location_dict.keys().__iter__()) + map_name = self.location_dict[env_name][0] + + vec_map: VectorMap = self.map_api.get_map( + f"{env_name}:{map_name}", **self.proto_loading_kwargs + ) + + if vec_map.search_rtrees is None: + return + + point = vec_map.lanes[0].center.xy[0, :] + closest_area = vec_map.get_closest_area( + point, elem_type=MapElementType.ROAD_AREA + ) + holes = closest_area.interior_holes + if len(holes) == 0: + holes = None + closest_area_polygon = polygons(closest_area.exterior_polygon.xy, holes=holes) + self.assertTrue(contains_xy(closest_area_polygon, point[None, :2])) + + rnd_points = np.random.uniform( + low=vec_map.extent[:2], high=vec_map.extent[3:5], size=(10, 2) + ) + + NEARBY_DIST = 150.0 + for point in rnd_points: + nearby_areas = vec_map.get_areas_within( + point, elem_type=MapElementType.ROAD_AREA, dist=NEARBY_DIST + ) + for area in nearby_areas: + holes = [linearrings(hole.xy) for hole in area.interior_holes] + if len(holes) == 0: + holes = None + area_polygon = polygons(area.exterior_polygon.xy, holes=holes) + point_pt = points(point) + self.assertTrue(dwithin(area_polygon, point_pt, distance=NEARBY_DIST)) + + for elem_type in [ + MapElementType.PED_CROSSWALK, + MapElementType.PED_WALKWAY, + ]: + for point in rnd_points: + nearby_areas = vec_map.get_areas_within( + point, elem_type=elem_type, dist=NEARBY_DIST + ) + for area in nearby_areas: + area_polygon = polygons(area.polygon.xy) + point_pt = points(point) + if not dwithin(area_polygon, point_pt, distance=NEARBY_DIST): + print( + f"{elem_type.name} at {area_polygon} is not within {NEARBY_DIST} of {point_pt}", + ) + + # TODO(bivanovic): Add more! + + +def maps_equal(map1: VectorMap, map2: VectorMap) -> bool: + elements1_set = set([elem.id for elem in map1.iter_elems()]) + elements2_set = set([elem.id for elem in map2.iter_elems()]) + return elements1_set == elements2_set + + +if __name__ == "__main__": + unittest.main()