Skip to content

Commit

Permalink
Add some constructors for generating object that only contains shape …
Browse files Browse the repository at this point in the history
…(do not contains data).
  • Loading branch information
hedaoyuan committed Jan 13, 2017
1 parent 2a20fdc commit 039c0bf
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
33 changes: 31 additions & 2 deletions paddle/function/BufferArg.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
class BufferArg;
class SequenceArg;
class SparseMatrixArg;
typedef std::shared_ptr<BufferArg> BufferArgPtr;

/**
* \brief BufferArg used as the argument type of Function.
Expand All @@ -50,6 +49,11 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
* 3. SequenceArg for a Buffer of sequence data.
* 4. SparseMatrixArg for a Buffer of sparse matrix.
*
* Buffer shape
* For most buffers, the first dimension `shape()[0]` represents
* the size of the mini-batch.
*
* Buffer argType
* There is an ArgType property for the BufferArg used as Function Output.
* Whether the result of the Function calculation is assigned to the
* output Buffer or added to the output Buffer is determined by the
Expand All @@ -71,6 +75,14 @@ class BufferArg {
ArgType getArgType() const { return argType_; }

public:
BufferArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: buf_(nullptr),
valueType_(valueType),
shape_(shape),
argType_(argType) {}

BufferArg(void* buf,
ValueType valueType,
const TensorShape& shape,
Expand Down Expand Up @@ -170,6 +182,12 @@ class BufferArg {
// if a < b then value_.buf_[a] < value_.buf_[b]
class SequenceIdArg : public BufferArg {
public:
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
CHECK_EQ(shape_.ndims(), (size_t)1);
numSeqs_ = shape_[0] - 1;
}

SequenceIdArg(void* buf,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
Expand All @@ -190,9 +208,18 @@ class SequenceIdArg : public BufferArg {
size_t numSeqs_;
};

// sequence data
// sequences data
// For mini-batch calculate,
// one batch can contain more than one sequence of data.
// SequenceArg can be used to represent sequences that contain multiple
// unequal lengths.
class SequenceArg : public BufferArg {
public:
SequenceArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {}

SequenceArg(void* buf,
ValueType valueType,
const TensorShape& shape,
Expand All @@ -210,6 +237,8 @@ class SequenceArg : public BufferArg {

void* getIdBuf() const { return startPositions_.data(); }
size_t numSeqs() const { return startPositions_.numSeqs(); }
SequenceIdArg& getSequenceId() { return startPositions_; }
const SequenceIdArg& getSequenceId() const { return startPositions_; }

private:
SequenceIdArg startPositions_;
Expand Down
18 changes: 18 additions & 0 deletions paddle/function/FunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ void testBufferArgs(const BufferArgs& inputs,
}
}

void testBufferArgs(const BufferArgs& inputs, const CheckBufferArg& check) {
check(inputs[0]);
}

TEST(Arguments, Matrix) {
MatrixPtr matrix = Matrix::create(100, 200);
CheckBufferArg check = [=](const BufferArg& arg) {
Expand Down Expand Up @@ -144,4 +148,18 @@ TEST(Arguments, CpuSparseMatrix) {
testBufferArgs(argments, checkFunc);
}

TEST(Arguments, BufferArg) {
BufferArg arg(nullptr, VALUE_TYPE_FLOAT, {1, 2, 3});
CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 3);
EXPECT_EQ(arg.shape()[0], 1);
EXPECT_EQ(arg.shape()[1], 2);
EXPECT_EQ(arg.shape()[2], 3);
};

BufferArgs argments;
argments.addArg(arg);
testBufferArgs(argments, check);
}

} // namespace paddle

0 comments on commit 039c0bf

Please sign in to comment.