forked from pytorch/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompiler.h
30 lines (21 loc) · 933 Bytes
/
compiler.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#pragma once
#include <cuda_runtime.h>
#include <vector>
#include "core/conversion/conversion.h"
#include "core/ir/ir.h"
#include "core/partitioning/partitioning.h"
#include "torch/csrc/jit/api/module.h"
namespace trtorch {
namespace core {
struct CompileSpec {
CompileSpec(std::vector<ir::InputRange> input_ranges) : convert_info(std::move(input_ranges)) {}
conversion::ConversionInfo convert_info;
partitioning::PartitionInfo partition_info;
};
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine);
void set_device(const int gpu_id);
} // namespace core
} // namespace trtorch