Skip to content

Commit 22342df

Browse files
Continue training on OOM error && add subsampling support for trainValidationDatasetManager (dotnet#6714)
* Update AutoMLExperiment.cs * implement subsampling for train-validation dataset manager * fix test * fix comments * fix comment * revert tests
1 parent b28710a commit 22342df

File tree

12 files changed

+193
-54
lines changed

12 files changed

+193
-54
lines changed

src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using System.Text.Json.Serialization;
1111
using Microsoft.Extensions.DependencyInjection;
1212
using Microsoft.ML.Runtime;
13+
using Microsoft.ML.SearchSpace.Option;
1314
using Newtonsoft.Json;
1415
using static Microsoft.ML.DataOperationsCatalog;
1516

@@ -24,14 +25,18 @@ public static class AutoMLExperimentExtension
2425
/// <param name="experiment"><see cref="AutoMLExperiment"/></param>
2526
/// <param name="train">dataset for training a model.</param>
2627
/// <param name="validation">dataset for validating a model during training.</param>
28+
/// <param name="subSamplingTrainDataset">determine if subsampling <paramref name="train"/> to train. This will be useful if <paramref name="train"/> is too large to be held in memory.</param>
2729
/// <returns><see cref="AutoMLExperiment"/></returns>
28-
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView train, IDataView validation)
30+
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView train, IDataView validation, bool subSamplingTrainDataset = false)
2931
{
30-
var datasetManager = new TrainValidateDatasetManager()
32+
var datasetManager = new TrainValidateDatasetManager(train, validation);
33+
34+
if (subSamplingTrainDataset)
3135
{
32-
TrainDataset = train,
33-
ValidateDataset = validation
34-
};
36+
var searchSpace = new SearchSpace.SearchSpace();
37+
searchSpace.Add(datasetManager.SubSamplingKey, new UniformSingleOption(0, 1, false, 0.1f));
38+
experiment.AddSearchSpace(nameof(TrainValidateDatasetManager), searchSpace);
39+
}
3540

3641
experiment.ServiceCollection.AddSingleton<IDatasetManager>(datasetManager);
3742
experiment.ServiceCollection.AddSingleton(datasetManager);
@@ -62,13 +67,7 @@ public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, Trai
6267
/// <returns><see cref="AutoMLExperiment"/></returns>
6368
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10, string samplingKeyColumnName = null)
6469
{
65-
var datasetManager = new CrossValidateDatasetManager()
66-
{
67-
Dataset = dataset,
68-
Fold = fold,
69-
SamplingKeyColumnName = samplingKeyColumnName,
70-
};
71-
70+
var datasetManager = new CrossValidateDatasetManager(dataset, fold, samplingKeyColumnName);
7271
experiment.ServiceCollection.AddSingleton<IDatasetManager>(datasetManager);
7372
experiment.ServiceCollection.AddSingleton(datasetManager);
7473

src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ public TrialResult Run(TrialSettings settings)
391391
{
392392
var stopWatch = new Stopwatch();
393393
stopWatch.Start();
394-
var fold = datasetManager.Fold ?? 5;
394+
var fold = datasetManager.Fold;
395395
var metrics = _context.BinaryClassification.CrossValidateNonCalibrated(datasetManager.Dataset, pipeline, fold, metricManager.LabelColumn);
396396

397397
// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
@@ -420,8 +420,8 @@ public TrialResult Run(TrialSettings settings)
420420
{
421421
var stopWatch = new Stopwatch();
422422
stopWatch.Start();
423-
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
424-
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
423+
var model = pipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_context, settings));
424+
var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_context, settings));
425425
var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
426426
var metric = GetMetric(metricManager.Metric, metrics);
427427
var loss = metricManager.IsMaximize ? -metric : metric;

