Skip to content

Commit 036476b

Browse files
committed
Completes pinning managed memory prior to native evaluation.
Applies CR feedback
1 parent 1104a7e commit 036476b

File tree

3 files changed

+121
-62
lines changed

3 files changed

+121
-62
lines changed

Source/Extensibility/CSEvalClient/Program.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ private static void EvaluateExtendedNetworkSingleLayerNoInput()
269269
string workingDirectory = Path.Combine(initialDirectory, @"..\..\Examples\Other\Simple2d\Config");
270270
Environment.CurrentDirectory = initialDirectory;
271271

272-
using (var model = new IEvaluateModelExtendedManagedF())
272+
using (var model = new ModelEvaluationExtendedF())
273273
{
274274
// Create the network
275275
// This network (AddOperatorConstantNoInput.cntk) is a simple network consisting of a single binary operator (Plus)
@@ -279,7 +279,7 @@ private static void EvaluateExtendedNetworkSingleLayerNoInput()
279279

280280
VariableSchema outputSchema = model.GetOutputSchema();
281281

282-
model.StartForwardEvaluation(outputSchema.Select(s => s.m_name).ToList<string>());
282+
model.StartForwardEvaluation(outputSchema.Select(s => s.Name).ToList<string>());
283283

284284
List<ValueBuffer<float>> outputBuffer = outputSchema.CreateBuffers<float>();
285285
List<ValueBuffer<float>> inputBuffer = new List<ValueBuffer<float>>();
@@ -291,7 +291,7 @@ private static void EvaluateExtendedNetworkSingleLayerNoInput()
291291
float[][] expected = {new float[]{2}, new float[]{3}};
292292

293293
Console.WriteLine("Expected values: {0}", string.Join(" - ", expected.Select(b => string.Join(", ", b)).ToList<string>()));
294-
Console.WriteLine("Actual Values : {0}", string.Join(" - ", outputBuffer.Select(b => string.Join(", ", b.m_buffer)).ToList<string>()));
294+
Console.WriteLine("Actual Values : {0}", string.Join(" - ", outputBuffer.Select(b => string.Join(", ", b.Buffer)).ToList<string>()));
295295
}
296296
}
297297
catch (CNTKException ex)

Source/Extensibility/EvalWrapper/EvalExtendedWrapper.cpp

+94-41
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,29 @@ generic<class ElemType>
5656
m_colIndices = gcnew array<int>(size);
5757
m_size = size;
5858
}
59+
60+
property int Size
61+
{
62+
int get() { return m_size; }
63+
}
64+
65+
property array<ElemType>^ Buffer
66+
{
67+
array<ElemType>^ get() { return m_buffer; }
68+
}
69+
70+
property array<int>^ Indices
71+
{
72+
array<int>^ get() { return m_indices; }
73+
}
74+
75+
property array<int>^ ColIndices
76+
{
77+
array<int>^ get() { return m_colIndices;
78+
}
79+
}
80+
81+
private:
5982

6083
int m_size;
6184

@@ -111,15 +134,15 @@ public ref struct VariableLayout
111134
};
112135

113136
// Name of the input
114-
String^ m_name;
137+
property String^ Name;
115138

116-
DataType m_dataType;
139+
property DataType DataKind;
117140

118-
StorageType m_storageType;
141+
property StorageType StorageKind;
119142

120143
// Dimension of the tensor, flattened to 1 dimension, for one entry on the dynamic axis.
121144
// E.g. for a tensor [2,3,*] this would be 6.
122-
int m_numElements;
145+
property int NumElements;
123146
};
124147

125148
public ref class VariableSchema : List<VariableLayout^>
@@ -136,7 +159,7 @@ public ref class VariableSchema : List<VariableLayout^>
136159
List<ValueBuffer<ElemType>^>^ buffers = gcnew List<ValueBuffer<ElemType>^>(this->Count);
137160
for (int i = 0; i < this->Count; i++)
138161
{
139-
buffers->Add(gcnew ValueBuffer<ElemType>(this[i]->m_numElements * maxLengths[i]));
162+
buffers->Add(gcnew ValueBuffer<ElemType>(this[i]->NumElements * maxLengths[i]));
140163
}
141164

142165
return buffers;
@@ -149,7 +172,7 @@ public ref class VariableSchema : List<VariableLayout^>
149172
List<ValueBuffer<ElemType>^>^ buffers = gcnew List<ValueBuffer<ElemType>^>(this->Count);
150173
for (int i = 0; i < this->Count; i++)
151174
{
152-
buffers->Add(gcnew ValueBuffer<ElemType>(this[i]->m_numElements));
175+
buffers->Add(gcnew ValueBuffer<ElemType>(this[i]->NumElements));
153176
}
154177

155178
return buffers;
@@ -158,14 +181,14 @@ public ref class VariableSchema : List<VariableLayout^>
158181

