Skip to content

Connect Inputs and Outputs Properly and Export CopyColumnTransform to (unofficial) ONNX operator #952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ public abstract class OnnxContext
/// <returns>A name that has not yet been returned from this function, starting with <paramref name="prefix"/></returns>
public abstract string GetNodeName(string prefix);

/// <summary>
/// Determine if a string has been used as ONNX variable name somewhere.
/// </summary>
/// <param name="variableName">examined string</param>
/// <returns>True if the input argument has been used to denote an ONNX variable. Otherwise, False.</returns>
public abstract bool IsVariableDefined(string variableName);

/// <summary>
/// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can
/// safely call <see cref="GetVariableName(string)"/>.
Expand Down
19 changes: 18 additions & 1 deletion src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(IDataTransform), typeof(CopyColumnsTransform),
Expand Down Expand Up @@ -159,11 +160,13 @@ public override void Save(ModelSaveContext ctx)
protected override IRowMapper MakeRowMapper(ISchema inputSchema)
=> new Mapper(this, inputSchema, ColumnPairs);

private sealed class Mapper : MapperBase
private sealed class Mapper : MapperBase, ISaveAsOnnx
{
private readonly ISchema _schema;
private readonly (string Source, string Name)[] _columns;

public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;

internal Mapper(CopyColumnsTransform parent, ISchema inputSchema, (string Source, string Name)[] columns)
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
{
Expand Down Expand Up @@ -197,6 +200,20 @@ public override RowMapperColumnInfo[] GetOutputColumns()
}
return result;
}

public void SaveAsOnnx(OnnxContext ctx)
{
var opType = "CSharp";

foreach (var column in _columns)
{
var srcVariableName = ctx.GetVariableName(column.Source);
_schema.TryGetColumnIndex(column.Source, out int colIndex);
var dstVariableName = ctx.AddIntermediateVariable(_schema.GetColumnType(colIndex), column.Name);
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("type", LoaderSignature);
}
}
}
}
}
15 changes: 6 additions & 9 deletions src/Microsoft.ML.Onnx/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,

public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);

public override bool IsVariableDefined(string variableName) => _variableNames.Contains(variableName);

/// <summary>
/// Stops tracking a column. If removeVariable is true then it also removes the
/// variable associated with it, this is useful in the event where an output variable is
Expand Down Expand Up @@ -200,7 +202,7 @@ public string TryGetVariableName(string colName)
/// </summary>
/// <param name="colName">IDataView column name.</param>
/// <returns>Unique variable name.</returns>
private string AddVariable(string colName)
public string AddVariable(string colName)
{
_host.CheckNonEmpty(colName, nameof(colName));
_columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains);
Expand All @@ -226,16 +228,11 @@ public override string AddIntermediateVariable(ColumnType type, string colName,
/// <summary>
/// Adds an output variable to the list.
/// </summary>
public string AddOutputVariable(ColumnType type, string colName, List<long> dim = null)
public void AddOutputVariable(ColumnType type, string variableName, List<long> dim = null)
{
_host.CheckValue(type, nameof(type));

if (!ContainsColumn(colName))
AddVariable(colName);

colName = GetVariableName(colName);
_outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim));
return colName;
_host.CheckParam(IsVariableDefined(variableName), nameof(variableName));
_outputs.Add(OnnxUtils.GetModelArgs(type, variableName, dim));
}

/// <summary>
Expand Down
12 changes: 8 additions & 4 deletions src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,17 @@ private void Run(IChannel ch)
if (end.Schema.IsHidden(i))
continue;

var idataviewColumnName = end.Schema.GetColumnName(i);;
if (_outputsToDrop.Contains(idataviewColumnName) || _inputsToDrop.Contains(idataviewColumnName))
var idataviewColumnName = end.Schema.GetColumnName(i);

// Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
// _inputToDrop should be removed too.
if (_inputsToDrop.Contains(idataviewColumnName) || _outputsToDrop.Contains(idataviewColumnName))
continue;

