Skip to content

Commit

Permalink
Bazel: make modules/prediction build passed. (ApolloAuto#11394)
Browse files Browse the repository at this point in the history
  • Loading branch information
changsh726 authored Jun 8, 2020
1 parent ccb6ebf commit b40b6e9
Show file tree
Hide file tree
Showing 25 changed files with 126 additions and 89 deletions.
5 changes: 5 additions & 0 deletions external/libtorch_cpu.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ cc_library(
linkstatic = False,
linkopts = [
"-L/usr/local/libtorch_cpu/lib",
"-lc10",
"-ltorch",
"-ltorch_cpu",
],
deps = [
"@python3",
],
)
7 changes: 6 additions & 1 deletion external/libtorch_gpu.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ cc_library(
linkstatic = False,
linkopts = [
"-L/usr/local/libtorch_gpu/lib",
"-ltorch",
"-lc10",
"-ltorch",
"-ltorch_cpu",
"-ltorch_cuda",
],
deps = [
"@python3",
],
)
5 changes: 4 additions & 1 deletion modules/prediction/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cc_library(
"//modules/prediction/common:message_process",
"//modules/prediction/evaluator:evaluator_manager",
"//modules/prediction/predictor:predictor_manager",
"//modules/prediction/proto:offline_features_proto",
"//modules/prediction/proto:offline_features_cc_proto",
"//modules/prediction/scenario:scenario_manager",
"//modules/prediction/submodules:evaluator_submodule_lib",
"//modules/prediction/submodules:predictor_submodule_lib",
Expand All @@ -35,6 +35,9 @@ cc_test(
":prediction_data",
":prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
":prediction_component_lib",
],
Expand Down
16 changes: 8 additions & 8 deletions modules/prediction/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ cc_library(
"//modules/common/util",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/container/obstacles:obstacle",
"//modules/prediction/proto:offline_features_proto",
"//modules/prediction/proto:prediction_proto",
"//modules/prediction/proto:offline_features_cc_proto",
"//modules/prediction/proto:prediction_obstacle_cc_proto",
],
)

Expand All @@ -110,7 +110,7 @@ cc_library(
"//modules/prediction/common:prediction_constants",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/common:prediction_system_gflags",
"//modules/prediction/proto:lane_graph_proto",
"//modules/prediction/proto:lane_graph_cc_proto",
],
)

Expand Down Expand Up @@ -141,7 +141,7 @@ cc_library(
deps = [
"//modules/common/math",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/proto:lane_graph_proto",
"//modules/prediction/proto:lane_graph_cc_proto",
],
)

Expand All @@ -165,7 +165,7 @@ cc_library(
hdrs = ["environment_features.h"],
deps = [
"//cyber",
"//modules/common/proto:geometry_proto",
"//modules/common/proto:geometry_cc_proto",
],
)

Expand All @@ -188,7 +188,7 @@ cc_library(
],
deps = [
":prediction_map",
"//modules/prediction/proto:feature_proto",
"//modules/prediction/proto:feature_cc_proto",
],
)

Expand Down Expand Up @@ -221,7 +221,7 @@ cc_library(
"//modules/common/adapters:adapter_gflags",
"//modules/prediction/evaluator:evaluator_manager",
"//modules/prediction/predictor:predictor_manager",
"//modules/prediction/proto:offline_features_proto",
"//modules/prediction/proto:offline_features_cc_proto",
"//modules/prediction/scenario:scenario_manager",
"//modules/prediction/util:data_extraction",
],
Expand Down Expand Up @@ -264,7 +264,7 @@ cc_library(
"//modules/common/util",
"//modules/prediction/container:container_manager",
"//modules/prediction/container/pose:pose_container",
"//modules/prediction/proto:feature_proto",
"//modules/prediction/proto:feature_cc_proto",
"@opencv",
],
)
Expand Down
68 changes: 28 additions & 40 deletions modules/prediction/common/message_process.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "cyber/common/file.h"
#include "cyber/record/record_reader.h"
#include "cyber/record/record_writer.h"

