Skip to content

Commit cb802e2

Browse files
feiyun0112feiyun0112Copilot
authored
ImageClassificationTrainer PredictedLabelColumnName bug when the name is not default (#7458)
* set PredictedLabelColumnName * Update src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: feiyun0112 <feiyun0112@hotmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 4b00b34 commit cb802e2

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ private protected sealed class BindingsImpl : BindingsBase
3838
public readonly int ScoreColumnIndex;
3939
// The type of the derived column.
4040
public readonly DataViewType PredColType;
41+
/// <summary>
42+
/// The name of the column that contains the predicted labels.
43+
/// This field is used in the scoring process to store or reference the predicted label column.
44+
/// </summary>
45+
public readonly string PredictedLabelColumnName;
4146
// The ScoreColumnKind metadata value for all score columns.
4247
public readonly string ScoreColumnKind;
4348

@@ -54,6 +59,7 @@ private BindingsImpl(DataViewSchema input, ISchemaBoundRowMapper mapper, string
5459
ScoreColumnIndex = scoreColIndex;
5560
ScoreColumnKind = scoreColumnKind;
5661
PredColType = predColType;
62+
PredictedLabelColumnName = predictedLabelColumnName;
5763

5864
_getScoreColumnKind = GetScoreColumnKind;
5965
_getScoreValueKind = GetScoreValueKind;
@@ -113,7 +119,7 @@ public BindingsImpl ApplyToSchema(DataViewSchema input, ISchemaBindableMapper bi
113119
bool tmp = rowMapper.OutputSchema.TryGetColumnIndex(scoreCol, out mapperScoreColumn);
114120
env.Check(tmp, "Mapper doesn't have expected score column");
115121

116-
return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType);
122+
return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType, PredictedLabelColumnName);
117123
}
118124

119125
public static BindingsImpl Create(ModelLoadContext ctx, DataViewSchema input,

0 commit comments

Comments
 (0)