Skip to content

Commit b9f8a0a

Browse files
implement auto featurizer (#6205)
1 parent 7ae1c5d commit b9f8a0a

File tree

6 files changed

+328
-19
lines changed

6 files changed

+328
-19
lines changed

src/Microsoft.ML.AutoML/API/AutoCatalog.cs

Lines changed: 130 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Diagnostics.Contracts;
8+
using System.Linq;
79
using Microsoft.ML.AutoML.CodeGen;
810
using Microsoft.ML.Data;
11+
using Microsoft.ML.Runtime;
912
using Microsoft.ML.SearchSpace;
1013
using Microsoft.ML.Trainers.FastTree;
1114

@@ -538,55 +541,164 @@ public SweepableEstimator[] Regression(string labelColumnName = DefaultColumnNam
538541
/// <param name="inputColumnName">input column name.</param>
539542
internal SweepableEstimator[] TextFeaturizer(string outputColumnName, string inputColumnName)
540543
{
541-
throw new NotImplementedException();
544+
var option = new FeaturizeTextOption
545+
{
546+
InputColumnName = inputColumnName,
547+
OutputColumnName = outputColumnName,
548+
};
549+
550+
return new[] { SweepableEstimatorFactory.CreateFeaturizeText(option) };
542551
}
543552

544553
/// <summary>
545554
/// Create a list of <see cref="SweepableEstimator"/> for featurizing numeric columns.
546555
/// </summary>
547-
/// <param name="outputColumnName">output column name.</param>
548-
/// <param name="inputColumnName">input column name.</param>
549-
internal SweepableEstimator[] NumericFeaturizer(string outputColumnName, string inputColumnName)
556+
/// <param name="outputColumnNames">output column names.</param>
557+
/// <param name="inputColumnNames">input column names.</param>
558+
internal SweepableEstimator[] NumericFeaturizer(string[] outputColumnNames, string[] inputColumnNames)
550559
{
551-
throw new NotImplementedException();
560+
Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames));
561+
Contracts.CheckValue(outputColumnNames, nameof(outputColumnNames));
562+
Contracts.Check(outputColumnNames.Count() == inputColumnNames.Count() && outputColumnNames.Count() > 0, "outputColumnNames and inputColumnNames must have the same length and greater than 0");
563+
var replaceMissingValueOption = new ReplaceMissingValueOption
564+
{
565+
InputColumnNames = inputColumnNames,
566+
OutputColumnNames = outputColumnNames,
567+
};
568+
569+
return new[] { SweepableEstimatorFactory.CreateReplaceMissingValues(replaceMissingValueOption) };
552570
}
553571

554572
/// <summary>
555573
/// Create a list of <see cref="SweepableEstimator"/> for featurizing catalog columns.
556574
/// </summary>
557-
/// <param name="outputColumnName">output column name.</param>
558-
/// <param name="inputColumnName">input column name.</param>
559-
internal SweepableEstimator[] CatalogFeaturizer(string outputColumnName, string inputColumnName)
575+
/// <param name="outputColumnNames">output column names.</param>
576+
/// <param name="inputColumnNames">input column names.</param>
577+
internal SweepableEstimator[] CatalogFeaturizer(string[] outputColumnNames, string[] inputColumnNames)
560578
{
561-
throw new NotImplementedException();
579+
Contracts.Check(outputColumnNames.Count() == inputColumnNames.Count() && outputColumnNames.Count() > 0, "outputColumnNames and inputColumnNames must have the same length and greater than 0");
580+
581+
var option = new OneHotOption
582+
{
583+
InputColumnNames = inputColumnNames,
584+
OutputColumnNames = outputColumnNames,
585+
};
586+
587+
return new SweepableEstimator[] { SweepableEstimatorFactory.CreateOneHotEncoding(option), SweepableEstimatorFactory.CreateOneHotHashEncoding(option) };
562588
}
563589

