Skip to content

Commit

Permalink
[QNN EP] Add model description into context binary file metadata for …
Browse files Browse the repository at this point in the history
…validation (microsoft#16248)

### Description
Add model description into context binary file metadata for validation

### Motivation and Context
Dump more information for validation

---------

Co-authored-by: Adrian Lizarraga <[email protected]>
  • Loading branch information
HectorSVC and adrianlizarraga authored Jun 9, 2023
1 parent d1e8d4a commit a9d47f7
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 61 deletions.
87 changes: 66 additions & 21 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,26 @@ Status QnnBackendManager::ReleaseContext() {
return Status::OK();
}

bool QnnBackendManager::IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring) {
// Avoid duplicate work
if (!context_cache_path_.empty()) {
return ctx_file_exists_;
}
model_description_ = model_description;
// Use user provided context cache file path if exist, otherwise try model_file.onnx.bin by default
if (customer_context_cache_path.empty()) {
context_cache_path_ = PathToUTF8String(model_pathstring) + ".bin";
} else {
context_cache_path_ = customer_context_cache_path;
}

ctx_file_exists_ = std::filesystem::exists(context_cache_path_);

return ctx_file_exists_;
}

Status WriteInt16ToBinaryFile(std::ofstream& of_stream, uint16_t value) {
const std::vector<uint16_t> data{value};
std::vector<unsigned char> data_bytes(sizeof(uint16_t) / sizeof(unsigned char));
Expand All @@ -324,9 +344,7 @@ Status WriteInt16ToBinaryFile(std::ofstream& of_stream, uint16_t value) {
return Status::OK();
}

Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_cache_pathstring,
const std::string& model_name,
const std::string& graph_name) {
Status QnnBackendManager::DumpQnnContext(const std::string& model_name, const std::string& graph_name) {
if (nullptr == qnn_interface_.contextGetBinarySize ||
nullptr == qnn_interface_.contextGetBinary) {
LOGS(*logger_, ERROR) << "Failed to get valid function pointer.";
Expand Down Expand Up @@ -362,7 +380,7 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Context written buffer exceeds allocated buffer size.");
}

std::ofstream of_stream(context_cache_pathstring.c_str(), std::ofstream::binary);
std::ofstream of_stream(context_cache_path_.c_str(), std::ofstream::binary);
if (!of_stream) {
LOGS(*logger_, ERROR) << "Failed to open cached context file.";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file.");
Expand All @@ -371,7 +389,10 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_
// Write Ort metadata into context binary file
uint16_t model_name_length = static_cast<uint16_t>(model_name.length());
uint16_t graph_name_length = static_cast<uint16_t>(graph_name.length());
uint16_t header_length = 3 * sizeof(uint16_t) + model_name_length + graph_name_length;
uint16_t model_description_length = static_cast<uint16_t>(model_description_.length());

// Header: uint16_t(totale_length)|uint16_t(model_name_length)|model_name|uint16_t(graph_name_length)|graph_name|uint16_t(model_description_length)|model_description
uint16_t header_length = 4 * sizeof(uint16_t) + model_name_length + graph_name_length + model_description_length;
uint16_t totale_length = header_length + static_cast<uint16_t>(strlen(QNN_PROVIDER));
of_stream.write(QNN_PROVIDER, strlen(QNN_PROVIDER));

Expand All @@ -382,6 +403,11 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_

ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, graph_name_length));
of_stream.write(graph_name.c_str(), graph_name_length);

ORT_RETURN_IF_ERROR(WriteInt16ToBinaryFile(of_stream, model_description_length));
of_stream.write(model_description_.c_str(), model_description_length);
model_description_.clear();

LOGS(*logger_, VERBOSE) << "Dump metadata with length: " << totale_length;

of_stream.write(reinterpret_cast<char*>(context_buffer.get()), written_buffer_size);
Expand All @@ -390,14 +416,16 @@ Status QnnBackendManager::DumpQnnContext(const onnxruntime::PathString& context_
return Status::OK();
}

Status QnnBackendManager::LoadCachedQnnContext(const onnxruntime::PathString& context_cache_pathstring, QnnModel& qnn_model) {
Status QnnBackendManager::LoadCachedQnnContext(QnnModel& qnn_model) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
ORT_RETURN_IF(result, "Failed to get valid function pointer.");

ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");

uint64_t buffer_size{0};
std::ifstream cache_file(context_cache_pathstring.c_str(), std::ifstream::binary);
std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
cache_file.seekg(0, cache_file.end);
buffer_size = cache_file.tellg();
Expand Down Expand Up @@ -466,6 +494,8 @@ Status QnnBackendManager::LoadCachedQnnContext(const onnxruntime::PathString& co
ORT_RETURN_IF_ERROR(ExtractBackendProfilingInfo());
context_created_ = true;

model_description_.clear();
model_description_from_ctx_cache_.clear();
LOGS(*logger_, VERBOSE) << "Load from cached QNN Context completed.";
return Status::OK();
}
Expand Down Expand Up @@ -502,20 +532,19 @@ Status ReadInt16FromBinaryFile(std::ifstream& binary_file, uint16_t& value) {
}

/* \brief: Try to get metadata from Ort generated context cache binary file.
* \param[in] context_cache_pathstring - context cache binary file path string
* Cached context binary file generated by Ort has some metadata which can be used for validation with the model
* to avoid user choose a wrong context binary file which is not for this model
* It is treated as Qnn generated context binary file if no metadata found from the file
*/
Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathString& context_cache_pathstring) {
Status QnnBackendManager::GetMetadataFromOrtContextFile() {
// Only try parse meta data once
if (ctx_metadata_tried_) {
return Status::OK();
}
ctx_metadata_tried_ = true;

uint64_t buffer_size = 0;
std::ifstream cache_file(context_cache_pathstring.c_str(), std::ifstream::binary);
std::ifstream cache_file(context_cache_path_.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open context cache file.");
cache_file.seekg(0, cache_file.end);
buffer_size = cache_file.tellg();
Expand All @@ -533,17 +562,18 @@ Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathS
}
ort_generated_ctx_cache_ = true;

uint16_t header_length = 0;
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, header_length));
ort_ctx_metadata_length_ = header_length + static_cast<uint16_t>(ort_flag_length);
uint16_t str_length = 0;
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ort_ctx_metadata_length_ = str_length + static_cast<uint16_t>(ort_flag_length);

ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_name_from_ctx_cache_, static_cast<size_t>(str_length)));

