Skip to content

Commit

Permalink
Prototype of Instrumentation (onnx#750)
Browse files Browse the repository at this point in the history
* simple instrument

Signed-off-by: Tong Chen <[email protected]>

* driver

Signed-off-by: Tong Chen <[email protected]>

* fix path

Signed-off-by: Tong Chen <[email protected]>

* run instrumentation

Signed-off-by: Tong Chen <[email protected]>

* use id

Signed-off-by: Tong Chen <[email protected]>

* new output

Signed-off-by: Tong Chen <[email protected]>

* pass for instrumentation

Signed-off-by: Tong Chen <[email protected]>

* support noBroadcasting

Signed-off-by: Tong Chen <[email protected]>

* print virtual memory

Signed-off-by: Tong Chen <[email protected]>

* new opID

Signed-off-by: Tong Chen <[email protected]>

* slit instrument

Signed-off-by: Tong Chen <[email protected]>

* change API name

Signed-off-by: Tong Chen <[email protected]>

* Revert "pass for instrumentation"

This reverts commit 9e8e300.

* fixes

Signed-off-by: Tong Chen <[email protected]>

* format

Signed-off-by: Tong Chen <[email protected]>

* fix

Signed-off-by: Tong Chen <[email protected]>

* change intrface

* add control

Signed-off-by: Tong Chen <[email protected]>

* fix

Signed-off-by: Tong Chen <[email protected]>

* format

Signed-off-by: Tong Chen <[email protected]>

* format

Signed-off-by: Tong Chen <[email protected]>

* document

Signed-off-by: Tong Chen <[email protected]>

* remove temporary test script

Signed-off-by: Tong Chen <[email protected]>

Co-authored-by: Kevin O'Brien <[email protected]>
  • Loading branch information
chentong319 and caoimhinuibrian authored Jul 13, 2021
1 parent 66cc7a7 commit 7d1bad6
Show file tree
Hide file tree
Showing 17 changed files with 373 additions and 9 deletions.
41 changes: 41 additions & 0 deletions docs/Instrumentation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<!--- SPDX-License-Identifier: Apache-2.0 -->

# Instrumentation

Instrumentation is prototyped in onnx-mlir and can be used to debug runtime issue.

## Compile for instrumentation

By default, instrumentation is turned off. To turn it on, modify the default value of `OMInstrumentEnabled' in 'src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp' and build the compiler. Command line flag will be added.

Currently, only some onnx ops are instrumented. They are Conv, element-wise binary and element-wise varadic operations.

The instrumentation is added before and after the op.o

Currently, the call of initialization, OMInstrumentInit, need to be added before you load the dynamic library. It is being considered to add it to the beginning of main_graph by compiler.

## Run with instrumentation
The instrumenation library will print out the time and virtual memory usage along at each instrumentation point. A sample output is listed below:
```
ID=Conv TAG=0 Time elapsed: 0.000966 accumulated: 0.000966
335128
ID=Conv TAG=1 Time elapsed: 0.395338 accumulated: 0.396304
335128
ID=Mul TAG=0 Time elapsed: 0.302189 accumulated: 0.698493
335128
ID=Mul TAG=1 Time elapsed: 0.021133 accumulated: 0.719626
335128
```
The output is explained here:
* ID: currently is the name (limited to up to 7 chars) of the op.
* TAG: 0 for before the op, while 1 for after the op.
* elpased: time, in second, elapsed from previous instrumentation point.
* accumulated: time, in second, from instrumentationInit.
* the following line, 33512 in this example, is the virtual memory size (in kb) used by this process.

## Control of output
* If env variable OMINSTRUMENTTIME is set, the report of time is disabled
* If env variable OMINSTRUMENTMEMORY is set, the report of virtual memory is disabled

## Used in gdb
The function for instrument point is called `OMInstrumentPoint`. Breakpoint can be set inside this function to kind of step through onnx ops.
1 change: 1 addition & 0 deletions include/OnnxMlirRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <stdint.h>
#endif

#include <onnx-mlir/Runtime/OMInstrument.h>
#include <onnx-mlir/Runtime/OMTensor.h>
#include <onnx-mlir/Runtime/OMTensorList.h>
#include <onnx-mlir/Runtime/OMSignature.h>
Expand Down
67 changes: 67 additions & 0 deletions include/onnx-mlir/Runtime/OMInstrument.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===-------------- OMInstrument.h - OM Instrument Declaration header ------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declaration of API functions for instrumentation.
//
//===----------------------------------------------------------------------===//

#ifndef ONNX_MLIR_OMINSTRUMENT_H
#define ONNX_MLIR_OMINSTRUMENT_H

#ifdef __cplusplus
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <map>
#include <numeric>
#include <string>
#include <vector>
#include <cstdlib>
#else
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#endif // #ifdef __cplusplus

#ifdef __APPLE__
#include <stdlib.h>
#else
#include <malloc.h>
#endif // #ifdef __APPLE__

#ifdef __cplusplus
extern "C" {
#endif

/**
* Initialize instrument.
* Initialize counter and read env variables for control
*
*/
void OMInstrumentInit();

/**
* Create an instrument point.
* Measurement of runtime behavior will be measured and output
* In current implementation, the elapsed time from previous instrument point,
* and virtual memory size will be reported.
*
* @param id for this point. op name is used now.
* @param tag can used to give extra control of output. Used for begin/end mark now
* @return void
*
*/
void OMInstrumentPoint(int64_t id, int64_t tag);

#ifdef __cplusplus
}
#endif

#endif // ONNX_MLIR_OMINSTRUMENT_H
63 changes: 63 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ static size_t getRankFromMemRefType(LLVM::LLVMStructType memRefTy) {
return memRefTy.getBody()[3].cast<LLVM::LLVMArrayType>().getNumElements();
}

// Create a function declaration for OMInstrumentPoint, the signature is:
// `void (i64, i64)`
static FlatSymbolRefAttr getOrInsertInstrument(
PatternRewriter &rewriter, ModuleOp module) {
auto *context = module.getContext();
const char funcName[] = "OMInstrumentPoint";
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName))
return SymbolRefAttr::get(context, funcName);
auto llvmVoidTy = LLVM::LLVMVoidType::get(context);
auto llvmI64Ty = IntegerType::get(context, 64);
auto llvmFnType = LLVM::LLVMFunctionType::get(
llvmVoidTy, ArrayRef<mlir::Type>({llvmI64Ty, llvmI64Ty}), false);

PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, llvmFnType);
return SymbolRefAttr::get(context, funcName);
}

/// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(
Expand Down Expand Up @@ -543,6 +562,48 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
}
};

class KrnlInstrumentOpLowering : public ConversionPattern {
public:
explicit KrnlInstrumentOpLowering(MLIRContext *context)
: ConversionPattern(KrnlInstrumentOp::getOperationName(), 1, context) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
KrnlInstrumentOpAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
KrnlInstrumentOp instrumentOp = llvm::dyn_cast<KrnlInstrumentOp>(op);

// Get a symbol reference to the memcpy function, inserting it if necessary.
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
auto llvmVoidTy = LLVM::LLVMVoidType::get(context);
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
auto llvmI64Ty = IntegerType::get(context, 64);
auto llvmFnType = LLVM::LLVMFunctionType::get(
llvmVoidTy, ArrayRef<mlir::Type>({llvmI64Ty, llvmI64Ty}), false);

auto instrumentRef = getOrInsertInstrument(rewriter, parentModule);

Value nodeName =
rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(context, 64),
rewriter.getIntegerAttr(
rewriter.getIntegerType(64), instrumentOp.opID()));
Value tag =
rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(context, 64),
rewriter.getIntegerAttr(
rewriter.getIntegerType(64), instrumentOp.tag()));
// StringRef txt = instrumentOp->op_name();
// Value nodeName = rewriter.create<LLVM::ConstantOp>(loc, llvmI8PtrTy,
// instrumentOp->op_name());

rewriter.create<CallOp>(loc, instrumentRef, ArrayRef<Type>({}),
ArrayRef<Value>({nodeName, tag}));

rewriter.eraseOp(op);
return success();
}
};

//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlMemcpyOpLowering
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1390,6 +1451,8 @@ void mlir::populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns,
patterns.insert<KrnlGetRefOpLowering>(ctx, typeConverter);
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(ctx);

patterns.insert<KrnlInstrumentOpLowering>(ctx);

// Math library functions.
patterns.insert<KrnlUnaryMathOpLowering<KrnlErfOp>>(ctx);
patterns.insert<KrnlUnaryMathOpLowering<KrnlAcosOp>>(ctx);
Expand Down
9 changes: 9 additions & 0 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ struct ONNXElementwiseBinaryOpLowering : public ConversionPattern {
NameLoc::get(Identifier::get(ElementwiseBinaryOp::getOperationName(),
op->getContext()),
op->getLoc());
insertInstrumentBefore(op, rewriter, loc);
auto outputMemRefType = convertToMemRefType(*op->result_type_begin());
auto outputElementType = outputMemRefType.getElementType();
auto outputRank = outputMemRefType.getRank();
Expand All @@ -749,6 +750,7 @@ struct ONNXElementwiseBinaryOpLowering : public ConversionPattern {
BuildKrnlLoop loops(rewriter, loc, outputRank);
loops.createDefineAndIterateOp(alloc);
Block *iterationBlock = loops.getIterateBlock();
insertInstrumentAfter(op, rewriter, loc);
// Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(iterationBlock);
// Handle the operation:
Expand Down Expand Up @@ -795,13 +797,18 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
NameLoc::get(Identifier::get(ElementwiseVariadicOp::getOperationName(),
op->getContext()),
op->getLoc());
insertInstrumentBefore(op, rewriter, loc);
auto numArgs = op->getNumOperands();
auto outputMemRefType = convertToMemRefType(*op->result_type_begin());
auto outputElementType = outputMemRefType.getElementType();
auto outputRank = outputMemRefType.getRank();

// Shape helper.
ONNXOpBroadcastedShapeHelper shapeHelper(&rewriter, loc);

// The following call is used to force no broadcasting check at runtime
// Even when the dim is unknown at compile time
// ONNXOpBroadcastedShapeHelper shapeHelper(&rewriter, loc, true, true);
LogicalResult shapecomputed = shapeHelper.Compute(operands);
assert(succeeded(shapecomputed));
IndexExprScope outerScope(rewriter, shapeHelper.scope);
Expand All @@ -818,6 +825,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Create iterateOp & get block within iterate op.
BuildKrnlLoop loops(rewriter, loc, outputRank);
loops.createDefineAndIterateOp(alloc);
insertInstrumentAfter(op, rewriter, loc);

Block *iterationBlock = loops.getIterateBlock();
// Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(iterationBlock);
Expand Down
5 changes: 5 additions & 0 deletions src/Conversion/ONNXToKrnl/NN/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct ONNXConvOpLowering : public ConversionPattern {
ONNXConvOpAdaptor operandAdaptor(operands);
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);

insertInstrumentBefore(op, rewriter, loc);

// Read dilations attribute if the op has.
std::vector<int64_t> dilations = getDilations(convOp);
bool isDilated = !dilations.empty();
Expand Down Expand Up @@ -182,6 +184,9 @@ struct ONNXConvOpLowering : public ConversionPattern {
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
// Outer loop iterations.
outerLoops.createIterateOp();

insertInstrumentAfter(op, rewriter, loc);

rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
{
// 2. Emit the body of the outer loop nest.
Expand Down
17 changes: 17 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"

// Temporarily used to control instrumentation
static bool instrumentEnabled = false;

/// Check if all operands are scalar values at compile time.
bool hasAllScalarValues(ArrayRef<Value> values) {
for (Value value : values) {
Expand Down Expand Up @@ -158,6 +161,20 @@ bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
return insertDealloc;
}

// Insert an instrument function before an op
void insertInstrumentBefore(
Operation *op, PatternRewriter &rewriter, Location loc) {
if (instrumentEnabled)
rewriter.create<mlir::KrnlInstrumentOp>(loc, op, 0);
}

// Insert an instrument function after an op
void insertInstrumentAfter(
Operation *op, PatternRewriter &rewriter, Location loc) {
if (instrumentEnabled)
rewriter.create<mlir::KrnlInstrumentOp>(loc, op, 1);
}

// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
Expand Down
8 changes: 8 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ Value insertAllocAndDeallocSimple(PatternRewriter &rewriter, Operation *op,
// inserted.
bool checkInsertDealloc(Operation *currentOp, int resultIndex = 0);

// Insert an instrument function before an op
void insertInstrumentBefore(
Operation *currentOp, PatternRewriter &rewwriter, Location loc);

// Insert an instrument function after an op
void insertInstrumentAfter(
Operation *currentOp, PatternRewriter &rewwriter, Location loc);

// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
Expand Down
13 changes: 13 additions & 0 deletions src/Dialect/Krnl/KrnlOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,19 @@ void KrnlEntryPointOp::build(mlir::OpBuilder &builder, OperationState &state,
state.addAttribute(KrnlEntryPointOp::getSignatureAttrName(), signature);
}

void KrnlInstrumentOp::build(mlir::OpBuilder &builder, OperationState &state,
Operation *op, int tag = 0) {
const char *opName = op->getName().getStringRef().data();
int64_t opID = 0;
// getName() result is "onnx.opName"
// Put only the opName part in the opID within its size
strncpy((char *)&opID, opName + 5, sizeof(decltype(opID)) - 1);
IntegerAttr attr = builder.getI64IntegerAttr(opID);
auto tagAttr = builder.getI64IntegerAttr(tag);
state.addAttribute("opID", attr);
state.addAttribute("tag", tagAttr);
}

//===----------------------------------------------------------------------===//
// KrnlBlockOp
//===----------------------------------------------------------------------===//
Expand Down
13 changes: 13 additions & 0 deletions src/Dialect/Krnl/KrnlOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,16 @@ def KrnlCopyFromBufferOp : Op<Krnl_Dialect, "copy_from_tile_buffer",
$buffer `,` $dest `[` $starts `]` attr-dict `:` type($buffer) `,` type($dest)
}];
}

def KrnlInstrumentOp : Op<Krnl_Dialect, "runtime_instrument",
[]> {
let summary = "instrumentation point.";
let description = [{
Operation that invokes the runtime instrument utility.
May be used for gdb.
}];

let arguments = (ins I64Attr:$opID, I64Attr:$tag);

let builders = [ OpBuilder<(ins "Operation *": $op, "int ": $tag)> ];
}
10 changes: 6 additions & 4 deletions src/Dialect/ONNX/ONNXShapeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ LogicalResult ONNXArgMaxOpShapeHelper::Compute(
//===----------------------------------------------------------------------===//

ONNXOpBroadcastedShapeHelper::ONNXOpBroadcastedShapeHelper(
ConversionPatternRewriter *rewriter, Location loc, bool uniBroadcasting)
: scope(rewriter, loc), isUniBroadcasting(uniBroadcasting) {}
ConversionPatternRewriter *rewriter, Location loc, bool uniBroadcasting,
bool noBroadcasting)
: scope(rewriter, loc), isUniBroadcasting(uniBroadcasting),
isNoBroadcasting(noBroadcasting) {}

LogicalResult ONNXOpBroadcastedShapeHelper::Compute(ArrayRef<Value> operands) {
// A temporary IndexExpr vector for the output.
Expand Down Expand Up @@ -171,7 +173,7 @@ LogicalResult ONNXOpBroadcastedShapeHelper::Compute(ArrayRef<Value> operands) {
// 1 - LiteralNot1
// 1 - 1
if (currentDimExpr.isLiteralAndIdenticalTo(1)) {
if (!isUniBroadcasting)
if (!isUniBroadcasting && !isNoBroadcasting)
dimsExpr[j] = nextDimExpr;
continue;
}
Expand Down Expand Up @@ -215,7 +217,7 @@ LogicalResult ONNXOpBroadcastedShapeHelper::Compute(ArrayRef<Value> operands) {
LogicalResult ONNXOpBroadcastedShapeHelper::GetAccessExprs(Value operand,
unsigned operandIndex, const SmallVectorImpl<IndexExpr> &outputAccessExprs,
SmallVectorImpl<IndexExpr> &operandAccessExprs) {
if (isUniBroadcasting && operandIndex == 0) {
if (isNoBroadcasting || (isUniBroadcasting && operandIndex == 0)) {
for (IndexExpr ie : outputAccessExprs)
operandAccessExprs.emplace_back(ie);
return success();
Expand Down
7 changes: 6 additions & 1 deletion src/Dialect/ONNX/ONNXShapeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct ONNXOpShapeHelper {
/// be ranked in advance.
struct ONNXOpBroadcastedShapeHelper {
ONNXOpBroadcastedShapeHelper(ConversionPatternRewriter *rewriter,
Location loc, bool uniBroadcasting = false);
Location loc, bool uniBroadcasting = false, bool noBroadcasting = false);

// Compute a vector of IndexExprs to represent the output shape. Results are
// stored in 'outputDims'.
Expand Down Expand Up @@ -114,6 +114,11 @@ struct ONNXOpBroadcastedShapeHelper {
// If unidirectional broadcasting, the other operands are always
// unidirectional broadcastable to the first operand.
bool isUniBroadcasting;

// If isNoBroadcasting is true, the shape of all input is assumed to be same
// This flag is used to test dynamic shape
// There is no impact on static shape
bool isNoBroadcasting;
};

// Shape for ArgMax
Expand Down
Loading

0 comments on commit 7d1bad6

Please sign in to comment.