#include "modules/common/adapters/adapter_gflags.h"
#include "modules/prediction/common/feature_output.h"
#include "modules/prediction/common/junction_analyzer.h"
Expand Down Expand Up @@ -297,51 +298,38 @@ void MessageProcess::ProcessOfflineData(
message.channel_name, perception_obstacles, message.time);
}
PredictionObstacles prediction_obstacles;
OnPerception(perception_obstacles, &prediction_obstacles);
OnPerception(perception_obstacles, container_manager, evaluator_manager,
predictor_manager, scenario_manager,
&prediction_obstacles);
if (FLAGS_prediction_offline_mode == PredictionConstants::kDumpRecord) {
SingleMessage single_message;
std::string content = "";
prediction_obstacles.SerializeToString(&content);
single_message.set_content(content);
single_message.set_time(message.time);
single_message.set_channel_name(FLAGS_prediction_topic);
writer.WriteMessage(RecordMessageToSingleMessage(message));
=======
OnPerception(perception_obstacles, container_manager,
evaluator_manager, predictor_manager, scenario_manager,
&prediction_obstacles);
if (FLAGS_prediction_offline_mode ==
PredictionConstants::kDumpRecord) {
writer.WriteMessage<PredictionObstacles>(
prediction_conf.topic_conf().perception_obstacle_topic(),
prediction_obstacles, message.time);
AINFO << "Generated a new prediction message.";
>>>>>>> master
}
}
} else if (message.channel_name ==
prediction_conf.topic_conf().localization_topic()) {
LocalizationEstimate localization;
if (localization.ParseFromString(message.content)) {
if (FLAGS_prediction_offline_mode ==
PredictionConstants::kDumpRecord) {
writer.WriteMessage<LocalizationEstimate>(
message.channel_name, localization, message.time);
}
OnLocalization(container_manager.get(), localization);
writer.WriteMessage<PredictionObstacles>(
prediction_conf.topic_conf().perception_obstacle_topic(),
prediction_obstacles, message.time);
AINFO << "Generated a new prediction message.";
}
} else if (message.channel_name ==
prediction_conf.topic_conf().planning_trajectory_topic()) {
ADCTrajectory adc_trajectory;
if (adc_trajectory.ParseFromString(message.content)) {
OnPlanning(container_manager.get(), adc_trajectory);
}
} else if (message.channel_name ==
prediction_conf.topic_conf().localization_topic()) {
LocalizationEstimate localization;
if (localization.ParseFromString(message.content)) {
if (FLAGS_prediction_offline_mode == PredictionConstants::kDumpRecord) {
writer.WriteMessage<LocalizationEstimate>(message.channel_name,
localization, message.time);
}
OnLocalization(container_manager.get(), localization);
}
} else if (message.channel_name ==
prediction_conf.topic_conf().planning_trajectory_topic()) {
ADCTrajectory adc_trajectory;
if (adc_trajectory.ParseFromString(message.content)) {
OnPlanning(container_manager.get(), adc_trajectory);
}
}
if (FLAGS_prediction_offline_mode == PredictionConstants::kDumpRecord) {
writer.Close();
}
}
if (FLAGS_prediction_offline_mode == PredictionConstants::kDumpRecord) {
writer.Close();
}
}

} // namespace prediction
} // namespace prediction
} // namespace apollo
2 changes: 1 addition & 1 deletion modules/prediction/container/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ cc_library(
"-DMODULE_NAME=\\\"prediction\\\"",
],
deps = [
"//modules/common/adapters/proto:adapter_config_proto",
"//modules/common/adapters/proto:adapter_config_cc_proto",
"//modules/prediction/container/adc_trajectory:adc_trajectory_container",
"//modules/prediction/container/obstacles:obstacles_container",
"//modules/prediction/container/pose:pose_container",
Expand Down
3 changes: 2 additions & 1 deletion modules/prediction/container/adc_trajectory/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ cc_library(
"-DMODULE_NAME=\\\"prediction\\\"",
],
deps = [
"//modules/planning/proto:planning_proto",
"//modules/planning/proto:planning_cc_proto",
"//modules/prediction/common:prediction_map",
"//modules/prediction/container",
"//modules/prediction/proto:lane_graph_cc_proto",
],
)

Expand Down
8 changes: 4 additions & 4 deletions modules/prediction/container/obstacles/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cc_library(
"//modules/prediction/common:prediction_constants",
"//modules/prediction/container",
"//modules/prediction/container/obstacles:obstacle",
"//modules/prediction/proto:prediction_proto",
"//modules/prediction/proto:prediction_obstacle_cc_proto",
"//modules/prediction/submodules:submodule_output",
],
)
Expand All @@ -33,8 +33,8 @@ cc_library(
"//modules/prediction/common:junction_analyzer",
"//modules/prediction/container/obstacles:obstacle_clusters",
"//modules/prediction/network/rnn_model",
"//modules/prediction/proto:prediction_conf_proto",
"//modules/prediction/proto:prediction_proto",
"//modules/prediction/proto:prediction_conf_cc_proto",
"//modules/prediction/proto:prediction_obstacle_cc_proto",
],
)

