forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Eval.cpp
138 lines (119 loc) · 4.01 KB
/
Eval.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//
// <copyright file="Eval.cpp" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// Eval.cpp : Defines the exported functions for the DLL application.
//
#include "stdafx.h"
#define EVAL_LOCAL
#include "Eval.h"
#include "basetypes.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
std::string GetEvalName(ElemType)
{std::string empty; return empty;}
template<> std::string GetEvalName(float) {std::string name = "GetEvalF"; return name;}
template<> std::string GetEvalName(double) {std::string name = "GetEvalD"; return name;}
template<class ElemType>
void Eval<ElemType>::Init(const std::string& config)
{
throw std::logic_error("Init shouldn't be called, use constructor");
// not implemented, calls the underlying class instead
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template<class ElemType>
void Eval<ElemType>::Destroy()
{
m_eval->Destroy();
}
// Eval Constructor
template<class ElemType>
void Eval<ElemType>::GetEvalClass(const std::string& config)
{
typedef void (*GetEvalProc)(IEvaluateModel<ElemType>** peval);
// initialize just in case
m_hModule = NULL;
m_eval = NULL;
m_dllName = L"CNTKEval";
// get the name for the dll we want to use, default to CNTKEval.dll
std::string::size_type found = config.find("evaluator=");
if (found != std::string::npos)
{
std::string::size_type end = config.find_first_of("\n \t", found);
if (end != std::string::npos)
{
m_dllName = msra::strfun::utf16(config.substr(found, end-found));
}
}
m_dllName += L".dll";
m_hModule = LoadLibrary(m_dllName.c_str());
if (m_hModule == NULL)
{
std::string message = "Eval not found: ";
message += msra::strfun::utf8(m_dllName);
throw std::runtime_error(message);
}
// create a variable of each type just to call the proper templated version
ElemType elemType = ElemType();
GetEvalProc getEvalProc = (GetEvalProc)GetProcAddress(m_hModule, GetEvalName(elemType).c_str());
getEvalProc(&m_eval);
}
// Eval Constructor
// options - [in] string of options (i.e. "-windowsize:11 -addenergy") data reader specific
template<class ElemType>
Eval<ElemType>::Eval(const std::string& config)
{
GetEvalClass(config);
m_eval->Init(config);
}
// destructor - cleanup temp files, etc.
template<class ElemType>
Eval<ElemType>::~Eval()
{
// free up resources
if (m_eval != NULL)
{
m_eval->Destroy();
m_eval = NULL;
}
if (m_hModule != NULL)
{
FreeLibrary(m_hModule);
m_hModule = NULL;
}
}
// LoadModel - load a model from the specified path
// modelFileName - file holding the model to load
template<class ElemType>
void Eval<ElemType>::LoadModel(const std::wstring& modelFileName)
{
m_eval->LoadModel(modelFileName);
}
// GetNodeDimensions - Get the node dimensions of the specified nodes
// dimensions - map from name of node to dimension of the node
// nodeGroup - type of node we are requesting (input/output/specified)
template<class ElemType>
void Eval<ElemType>::GetNodeDimensions(std::map<std::wstring, size_t>& dimensions, NodeGroup nodeGroup)
{
m_eval->GetNodeDimensions(dimensions, nodeGroup);
}
// Evaluate - Evalute using the model with the given inputs and outputs
// inputs - map from node name to input vector
// outputs - map from node name to output vector, outputs vectors need to be preallocated by caller, sizing will happen during evaluation
template<class ElemType>
void Eval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>& inputs, std::map<std::wstring, std::vector<ElemType>*>& outputs)
{
m_eval->Evaluate(inputs, outputs);
}
// ResetState - Reset the cell state when we get the start of an utterance
template<class ElemType>
void Eval<ElemType>::ResetState()
{
m_eval->ResetState();
}
//The explicit instantiation
template class Eval<double>;
template class Eval<float>;
}}}