forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Eval.cpp
125 lines (106 loc) · 3.82 KB
/
Eval.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
//
// <copyright file="Eval.cpp" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// Eval.cpp : Defines the exported functions for the DLL application.
//
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#include "stdafx.h"
#include "Basics.h"
#define EVAL_LOCAL
#include "Eval.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
std::string GetEvalName(ElemType)
{std::string empty; return empty;}
template<> std::string GetEvalName(float) {std::string name = "GetEvalF"; return name;}
template<> std::string GetEvalName(double) {std::string name = "GetEvalD"; return name;}
template<class ElemType>
void Eval<ElemType>::Init(const std::string& /*config*/)
{
throw std::logic_error("Init shouldn't be called, use constructor");
// not implemented, calls the underlying class instead
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template<class ElemType>
void Eval<ElemType>::Destroy()
{
m_eval->Destroy();
}
// Eval Constructor
template<class ElemType>
void Eval<ElemType>::GetEvalClass(const std::string& config)
{
typedef void (*GetEvalProc)(IEvaluateModel<ElemType>** peval);
// initialize just in case
m_eval = NULL;
std::wstring module = L"CNTKEval";
// get the name for the dll we want to use, default to CNTKEval.dll
std::string::size_type found = config.find("evaluator=");
if (found != std::string::npos)
{
std::string::size_type end = config.find_first_of("\n \t", found);
if (end != std::string::npos)
{
module = msra::strfun::utf16(config.substr(found, end-found));
}
}
// create a variable of each type just to call the proper templated version
ElemType elemType = ElemType();
GetEvalProc getEvalProc = (GetEvalProc)Plugin::Load(module, GetEvalName(elemType));
getEvalProc(&m_eval);
}
// Eval Constructor
// options - [in] string of options (i.e. "-windowsize:11 -addenergy") data reader specific
template<class ElemType>
Eval<ElemType>::Eval(const std::string& config)
{
GetEvalClass(config);
m_eval->Init(config);
}
// destructor - cleanup temp files, etc.
template<class ElemType>
Eval<ElemType>::~Eval()
{
// free up resources
if (m_eval != NULL)
{
m_eval->Destroy();
m_eval = NULL;
}
}
// LoadModel - load a model from the specified path
// modelFileName - file holding the model to load
template<class ElemType>
void Eval<ElemType>::LoadModel(const std::wstring& modelFileName)
{
m_eval->LoadModel(modelFileName);
}
// GetNodeDimensions - Get the node dimensions of the specified nodes
// dimensions - map from name of node to dimension of the node
// nodeGroup - type of node we are requesting (input/output/specified)
template<class ElemType>
void Eval<ElemType>::GetNodeDimensions(std::map<std::wstring, size_t>& dimensions, NodeGroup nodeGroup)
{
m_eval->GetNodeDimensions(dimensions, nodeGroup);
}
// Evaluate - Evalute using the model with the given inputs and outputs
// inputs - map from node name to input vector
// outputs - map from node name to output vector, outputs vectors need to be preallocated by caller, sizing will happen during evaluation
template<class ElemType>
void Eval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>& inputs, std::map<std::wstring, std::vector<ElemType>*>& outputs)
{
m_eval->Evaluate(inputs, outputs);
}
// ResetState - Reset the cell state when we get the start of an utterance
template<class ElemType>
void Eval<ElemType>::ResetState()
{
m_eval->ResetState();
}
//The explicit instantiation
template class Eval<double>;
template class Eval<float>;
}}}