Skip to content

Commit

Permalink
report errors and successes of NER
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Aug 29, 2019
1 parent 5c7e86c commit 0e54b77
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 7 deletions.
102 changes: 97 additions & 5 deletions rasa/nlu/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,11 @@
NO_ENTITY = "no_entity"

IntentEvaluationResult = namedtuple(
"IntentEvaluationResult",
"intent_target " "intent_prediction " "message " "confidence",
"IntentEvaluationResult", "intent_target intent_prediction message confidence"
)

EntityEvaluationResult = namedtuple(
"EntityEvaluationResult", "entity_targets " "entity_predictions " "tokens"
"EntityEvaluationResult", "entity_targets entity_predictions tokens message"
)

IntentMetrics = Dict[Text, List[float]]
Expand Down Expand Up @@ -422,11 +421,87 @@ def substitute_labels(labels: List[Text], old: Text, new: Text) -> List[Text]:
return [new if label == old else label for label in labels]


def collect_entity_errors(
entity_results: List[EntityEvaluationResult],
merged_targets: List[Text],
merged_predictions: List[Text],
error_filename: Text,
):
errors = []

offset = 0
for entity_result in entity_results:
error = False
for i in range(offset, offset + len(entity_result.tokens)):
if merged_targets[i] != merged_predictions[i]:
error = True
break

if error:
errors.append(
{
"text": entity_result.message,
"entities": entity_result.entity_targets,
"predicted_entities": entity_result.entity_predictions,
}
)
offset += len(entity_result.tokens)

if errors:
utils.write_json_to_file(error_filename, errors)
logger.info("Model prediction errors saved to {}.".format(error_filename))
logger.debug(
"\n\nThese intent examples could not be classified "
"correctly: \n{}".format(errors)
)
else:
logger.info("Your model made no error predictions.")


def collect_entity_successes(
entity_results: List[EntityEvaluationResult],
merged_targets: List[Text],
merged_predictions: List[Text],
successes_filename: Text,
):
successes = []

offset = 0
for entity_result in entity_results:
success = False
for i in range(offset, offset + len(entity_result.tokens)):
if merged_targets[i] == merged_predictions[i]:
success = True
break

if success:
successes.append(
{
"text": entity_result.message,
"entities": entity_result.entity_targets,
"predicted_entities": entity_result.entity_predictions,
}
)

if successes:
utils.write_json_to_file(successes_filename, successes)
logger.info(
"Model prediction successes saved to {}.".format(successes_filename)
)
logger.debug(
"\n\nSuccessfully predicted the following entities: \n{}".format(successes)
)
else:
logger.info("Your model made no successful predictions.")


def evaluate_entities(
entity_results: List[EntityEvaluationResult],
extractors: Set[Text],
report_folder: Optional[Text],
output_folder: Optional[Text] = None,
output_folder: Optional[Text],
successes_filename: Optional[Text] = None,
errors_filename: Optional[Text] = None,
) -> Dict: # pragma: no cover
"""Creates summary statistics for each entity extractor.
Logs precision, recall, and F1 per entity type for each extractor."""
Expand Down Expand Up @@ -475,6 +550,22 @@ def evaluate_entities(
"accuracy": accuracy,
}

if successes_filename:
if output_folder:
successes_filename = os.path.join(output_folder, successes_filename)
# save classified samples to file for debugging
collect_entity_successes(
entity_results, merged_targets, merged_predictions, successes_filename
)

if errors_filename:
if output_folder:
errors_filename = os.path.join(output_folder, errors_filename)
# log and save misclassified samples to file for debugging
collect_entity_errors(
entity_results, merged_targets, merged_predictions, errors_filename
)

return result


Expand Down Expand Up @@ -679,6 +770,7 @@ def get_eval_data(
example.get("entities", []),
result.get("entities", []),
result.get("tokens", []),
result.get("text", ""),
)
)

Expand Down Expand Up @@ -772,7 +864,7 @@ def run_evaluation(
logger.info("Entity evaluation results:")
extractors = get_entity_extractors(interpreter)
result["entity_evaluation"] = evaluate_entities(
entity_results, extractors, report, out_directory
entity_results, extractors, report, out_directory, successes, errors
)

return result
Expand Down
6 changes: 4 additions & 2 deletions tests/nlu/base/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ async def pretrained_interpreter(component_builder, tmpdir_factory):
},
]

EN_entity_result = EntityEvaluationResult(EN_targets, EN_predicted, EN_tokens)
EN_entity_result = EntityEvaluationResult(
EN_targets, EN_predicted, EN_tokens, " ".join([t.text for t in EN_tokens])
)

EN_entity_result_no_tokens = EntityEvaluationResult(EN_targets, EN_predicted, [])
EN_entity_result_no_tokens = EntityEvaluationResult(EN_targets, EN_predicted, [], "")


def test_token_entity_intersection():
Expand Down

0 comments on commit 0e54b77

Please sign in to comment.