564590
/// <summary>
565591
/// Create a single featurize pipeline according to <paramref name="data"/>. This function will collect all columns in <paramref name="data"/> and not in <paramref name="excludeColumns"/>,
566-
/// featurizing them using <see cref="CatalogFeaturizer(string, string)"/>, <see cref="NumericFeaturizer(string, string)"/> or <see cref="TextFeaturizer(string, string)"/>. And combine
592+
/// featurizing them using <see cref="CatalogFeaturizer(string[], string[])"/>, <see cref="NumericFeaturizer(string[], string[])"/> or <see cref="TextFeaturizer(string, string)"/>. And combine
567593
/// them into a single feature column as output.
568594
/// </summary>
569595
/// <param name="data">input data.</param>
570596
/// <param name="catalogColumns">columns that should be treated as catalog. If not specified, it will automatically infer if a column is catalog or not.</param>
597+
/// <param name="numericColumns">columns that should be treated as numeric. If not specified, it will automatically infer if a column is catalog or not.</param>
598+
/// <param name="textColumns">columns that should be treated as text. If not specified, it will automatically infer if a column is catalog or not.</param>
571599
/// <param name="outputColumnName">output feature column.</param>
572600
/// <param name="excludeColumns">columns that won't be included when featurizing, like label</param>
573-
internal MultiModelPipeline Featurizer(IDataView data, string outputColumnName = "Features", string[] catalogColumns = null, string[] excludeColumns = null)
601+
public MultiModelPipeline Featurizer(IDataView data, string outputColumnName = "Features", string[] catalogColumns = null, string[] numericColumns = null, string[] textColumns = null, string[] excludeColumns = null)
574602
{
575-
throw new NotImplementedException();
603+
Contracts.CheckValue(data, nameof(data));
604+
605+
// validate if there's overlapping among catalogColumns, numericColumns, textColumns and excludeColumns
606+
var overallColumns = new string[][] { catalogColumns, numericColumns, textColumns, excludeColumns }
607+
.Where(c => c != null)
608+
.SelectMany(c => c);
609+
610+
if (overallColumns != null)
611+
{
612+
Contracts.Assert(overallColumns.Count() == overallColumns.Distinct().Count(), "detect overlapping among catalogColumns, numericColumns, textColumns and excludedColumns");
613+
}
614+
615+
var columnInfo = new ColumnInformation();
616+
617+
if (excludeColumns != null)
618+
{
619+
foreach (var ignoreColumn in excludeColumns)
620+
{
621+
columnInfo.IgnoredColumnNames.Add(ignoreColumn);
622+
}
623+
}
624+
625+
if (catalogColumns != null)
626+
{
627+
foreach (var catalogColumn in catalogColumns)
628+
{
629+
columnInfo.CategoricalColumnNames.Add(catalogColumn);
630+
}
631+
}
632+
633+
if (numericColumns != null)
634+
{
635+
foreach (var column in numericColumns)
636+
{
637+
columnInfo.NumericColumnNames.Add(column);
638+
}
639+
}
640+
641+
if (textColumns != null)
642+
{
643+
foreach (var column in textColumns)
644+
{
645+
columnInfo.TextColumnNames.Add(column);
646+
}
647+
}
648+
649+
return this.Featurizer(data, columnInfo, outputColumnName);
576650
}
577651

