Skip to content

Commit

Permalink
Refactorize enforece_test.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
wangkuiyi committed Aug 9, 2017
1 parent 9a52056 commit 54cda76
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 97 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ RUN apt-get update && \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format-3.8 swig doxygen cmake \
automake locales clang-format swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \
Expand Down
2 changes: 1 addition & 1 deletion paddle/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)

add_subdirectory(dynload)

cc_test(enforce_test SRCS enforce_test.cc)
cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece)

IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
Expand Down
144 changes: 49 additions & 95 deletions paddle/platform/enforce_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ limitations under the License. */

#include "gtest/gtest.h"
#include "paddle/platform/enforce.h"
#include "paddle/string/piece.h"

using StringPiece = paddle::string::Piece;
using paddle::string::HasPrefix;

TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
Expand All @@ -22,19 +26,15 @@ TEST(ENFORCE, OK) {
}

TEST(ENFORCE, FAILED) {
bool in_catch = false;
bool caught_exception = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (paddle::platform::EnforceNotMet error) {
// your error handling code here
in_catch = true;
std::string msg = "Enforce is not ok 123 at all";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all"));
}
ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}

TEST(ENFORCE, NO_ARG_OK) {
Expand All @@ -47,84 +47,60 @@ TEST(ENFORCE, NO_ARG_OK) {

TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) {
int a = 2;
bool in_catch = false;

bool caught_exception = false;
try {
PADDLE_ENFORCE_EQ(a, 1 + 3);

} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce a == 1 + 3 failed, 2 != 4";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
HasPrefix(StringPiece(error.what()), "enforce a == 1 + 3 failed, 2 != 4");
}

ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}

TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) {
int a = 2;
bool in_catch = false;

bool caught_exception = false;
try {
PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their");

} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg =
"enforce a == 1 + 3 failed, 2 != 4\ntheir size not match";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
HasPrefix(StringPiece(error.what()),
"enforce a == 1 + 3 failed, 2 != 4\ntheir size not match");
}

ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}

TEST(ENFORCE_NE, OK) {
PADDLE_ENFORCE_NE(1, 2);
PADDLE_ENFORCE_NE(1.0, 2UL);
}
TEST(ENFORCE_NE, FAIL) {
bool in_catch = false;
bool caught_exception = false;

try {
// 2UL here to check data type compatible
PADDLE_ENFORCE_NE(1.0, 1UL);

} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1.0 != 1UL failed, 1.000000 == 1";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(HasPrefix(StringPiece(error.what()),
"enforce 1.0 != 1UL failed, 1.000000 == 1"))
<< error.what() << " does not have expected prefix";
}

ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}

TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); }
TEST(ENFORCE_GT, FAIL) {
bool in_catch = false;

bool caught_exception = false;
try {
// 2UL here to check data type compatible
PADDLE_ENFORCE_GT(1, 2UL);

} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1 > 2UL failed, 1 <= 2";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2"));
}

ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}

TEST(ENFORCE_GE, OK) {
Expand All @@ -134,21 +110,16 @@ TEST(ENFORCE_GE, OK) {
PADDLE_ENFORCE_GE(3.21, 2UL);
}
TEST(ENFORCE_GE, FAIL) {
bool in_catch = false;

bool caught_exception = false;
try {
PADDLE_ENFORCE_GE(1, 2UL);

} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1 >= 2UL failed, 1 < 2";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "enforce 1 >= 2UL failed, 1 < 2"));
}

ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}

TEST(ENFORCE_LE, OK) {
Expand All @@ -159,21 +130,16 @@ TEST(ENFORCE_LE, OK) {
PADDLE_ENFORCE_LE(2UL, 3.2);
}
TEST(ENFORCE_LE, FAIL) {
bool in_catch = false;

bool caught_exception = false;
try {
PADDLE_ENFORCE_GT(1, 2UL);

} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1 > 2UL failed, 1 <= 2";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "enforce 1 > 2UL failed, 1 <= 2"));
}

ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}

TEST(ENFORCE_LT, OK) {
Expand All @@ -182,21 +148,15 @@ TEST(ENFORCE_LT, OK) {
PADDLE_ENFORCE_LT(2UL, 3);
}
TEST(ENFORCE_LT, FAIL) {
bool in_catch = false;

bool caught_exception = false;
try {
PADDLE_ENFORCE_LT(1UL, 0.12);

} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "enforce 1UL < 0.12 failed, 1 >= 0.12";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(HasPrefix(StringPiece(error.what()),
"enforce 1UL < 0.12 failed, 1 >= 0.12"));
}

ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}

TEST(ENFORCE_NOT_NULL, OK) {
Expand All @@ -205,20 +165,14 @@ TEST(ENFORCE_NOT_NULL, OK) {
delete a;
}
TEST(ENFORCE_NOT_NULL, FAIL) {
bool in_catch = false;
int* a{nullptr};

bool caught_exception = false;
try {
int* a = nullptr;
PADDLE_ENFORCE_NOT_NULL(a);

} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "a should not be null";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
caught_exception = true;
EXPECT_TRUE(HasPrefix(StringPiece(error.what()), "a should not be null"));
}

ASSERT_TRUE(in_catch);
EXPECT_TRUE(caught_exception);
}

0 comments on commit 54cda76

Please sign in to comment.