Skip to content

Commit

Permalink
[LogSoftmax]add LogSoftmax-1 and LogSoftmax-11 operator support
Browse files Browse the repository at this point in the history
  • Loading branch information
jianjunjiang committed Nov 23, 2020
1 parent 84a9883 commit 6b4d14b
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 34 deletions.
4 changes: 2 additions & 2 deletions documents/the-supported-operator-table.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@
|GreaterOrEqual-12||
|LessOrEqual-12||
|LogSoftmax-13||
|LogSoftmax-11||
|LogSoftmax-1||
|LogSoftmax-11||
|LogSoftmax-1||
|MeanVarianceNormalization-13||
|MeanVarianceNormalization-9||
|NegativeLogLikelihoodLoss-13||
Expand Down
260 changes: 228 additions & 32 deletions src/default/LogSoftmax.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <onnx.h>

struct operator_pdata_t
struct operator_13_pdata_t
{
int axis;

Expand All @@ -10,13 +10,13 @@ struct operator_pdata_t
int inner;
};

static int LogSoftmax_init(struct onnx_node_t * n)
static int LogSoftmax_13_init(struct onnx_node_t * n)
{
struct operator_pdata_t * pdat;
struct operator_13_pdata_t * pdat;

if((n->ninput == 1) && (n->noutput == 1))
{
pdat = malloc(sizeof(struct operator_pdata_t));
pdat = malloc(sizeof(struct operator_13_pdata_t));
if(pdat)
{
pdat->axis = onnx_attribute_read_int(n, "axis", -1);
Expand All @@ -27,18 +27,18 @@ static int LogSoftmax_init(struct onnx_node_t * n)
return 0;
}

static int LogSoftmax_exit(struct onnx_node_t * n)
static int LogSoftmax_13_exit(struct onnx_node_t * n)
{
struct operator_pdata_t * pdat = (struct operator_pdata_t *)n->priv;
struct operator_13_pdata_t * pdat = (struct operator_13_pdata_t *)n->priv;

if(pdat)
free(pdat);
return 1;
}

static int LogSoftmax_reshape(struct onnx_node_t * n)
static int LogSoftmax_13_reshape(struct onnx_node_t * n)
{
struct operator_pdata_t * pdat = (struct operator_pdata_t *)n->priv;
struct operator_13_pdata_t * pdat = (struct operator_13_pdata_t *)n->priv;
struct onnx_tensor_t * x = n->inputs[0];
struct onnx_tensor_t * y = n->outputs[0];
int i;
Expand All @@ -60,9 +60,9 @@ static int LogSoftmax_reshape(struct onnx_node_t * n)
return onnx_tensor_reshape_identity(y, x, x->type);
}

static void LogSoftmax_bfloat16(struct onnx_node_t * n)
static void LogSoftmax_13_bfloat16(struct onnx_node_t * n)
{
struct operator_pdata_t * pdat = (struct operator_pdata_t *)n->priv;
struct operator_13_pdata_t * pdat = (struct operator_13_pdata_t *)n->priv;
struct onnx_tensor_t * x = n->inputs[0];
struct onnx_tensor_t * y = n->outputs[0];
uint16_t * px = (uint16_t *)x->datas;
Expand Down Expand Up @@ -103,9 +103,9 @@ static void LogSoftmax_bfloat16(struct onnx_node_t * n)
}
}

static void LogSoftmax_float16(struct onnx_node_t * n)
static void LogSoftmax_13_float16(struct onnx_node_t * n)
{
struct operator_pdata_t * pdat = (struct operator_pdata_t *)n->priv;
struct operator_13_pdata_t * pdat = (struct operator_13_pdata_t *)n->priv;
struct onnx_tensor_t * x = n->inputs[0];
struct onnx_tensor_t * y = n->outputs[0];
uint16_t * px = (uint16_t *)x->datas;
Expand Down Expand Up @@ -146,9 +146,9 @@ static void LogSoftmax_float16(struct onnx_node_t * n)
}
}

static void LogSoftmax_float32(struct onnx_node_t * n)
static void LogSoftmax_13_float32(struct onnx_node_t * n)
{
struct operator_pdata_t * pdat = (struct operator_pdata_t *)n->priv;
struct operator_13_pdata_t * pdat = (struct operator_13_pdata_t *)n->priv;
struct onnx_tensor_t * x = n->inputs[0];
struct onnx_tensor_t * y = n->outputs[0];
float * px = (float *)x->datas;
Expand Down Expand Up @@ -186,9 +186,9 @@ static void LogSoftmax_float32(struct onnx_node_t * n)
}
}