src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ public TrialResult Run(TrialSettings settings)
369369
{
370370
var stopWatch = new Stopwatch();
371371
stopWatch.Start();
372-
var fold = datasetManager.Fold ?? 5;
372+
var fold = datasetManager.Fold;
373373
var metrics = _context.MulticlassClassification.CrossValidate(datasetManager.Dataset, pipeline, fold, metricManager.LabelColumn);
374374

375375
// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
@@ -398,8 +398,8 @@ public TrialResult Run(TrialSettings settings)
398398
{
399399
var stopWatch = new Stopwatch();
400400
stopWatch.Start();
401-
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
402-
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
401+
var model = pipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_context, settings));
402+
var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_context, settings));
403403
var metrics = _context.MulticlassClassification.Evaluate(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
404404
var metric = GetMetric(metricManager.Metric, metrics);
405405
var loss = metricManager.IsMaximize ? -metric : metric;

src/Microsoft.ML.AutoML/API/RegressionExperiment.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
396396
{
397397
var stopWatch = new Stopwatch();
398398
stopWatch.Start();
399-
var fold = datasetManager.Fold ?? 5;
399+
var fold = datasetManager.Fold;
400400
var metrics = _context.Regression.CrossValidate(datasetManager.Dataset, pipeline, fold, metricManager.LabelColumn);
401401

402402
// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
@@ -425,8 +425,8 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
425425
{
426426
var stopWatch = new Stopwatch();
427427
stopWatch.Start();
428-
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
429-
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
428+
var model = pipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_context, settings));
429+
var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_context, settings));
430430
var metrics = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn);
431431
var metric = GetMetric(metricManager.Metric, metrics);
432432
var loss = metricManager.IsMaximize ? -metric : metric;

src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ Abandoning Trial {trialSettings.TrialId} and continue training.
315315
trialResultManager?.AddOrUpdateTrialResult(trialResult);
316316
aggregateTrainingStopManager.Update(trialResult);
317317

