Skip to content

Commit

Permalink
avoid import model from file in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreaCasalino committed Apr 26, 2022
1 parent eb1899d commit 7571840
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 56 deletions.
6 changes: 0 additions & 6 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,4 @@ target_link_libraries(${TEST_NAME} PUBLIC
EFG
)

target_compile_definitions(${TEST_NAME}
PUBLIC
-D TEST_FOLDER="${CMAKE_CURRENT_SOURCE_DIR}/"
-D SAMPLE_FOLDER="${CMAKE_CURRENT_SOURCE_DIR}/../samples/"
)

install(TARGETS ${TEST_NAME})
4 changes: 0 additions & 4 deletions tests/FactorDescription

This file was deleted.

80 changes: 80 additions & 0 deletions tests/ModelLibrary.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "ModelLibrary.h"
#include "Utils.h"

namespace EFG::test::library {
const float SimpleTree::alfa = 1.f;
const float SimpleTree::beta = 2.f;
const float SimpleTree::gamma = 1.f;
const float SimpleTree::eps = 1.5f;

SimpleTree::SimpleTree() {
auto A = categoric::make_variable(2, "A");
auto B = categoric::make_variable(2, "B");
auto C = categoric::make_variable(2, "C");
auto D = categoric::make_variable(2, "D");
auto E = categoric::make_variable(2, "E");

addTunableFactor(make_corr_expfactor2(A, B, alfa));
addTunableFactor(make_corr_expfactor2(B, C, beta));
addTunableFactor(make_corr_expfactor2(B, D, gamma));
addTunableFactor(make_corr_expfactor2(D, E, eps));
}

ComplexTree::ComplexTree() {
categoric::VariablesSoup vars;
vars.push_back(nullptr);
for (std::size_t k = 1; k <= 13; ++k) {
vars.push_back(categoric::make_variable(2, "v" + std::to_string(k)));
}

const float w = 1.f;
addTunableFactor(make_corr_expfactor2(vars[1], vars[4], w));
addTunableFactor(make_corr_expfactor2(vars[2], vars[4], w));
addTunableFactor(make_corr_expfactor2(vars[3], vars[5], w));
addTunableFactor(make_corr_expfactor2(vars[4], vars[6], w));
addTunableFactor(make_corr_expfactor2(vars[4], vars[7], w));
addTunableFactor(make_corr_expfactor2(vars[5], vars[7], w));
addTunableFactor(make_corr_expfactor2(vars[5], vars[8], w));
addTunableFactor(make_corr_expfactor2(vars[6], vars[9], w));
addTunableFactor(make_corr_expfactor2(vars[6], vars[10], w));
addTunableFactor(make_corr_expfactor2(vars[7], vars[11], w));
addTunableFactor(make_corr_expfactor2(vars[8], vars[12], w));
addTunableFactor(make_corr_expfactor2(vars[8], vars[13], w));
}

const float SimpleLoopy::w = 1.f;

SimpleLoopy::SimpleLoopy() {
auto A = categoric::make_variable(2, "A");
auto B = categoric::make_variable(2, "B");
auto C = categoric::make_variable(2, "C");
auto D = categoric::make_variable(2, "D");
auto E = categoric::make_variable(2, "E");

addTunableFactor(make_corr_expfactor2(A, B, w));
addTunableFactor(make_corr_expfactor2(B, C, w));
addTunableFactor(make_corr_expfactor2(B, D, w));
addTunableFactor(make_corr_expfactor2(C, D, w));
addTunableFactor(make_corr_expfactor2(E, D, w));
}

ComplexLoopy::ComplexLoopy() {
categoric::VariablesSoup vars;
vars.push_back(nullptr);
for (std::size_t k = 1; k <= 8; ++k) {
vars.push_back(categoric::make_variable(2, "v" + std::to_string(k)));
}

const float w = 1.f;
addTunableFactor(make_corr_expfactor2(vars[1], vars[2], w));
addTunableFactor(make_corr_expfactor2(vars[2], vars[4], w));
addTunableFactor(make_corr_expfactor2(vars[2], vars[3], w));
addTunableFactor(make_corr_expfactor2(vars[3], vars[4], w));
addTunableFactor(make_corr_expfactor2(vars[4], vars[5], w));
addTunableFactor(make_corr_expfactor2(vars[3], vars[5], w));
addTunableFactor(make_corr_expfactor2(vars[4], vars[6], w));
addTunableFactor(make_corr_expfactor2(vars[5], vars[7], w));
addTunableFactor(make_corr_expfactor2(vars[6], vars[7], w));
addTunableFactor(make_corr_expfactor2(vars[7], vars[8], w));
}
} // namespace EFG::test::library
36 changes: 36 additions & 0 deletions tests/ModelLibrary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include <EasyFactorGraph/model/RandomField.h>

