Skip to content

Commit

Permalink
[autodiff] Add gradient visited for global data access rule checker
Browse files Browse the repository at this point in the history
ghstack-source-id: 2e48ef445fe758d22789d0acd2d56d4823cc9364
Pull Request resolved: taichi-dev#5569
  • Loading branch information
erizmr authored and Ailing Zhang committed Aug 6, 2022
1 parent aefa42b commit 15448ca
Show file tree
Hide file tree
Showing 16 changed files with 205 additions and 32 deletions.
6 changes: 6 additions & 0 deletions cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ void autograd() {
SNode *dual_snode() const override {
return snode;
}
SNode *adjoint_visited_snode() const override {
return nullptr;
}
};
class GradInfoAdjoint final : public SNode::GradInfoProvider {
public:
Expand All @@ -75,6 +78,9 @@ void autograd() {
SNode *dual_snode() const override {
return nullptr;
}
SNode *adjoint_visited_snode() const override {
return nullptr;
}
};

auto *snode =
Expand Down
6 changes: 6 additions & 0 deletions python/taichi/_snode/fields_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def lazy_grad(self):
self.empty = False
self.root.lazy_grad()

def _allocate_grad_visited(self):
"""Same as :func:`taichi.lang.snode.SNode._allocate_grad_visited`"""
self._check_not_finalized()
self.empty = False
self.root._allocate_grad_visited()

def lazy_dual(self):
"""Same as :func:`taichi.lang.snode.SNode.lazy_dual`"""
# TODO: This complicates the implementation. Figure out why we need this
Expand Down
9 changes: 8 additions & 1 deletion python/taichi/ad/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This module supplies two decorators for users to customize their
gradient computation task.
"""
import warnings
from functools import reduce

from taichi.lang import impl
Expand All @@ -13,7 +14,7 @@


class Tape:
def __init__(self, loss=None, clear_gradients=True):
def __init__(self, loss=None, clear_gradients=True, validation=False):
"""A context manager for reverse mode autodiff :class:`~taichi.ad.Tape`. The
context manager would catching all of the callings of functions that
decorated by :func:`~taichi.lang.kernel_impl.kernel` or
Expand All @@ -28,6 +29,7 @@ def __init__(self, loss=None, clear_gradients=True):
Args:
loss(:class:`~taichi.lang.expr.Expr`): The loss field, which shape should be ().
clear_gradients(Bool): Before `with` body start, clear all gradients or not.
validation(Bool): Check whether the code inside the context manager is autodiff valid, e.g., agree with the global data access rule.
Example::
Expand All @@ -43,7 +45,12 @@ def __init__(self, loss=None, clear_gradients=True):
self.entered = False
self.gradient_evaluated = False
self.clear_gradients = clear_gradients
self.validation = validation
self.runtime = impl.get_runtime()
if not self.runtime.prog.config.debug and self.validation:
warnings.warn(
"Debug mode is disabled, autodiff valid check will not work. Please specify `ti.init(debug=True)` to enable the check.",
Warning)
self.eval_on_exit = loss is not None
self.loss = loss

Expand Down
24 changes: 23 additions & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from taichi.lang.struct import Struct, StructField, _IntermediateStruct
from taichi.lang.util import (cook_dtype, get_traceback, is_taichi_class,
python_scope, taichi_scope, warning)
from taichi.types.primitive_types import (all_types, f16, f32, f64, i32, i64,
from taichi.types.primitive_types import (all_types, f16, f32, f64, i32, i64, u8,
u32, u64)


Expand Down Expand Up @@ -338,6 +338,12 @@ def _check_gradient_field_not_placed(self, gradient_type):
'\n\n x = ti.field(float, shape=(2, 3), needs_{gradient_type}=True)'
)

@staticmethod
def _allocate_gradient_visited():
if root.finalized:
return
root._allocate_grad_visited()

def _check_matrix_field_member_shape(self):
for _field in self.matrix_fields:
shapes = [
Expand All @@ -355,6 +361,8 @@ def _calc_matrix_field_dynamic_index_stride(self):
_field._calc_dynamic_index_stride()

def materialize(self):
if get_runtime().prog.config.debug:
self._allocate_gradient_visited()
self.materialize_root_fb(not self.materialized)
self.materialized = True

Expand Down Expand Up @@ -560,6 +568,8 @@ def create_field_member(dtype, name, needs_grad, needs_dual):

x_grad = None
x_dual = None
# The x_grad_visited is used for global data access rule checker
x_grad_visited = None
if _ti_core.is_real(dtype):
# adjoint
x_grad = Expr(get_runtime().prog.make_id_expr(""))
Expand All @@ -571,6 +581,18 @@ def create_field_member(dtype, name, needs_grad, needs_dual):
if needs_grad:
pytaichi.grad_vars.append(x_grad)

if prog.config.debug:
# adjoint flag
x_grad_visited = Expr(get_runtime().prog.make_id_expr(""))
dtype = u8
if prog.config.arch in (_ti_core.opengl, _ti_core.vulkan):
dtype = i32
x_grad_visited.ptr = _ti_core.global_new(x_grad_visited.ptr,
cook_dtype(dtype))
x_grad_visited.ptr.set_name(name + ".grad_visited")
x_grad_visited.ptr.set_is_primal(False)
x.ptr.set_adjoint_visited(x_grad_visited.ptr)

# dual
x_dual = Expr(get_runtime().prog.make_id_expr(""))
x_dual.ptr = _ti_core.global_new(x_dual.ptr, dtype)
Expand Down
9 changes: 7 additions & 2 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,17 @@ def lazy_grad(self):
To know more details about primal, adjoint fields and ``lazy_grad()``,
please see Page 4 and Page 13-14 of DiffTaichi Paper: https://arxiv.org/pdf/1910.00935.pdf
"""
self.ptr.lazy_grad(True, False)
self.ptr.lazy_grad()

def lazy_dual(self):
"""Automatically place the dual fields following the layout of their primal fields.
"""
self.ptr.lazy_grad(False, True)
self.ptr.lazy_dual()

def _allocate_grad_visited(self):
"""Automatically place the adjoint flag fields following the layout of their primal fields for global data access rule checker
"""
self.ptr.allocate_grad_visited()

def parent(self, n=1):
"""Gets an ancestor of `self` in the SNode tree.
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ void Expr::set_dual(const Expr &o) {
this->cast<GlobalVariableExpression>()->dual.set(o);
}

void Expr::set_adjoint_visited(const Expr &o) {
this->cast<GlobalVariableExpression>()->adjoint_visited.set(o);
}

Expr::Expr(int16 x) : Expr() {
expr = std::make_shared<ConstExpression>(PrimitiveType::i16, x);
}
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class Expr {

void set_dual(const Expr &o);

void set_adjoint_visited(const Expr &o);

void set_attribute(const std::string &key, const std::string &value);

std::string get_attribute(const std::string &key) const;
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ class GlobalVariableExpression : public Expression {
bool is_primal{true};
Expr adjoint;
Expr dual;
Expr adjoint_visited;

GlobalVariableExpression(DataType dt, const Identifier &ident)
: ident(ident), dt(dt) {
Expand Down
43 changes: 43 additions & 0 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,40 @@ bool SNode::need_activation() const {
type == SNodeType::bitmasked || type == SNodeType::dynamic;
}

void SNode::lazy_grad() {
make_lazy_place(
this, snode_to_glb_var_exprs_,
[this](std::unique_ptr<SNode> &c, std::vector<Expr> &new_grads) {
if (c->type == SNodeType::place && c->is_primal() && is_real(c->dt) &&
!c->has_adjoint()) {
new_grads.push_back(snode_to_glb_var_exprs_->at(c.get())->adjoint);
}
});
}

void SNode::lazy_dual() {
make_lazy_place(
this, snode_to_glb_var_exprs_,
[this](std::unique_ptr<SNode> &c, std::vector<Expr> &new_duals) {
if (c->type == SNodeType::place && c->is_primal() && is_real(c->dt) &&
!c->has_dual()) {
new_duals.push_back(snode_to_glb_var_exprs_->at(c.get())->dual);
}
});
}

void SNode::allocate_grad_visited() {
make_lazy_place(
this, snode_to_glb_var_exprs_,
[this](std::unique_ptr<SNode> &c, std::vector<Expr> &new_grad_visiteds) {
if (c->type == SNodeType::place && c->is_primal() && is_real(c->dt) &&
c->has_adjoint()) {
new_grad_visiteds.push_back(
snode_to_glb_var_exprs_->at(c.get())->adjoint_visited);
}
});
}

bool SNode::is_primal() const {
return grad_info && grad_info->is_primal();
}
Expand All @@ -291,6 +325,10 @@ bool SNode::has_adjoint() const {
return is_primal() && (grad_info->adjoint_snode() != nullptr);
}

bool SNode::has_adjoint_visited() const {
return is_primal() && (grad_info->adjoint_visited_snode() != nullptr);
}

bool SNode::has_dual() const {
return is_primal() && (grad_info->dual_snode() != nullptr);
}
Expand All @@ -300,6 +338,11 @@ SNode *SNode::get_adjoint() const {
return grad_info->adjoint_snode();
}

SNode *SNode::get_adjoint_visited() const {
// TI_ASSERT(has_adjoint());
return grad_info->adjoint_visited_snode();
}

SNode *SNode::get_dual() const {
TI_ASSERT(has_dual());
return grad_info->dual_snode();
Expand Down
13 changes: 10 additions & 3 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class SNode {
virtual bool is_primal() const = 0;
virtual SNode *adjoint_snode() const = 0;
virtual SNode *dual_snode() const = 0;
virtual SNode *adjoint_visited_snode() const = 0;

template <typename T>
T *cast() {
Expand Down Expand Up @@ -283,6 +284,10 @@ class SNode {

SNode *get_adjoint() const;

bool has_adjoint_visited() const;

SNode *get_adjoint_visited() const;

bool has_dual() const;

SNode *get_dual() const;
Expand Down Expand Up @@ -324,9 +329,11 @@ class SNode {
place_child(&expr, offset, id_in_bit_struct, this, snode_to_glb_var_exprs_);
}

void lazy_grad(bool is_adjoint, bool is_dual) {
make_lazy_grad(this, snode_to_glb_var_exprs_, is_adjoint, is_dual);
}
void lazy_grad();

void lazy_dual();

void allocate_grad_visited();

int64 read_int(const std::vector<int> &i);
uint64 read_uint(const std::vector<int> &i);
Expand Down
35 changes: 16 additions & 19 deletions taichi/program/snode_expr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ class GradInfoImpl final : public SNode::GradInfoProvider {
return dual.snode();
}

SNode *adjoint_visited_snode() const override {
auto &adjoint_visited = glb_var_->adjoint_visited;
if (adjoint_visited.expr == nullptr) {
return nullptr;
}
return adjoint_visited.snode();
}

private:
GlobalVariableExpression *glb_var_;
};
Expand Down Expand Up @@ -73,31 +81,20 @@ void place_child(Expr *expr_arg,
}
}

void make_lazy_grad(SNode *snode,
SNodeGlobalVarExprMap *snode_to_exprs,
bool is_adjoint,
bool is_dual) {
void make_lazy_place(SNode *snode,
SNodeGlobalVarExprMap *snode_to_exprs,
const std::function<void(std::unique_ptr<SNode> &,
std::vector<Expr> &)> &collect) {
if (snode->type == SNodeType::place)
return;
for (auto &c : snode->ch) {
make_lazy_grad(c.get(), snode_to_exprs, is_adjoint, is_dual);
make_lazy_place(c.get(), snode_to_exprs, collect);
}
std::vector<Expr> new_grads;
std::vector<Expr> new_places;
for (auto &c : snode->ch) {
if (is_adjoint) {
if (c->type == SNodeType::place && c->is_primal() && is_real(c->dt) &&
!c->has_adjoint()) {
new_grads.push_back(snode_to_exprs->at(c.get())->adjoint);
}
}
if (is_dual) {
if (c->type == SNodeType::place && c->is_primal() && is_real(c->dt) &&
!c->has_dual()) {
new_grads.push_back(snode_to_exprs->at(c.get())->dual);
}
}
collect(c, new_places);
}
for (auto p : new_grads) {
for (auto p : new_places) {
place_child(&p, /*offset=*/{}, -1, snode, snode_to_exprs);
}
}
Expand Down
9 changes: 5 additions & 4 deletions taichi/program/snode_expr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <memory>
#include <unordered_map>
#include <functional>
#include <vector>

// This file groups the set of helpers that need the Expr associated with a
Expand All @@ -24,10 +25,10 @@ void place_child(Expr *expr_arg,
SNode *parent,
SNodeGlobalVarExprMap *snode_to_exprs);

void make_lazy_grad(SNode *snode,
SNodeGlobalVarExprMap *snode_to_exprs,
bool is_adjoint,
bool is_dual);
void make_lazy_place(SNode *snode,
SNodeGlobalVarExprMap *snode_to_exprs,
const std::function<void(std::unique_ptr<SNode> &,
std::vector<Expr> &)> &collect);

} // namespace lang
} // namespace taichi
4 changes: 4 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,13 @@ void export_lang(py::module &m) {
[](SNode *snode, int i) -> SNode * { return snode->ch[i].get(); },
py::return_value_policy::reference)
.def("lazy_grad", &SNode::lazy_grad)
.def("lazy_dual", &SNode::lazy_dual)
.def("allocate_grad_visited", &SNode::allocate_grad_visited)
.def("read_int", &SNode::read_int)
.def("read_uint", &SNode::read_uint)
.def("read_float", &SNode::read_float)
.def("has_adjoint", &SNode::has_adjoint)
.def("has_adjoint_visited", &SNode::has_adjoint_visited)
.def("has_dual", &SNode::has_dual)
.def("is_primal", &SNode::is_primal)
.def("is_place", &SNode::is_place)
Expand Down Expand Up @@ -713,6 +716,7 @@ void export_lang(py::module &m) {
expr->cast<GlobalVariableExpression>()->is_primal = v;
})
.def("set_adjoint", &Expr::set_adjoint)
.def("set_adjoint_visited", &Expr::set_adjoint_visited)
.def("set_dual", &Expr::set_dual)
.def("set_attribute", &Expr::set_attribute)
.def(
Expand Down
Loading

0 comments on commit 15448ca

Please sign in to comment.