Skip to content

Commit

Permalink
label ambiguity notebook improvements (deepchecks#760)
Browse files Browse the repository at this point in the history
* label ambiguity notebook improvements
  • Loading branch information
benisraeldan authored Jan 26, 2022
1 parent bb5b4a3 commit dc40e86
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 170 deletions.
6 changes: 5 additions & 1 deletion deepchecks/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,8 @@ def columns_info(self) -> t.Dict[Hashable, str]:
def select(
self: TDataset,
columns: t.Union[Hashable, t.List[Hashable], None] = None,
ignore_columns: t.Union[Hashable, t.List[Hashable], None] = None
ignore_columns: t.Union[Hashable, t.List[Hashable], None] = None,
keep_label: bool = False
) -> TDataset:
"""Filter dataset columns by given params.
Expand All @@ -786,6 +787,9 @@ def select(
DeepchecksValueError
In case one of columns given don't exists raise error
"""
if keep_label and columns and self.label_name not in columns:
columns.append(self.label_name)

new_data = select_from_dataframe(self._data, columns, ignore_columns)
if new_data.equals(self.data):
return self
Expand Down
9 changes: 5 additions & 4 deletions deepchecks/checks/integrity/label_ambiguity.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def run_logic(self, context: CheckRunContext, dataset_type: str = 'train') -> Ch

context.assert_classification_task()

dataset = dataset.select(self.columns, self.ignore_columns)
dataset = dataset.select(self.columns, self.ignore_columns, keep_label=True)

label_col = dataset.label_name
label_col = context.label_name

# HACK: pandas have bug with groupby on category dtypes, so until it fixed, change dtypes manually
df = dataset.data
Expand All @@ -87,7 +87,7 @@ def run_logic(self, context: CheckRunContext, dataset_type: str = 'train') -> Ch

group_df = group_data[1]
sample_values = dict(group_df[dataset.features].iloc[0])
labels = tuple(group_df[label_col].unique())
labels = tuple(sorted(group_df[label_col].unique()))
n_data_sample = group_df.shape[0]
num_ambiguous += n_data_sample

Expand All @@ -96,7 +96,8 @@ def run_logic(self, context: CheckRunContext, dataset_type: str = 'train') -> Ch
display = display.set_index(ambiguous_label_name)

explanation = ('Each row in the table shows an example of a data sample '
'and the its observed labels as found in the dataset.')
'and the its observed labels as found in the dataset. '
f'Showing top {self.n_to_show} of {display.shape[0]}')

display = None if display.empty else [explanation, display.head(self.n_to_show)]

Expand Down
Loading

0 comments on commit dc40e86

Please sign in to comment.