Skip to content

Commit

Permalink
fix: Improve logging, restructure casting function
Browse files Browse the repository at this point in the history
- Address review comments
- Improve documentation and logging messages
- Restructure casting function to allow for casting of variable data
types
- Add casting for `at::kByte` segment block inputs as well as segment
block outputs
  • Loading branch information
gs-olive committed Dec 22, 2022
1 parent a4c2d60 commit d74e0b5
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ torch::jit::Node* createCastNode(
SegmentedBlock& seg_block,
size_t index,
bool is_input,
at::ScalarType dtype,
std::string device,
bool force_create_node = false) {
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
Expand All @@ -115,7 +116,7 @@ torch::jit::Node* createCastNode(
value_map.insert({cast_node->inputs()[0], cast_subgraph_value});
if (!is_input) {
// if this value is output, we need to cast it to int32
auto const_val = g->insertConstant(3);
auto const_val = g->insertConstant(dtype);
if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
value_map.insert({cast_node->inputs()[2], const_val});
} else {
Expand All @@ -127,7 +128,7 @@ torch::jit::Node* createCastNode(
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
} else {
// if there is no explicit cast aten::to operation, we need to create a node
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
auto const_type = g->insertConstant(dtype);
auto const_zero = g->insertConstant(0);
const_zero->setType(torch::jit::BoolType::get());
auto cuda = g->insertConstant(device);
Expand Down Expand Up @@ -230,17 +231,28 @@ void getSegmentsOutputByRunning(
// auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
if (seg_block.target() == SegmentedBlock::kTorch) {
// First, check if there is Int64 input
if (partitioning_info.truncate_long_and_double) {
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {
// we add a cast operation to cast the type to Int64
auto cast_node = createCastNode(seg_block, i, true, target_device);
seg_block.g()->prependNode(cast_node);
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
}
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
LOG_DEBUG(
"Detected graph Long tensor input type during shape analysis, "
<< "inserting aten::to cast to Long to ensure this Torch block receives "
<< "a Long-type tensor input.");
// we add a cast operation to cast the type to Int64
auto cast_node = createCastNode(seg_block, i, true, at::kLong, target_device);
seg_block.g()->prependNode(cast_node);
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
LOG_DEBUG(
"Detected graph Byte tensor input type during shape analysis, "
<< "inserting aten::to cast to Byte to ensure this Torch block receives "
<< "a Byte-type tensor input.");
// If the input has type Byte, ensure it is casted to the correct type
auto cast_node = createCastNode(seg_block, i, true, at::kByte, target_device, /*force_create_node=*/true);
seg_block.g()->prependNode(cast_node);
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
}
}
}
Expand All @@ -250,14 +262,22 @@ void getSegmentsOutputByRunning(
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();

// If the input has type Long and truncation was requested, insert truncate
// If the output has type Long and truncation was requested, insert truncate
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
auto cast_node = createCastNode(seg_block, i, false, target_device);
LOG_DEBUG(
"Detected graph Long tensor output type during shape analysis, "
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
<< "receives an Int-type tensor input.");
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device);
seg_block.g()->appendNode(cast_node);
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
// If the input has type Byte and truncation was requested, insert Integer cast
auto cast_node = createCastNode(seg_block, i, false, target_device, /*force_create_node=*/true);
LOG_DEBUG(
"Detected graph Byte tensor output type during shape analysis, "
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
<< "receives an Int-type tensor input.");
// If the output has type Byte and casting was requested, insert Integer cast
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device, /*force_create_node=*/true);
seg_block.g()->appendNode(cast_node);
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
}
Expand Down

0 comments on commit d74e0b5

Please sign in to comment.