forked from btheodorou99/HALO_Inpatient
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_privacy_nearest.py
83 lines (73 loc) · 2.45 KB
/
evaluate_privacy_nearest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import pickle
import random
import numpy as np
from tqdm import tqdm
from sklearn import metrics
from config import HALOConfig
SEED = 4
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
NUM_SAMPLES = 5000
config = HALOConfig()
train_ehr_dataset = pickle.load(open('./data/trainDataset.pkl', 'rb'))
train_ehr_dataset = np.random.choice(train_ehr_dataset, NUM_SAMPLES)
test_ehr_dataset = pickle.load(open('./data/testDataset.pkl', 'rb'))
test_ehr_dataset = np.random.choice(test_ehr_dataset, NUM_SAMPLES)
synthetic_ehr_dataset = pickle.load(open('./results/datasets/haloDataset.pkl', 'rb'))
synthetic_ehr_dataset = np.random.choice([p for p in synthetic_ehr_dataset if len(p['visits']) > 0], NUM_SAMPLES)
synthetic_ehr_dataset = [{'labels': p['labels'], 'visits': [set(v) for v in p['visits']]} for p in synthetic_ehr_dataset]
def find_hamming(ehr, dataset):
min_d = 1e10
visits = ehr['visits']
labels = ehr['labels']
for p in dataset:
d = 0 if len(visits) == len(p['visits']) else 1
l = p['labels']
d += ((labels + l) == 1).sum()
for i in range(len(visits)):
v = visits[i]
if i >= len(p['visits']):
d += len(v)
else:
v2 = p['visits'][i]
d += len(v) + len(v2) - (2 * len(v.intersection(v2)))
min_d = d if d < min_d and d > 0 else min_d
return min_d
def calc_nnaar(train, evaluation, synthetic):
val1 = 0
val2 = 0
val3 = 0
val4 = 0
for p in tqdm(evaluation):
des = find_hamming(p, synthetic)
dee = find_hamming(p, evaluation)
if des > dee:
val1 += 1
for p in tqdm(train):
dts = find_hamming(p, synthetic)
dtt = find_hamming(p, train)
if dts > dtt:
val3 += 1
for p in tqdm(synthetic):
dse = find_hamming(p, evaluation)
dst = find_hamming(p, train)
dss = find_hamming(p, synthetic)
if dse > dss:
val2 += 1
if dst > dss:
val4 += 1
val1 = val1 / NUM_SAMPLES
val2 = val2 / NUM_SAMPLES
val3 = val3 / NUM_SAMPLES
val4 = val4 / NUM_SAMPLES
aaes = (0.5 * val1) + (0.5 * val2)
aaet = (0.5 * val3) + (0.5 * val4)
return aaes - aaet
nnaar = calc_nnaar(train_ehr_dataset, test_ehr_dataset, synthetic_ehr_dataset)
results = {
"NNAAE": nnaar
}
pickle.dump(results, open("results/privacy_evaluation/nnaar.pkl", "wb"))
print(results)