Skip to content

Commit

Permalink
Drop NULLs from ground truth and fixed lookup of values with only
Browse files Browse the repository at this point in the history
NULL co-occurring values.
  • Loading branch information
richardwu committed Feb 15, 2019
1 parent 3d58ace commit 1e573a0
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 53 deletions.
8 changes: 7 additions & 1 deletion dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ def get_statistics(self):
<val1>: all values of <attr1>
<val2>: values of <attr2> that appear at least once with <val1>.
<count>: frequency (# of entities) where attr1=val1 AND attr2=val2
NB: neither single_attr_stats nor pair_attr_stats contain frequencies
for values that are NULL (NULL_REPR). One would need to explicitly
check if the value is NULL before lookup.
Also, values that only co-occur with NULLs will NOT be in pair_attr_stats.
"""
if not self.stats_ready:
logging.debug('computing frequency and co-occurrence statistics from raw data...')
Expand Down Expand Up @@ -260,7 +266,7 @@ def get_stats_pair(self, first_attr, second_attr):
<first_val>: all possible values for first_attr
<second_val>: all values for second_attr that appear at least once with <first_val>
<count>: frequency (# of entities) where first_attr=<first_val> AND second_attr=<second_val>
Filters out NULL values so no entries in the dictionary would have nulls.
Filters out NULL values so no entries in the dictionary would have NULLs.
"""
data_df = self.get_raw_data()
tmp_df = data_df[[first_attr, second_attr]]\
Expand Down
63 changes: 35 additions & 28 deletions domain/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,8 @@ def get_corr_attributes(self, attr, thres):
def generate_domain(self):
"""
Generates the domain for each cell in the active attributes as well
as assigns variable IDs (_vid_) (increment key from 0 onwards, depends on
iteration order of rows/entities in raw data and attributes.
Note that _vid_ has a 1-1 correspondence with _cid_.
as assigns a random variable ID (_vid_) for cells that have
a domain of size >= 2.
See get_domain_cell for how the domain is generated from co-occurrence
and correlated attributes.
Expand All @@ -218,8 +216,8 @@ def generate_domain(self):
:return: DataFrame with columns
_tid_: entity/tuple ID
_cid_: cell ID (unique for every entity-attribute)
_vid_: variable ID (1-1 correspondence with _cid_)
_cid_: cell ID (one for every cell in the raw data in active attributes)
_vid_: random variable ID (one for every cell with a domain of at least size 2)
attribute: attribute name
domain: ||| separated string of domain values
domain_size: length of domain
Expand All @@ -243,10 +241,6 @@ def generate_domain(self):
tid = row['_tid_']
for attr in self.active_attributes:
init_value, init_value_idx, dom = self.get_domain_cell(attr, row)
# If init_value is NULL, we would not have it in the domain since we filtered it out.
init_value_idx = -1
if init_value in dom:
init_value_idx = dom.index(init_value)
# We will use an estimator model for additional weak labelling
# below, which requires an initial pruned domain first.
# Weak labels will be trained on the init values.
Expand All @@ -259,11 +253,10 @@ def generate_domain(self):
if len(dom) <= 1:
# Not enough domain values, we need to get some random
# values (other than 'init_value') for training. However,
# this might still get us zero domain values. We handle it
# next.
# this might still get us zero domain values.
rand_dom_values = self.get_random_domain(attr, init_value)

# the rand_dom_values might still be empty. In this case,
# rand_dom_values might still be empty. In this case,
# there are no other possible values for this cell. There
# is not point to use this cell for training and there is no
# point to run inference on it since we cannot even generate
Expand Down Expand Up @@ -339,8 +332,10 @@ def generate_domain(self):
# update our memoized domain values for this row again
row['domain'] = '|||'.join(domain_values)
row['domain_size'] = len(domain_values)
if row['init_value'] != NULL_REPR:
# update init index based on new domain
if row['init_value'] in domain_values:
row['init_index'] = domain_values.index(row['init_value'])
# update weak label index based on new domain
if row['weak_label'] != NULL_REPR:
row['weak_label_idx'] = domain_values.index(row['weak_label'])

Expand Down Expand Up @@ -390,7 +385,7 @@ def get_domain_cell(self, attr, row):
"""

domain = set()
attr_val = row[attr]
init_value = row[attr]
correlated_attributes = self.get_corr_attributes(attr, self.cor_strength)
# Iterate through all correlated attributes and take the top K co-occurrence values
# for 'attr' with the current row's 'cond_attr' value.
Expand All @@ -402,25 +397,35 @@ def get_domain_cell(self, attr, row):
logging.warning("domain generation could not find pair_statistics between attributes: {}, {}".format(cond_attr, attr))
continue
cond_val = row[cond_attr]
# Ignore co-occurrences with if any value is null.
if cond_val == NULL_REPR or attr_val == NULL_REPR or pd.isnull(cond_val) or pd.isnull(attr_val):
# Ignore co-occurrence with a NULL cond init value since we do not
# store them.
# Also it does not make sense to retrieve the top co-occuring
# values with a NULL value.
# It is possible for cond_val to not be in pair stats if it only co-occurs
# with NULL values.
if cond_val == NULL_REPR or cond_val not in self.pair_stats[cond_attr][attr]:
continue
s = self.pair_stats[cond_attr][attr]
candidates = s[cond_val]

# Update domain with top co-occuring values with the cond init value.
candidates = self.pair_stats[cond_attr][attr][cond_val]
domain.update(candidates)

# Remove NULL_REPR (_nan_) if added due to correlated attributes.
domain.discard(NULL_REPR)
# Add the initial value to the domain if it is not null.
if attr_val != NULL_REPR:
domain.update({attr_val})
# We should not have any NULLs since we do not store co-occurring NULL
# values.
assert NULL_REPR not in domain

# Add the initial value to the domain if it is not NULL.
if init_value != NULL_REPR:
domain.add(init_value)

# Convert to ordered list to preserve order.
domain_lst = sorted(list(domain))

# Get the index of the initial value. This should never raise a ValueError since we made sure
# that 'init_value' was added.
init_value_idx = domain_lst.index(init_value)
# Get the index of the initial value.
# NULL values are not in the domain so we set their index to -1.
init_value_idx = -1
if init_value != NULL_REPR:
init_value_idx = domain_lst.index(init_value)

return init_value, init_value_idx, domain_lst

Expand All @@ -430,8 +435,10 @@ def get_random_domain(self, attr, cur_value):
'self.max_sample' of domain values for 'attr' that is NOT 'cur_value'.
"""
domain_pool = set(self.single_stats[attr].keys())
# We should not have any NULLs since we do not keep track of their
# counts.
assert NULL_REPR not in domain_pool
domain_pool.discard(cur_value)
domain_pool.discard(NULL_REPR)
domain_pool = sorted(list(domain_pool))
size = len(domain_pool)
if size > 0:
Expand Down
7 changes: 7 additions & 0 deletions domain/estimators/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tqdm import tqdm

from ..estimator import Estimator
from utils import NULL_REPR


class NaiveBayes(Estimator):
Expand Down Expand Up @@ -41,6 +42,12 @@ def predict_pp(self, row, attr, values):
if at == attr or at == '_tid_':
continue
val2 = row[at]
# Since we do not have co-occurrence stats with NULL values,
# we skip them.
# It also doesn't make sense for our likelihood to be conditioned
# on a NULL value.
# if val2 == NULL_REPR:
# continue
val2_val1_count = 0.1
if val1 in self._cooccur_freq[attr][at]:
if val2 in self._cooccur_freq[attr][at][val1]:
Expand Down
6 changes: 6 additions & 0 deletions evaluate/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def load_data(self, name, fpath, tid_col, attr_col, val_col, na_values=None):
tic = time.clock()
try:
raw_data = pd.read_csv(fpath, na_values=na_values, encoding='utf-8')
import pdb; pdb.set_trace()
# We drop any ground truth values that are NULLs since we follow
# the closed-world assumption (if it's not there it's wrong).
# TODO: revisit this once we allow users to specify which
# attributes may be NULL.
raw_data.dropna(subset=[val_col], inplace=True)
raw_data.fillna(NULL_REPR, inplace=True)
raw_data.rename({tid_col: '_tid_',
attr_col: '_attribute_',
Expand Down
29 changes: 14 additions & 15 deletions examples/holoclean_repair_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,20 @@
hc.load_dcs('../testdata/hospital_constraints.txt')
hc.ds.set_constraints(hc.get_dcs())

# 3. Detect erroneous cells using these two detectors.
detectors = [NullDetector(), ViolationDetector()]
hc.detect_errors(detectors)

# 4. Repair errors utilizing the defined features.
hc.setup_domain()
featurizers = [
InitAttrFeaturizer(),
OccurAttrFeaturizer(),
FreqFeaturizer(),
ConstraintFeaturizer(),
LangModelFeaturizer(),
]

hc.repair_errors(featurizers)
# # 3. Detect erroneous cells using these two detectors.
# detectors = [NullDetector(), ViolationDetector()]
# hc.detect_errors(detectors)
#
# # 4. Repair errors utilizing the defined features.
# hc.setup_domain()
# featurizers = [
# InitAttrFeaturizer(),
# OccurAttrFeaturizer(),
# FreqFeaturizer(),
# ConstraintFeaturizer(),
# ]
#
# hc.repair_errors(featurizers)

# 5. Evaluate the correctness of the results.
hc.evaluate(fpath='../testdata/hospital_clean.csv',
Expand Down
19 changes: 12 additions & 7 deletions repair/featurize/featurized_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F

from dataset import AuxTables, CellStatus
from utils import NULL_REPR

FeatInfo = namedtuple('FeatInfo', ['name', 'size', 'learnable', 'init_weight', 'feature_names'])

Expand Down Expand Up @@ -65,13 +66,17 @@ def generate_weak_labels(self):
contains the domain index of the initial value for the i-th
variable/VID.
"""
# Trains with clean cells AND cells that have been weak labelled.
query = 'SELECT _vid_, weak_label_idx, fixed, (t2._cid_ IS NULL) AS clean ' \
'FROM {} AS t1 LEFT JOIN {} AS t2 ON t1._cid_ = t2._cid_ ' \
'WHERE t2._cid_ is NULL ' \
' OR t1.fixed != {};'.format(AuxTables.cell_domain.name,
AuxTables.dk_cells.name,
CellStatus.NOT_SET.value)
# Generate weak labels for clean cells AND cells that have been weak
# labelled. Do not train on cells with NULL weak labels (i.e.
# NULL init values that were not weak labelled).
query = """
SELECT _vid_, weak_label_idx, fixed, (t2._cid_ IS NULL) AS clean
FROM {cell_domain} AS t1 LEFT JOIN {dk_cells} AS t2 ON t1._cid_ = t2._cid_
WHERE weak_label != '{null_repr}' AND (t2._cid_ is NULL OR t1.fixed != {cell_status});
""".format(cell_domain=AuxTables.cell_domain.name,
dk_cells=AuxTables.dk_cells.name,
null_repr=NULL_REPR,
cell_status=CellStatus.NOT_SET.value)
res = self.ds.engine.execute_query(query)
if len(res) == 0:
raise Exception("No weak labels available. Reduce pruning threshold.")
Expand Down
6 changes: 4 additions & 2 deletions repair/featurize/occurattrfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,17 @@ def gen_feat_tensor(self, row, tuple):
rv_attr = row['attribute']
domain = row['domain'].split('|||')
rv_domain_idx = {val: idx for idx, val in enumerate(domain)}
# We should not have any NULLs in our domain.
assert NULL_REPR not in rv_domain_idx
rv_attr_idx = self.ds.attr_to_idx[rv_attr]
for attr in self.all_attrs:
val = tuple[attr]

# Ignore co-occurrences of same attribute or with null values.
# It's possible a value is not in pair_stats if it only co-occurred
# with NULL values.
if attr == rv_attr \
or pd.isnull(val) \
or val == NULL_REPR \
or val not in self.single_stats[attr] \
or val not in self.pair_stats[attr][rv_attr]:
continue
attr_idx = self.ds.attr_to_idx[attr]
Expand Down

0 comments on commit 1e573a0

Please sign in to comment.