uint16_t model_name_length = 0;
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, model_name_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_name_from_ctx_cache_, static_cast<size_t>(model_name_length)));
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, graph_name_from_ctx_cache_, static_cast<size_t>(str_length)));

uint16_t graph_name_length = 0;
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, graph_name_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, graph_name_from_ctx_cache_, static_cast<size_t>(graph_name_length)));
ORT_RETURN_IF_ERROR(ReadInt16FromBinaryFile(cache_file, str_length));
ORT_RETURN_IF_ERROR(ReadStringFromBinaryFile(cache_file, model_description_from_ctx_cache_, static_cast<size_t>(str_length)));

return Status::OK();
}
Expand All @@ -555,15 +585,30 @@ Status QnnBackendManager::GetMetadataFromOrtContextFile(const onnxruntime::PathS
* so only validate the graph name for 2nd call
*/
Status QnnBackendManager::ValidateWithContextFile(const std::string& model_name, const std::string& graph_name) {
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");

// Get metadata from cached context binary file
ORT_RETURN_IF_ERROR(GetMetadataFromOrtContextFile());

// The context binary file doesn't have ORT metadata, so it is generated from QNN toolchain not from ORT
if (!ort_generated_ctx_cache_) {
return Status::OK();
}

ORT_RETURN_IF(model_name != model_name_from_ctx_cache_,
"Model file name from context cache metadata: " + model_name_from_ctx_cache_ + " is different with target: " + model_name);
"Model file name from context cache metadata: " + model_name_from_ctx_cache_ +
" is different with target: " + model_name +
". Please make sure the context binary file matches the model.");

ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache_,
"Model description from context cache metadata: " + model_description_from_ctx_cache_ +
" is different with target: " + model_description_ +
". Please make sure the context binary file matches the model.");

ORT_RETURN_IF(graph_name != graph_name_from_ctx_cache_ && get_capability_round_2_,
"Graph name from context cache metadata: " + graph_name_from_ctx_cache_ + " is different with target: " + graph_name);
"Graph name from context cache metadata: " + graph_name_from_ctx_cache_ +
" is different with target: " + graph_name +
". You may need to re-generate the context binary file.");

get_capability_round_2_ = true;
return Status::OK();
Expand Down
16 changes: 11 additions & 5 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,11 @@ class QnnBackendManager {
return CreateContext();
}

Status DumpQnnContext(const onnxruntime::PathString& context_cache_pathstring,
const std::string& model_name,
const std::string& graph_name);
Status DumpQnnContext(const std::string& model_name, const std::string& graph_name);

Status LoadCachedQnnContext(const onnxruntime::PathString& context_cache_pathstring, QnnModel& qnn_model);
Status LoadCachedQnnContext(QnnModel& qnn_model);

Status GetMetadataFromOrtContextFile(const onnxruntime::PathString& model_path);
Status GetMetadataFromOrtContextFile();