namespace EFG::test::library {
class SimpleTree : public EFG::model::RandomField {
public:
SimpleTree();

static const float alfa;
static const float beta;
static const float gamma;
static const float eps;
};
static const SimpleTree SIMPLE_TREE = SimpleTree{};

class ComplexTree : public EFG::model::RandomField {
public:
ComplexTree();
};
static const ComplexTree COMPLEX_TREE = ComplexTree{};

class SimpleLoopy : public EFG::model::RandomField {
public:
SimpleLoopy();

static const float w;
};
static const SimpleLoopy SIMPLE_LOOPY = SimpleLoopy{};

class ComplexLoopy : public EFG::model::RandomField {
public:
ComplexLoopy();
};
static const ComplexLoopy COMPLEX_LOOPY = ComplexLoopy{};
} // namespace EFG::test::library
12 changes: 9 additions & 3 deletions tests/Test03-Factor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,21 @@ TEST_CASE("operations on factor", "[factor]") {
}

#include <EasyFactorGraph/io/FactorImporter.h>
#include <fstream>
#include <math.h>

TEST_CASE("import from file", "[factor]") {
// 2,2,4
const std::string &file_name = "factor_description";
{
std::ofstream stream(file_name);
stream << "0 1 1 2.0\n";
stream << "0 0 0 3.0\n";
stream << "1 1 3 2.5\n";
stream << "1 0 2 1.4";
} // 2,2,4
categoric::VariablesSoup vars = {make_variable(2, "A"), make_variable(2, "B"),
make_variable(4, "C")};
distribution::Factor factor_ABC(categoric::Group{vars});
const std::string file_name =
std::string(TEST_FOLDER) + std::string("FactorDescription");

io::import_values(factor_ABC, file_name);

Expand Down
53 changes: 27 additions & 26 deletions tests/Test07-BeliefPropagation.cpp
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators.hpp>

#include "ModelLibrary.h"
#include "Utils.h"
#include <EasyFactorGraph/io/xml/Importer.h>
#include <EasyFactorGraph/model/Graph.h>

using namespace EFG;
using namespace EFG::model;
using namespace EFG::strct;
using namespace EFG::io;
using namespace EFG::categoric;
using namespace EFG::distribution;
using namespace EFG::test;
using namespace EFG::test::library;

namespace {
class TestModels : public Graph {
template <typename ModelT> class TestModels : public ModelT {
public:
TestModels(const std::string &file_name) {
xml::Importer::importFromFile(*this, file_name);
}
TestModels() = default;

// check all messages were computed after propagation
bool araAllMessagesComputed() const {
for (const auto &cluster : getState().clusters) {
bool areAllMessagesComputed() const {
for (const auto &cluster : this->getState().clusters) {
for (const auto *node : cluster.nodes) {
for (const auto &[connected_node, connection] :
node->active_connections) {
Expand All @@ -38,9 +36,9 @@ class TestModels : public Graph {
bool checkMarginals(const std::string &var_name,
const std::vector<float> &expected,
const float threshold = 0.01f) {
const auto var = findVariable(var_name);
const auto var = this->findVariable(var_name);
return almost_equal(*ProbDistribution{expected},
getMarginalDistribution(var_name), threshold);
this->getMarginalDistribution(var_name), threshold);
}
};

Expand Down Expand Up @@ -71,13 +69,16 @@ bool are_equal(const PropagationResult &a, const PropagationResult &b) {
}
} // namespace

TEST_CASE("simple poly tree", "[propagation]") {
TestModels model(make_graph_path("graph_1.xml"));

float a = expf(1.f), b = expf(2.f), g = expf(1.f), e = expf(1.5f);
TEST_CASE("simple poly tree belief propagation", "[propagation]") {
TestModels<SimpleTree> model;

REQUIRE_FALSE(model.hasPropagationResult());

const float a = expf(SimpleTree::alfa);
const float b = expf(SimpleTree::beta);
const float g = expf(SimpleTree::gamma);
const float e = expf(SimpleTree::eps);

// E=1
model.setEvidence(model.findVariable("E"), 1);
CHECK(model.checkMarginals(
Expand All @@ -92,7 +93,7 @@ TEST_CASE("simple poly tree", "[propagation]") {
std::vector<ClusterInfo>{ClusterInfo{true, 4}};
REQUIRE(are_equal(propagation_expected, propagation_result));
}
REQUIRE(model.araAllMessagesComputed());
REQUIRE(model.areAllMessagesComputed());
CHECK(model.checkMarginals("B", {(g + e), (1 + g * e)}));
CHECK(model.checkMarginals(
"C", {(b * (g + e) + (1 + g * e)), ((g + e) + b * (1 + g * e))}));
Expand All @@ -113,14 +114,14 @@ TEST_CASE("simple poly tree", "[propagation]") {
std::vector<ClusterInfo>{ClusterInfo{true, 3}, ClusterInfo{true, 1}};
REQUIRE(are_equal(propagation_expected, propagation_result));
}
REQUIRE(model.araAllMessagesComputed());
REQUIRE(model.areAllMessagesComputed());
CHECK(model.checkMarginals("B", {1.f, g}));
CHECK(model.checkMarginals("C", {b + g, 1.f + b * g}));
CHECK(model.checkMarginals("E", {1.f, e}));
}

TEST_CASE("complex poly tree", "[propagation]") {
TestModels model(make_graph_path("graph_2.xml"));
TEST_CASE("complex poly tree belief propagation", "[propagation]") {
TestModels<ComplexTree> model;
model.setEvidence(model.findVariable("v1"), 1);
model.setEvidence(model.findVariable("v2"), 1);
model.setEvidence(model.findVariable("v3"), 1);
Expand All @@ -140,13 +141,13 @@ TEST_CASE("complex poly tree", "[propagation]") {
CHECK(prob[0] < prob[1]);
}

CHECK(model.araAllMessagesComputed());
CHECK(model.areAllMessagesComputed());
}

TEST_CASE("simple loopy graph", "[propagation]") {
TestModels model(make_graph_path("graph_3.xml"));
TEST_CASE("simple loopy graph belief propagation", "[propagation]") {
TestModels<SimpleLoopy> model;

float M = expf(1.f);
float M = expf(SimpleLoopy::w);
float M_alfa = powf(M, 3) + M + 2.f * powf(M, 2);
float M_beta = powf(M, 4) + 2.f * M + powf(M, 2);

Expand All @@ -164,23 +165,23 @@ TEST_CASE("simple loopy graph", "[propagation]") {
std::vector<ClusterInfo>{ClusterInfo{false, 4}};
REQUIRE(are_equal(propagation_expected, propagation_result));
}
REQUIRE(model.araAllMessagesComputed());
REQUIRE(model.areAllMessagesComputed());
CHECK(model.checkMarginals("C", {M_alfa, M_beta}, 0.045f));
CHECK(model.checkMarginals("B", {M_alfa, M_beta}, 0.045f));
CHECK(model.checkMarginals("A", {M * M_alfa + M_beta, M_alfa + M * M_beta},
0.045f));
}

TEST_CASE("complex loopy graph", "[propagation]") {
TestModels model(make_graph_path("graph_4.xml"));
TEST_CASE("complex loopy graph belief propagation", "[propagation]") {
TestModels<ComplexLoopy> model;

model.setEvidence(model.findVariable("v1"), 1);

auto threads = GENERATE(1, 2, 4);

auto prob = model.getMarginalDistribution("v8", threads);
CHECK(prob[0] < prob[1]);
CHECK(model.araAllMessagesComputed());
CHECK(model.areAllMessagesComputed());
}

#include <sstream>
Expand Down
19 changes: 11 additions & 8 deletions tests/Test08-Gibbs.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators.hpp>

#include <EasyFactorGraph/io/xml/Importer.h>
#include <EasyFactorGraph/model/Graph.h>

#include "ModelLibrary.h"
#include "Utils.h"

#include <algorithm>
Expand All @@ -13,8 +13,8 @@ using namespace EFG::categoric;
using namespace EFG::distribution;
using namespace EFG::model;
using namespace EFG::strct;
using namespace EFG::io;
using namespace EFG::test;
using namespace EFG::test::library;

namespace {
bool are_samples_valid(const std::vector<Combination> &samples,
Expand Down Expand Up @@ -205,9 +205,12 @@ TEST_CASE("simple graph gibbs sampling", "[gibbs_sampling]") {
}

TEST_CASE("polyTree gibbs sampling", "[gibbs_sampling]") {
float a = expf(1.f), b = expf(2.f), g = expf(1.f), e = expf(1.5f);
Graph model;
xml::Importer::importFromFile(model, make_graph_path("graph_1.xml"));
RandomField model(SIMPLE_TREE);

const float a = expf(SimpleTree::alfa);
const float b = expf(SimpleTree::beta);
const float g = expf(SimpleTree::gamma);
const float e = expf(SimpleTree::eps);

auto threads = GENERATE(1, 2, 3);

Expand All @@ -234,11 +237,11 @@ TEST_CASE("polyTree gibbs sampling", "[gibbs_sampling]") {
}

TEST_CASE("loopy model gibbs sampling", "[gibbs_sampling]") {
float M = expf(1.f);
RandomField model(SIMPLE_LOOPY);

float M = expf(SimpleLoopy::w);
float M_alfa = powf(M, 3) + M + 2.f * powf(M, 2);
float M_beta = powf(M, 4) + 2.f * M + powf(M, 2);
Graph model;
xml::Importer::importFromFile(model, make_graph_path("graph_3.xml"));

auto threads = GENERATE(1, 2, 3);

Expand Down
1 change: 0 additions & 1 deletion tests/Test10-Learning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "Utils.h"
#include <EasyFactorGraph/categoric/GroupRange.h>
#include <EasyFactorGraph/distribution/CombinationFinder.h>
#include <EasyFactorGraph/io/xml/Importer.h>
#include <EasyFactorGraph/model/ConditionalRandomField.h>
#include <EasyFactorGraph/model/RandomField.h>
#include <EasyFactorGraph/structure/QueryManager.h>
Expand Down
6 changes: 0 additions & 6 deletions tests/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ make_corr_expfactor2(const categoric::VariablePtr &first,
return std::make_shared<distribution::FactorExponential>(factor, w);
}

std::string make_graph_path(const std::string &file_name) {
std::stringstream stream;
stream << SAMPLE_FOLDER << "Sample03-BeliefPropagation-B/" << file_name;
return stream.str();
}

bool almost_equal(const float a, const float b, const float tollerance) {
return fabs(a - b) < tollerance;
}
Expand Down
2 changes: 0 additions & 2 deletions tests/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ std::shared_ptr<distribution::FactorExponential>
make_corr_expfactor2(const categoric::VariablePtr &first,
const categoric::VariablePtr &second, const float w);

std::string make_graph_path(const std::string &file_name);

bool almost_equal(const float a, const float b, const float tollerance);

bool almost_equal(const std::vector<float> &a, const std::vector<float> &b,
Expand Down

0 comments on commit 7571840

Please sign in to comment.