Skip to content

Commit

Permalink
planning: loading test learning model scenario config.
Browse files Browse the repository at this point in the history
  • Loading branch information
ycool authored and jinghaomiao committed Mar 25, 2020
1 parent 10e2bfa commit 297ec9a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
1 change: 1 addition & 0 deletions modules/planning/scenarios/learning_model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cc_library(
"//modules/planning/proto:planning_proto",
"//modules/planning/reference_line",
"//modules/planning/scenarios:scenario",
"//third_party:libtorch",
"@com_github_gflags_gflags//:gflags",
"@eigen",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,29 @@

#include "modules/planning/scenarios/learning_model/test_learning_model_scenario.h"

#include <algorithm>
#include <iterator>

#include "cyber/common/log.h"

namespace apollo {
namespace planning {
namespace scenario {

TestLearningModelScenario::TestLearningModelScenario(
const ScenarioConfig& scenario_config,
const ScenarioContext* context)
: Scenario(scenario_config, context), device_(torch::kCPU) {
const auto& config = scenario_config.test_learning_model_config();
AINFO << "Loading learning model:" << config.model_file();
model_ = torch::jit::load(config.model_file(), device_);

std::copy(config.input_shape().begin(), config.input_shape().end(),
std::back_inserter(input_shapes_));
std::copy(config.output_shape().begin(), config.output_shape().end(),
std::back_inserter(output_shapes_));
}

Scenario::ScenarioStatus TestLearningModelScenario::Process(
const common::TrajectoryPoint& planning_init_point,
Frame* frame) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#pragma once

#include <memory>
#include <vector>

#include "torch/script.h"
#include "torch/torch.h"

#include "modules/common/status/status.h"
#include "modules/common/util/factory.h"
Expand All @@ -34,16 +38,22 @@ namespace scenario {
class TestLearningModelScenario : public Scenario {
public:
TestLearningModelScenario(const ScenarioConfig& config,
const ScenarioContext* context)
: Scenario(config, context) {}
const ScenarioContext* context);

// TODO(all): continue to refactor scenario framework to
// make output more clear
ScenarioStatus Process(
const common::TrajectoryPoint& planning_init_point,
Frame* frame) override;

std::unique_ptr<Stage> CreateStage(
const ScenarioConfig::StageConfig& stage_config) override;

private:
torch::jit::script::Module model_;
torch::Device device_;
std::vector<int> input_shapes_;
std::vector<int> output_shapes_;
};

} // namespace scenario
Expand Down

0 comments on commit 297ec9a

Please sign in to comment.