forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cudalattice.h
131 lines (120 loc) · 7.38 KB
/
cudalattice.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// cudalattice.h -- lattice forward/backward functions CUDA execution
//
// F. Seide, V-hansu
#pragma once
#include <stdexcept> // (for NOCUDA version only)
#include "latticestorage.h" // basic data types for storing lattices
#define __kernel_emulation__ // allow the compilation of CUDA kernels on the CPU
#include "latticefunctionskernels.h" // for data types passed to kernel functions
#undef __kernel_emulation__
#include "Matrix.h"
namespace msra { namespace cuda {
// ---------------------------------------------------------------------------
// how to add vector types here:
// - for types with methods that operate on the vector, follow the example of 'somedatavector'
// - for types without extra methods, follow the example of 'ushortvector'
// - the total set of classes that need to be defined is:
// - XXXvector --the interface that the calling code sees (in this header file)
// - XXXvectorimpl --glue for each method: switch CUDA device, then call kernel (in XXXvectorops, cudalattice.cpp)
// - XXXvectorops --contains the actual kernel functions (cudalatticeops.h, cudalatticeops.cu.h)
// - comments:
// - follow precisely this pattern
// - for non-vector elements you need to pass by reference, you can still use this with a size-1 vector
// - and for functions that have no 'this', use a size-0 vector
// - note that you cannot add data members to XXXvector, as there is no mechanism to sync between GPU and CPU; only vector elements are synced
// - remember to Get Latest in both projects to see the change, for cudalattice.h and latticestorage.h
// ---------------------------------------------------------------------------
// a vector type; use this as a basetype
// Note that this will only be instantiated for types known inside this lib, and a newvector<> function must be exported for each.
template <typename ELEMTYPE>
struct /*interface*/ vectorbase
{
virtual ~vectorbase()
{
}
virtual void allocate(size_t n) = 0;
virtual size_t size() const throw() = 0;
virtual void assign(const ELEMTYPE* p, size_t n, bool synchronize) = 0;
template <class VECTOR>
void assign(const VECTOR& v, bool synchronize)
{
allocate(v.size());
if (!v.empty())
assign(&v[0], v.size(), synchronize);
}
virtual void fetch(ELEMTYPE* p, size_t n, bool synchronize) const = 0;
template <class VECTOR>
void fetch(VECTOR& v, bool synchronize) const
{
v.resize(size());
if (!v.empty())
fetch(&v[0], v.size(), synchronize);
}
typedef ELEMTYPE elemtype;
};
// ---------------------------------------------------------------------------
// vectors of these with custom functions
// The XXXvector classes must derive from vectorbase<XXX>.
// ---------------------------------------------------------------------------
struct somedatavector : public vectorbase<msra::lattices::somedata>
{
// must implement all members of vectorbase<>, and can add operations here
virtual int somedataoperation(size_t arg) = 0;
};
typedef vectorbase<unsigned short> ushortvector;
typedef vectorbase<float> floatvector;
typedef vectorbase<double> doublevector;
typedef vectorbase<unsigned int> uintvector;
typedef vectorbase<size_t> sizetvector;
typedef vectorbase<msra::lattices::lrhmmdef> lrhmmdefvector;
typedef vectorbase<msra::lattices::lr3transP> lr3transPvector;
typedef vectorbase<msra::lattices::nodeinfo> nodeinfovector;
typedef vectorbase<msra::lattices::edgeinfowithscores> edgeinfowithscoresvector;
typedef vectorbase<msra::lattices::aligninfo> aligninfovector;
struct latticefunctions : public vectorbase<msra::lattices::empty>
{
virtual void edgealignment(const lrhmmdefvector& hmms, const lr3transPvector& transPs, const size_t spalignunitid,
const size_t silalignunitid, const Microsoft::MSR::CNTK::Matrix<float>& logLLs, const nodeinfovector& nodes,
const edgeinfowithscoresvector& edges, const aligninfovector& aligns,
const uintvector& alignoffsets, ushortvector& backptrstorage, const sizetvector& backptroffsets,
ushortvector& alignresult, floatvector& edgeacscores) = 0; // output
virtual void forwardbackwardlattice(const size_t* batchsizeforward, const size_t* batchsizebackward,
const size_t numlaunchforward, const size_t numlaunchbackward,
const size_t spalignunitid, const size_t silalignunitid,
const floatvector& edgeacscores, const edgeinfowithscoresvector& edges,
const nodeinfovector& nodes, const aligninfovector& aligns,
const ushortvector& alignoutput, const uintvector& alignoffsets,
doublevector& logpps, doublevector& logalphas, doublevector& logbetas,
const float lmf, const float wp, const float amf, const float boostingfactor, const bool returnEframescorrect,
const ushortvector& uids, const ushortvector& senone2classmap,
doublevector& logaccalphas, doublevector& logaccbetas,
doublevector& logframescorrectedge, doublevector& logEframescorrect,
doublevector& Eframescorrectbuf, double& logEframescorrecttotal, double& totalfwscore) = 0;
virtual void sMBRerrorsignal(const ushortvector& alignstateids, const uintvector& alignoffsets,
const edgeinfowithscoresvector& edges, const nodeinfovector& nodes,
const doublevector& logpps, const float amf, const doublevector& logEframescorrect,
const double logEframescorrecttotal, Microsoft::MSR::CNTK::Matrix<float>& dengammas, Microsoft::MSR::CNTK::Matrix<float>& dengammasbuf) = 0;
virtual void mmierrorsignal(const ushortvector& alignstateids, const uintvector& alignoffsets,
const edgeinfowithscoresvector& edges, const nodeinfovector& nodes,
const doublevector& logpps, Microsoft::MSR::CNTK::Matrix<float>& dengammas) = 0;
virtual void stateposteriors(const ushortvector& alignstateids, const uintvector& alignoffsets,
const edgeinfowithscoresvector& edges, const nodeinfovector& nodes,
const doublevector& logqs, Microsoft::MSR::CNTK::Matrix<float>& logacc) = 0;
};
// ---------------------------------------------------------------------------
// factor methods
// ---------------------------------------------------------------------------
somedatavector* newsomedatavector(size_t deviceid);
ushortvector* newushortvector(size_t deviceid);
uintvector* newuintvector(size_t deviceid);
floatvector* newfloatvector(size_t deviceid);
doublevector* newdoublevector(size_t deviceid);
sizetvector* newsizetvector(size_t deviceid);
latticefunctions* newlatticefunctions(size_t deviceid);
lrhmmdefvector* newlrhmmdefvector(size_t deviceid);
lr3transPvector* newlr3transPvector(size_t deviceid);
nodeinfovector* newnodeinfovector(size_t deviceid);
edgeinfowithscoresvector* newedgeinfovector(size_t deviceid);
aligninfovector* newaligninfovector(size_t deviceid);
};
};