Skip to content

Commit

Permalink
[onert] Infer shapes in TrainingCompiler (Samsung#10942)
Browse files Browse the repository at this point in the history
This commit allows TrainingCompiler to infer shapes based on batch size.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jul 6, 2023
1 parent d884ec0 commit 5f85e93
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 12 deletions.
2 changes: 2 additions & 0 deletions runtime/onert/core/include/compiler/ILoweredGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct ILoweredGraph
virtual const ir::Graph &graph() const = 0;
virtual const compiler::GraphLowerInfo &lower_info() const = 0;
virtual compiler::GraphLowerInfo &lower_info() = 0;
virtual void setHasDynamicTensor(ir::OperationIndex ind, bool val) = 0;
virtual bool getHasDynamicTensor(ir::OperationIndex ind) const = 0;
};

} // namespace compiler
Expand Down
4 changes: 2 additions & 2 deletions runtime/onert/core/include/compiler/LoweredGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class LoweredGraph : public ILoweredGraph
compiler::GraphLowerInfo &lower_info() override { return _lower_info_map; }
std::shared_ptr<ir::OperationIndexMap<int64_t>> indexed_ranks() { return _indexed_ranks; }

void setHasDynamicTensor(ir::OperationIndex ind, bool val)
void setHasDynamicTensor(ir::OperationIndex ind, bool val) override
{
_has_dynamic_tensor_map.emplace(ind, val);
}
bool getHasDynamicTensor(ir::OperationIndex ind) const
bool getHasDynamicTensor(ir::OperationIndex ind) const override
{
auto itr = _has_dynamic_tensor_map.find(ind);
return (itr == _has_dynamic_tensor_map.end()) ? false : itr->second;
Expand Down
10 changes: 5 additions & 5 deletions runtime/onert/core/include/compiler/StaticShapeInferer.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class OperandObserver
class StaticShapeInferer : public ir::OperationVisitor
{
public:
StaticShapeInferer(compiler::LoweredGraph *lowered_subg)
StaticShapeInferer(compiler::ILoweredGraph *lowered_subg)
: _lowered_subg{lowered_subg}, _subg_input_observers{}, _controlflow_output_observer{nullptr},
_child_inferers{}
{
Expand Down Expand Up @@ -102,13 +102,13 @@ class StaticShapeInferer : public ir::OperationVisitor
void dump();

/**
* @brief Create a lowered model shape inferer map
* @param[in] lowered_subgs lowered model subgraph map
* @brief Create a shape inferer map for a lowered model
* @param[in] lowered_subgs lowered model map
* @return Shape inferer map
*/
static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
createStaticShapeInferers(
const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs);
const std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> &lowered_subgs);

private:
bool checkDynamicInput(const ir::IOperation &op);
Expand Down Expand Up @@ -178,7 +178,7 @@ class StaticShapeInferer : public ir::OperationVisitor
void handleSimpleUnaryOp(const ir::Operation &op, const ir::OperandIndex input_idx);

private:
compiler::LoweredGraph *_lowered_subg;
compiler::ILoweredGraph *_lowered_subg;
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<OperandObserver>>
_subg_input_observers; // child subg input
std::unique_ptr<OperandObserver> _controlflow_output_observer; // parent controlflow op output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ class LoweredTrainableGraph : public ILoweredGraph
compiler::GraphLowerInfo &lower_info() override { return _lower_info_map; }
std::shared_ptr<ir::OperationIndexMap<int64_t>> indexed_ranks() { return _indexed_ranks; }

void setHasDynamicTensor(ir::OperationIndex, bool has_dynamic) override
{
if (has_dynamic)
throw std::runtime_error("LoweredTrainableGraph does not support dynamic tensors yet");
}
bool getHasDynamicTensor(ir::OperationIndex) const override { return false; }

private:
void makeLowerInfo(const compiler::BackendResolver &backend_resolver);
void dumpLowerInfo();
Expand Down
3 changes: 2 additions & 1 deletion runtime/onert/core/src/compiler/Compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "compiler/Compiler.h"

#include "CompilerHelpers.h"
#include "ExecutorFactory.h"
#include "ShapeValidator.h"
#include "pass/ConstantOutputPass.h"
Expand Down Expand Up @@ -137,7 +138,7 @@ std::shared_ptr<CompilerArtifact> Compiler::compile(void)
// Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
// recursively
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
StaticShapeInferer::createStaticShapeInferers(lowered_subgs);
createStaticShapeInferers(lowered_subgs);

const auto primary_subg_idx = ir::SubgraphIndex{0};
inferers.at(primary_subg_idx)->infer();
Expand Down
52 changes: 52 additions & 0 deletions runtime/onert/core/src/compiler/CompilerHelpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_COMPILER_COMPILER_HELPERS_H__
#define __ONERT_COMPILER_COMPILER_HELPERS_H__

#include <compiler/ILoweredGraph.h>
#include <compiler/StaticShapeInferer.h>
#include <ir/Index.h>

#include <memory>
#include <unordered_map>

namespace onert
{
namespace compiler
{

/**
* @brief Create a shape inferer map for a lowered model
* @param[in] lowered_subgs lowered model map
* @return Shape inferer map
*/
template <typename LoweredGraphType,
typename = std::enable_if_t<std::is_base_of<ILoweredGraph, LoweredGraphType>::value>>
static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
createStaticShapeInferers(
const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraphType>> &lowered_subgs)
{
std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> lsubgs;
for (auto &&e : lowered_subgs)
lsubgs[e.first] = e.second.get();
return StaticShapeInferer::createStaticShapeInferers(lsubgs);
}

} // namespace compiler
} // namespace onert

#endif // __ONERT_COMPILER_COMPILER_HELPERS_H__
3 changes: 2 additions & 1 deletion runtime/onert/core/src/compiler/MultiModelCompiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "MultiModelCompiler.h"

#include "CompilerHelpers.h"
#include "ExecutorFactory.h"
#include "ShapeValidator.h"
#include "pass/ConstantOutputPass.h"
Expand Down Expand Up @@ -169,7 +170,7 @@ std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void)
// Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
// recursively
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
StaticShapeInferer::createStaticShapeInferers(model_lsubgs);
createStaticShapeInferers(model_lsubgs);

const auto primary_subg_idx = ir::SubgraphIndex{0};
inferers.at(primary_subg_idx)->infer();
Expand Down
4 changes: 2 additions & 2 deletions runtime/onert/core/src/compiler/StaticShapeInferer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,15 @@ void StaticShapeInferer::dump()

std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
StaticShapeInferer::createStaticShapeInferers(
const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs)
const std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> &lowered_subgs)
{
// Allocate StaticShapeInferer per each subgraph
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers;
for (auto &&pair : lowered_subgs)
{
const auto &subg_index = pair.first;
auto &lowered_subg = pair.second;
inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg.get());
inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg);
}

// Append observers in all StaticShapeInferers
Expand Down
20 changes: 19 additions & 1 deletion runtime/onert/core/src/compiler/train/TrainingCompiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "TrainableOperationConverter.h"
#include "pass/LossInsertionPass.h"
#include "../CompilerHelpers.h"
#include "../ExecutorFactory.h"
#include "../pass/ConstantOutputPass.h"
#include "../pass/OddOutputPass.h"
Expand Down Expand Up @@ -175,7 +176,24 @@ std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void)
dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value()));
}

// TODO Shape inference for applying batch size.
// Shape inference.
{
// Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
// recursively
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
createStaticShapeInferers(lowered_subgs);

const auto primary_subg_idx = ir::SubgraphIndex{0};
inferers.at(primary_subg_idx)->infer();

for (const auto &pair_inferer : inferers)
{
const auto inferer = pair_inferer.second.get();
inferer->dump();
}
}

// TODO Infer shapes for gradient

// Shape validation
for (const auto &pair : lowered_subgs)
Expand Down

0 comments on commit 5f85e93

Please sign in to comment.