318-
if (ex is not OperationCanceledException && _bestTrialResult == null)
318+
if (ex is not OperationCanceledException && ex is not OutOfMemoryException && _bestTrialResult == null)
319319
{
320320
logger.Trace($"trial fatal error - {JsonSerializer.Serialize(trialSettings)}, stop training");
321321

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
4+
#nullable enable
5+
6+
using Microsoft.ML.SearchSpace;
47

58
namespace Microsoft.ML.AutoML
69
{
@@ -12,34 +15,102 @@ public interface IDatasetManager
1215
{
1316
}
1417

15-
internal interface ICrossValidateDatasetManager
18+
/// <summary>
19+
/// Inferface for cross validate dataset manager.
20+
/// </summary>
21+
public interface ICrossValidateDatasetManager : IDatasetManager
1622
{
17-
int? Fold { get; set; }
23+
/// <summary>
24+
/// Cross validate fold.
25+
/// </summary>
26+
int Fold { get; set; }
1827

28+
/// <summary>
29+
/// The dataset to cross validate.
30+
/// </summary>
1931
IDataView Dataset { get; set; }
2032

21-
string SamplingKeyColumnName { get; set; }
33+
/// <summary>
34+
/// The dataset column used for grouping rows.
35+
/// </summary>
36+
string? SamplingKeyColumnName { get; set; }
2237
}
2338

24-
internal interface ITrainValidateDatasetManager
39+
public interface ITrainValidateDatasetManager : IDatasetManager
2540
{
26-
IDataView TrainDataset { get; set; }
41+
IDataView LoadTrainDataset(MLContext context, TrialSettings? settings);
2742

28-
IDataView ValidateDataset { get; set; }
43+
IDataView LoadValidateDataset(MLContext context, TrialSettings? settings);
2944
}
3045

3146
internal class TrainValidateDatasetManager : IDatasetManager, ITrainValidateDatasetManager
3247
{
33-
public IDataView TrainDataset { get; set; }
48+
private ulong _rowCount;
49+
private IDataView _trainDataset;
50+
private readonly IDataView _validateDataset;
51+
private readonly string _subSamplingKey = "TrainValidateDatasetSubsamplingKey";
52+
private bool _isInitialized = false;
53+
public TrainValidateDatasetManager(IDataView trainDataset, IDataView validateDataset, string? subSamplingKey = null)
54+
{
55+
_trainDataset = trainDataset;
56+
_validateDataset = validateDataset;
57+
_subSamplingKey = subSamplingKey ?? _subSamplingKey;
58+
}
59+
60+
public string SubSamplingKey => _subSamplingKey;
61+
62+
/// <summary>
63+
/// Load Train Dataset. If <see cref="TrialSettings.Parameter"/> contains <see cref="_subSamplingKey"/> then the train dataset will be subsampled.
64+
/// </summary>
65+
/// <param name="context">MLContext.</param>
66+
/// <param name="settings">trial settings. If null, return entire train dataset.</param>
67+
/// <returns>train dataset.</returns>
68+
public IDataView LoadTrainDataset(MLContext context, TrialSettings? settings)
69+
{
70+
if (!_isInitialized)
71+
{
72+
InitializeTrainDataset(context);
73+
_isInitialized = true;
74+
}
75+
var trainTestSplitParameter = settings?.Parameter.ContainsKey(nameof(TrainValidateDatasetManager)) is true ? settings.Parameter[nameof(TrainValidateDatasetManager)] : null;
76+
if (trainTestSplitParameter is Parameter parameter)
77+
{
78+
var subSampleRatio = parameter.ContainsKey(_subSamplingKey) ? parameter[_subSamplingKey].AsType<double>() : 1;
79+
if (subSampleRatio < 1.0)
80+
{
81+
var subSampledTrainDataset = context.Data.TakeRows(_trainDataset, (long)(subSampleRatio * _rowCount));
82+
return subSampledTrainDataset;
83+
}
84+
}
3485

35-
public IDataView ValidateDataset { get; set; }
86+
return _trainDataset;
87+
}
88+
89+
public IDataView LoadValidateDataset(MLContext context, TrialSettings? settings)
90+
{
91+
return _validateDataset;
92+
}
93+
94+
private void InitializeTrainDataset(MLContext context)
95+
{
96+
_rowCount = DatasetDimensionsUtil.CountRows(_trainDataset, ulong.MaxValue);
97+
_trainDataset = context.Data.ShuffleRows(_trainDataset);
98+
}
3699
}
37100

38101
internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateDatasetManager
39102
{
103+
public CrossValidateDatasetManager(IDataView dataset, int fold, string? samplingKeyColumnName = null)
104+
{
105+
Dataset = dataset;
106+
Fold = fold;
107+
SamplingKeyColumnName = samplingKeyColumnName;
108+
}
109+
40110
public IDataView Dataset { get; set; }
41111

42-
public int? Fold { get; set; }
43-
public string SamplingKeyColumnName { get; set; }
112+
public int Fold { get; set; }
113+
114+
public string? SamplingKeyColumnName { get; set; }
44115
}
45116
}

src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public TrialResult Run(TrialSettings settings)
4040
var mlnetPipeline = _pipeline.BuildFromOption(_mLContext, parameter);
4141
if (_datasetManager is ICrossValidateDatasetManager crossValidateDatasetManager)
4242
{
43-
var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5, crossValidateDatasetManager.SamplingKeyColumnName);
43+
var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold, crossValidateDatasetManager.SamplingKeyColumnName);
4444
var metrics = new List<double>();
4545
var models = new List<ITransformer>();
4646
foreach (var split in datasetSplit)
@@ -68,8 +68,8 @@ public TrialResult Run(TrialSettings settings)
6868

6969
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
7070
{
71-
var model = mlnetPipeline.Fit(trainTestDatasetManager.TrainDataset);
72-
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
71+
var model = mlnetPipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_mLContext!, settings));
72+
var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_mLContext!, settings));
7373
var metric = _metricManager.Evaluate(_mLContext, eval);
7474
stopWatch.Stop();
7575
var loss = _metricManager.IsMaximize ? -metric : metric;

