-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathequal_odds.R
171 lines (154 loc) · 7.25 KB
/
equal_odds.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#' @title Equalized Odds
#'
#' @description
#' This function computes the Equalized Odds metric
#'
#' Formula: TP / (TP + FN)
#'
#' @details
#' This function computes the Equalized Odds metric (also known as Equal Opportunity, Positive Rate Parity or Separation). Equalized Odds are calculated
#' by the division of true positives with all positives (irrespective of predicted values). This metrics equals to
#' what is traditionally known as sensitivity. In the returned
#' named vector, the reference group will be assigned 1, while all other groups will be assigned values
#' according to whether their sensitivities are lower or higher compared to the reference group. Lower
#' sensitivities will be reflected in numbers lower than 1 in the returned named vector, thus numbers
#' lower than 1 mean WORSE prediction for the subgroup.
#'
#' @param data Data.frame that contains the necessary columns.
#' @param group Column name indicating the sensitive group (character).
#' @param base Base level of the sensitive group (character).
#' @param group_breaks If group is continuous (e.g., age): either a numeric vector of two or more unique cut points or a single number >= 2 giving the number of intervals into which group feature is to be cut.
#' @param outcome Column name indicating the binary outcome variable (character).
#' @param outcome_base Base level of the outcome variable (i.e., negative class). Default is the first level of the outcome variable.
#' @param probs Column name or vector with the predicted probabilities (numeric between 0 - 1). Either probs or preds need to be supplied.
#' @param preds Column name or vector with the predicted binary outcome (0 or 1). Either probs or preds need to be supplied.
#' @param cutoff Cutoff to generate predicted outcomes from predicted probabilities. Default set to 0.5.
#'
#' @name equal_odds
#'
#' @return
#' \item{Metric}{Raw sensitivities for all groups and metrics standardized for the base group (equalized odds parity metric). Lower values compared to the reference group mean lower sensitivities in the selected subgroups}
#' \item{Metric_plot}{Bar plot of Equalized Odds metric}
#' \item{Probability_plot}{Density plot of predicted probabilities per subgroup. Only plotted if probabilities are defined}
#'
#'
#' @examples
#' data(compas)
#' compas$Two_yr_Recidivism_01 <- ifelse(compas$Two_yr_Recidivism == 'yes', 1, 0)
#' equal_odds(data = compas, outcome = 'Two_yr_Recidivism_01', group = 'ethnicity',
#' probs = 'probability', cutoff = 0.4, base = 'Caucasian')
#' equal_odds(data = compas, outcome = 'Two_yr_Recidivism_01', group = 'ethnicity',
#' preds = 'predicted', cutoff = 0.5, base = 'Hispanic')
#'
#' @export
equal_odds <- function(data, outcome, group,
probs = NULL,
preds = NULL,
outcome_base = NULL,
cutoff = 0.5,
base = NULL,
group_breaks = NULL) {
# check if data is data.frame
if (class(data)[1] != 'data.frame') {
warning(paste0('Converting ', class(data)[1], ' to data.frame'))
data <- as.data.frame(data)
}
# convert types, sync levels
if (is.null(probs) & is.null(preds)) {
stop({'Either probs or preds have to be supplied'})
}
if (is.null(probs)) {
if (length(preds) == 1) {
preds <- data[, preds]
}
preds_status <- as.factor(preds)
} else {
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}
# check group feature and cut if needed
if ((length(unique(data[, group])) > 10) & (is.null(group_breaks))) {
warning('Number of unqiue group levels exceeds 10. Consider specifying `group_breaks`.')
}
if (!is.null(group_breaks)) {
if (is.numeric(data[, group])) {
data[, group] <- cut(data[, group], breaks = group_breaks)
}else{
warning('Attempting to bin a non-numeric group feature.')
}
}
# convert to factor
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])
# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}
# relevel preds & outcomes
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]
# check lengths
if ((length(outcome_status) != length(preds_status)) | (length(outcome_status) !=
length(group_status))) {
stop('Outcomes, predictions/probabilities and group status must be of the same length')
}
# relevel group
if (is.null(base)) {
base <- levels(group_status)[1]
}
group_status <- relevel(group_status, base)
# placeholders
val <- rep(NA, length(levels(group_status)))
names(val) <- levels(group_status)
sample_size <- val
# compute value for all groups
for (i in levels(group_status)) {
cm <- caret::confusionMatrix(preds_status[group_status == i],
outcome_status[group_status == i],
mode = 'everything',
positive = outcome_positive)
metric_i <- cm$byClass['Sensitivity']
val[i] <- metric_i
sample_size[i] <- sum(cm$table)
}
# aggregate results
res_table <- rbind(val, val/val[[1]], sample_size)
rownames(res_table) <- c('Sensitivity', 'Equalized odds', 'Group size')
# conversion of metrics to df
val_df <- as.data.frame(res_table[2, ])
colnames(val_df) <- c('val')
val_df$groupst <- rownames(val_df)
val_df$groupst <- as.factor(val_df$groupst)
# relevel group
if (is.null(base)) {
val_df$groupst <- levels(val_df$groupst)[1]
}
val_df$groupst <- relevel(val_df$groupst, base)
p <- ggplot(val_df, aes(x = groupst, weight = val, fill = groupst)) + geom_bar(alpha = 0.5) +
coord_flip() + theme(legend.position = 'none') + labs(x = '', y = 'Equalized Odds')
# plotting
if (!is.null(probs)) {
q <- ggplot(data, aes(x = probs, fill = group_status)) + geom_density(alpha = 0.5) +
labs(x = 'Predicted probabilities') + guides(fill = guide_legend(title = '')) +
theme(plot.title = element_text(hjust = 0.5)) + xlim(0, 1) + geom_vline(xintercept = cutoff,
linetype = 'dashed')
}
if (is.null(probs)) {
list(Metric = res_table, Metric_plot = p)
} else {
list(Metric = res_table, Metric_plot = p, Probability_plot = q)
}
}