Skip to content

Commit da35f09

Browse files
aruggeroalessandrobenedetti
authored andcommitted
SOLR-17760: solving bug in LTR dense/sparse format (#3354)
* Fixed field value feature (cherry picked from commit ba981cd)
1 parent c790b54 commit da35f09

File tree

6 files changed

+153
-83
lines changed

6 files changed

+153
-83
lines changed

solr/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ Bug Fixes
132132

133133
* SOLR-17726: MoreLikeThis to support copy-fields (Ilaria Petreti via Alessandro Benedetti)
134134

135+
* SOLR-16667: Fixed dense/sparse representation in LTR module. (Anna Ruggero, Alessandro Benedetti)
136+
135137
Dependency Upgrades
136138
---------------------
137139
* SOLR-17471: Upgrade Lucene to 9.12.1. (Pierre Salagnac, Christine Poerschke)

solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) {
4444
StringBuilder sb = new StringBuilder(featuresInfo.length * 3);
4545
boolean isDense = featureFormat.equals(FeatureFormat.DENSE);
4646
for (LTRScoringQuery.FeatureInfo featInfo : featuresInfo) {
47-
if (featInfo != null && (isDense || featInfo.isUsed())) {
47+
if (featInfo != null && (isDense || !featInfo.isDefaultValue())) {
4848
sb.append(featInfo.getName())
4949
.append(keyValueSep)
5050
.append(featInfo.getValue())

solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,12 @@ public long ramBytesUsed() {
331331
public static class FeatureInfo {
332332
private final String name;
333333
private float value;
334-
private boolean used;
334+
private boolean isDefaultValue;
335335

336-
FeatureInfo(String n, float v, boolean u) {
337-
name = n;
338-
value = v;
339-
used = u;
336+
FeatureInfo(String name, float value, boolean isDefaultValue) {
337+
this.name = name;
338+
this.value = value;
339+
this.isDefaultValue = isDefaultValue;
340340
}
341341

342342
public void setValue(float value) {
@@ -351,12 +351,12 @@ public float getValue() {
351351
return value;
352352
}
353353

354-
public boolean isUsed() {
355-
return used;
354+
public boolean isDefaultValue() {
355+
return isDefaultValue;
356356
}
357357

358-
public void setUsed(boolean used) {
359-
this.used = used;
358+
public void setIsDefaultValue(boolean isDefaultValue) {
359+
this.isDefaultValue = isDefaultValue;
360360
}
361361
}
362362

@@ -408,7 +408,7 @@ private void setFeaturesInfo() {
408408
String featName = extractedFeatureWeights[i].getName();
409409
int featId = extractedFeatureWeights[i].getIndex();
410410
float value = extractedFeatureWeights[i].getDefaultValue();
411-
featuresInfo[featId] = new FeatureInfo(featName, value, false);
411+
featuresInfo[featId] = new FeatureInfo(featName, value, true);
412412
}
413413
}
414414

@@ -440,12 +440,7 @@ private float makeNormalizedFeaturesAndScore() {
440440
for (final Feature.FeatureWeight feature : modelFeatureWeights) {
441441
final int featureId = feature.getIndex();
442442
FeatureInfo fInfo = featuresInfo[featureId];
443-
// not checking for finfo == null as that would be a bug we should catch
444-
if (fInfo.isUsed()) {
445-
modelFeatureValuesNormalized[pos] = fInfo.getValue();
446-
} else {
447-
modelFeatureValuesNormalized[pos] = feature.getDefaultValue();
448-
}
443+
modelFeatureValuesNormalized[pos] = fInfo.getValue();
449444
pos++;
450445
}
451446
ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized);
@@ -480,7 +475,7 @@ protected void reset() {
480475
// need to set default value everytime as the default value is used in 'dense'
481476
// mode even if used=false
482477
featuresInfo[featId].setValue(value);
483-
featuresInfo[featId].setUsed(false);
478+
featuresInfo[featId].setIsDefaultValue(true);
484479
}
485480
}
486481

@@ -598,7 +593,9 @@ public float score() throws IOException {
598593
Feature.FeatureWeight scFW = (Feature.FeatureWeight) subScorer.getWeight();
599594
final int featureId = scFW.getIndex();
600595
featuresInfo[featureId].setValue(subScorer.score());
601-
featuresInfo[featureId].setUsed(true);
596+
if (featuresInfo[featureId].getValue() != scFW.getDefaultValue()) {
597+
featuresInfo[featureId].setIsDefaultValue(false);
598+
}
602599
}
603600
}
604601
return makeNormalizedFeaturesAndScore();
@@ -683,7 +680,9 @@ public float score() throws IOException {
683680
Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight();
684681
final int featureId = scFW.getIndex();
685682
featuresInfo[featureId].setValue(scorer.score());
686-
featuresInfo[featureId].setUsed(true);
683+
if (featuresInfo[featureId].getValue() != scFW.getDefaultValue()) {
684+
featuresInfo[featureId].setIsDefaultValue(false);
685+
}
687686
}
688687
}
689688
}

solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ public void testScoringQueryWeightCreation() throws IOException, ModelException
142142
assertEquals("11", searcher.storedFields().document(hits.scoreDocs[1].doc).get("id"));
143143

144144
List<Feature> features = makeFeatures(new int[] {0, 1, 2});
145+
List<Feature> expectedNonDefaultFeatures = makeFeatures(new int[] {1, 2});
145146
final List<Feature> allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
147+
List<Feature> expectedNonDefaultAllFeatures =
148+
makeFeatures(new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
146149
final List<Normalizer> norms = new ArrayList<>();
147150
for (int k = 0; k < features.size(); ++k) {
148151
norms.add(IdentityNormalizer.INSTANCE);
@@ -167,13 +170,13 @@ public void testScoringQueryWeightCreation() throws IOException, ModelException
167170
LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getFeaturesInfo();
168171

169172
assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length);
170-
int validFeatures = 0;
173+
int nonDefaultFeatures = 0;
171174
for (int i = 0; i < featuresInfo.length; ++i) {
172-
if (featuresInfo[i] != null && featuresInfo[i].isUsed()) {
173-
validFeatures += 1;
175+
if (featuresInfo[i] != null && !featuresInfo[i].isDefaultValue()) {
176+
nonDefaultFeatures += 1;
174177
}
175178
}
176-
assertEquals(validFeatures, features.size());
179+
assertEquals(expectedNonDefaultFeatures.size(), nonDefaultFeatures);
177180

178181
// when features are requested in the response, weights should be created for all features
179182
final LTRScoringModel ltrScoringModel2 =
@@ -194,13 +197,13 @@ public void testScoringQueryWeightCreation() throws IOException, ModelException
194197
assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length);
195198
assertEquals(allFeatures.size(), modelWeight.getExtractedFeatureWeights().length);
196199

197-
validFeatures = 0;
200+
nonDefaultFeatures = 0;
198201
for (int i = 0; i < featuresInfo.length; ++i) {
199-
if (featuresInfo[i] != null && featuresInfo[i].isUsed()) {
200-
validFeatures += 1;
202+
if (featuresInfo[i] != null && !featuresInfo[i].isDefaultValue()) {
203+
nonDefaultFeatures += 1;
201204
}
202205
}
203-
assertEquals(validFeatures, allFeatures.size());
206+
assertEquals(expectedNonDefaultAllFeatures.size(), nonDefaultFeatures);
204207

205208
assertU(delI("10"));
206209
assertU(delI("11"));

0 commit comments

Comments
 (0)