578652
/// <summary>
579-
/// Create a single featurize pipeline according to <paramref name="columnInformation"/>. This function will collect all columns in <paramref name="columnInformation"/> and not in <paramref name="excludeColumns"/>,
580-
/// featurizing them using <see cref="CatalogFeaturizer(string, string)"/>, <see cref="NumericFeaturizer(string, string)"/> or <see cref="TextFeaturizer(string, string)"/>. And combine
653+
/// Create a single featurize pipeline according to <paramref name="columnInformation"/>. This function will collect all columns in <paramref name="columnInformation"/>,
654+
/// featurizing them using <see cref="CatalogFeaturizer(string[], string[])"/>, <see cref="NumericFeaturizer(string[], string[])"/> or <see cref="TextFeaturizer(string, string)"/>. And combine
581655
/// them into a single feature column as output.
582656
/// </summary>
657+
/// <param name="data">input data.</param>
583658
/// <param name="columnInformation">column information.</param>
584659
/// <param name="outputColumnName">output feature column.</param>
585-
/// <param name="excludeColumns">columns that won't be included when featurizing, like label</param>
586-
/// <returns></returns>
587-
internal MultiModelPipeline Featurizer(ColumnInformation columnInformation, string outputColumnName = "Features", string[] excludeColumns = null)
660+
/// <returns>A <see cref="MultiModelPipeline"/> for featurization.</returns>
661+
public MultiModelPipeline Featurizer(IDataView data, ColumnInformation columnInformation, string outputColumnName = "Features")
588662
{
589-
throw new NotImplementedException();
663+
Contracts.CheckValue(data, nameof(data));
664+
Contracts.CheckValue(columnInformation, nameof(columnInformation));
665+
666+
var columnPurposes = PurposeInference.InferPurposes(this._context, data, columnInformation);
667+
var textFeatures = columnPurposes.Where(c => c.Purpose == ColumnPurpose.TextFeature);
668+
var numericFeatures = columnPurposes.Where(c => c.Purpose == ColumnPurpose.NumericFeature);
669+
var catalogFeatures = columnPurposes.Where(c => c.Purpose == ColumnPurpose.CategoricalFeature);
670+
var textFeatureColumnNames = textFeatures.Select(c => data.Schema[c.ColumnIndex].Name).ToArray();
671+
var numericFeatureColumnNames = numericFeatures.Select(c => data.Schema[c.ColumnIndex].Name).ToArray();
672+
var catalogFeatureColumnNames = catalogFeatures.Select(c => data.Schema[c.ColumnIndex].Name).ToArray();
673+
674+
var pipeline = new MultiModelPipeline();
675+
if (numericFeatureColumnNames.Length > 0)
676+
{
677+
pipeline = pipeline.Append(this.NumericFeaturizer(numericFeatureColumnNames, numericFeatureColumnNames));
678+
}
679+
680+
if (catalogFeatureColumnNames.Length > 0)
681+
{
682+
pipeline = pipeline.Append(this.CatalogFeaturizer(catalogFeatureColumnNames, catalogFeatureColumnNames));
683+
}
684+
685+
foreach (var textColumn in textFeatureColumnNames)
686+
{
687+
pipeline = pipeline.Append(this.TextFeaturizer(textColumn, textColumn));
688+
}
689+
690+
var option = new ConcatOption
691+
{
692+
InputColumnNames = textFeatureColumnNames.Concat(numericFeatureColumnNames).Concat(catalogFeatureColumnNames).ToArray(),
693+
OutputColumnName = outputColumnName,
694+
};
695+
696+
if (option.InputColumnNames.Length > 0)
697+
{
698+
pipeline = pipeline.Append(SweepableEstimatorFactory.CreateConcatenate(option));
699+
}
700+
701+
return pipeline;
590702
}
591703
}
592704
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"schema": "e0 * e1",
3+
"estimators": {
4+
"e0": {
5+
"estimatorType": "ReplaceMissingValues",
6+
"parameter": {
7+
"OutputColumnNames": [
8+
"col1",
9+
"col2",
10+
"col3",
11+
"col4"
12+
],
13+
"InputColumnNames": [
14+
"col1",
15+
"col2",
16+
"col3",
17+
"col4"
18+
]
19+
}
20+
},
21+
"e1": {
22+
"estimatorType": "Concatenate",
23+
"parameter": {
24+
"InputColumnNames": [
25+
"col1",
26+
"col2",
27+
"col3",
28+
"col4"
29+
],
30+
"OutputColumnName": "Features"
31+
}
32+
}
33+
}
34+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{
2+
"schema": "e0 * (e1 \u002B e2) * e3",
3+
"estimators": {
4+
"e0": {
5+
"estimatorType": "ReplaceMissingValues",
6+
"parameter": {
7+
"OutputColumnNames": [
8+
"Features"
9+
],
10+
"InputColumnNames": [
11+
"Features"
12+
]
13+
}
14+
},
15+
"e1": {
16+
"estimatorType": "OneHotEncoding",
17+
"parameter": {
18+
"OutputColumnNames": [
19+
"Workclass",
20+
"education",
21+
"marital-status",
22+
"occupation",
23+
"relationship",
24+
"ethnicity",
25+
"sex",
26+
"native-country-region"
27+
],
28+
"InputColumnNames": [
29+
"Workclass",
30+
"education",
31+
"marital-status",
32+
"occupation",
33+
"relationship",
34+
"ethnicity",
35+
"sex",
36+
"native-country-region"
37+
]
38+
}
39+
},
40+
"e2": {
41+
"estimatorType": "OneHotHashEncoding",
42+
"parameter": {
43+
"OutputColumnNames": [
44+
"Workclass",
45+
"education",
46+
"marital-status",
47+
"occupation",
48+
"relationship",
49+
"ethnicity",
50+
"sex",
51+
"native-country-region"
52+
],
53+
"InputColumnNames": [
54+
"Workclass",
55+
"education",
56+
"marital-status",
57+
"occupation",
58+
"relationship",
59+
"ethnicity",
60+
"sex",
61+
"native-country-region"
62+
]
63+
}
64+
},
65+
"e3": {
66+
"estimatorType": "Concatenate",
67+
"parameter": {
68+
"InputColumnNames": [
69+
"Features",
70+
"Workclass",
71+
"education",
72+
"marital-status",
73+
"occupation",
74+
"relationship",
75+
"ethnicity",
76+
"sex",
77+
"native-country-region"
78+
],
79+
"OutputColumnName": "OutputFeature"
80+
}
81+
}
82+
}
83+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using System.Text.Json;
9+
using Microsoft.ML.TestFramework;
10+
using Xunit;
11+
using Xunit.Abstractions;
12+
using ApprovalTests;
13+
using ApprovalTests.Namers;
14+
using ApprovalTests.Reporters;
15+
using System.Text.Json.Serialization;
16+
17+
namespace Microsoft.ML.AutoML.Test
18+
{
19+
public class AutoFeaturizerTests : BaseTestClass
20+
{
21+
private readonly JsonSerializerOptions _jsonSerializerOptions;
22+
23+
public AutoFeaturizerTests(ITestOutputHelper output)
24+
: base(output)
25+
{
26+
_jsonSerializerOptions = new JsonSerializerOptions()
27+
{
28+
WriteIndented = true,
29+
Converters =
30+
{
31+
new JsonStringEnumConverter(), new DoubleToDecimalConverter(), new FloatToDecimalConverter(),
32+
},
33+
};
34+
35+
if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
36+
{
37+
Approvals.UseAssemblyLocationForApprovedFiles();
38+
}
39+
}
40+
41+
[Fact]
42+
[UseReporter(typeof(DiffReporter))]
43+
[UseApprovalSubdirectory("ApprovalTests")]
44+
public void AutoFeaturizer_uci_adult_test()
45+
{
46+
var context = new MLContext(1);
47+
var dataset = DatasetUtil.GetUciAdultDataView();
48+
var pipeline = context.Auto().Featurizer(dataset, outputColumnName: "OutputFeature", excludeColumns: new[] { "Label" });
49+
50+
Approvals.Verify(JsonSerializer.Serialize(pipeline, _jsonSerializerOptions));
51+
}
52+
53+
[Fact]
54+
[UseReporter(typeof(DiffReporter))]
55+
[UseApprovalSubdirectory("ApprovalTests")]
56+
public void AutoFeaturizer_iris_test()
57+
{
58+
var context = new MLContext(1);
59+
var dataset = DatasetUtil.GetIrisDataView();
60+
var pipeline = context.Auto().Featurizer(dataset, excludeColumns: new[] { "Label" });
61+
62+
Approvals.Verify(JsonSerializer.Serialize(pipeline, _jsonSerializerOptions));
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)