Skip to content

Commit

Permalink
Prediction: add semantic lstm pedestrian model
Browse files Browse the repository at this point in the history
  • Loading branch information
kechxu authored and storypku committed May 11, 2020
1 parent 107c4d8 commit 26480c4
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 7 deletions.
8 changes: 8 additions & 0 deletions modules/prediction/common/prediction_gflags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ DEFINE_string(torch_pedestrian_interaction_prediction_layer_file,
"/apollo/modules/prediction/data/"
"pedestrian_interaction_prediction_layer.pt",
"pedestrian interaction prediction layer");
DEFINE_string(
torch_pedestrian_semantic_lstm_file,
"/apollo/modules/prediction/data/semantic_lstm_pedestrian_model.pt",
"Pedestrian semantic lstm model file, default for gpu");
DEFINE_string(
torch_pedestrian_semantic_lstm_cpu_file,
"/apollo/modules/prediction/data/semantic_lstm_pedestrian_cpu_model.pt",
"Pedestrian semantic lstm cpu model file");
DEFINE_string(torch_lane_aggregating_obstacle_encoding_file,
"/apollo/modules/prediction/data/"
"traced_online_obs_enc.pt",
Expand Down
2 changes: 2 additions & 0 deletions modules/prediction/common/prediction_gflags.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ DECLARE_string(torch_pedestrian_interaction_position_embedding_file);
DECLARE_string(torch_pedestrian_interaction_social_embedding_file);
DECLARE_string(torch_pedestrian_interaction_single_lstm_file);
DECLARE_string(torch_pedestrian_interaction_prediction_layer_file);
DECLARE_string(torch_pedestrian_semantic_lstm_file);
DECLARE_string(torch_pedestrian_semantic_lstm_cpu_file);
DECLARE_string(torch_lane_aggregating_obstacle_encoding_file);
DECLARE_string(torch_lane_aggregating_lane_encoding_file);
DECLARE_string(torch_lane_aggregating_prediction_layer_file);
Expand Down
1 change: 1 addition & 0 deletions modules/prediction/conf/prediction_conf.pb.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ obstacle_conf {
}
obstacle_conf {
obstacle_type: PEDESTRIAN
obstacle_status: MOVING
evaluator_type: PEDESTRIAN_INTERACTION_EVALUATOR
predictor_type: FREE_MOVE_PREDICTOR
}
Expand Down
4 changes: 4 additions & 0 deletions modules/prediction/container/obstacles/obstacle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ bool IsClosed(const double x0, const double y0, const double theta0,

PerceptionObstacle::Type Obstacle::type() const { return type_; }

bool Obstacle::IsPedestrian() {
return type_ == PerceptionObstacle::PEDESTRIAN;
}

int Obstacle::id() const { return id_; }

double Obstacle::timestamp() const {
Expand Down
2 changes: 2 additions & 0 deletions modules/prediction/container/obstacles/obstacle.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class Obstacle {
*/
perception::PerceptionObstacle::Type type() const;

bool IsPedestrian();

/**
* @brief Get the obstacle's ID.
* @return The obstacle's ID.
Expand Down
Binary file not shown.
Binary file not shown.
25 changes: 19 additions & 6 deletions modules/prediction/evaluator/vehicle/semantic_lstm_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,14 @@ bool SemanticLSTMEvaluator::Evaluate(Obstacle* obstacle_ptr,
std::vector<double> pred_traj;

auto start_time = std::chrono::system_clock::now();
at::Tensor torch_output_tensor =
torch_model_.forward(torch_inputs).toTensor().to(torch::kCPU);
at::Tensor torch_output_tensor = torch_default_output_tensor_;
if (obstacle_ptr->IsPedestrian()) {
torch_output_tensor = torch_pedestrian_model_.forward(torch_inputs).
toTensor().to(torch::kCPU);
} else {
torch_output_tensor =
torch_vehicle_model_.forward(torch_inputs).toTensor().to(torch::kCPU);
}

auto end_time = std::chrono::system_clock::now();
std::chrono::duration<double> diff = end_time - start_time;
Expand Down Expand Up @@ -231,11 +237,15 @@ void SemanticLSTMEvaluator::LoadModel() {
if (FLAGS_use_cuda && torch::cuda::is_available()) {
ADEBUG << "CUDA is available";
device_ = torch::Device(torch::kCUDA);
torch_model_ =
torch_vehicle_model_ =
torch::jit::load(FLAGS_torch_vehicle_semantic_lstm_file, device_);
torch_pedestrian_model_ =
torch::jit::load(FLAGS_torch_pedestrian_semantic_lstm_file, device_);
} else {
torch_model_ =
torch_vehicle_model_ =
torch::jit::load(FLAGS_torch_vehicle_semantic_lstm_cpu_file, device_);
torch_pedestrian_model_ = torch::jit::load(
FLAGS_torch_pedestrian_semantic_lstm_cpu_file, device_);
}
torch::set_num_threads(1);

Expand All @@ -249,8 +259,11 @@ void SemanticLSTMEvaluator::LoadModel() {
std::move(obstacle_pos_step.to(device_))},
c10::TupleType::create(
std::vector<c10::TypePtr>(3, c10::TensorType::create()))));
at::Tensor torch_output_tensor =
torch_model_.forward(torch_inputs).toTensor().to(torch::kCPU);
// Run one inference to avoid very slow first inference later
torch_default_output_tensor_ =
torch_vehicle_model_.forward(torch_inputs).toTensor().to(torch::kCPU);
torch_default_output_tensor_ =
torch_pedestrian_model_.forward(torch_inputs).toTensor().to(torch::kCPU);
}

} // namespace prediction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ class SemanticLSTMEvaluator : public Evaluator {
void LoadModel();

private:
torch::jit::script::Module torch_model_;
torch::jit::script::Module torch_vehicle_model_;
torch::jit::script::Module torch_pedestrian_model_;
at::Tensor torch_default_output_tensor_;
torch::Device device_;
};

Expand Down

0 comments on commit 26480c4

Please sign in to comment.