static void LogSoftmax_float64(struct onnx_node_t * n)
static void LogSoftmax_13_float64(struct onnx_node_t * n)
{
struct operator_pdata_t * pdat = (struct operator_pdata_t *)n->priv;
struct operator_13_pdata_t * pdat = (struct operator_13_pdata_t *)n->priv;
struct onnx_tensor_t * x = n->inputs[0];
struct onnx_tensor_t * y = n->outputs[0];
double * px = (double *)x->datas;
Expand Down Expand Up @@ -226,44 +226,240 @@ static void LogSoftmax_float64(struct onnx_node_t * n)
}
}

struct operator_1_11_pdata_t {
int axis;

int N;
int D;
};

static int LogSoftmax_1_11_init(struct onnx_node_t * n)
{
struct operator_1_11_pdata_t * pdat;

if((n->ninput == 1) && (n->noutput == 1))
{
pdat = malloc(sizeof(struct operator_1_11_pdata_t));
if(pdat)
{
pdat->axis = onnx_attribute_read_int(n, "axis", 1);
n->priv = pdat;
return 1;
}
}
return 0;
}

static int LogSoftmax_1_11_exit(struct onnx_node_t * n)
{
struct operator_1_11_pdata_t * pdat = (struct operator_1_11_pdata_t *)n->priv;

if(pdat)
free(pdat);
return 1;
}

static int LogSoftmax_1_11_reshape(struct onnx_node_t * n)
{
struct operator_1_11_pdata_t * pdat = (struct operator_1_11_pdata_t *)n->priv;
struct onnx_tensor_t * x = n->inputs[0];
struct onnx_tensor_t * y = n->outputs[0];
int axis = pdat->axis;
int i;

if(axis < 0)
axis += x->ndim;
if(axis < 0 || axis >= x->ndim)
return 0;
for(i = 0, pdat->N = 1, pdat->D = 1; i < x->ndim; i++)
{
if(i < axis)
pdat->N *= x->dims[i];
else
pdat->D *= x->dims[i];
}
return onnx_tensor_reshape_identity(y, x, x->type);
}

static void LogSoftmax_1_11_float16(struct onnx_node_t * n)
{
struct operator_1_11_pdata_t * pdat = (struct operator_1_11_pdata_t *)n->priv;
struct onnx_tensor_t * x = n->inputs[0];
struct onnx_tensor_t * y = n->outputs[0];
uint16_t * px = (uint16_t *)x->datas;
uint16_t * py = (uint16_t *)y->datas;
float maxv, sum, v;
int i, j, o;

for(i = 0, o = 0; i < pdat->N; i++, o += pdat->D)
{
for(j = 0, maxv = FLT_MIN; j < pdat->D; j++)
{
v = float16_to_float32(px[o + j]);
if(v > maxv)
maxv = v;
}
for(j = 0, sum = 0; j < pdat->D; j++)
{
v = expf(float16_to_float32(px[o + j]) - maxv);
py[o + j] = float32_to_float16(v);
sum += v;
}
if(sum != 0)
{
for(j = 0; j < pdat->D; j++)
{
v = float16_to_float32(py[o + j]);
py[o + j] = float32_to_float16(logf(v / sum));
}
}
}
}

static void LogSoftmax_1_11_float32(struct onnx_node_t * n)
{
struct operator_1_11_pdata_t * pdat = (struct operator_1_11_pdata_t *)n->priv;
struct onnx_tensor_t * x = n->inputs[0];
struct onnx_tensor_t * y = n->outputs[0];
float * px = (float *)x->datas;
float * py = (float *)y->datas;
float maxv, sum;
int i, j, o;

for(i = 0, o = 0; i < pdat->N; i++, o += pdat->D)
{
for(j = 0, maxv = FLT_MIN; j < pdat->D; j++)
{
if(px[o + j] > maxv)
maxv = px[o + j];
}
for(j = 0, sum = 0; j < pdat->D; j++)
{
py[o + j] = expf(px[o + j] - maxv);
sum += py[o + j];
}
if(sum != 0)
{
for(j = 0; j < pdat->D; j++)
py[o + j] = logf(py[o + j] / sum);
}
}
}