159182
/// Managed wrapper for the native evaluation model
160183
template<typename ElemType>
161-
public ref class IEvaluateModelExtendedManaged : IDisposable
184+
public ref class ModelEvaluationExtended : IDisposable
162185
{
163186
typedef std::pair<std::wstring, std::vector<ElemType>*> MapEntry;
164187

165188
public:
166-
/// <summary>Initializes a new instance of the <see cref="IEvaluateModelExtendedManaged"> class.</summary>
189+
/// <summary>Initializes a new instance of the <see cref="ModelEvaluationExtended"> class.</summary>
167190
/// <param name="funcName">Factory function name for retrieving the native model from the dll.</param>
168-
IEvaluateModelExtendedManaged(String^ funcName)
191+
ModelEvaluationExtended(String^ funcName)
169192
{
170193
pin_ptr<const WCHAR> dllname = PtrToStringChars("evaldll.dll");
171194
auto hModule = LoadLibrary(dllname);
@@ -228,9 +251,10 @@ public ref class IEvaluateModelExtendedManaged : IDisposable
228251
for (auto& lay : outputLayout)
229252
{
230253
VariableLayout^ layout = gcnew VariableLayout();
231-
layout->m_name = gcnew String(lay.m_name.c_str());
232-
layout->m_dataType = GetDataType(lay.m_dataType);
233-
layout->m_numElements = lay.m_numElements;
254+
layout->Name = gcnew String(lay.m_name.c_str());
255+
layout->DataKind = GetDataKind(lay.m_dataType);
256+
layout->NumElements = lay.m_numElements;
257+
layout->StorageKind = GetStorageKind(lay.m_storageType);
234258

235259
outputSchema->Add(layout);
236260
}
@@ -275,9 +299,10 @@ public ref class IEvaluateModelExtendedManaged : IDisposable
275299
for (auto& lay : inputLayout)
276300
{
277301
VariableLayout^ layout = gcnew VariableLayout();
278-
layout->m_name = gcnew String(lay.m_name.c_str());
279-
layout->m_dataType = GetDataType(lay.m_dataType);
280-
layout->m_numElements = lay.m_numElements;
302+
layout->Name = gcnew String(lay.m_name.c_str());
303+
layout->DataKind = GetDataKind(lay.m_dataType);
304+
layout->NumElements = lay.m_numElements;
305+
layout->StorageKind = GetStorageKind(lay.m_storageType);
281306

282307
inputSchema->Add(layout);
283308
}
@@ -307,26 +332,54 @@ public ref class IEvaluateModelExtendedManaged : IDisposable
307332
Native::ValueRefs<ElemType> stdOutputs;
308333
Native::ValueBuffer<ElemType, Native::VectorRef>* vb = new Native::ValueBuffer<ElemType, Native::VectorRef>();
309334

335+
// Hold gc objects in the stack, while performing native actions
336+
vector<gcroot<array<ElemType>^>*> pinBuffers;
337+
vector<gcroot<array<int>^>*> pinIndices;
338+
310339
// Map the managed space into the native space, results will be written directly into the managed memory space
340+
// https://msdn.microsoft.com/en-us/library/1dz8byfh.aspx
341+
311342
for each (auto item in inputs)
312343
{
313-
pin_ptr<ElemType> pb = &(item->m_buffer[0]);
314-
pin_ptr<int> pi = &(item->m_indices[0]);
315-
pin_ptr<int> pci = &(item->m_colIndices[0]);
316-
vb->m_buffer.InitFrom(pb, item->m_size, item->m_size);
317-
vb->m_indices.InitFrom(pi, item->m_size, item->m_size);
318-
vb->m_colIndices.InitFrom(pci, item->m_size, item->m_size);
344+
int size = item->Size;
345+
// gcroot object manages the pointer so that it always corresponds to the correct managed location (even after gc relocation)
346+
gcroot<array<ElemType>^>* pBuf = new gcroot<array<ElemType>^>(item->Buffer);
347+
gcroot<array<int>^>* pInd = new gcroot<array<int>^>(item->Indices);
348+
gcroot<array<int>^>* pColInd = new gcroot<array<int>^>(item->ColIndices);
349+
350+
pinBuffers.push_back(pBuf);
351+
pinIndices.push_back(pInd);
352+
pinIndices.push_back(pColInd);
353+
354+
pin_ptr<ElemType> pp = &(*pBuf)[0];
355+
pin_ptr<int> pi = &(*pInd)[0];
356+
pin_ptr<int> pci = &(*pColInd)[0];
357+
358+
vb->m_buffer.InitFrom(pp, size, size);
359+
vb->m_indices.InitFrom(pi, size, size);
360+
vb->m_colIndices.InitFrom(pci, size, size);
361+
319362
stdInputs.push_back(*vb);
320363
}
321364

322365
for each (auto item in outputs)
323366
{
324-
pin_ptr<ElemType> pb = &(item->m_buffer[0]);
325-
pin_ptr<int> pi = &(item->m_indices[0]);
326-
pin_ptr<int> pci = &(item->m_colIndices[0]);
327-
vb->m_buffer.InitFrom(pb, item->m_size, item->m_size);
328-
vb->m_indices.InitFrom(pi, item->m_size, item->m_size);
329-
vb->m_colIndices.InitFrom(pci, item->m_size, item->m_size);
367+
int size = item->Size;
368+
gcroot<array<ElemType>^>* pBuf = new gcroot<array<ElemType>^>(item->Buffer);
369+
gcroot<array<int>^>* pInd = new gcroot<array<int>^>(item->Indices);
370+
gcroot<array<int>^>* pColInd = new gcroot<array<int>^>(item->ColIndices);
371+
372+
pin_ptr<ElemType> pp = &(*pBuf)[0];
373+
pin_ptr<int> pi = &(*pInd)[0];
374+
pin_ptr<int> pci = &(*pColInd)[0];
375+
376+
pinBuffers.push_back(pBuf);
377+
pinIndices.push_back(pInd);
378+
pinIndices.push_back(pColInd);
379+
vb->m_buffer.InitFrom(pp, size, size);
380+
vb->m_indices.InitFrom(pi, size, size);
381+
vb->m_colIndices.InitFrom(pci, size, size);
382+
330383
stdOutputs.push_back(*vb);
331384
}
332385

@@ -345,18 +398,18 @@ public ref class IEvaluateModelExtendedManaged : IDisposable
345398
}
346399
}
347400

