Skip to content

Commit c90223e

Browse files
committed
Fix LTR query feature with phrases (and two-phase) queries (elastic#125103)
Query features should verify that docs match the two-phase iterator.
1 parent f9263c8 commit c90223e

File tree

3 files changed

+139
-108
lines changed

3 files changed

+139
-108
lines changed

docs/changelog/125103.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125103
2+
summary: Fix LTR query feature with phrases (and two-phase) queries
3+
area: Ranking
4+
type: bug
5+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import org.apache.lucene.search.Weight;
1616

1717
import java.io.IOException;
18-
import java.util.ArrayList;
1918
import java.util.List;
2019
import java.util.Map;
2120

@@ -25,52 +24,52 @@
2524
* respective feature name.
2625
*/
2726
public class QueryFeatureExtractor implements FeatureExtractor {
28-
2927
private final List<String> featureNames;
3028
private final List<Weight> weights;
31-
private final List<Scorer> scorers;
32-
private DisjunctionDISIApproximation rankerIterator;
29+
30+
private final DisiPriorityQueue subScorers;
31+
private DisjunctionDISIApproximation approximation;
3332

3433
public QueryFeatureExtractor(List<String> featureNames, List<Weight> weights) {
3534
if (featureNames.size() != weights.size()) {
3635
throw new IllegalArgumentException("[featureNames] and [weights] must be the same size.");
3736
}
3837
this.featureNames = featureNames;
3938
this.weights = weights;
40-
this.scorers = new ArrayList<>(weights.size());
39+
this.subScorers = new DisiPriorityQueue(weights.size());
4140
}
4241

4342
@Override
4443
public void setNextReader(LeafReaderContext segmentContext) throws IOException {
45-
DisiPriorityQueue disiPriorityQueue = new DisiPriorityQueue(weights.size());
46-
scorers.clear();
47-
for (Weight weight : weights) {
44+
subScorers.clear();
45+
for (int i = 0; i < weights.size(); i++) {
46+
var weight = weights.get(i);
4847
if (weight == null) {
49-
scorers.add(null);
5048
continue;
5149
}
5250
Scorer scorer = weight.scorer(segmentContext);
5351
if (scorer != null) {
54-
disiPriorityQueue.add(new DisiWrapper(scorer));
52+
subScorers.add(new FeatureDisiWrapper(scorer, featureNames.get(i)));
5553
}
56-
scorers.add(scorer);
5754
}
58-
59-
rankerIterator = disiPriorityQueue.size() > 0 ? new DisjunctionDISIApproximation(disiPriorityQueue) : null;
55+
approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(subScorers) : null;
6056
}
6157

6258
@Override
6359
public void addFeatures(Map<String, Object> featureMap, int docId) throws IOException {
64-
if (rankerIterator == null) {
60+
if (approximation == null || approximation.docID() > docId) {
6561
return;
6662
}
67-
68-
rankerIterator.advance(docId);
69-
for (int i = 0; i < featureNames.size(); i++) {
70-
Scorer scorer = scorers.get(i);
71-
// Do we have a scorer, and does it match the provided document?
72-
if (scorer != null && scorer.docID() == docId) {
73-
featureMap.put(featureNames.get(i), scorer.score());
63+
if (approximation.docID() < docId) {
64+
approximation.advance(docId);
65+
}
66+
if (approximation.docID() != docId) {
67+
return;
68+
}
69+
var w = (FeatureDisiWrapper) subScorers.topList();
70+
for (; w != null; w = (FeatureDisiWrapper) w.next) {
71+
if (w.twoPhaseView == null || w.twoPhaseView.matches()) {
72+
featureMap.put(w.featureName, w.scorable.score());
7473
}
7574
}
7675
}
@@ -80,4 +79,12 @@ public List<String> featureNames() {
8079
return featureNames;
8180
}
8281

82+
private static class FeatureDisiWrapper extends DisiWrapper {
83+
final String featureName;
84+
85+
FeatureDisiWrapper(Scorer scorer, String featureName) {
86+
super(scorer, false);
87+
this.featureName = featureName;
88+
}
89+
}
8390
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractorTests.java

Lines changed: 106 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import org.apache.lucene.document.IntField;
1313
import org.apache.lucene.index.IndexReader;
1414
import org.apache.lucene.index.LeafReaderContext;
15-
import org.apache.lucene.search.IndexSearcher;
15+
import org.apache.lucene.index.NoMergePolicy;
1616
import org.apache.lucene.search.Query;
1717
import org.apache.lucene.search.ScoreMode;
1818
import org.apache.lucene.search.Weight;
@@ -43,13 +43,11 @@
4343

4444
public class QueryFeatureExtractorTests extends AbstractBuilderTestCase {
4545

46-
private Directory dir;
47-
private IndexReader reader;
48-
private IndexSearcher searcher;
49-
50-
private void addDocs(String[] textValues, int[] numberValues) throws IOException {
51-
dir = newDirectory();
52-
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir)) {
46+
private IndexReader addDocs(Directory dir, String[] textValues, int[] numberValues) throws IOException {
47+
var config = newIndexWriterConfig();
48+
// override the merge policy to ensure that docs remain in the same ingestion order
49+
config.setMergePolicy(newLogMergePolicy(random()));
50+
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir, config)) {
5351
for (int i = 0; i < textValues.length; i++) {
5452
Document doc = new Document();
5553
doc.add(newTextField(TEXT_FIELD_NAME, textValues[i], Field.Store.NO));
@@ -59,98 +57,119 @@ private void addDocs(String[] textValues, int[] numberValues) throws IOException
5957
indexWriter.flush();
6058
}
6159
}
62-
reader = indexWriter.getReader();
60+
return indexWriter.getReader();
6361
}
64-
searcher = newSearcher(reader);
65-
searcher.setSimilarity(new ClassicSimilarity());
6662
}
6763

68-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98127")
6964
public void testQueryExtractor() throws IOException {
70-
addDocs(
71-
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
72-
new int[] { 5, 10, 12, 11 }
73-
);
74-
QueryRewriteContext ctx = createQueryRewriteContext();
75-
List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
76-
new QueryExtractorBuilder("text_score", QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox")))
77-
.rewrite(ctx),
78-
new QueryExtractorBuilder(
79-
"number_score",
80-
QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
81-
).rewrite(ctx),
82-
new QueryExtractorBuilder(
83-
"matching_none",
84-
QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
85-
).rewrite(ctx),
86-
new QueryExtractorBuilder(
87-
"matching_missing_field",
88-
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
89-
).rewrite(ctx)
90-
);
91-
SearchExecutionContext dummySEC = createSearchExecutionContext();
92-
List<Weight> weights = new ArrayList<>();
93-
List<String> featureNames = new ArrayList<>();
94-
for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
95-
Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
96-
Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
97-
weights.add(weight);
98-
featureNames.add(qeb.featureName());
99-
}
100-
QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
101-
List<Map<String, Object>> extractedFeatures = new ArrayList<>();
102-
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
103-
int maxDoc = leafReaderContext.reader().maxDoc();
104-
queryFeatureExtractor.setNextReader(leafReaderContext);
105-
for (int i = 0; i < maxDoc; i++) {
106-
Map<String, Object> featureMap = new HashMap<>();
107-
queryFeatureExtractor.addFeatures(featureMap, i);
108-
extractedFeatures.add(featureMap);
65+
try (var dir = newDirectory()) {
66+
try (
67+
var reader = addDocs(
68+
dir,
69+
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
70+
new int[] { 5, 10, 12, 11 }
71+
)
72+
) {
73+
var searcher = newSearcher(reader);
74+
searcher.setSimilarity(new ClassicSimilarity());
75+
QueryRewriteContext ctx = createQueryRewriteContext();
76+
List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
77+
new QueryExtractorBuilder(
78+
"text_score",
79+
QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox"))
80+
).rewrite(ctx),
81+
new QueryExtractorBuilder(
82+
"number_score",
83+
QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
84+
).rewrite(ctx),
85+
new QueryExtractorBuilder(
86+
"matching_none",
87+
QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
88+
).rewrite(ctx),
89+
new QueryExtractorBuilder(
90+
"matching_missing_field",
91+
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
92+
).rewrite(ctx),
93+
new QueryExtractorBuilder(
94+
"phrase_score",
95+
QueryProvider.fromParsedQuery(QueryBuilders.matchPhraseQuery(TEXT_FIELD_NAME, "slow brown fox"))
96+
).rewrite(ctx)
97+
);
98+
SearchExecutionContext dummySEC = createSearchExecutionContext();
99+
List<Weight> weights = new ArrayList<>();
100+
List<String> featureNames = new ArrayList<>();
101+
for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
102+
Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
103+
Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
104+
weights.add(weight);
105+
featureNames.add(qeb.featureName());
106+
}
107+
QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
108+
List<Map<String, Object>> extractedFeatures = new ArrayList<>();
109+
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
110+
int maxDoc = leafReaderContext.reader().maxDoc();
111+
queryFeatureExtractor.setNextReader(leafReaderContext);
112+
for (int i = 0; i < maxDoc; i++) {
113+
Map<String, Object> featureMap = new HashMap<>();
114+
queryFeatureExtractor.addFeatures(featureMap, i);
115+
extractedFeatures.add(featureMap);
116+
}
117+
}
118+
assertThat(extractedFeatures, hasSize(4));
119+
// Should never add features for queries that don't match a document or on documents where the field is missing
120+
for (Map<String, Object> features : extractedFeatures) {
121+
assertThat(features, not(hasKey("matching_none")));
122+
assertThat(features, not(hasKey("matching_missing_field")));
123+
}
124+
// First two only match the text field
125+
assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
126+
assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
127+
assertThat(extractedFeatures.get(0), not(hasKey("phrase_score")));
128+
assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
129+
assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
130+
assertThat(extractedFeatures.get(1), hasEntry("phrase_score", 2.468971f));
131+
132+
// Only matches the range query
133+
assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
134+
assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
135+
assertThat(extractedFeatures.get(2), not(hasKey("phrase_score")));
136+
137+
// No query matches
138+
assertThat(extractedFeatures.get(3), anEmptyMap());
109139
}
110140
}
111-
assertThat(extractedFeatures, hasSize(4));
112-
// Should never add features for queries that don't match a document or on documents where the field is missing
113-
for (Map<String, Object> features : extractedFeatures) {
114-
assertThat(features, not(hasKey("matching_none")));
115-
assertThat(features, not(hasKey("matching_missing_field")));
116-
}
117-
// First two only match the text field
118-
assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
119-
assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
120-
assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
121-
assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
122-
// Only matches the range query
123-
assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
124-
assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
125-
// No query matches
126-
assertThat(extractedFeatures.get(3), anEmptyMap());
127-
reader.close();
128-
dir.close();
129141
}
130142

131143
public void testEmptyDisiPriorityQueue() throws IOException {
132-
addDocs(
133-
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
134-
new int[] { 5, 10, 12, 11 }
135-
);
144+
try (var dir = newDirectory()) {
145+
var config = newIndexWriterConfig();
146+
config.setMergePolicy(NoMergePolicy.INSTANCE);
147+
try (
148+
var reader = addDocs(
149+
dir,
150+
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
151+
new int[] { 5, 10, 12, 11 }
152+
)
153+
) {
136154

137-
// Scorers returned by weights are null
138-
List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
139-
List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();
155+
var searcher = newSearcher(reader);
156+
searcher.setSimilarity(new ClassicSimilarity());
140157

141-
QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
158+
// Scorers returned by weights are null
159+
List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
160+
List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();
142161

143-
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
144-
int maxDoc = leafReaderContext.reader().maxDoc();
145-
featureExtractor.setNextReader(leafReaderContext);
146-
for (int i = 0; i < maxDoc; i++) {
147-
Map<String, Object> featureMap = new HashMap<>();
148-
featureExtractor.addFeatures(featureMap, i);
149-
assertThat(featureMap, anEmptyMap());
162+
QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
163+
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
164+
int maxDoc = leafReaderContext.reader().maxDoc();
165+
featureExtractor.setNextReader(leafReaderContext);
166+
for (int i = 0; i < maxDoc; i++) {
167+
Map<String, Object> featureMap = new HashMap<>();
168+
featureExtractor.addFeatures(featureMap, i);
169+
assertThat(featureMap, anEmptyMap());
170+
}
171+
}
150172
}
151173
}
152-
153-
reader.close();
154-
dir.close();
155174
}
156175
}

0 commit comments

Comments
 (0)