static void LogSoftmax_1_11_float64(struct onnx_node_t * n)
{
struct operator_1_11_pdata_t * pdat = (struct operator_1_11_pdata_t *)n->priv;
struct onnx_tensor_t * x = n->inputs[0];
struct onnx_tensor_t * y = n->outputs[0];
double * px = (double *)x->datas;
double * py = (double *)y->datas;
double maxv, sum;
int i, j, o;

for(i = 0, o = 0; i < pdat->N; i++, o += pdat->D)
{
for(j = 0, maxv = DBL_MIN; j < pdat->D; j++)
{
if(px[o + j] > maxv)
maxv = px[o + j];
}
for(j = 0, sum = 0; j < pdat->D; j++)
{
py[o + j] = exp(px[o + j] - maxv);
sum += py[o + j];
}
if(sum != 0)
{
for(j = 0; j < pdat->D; j++)
py[o + j] = log(py[o + j] / sum);
}
}
}

void resolver_default_op_LogSoftmax(struct onnx_node_t * n)
{
if(n->opset >= 13)
{
switch(n->inputs[0]->type)
{
case ONNX_TENSOR_TYPE_BFLOAT16:
n->init = LogSoftmax_init;
n->exit = LogSoftmax_exit;
n->reshape = LogSoftmax_reshape;
n->operator = LogSoftmax_bfloat16;
n->init = LogSoftmax_13_init;
n->exit = LogSoftmax_13_exit;
n->reshape = LogSoftmax_13_reshape;
n->operator = LogSoftmax_13_bfloat16;
break;
case ONNX_TENSOR_TYPE_FLOAT16:
n->init = LogSoftmax_init;
n->exit = LogSoftmax_exit;
n->reshape = LogSoftmax_reshape;
n->operator = LogSoftmax_float16;
n->init = LogSoftmax_13_init;
n->exit = LogSoftmax_13_exit;
n->reshape = LogSoftmax_13_reshape;
n->operator = LogSoftmax_13_float16;
break;
case ONNX_TENSOR_TYPE_FLOAT32:
n->init = LogSoftmax_init;
n->exit = LogSoftmax_exit;
n->reshape = LogSoftmax_reshape;
n->operator = LogSoftmax_float32;
n->init = LogSoftmax_13_init;
n->exit = LogSoftmax_13_exit;
n->reshape = LogSoftmax_13_reshape;
n->operator = LogSoftmax_13_float32;
break;
case ONNX_TENSOR_TYPE_FLOAT64:
n->init = LogSoftmax_init;
n->exit = LogSoftmax_exit;
n->reshape = LogSoftmax_reshape;
n->operator = LogSoftmax_float64;
n->init = LogSoftmax_13_init;
n->exit = LogSoftmax_13_exit;
n->reshape = LogSoftmax_13_reshape;
n->operator = LogSoftmax_13_float64;
break;
default:
break;
}
}
else if(n->opset >= 11)
{
switch(n->inputs[0]->type)
{
case ONNX_TENSOR_TYPE_FLOAT16:
n->init = LogSoftmax_1_11_init;
n->exit = LogSoftmax_1_11_exit;
n->reshape = LogSoftmax_1_11_reshape;
n->operator = LogSoftmax_1_11_float16;
break;
case ONNX_TENSOR_TYPE_FLOAT32:
n->init = LogSoftmax_1_11_init;
n->exit = LogSoftmax_1_11_exit;
n->reshape = LogSoftmax_1_11_reshape;
n->operator = LogSoftmax_1_11_float32;
break;
case ONNX_TENSOR_TYPE_FLOAT64:
n->init = LogSoftmax_1_11_init;
n->exit = LogSoftmax_1_11_exit;
n->reshape = LogSoftmax_1_11_reshape;
n->operator = LogSoftmax_1_11_float64;
break;
default:
break;
}
}
else if(n->opset >= 1)
{
switch(n->inputs[0]->type)
{
case ONNX_TENSOR_TYPE_FLOAT16:
n->init = LogSoftmax_1_11_init;
n->exit = LogSoftmax_1_11_exit;
n->reshape = LogSoftmax_1_11_reshape;
n->operator = LogSoftmax_1_11_float16;
break;
case ONNX_TENSOR_TYPE_FLOAT32:
n->init = LogSoftmax_1_11_init;
n->exit = LogSoftmax_1_11_exit;
n->reshape = LogSoftmax_1_11_reshape;
n->operator = LogSoftmax_1_11_float32;
break;
case ONNX_TENSOR_TYPE_FLOAT64:
n->init = LogSoftmax_1_11_init;
n->exit = LogSoftmax_1_11_exit;
n->reshape = LogSoftmax_1_11_reshape;
n->operator = LogSoftmax_1_11_float64;
break;
default:
break;
}
}
}

0 comments on commit 6b4d14b

Please sign in to comment.