348-
~IEvaluateModelExtendedManaged()
401+
~ModelEvaluationExtended()
349402
{
350403
if (m_eval == nullptr)
351404
{
352405
return;
353406
}
354407

355-
this->!IEvaluateModelExtendedManaged();
408+
this->!ModelEvaluationExtended();
356409
}
357410

358411
protected:
359-
!IEvaluateModelExtendedManaged()
412+
!ModelEvaluationExtended()
360413
{
361414
if (m_eval != nullptr)
362415
{
@@ -417,7 +470,7 @@ public ref class IEvaluateModelExtendedManaged : IDisposable
417470
}
418471
}
419472

420-
VariableLayout::DataType GetDataType(Microsoft::MSR::CNTK::VariableLayout::DataType dataType)
473+
VariableLayout::DataType GetDataKind(Microsoft::MSR::CNTK::VariableLayout::DataType dataType)
421474
{
422475
switch ((int)dataType)
423476
{
@@ -430,7 +483,7 @@ public ref class IEvaluateModelExtendedManaged : IDisposable
430483
}
431484
}
432485

433-
VariableLayout::StorageType GetStorageType(Microsoft::MSR::CNTK::VariableLayout::StorageType storageType)
486+
VariableLayout::StorageType GetStorageKind(Microsoft::MSR::CNTK::VariableLayout::StorageType storageType)
434487
{
435488
switch ((int)storageType)
436489
{
@@ -448,22 +501,22 @@ public ref class IEvaluateModelExtendedManaged : IDisposable
448501

449502
/// <summary>Managed float-specific model evaluation class</summary>
450503
/// <remarks>This class is necessary due to how generics and templates work in CLR</remarks>
451-
public ref class IEvaluateModelExtendedManagedF : IEvaluateModelExtendedManaged<float>
504+
public ref class ModelEvaluationExtendedF : ModelEvaluationExtended<float>
452505
{
453506
public:
454-
IEvaluateModelExtendedManagedF::IEvaluateModelExtendedManagedF()
455-
: IEvaluateModelExtendedManaged("GetEvalExtendedF")
507+
ModelEvaluationExtendedF::ModelEvaluationExtendedF()
508+
: ModelEvaluationExtended("GetEvalExtendedF")
456509
{
457510
}
458511
};
459512

460513
/// <summary>Managed double-specific model evaluation class</summary>
461514
/// <remarks>This class is necessary due to how generics and templates work in CLR</remarks>
462-
public ref class IEvaluateModelExtendedManagedD : IEvaluateModelExtendedManaged<double>
515+
public ref class ModelEvaluationExtendedD : ModelEvaluationExtended<double>
463516
{
464517
public:
465-
IEvaluateModelExtendedManagedD::IEvaluateModelExtendedManagedD()
466-
: IEvaluateModelExtendedManaged("GetEvalExtendedD")
518+
ModelEvaluationExtendedD::ModelEvaluationExtendedD()
519+
: ModelEvaluationExtended("GetEvalExtendedD")
467520
{
468521
}
469522
};
@@ -472,16 +525,16 @@ public ref class IEvaluateModelExtendedManagedD : IEvaluateModelExtendedManaged<
472525
// This method tricks the compiler into emitting the methods of the classes
473526
// Refer to https://msdn.microsoft.com/en-us/library/ms177213.aspx for an
474527
// explanation to this behavior
475-
void emitExtended()
528+
void EmitExtended()
476529
{
477-
IEvaluateModelExtendedManagedF f;
530+
ModelEvaluationExtendedF f;
478531
f.CreateNetwork("");
479532
f.GetOutputSchema();
480533
f.GetInputSchema();
481534
f.StartForwardEvaluation(nullptr);
482535
f.ForwardPass(nullptr, nullptr);
483536

484-
IEvaluateModelExtendedManagedD d;
537+
ModelEvaluationExtendedD d;
485538
d.CreateNetwork("");
486539
d.GetOutputSchema();
487540
d.GetInputSchema();

0 commit comments

Comments
 (0)