|
4 | 4 |
|
5 | 5 | using System;
|
6 | 6 | using System.Collections.Generic;
|
| 7 | +using System.Diagnostics.Contracts; |
| 8 | +using System.Linq; |
7 | 9 | using Microsoft.ML.AutoML.CodeGen;
|
8 | 10 | using Microsoft.ML.Data;
|
| 11 | +using Microsoft.ML.Runtime; |
9 | 12 | using Microsoft.ML.SearchSpace;
|
10 | 13 | using Microsoft.ML.Trainers.FastTree;
|
11 | 14 |
|
@@ -538,55 +541,164 @@ public SweepableEstimator[] Regression(string labelColumnName = DefaultColumnNam
|
538 | 541 | /// <param name="inputColumnName">input column name.</param>
|
539 | 542 | internal SweepableEstimator[] TextFeaturizer(string outputColumnName, string inputColumnName)
|
540 | 543 | {
|
541 |
| - throw new NotImplementedException(); |
| 544 | + var option = new FeaturizeTextOption |
| 545 | + { |
| 546 | + InputColumnName = inputColumnName, |
| 547 | + OutputColumnName = outputColumnName, |
| 548 | + }; |
| 549 | + |
| 550 | + return new[] { SweepableEstimatorFactory.CreateFeaturizeText(option) }; |
542 | 551 | }
|
543 | 552 |
|
544 | 553 | /// <summary>
|
545 | 554 | /// Create a list of <see cref="SweepableEstimator"/> for featurizing numeric columns.
|
546 | 555 | /// </summary>
|
547 |
| - /// <param name="outputColumnName">output column name.</param> |
548 |
| - /// <param name="inputColumnName">input column name.</param> |
549 |
| - internal SweepableEstimator[] NumericFeaturizer(string outputColumnName, string inputColumnName) |
| 556 | + /// <param name="outputColumnNames">output column names.</param> |
| 557 | + /// <param name="inputColumnNames">input column names.</param> |
| 558 | + internal SweepableEstimator[] NumericFeaturizer(string[] outputColumnNames, string[] inputColumnNames) |
550 | 559 | {
|
551 |
| - throw new NotImplementedException(); |
| 560 | + Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames)); |
| 561 | + Contracts.CheckValue(outputColumnNames, nameof(outputColumnNames)); |
| 562 | + Contracts.Check(outputColumnNames.Count() == inputColumnNames.Count() && outputColumnNames.Count() > 0, "outputColumnNames and inputColumnNames must have the same length and greater than 0"); |
| 563 | + var replaceMissingValueOption = new ReplaceMissingValueOption |
| 564 | + { |
| 565 | + InputColumnNames = inputColumnNames, |
| 566 | + OutputColumnNames = outputColumnNames, |
| 567 | + }; |
| 568 | + |
| 569 | + return new[] { SweepableEstimatorFactory.CreateReplaceMissingValues(replaceMissingValueOption) }; |
552 | 570 | }
|
553 | 571 |
|
554 | 572 | /// <summary>
|
555 | 573 | /// Create a list of <see cref="SweepableEstimator"/> for featurizing catalog columns.
|
556 | 574 | /// </summary>
|
557 |
| - /// <param name="outputColumnName">output column name.</param> |
558 |
| - /// <param name="inputColumnName">input column name.</param> |
559 |
| - internal SweepableEstimator[] CatalogFeaturizer(string outputColumnName, string inputColumnName) |
| 575 | + /// <param name="outputColumnNames">output column names.</param> |
| 576 | + /// <param name="inputColumnNames">input column names.</param> |
| 577 | + internal SweepableEstimator[] CatalogFeaturizer(string[] outputColumnNames, string[] inputColumnNames) |
560 | 578 | {
|
561 |
| - throw new NotImplementedException(); |
| 579 | + Contracts.Check(outputColumnNames.Count() == inputColumnNames.Count() && outputColumnNames.Count() > 0, "outputColumnNames and inputColumnNames must have the same length and greater than 0"); |
| 580 | + |
| 581 | + var option = new OneHotOption |
| 582 | + { |
| 583 | + InputColumnNames = inputColumnNames, |
| 584 | + OutputColumnNames = outputColumnNames, |
| 585 | + }; |
| 586 | + |
| 587 | + return new SweepableEstimator[] { SweepableEstimatorFactory.CreateOneHotEncoding(option), SweepableEstimatorFactory.CreateOneHotHashEncoding(option) }; |
562 | 588 | }
|
563 | 589 |
|
564 | 590 | /// <summary>
|
565 | 591 | /// Create a single featurize pipeline according to <paramref name="data"/>. This function will collect all columns in <paramref name="data"/> and not in <paramref name="excludeColumns"/>,
|
566 |
| - /// featurizing them using <see cref="CatalogFeaturizer(string, string)"/>, <see cref="NumericFeaturizer(string, string)"/> or <see cref="TextFeaturizer(string, string)"/>. And combine |
| 592 | + /// featurizing them using <see cref="CatalogFeaturizer(string[], string[])"/>, <see cref="NumericFeaturizer(string[], string[])"/> or <see cref="TextFeaturizer(string, string)"/>. And combine |
567 | 593 | /// them into a single feature column as output.
|
568 | 594 | /// </summary>
|
569 | 595 | /// <param name="data">input data.</param>
|
570 | 596 | /// <param name="catalogColumns">columns that should be treated as catalog. If not specified, it will automatically infer if a column is catalog or not.</param>
|
| 597 | + /// <param name="numericColumns">columns that should be treated as numeric. If not specified, it will automatically infer if a column is catalog or not.</param> |
| 598 | + /// <param name="textColumns">columns that should be treated as text. If not specified, it will automatically infer if a column is catalog or not.</param> |
571 | 599 | /// <param name="outputColumnName">output feature column.</param>
|
572 | 600 | /// <param name="excludeColumns">columns that won't be included when featurizing, like label</param>
|
573 |
| - internal MultiModelPipeline Featurizer(IDataView data, string outputColumnName = "Features", string[] catalogColumns = null, string[] excludeColumns = null) |
| 601 | + public MultiModelPipeline Featurizer(IDataView data, string outputColumnName = "Features", string[] catalogColumns = null, string[] numericColumns = null, string[] textColumns = null, string[] excludeColumns = null) |
574 | 602 | {
|
575 |
| - throw new NotImplementedException(); |
| 603 | + Contracts.CheckValue(data, nameof(data)); |
| 604 | + |
| 605 | + // validate if there's overlapping among catalogColumns, numericColumns, textColumns and excludeColumns |
| 606 | + var overallColumns = new string[][] { catalogColumns, numericColumns, textColumns, excludeColumns } |
| 607 | + .Where(c => c != null) |
| 608 | + .SelectMany(c => c); |
| 609 | + |
| 610 | + if (overallColumns != null) |
| 611 | + { |
| 612 | + Contracts.Assert(overallColumns.Count() == overallColumns.Distinct().Count(), "detect overlapping among catalogColumns, numericColumns, textColumns and excludedColumns"); |
| 613 | + } |
| 614 | + |
| 615 | + var columnInfo = new ColumnInformation(); |
| 616 | + |
| 617 | + if (excludeColumns != null) |
| 618 | + { |
| 619 | + foreach (var ignoreColumn in excludeColumns) |
| 620 | + { |
| 621 | + columnInfo.IgnoredColumnNames.Add(ignoreColumn); |
| 622 | + } |
| 623 | + } |
| 624 | + |
| 625 | + if (catalogColumns != null) |
| 626 | + { |
| 627 | + foreach (var catalogColumn in catalogColumns) |
| 628 | + { |
| 629 | + columnInfo.CategoricalColumnNames.Add(catalogColumn); |
| 630 | + } |
| 631 | + } |
| 632 | + |
| 633 | + if (numericColumns != null) |
| 634 | + { |
| 635 | + foreach (var column in numericColumns) |
| 636 | + { |
| 637 | + columnInfo.NumericColumnNames.Add(column); |
| 638 | + } |
| 639 | + } |
| 640 | + |
| 641 | + if (textColumns != null) |
| 642 | + { |
| 643 | + foreach (var column in textColumns) |
| 644 | + { |
| 645 | + columnInfo.TextColumnNames.Add(column); |
| 646 | + } |
| 647 | + } |
| 648 | + |
| 649 | + return this.Featurizer(data, columnInfo, outputColumnName); |
576 | 650 | }
|
577 | 651 |
|
578 | 652 | /// <summary>
|
579 |
| - /// Create a single featurize pipeline according to <paramref name="columnInformation"/>. This function will collect all columns in <paramref name="columnInformation"/> and not in <paramref name="excludeColumns"/>, |
580 |
| - /// featurizing them using <see cref="CatalogFeaturizer(string, string)"/>, <see cref="NumericFeaturizer(string, string)"/> or <see cref="TextFeaturizer(string, string)"/>. And combine |
| 653 | + /// Create a single featurize pipeline according to <paramref name="columnInformation"/>. This function will collect all columns in <paramref name="columnInformation"/>, |
| 654 | + /// featurizing them using <see cref="CatalogFeaturizer(string[], string[])"/>, <see cref="NumericFeaturizer(string[], string[])"/> or <see cref="TextFeaturizer(string, string)"/>. And combine |
581 | 655 | /// them into a single feature column as output.
|
582 | 656 | /// </summary>
|
| 657 | + /// <param name="data">input data.</param> |
583 | 658 | /// <param name="columnInformation">column information.</param>
|
584 | 659 | /// <param name="outputColumnName">output feature column.</param>
|
585 |
| - /// <param name="excludeColumns">columns that won't be included when featurizing, like label</param> |
586 |
| - /// <returns></returns> |
587 |
| - internal MultiModelPipeline Featurizer(ColumnInformation columnInformation, string outputColumnName = "Features", string[] excludeColumns = null) |
| 660 | + /// <returns>A <see cref="MultiModelPipeline"/> for featurization.</returns> |
| 661 | + public MultiModelPipeline Featurizer(IDataView data, ColumnInformation columnInformation, string outputColumnName = "Features") |
588 | 662 | {
|
589 |
| - throw new NotImplementedException(); |
| 663 | + Contracts.CheckValue(data, nameof(data)); |
| 664 | + Contracts.CheckValue(columnInformation, nameof(columnInformation)); |
| 665 | + |
| 666 | + var columnPurposes = PurposeInference.InferPurposes(this._context, data, columnInformation); |
| 667 | + var textFeatures = columnPurposes.Where(c => c.Purpose == ColumnPurpose.TextFeature); |
| 668 | + var numericFeatures = columnPurposes.Where(c => c.Purpose == ColumnPurpose.NumericFeature); |
| 669 | + var catalogFeatures = columnPurposes.Where(c => c.Purpose == ColumnPurpose.CategoricalFeature); |
| 670 | + var textFeatureColumnNames = textFeatures.Select(c => data.Schema[c.ColumnIndex].Name).ToArray(); |
| 671 | + var numericFeatureColumnNames = numericFeatures.Select(c => data.Schema[c.ColumnIndex].Name).ToArray(); |
| 672 | + var catalogFeatureColumnNames = catalogFeatures.Select(c => data.Schema[c.ColumnIndex].Name).ToArray(); |
| 673 | + |
| 674 | + var pipeline = new MultiModelPipeline(); |
| 675 | + if (numericFeatureColumnNames.Length > 0) |
| 676 | + { |
| 677 | + pipeline = pipeline.Append(this.NumericFeaturizer(numericFeatureColumnNames, numericFeatureColumnNames)); |
| 678 | + } |
| 679 | + |
| 680 | + if (catalogFeatureColumnNames.Length > 0) |
| 681 | + { |
| 682 | + pipeline = pipeline.Append(this.CatalogFeaturizer(catalogFeatureColumnNames, catalogFeatureColumnNames)); |
| 683 | + } |
| 684 | + |
| 685 | + foreach (var textColumn in textFeatureColumnNames) |
| 686 | + { |
| 687 | + pipeline = pipeline.Append(this.TextFeaturizer(textColumn, textColumn)); |
| 688 | + } |
| 689 | + |
| 690 | + var option = new ConcatOption |
| 691 | + { |
| 692 | + InputColumnNames = textFeatureColumnNames.Concat(numericFeatureColumnNames).Concat(catalogFeatureColumnNames).ToArray(), |
| 693 | + OutputColumnName = outputColumnName, |
| 694 | + }; |
| 695 | + |
| 696 | + if (option.InputColumnNames.Length > 0) |
| 697 | + { |
| 698 | + pipeline = pipeline.Append(SweepableEstimatorFactory.CreateConcatenate(option)); |
| 699 | + } |
| 700 | + |
| 701 | + return pipeline; |
590 | 702 | }
|
591 | 703 | }
|
592 | 704 | }
|
0 commit comments