Skip to content

Commit

Permalink
Added a method to compute ROC curves for classification experiments, …
Browse files Browse the repository at this point in the history
…a method to get the true label of a question pair, and a method to do some basic url cleaning.
  • Loading branch information
hoogeveen committed Apr 6, 2016
1 parent 6e55d38 commit 6655024
Showing 1 changed file with 67 additions and 1 deletion.
68 changes: 67 additions & 1 deletion query_cqadupstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _unzip_and_load(self, zipped_catfile):
answerfile = ziplocation + '/' + cat + '/' + cat + '_answers.json'
commentfile = ziplocation + '/' + cat + '/' + cat + '_comments.json'
userfile = ziplocation + '/' + cat + '/' + cat + '_users.json'
if os.path.exists(questionfile) and os.path.exists(answerfile) and os.path.exists(commentfile) and os.path.exists(userfile):
if os.path.exists(questionfile) and os.path.exists(answerfile) and os.path.exists(userfile) and os.path.exists(commentfile):
pass # All good, we don't need to unzip anything
else:
zip_ref = zipfile.ZipFile(zipped_catfile, 'r')
Expand Down Expand Up @@ -149,6 +149,18 @@ def get_all_postids(self):
''' Takes no input and returns a list of ALL post ids. '''
return self.postdict.keys()

def get_true_label(self, postid1, postid2):
''' Takes two postids as input and returns the true label, which is one of "dup", "nodup" or "related". '''
if postid1 in self.postdict[postid2]['dups']:
return "dup"
elif postid1 in self.postdict[postid2]['related']:
return "related"
elif postid2 in self.postdict[postid1]['dups']:
return "dups"
elif postid2 in self.postdict[postid1]['related']:
return "related"
else:
return "nodup"

###########################
# PARTICULAR POST METHODS #
Expand Down Expand Up @@ -427,6 +439,18 @@ def perform_cleaning(self, s, remove_stopwords=False, remove_punct=False, stem=F
return s


def url_cleaning(self, s):
''' Takes a string as input and removes references to possible duplicate posts, and other stackexchange urls. '''

posduppat = re.compile('<blockquote>(.|\n)+Possible Duplicate(.|\n)+</blockquote>', re.MULTILINE)
s = re.sub(posduppat, '', s)

s = re.sub('<a[^>]+stackexchange[^>]+>([^<]+)</a>', 'stackexchange-url ("\1")', s)
s = re.sub('<a[^>]+stackoverflow[^>]+>([^<]+)</a>', 'stackexchange-url ("\1")', s)

return s


def _remove_stopwords(self, s):
''' Takes a string as input, removes the stop words in the current stop word list, and returns the result.
The current stop word list can be accessed via self.stopwords, or altered by calling supply_stopwords(). '''
Expand Down Expand Up @@ -1224,6 +1248,48 @@ def evaluate_classification(self, scorefile):
rec_pos, rec_neg = self._compute_recall_oneclass(truenegatives, truepositives, falsenegatives, falsepositives)
return {'precision': precision, 'recall': recall, 'fscore': fscore, 'accuracy': accuracy, 'precision_positive_class': prec_pos, 'precision_negative_class': prec_neg, 'recall_positive_class': rec_pos, 'recall_negative_class': rec_neg}

def plot_roc(self, scorefile, plotfilename):
''' Takes a file with scores and a the name of a plot file (png) as input and returns the false positive rates (list), true positive rates (list), thresholds at which they were computed (list) and the area under the curve (float). The plot will be written to the supplied plot file.
The scores can either be probability estimates of the positive class, confidence values, or binary decisions.
This method requires scikit-learn to be installed: http://scikit-learn.org/stable/install.html
This method only computes the ROC curve for the positive class. See http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html for an example on how to make curves for multiple classes (for instance when you have a third class for the related questions). '''
# A simple example on how to use roc_curve: http://scikit-learn.org/stable/modules/model_evaluation.html#roc-metrics

import numpy as np
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

y_list = []
scores_list = []
with open(scorefile) as fileobject:
for line in fileobject:
postid, compareid, verdict = line.split()
scores_list.append(float(verdict))
if compareid in self.get_duplicates(postid):
y_list.append(1)
else:
y_list.append(0)

y = np.array(y_list) # y need to contains the true binary values
scores = np.array(scores_list) # the scores can either be probability estimates of the positive class, confidence values, or binary decisions.
fpr, tpr, thresholds = roc_curve(y, scores, pos_label=1)
auc = roc_auc_score(y, scores)

plt.figure()
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic curve')
plt.legend(loc="lower right")
plt.savefig(plotfilename)
#plt.show()

return fpr, tpr, thresholds, auc

def _compute_precision(self, truenegatives, truepositives, falsenegatives, falsepositives):
''' Takes the nr of truenegatives, truepositives, falsenegatives, and falsepositives as input and returns the precision. '''
predicted_positives = float(truepositives + falsepositives)
Expand Down

0 comments on commit 6655024

Please sign in to comment.