Skip to content

Commit

Permalink
Transformer for Concat (dotnet#896)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zruty0 authored Sep 14, 2018
1 parent a8cd341 commit 4cb7dd9
Show file tree
Hide file tree
Showing 22 changed files with 815 additions and 919 deletions.
25 changes: 16 additions & 9 deletions src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,15 @@ public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMa
{
private sealed class Bindings : ColumnBindingsBase
{
private readonly RowToRowMapperTransform _parent;
private readonly IRowMapper _mapper;
public readonly RowMapperColumnInfo[] OutputColInfos;

public Bindings(ISchema inputSchema, RowToRowMapperTransform parent)
: base(inputSchema, true, Contracts.CheckRef(parent, nameof(parent))._mapper.GetOutputColumns().Select(info => info.Name).ToArray())
public Bindings(ISchema inputSchema, IRowMapper mapper)
: base(inputSchema, true, Contracts.CheckRef(mapper, nameof(mapper)).GetOutputColumns().Select(info => info.Name).ToArray())
{
Contracts.AssertValue(parent);
_parent = parent;
OutputColInfos = _parent._mapper.GetOutputColumns().ToArray();
Contracts.AssertValue(mapper);
_mapper = mapper;
OutputColInfos = _mapper.GetOutputColumns().ToArray();
}

protected override ColumnType GetColumnTypeCore(int iinfo)
Expand All @@ -168,7 +168,7 @@ public bool[] GetActive(Func<int, bool> predicate, out Func<int, bool> predicate
var predicateOut = GetActiveOutputColumns(active);

// Now map those to active input columns.
var predicateIn = _parent._mapper.GetDependencies(predicateOut);
var predicateIn = _mapper.GetDependencies(predicateOut);

// Combine the two sets of input columns.
predicateInput =
Expand Down Expand Up @@ -255,7 +255,14 @@ public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper
{
Contracts.CheckValue(mapper, nameof(mapper));
_mapper = mapper;
_bindings = new Bindings(input.Schema, this);
_bindings = new Bindings(input.Schema, mapper);
}

public static ISchema GetOutputSchema(ISchema inputSchema, IRowMapper mapper)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.CheckValue(mapper, nameof(mapper));
return new Bindings(inputSchema, mapper);
}

private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
Expand All @@ -265,7 +272,7 @@ private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView inpu
// _mapper

ctx.LoadModel<IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
_bindings = new Bindings(input.Schema, this);
_bindings = new Bindings(input.Schema, _mapper);
}

public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env,
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

var xf = new ConcatTransform(env, input, input.Data);
var xf = ConcatTransform.Create(env, input, input.Data);
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf };
}

Expand Down
97 changes: 1 addition & 96 deletions src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,11 @@

using Microsoft.ML.Core.Data;
using Microsoft.ML.Data.StaticPipe.Runtime;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using System;
using System.Collections.Generic;
using System.Linq;

[assembly: LoadableClass(typeof(ConcatTransformer), null, typeof(SignatureLoadModel),
"Concat Transformer Wrapper", ConcatTransformer.LoaderSignature)]

namespace Microsoft.ML.Runtime.Data
{
public sealed class ConcatEstimator : IEstimator<ITransformer>
Expand All @@ -41,11 +34,7 @@ public ConcatEstimator(IHostEnvironment env, string name, params string[] source
public ITransformer Fit(IDataView input)
{
_host.CheckValue(input, nameof(input));

var xf = new ConcatTransform(_host, input, _name, _source);
var empty = new EmptyDataView(_host, input.Schema);
var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_host, xf, empty, input);
return new ConcatTransformer(_host, chunk);
return new ConcatTransform(_host, _name, _source);
}

private bool HasCategoricals(SchemaShape.Column col)
Expand Down Expand Up @@ -123,90 +112,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
}
}

// REVIEW: Note that the presence of this thing is a temporary measure only.
// If it is cleaned up by code complete so much the better, but if not we will
// have to wait a little bit.
internal sealed class ConcatTransformer : ITransformer, ICanSaveModel
{
public const string LoaderSignature = "ConcatTransformWrapper";
private const string TransformDirTemplate = "Step_{0:000}";

private readonly IHostEnvironment _env;
private readonly IDataView _xf;

internal ConcatTransformer(IHostEnvironment env, IDataView xf)
{
_env = env;
_xf = xf;
}

public ISchema GetOutputSchema(ISchema inputSchema)
{
var dv = new EmptyDataView(_env, inputSchema);
var output = ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, dv);
return output.Schema;
}

public void Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

var dataPipe = _xf;
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
transforms.Reverse();

ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_env, c, dataPipe.Schema));

ctx.Writer.Write(transforms.Count);
for (int i = 0; i < transforms.Count; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(transforms[i], dirName);
}
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "CCATWRPR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature);
}

public ConcatTransformer(IHostEnvironment env, ModelLoadContext ctx)
{
ctx.CheckAtModel(GetVersionInfo());
int n = ctx.Reader.ReadInt32();

ctx.LoadModel<IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

IDataView data = loader;
for (int i = 0; i < n; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
data = xf;
}

_env = env;
_xf = data;
}

public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input);
}

/// <summary>
/// The extension methods and implementation support for concatenating columns together.
/// </summary>
Expand Down
Loading

0 comments on commit 4cb7dd9

Please sign in to comment.