Skip to content

Commit

Permalink
Plots: Better confusion matrix, and normalized version to (iterative#…
Browse files Browse the repository at this point in the history
…4775)

* Add normalized confusion plot, and add text labels

* Impute missing XY combinations
  • Loading branch information
sjawhar authored Nov 4, 2020
1 parent d1441ee commit 6494bd0
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 16 deletions.
171 changes: 157 additions & 14 deletions dvc/repo/plots/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,164 @@ class DefaultConfusionTemplate(Template):
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
"data": {"values": Template.anchor("data")},
"title": Template.anchor("title"),
"mark": "rect",
"encoding": {
"x": {
"field": Template.anchor("x"),
"type": "nominal",
"sort": "ascending",
"title": Template.anchor("x_label"),
"facet": {"field": "rev", "type": "nominal"},
"spec": {
"transform": [
{
"aggregate": [{"op": "count", "as": "xy_count"}],
"groupby": [Template.anchor("y"), Template.anchor("x")],
},
{
"impute": "xy_count",
"groupby": ["rev", Template.anchor("y")],
"key": Template.anchor("x"),
"value": 0,
},
{
"impute": "xy_count",
"groupby": ["rev", Template.anchor("x")],
"key": Template.anchor("y"),
"value": 0,
},
{
"joinaggregate": [
{"op": "max", "field": "xy_count", "as": "max_count"}
],
"groupby": [],
},
{
"calculate": "datum.xy_count / datum.max_count",
"as": "percent_of_max",
},
],
"encoding": {
"x": {
"field": Template.anchor("x"),
"type": "nominal",
"sort": "ascending",
"title": Template.anchor("x_label"),
},
"y": {
"field": Template.anchor("y"),
"type": "nominal",
"sort": "ascending",
"title": Template.anchor("y_label"),
},
},
"y": {
"field": Template.anchor("y"),
"type": "nominal",
"sort": "ascending",
"title": Template.anchor("y_label"),
"layer": [
{
"mark": "rect",
"width": 300,
"height": 300,
"encoding": {
"color": {
"field": "xy_count",
"type": "quantitative",
"title": "",
"scale": {"domainMin": 0, "nice": True},
}
},
},
{
"mark": "text",
"encoding": {
"text": {"field": "xy_count", "type": "quantitative"},
"color": {
"condition": {
"test": "datum.percent_of_max > 0.5",
"value": "white",
},
"value": "black",
},
},
},
],
},
}


class NormalizedConfusionTemplate(Template):
DEFAULT_NAME = "confusion_normalized"
DEFAULT_CONTENT = {
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
"data": {"values": Template.anchor("data")},
"title": Template.anchor("title"),
"facet": {"field": "rev", "type": "nominal"},
"spec": {
"transform": [
{
"aggregate": [{"op": "count", "as": "xy_count"}],
"groupby": [Template.anchor("y"), Template.anchor("x")],
},
{
"impute": "xy_count",
"groupby": ["rev", Template.anchor("y")],
"key": Template.anchor("x"),
"value": 0,
},
{
"impute": "xy_count",
"groupby": ["rev", Template.anchor("x")],
"key": Template.anchor("y"),
"value": 0,
},
{
"joinaggregate": [
{"op": "sum", "field": "xy_count", "as": "sum_y"}
],
"groupby": [Template.anchor("y")],
},
{
"calculate": "datum.xy_count / datum.sum_y",
"as": "percent_of_y",
},
],
"encoding": {
"x": {
"field": Template.anchor("x"),
"type": "nominal",
"sort": "ascending",
"title": Template.anchor("x_label"),
},
"y": {
"field": Template.anchor("y"),
"type": "nominal",
"sort": "ascending",
"title": Template.anchor("y_label"),
},
},
"color": {"aggregate": "count", "type": "quantitative"},
"facet": {"field": "rev", "type": "nominal"},
"layer": [
{
"mark": "rect",
"width": 300,
"height": 300,
"encoding": {
"color": {
"field": "percent_of_y",
"type": "quantitative",
"title": "",
"scale": {"domain": [0, 1]},
}
},
},
{
"mark": "text",
"encoding": {
"text": {
"field": "percent_of_y",
"type": "quantitative",
"format": ".2f",
},
"color": {
"condition": {
"test": "datum.percent_of_y > 0.5",
"value": "white",
},
"value": "black",
},
},
},
],
},
}

Expand Down Expand Up @@ -219,6 +361,7 @@ class PlotTemplates:
TEMPLATES = [
DefaultLinearTemplate,
DefaultConfusionTemplate,
NormalizedConfusionTemplate,
DefaultScatterTemplate,
SmoothLinearTemplate,
]
Expand Down
42 changes: 40 additions & 2 deletions tests/func/metrics/plots/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,46 @@ def test_plot_confusion(tmp_dir, dvc, run_copy_metrics):
{"predicted": "B", "actual": "A", "rev": "workspace"},
{"predicted": "A", "actual": "A", "rev": "workspace"},
]
assert plot_content["encoding"]["x"]["field"] == "predicted"
assert plot_content["encoding"]["y"]["field"] == "actual"
assert plot_content["spec"]["transform"][0]["groupby"] == [
"actual",
"predicted",
]
assert plot_content["spec"]["encoding"]["x"]["field"] == "predicted"
assert plot_content["spec"]["encoding"]["y"]["field"] == "actual"


def test_plot_confusion_normalized(tmp_dir, dvc, run_copy_metrics):
confusion_matrix = [
{"predicted": "B", "actual": "A"},
{"predicted": "A", "actual": "A"},
]
_write_json(tmp_dir, confusion_matrix, "metric_t.json")
run_copy_metrics(
"metric_t.json",
"metric.json",
plots_no_cache=["metric.json"],
commit="first run",
)

props = {
"template": "confusion_normalized",
"x": "predicted",
"y": "actual",
}
plot_string = dvc.plots.show(props=props)["metric.json"]

plot_content = json.loads(plot_string)
assert plot_content["data"]["values"] == [
{"predicted": "B", "actual": "A", "rev": "workspace"},
{"predicted": "A", "actual": "A", "rev": "workspace"},
]
assert plot_content["spec"]["transform"][0]["groupby"] == [
"actual",
"predicted",
]
assert plot_content["spec"]["transform"][1]["groupby"] == ["rev", "actual"]
assert plot_content["spec"]["encoding"]["x"]["field"] == "predicted"
assert plot_content["spec"]["encoding"]["y"]["field"] == "actual"


def test_plot_multiple_revs_default(tmp_dir, scm, dvc, run_copy_metrics):
Expand Down

0 comments on commit 6494bd0

Please sign in to comment.