forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
VS forgot to add the new files to the git repo...
- Loading branch information
1 parent
7fc742e
commit 94e49de
Showing
3 changed files
with
138 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// TensorView.cpp -- main CPP file that contains all functions exported by the CNTKMath.dll | ||
// | ||
// <copyright file="Matrix.cpp" company="Microsoft"> | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// </copyright> | ||
// | ||
|
||
// This implements the TensorView class, which is a layer around Matrix that reinterprets its content as a generic tensor. | ||
|
||
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings | ||
|
||
#include "Basics.h" | ||
#include "TensorView.h" | ||
#include <array> | ||
|
||
#ifndef let | ||
#define let const auto | ||
#endif | ||
|
||
namespace Microsoft { | ||
namespace MSR { | ||
namespace CNTK { | ||
|
||
using namespace std; | ||
|
||
// cast a matrix as a tensor | ||
template<class ElemType> | ||
TensorView<ElemType>::TensorView(const Matrix<ElemType> & sob) : | ||
m_sob(sob), m_shape(TensorShape(array<size_t, 2> { sob.GetNumRows(), sob.GetNumCols() })) | ||
{ } | ||
template<class ElemType> | ||
TensorView<ElemType>::TensorView(const TensorView<ElemType> & other, const TensorShape & shape) : | ||
m_sob(other.m_sob), m_shape(shape) | ||
{ | ||
// for now we enforce that tensor dimensions match dimensions of the underlying matrix storage object | ||
// This is for sanity checks. In the future, it may appropriate to reduce this check to just checking the total number of elements, to allow abuses. | ||
// TODO: Use the multipliers instead? | ||
size_t i; | ||
size_t rowDim = 1; | ||
for (i = 0; i < m_shape.size() && rowDim < m_sob.GetNumRows(); i++) | ||
rowDim *= m_shape[i]; | ||
// first i dimensions match matrix row dimension | ||
size_t colDim = 1; | ||
for (; i < m_shape.size(); i++) | ||
colDim *= m_shape[i]; | ||
if (rowDim != m_sob.GetNumRows() || colDim != m_sob.GetNumCols()) | ||
LogicError("TensorView: Tensor dimensions %s do not match storage-object dims %d x %d", string(m_shape).c_str(), (int)m_sob.GetNumRows(), (int)m_sob.GetNumCols()); | ||
} | ||
|
||
// simple test function for testing stuff | ||
template<class ElemType> | ||
/*static*/ void TensorView<ElemType>::Test() | ||
{ | ||
Matrix<ElemType> m(0); | ||
m.Resize(13, 42); | ||
TensorShape s1(13, 2, 22); | ||
TensorShape s2(13, 2, 21); | ||
s1;//let t1 = TensorView<ElemType>(m, s1); t1; | ||
let t2 = TensorView<ElemType>(m, s2); t2; | ||
} | ||
|
||
template class TensorView<float>; | ||
template class TensorView<double>; | ||
|
||
}}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// | ||
// <copyright file="TensorView.h" company="Microsoft"> | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// </copyright> | ||
// | ||
|
||
// This implements the TensorView class, which is a layer around Matrix that reinterprets its content as a generic tensor. | ||
|
||
#pragma once | ||
|
||
#include "Basics.h" | ||
#include "Matrix.h" | ||
#include "DataTensor.h" | ||
|
||
#pragma warning (push) | ||
#pragma warning (disable: 4251) // needs to have dll-interface to be used by clients of... caused by TensorView::m_shape which is only private. We use the same compiler everywhere. | ||
|
||
// This class is exported from the Math.dll. | ||
namespace Microsoft { namespace MSR { namespace CNTK { | ||
|
||
template<class ElemType> | ||
class MATH_API TensorView | ||
{ | ||
public: | ||
// ------------------------------------------------------------------- | ||
// construction | ||
// ------------------------------------------------------------------- | ||
|
||
// cast a matrix as a TensorView (without shape change) | ||
TensorView(const Matrix<ElemType> & sob); | ||
// reshape a TensorView | ||
TensorView(const TensorView<ElemType> & sob, const TensorShape & shape); | ||
// reinterpret a Matrix as a TensorView with reshaping | ||
TensorView(const Matrix<ElemType> & sob, const TensorShape & shape) : | ||
TensorView(TensorView(sob)/*cast as a TensorView*/, shape/*with a shape*/) | ||
{ } | ||
// copy constructor | ||
TensorView(const TensorView<ElemType> & other) : | ||
TensorView(other.m_sob, other.m_shape) | ||
{ } | ||
// assignment is forbidden since we contain a reference | ||
// If you ever need this, change the reference to a pointer. | ||
void operator=(const TensorView & other) = delete; // since we have a reference | ||
|
||
// ------------------------------------------------------------------- | ||
// operations | ||
// ------------------------------------------------------------------- | ||
|
||
static void Test(); | ||
|
||
private: | ||
|
||
// ------------------------------------------------------------------- | ||
// sob members | ||
// ------------------------------------------------------------------- | ||
|
||
const Matrix<ElemType> & m_sob; // Storage OBject that holds the data that is being viewed with this TensorView | ||
TensorShape m_shape; // the meta-data that describes the data's shape and/or access pattern | ||
}; | ||
|
||
}}} | ||
|
||
#pragma warning (pop) |