Skip to content

Commit

Permalink
ASR confidence bug fix for older Python versions (NVIDIA#5180)
Browse files Browse the repository at this point in the history
* math.prod fix and rebase leftover trim

Signed-off-by: Aleksandr Laptev <[email protected]>

* remove unused import

Signed-off-by: Aleksandr Laptev <[email protected]>

Signed-off-by: Aleksandr Laptev <[email protected]>
Co-authored-by: Aleksandr Laptev <[email protected]>
  • Loading branch information
GNroy and GNroy authored Oct 17, 2022
1 parent d933ee8 commit ee98f8d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 34 deletions.
37 changes: 5 additions & 32 deletions nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import copy
import math
from abc import abstractmethod
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -97,7 +96,7 @@ class AbstractRNNTDecoding(ConfidenceMixin):
The length of the list corresponds to the number of recognized words.
exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded
from the `token_confidence`.
reduction: Which reduction type to use for collapsing per-token confidence into per-word confidence.
aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down Expand Up @@ -202,32 +201,6 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
self.joint_fused_batch_size = self.cfg.get('fused_batch_size', None)
self.compute_timestamps = self.cfg.get('compute_timestamps', None)
self.word_seperator = self.cfg.get('word_seperator', ' ')
self.confidence_cfg = self.cfg.get('confidence_cfg', None)
if self.confidence_cfg is not None:
self.preserve_word_confidence = self.confidence_cfg.get('preserve_word_confidence', False)
# set preserve_frame_confidence and preserve_token_confidence to True
# if preserve_word_confidence is True
self.preserve_token_confidence = (
self.confidence_cfg.get('preserve_token_confidence', False) | self.preserve_word_confidence
)
# set preserve_frame_confidence to True if preserve_token_confidence is True
self.preserve_frame_confidence = (
self.confidence_cfg.get('preserve_frame_confidence', False) | self.preserve_token_confidence
)
self.exclude_blank_from_confidence = self.confidence_cfg.get('exclude_blank', True)
self.word_confidence_reduction = self.confidence_cfg.get('reduction', "min")
self.confidence_method_cfg = self.confidence_cfg.get('method_cfg', None)
else:
self.preserve_frame_confidence = False
self.preserve_token_confidence = False
self.preserve_word_confidence = False
self.exclude_blank_from_confidence = True
self.word_confidence_reduction = "min"
self.confidence_method_cfg = None

# define reduction functions
self.reduction_function_bank = {"mean": (lambda x: sum(x) / len(x)), "min": min, "max": max, "prod": math.prod}
self.reduction_function = self.reduction_function_bank[self.word_confidence_reduction]

possible_strategies = ['greedy', 'greedy_batch', 'beam', 'tsd', 'alsd', 'maes']
if self.cfg.strategy not in possible_strategies:
Expand All @@ -253,6 +226,9 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
if self.compute_timestamps is True and self.preserve_alignments is False:
raise ValueError("If `compute_timesteps` flag is set, then `preserve_alignments` flag must also be set.")

# initialize confidence-related fields
self._init_confidence(self.cfg.get('confidence_cfg', None))

# Update preserve frame confidence
if self.preserve_frame_confidence is False:
if self.cfg.strategy in ['greedy', 'greedy_batch']:
Expand All @@ -263,9 +239,6 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
# Not implemented
pass

# initialize confidence-related fields
self._init_confidence(self.cfg.get('confidence_cfg', None))

if self.cfg.strategy == 'greedy':

self.decoding = greedy_decode.GreedyRNNTInfer(
Expand Down Expand Up @@ -932,7 +905,7 @@ class RNNTDecoding(AbstractRNNTDecoding):
The length of the list corresponds to the number of recognized words.
exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded
from the `token_confidence`.
reduction: Which reduction type to use for collapsing per-token confidence into per-word confidence.
aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class RNNTBPEDecoding(AbstractRNNTDecoding):
The length of the list corresponds to the number of recognized words.
exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded
from the `token_confidence`.
reduction: Which reduction type to use for collapsing per-token confidence into per-word confidence.
aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence.
Valid options are `mean`, `min`, `max`, `prod`.
method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
Expand Down
11 changes: 10 additions & 1 deletion nemo/collections/asr/parts/utils/asr_confidence_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,16 @@ def get_confidence_aggregation_bank():
Returns:
dictionary with functions.
"""
return {"mean": (lambda x: sum(x) / len(x)), "min": min, "max": max, "prod": math.prod}
confidence_aggregation_bank = {"mean": lambda x: sum(x) / len(x), "min": min, "max": max}
# python 3.7 and earlier do not have math.prod
if hasattr(math, "prod"):
confidence_aggregation_bank["prod"] = math.prod
else:
from functools import reduce
import operator

confidence_aggregation_bank["prod"] = lambda x: reduce(operator.mul, x, 1)
return confidence_aggregation_bank


class ConfidenceMeasureMixin(ABC):
Expand Down

0 comments on commit ee98f8d

Please sign in to comment.