Expand Down Expand Up @@ -77,7 +77,7 @@ cc_library(
],
deps = [
"//modules/prediction/common:road_graph",
"//modules/prediction/proto:feature_proto",
"//modules/prediction/proto:feature_cc_proto",
],
)

Expand Down
4 changes: 2 additions & 2 deletions modules/prediction/container/pose/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ cc_library(
],
deps = [
"//modules/common/math:quaternion",
"//modules/localization/proto:localization_proto",
"//modules/perception/proto:perception_proto",
"//modules/localization/proto:localization_cc_proto",
"//modules/perception/proto:perception_obstacle_cc_proto",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/container",
],
Expand Down
2 changes: 1 addition & 1 deletion modules/prediction/container/storytelling/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ cc_library(
deps = [
"//modules/prediction/common:prediction_map",
"//modules/prediction/container",
"//modules/storytelling/proto:story_proto",
"//modules/storytelling/proto:story_cc_proto",
],
)

Expand Down
6 changes: 5 additions & 1 deletion modules/prediction/evaluator/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ cc_library(
"//modules/prediction/evaluator/vehicle:lane_scanning_evaluator",
"//modules/prediction/evaluator/vehicle:mlp_evaluator",
"//modules/prediction/evaluator/vehicle:semantic_lstm_evaluator",
"//modules/prediction/proto:prediction_conf_proto",
"//modules/prediction/proto:prediction_conf_cc_proto",
"//third_party:libtorch",
],
)

Expand All @@ -38,6 +39,9 @@ cc_test(
"//modules/prediction:prediction_data",
"//modules/prediction:prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
"//modules/prediction/common:kml_map_based_test",
"//modules/prediction/evaluator:evaluator_manager",
Expand Down
13 changes: 12 additions & 1 deletion modules/prediction/evaluator/vehicle/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cc_library(
"//modules/prediction/common:validation_checker",
"//modules/prediction/container/obstacles:obstacles_container",
"//modules/prediction/evaluator",
"//modules/prediction/proto:fnn_vehicle_model_proto",
"//modules/prediction/proto:fnn_vehicle_model_cc_proto",
],
)

Expand Down Expand Up @@ -70,6 +70,7 @@ cc_library(
hdrs = ["junction_mlp_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/common/math:geometry",
Expand All @@ -90,6 +91,9 @@ cc_test(
"//modules/prediction:prediction_data",
"//modules/prediction:prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
"//modules/prediction/common:kml_map_based_test",
"//modules/prediction/evaluator/vehicle:junction_mlp_evaluator",
Expand All @@ -103,6 +107,7 @@ cc_library(
hdrs = ["junction_map_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/prediction/common:prediction_util",
Expand All @@ -119,6 +124,7 @@ cc_library(
hdrs = ["cruise_mlp_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/prediction/common:prediction_util",
Expand All @@ -137,6 +143,9 @@ cc_test(
"//modules/prediction:prediction_data",
"//modules/prediction:prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
"//modules/prediction/common:kml_map_based_test",
"//modules/prediction/evaluator/vehicle:cruise_mlp_evaluator",
Expand All @@ -149,6 +158,7 @@ cc_library(
hdrs = ["lane_scanning_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/prediction/container:container_manager",
Expand Down Expand Up @@ -179,6 +189,7 @@ cc_library(
hdrs = ["semantic_lstm_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/prediction/common:prediction_util",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ bool JunctionMapEvaluator::Evaluate(Obstacle* obstacle_ptr,
junction_exit_mask[0][i] = static_cast<float>(feature_values[i]);
}

torch_inputs.push_back(c10::ivalue::Tuple::create(
at::Tensor torch_input_tensor;
torch_inputs.push_back(c10::ivalue::Tuple::createNamed(
{std::move(img_tensor.to(device_)),
std::move(junction_exit_mask.to(device_))},
c10::TupleType::create(
std::vector<c10::TypePtr>(2, c10::TensorType::create()))));
c10::TupleType::create(std::vector<c10::TypePtr>(
2, c10::TensorType::create(torch_input_tensor)))));

// Compute probability
std::vector<double> probability;
Expand Down
Loading

0 comments on commit b40b6e9

Please sign in to comment.