Skip to content

Commit

Permalink
Improving and cleaning up tests
Browse files Browse the repository at this point in the history
Removing the unnecessary RankEvalTestHelper, making use of the common test infra
in ESTestCase, also hardening a few of the classes by making more fields final.
cbuescher committed Nov 21, 2017
1 parent 5c65a59 commit e278c1d
Showing 25 changed files with 448 additions and 634 deletions.
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchHit;
@@ -35,22 +35,25 @@
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;

public class DiscountedCumulativeGain implements RankedListQualityMetric {
public class DiscountedCumulativeGain implements EvaluationMetric {

/** If set to true, the dcg will be normalized (ndcg) */
private boolean normalize;
private final boolean normalize;

/**
* If set to, this will be the rating for docs the user hasn't supplied an
* explicit rating for
*/
private Integer unknownDocRating;
private final Integer unknownDocRating;

public static final String NAME = "dcg";
private static final double LOG2 = Math.log(2.0);

public DiscountedCumulativeGain() {
this(false, null);
}

/**
@@ -82,27 +85,13 @@ public String getWriteableName() {
return NAME;
}

/**
* If set to true, the dcg will be normalized (ndcg)
*/
public void setNormalize(boolean normalize) {
this.normalize = normalize;
}

/**
* check whether this metric computes only dcg or "normalized" ndcg
*/
public boolean getNormalize() {
return this.normalize;
}

/**
* the rating for docs the user hasn't supplied an explicit rating for
*/
public void setUnknownDocRating(int unknownDocRating) {
this.unknownDocRating = unknownDocRating;
}

/**
* get the rating used for unrated documents
*/
@@ -118,10 +107,10 @@ public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
List<RatedSearchHit> ratedHits = joinHitsWithRatings(hits, ratedDocs);
List<Integer> ratingsInSearchHits = new ArrayList<>(ratedHits.size());
for (RatedSearchHit hit : ratedHits) {
// unknownDocRating might be null, which means it will be unrated
// docs are ignored in the dcg calculation
// we still need to add them as a placeholder so the rank of the
// subsequent ratings is correct
// unknownDocRating might be null, which means it will be unrated docs are
// ignored in the dcg calculation
// we still need to add them as a placeholder so the rank of the subsequent
// ratings is correct
ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating));
}
double dcg = computeDCG(ratingsInSearchHits);
@@ -151,12 +140,15 @@ private static double computeDCG(List<Integer> ratings) {

private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
private static final ObjectParser<DiscountedCumulativeGain, Void> PARSER = new ObjectParser<>(
"dcg_at", () -> new DiscountedCumulativeGain());
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at",
args -> {
Boolean normalized = (Boolean) args[0];
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]);
});

static {
PARSER.declareBoolean(DiscountedCumulativeGain::setNormalize, NORMALIZE_FIELD);
PARSER.declareInt(DiscountedCumulativeGain::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD);
PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
}

public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
@@ -193,6 +185,4 @@ public final boolean equals(Object obj) {
public final int hashCode() {
return Objects.hash(normalize, unknownDocRating);
}

// TODO maybe also add debugging breakdown here
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;

import java.io.IOException;
import java.util.ArrayList;
@@ -91,8 +92,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject(id);
builder.field("quality_level", this.qualityLevel);
builder.startArray("unknown_docs");
for (DocumentKey key : RankedListQualityMetric.filterUnknownDocuments(hits)) {
key.toXContent(builder, params);
for (DocumentKey key : EvaluationMetric.filterUnknownDocuments(hits)) {
builder.startObject();
builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), key.getIndex());
builder.field(RatedDocument.DOC_ID_FIELD.getPreferredName(), key.getDocId());
builder.endObject();
}
builder.endArray();
builder.startArray("hits");
Original file line number Diff line number Diff line change
@@ -24,7 +24,9 @@
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;

import java.io.IOException;
import java.util.ArrayList;
@@ -35,13 +37,14 @@
import java.util.stream.Collectors;

/**
* Classes implementing this interface provide a means to compute the quality of a result list returned by some search.
* Implementations of {@link EvaluationMetric} need to provide a way to compute the quality metric for
* a result list returned by some search (@link {@link SearchHits}) and a list of rated documents.
*/
public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
public interface EvaluationMetric extends ToXContent, NamedWriteable {

/**
* Returns a single metric representing the ranking quality of a set of returned
* documents wrt. to a set of document Ids labeled as relevant for this search.
* documents wrt. to a set of document ids labeled as relevant for this search.
*
* @param taskId
* the id of the query for which the ranking is currently evaluated
@@ -55,15 +58,15 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
*/
EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);

static RankedListQualityMetric fromXContent(XContentParser parser) throws IOException {
RankedListQualityMetric rc;
static EvaluationMetric fromXContent(XContentParser parser) throws IOException {
EvaluationMetric rc;
Token token = parser.nextToken();
if (token != XContentParser.Token.FIELD_NAME) {
throw new ParsingException(parser.getTokenLocation(), "[_na] missing required metric name");
}
String metricName = parser.currentName();

// TODO maybe switch to using a plugable registry later?
// TODO switch to using a plugable registry
switch (metricName) {
case PrecisionAtK.NAME:
rc = PrecisionAtK.fromXContent(parser);
@@ -101,13 +104,19 @@ static List<RatedSearchHit> joinHitsWithRatings(SearchHit[] hits, List<RatedDocu
return ratedSearchHits;
}

/**
* filter @link {@link RatedSearchHit} that don't have a rating
*/
static List<DocumentKey> filterUnknownDocuments(List<RatedSearchHit> ratedHits) {
// join hits with rated documents
List<DocumentKey> unknownDocs = ratedHits.stream().filter(hit -> hit.getRating().isPresent() == false)
.map(hit -> new DocumentKey(hit.getSearchHit().getIndex(), hit.getSearchHit().getId())).collect(Collectors.toList());
return unknownDocs;
}

/**
* how evaluation metrics for particular search queries get combined for the overall evaluation score.
* Defaults to averaging over the partial results.
*/
default double combine(Collection<EvalQueryQuality> partialResults) {
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
}
Loading

0 comments on commit e278c1d

Please sign in to comment.