Skip to content

Commit

Permalink
[autodiff] Allocate dual and adjoint snode (taichi-dev#5083)
Browse files Browse the repository at this point in the history
* allocate dual and decouple grad and adjoint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* update

* update the adjoint name

* fix matrix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* recover the grad name

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
erizmr and pre-commit-ci[bot] authored Jun 2, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 4c42fc9 commit a645b99
Showing 15 changed files with 126 additions and 59 deletions.
30 changes: 18 additions & 12 deletions cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
@@ -55,7 +55,10 @@ void autograd() {
bool is_primal() const override {
return true;
}
SNode *grad_snode() const override {
SNode *adjoint_snode() const override {
return snode;
}
SNode *dual_snode() const override {
return snode;
}
};
@@ -66,7 +69,10 @@ void autograd() {
bool is_primal() const override {
return false;
}
SNode *grad_snode() const override {
SNode *adjoint_snode() const override {
return nullptr;
}
SNode *dual_snode() const override {
return nullptr;
}
};
@@ -76,8 +82,8 @@ void autograd() {
snode->dt = PrimitiveType::f32;
snode->grad_info = std::make_unique<GradInfoPrimal>(
&root->dense(Axis(0), n, false).insert_children(SNodeType::place));
snode->get_grad()->dt = PrimitiveType::f32;
snode->get_grad()->grad_info = std::make_unique<GradInfoAdjoint>();
snode->get_adjoint()->dt = PrimitiveType::f32;
snode->get_adjoint()->grad_info = std::make_unique<GradInfoAdjoint>();
return snode;
};
auto *a = get_snode_grad(), *b = get_snode_grad(), *c = get_snode_grad();
@@ -100,12 +106,12 @@ void autograd() {
builder.create_add(i, one));
builder.create_global_store(builder.create_global_ptr(c, {i}), zero);

builder.create_global_store(builder.create_global_ptr(a->get_grad(), {i}),
zero);
builder.create_global_store(builder.create_global_ptr(b->get_grad(), {i}),
zero);
builder.create_global_store(builder.create_global_ptr(c->get_grad(), {i}),
one);
builder.create_global_store(
builder.create_global_ptr(a->get_adjoint(), {i}), zero);
builder.create_global_store(
builder.create_global_ptr(b->get_adjoint(), {i}), zero);
builder.create_global_store(
builder.create_global_ptr(c->get_adjoint(), {i}), one);
}

kernel_init =
@@ -141,13 +147,13 @@ void autograd() {
auto *ext_a = builder.create_external_ptr(
builder.create_arg_load(0, PrimitiveType::f32, true), {i});
auto *a_grad_i = builder.create_global_load(
builder.create_global_ptr(a->get_grad(), {i}));
builder.create_global_ptr(a->get_adjoint(), {i}));
builder.create_global_store(ext_a, a_grad_i);

auto *ext_b = builder.create_external_ptr(
builder.create_arg_load(1, PrimitiveType::f32, true), {i});
auto *b_grad_i = builder.create_global_load(
builder.create_global_ptr(b->get_grad(), {i}));
builder.create_global_ptr(b->get_adjoint(), {i}));
builder.create_global_store(ext_b, b_grad_i);

auto *ext_c = builder.create_external_ptr(
30 changes: 27 additions & 3 deletions python/taichi/lang/field.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,8 @@ def __init__(self, _vars):
self.vars = _vars
self.host_accessors = None
self.grad = None
self.adjoint = None
self.dual = None

@property
def snode(self):
@@ -92,13 +94,35 @@ def _loop_range(self):
"""
return self.vars[0].ptr

def _set_grad(self, grad):
"""Sets corresponding gradient field.
def _set_grad(self, grad, reverse_mode=True):
"""Binds corresponding gradient field to adjoint or dual.
Args:
grad (Field): Corresponding gradient field.
reverse_mode (Bool): set for reverse or forward mode
"""
self.grad = grad
if reverse_mode:
self._set_adjoint(grad)
self.grad = self.adjoint
else:
self._set_dual(grad)
self.grad = self.dual

def _set_adjoint(self, adjoint):
"""Sets corresponding adjoint field (reverse mode).
Args:
adjoint (Field): Corresponding adjoint field.
"""
self.adjoint = adjoint

def _set_dual(self, dual):
"""Sets corresponding dual field (forward mode).
Args:
dual (Field): Corresponding dual field.
"""
self.dual = dual

@python_scope
def fill(self, val):
22 changes: 11 additions & 11 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
@@ -496,16 +496,16 @@ def create_field_member(dtype, name):
x.ptr.set_is_primal(True)
pytaichi.global_vars.append(x)

x_grad = None
x_adjoint = None
if _ti_core.needs_grad(dtype):
# adjoint
x_grad = Expr(get_runtime().prog.make_id_expr(""))
x_grad.ptr = _ti_core.global_new(x_grad.ptr, dtype)
x_grad.ptr.set_name(name + ".grad")
x_grad.ptr.set_is_primal(False)
x.ptr.set_grad(x_grad.ptr)
x_adjoint = Expr(get_runtime().prog.make_id_expr(""))
x_adjoint.ptr = _ti_core.global_new(x_adjoint.ptr, dtype)
x_adjoint.ptr.set_name(name + ".grad")
x_adjoint.ptr.set_is_primal(False)
x.ptr.set_adjoint(x_adjoint.ptr)

return x, x_grad
return x, x_adjoint


@python_scope
@@ -552,15 +552,15 @@ def field(dtype, shape=None, name="", offset=None, needs_grad=False):
assert (offset is None or shape
is not None), 'The shape cannot be None when offset is being set'

x, x_grad = create_field_member(dtype, name)
x, x_grad = ScalarField(x), ScalarField(x_grad)
x._set_grad(x_grad)
x, x_adjoint = create_field_member(dtype, name)
x, x_adjoint = ScalarField(x), ScalarField(x_adjoint)
x._set_grad(x_adjoint, reverse_mode=True)

if shape is not None:
dim = len(shape)
root.dense(index_nd(dim), shape).place(x, offset=offset)
if needs_grad:
root.dense(index_nd(dim), shape).place(x_grad)
root.dense(index_nd(dim), shape).place(x_adjoint)
return x


13 changes: 7 additions & 6 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
@@ -1117,10 +1117,10 @@ def field(cls,
else:
for _ in range(n * m):
entries.append(impl.create_field_member(dtype, name=name))
entries, entries_grad = zip(*entries)
entries, entries_grad = MatrixField(entries, n, m), MatrixField(
entries_grad, n, m)
entries._set_grad(entries_grad)
entries, entries_adjoint = zip(*entries)
entries, entries_adjoint = MatrixField(entries, n, m), MatrixField(
entries_adjoint, n, m)
entries._set_grad(entries_adjoint, reverse_mode=True)
impl.get_runtime().matrix_fields.append(entries)

if shape is None:
@@ -1143,7 +1143,7 @@ def field(cls,
impl.root.dense(impl.index_nd(dim),
shape).place(ScalarField(e), offset=offset)
if needs_grad:
for e in entries_grad._get_field_members():
for e in entries_adjoint._get_field_members():
impl.root.dense(impl.index_nd(dim),
shape).place(ScalarField(e),
offset=offset)
@@ -1152,7 +1152,8 @@ def field(cls,
offset=offset)
if needs_grad:
impl.root.dense(impl.index_nd(dim),
shape).place(entries_grad, offset=offset)
shape).place(entries_adjoint,
offset=offset)
return entries

@classmethod
2 changes: 1 addition & 1 deletion python/taichi/lang/misc.py
Original file line number Diff line number Diff line change
@@ -692,7 +692,7 @@ def Tape(loss, clear_gradients=True):
if len(loss.shape) != 0:
raise RuntimeError(
'The loss of `Tape` must be a 0-D field, i.e. scalar')
if not loss.snode.ptr.has_grad():
if not loss.snode.ptr.has_adjoint():
raise RuntimeError(
'Gradients of loss are not allocated, please use ti.field(..., needs_grad=True)'
' for all fields that are required by autodiff.')
1 change: 1 addition & 0 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
@@ -136,6 +136,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->ambient_value);
emit(expr->is_primal);
emit(expr->adjoint);
emit(expr->dual);
}

void visit(GlobalPtrExpression *expr) override {
7 changes: 5 additions & 2 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
@@ -112,8 +112,11 @@ static void get_offline_cache_key_of_snode_impl(
serializer(snode->ambient_val.stringify());
}
if (snode->grad_info && !snode->grad_info->is_primal()) {
if (auto *grad_snode = snode->grad_info->grad_snode()) {
get_offline_cache_key_of_snode_impl(grad_snode, serializer, visited);
if (auto *adjoint_snode = snode->grad_info->adjoint_snode()) {
get_offline_cache_key_of_snode_impl(adjoint_snode, serializer, visited);
}
if (auto *dual_snode = snode->grad_info->dual_snode()) {
get_offline_cache_key_of_snode_impl(dual_snode, serializer, visited);
}
}
if (snode->exp_snode) {
6 changes: 5 additions & 1 deletion taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
@@ -50,10 +50,14 @@ SNode *Expr::snode() const {
return cast<GlobalVariableExpression>()->snode;
}

void Expr::set_grad(const Expr &o) {
void Expr::set_adjoint(const Expr &o) {
this->cast<GlobalVariableExpression>()->adjoint.set(o);
}

void Expr::set_dual(const Expr &o) {
this->cast<GlobalVariableExpression>()->dual.set(o);
}

Expr::Expr(int16 x) : Expr() {
expr = std::make_shared<ConstExpression>(PrimitiveType::i16, x);
}
4 changes: 3 additions & 1 deletion taichi/ir/expr.h
Original file line number Diff line number Diff line change
@@ -93,7 +93,9 @@ class Expr {
// traceback for type checking error message
void set_tb(const std::string &tb);

void set_grad(const Expr &o);
void set_adjoint(const Expr &o);

void set_dual(const Expr &o);

void set_attribute(const std::string &key, const std::string &value);

1 change: 1 addition & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
@@ -439,6 +439,7 @@ class GlobalVariableExpression : public Expression {
TypedConstant ambient_value;
bool is_primal{true};
Expr adjoint;
Expr dual;

GlobalVariableExpression(DataType dt, const Identifier &ident)
: ident(ident), dt(dt) {
19 changes: 14 additions & 5 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
@@ -304,13 +304,22 @@ bool SNode::is_primal() const {
return grad_info->is_primal();
}

bool SNode::has_grad() const {
return is_primal() && (grad_info->grad_snode() != nullptr);
bool SNode::has_adjoint() const {
return is_primal() && (grad_info->adjoint_snode() != nullptr);
}

SNode *SNode::get_grad() const {
TI_ASSERT(has_grad());
return grad_info->grad_snode();
bool SNode::has_dual() const {
return is_primal() && (grad_info->dual_snode() != nullptr);
}

SNode *SNode::get_adjoint() const {
TI_ASSERT(has_adjoint());
return grad_info->adjoint_snode();
}

SNode *SNode::get_dual() const {
TI_ASSERT(has_dual());
return grad_info->dual_snode();
}

void SNode::set_snode_tree_id(int id) {
11 changes: 8 additions & 3 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
@@ -93,7 +93,8 @@ class SNode {
public:
virtual ~GradInfoProvider() = default;
virtual bool is_primal() const = 0;
virtual SNode *grad_snode() const = 0;
virtual SNode *adjoint_snode() const = 0;
virtual SNode *dual_snode() const = 0;

template <typename T>
T *cast() {
@@ -286,9 +287,13 @@ class SNode {

bool is_scalar() const;

bool has_grad() const;
bool has_adjoint() const;

SNode *get_grad() const;
SNode *get_adjoint() const;

bool has_dual() const;

SNode *get_dual() const;

SNode *get_least_sparse_ancestor() const;

13 changes: 11 additions & 2 deletions taichi/program/snode_expr_utils.cpp
Original file line number Diff line number Diff line change
@@ -16,14 +16,22 @@ class GradInfoImpl final : public SNode::GradInfoProvider {
return glb_var_->is_primal;
}

SNode *grad_snode() const override {
SNode *adjoint_snode() const override {
auto &adj = glb_var_->adjoint;
if (adj.expr == nullptr) {
return nullptr;
}
return adj.snode();
}

SNode *dual_snode() const override {
auto &dual = glb_var_->dual;
if (dual.expr == nullptr) {
return nullptr;
}
return dual.snode();
}

private:
GlobalVariableExpression *glb_var_;
};
@@ -102,8 +110,9 @@ void make_lazy_grad(SNode *snode, SNodeGlobalVarExprMap *snode_to_exprs) {
}
std::vector<Expr> new_grads;
for (auto &c : snode->ch) {
// TODO: handle the dual SNode
if (c->type == SNodeType::place && c->is_primal() && needs_grad(c->dt) &&
!c->has_grad()) {
!c->has_adjoint()) {
new_grads.push_back(snode_to_exprs->at(c.get())->adjoint);
}
}
6 changes: 4 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
@@ -512,7 +512,8 @@ void export_lang(py::module &m) {
.def("read_int", &SNode::read_int)
.def("read_uint", &SNode::read_uint)
.def("read_float", &SNode::read_float)
.def("has_grad", &SNode::has_grad)
.def("has_adjoint", &SNode::has_adjoint)
.def("has_dual", &SNode::has_dual)
.def("is_primal", &SNode::is_primal)
.def("is_place", &SNode::is_place)
.def("get_expr", &SNode::get_expr)
@@ -662,7 +663,8 @@ void export_lang(py::module &m) {
[&](Expr *expr, bool v) {
expr->cast<GlobalVariableExpression>()->is_primal = v;
})
.def("set_grad", &Expr::set_grad)
.def("set_adjoint", &Expr::set_adjoint)
.def("set_dual", &Expr::set_dual)
.def("set_attribute", &Expr::set_attribute)
.def("get_ret_type", &Expr::get_ret_type)
.def("type_check", &Expr::type_check)
Loading

0 comments on commit a645b99

Please sign in to comment.