Skip to content

Commit

Permalink
Update planner.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MCZhi authored Jul 15, 2022
1 parent 8bd92e8 commit 22633d5
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions model/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ def __init__(self, trajectory_len, feature_len, device, vectorize=True, test=Fal
control_variables = th.Vector(dof=100, name="control_variables")

# define prediction variable
predictions = th.Variable(data=torch.empty(1, 10, trajectory_len, 3), name="predictions")
predictions = th.Variable(torch.empty(1, 10, trajectory_len, 3), name="predictions")

# define ref_line_info
ref_line_info = th.Variable(data=torch.empty(1, 1200, 5), name="ref_line_info")
ref_line_info = th.Variable(torch.empty(1, 1200, 5), name="ref_line_info")

# define current state
current_state = th.Variable(data=torch.empty(1, 11, 8), name="current_state")
current_state = th.Variable(torch.empty(1, 11, 8), name="current_state")

# set up objective
objective = th.Objective()
Expand Down Expand Up @@ -66,34 +66,34 @@ def bicycle_model(control, current_state):

# cost functions
def acceleration(optim_vars, aux_vars):
control = optim_vars[0].data.view(-1, 50, 2)
control = optim_vars[0].tensor.view(-1, 50, 2)
acc = control[:, :, 0]

return acc

def jerk(optim_vars, aux_vars):
control = optim_vars[0].data.view(-1, 50, 2)
control = optim_vars[0].tensor.view(-1, 50, 2)
acc = control[:, :, 0]
jerk = torch.diff(acc) / 0.1

return jerk

def steering(optim_vars, aux_vars):
control = optim_vars[0].data.view(-1, 50, 2)
control = optim_vars[0].tensor.view(-1, 50, 2)
steering = control[:, :, 1]

return steering

def steering_change(optim_vars, aux_vars):
control = optim_vars[0].data.view(-1, 50, 2)
control = optim_vars[0].tensor.view(-1, 50, 2)
steering = control[:, :, 1]
steering_change = torch.diff(steering) / 0.1

return steering_change

def speed(optim_vars, aux_vars):
control = optim_vars[0].data.view(-1, 50, 2)
current_state = aux_vars[1].data[:, 0]
control = optim_vars[0].tensor.view(-1, 50, 2)
current_state = aux_vars[1].tensor[:, 0]
velocity = torch.hypot(current_state[:, 3], current_state[:, 4])
dt = 0.1

Expand All @@ -108,9 +108,9 @@ def speed(optim_vars, aux_vars):
def lane_xy(optim_vars, aux_vars):
global ref_points

control = optim_vars[0].data.view(-1, 50, 2)
ref_line = aux_vars[0].data
current_state = aux_vars[1].data[:, 0]
control = optim_vars[0].tensor.view(-1, 50, 2)
ref_line = aux_vars[0].tensor
current_state = aux_vars[1].tensor[:, 0]

traj = bicycle_model(control, current_state)
distance_to_ref = torch.cdist(traj[:, :, :2], ref_line[:, :, :2])
Expand All @@ -121,8 +121,8 @@ def lane_xy(optim_vars, aux_vars):
return lane_error

def lane_theta(optim_vars, aux_vars):
control = optim_vars[0].data.view(-1, 50, 2)
current_state = aux_vars[1].data[:, 0]
control = optim_vars[0].tensor.view(-1, 50, 2)
current_state = aux_vars[1].tensor[:, 0]

traj = bicycle_model(control, current_state)
theta = traj[:, :, 2]
Expand All @@ -131,9 +131,9 @@ def lane_theta(optim_vars, aux_vars):
return lane_error

def red_light_violation(optim_vars, aux_vars):
control = optim_vars[0].data.view(-1, 50, 2)
current_state = aux_vars[1].data[:, 0]
ref_line = aux_vars[0].data
control = optim_vars[0].tensor.view(-1, 50, 2)
current_state = aux_vars[1].tensor[:, 0]
ref_line = aux_vars[0].tensor
red_light = ref_line[..., -1]
dt = 0.1

Expand All @@ -150,10 +150,10 @@ def red_light_violation(optim_vars, aux_vars):
return red_light_error

def safety(optim_vars, aux_vars):
control = optim_vars[0].data.view(-1, 50, 2)
neighbors = aux_vars[0].data.permute(0, 2, 1, 3)
current_state = aux_vars[1].data
ref_line = aux_vars[2].data
control = optim_vars[0].tensor.view(-1, 50, 2)
neighbors = aux_vars[0].tensor.permute(0, 2, 1, 3)
current_state = aux_vars[1].tensor
ref_line = aux_vars[2].tensor

actor_mask = torch.ne(current_state, 0)[:, 1:, -1]
ego_current_state = current_state[:, 0]
Expand Down

0 comments on commit 22633d5

Please sign in to comment.