Skip to content

Commit

Permalink
Introduce equal_assertion to reduce more boundary checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Apr 7, 2021
1 parent 8a18c4d commit cf1a0d5
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 80 deletions.
15 changes: 14 additions & 1 deletion lib/nnc/_ccv_nnc_micro.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ typedef struct {
int16_t id;
} ccv_nnc_micro_id_t;

typedef struct {
ccv_nnc_micro_id_t left;
ccv_nnc_micro_id_t right;
} ccv_nnc_micro_id_equal_assertion_t;

enum {
CCV_NNC_MICRO_LOOP_INDEX_TYPE_NONE,
CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID,
Expand Down Expand Up @@ -208,6 +213,7 @@ typedef uint32_t(*ccv_nnc_micro_scalar_lookup_f)(const void* const context, cons
struct ccv_nnc_micro_io_vtab_s {
void (*bind_scalars)(const ccv_nnc_micro_io_t self, ccv_nnc_micro_scalar_lookup_f lookup, const void* const context); /**< Bind scalar name to a scoped id. */
void (*numbering)(const ccv_nnc_micro_io_t self, const int id, const int var_count); /**< Assign id to the output of this micro op. */
void (*equal_assertions)(const ccv_nnc_micro_io_t self, ccv_array_t* const equal_assertions); /**< Collect assertions about id equal. */
ccv_nnc_micro_function_t (*emit)(const ccv_nnc_micro_io_t self); /**< Emit instructions for this micro op. */
ccv_nnc_micro_function_t (*emit_grad)(const ccv_nnc_micro_io_t self, const int var_count); /**< Emit backward instructions for this micro op. */
ccv_nnc_micro_tensor_t (*return_shape)(const ccv_nnc_micro_io_t self); /**< The shape of the returned tensor. */
Expand All @@ -232,6 +238,13 @@ static inline void ccv_nnc_micro_numbering(const ccv_nnc_micro_io_t self, const
self->id = id;
}

static inline void ccv_nnc_micro_equal_assertions(const ccv_nnc_micro_io_t self, ccv_array_t* const equal_assertions)
{
const ccv_nnc_micro_io_vtab_t* const isa = self->isa;
if (isa->equal_assertions)
isa->equal_assertions(self, equal_assertions);
}

static inline void ccv_nnc_micro_bind_scalars(const ccv_nnc_micro_io_t self, ccv_nnc_micro_scalar_lookup_f lookup, const void* const context)
{
const ccv_nnc_micro_io_vtab_t* const isa = self->isa;
Expand Down Expand Up @@ -463,7 +476,7 @@ static inline ccv_nnc_micro_loop_carried_t ccv_nnc_micro_loop_carried(const uint
}

// This method has to be mutable for efficiency reasons. Hence I kept it private.
void ccv_nnc_micro_program_simplify(ccv_nnc_micro_program_t* const program, const ccv_nnc_micro_io_t* const inputs, const int input_size, const ccv_nnc_micro_io_t* const outputs, const int output_size);
void ccv_nnc_micro_program_simplify(ccv_nnc_micro_program_t* const program, const ccv_nnc_micro_io_t* const inputs, const int input_size, const ccv_nnc_micro_io_t* const outputs, const int output_size, const ccv_array_t* const equal_assertions);
ccv_nnc_micro_loop_index_term_t ccv_nnc_micro_loop_index_deep_copy(const ccv_nnc_micro_loop_index_term_t* const term);
void ccv_nnc_micro_loop_index_free(ccv_nnc_micro_loop_index_term_t* const term);
void ccv_nnc_micro_loop_variable_free(ccv_nnc_micro_loop_variable_t* const var);
Expand Down
9 changes: 6 additions & 3 deletions lib/nnc/ccv_nnc_micro.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ CCV_WARN_UNUSED(ccv_nnc_micro_combine_t*) ccv_nnc_micro_combine_new(const ccv_nn
// Applying numbering for the inputs. Note that our variables are numbered in reverse topological order.
for (i = 0; i < input_size; i++)
ccv_nnc_micro_numbering(inputs[i], i, var_count);
// Applying numbering for the outputs.
ccv_array_t* const equal_assertions = ccv_array_new(sizeof(ccv_nnc_micro_id_equal_assertion_t), 0, 0);
// Applying numbering for the outputs and collect equal assertions.
for (i = reverse_top->rnum - 1; i >= 0; i--)
{
const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, reverse_top->rnum - 1 - i);
ccv_nnc_micro_numbering(output, i + input_size, var_count);
ccv_nnc_micro_equal_assertions(output, equal_assertions);
}
for (i = 0; i < ingrad_size; i++)
ccv_nnc_micro_numbering(ingrads[i], -1, var_count);
Expand Down Expand Up @@ -131,7 +133,7 @@ CCV_WARN_UNUSED(ccv_nnc_micro_combine_t*) ccv_nnc_micro_combine_new(const ccv_nn
memcpy(combine->forward.vars, vars, sizeof(ccv_nnc_micro_tensor_t) * var_count);
combine->forward.function_count = function_count;
combine->forward.functions = functions;
ccv_nnc_micro_program_simplify(&combine->forward, inputs, input_size, outputs, output_size);
ccv_nnc_micro_program_simplify(&combine->forward, inputs, input_size, outputs, output_size, equal_assertions);
function_count = reverse_top->rnum * 2;
functions = (ccv_nnc_micro_function_t*)ccmalloc(sizeof(ccv_nnc_micro_function_t) * function_count);
for (i = 0; i < reverse_top->rnum; i++)
Expand All @@ -156,7 +158,8 @@ CCV_WARN_UNUSED(ccv_nnc_micro_combine_t*) ccv_nnc_micro_combine_new(const ccv_nn
combine->backward.vars = vars;
combine->backward.function_count = function_count;
combine->backward.functions = functions;
ccv_nnc_micro_program_simplify(&combine->backward, ingrads, ingrad_size, outgrads, outgrad_size);
ccv_nnc_micro_program_simplify(&combine->backward, ingrads, ingrad_size, outgrads, outgrad_size, equal_assertions);
ccv_array_free(equal_assertions);
for (i = 0; i < reverse_top->rnum; i++)
{
const ccv_nnc_micro_io_t output = *(ccv_nnc_micro_io_t*)ccv_array_get(reverse_top, i);
Expand Down
83 changes: 70 additions & 13 deletions lib/nnc/ccv_nnc_micro_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ static int _index(const char** const pos, int* const remain_size, int* const id)
return 0;
}

static int _dim(const char** const pos, int* const remain_size, int* const id, int* const d)
static int _dim(const char** const pos, int* const remain_size, int* const id, int* const d, ccv_array_t* const equal_assertions)
{
if (!(*remain_size > 1 && pos[0][0] == 'd'))
return 0;
Expand All @@ -123,6 +123,32 @@ static int _dim(const char** const pos, int* const remain_size, int* const id, i
{
*remain_size -= size;
*pos += size;
while (_accept(pos, remain_size, " ", 1)) {}
if (_accept(pos, remain_size, "[", 1))
{
while (_accept(pos, remain_size, " ", 1)) {}
_expect(pos, remain_size, "=", 1);
while (_accept(pos, remain_size, " ", 1)) {}
int next_id;
int next_d;
if (!_dim(pos, remain_size, &next_id, &next_d, equal_assertions))
{ assert(0 && "unexpected symbol"); }
const ccv_nnc_micro_id_equal_assertion_t equal_assertion = {
.left = {
.type = CCV_NNC_MICRO_AXIS_SIZE_ID,
.id = -(*id + 1),
.d = *d
},
.right = {
.type = CCV_NNC_MICRO_AXIS_SIZE_ID,
.id = -(next_id + 1),
.d = next_d
}
};
ccv_array_push(equal_assertions, &equal_assertion);
while (_accept(pos, remain_size, " ", 1)) {}
_expect(pos, remain_size, "]", 1);
}
return 1;
}
return 0;
Expand Down Expand Up @@ -151,9 +177,9 @@ static int _var(const char** const pos, int* const remain_size, char** name)
return 0;
}

static CCV_WARN_UNUSED(ccv_nnc_micro_loop_index_term_t) _expression(const char** const pos, int* const remain_size);
static CCV_WARN_UNUSED(ccv_nnc_micro_loop_index_term_t) _expression(const char** const pos, int* const remain_size, ccv_array_t* const equal_assertions);

static ccv_nnc_micro_loop_index_term_t _factor(const char** const pos, int* const remain_size)
static ccv_nnc_micro_loop_index_term_t _factor(const char** const pos, int* const remain_size, ccv_array_t* const equal_assertions)
{
ccv_nnc_micro_loop_index_term_t term;
while (_accept(pos, remain_size, " ", 1)) {}
Expand All @@ -166,7 +192,7 @@ static ccv_nnc_micro_loop_index_term_t _factor(const char** const pos, int* cons
term.type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID;
term.id.type = CCV_NNC_MICRO_LOOP_ID;
term.id.id = id;
} else if (_dim(pos, remain_size, &id, &d)) {
} else if (_dim(pos, remain_size, &id, &d, equal_assertions)) {
term.type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_ID;
term.id.type = CCV_NNC_MICRO_AXIS_SIZE_ID;
term.id.d = d;
Expand All @@ -175,7 +201,7 @@ static ccv_nnc_micro_loop_index_term_t _factor(const char** const pos, int* cons
term.type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_UNBOUND_SCALAR;
term.name = name;
} else if (_accept(pos, remain_size, "(", 1)) {
term = _expression(pos, remain_size);
term = _expression(pos, remain_size, equal_assertions);
_expect(pos, remain_size, ")", 1);
} else {
assert(0 && "factor: syntax error");
Expand All @@ -184,17 +210,17 @@ static ccv_nnc_micro_loop_index_term_t _factor(const char** const pos, int* cons
return term;
}

static ccv_nnc_micro_loop_index_term_t _term(const char** const pos, int* const remain_size)
static ccv_nnc_micro_loop_index_term_t _term(const char** const pos, int* const remain_size, ccv_array_t* const equal_assertions)
{
while (_accept(pos, remain_size, " ", 1)) {}
ccv_nnc_micro_loop_index_term_t term = _factor(pos, remain_size);
ccv_nnc_micro_loop_index_term_t term = _factor(pos, remain_size, equal_assertions);
while (*remain_size > 0 && (pos[0][0] == '*' || pos[0][0] == '/'))
{
const int op = pos[0][0] == '*' ? CCV_NNC_MICRO_BINARY_OP_MUL : CCV_NNC_MICRO_BINARY_OP_DIV;
*remain_size -= 1;
*pos += 1;
const ccv_nnc_micro_loop_index_term_t left = term;
const ccv_nnc_micro_loop_index_term_t right = _factor(pos, remain_size);
const ccv_nnc_micro_loop_index_term_t right = _factor(pos, remain_size, equal_assertions);
term.type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY;
term.binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t));
term.binary->op = op;
Expand All @@ -205,7 +231,7 @@ static ccv_nnc_micro_loop_index_term_t _term(const char** const pos, int* const
return term;
}

static ccv_nnc_micro_loop_index_term_t _expression(const char** const pos, int* const remain_size)
static ccv_nnc_micro_loop_index_term_t _expression(const char** const pos, int* const remain_size, ccv_array_t* const equal_assertions)
{
while (_accept(pos, remain_size, " ", 1)) {}
int prefix_op = -1;
Expand All @@ -215,14 +241,14 @@ static ccv_nnc_micro_loop_index_term_t _expression(const char** const pos, int*
*remain_size -= 1;
*pos += 1;
}
ccv_nnc_micro_loop_index_term_t node = _term(pos, remain_size);
ccv_nnc_micro_loop_index_term_t node = _term(pos, remain_size, equal_assertions);
while (*remain_size > 0 && (pos[0][0] == '+' || pos[0][0] == '-'))
{
const int op = pos[0][0] == '+' ? CCV_NNC_MICRO_BINARY_OP_PLUS : CCV_NNC_MICRO_BINARY_OP_MINUS;
*remain_size -= 1;
*pos += 1;
const ccv_nnc_micro_loop_index_term_t left = node;
const ccv_nnc_micro_loop_index_term_t right = _term(pos, remain_size);
const ccv_nnc_micro_loop_index_term_t right = _term(pos, remain_size, equal_assertions);
node.type = CCV_NNC_MICRO_LOOP_INDEX_TYPE_BINARY;
node.binary = (ccv_nnc_micro_loop_index_binary_t*)ccmalloc(sizeof(ccv_nnc_micro_loop_index_binary_t));
node.binary->op = op;
Expand Down Expand Up @@ -282,6 +308,7 @@ struct ccv_nnc_micro_io_reindex_s {
ccv_nnc_micro_loop_index_term_t* shape;
ccv_nnc_micro_loop_index_term_t* reindex;
ccv_nnc_micro_io_t* ss;
ccv_array_t* equal_assertions;
};

static void _ccv_nnc_micro_reindex_numbering(const ccv_nnc_micro_io_t super, const int id, const int var_count)
Expand All @@ -299,6 +326,33 @@ static void _ccv_nnc_micro_reindex_numbering(const ccv_nnc_micro_io_t super, con
_sid_to_axis_size_term(&self->shape[i], sids, self->s_count);
for (i = 0; i < self->x->dimensions; i++)
_sid_to_axis_size_term(&self->reindex[i], sids, self->s_count);
for (i = 0; i < self->equal_assertions->rnum; i++)
{
ccv_nnc_micro_id_equal_assertion_t* const equal_assertion = (ccv_nnc_micro_id_equal_assertion_t*)ccv_array_get(self->equal_assertions, i);
if (equal_assertion->left.type == CCV_NNC_MICRO_AXIS_SIZE_ID && equal_assertion->left.id < 0)
{
const int id = -(equal_assertion->left.id + 1);
assert(id >= 0 && id < self->s_count);
equal_assertion->left.id = sids[id];
}
if (equal_assertion->right.type == CCV_NNC_MICRO_AXIS_SIZE_ID && equal_assertion->right.id < 0)
{
const int id = -(equal_assertion->right.id + 1);
assert(id >= 0 && id < self->s_count);
equal_assertion->right.id = sids[id];
}
}
}

static void _ccv_nnc_micro_reindex_equal_assertions(const ccv_nnc_micro_io_t super, ccv_array_t* const equal_assertions)
{
struct ccv_nnc_micro_io_reindex_s* const self = (struct ccv_nnc_micro_io_reindex_s*)super;
int i;
for (i = 0; i < self->equal_assertions->rnum; i++)
{
ccv_nnc_micro_id_equal_assertion_t* const equal_assertion = (ccv_nnc_micro_id_equal_assertion_t*)ccv_array_get(self->equal_assertions, i);
ccv_array_push(equal_assertions, equal_assertion);
}
}

static void _ccv_nnc_bind_scalars_in_term(ccv_nnc_micro_loop_index_term_t* const term, ccv_nnc_micro_scalar_lookup_f lookup, const void* const context)
Expand Down Expand Up @@ -442,10 +496,12 @@ static void _ccv_nnc_micro_reindex_deinit(const ccv_nnc_micro_io_t super)
int i;
for (i = 0; i < self->x->dimensions; i++)
ccv_nnc_micro_loop_index_free(&self->reindex[i]);
ccv_array_free(self->equal_assertions);
}

static const ccv_nnc_micro_io_vtab_t ccv_nnc_micro_io_reindex_isa = {
.numbering = _ccv_nnc_micro_reindex_numbering,
.equal_assertions = _ccv_nnc_micro_reindex_equal_assertions,
.bind_scalars = _ccv_nnc_micro_reindex_bind_scalars,
.emit = _ccv_nnc_micro_reindex_emit,
.emit_grad = _ccv_nnc_micro_reindex_emit_grad,
Expand Down Expand Up @@ -473,6 +529,7 @@ ccv_nnc_micro_io_t ccv_nnc_micro_reindex(const char* const* const shape, const i
self->super.input_size = s_count + 1;
if (s_count > 0)
memcpy(self->ss, ss, sizeof(ccv_nnc_micro_io_t) * s_count);
ccv_array_t* const equal_assertions = self->equal_assertions = ccv_array_new(sizeof(ccv_nnc_micro_id_equal_assertion_t), 0, 0);
// Parse shape into expressions and validate the grammar. Do this upfront so we don't fail on parsing
// later, which can be confusing.
// CFG:
Expand All @@ -490,7 +547,7 @@ ccv_nnc_micro_io_t ccv_nnc_micro_reindex(const char* const* const shape, const i
{
int remain_size = strlen(shape[i]);
const char* pos = shape[i];
ccv_nnc_micro_loop_index_term_t term = _expression(&pos, &remain_size);
ccv_nnc_micro_loop_index_term_t term = _expression(&pos, &remain_size, equal_assertions);
_no_index(term); // Make sure this is not index, no loop index.
self->shape[i] = term;
}
Expand All @@ -499,7 +556,7 @@ ccv_nnc_micro_io_t ccv_nnc_micro_reindex(const char* const* const shape, const i
{
int remain_size = strlen(reindex[i]);
const char* pos = reindex[i];
self->reindex[i] = _expression(&pos, &remain_size);
self->reindex[i] = _expression(&pos, &remain_size, equal_assertions);
}
return (ccv_nnc_micro_io_t)self;
}
Expand Down
Loading

0 comments on commit cf1a0d5

Please sign in to comment.