-
Notifications
You must be signed in to change notification settings - Fork 43
/
batch_example.py
53 lines (43 loc) · 1.36 KB
/
batch_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from collections import defaultdict
from torch.utils.data import DataLoader
from tqdm import tqdm
from trajdata import AgentBatch, AgentType, UnifiedDataset
from trajdata.augmentation import NoiseHistories
from trajdata.visualization.vis import plot_agent_batch
def main():
noise_hists = NoiseHistories()
dataset = UnifiedDataset(
desired_data=["nusc_mini-mini_train"],
centric="agent",
desired_dt=0.1,
history_sec=(3.2, 3.2),
future_sec=(4.8, 4.8),
only_predict=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 30.0),
incl_robot_future=False,
incl_raster_map=True,
raster_map_params={
"px_per_m": 2,
"map_size_px": 224,
"offset_frac_xy": (-0.5, 0.0),
},
augmentations=[noise_hists],
num_workers=0,
verbose=True,
data_dirs={ # Remember to change this to match your filesystem!
"nusc_mini": "~/datasets/nuScenes",
},
)
print(f"# Data Samples: {len(dataset):,}")
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
collate_fn=dataset.get_collate_fn(),
num_workers=4,
)
batch: AgentBatch
for batch in tqdm(dataloader):
plot_agent_batch(batch, batch_idx=0)
if __name__ == "__main__":
main()