Skip to content

Commit

Permalink
Add two-variable scenario in Tensor shape inference for TensorflowTra…
Browse files Browse the repository at this point in the history
…nsform (dotnet#5257)

* init checkin

* remove

* modify comments

* fix

* add tests

* update

* remove uncessary lines

* fix one of the comment

* update tensorflow test data version

* remove uncessary depencency

* fix comments

* readability
  • Loading branch information
wangyems authored Jul 2, 2020
1 parent 81357ba commit c459af0
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
2 changes: 1 addition & 1 deletion build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
<MicrosoftExtensionsTestPackageVersion>3.0.1</MicrosoftExtensionsTestPackageVersion>
<MicrosoftMLTestDatabasesPackageVersion>0.0.6-test</MicrosoftMLTestDatabasesPackageVersion>
<MicrosoftMLTestModelsPackageVersion>0.0.6-test</MicrosoftMLTestModelsPackageVersion>
<MicrosoftMLTensorFlowTestModelsVersion>0.0.11-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLTensorFlowTestModelsVersion>0.0.12-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLOnnxTestModelsVersion>0.0.6-test</MicrosoftMLOnnxTestModelsVersion>
<SystemDataSqlClientVersion>4.6.1</SystemDataSqlClientVersion>
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>
Expand Down
14 changes: 12 additions & 2 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,18 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
if (typeValueCount % valCount != 0)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");

// This cover the 2-variable senario e.g. [?, ?, ?, C] where we can assume typeDims provides the information of [W, H, C]
// The shape will become [?, W, H, C]
var originalShapeDims = originalShape.dims;
var originalShapeNdim = originalShape.ndim;
if (numOfUnkDim == 3 && colTypeDims.Length == 3 && originalShapeNdim == numOfUnkDim + 1 && originalShapeDims[1] == -1)
{
originalShapeDims[1] = colTypeDims[0];
originalShapeDims[2] = colTypeDims[1];
valCount *= originalShapeDims[1] * originalShapeDims[2];
numOfUnkDim -= 2;
}

// If the shape is multi-dimensional, we should be able to create the length of the vector by plugging
// in a single value for the unknown shapes. For example, if the shape is [?,?,3], then there should exist a value
// d such that d*d*3 is equal to the length of the input column.
Expand All @@ -537,8 +549,6 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");

// Fill in the unknown dimensions.
var originalShapeNdim = originalShape.ndim;
var originalShapeDims = originalShape.dims;
var l = new int[originalShapeNdim];
for (int ishape = 0; ishape < originalShapeNdim; ishape++)
l[ishape] = originalShapeDims[ishape] == -1 ? (int)d : originalShapeDims[ishape];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1900,5 +1900,40 @@ private static string GetTemporaryDirectory()
Directory.CreateDirectory(tempDirectory);
return tempDirectory;
}

[TensorFlowFact]
public void TensorflowPlaceholderShapeInferenceTest()
{
//frozen_model_variadic_input_shape.pb is modified by frozen_model.pb
//the shape of placeholder is changed from [?, w, h, c] to [?, ?, ?, c]
string modelLocation = "cifar_model/frozen_model_variadic_input_shape.pb";

int imageHeight = 32;
int imageWidth = 32;
string dataFile = GetDataPath("images/images.tsv");
string imageFolder = Path.GetDirectoryName(dataFile);

IDataView data = _mlContext.Data.LoadFromTextFile(dataFile, new[] {
new TextLoader.Column("imagePath", DataKind.String, 0),
new TextLoader.Column("name", DataKind.String, 1)
});

Tensorflow.TensorShape[] tfInputShape;

using (var tfModel = _mlContext.Model.LoadTensorFlowModel(modelLocation))
{
var pipeline = _mlContext.Transforms.LoadImages("Input", imageFolder, "imagePath")
.Append(_mlContext.Transforms.ResizeImages("Input", imageHeight, imageWidth))
.Append(_mlContext.Transforms.ExtractPixels("Input", interleavePixelColors: true))
.Append(tfModel.ScoreTensorFlowModel("Output", "Input"));

var transformer = pipeline.Fit(data);

tfInputShape = transformer.LastTransformer.TFInputShapes;
}

Assert.Equal(imageHeight, tfInputShape.ElementAt(0)[1].dims[0]);
Assert.Equal(imageWidth, tfInputShape.ElementAt(0)[2].dims[0]);
}
}
}

0 comments on commit c459af0

Please sign in to comment.