Status ValidateWithContextFile(const std::string& model_name, const std::string& graph_name);

Expand Down Expand Up @@ -133,6 +131,10 @@ class QnnBackendManager {
// NPU backend requires quantized model
bool IsNpuBackend() { return is_npu_backend_; }

bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring);

private:
void* LoadLib(const char* file_name, int flags, std::string& error_msg);

Expand Down Expand Up @@ -197,6 +199,10 @@ class QnnBackendManager {
HtpPerformanceMode htp_performance_mode_;
std::string model_name_from_ctx_cache_ = "";
std::string graph_name_from_ctx_cache_ = "";
std::string model_description_from_ctx_cache_ = "";
std::string model_description_ = "";
std::string context_cache_path_ = "";
bool ctx_file_exists_ = false;
bool ctx_metadata_tried_ = false;
bool ort_generated_ctx_cache_ = false;
bool get_capability_round_2_ = false;
Expand Down
42 changes: 11 additions & 31 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,18 +276,9 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
const auto& logger = *GetLogger();
bool load_from_cached_context = false;
if (context_cache_enabled_) {
onnxruntime::PathString context_cache_pathstring;
load_from_cached_context = IsContextCacheFileExists(graph_viewer.ModelPath().ToPathString(),
context_cache_pathstring);

// Get metadata from cached context binary file
if (load_from_cached_context) {
auto rt = qnn_backend_manager_->GetMetadataFromOrtContextFile(context_cache_pathstring);
if (Status::OK() != rt) {
LOGS(logger, ERROR) << "Failed to get metadata from cached context binary file. " << rt.ErrorMessage();
return result;
}
}
load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_,
graph_viewer.Description(),
graph_viewer.ModelPath().ToPathString());
}

// Load from cached context will load the QnnSystem lib and skip the Qnn context creation
Expand Down Expand Up @@ -444,19 +435,6 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector<FusedNodeAndG
return Status::OK();
}

bool QNNExecutionProvider::IsContextCacheFileExists(const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_pathstring) const {
// Use user provided context cache file path if exist, otherwise try model_file.onnx.bin by default
if (context_cache_path_.empty()) {
context_cache_pathstring = model_pathstring + ToPathString(".bin");
} else {
context_cache_pathstring = ToPathString(context_cache_path_);
}
bool context_cache_file_exist = std::filesystem::exists(context_cache_pathstring.c_str());

return context_cache_file_exist;
}

Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
const auto& logger = *GetLogger();
Expand All @@ -466,15 +444,18 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support singel partition for context cache feature.");
Node& fused_node = fused_nodes_and_graphs[0].fused_node;
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph);
onnxruntime::PathString context_cache_pathstring;
bool load_from_cached_context = IsContextCacheFileExists(graph_viewer.ModelPath().ToPathString(),
context_cache_pathstring);
// The dumy_model_description won't be used since IsContextCacheFileExists call cached the result
// The graph_viewer.Description here is not same with original model
std::string dumy_model_description = "";
bool load_from_cached_context = qnn_backend_manager_->IsContextCacheFileExists(context_cache_path_,
dumy_model_description,
graph_viewer.ModelPath().ToPathString());
// Load and execute from cached context if exist
if (load_from_cached_context) {
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger,
qnn_backend_manager_.get(),
is_npu_backend);
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnContext(context_cache_pathstring, *(qnn_model.get())));
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnContext(*(qnn_model.get())));
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());

Expand All @@ -490,8 +471,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger));
// graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id]
// dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability
ORT_RETURN_IF_ERROR(qnn_backend_manager_->DumpQnnContext(context_cache_pathstring,
GetFileNameFromModelPath(graph_viewer.ModelPath()),
ORT_RETURN_IF_ERROR(qnn_backend_manager_->DumpQnnContext(GetFileNameFromModelPath(graph_viewer.ModelPath()),
graph_viewer.Name()));
}
return Status::OK();
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ class QNNExecutionProvider : public IExecutionProvider {
std::vector<NodeComputeInfo>& node_compute_funcs,
const logging::Logger& logger);

bool IsContextCacheFileExists(const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_pathstring) const;

void ParseHtpPerformanceMode(std::string htp_performance_mode_string);

private:
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/test/providers/qnn/simple_op_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ TEST_F(QnnHTPBackendTests, TestQDQAtanTest) {
// 1st run will generate the Qnn context cache binary file
// 2nd run will load and run from Qnn context cache binary file
TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) {
RunQDQSingleInputOpTest({1, 2, 3}, "Atan", "TestQDQGeluTest", 11, ExpectedEPNodeAssignment::All, 1);
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
Expand Down

0 comments on commit a9d47f7

Please sign in to comment.