Skip to content

Commit bbc46d9

Browse files
committed
DATAES-13 added support for minimum score
1 parent 6c5e8fe commit bbc46d9

File tree

8 files changed

+158
-0
lines changed

8 files changed

+158
-0
lines changed

src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ public <T> Page<T> queryForPage(CriteriaQuery criteriaQuery, Class<T> clazz) {
210210
searchRequestBuilder.setQuery(QueryBuilders.matchAllQuery());
211211
}
212212

213+
if(criteriaQuery.getMinScore()>0){
214+
searchRequestBuilder.setMinScore(criteriaQuery.getMinScore());
215+
}
216+
213217
if (elasticsearchFilter != null)
214218
searchRequestBuilder.setFilter(elasticsearchFilter);
215219

@@ -520,6 +524,10 @@ private SearchRequestBuilder prepareSearch(Query query) {
520524
: SortOrder.ASC);
521525
}
522526
}
527+
528+
if(query.getMinScore()>0){
529+
searchRequestBuilder.setMinScore(query.getMinScore());
530+
}
523531
return searchRequestBuilder;
524532
}
525533

src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ abstract class AbstractQuery implements Query {
3737
protected List<String> indices = new ArrayList<String>();
3838
protected List<String> types = new ArrayList<String>();
3939
protected List<String> fields = new ArrayList<String>();
40+
protected float minScore;
4041

4142
@Override
4243
public Sort getSort() {
@@ -99,4 +100,12 @@ public final <T extends Query> T addSort(Sort sort) {
99100

100101
return (T) this;
101102
}
103+
104+
public float getMinScore() {
105+
return minScore;
106+
}
107+
108+
public void setMinScore(float minScore) {
109+
this.minScore = minScore;
110+
}
102111
}

src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.data.elasticsearch.core.query;
1717

1818
import org.apache.commons.collections.CollectionUtils;
19+
import org.elasticsearch.common.cache.CacheBuilder;
1920
import org.elasticsearch.index.query.FilterBuilder;
2021
import org.elasticsearch.index.query.QueryBuilder;
2122
import org.elasticsearch.search.highlight.HighlightBuilder;
@@ -45,6 +46,7 @@ public class NativeSearchQueryBuilder {
4546
private String[] indices;
4647
private String[] types;
4748
private String[] fields;
49+
private float minScore;
4850

4951
public NativeSearchQueryBuilder withQuery(QueryBuilder queryBuilder) {
5052
this.queryBuilder = queryBuilder;
@@ -91,6 +93,11 @@ public NativeSearchQueryBuilder withFields(String... fields) {
9193
return this;
9294
}
9395

96+
public NativeSearchQueryBuilder withMinScore(float minScore) {
97+
this.minScore = minScore;
98+
return this;
99+
}
100+
94101
public NativeSearchQuery build() {
95102
NativeSearchQuery nativeSearchQuery = new NativeSearchQuery(queryBuilder, filterBuilder, sortBuilder, highlightFields);
96103
if (pageable != null) {
@@ -108,6 +115,10 @@ public NativeSearchQuery build() {
108115
if (CollectionUtils.isNotEmpty(facetRequests)) {
109116
nativeSearchQuery.setFacets(facetRequests);
110117
}
118+
119+
if(minScore>0){
120+
nativeSearchQuery.setMinScore(minScore);
121+
}
111122
return nativeSearchQuery;
112123
}
113124
}

src/main/java/org/springframework/data/elasticsearch/core/query/Query.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,10 @@ public interface Query {
108108
* @return
109109
*/
110110
List<String> getFields();
111+
112+
/**
113+
* Get minimum score
114+
* @return
115+
*/
116+
float getMinScore();
111117
}

src/test/java/org/springframework/data/elasticsearch/SampleEntity.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,17 @@ public int hashCode() {
112112
return new HashCodeBuilder().append(id).append(type).append(message).append(rate).append(available).append(version)
113113
.toHashCode();
114114
}
115+
116+
@Override
117+
public String toString() {
118+
return "SampleEntity{" +
119+
"id='" + id + '\'' +
120+
", type='" + type + '\'' +
121+
", message='" + message + '\'' +
122+
", rate=" + rate +
123+
", available=" + available +
124+
", highlightedMessage='" + highlightedMessage + '\'' +
125+
", version=" + version +
126+
'}';
127+
}
115128
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package org.springframework.data.elasticsearch;
2+
3+
import org.springframework.data.elasticsearch.core.query.IndexQuery;
4+
5+
/**
6+
* User: dead
7+
* Date: 23/01/14
8+
* Time: 18:25
9+
*/
10+
public class SampleEntityBuilder {
11+
12+
private SampleEntity result;
13+
14+
public SampleEntityBuilder(String id) {
15+
result = new SampleEntity();
16+
result.setId(id);
17+
}
18+
19+
public SampleEntityBuilder type(String type) {
20+
result.setType(type);
21+
return this;
22+
}
23+
24+
public SampleEntityBuilder message(String message) {
25+
result.setMessage(message);
26+
return this;
27+
}
28+
29+
public SampleEntityBuilder rate(int rate) {
30+
result.setRate(rate);
31+
return this;
32+
}
33+
34+
public SampleEntityBuilder available(boolean available) {
35+
result.setAvailable(available);
36+
return this;
37+
}
38+
39+
public SampleEntityBuilder highlightedMessage(String highlightedMessage) {
40+
result.setHighlightedMessage(highlightedMessage);
41+
return this;
42+
}
43+
44+
public SampleEntityBuilder version(Long version) {
45+
result.setVersion(version);
46+
return this;
47+
}
48+
49+
public SampleEntity build() {
50+
return result;
51+
}
52+
53+
public IndexQuery buildIndex() {
54+
IndexQuery indexQuery = new IndexQuery();
55+
indexQuery.setId(result.getId());
56+
indexQuery.setObject(result);
57+
return indexQuery;
58+
}
59+
}

src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.data.domain.Sort;
3434
import org.springframework.data.elasticsearch.ElasticsearchException;
3535
import org.springframework.data.elasticsearch.SampleEntity;
36+
import org.springframework.data.elasticsearch.SampleEntityBuilder;
3637
import org.springframework.data.elasticsearch.SampleMappingEntity;
3738
import org.springframework.data.elasticsearch.core.query.*;
3839
import org.springframework.test.context.ContextConfiguration;
@@ -1007,4 +1008,30 @@ public void shouldReturnIds(){
10071008
assertThat(ids, is(notNullValue()));
10081009
assertThat(ids.size(), is(30));
10091010
}
1011+
1012+
@Test
1013+
public void shouldReturnDocumentAboveMinimalScoreGivenQuery() {
1014+
// given
1015+
List<IndexQuery> indexQueries = new ArrayList<IndexQuery>();
1016+
1017+
indexQueries.add(new SampleEntityBuilder("1").message("ab").buildIndex());
1018+
indexQueries.add(new SampleEntityBuilder("2").message("bc").buildIndex());
1019+
indexQueries.add(new SampleEntityBuilder("3").message("ac").buildIndex());
1020+
1021+
elasticsearchTemplate.bulkIndex(indexQueries);
1022+
elasticsearchTemplate.refresh(SampleEntity.class, true);
1023+
1024+
// when
1025+
SearchQuery searchQuery = new NativeSearchQueryBuilder()
1026+
.withQuery(boolQuery().must(wildcardQuery("message", "*a*")).should(wildcardQuery("message", "*b*")))
1027+
.withIndices("test-index")
1028+
.withTypes("test-type")
1029+
.withMinScore(0.5F)
1030+
.build();
1031+
1032+
Page<SampleEntity> page = elasticsearchTemplate.queryForPage(searchQuery, SampleEntity.class);
1033+
// then
1034+
assertThat(page.getTotalElements(),is(1L));
1035+
assertThat(page.getContent().get(0).getMessage(), is("ab"));
1036+
}
10101037
}

src/test/java/org/springframework/data/elasticsearch/core/query/CriteriaQueryTests.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
*/
1616
package org.springframework.data.elasticsearch.core.query;
1717

18+
import org.elasticsearch.search.sort.FieldSortBuilder;
19+
import org.elasticsearch.search.sort.SortOrder;
1820
import org.junit.Before;
1921
import org.junit.Ignore;
2022
import org.junit.Test;
2123
import org.junit.runner.RunWith;
2224
import org.springframework.data.domain.Page;
2325
import org.springframework.data.elasticsearch.SampleEntity;
26+
import org.springframework.data.elasticsearch.SampleEntityBuilder;
2427
import org.springframework.data.elasticsearch.core.ElasticsearchTemplate;
2528
import org.springframework.test.context.ContextConfiguration;
2629
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
@@ -30,6 +33,7 @@
3033
import java.util.List;
3134

3235
import static org.apache.commons.lang.RandomStringUtils.randomNumeric;
36+
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
3337
import static org.hamcrest.Matchers.*;
3438
import static org.junit.Assert.*;
3539

@@ -700,4 +704,25 @@ public void shouldPerformBoostOperation() {
700704
// then
701705
assertThat(page.getTotalElements(), is(greaterThanOrEqualTo(1L)));
702706
}
707+
708+
@Test
709+
public void shouldReturnDocumentAboveMinimalScoreGivenCriteria() {
710+
// given
711+
List<IndexQuery> indexQueries = new ArrayList<IndexQuery>();
712+
713+
indexQueries.add(new SampleEntityBuilder("1").message("ab").buildIndex());
714+
indexQueries.add(new SampleEntityBuilder("2").message("bc").buildIndex());
715+
indexQueries.add(new SampleEntityBuilder("3").message("ac").buildIndex());
716+
717+
elasticsearchTemplate.bulkIndex(indexQueries);
718+
elasticsearchTemplate.refresh(SampleEntity.class, true);
719+
720+
// when
721+
CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria("message").contains("a").or(new Criteria("message").contains("b")));
722+
criteriaQuery.setMinScore(0.5F);
723+
Page<SampleEntity> page = elasticsearchTemplate.queryForPage(criteriaQuery, SampleEntity.class);
724+
// then
725+
assertThat(page.getTotalElements(),is(1L));
726+
assertThat(page.getContent().get(0).getMessage(), is("ab"));
727+
}
703728
}

0 commit comments

Comments
 (0)