src/Microsoft.ML.AutoML/Tuner/EciCfoTuner.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ public EciCostFrugalTuner(SweepablePipeline sweepablePipeline, AutoMLExperiment.
3232
_tuners = pipelineSchemas.ToDictionary(schema => schema, schema =>
3333
{
3434
var searchSpace = sweepablePipeline.BuildSweepableEstimatorPipeline(schema).SearchSpace;
35-
return new CostFrugalTuner(searchSpace, searchSpace.SampleFromFeatureSpace(searchSpace.Default), seed: settings.Seed) as ITuner;
35+
var aggregateSearchSpace = new SearchSpace.SearchSpace(settings.SearchSpace);
36+
aggregateSearchSpace[AutoMLExperiment.PipelineSearchspaceName] = searchSpace;
37+
return new CostFrugalTuner(aggregateSearchSpace, aggregateSearchSpace.SampleFromFeatureSpace(aggregateSearchSpace.Default), seed: settings.Seed) as ITuner;
3638
});
3739

3840
if (trialResultManager != null)
@@ -57,22 +59,18 @@ public Parameter Propose(TrialSettings settings)
5759
parameter[k.Key] = _defaultParameter[k.Key];
5860
}
5961
}
60-
settings.Parameter[AutoMLExperiment.PipelineSearchspaceName] = parameter;
62+
settings.Parameter = parameter;
6163

6264
return settings.Parameter;
6365
}
6466

6567
public void Update(TrialResult result)
6668
{
67-
var originalParameter = result.TrialSettings.Parameter;
6869
var schema = result.TrialSettings.Parameter[AutoMLExperiment.PipelineSearchspaceName]["_SCHEMA_"].AsType<string>();
6970
_pipelineProposer.Update(result, schema);
7071
if (_tuners.TryGetValue(schema, out var tuner))
7172
{
72-
var parameter = result.TrialSettings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
73-
result.TrialSettings.Parameter = parameter;
7473
tuner.Update(result);
75-
result.TrialSettings.Parameter = originalParameter;
7674
}
7775
}
7876
}

src/Microsoft.ML.Fairlearn/AutoML/AutoMLExperimentExtension.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Microsoft.Extensions.DependencyInjection;
1010
using Microsoft.ML.AutoML;
1111
using Microsoft.ML.Data;
12+
using Microsoft.ML.SearchSpace;
1213

1314
namespace Microsoft.ML.Fairlearn.AutoML
1415
{
@@ -55,9 +56,14 @@ public static AutoMLExperiment SetBinaryClassificationMetricWithFairLearn(
5556
{
5657
var datasetManager = serviceProvider.GetRequiredService<TrainValidateDatasetManager>();
5758
var moment = new UtilityParity();
58-
var sensitiveFeature = DataFrameColumn.Create("group_id", datasetManager.TrainDataset.GetColumn<string>(sensitiveColumnName));
59-
var label = DataFrameColumn.Create("label", datasetManager.TrainDataset.GetColumn<bool>(labelColumn));
60-
moment.LoadData(datasetManager.TrainDataset, label, sensitiveFeature);
59+
var context = serviceProvider.GetRequiredService<MLContext>();
60+
var trainData = datasetManager.LoadTrainDataset(context, new TrialSettings
61+
{
62+
Parameter = Parameter.CreateNestedParameter(),
63+
});
64+
var sensitiveFeature = DataFrameColumn.Create("group_id", trainData.GetColumn<string>(sensitiveColumnName));
65+
var label = DataFrameColumn.Create("label", trainData.GetColumn<bool>(labelColumn));
66+
moment.LoadData(trainData, label, sensitiveFeature);
6167
var lambdaSearchSpace = Utilities.GenerateBinaryClassificationLambdaSearchSpace(moment, gridLimit, negativeAllowed);
6268
experiment.AddSearchSpace("_lambda_search_space", lambdaSearchSpace);
6369

@@ -70,8 +76,9 @@ public static AutoMLExperiment SetBinaryClassificationMetricWithFairLearn(
7076
var moment = serviceProvider.GetRequiredService<ClassificationMoment>();
7177
var datasetManager = serviceProvider.GetRequiredService<TrainValidateDatasetManager>();
7278
var pipeline = serviceProvider.GetRequiredService<SweepablePipeline>();
73-
return new GridSearchTrailRunner(context, datasetManager.TrainDataset, datasetManager.ValidateDataset, labelColumn, sensitiveColumnName, pipeline, moment);
79+
return new GridSearchTrailRunner(context, datasetManager, labelColumn, sensitiveColumnName, pipeline, moment);
7480
});
81+
7582
experiment.SetRandomSearchTuner();
7683

7784
return experiment;

0 commit comments

Comments
 (0)