diff --git a/dedupe/core.py b/dedupe/core.py index 25ca924d5..e79bb6de7 100644 --- a/dedupe/core.py +++ b/dedupe/core.py @@ -19,7 +19,12 @@ Any, Type, Iterable) -from dedupe._typing import (RecordPairs, RecordID, Blocks, Data, Literal) +from dedupe._typing import (RecordPairs, + RecordID, + RecordDict, + Blocks, + Data, + Literal) import numpy import multiprocessing @@ -135,7 +140,9 @@ def __call__(self) -> None: def fieldDistance(self, record_pairs: RecordPairs) -> Optional[Tuple]: - record_ids, records = zip(*(zip(*record_pair) for record_pair in record_pairs)) + record_ids: Sequence[Tuple[RecordID, RecordID]] + records: Sequence[Tuple[RecordDict, RecordDict]] + record_ids, records = zip(*(zip(*record_pair) for record_pair in record_pairs)) # type: ignore if records: @@ -293,7 +300,9 @@ def __init__(self, data_model, classifier): def __call__(self, block: RecordPairs) -> numpy.ndarray: - record_ids, records = zip(*(zip(*each) for each in block)) + record_ids: Sequence[Tuple[RecordID, RecordID]] + records: Sequence[Tuple[RecordDict, RecordDict]] + record_ids, records = zip(*(zip(*each) for each in block)) # type: ignore distances = self.data_model.distances(records) scores = self.classifier.predict_proba(distances)[:, -1]