Skip to content

Commit

Permalink
[infrt] Add linear cpu demo (PaddlePaddle#40715)
Browse files Browse the repository at this point in the history
  • Loading branch information
DannyIsFunny authored Mar 22, 2022
1 parent c29f85b commit fcf8758
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 4 deletions.
4 changes: 4 additions & 0 deletions paddle/infrt/dialect/phi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ endif()
add_subdirectory(ir)
add_subdirectory(pass)

add_executable(phi-ir-exec phi_ir_exec.cc)
target_link_libraries(phi-ir-exec infrt)


add_executable(phi-exec phi_exec.cc)
target_link_libraries(phi-exec infrt)

Expand Down
4 changes: 2 additions & 2 deletions paddle/infrt/dialect/phi/ir/infrt_phi_base.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def PHI_Dialect : Dialect {

def PhiOpTrait : NativeOpTrait<"PhiOpTrait">;

class PHI_Type<string type, list<Trait> traits = []>
: TypeDef<PHI_Dialect, type, !listconcat(traits, [PhiOpTrait, IsolatedFromAbove])> {}
class PHI_Type<string type, list<Trait> traits = [], string baseCppClass = "::mlir::Type">
: TypeDef<PHI_Dialect, type, !listconcat(traits, [PhiOpTrait, IsolatedFromAbove]), baseCppClass> {}

def Allocator : PHI_Type<"Allocator"> {
let mnemonic = "allocator";
Expand Down
7 changes: 5 additions & 2 deletions paddle/infrt/host_context/paddle_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/infrt/dialect/infrt/ir/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/pd/common/pd_ops_info.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"

MLIRModelGenImpl::MLIRModelGenImpl()
: context_(infrt::Global::getMLIRContext()), builder_(context_) {
Expand All @@ -24,6 +25,8 @@ MLIRModelGenImpl::MLIRModelGenImpl()
context_->getOrLoadDialect<infrt::dt::DTDialect>();
context_->getOrLoadDialect<infrt::pd::PaddleDialect>();
context_->getOrLoadDialect<::infrt::InfrtDialect>();
context_->getOrLoadDialect<::infrt::phi::PHIDialect>();
context_->getOrLoadDialect<::infrt::phi::PHIDenseTensorDialect>();
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(context_));
}

Expand Down Expand Up @@ -79,7 +82,7 @@ mlir::FuncOp MLIRModelGenImpl::UpdateModelModule(
llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetModelInputsType(
const infrt::paddle::framework_proto::ProgramDesc &program) {
llvm::SmallVector<mlir::Type, 4> operandTypes;
operandTypes.push_back(infrt::DenseHostTensorMapType::get(context_));
operandTypes.push_back(infrt::phi::DenseTensorMapType::get(context_));
for (auto &op_desc : main_block_.ops()) {
if (op_desc.type() != "feed") continue;
for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) {
Expand Down Expand Up @@ -180,7 +183,7 @@ void MLIRModelGenImpl::UpdateModelParams(
&precision_);
mlir::Type type_ = infrt::DenseTensorType::get(
context_, infrt::TargetType::CPU, precision_, infrt::LayoutType::ANY);
auto op = builder_.create<infrt::dt::TensorMapGetTensorOp>(
auto op = builder_.create<::infrt::phi::TensorMapGetTensorOp>(
mlir::UnknownLoc::get(context_), type_, map, name);
params_map_.insert(std::pair<std::string, mlir::Value>(
var_desc.name(), op.getOperation()->getResult(0)));
Expand Down
1 change: 1 addition & 0 deletions paddle/infrt/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle
DEPENDS infrtopt infrtexec)

configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir)
19 changes: 19 additions & 0 deletions paddle/infrt/tests/dialect/phi/linear_cpu.mlir.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: infrtexec -i %s
module {
func @main_graph(%arg0: !phi.dense_tensor_map, %arg1: !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW> {
%0 = phi_dt.tensor_map_get_tensor(%arg0) {name = "linear_0.w_0"} -> !infrt.dense_tensor<CPU, FP32, NCHW>
%1 = phi_dt.tensor_map_get_tensor(%arg0) {name = "linear_0.b_0"} -> !infrt.dense_tensor<CPU, FP32, NCHW>
%2 = "phi_dt.create_context.cpu"() : () -> !phi.context<CPU>
%5 = "phi_cpu.matmul.float32.any"(%2, %arg1, %0) {trans_x = false, trans_y = false} : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%7 = "phi_cpu.add.float32.any"(%2, %5, %1): (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
infrt.return %7 : !infrt.dense_tensor<CPU, FP32, NCHW>
}
func @main() {
%ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
%1 = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[16:i64, 784:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
%map = phi_dt.load_combined_params(){model_path="@CMAKE_BINARY_DIR@/linear/linear.pdmodel",params_path="@CMAKE_BINARY_DIR@/linear/linear.pdiparams"}
%2 = infrt.call@main_graph(%map, %1) : (!phi.dense_tensor_map, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
phi_dt.print_tensor (%2 : !infrt.dense_tensor<CPU, FP32, NCHW>)
infrt.return
}
}
80 changes: 80 additions & 0 deletions paddle/infrt/tests/model/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# example 1: save layer
import numpy as np
import paddle
import paddle.nn as nn
import paddle.optimizer as opt

BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4

IMAGE_SIZE = 784
CLASS_NUM = 10


# define a random dataset
class RandomDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, idx):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
return image, label

def __len__(self):
return self.num_samples


class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)

@paddle.jit.to_static
def forward(self, x):
return self._linear(x)


def train(layer, loader, loss_fn, opt):
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()):
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
opt.step()
opt.clear_grad()


# 1. train & save model.

# create network
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()
adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())

# create data loader
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
loader = paddle.io.DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2)

# train
train(layer, loader, loss_fn, adam)

# save
path = "linear/linear"
paddle.jit.save(layer, path)
1 change: 1 addition & 0 deletions paddle/scripts/infrt_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ function create_fake_models() {
python3 -m pip install *whl
cd ${PADDLE_ROOT}/build
python3 ${PADDLE_ROOT}/tools/infrt/fake_models/multi_fc.py
python3 ${PADDLE_ROOT}/paddle/infrt/tests/model/linear.py
}

function test_infrt() {
Expand Down

0 comments on commit fcf8758

Please sign in to comment.