Skip to content

Commit

Permalink
Fix for MulticlassNaiveBayesTrainer export to Onnx (dotnet#4928)
Browse files Browse the repository at this point in the history
* adding support for batch input dim
  • Loading branch information
Lynx1820 authored Mar 11, 2020
1 parent f6cdf57 commit ed481b6
Showing 1 changed file with 41 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
}

var one = ctx.AddInitializer(1.0f, "one");
var oneInt = ctx.AddInitializer(1, typeof(int), "oneInt");
var zero = ctx.AddInitializer(0.0f, "zero");
var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount");
var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
Expand All @@ -414,108 +415,119 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");

var greaterOutput = ctx.AddIntermediateVariable(null, "greaterOutput", true);
var typeOne = new VectorDataViewType(NumberDataViewType.Single, 1);
var typeFea = new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length);
var typeLabelByFea = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, _featureHistogram[0].Length);
var typeLabelByOne = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, 1);

var greaterOutput = ctx.AddIntermediateVariable(new VectorDataViewType(BooleanDataViewType.Instance, _featureHistogram[0].Length), "greaterOutput");
var opType = "Greater";
ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var isFeaturePresent = ctx.AddIntermediateVariable(null, "isFeaturePresent", true);
var node = ctx.CreateNode(opType, greaterOutput, isFeaturePresent, ctx.GetNodeName(opType), "");
var castOutput = ctx.AddIntermediateVariable(typeFea, "CastOutput");
var node = ctx.CreateNode(opType, greaterOutput, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
node.AddAttribute("to", t);

opType = "ExpandDims";
var isFeaturePresent = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1, _featureHistogram[0].Length), "isFeaturePresent");
ctx.CreateNode(opType, new[] { castOutput, oneInt }, new[] { isFeaturePresent }, ctx.GetNodeName(opType), "com.microsoft");

//initialize logProb
opType = "Div";
var divOutput = ctx.AddIntermediateVariable(null, "DivOutput", true);
var divOutput = ctx.AddIntermediateVariable(typeOne, "DivOutput");
ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");

opType = "Log";
var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
var logOutput = ctx.AddIntermediateVariable(typeOne, "LogOutput");
ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");

//log1
opType = "Sum";
var sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
var sumOutput = ctx.AddIntermediateVariable(_inputType, "SumOutput");
ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

var logOutput1 = ctx.AddIntermediateVariable(null, "LogOutput", true);
var logOutput1 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");
LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);

//log2
opType = "Transpose";
var labelHistogramTrans = ctx.AddIntermediateVariable(null, "transpose", true);
var labelHistogramTrans = ctx.AddIntermediateVariable(typeFea, "Transpose");
ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");

opType = "Sub";
var absentFeatureCount = ctx.AddIntermediateVariable(null, "AbsentFeatureCounts", true);
var absentFeatureCount = ctx.AddIntermediateVariable(typeFea, "AbsentFeatureCounts");
ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");

opType = "Sum";
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

var logOutput2 = ctx.AddIntermediateVariable(null, "LogOutput", true);
var logOutput2 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");
LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);

//log3
opType = "Sum";
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

var logOutput3 = ctx.AddIntermediateVariable(null, "LogOutput", true);
var logOutput3 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");
LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);

//result
opType = "Sub";
var logProb = ctx.AddIntermediateVariable(null, "LogProb", true);
var logProb = ctx.AddIntermediateVariable(typeLabelByFea, "LogProb");
ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");

opType = "Sub";
var absentFeatureLogProb = ctx.AddIntermediateVariable(null, "AbsentFeatureLogProb", true);
var absentFeatureLogProb = ctx.AddIntermediateVariable(typeLabelByFea, "AbsentFeatureLogProb");
ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");

opType = "ReduceSum";
var logProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
var logProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");
node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
long[] list = { 1 };
long[] list = { 2 };
node.AddAttribute("axes", list);

opType = "ReduceSum";
var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");
node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", list);

opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(null, "CastOutput2", true);
castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastOutput");
node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
node.AddAttribute("to", t);

opType = "Sub";
var subOutput = ctx.AddIntermediateVariable(null, "SubOutput", true);
var subOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SubOutput");
ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");

opType = "Sum";
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
sumOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SumOutput");
ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

opType = "Transpose";
var transposeOutput = ctx.AddIntermediateVariable(null, "TransposeOutput", true);
ctx.CreateNode(opType, new[] { sumOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");
opType = "Squeeze";
var squeezeNode = ctx.CreateNode(opType, sumOutput, outputNames[1], ctx.GetNodeName(opType), "");
squeezeNode.AddAttribute("axes", new long[] { 2 });

opType = "ArgMax";
var scoreIndex = ctx.AddIntermediateVariable(null, "ScoreIndex", true);
ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
var scoreIndex = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, 1), "ScoreIndex");
node = ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
node.AddAttribute("axis", 1);
node.AddAttribute("keepdims", 0);

opType = "Cast";
castOutput = ctx.AddIntermediateVariable(null, "CastOutput3", true);
castOutput = ctx.AddIntermediateVariable(typeOne, "CastOutput");
node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
node.AddAttribute("to", t);

//log3
opType = "Sum";
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
sumOutput = ctx.AddIntermediateVariable(typeOne, "SumOutput");
ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
Expand All @@ -529,7 +541,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output)
{
var opType = "Log";
var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
var logOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length), "LogOutput");
ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), "");

opType = "Mul";
Expand Down

0 comments on commit ed481b6

Please sign in to comment.