Skip to content

Commit

Permalink
minibmg: precompute and optimize expression trees to be used during i…
Browse files Browse the repository at this point in the history
…nference (facebookresearch#1793)

Summary:
Pull Request resolved: facebookresearch#1793

Precompute the expression trees needed to evaluate the graph for NUTS.  Avoids all use of AD during inference.

This brings our performance to slightly faster than bmg.

Reviewed By: rodrigodesalvobraz

Differential Revision: D40813973

fbshipit-source-id: 335b1789cef73f742bbac996e82884e7f111c0a3
  • Loading branch information
Neal Gafter authored and facebook-github-bot committed Nov 5, 2022
1 parent f07bc3f commit 3970cf1
Show file tree
Hide file tree
Showing 16 changed files with 438 additions and 92 deletions.
1 change: 1 addition & 0 deletions minibmg/ad/reverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#pragma once

#include <vector>
#include "beanmachine/minibmg/ad/num2.h"
#include "beanmachine/minibmg/ad/number.h"
#include "beanmachine/minibmg/topological.h"
Expand Down
2 changes: 1 addition & 1 deletion minibmg/ad/traced.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Traced {
/* implicit */ inline Traced(Real value)
: node{std::make_shared<ScalarConstantNode>(value.as_double())} {}

static Traced variable(const std::string& name, const unsigned identifier) {
static Traced variable(const std::string& name, const int identifier) {
return Traced{std::make_shared<ScalarVariableNode>(name, identifier)};
}

Expand Down
43 changes: 43 additions & 0 deletions minibmg/eval.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "beanmachine/minibmg/eval.h"
#include <stdexcept>

namespace beanmachine::minibmg {

RecursiveNodeEvaluatorVisitor::RecursiveNodeEvaluatorVisitor(
std::function<double(const std::string& name, const int identifier)>
read_variable)
: read_variable{read_variable} {}

void RecursiveNodeEvaluatorVisitor::visit(const ScalarVariableNode* node) {
result = read_variable(node->name, node->identifier);
}

void RecursiveNodeEvaluatorVisitor::visit(const ScalarSampleNode*) {
throw std::logic_error("recursive evaluator may not sample");
}

Real RecursiveNodeEvaluatorVisitor::evaluate_input(const ScalarNodep& node) {
return evaluate_scalar(node);
}

std::shared_ptr<const Distribution<Real>>
RecursiveNodeEvaluatorVisitor::evaluate_input_distribution(
const DistributionNodep&) {
throw std::logic_error(
"recursive evaluator may not traffic in distributions");
}

double eval_node(
RecursiveNodeEvaluatorVisitor& evaluator,
const ScalarNodep& node) {
return evaluator.evaluate_scalar(node).value;
}

} // namespace beanmachine::minibmg
61 changes: 43 additions & 18 deletions minibmg/eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ struct SampledValue {
N log_prob;
};

} // namespace beanmachine::minibmg

namespace {

using namespace beanmachine::minibmg;

// A visitor that evaluates a single node. This method does not implement a
// particular policy for providing values for the inputs to the node being
// evaluated; the programmer must inherit from this class and implement several
Expand Down Expand Up @@ -79,7 +73,7 @@ class NodeEvaluatorVisitor : public NodeVisitor {
const DistributionNodep& node) = 0;

N result;
N evaluate_scalar(ScalarNodep& node) {
N evaluate_scalar(const ScalarNodep& node) {
node->accept(*this);
return result;
}
Expand Down Expand Up @@ -171,8 +165,7 @@ class NodeEvaluatorVisitor : public NodeVisitor {
template <class N>
requires Number<N>
class OneNodeAtATimeEvaluatorVisitor : public NodeEvaluatorVisitor<N> {
std::function<N(const std::string& name, const unsigned identifier)>
read_variable;
std::function<N(const std::string& name, const int identifier)> read_variable;
std::unordered_map<const Node*, double> observations;
N& log_prob;
std::unordered_map<Nodep, N>& data;
Expand All @@ -196,7 +189,7 @@ class OneNodeAtATimeEvaluatorVisitor : public NodeEvaluatorVisitor<N> {
public:
OneNodeAtATimeEvaluatorVisitor(
const Graph& graph,
std::function<N(const std::string& name, const unsigned identifier)>
std::function<N(const std::string& name, const int identifier)>
read_variable,
std::unordered_map<Nodep, N>& data,
std::unordered_map<Nodep, std::shared_ptr<const Distribution<N>>>&
Expand Down Expand Up @@ -245,18 +238,14 @@ class OneNodeAtATimeEvaluatorVisitor : public NodeEvaluatorVisitor<N> {
}
};

} // namespace

namespace beanmachine::minibmg {

template <class N>
requires Number<N>
struct EvalResult {
// The log probability of the overall computation.
N log_prob;

// The value of the queries.
std::vector<double> queries;
std::vector<N> queries;
};

template <class N>
Expand Down Expand Up @@ -290,7 +279,7 @@ template <class N>
requires Number<N> EvalResult<N> eval_graph(
const Graph& graph,
std::mt19937& gen,
std::function<N(const std::string& name, const unsigned identifier)>
std::function<N(const std::string& name, const int identifier)>
read_variable,
std::unordered_map<Nodep, N>& data,
bool run_queries = false,
Expand Down Expand Up @@ -330,11 +319,11 @@ requires Number<N> EvalResult<N> eval_graph(
}
}

std::vector<double> queries;
std::vector<N> queries;
if (run_queries) {
for (const auto& q : graph.queries) {
auto d = data.find(q);
double value = (d == data.end()) ? 0 : d->second.as_double();
N value = (d == data.end()) ? 0 : d->second;
queries.push_back(value);
}
}
Expand All @@ -358,4 +347,40 @@ class NodeRewriteAdapter<EvalResult<Underlying>> {
}
};

class RecursiveNodeEvaluatorVisitor : public NodeEvaluatorVisitor<Real> {
private:
std::function<double(const std::string& name, const int identifier)>
read_variable;

public:
explicit RecursiveNodeEvaluatorVisitor(
std::function<double(const std::string& name, const int identifier)>
read_variable);

private:
void visit(const ScalarVariableNode* node) override;

// The caller must provide a mechanism for proposing values for a sample node,
// e.g. by sampling from the distribution.
void visit(const ScalarSampleNode* node) override;

// The caller must provide a mechanism for evaluating the inputs to a node.
// For example, if the graph is a tree it might be done recursively. Or it
// might keep values in a map from node to value.
Real evaluate_input(const ScalarNodep& node) override;

// Similarly, the caller must provide a mechanism to evaluate inputs that are
// distributions.
std::shared_ptr<const Distribution<Real>> evaluate_input_distribution(
const DistributionNodep& node) override;
};

// Evaluate a single node by recursive descent. This works best if the node is
// a tree, rather than a directed acyclic graph with shared values. This
// cannot sample from a distribution or compute log_prob values unless that
// computation is already inlined into the node's tree.
double eval_node(
RecursiveNodeEvaluatorVisitor& evaluator,
const ScalarNodep& node);

} // namespace beanmachine::minibmg
2 changes: 1 addition & 1 deletion minibmg/graph_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ScalarNodeId Graph::Factory::constant(double value) {
}
ScalarNodeId Graph::Factory::variable(
const std::string& name,
const unsigned identifier) {
const int identifier) {
ScalarNodep result = std::make_shared<ScalarVariableNode>(name, identifier);
return add_node(result);
}
Expand Down
2 changes: 1 addition & 1 deletion minibmg/graph_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ namespace beanmachine::minibmg {
class Graph::Factory {
public:
ScalarNodeId constant(double value);
ScalarNodeId variable(const std::string& name, const unsigned identifier);
ScalarNodeId variable(const std::string& name, const int identifier);
ScalarSampleNodeId sample(
DistributionNodeId distribution,
const std::string& rvid = make_fresh_rvid());
Expand Down
30 changes: 26 additions & 4 deletions minibmg/inference/global_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,28 @@

namespace beanmachine::minibmg {

MinibmgGlobalState::MinibmgGlobalState(beanmachine::minibmg::Graph& graph)
: graph{graph}, world{hmc_world_0(graph)} {
std::unique_ptr<MinibmgGlobalState> MinibmgGlobalState::create0(
const beanmachine::minibmg::Graph& graph) {
return std::unique_ptr<MinibmgGlobalState>{
new MinibmgGlobalState{graph, hmc_world_0(graph)}};
}

std::unique_ptr<MinibmgGlobalState> MinibmgGlobalState::create1(
const beanmachine::minibmg::Graph& graph) {
return std::unique_ptr<MinibmgGlobalState>{
new MinibmgGlobalState{graph, hmc_world_1(graph)}};
}

std::unique_ptr<MinibmgGlobalState> MinibmgGlobalState::create2(
const beanmachine::minibmg::Graph& graph) {
return std::unique_ptr<MinibmgGlobalState>{
new MinibmgGlobalState{graph, hmc_world_2(graph)}};
}

MinibmgGlobalState::MinibmgGlobalState(
const beanmachine::minibmg::Graph& graph,
std::unique_ptr<const HMCWorld> world)
: graph{graph}, world{std::move(world)} {
samples.clear();
// Since we only support scalars, we count the unobserved samples by ones.
int num_unobserved_samples = -graph.observations.size();
Expand Down Expand Up @@ -47,6 +67,7 @@ void MinibmgGlobalState::initialize_values(
samples.push_back(result.unconstrained.as_double());
return result;
};
std::unordered_map<Nodep, Real> real_eval_data;
auto eval_result = eval_graph<Real>(
graph,
gen,
Expand Down Expand Up @@ -139,11 +160,12 @@ void MinibmgGlobalState::update_log_prob() {
}

void MinibmgGlobalState::update_backgrad() {
unconstrained_grads = world->gradients(this->unconstrained_values);
world->gradients(this->unconstrained_values, unconstrained_grads);
}

void MinibmgGlobalState::collect_sample() {
auto queries = world->queries(this->unconstrained_values);
std::vector<double> queries;
world->queries(this->unconstrained_values, queries);
std::vector<beanmachine::graph::NodeValue> compat_query;
for (auto v : queries) {
compat_query.emplace_back(v);
Expand Down
24 changes: 19 additions & 5 deletions minibmg/inference/global_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#pragma once

#include <memory>
#include "beanmachine/graph/global/global_state.h"
#include "beanmachine/graph/graph.h"
#include "beanmachine/minibmg/ad/real.h"
Expand All @@ -22,7 +23,20 @@ namespace beanmachine::minibmg {
// needed to use the NUTS api from bmg.
class MinibmgGlobalState : public beanmachine::graph::GlobalState {
public:
explicit MinibmgGlobalState(beanmachine::minibmg::Graph& graph);
// Create a global state that uses brute-force evaluation over the graph
static std::unique_ptr<MinibmgGlobalState> create0(
const beanmachine::minibmg::Graph& graph);

// Create a global state that first compiles the model to an expression tree
// and evaluates by interpreting that tree.
static std::unique_ptr<MinibmgGlobalState> create1(
const beanmachine::minibmg::Graph& graph);

// Create a global state that first compiles the model to an expression tree,
// generates code from that tree, and evaluates by running the generated code.
static std::unique_ptr<MinibmgGlobalState> create2(
const beanmachine::minibmg::Graph& graph);

void initialize_values(beanmachine::graph::InitType init_type, uint seed)
override;
void backup_unconstrained_values() override;
Expand All @@ -48,6 +62,10 @@ class MinibmgGlobalState : public beanmachine::graph::GlobalState {
void clear_samples() override;

private:
explicit MinibmgGlobalState(
const beanmachine::minibmg::Graph& graph,
std::unique_ptr<const HMCWorld> world);

const beanmachine::minibmg::Graph& graph;
const std::unique_ptr<const HMCWorld> world;
std::vector<std::vector<beanmachine::graph::NodeValue>> samples;
Expand All @@ -57,10 +75,6 @@ class MinibmgGlobalState : public beanmachine::graph::GlobalState {
std::vector<double> unconstrained_grads;
std::vector<double> saved_unconstrained_values;
std::vector<double> saved_unconstrained_grads;

// scratchpads for evaluation
std::unordered_map<Nodep, Reverse<Real>> reverse_eval_data;
std::unordered_map<Nodep, Real> real_eval_data;
};

} // namespace beanmachine::minibmg
Loading

0 comments on commit 3970cf1

Please sign in to comment.