Skip to content

Commit

Permalink
fix a bug in preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry1iu committed May 19, 2022
1 parent 3f8b3f5 commit a52cdd4
Show file tree
Hide file tree
Showing 13 changed files with 15 additions and 35 deletions.
Binary file modified core/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
8 changes: 0 additions & 8 deletions core/dataloader/argoverse_loader_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,8 @@ def get(self, idx):

# pad feature with zero nodes
data.x = torch.cat([data.x, torch.zeros((index_to_pad - valid_len, feature_len), dtype=data.x.dtype)])
<<<<<<< HEAD
data.cluster = torch.cat([data.cluster, torch.arange(valid_len, index_to_pad, dtype=data.cluster.dtype)])
data.identifier = torch.cat([data.identifier, torch.zeros((index_to_pad - valid_len, 2), dtype=data.identifier.dtype)])
=======
data.cluster = torch.cat([data.cluster, torch.arange(valid_len, index_to_pad)]).long()
data.identifier = torch.cat([data.identifier, torch.zeros((index_to_pad - valid_len, 2), dtype=data.x.dtype)])
>>>>>>> 5b83e30e2f9960bef563ebac6a4390083b751b26

# pad candidate and candidate_gt
num_cand_max = data.candidate_len_max[0].item()
Expand Down Expand Up @@ -209,8 +204,6 @@ def _get_x(data_seq):
continue # skip if only 1 node
if cluster_idc < traj_cnt:
edge_index = np.hstack([edge_index, get_fc_edge_index(indices)])
<<<<<<< HEAD
=======
else:
edge_index = np.hstack([edge_index, get_fc_edge_index(indices)])
return feats, cluster, edge_index, identifier
Expand Down Expand Up @@ -374,7 +367,6 @@ def _get_x(data_seq):
continue # skip if only 1 node
if cluster_idc < traj_cnt:
edge_index = np.hstack([edge_index, get_fc_edge_index(indices)])
>>>>>>> 5b83e30e2f9960bef563ebac6a4390083b751b26
else:
edge_index = np.hstack([edge_index, get_fc_edge_index(indices)])
return feats, cluster, edge_index, identifier
Expand Down
Binary file modified core/util/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file modified core/util/__pycache__/cubic_spline.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified core/util/preprocessor/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file modified core/util/preprocessor/__pycache__/base.cpython-38.pyc
Binary file not shown.
11 changes: 5 additions & 6 deletions core/util/preprocessor/argoverse_preprocess_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
from os.path import join as pjoin
import copy
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
Expand Down Expand Up @@ -148,9 +149,7 @@ def get_obj_feats(self, data):
for i, _ in enumerate(ctr_line_candts):
ctr_line_candts[i] = np.matmul(rot, (ctr_line_candts[i] - orig.reshape(-1, 2)).T).T

tar_candts = self.lane_candidate_sampling(ctr_line_candts, viz=False)
if self.normalized and len(tar_candts[tar_candts[:, 1] >= 0, :]) != 0:
tar_candts = tar_candts[tar_candts[:, 1] >= 0, :]
tar_candts = self.lane_candidate_sampling(ctr_line_candts, [0, 0], viz=True)

if self.split == "test":
tar_candts_gt, tar_offse_gt = np.zeros((tar_candts.shape[0], 1)), np.zeros((1, 2))
Expand Down Expand Up @@ -426,16 +425,16 @@ def ref_copy(data):
parser.add_argument("-s", "--small", action='store_true', default=False)
args = parser.parse_args()

args.root = "/Users/jb/projects/trajectory_prediction_algorithms/yet-another-vectornet/data/"
# args.root = "/home/jb/projects/Code/trajectory-prediction/TNT-Trajectory-Predition/dataset"
raw_dir = os.path.join(args.root, "raw_data")
interm_dir = os.path.join(args.dest, "interm_data" if not args.small else "interm_data_small")

for split in ["train", "val", "test"]:
# construct the preprocessor and dataloader
argoverse_processor = ArgoversePreprocessor(root_dir=raw_dir, split=split, save_dir=interm_dir)
loader = DataLoader(argoverse_processor,
batch_size=16,
num_workers=16,
batch_size=1 if sys.gettrace() else 16, # 1 batch in debug mode
num_workers=0 if sys.gettrace() else 16, # use only 0 worker in debug mode
shuffle=False,
pin_memory=False,
drop_last=False)
Expand Down
31 changes: 10 additions & 21 deletions core/util/preprocessor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from argoverse.utils.mpl_plotting_utils import visualize_centerline

from core.util.cubic_spline import Spline2D


class Preprocessor(Dataset):
"""
Expand Down Expand Up @@ -113,29 +115,16 @@ def uniform_candidate_sampling(sampling_range, rate=30):
return np.stack(np.meshgrid(x, x), -1).reshape(-1, 2)

# implement a candidate sampling with equal distance;
def lane_candidate_sampling(self, centerline_list, distance=0.5, viz=False):
def lane_candidate_sampling(self, centerline_list, orig, distance=0.5, viz=False):
"""the input are list of lines, each line containing"""
candidates = []
for line in centerline_list:
for i in range(len(line) - 1):
if np.any(np.isnan(line[i])) or np.any(np.isnan(line[i+1])):
continue
[x_diff, y_diff] = line[i+1] - line[i]
if x_diff == 0.0 and y_diff == 0.0:
continue
candidates.append(line[i])

# compute displacement along each coordinate
den = np.hypot(x_diff, y_diff) + np.finfo(float).eps
d_x = distance * (x_diff / den)
d_y = distance * (y_diff / den)

num_c = np.floor(den / distance).astype(np.int)
pt = copy.deepcopy(line[i])
for j in range(num_c):
pt += np.array([d_x, d_y])
candidates.append(copy.deepcopy(pt))
candidates = np.unique(np.asarray(candidates), axis=0)
for lane_id, line in enumerate(centerline_list):
sp = Spline2D(x=line[:, 0], y=line[:, 1])
s_o, d_o = sp.calc_frenet_position(orig[0], orig[1])
s = np.arange(s_o, sp.s[-1], distance)
ix, iy = sp.calc_global_position_online(s)
candidates.append(np.stack([ix, iy], axis=1))
candidates = np.unique(np.concatenate(candidates), axis=0)

if viz:
fig = plt.figure(0, figsize=(8, 7))
Expand Down

0 comments on commit a52cdd4

Please sign in to comment.