var variableName = ctx.TryGetVariableName(idataviewColumnName);
if (variableName != null)
ctx.AddOutputVariable(end.Schema.GetColumnType(i), variableName);
var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
ctx.AddOutputVariable(end.Schema.GetColumnType(i), trueVariableName);
}

var model = ctx.MakeModel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,36 @@
}
],
"domain": "ai.onnx.ml"
},
{
"input": [
"PredictedLabel"
],
"output": [
"PredictedLabel0"
],
"name": "Identity",
"opType": "Identity"
},
{
"input": [
"Score"
],
"output": [
"Score0"
],
"name": "Identity0",
"opType": "Identity"
},
{
"input": [
"Probability"
],
"output": [
"Probability0"
],
"name": "Identity1",
"opType": "Identity"
}
],
"name": "BinaryClassificationFastTreeSaveModelToOnnxTest",
Expand Down Expand Up @@ -383,7 +413,7 @@
],
"output": [
{
"name": "PredictedLabel",
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand All @@ -401,7 +431,7 @@
}
},
{
"name": "Score",
"name": "Score0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand All @@ -419,7 +449,7 @@
}
},
{
"name": "Probability",
"name": "Probability0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,36 @@
}
],
"domain": "ai.onnx.ml"
},
{
"input": [
"PredictedLabel"
],
"output": [
"PredictedLabel0"
],
"name": "Identity",
"opType": "Identity"
},
{
"input": [
"Score"
],
"output": [
"Score0"
],
"name": "Identity0",
"opType": "Identity"
},
{
"input": [
"Probability"
],
"output": [
"Probability0"
],
"name": "Identity1",
"opType": "Identity"
}
],
"name": "BinaryClassificationLRSaveModelToOnnxTest",
Expand All @@ -167,7 +197,7 @@
],
"output": [
{
"name": "PredictedLabel",
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand All @@ -185,7 +215,7 @@
}
},
{
"name": "Score",
"name": "Score0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand All @@ -203,7 +233,7 @@
}
},
{
"name": "Probability",
"name": "Probability0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,36 @@
}
],
"domain": "ai.onnx.ml"
},
{
"input": [
"PredictedLabel"
],
"output": [
"PredictedLabel0"
],
"name": "Identity",
"opType": "Identity"
},
{
"input": [
"Score"
],
"output": [
"Score0"
],
"name": "Identity0",
"opType": "Identity"
},
{
"input": [
"Probability"
],
"output": [
"Probability0"
],
"name": "Identity1",
"opType": "Identity"
}
],
"name": "BinaryClassificationLightGBMSaveModelToOnnxTest",
Expand All @@ -218,7 +248,7 @@
],
"output": [
{
"name": "PredictedLabel",
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand All @@ -236,7 +266,7 @@
}
},
{
"name": "Score",
"name": "Score0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand All @@ -254,7 +284,7 @@
}
},
{
"name": "Probability",
"name": "Probability0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,36 @@
}
],
"domain": "ai.onnx.ml"
},
{
"input": [
"PredictedLabel"
],
"output": [
"PredictedLabel0"
],
"name": "Identity",
"opType": "Identity"
},
{
"input": [
"Score"
],
"output": [
"Score0"
],
"name": "Identity0",
"opType": "Identity"
},
{
"input": [
"Probability"
],
"output": [
"Probability0"
],
"name": "Identity1",
"opType": "Identity"
}
],
"name": "KeyToVectorBag",
Expand Down Expand Up @@ -354,7 +384,7 @@
],
"output": [
{
"name": "PredictedLabel",
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand All @@ -372,7 +402,7 @@
}
},
{
"name": "Score",
"name": "Score0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand All @@ -390,7 +420,7 @@
}
},
{
"name": "Probability",
"name": "Probability0",
"type": {
"tensorType": {
"elemType": "FLOAT",
Expand Down
Loading