From 66550246b605a458d1411c5492dd2846b0b8a797 Mon Sep 17 00:00:00 2001 From: hoogeveen Date: Wed, 6 Apr 2016 10:46:59 +1000 Subject: [PATCH] Added a method to compute ROC curves for classification experiments, a method to get the true label of a question pair, and a method to do some basic url cleaning. --- query_cqadupstack.py | 68 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/query_cqadupstack.py b/query_cqadupstack.py index c1302c5..72cf903 100644 --- a/query_cqadupstack.py +++ b/query_cqadupstack.py @@ -63,7 +63,7 @@ def _unzip_and_load(self, zipped_catfile): answerfile = ziplocation + '/' + cat + '/' + cat + '_answers.json' commentfile = ziplocation + '/' + cat + '/' + cat + '_comments.json' userfile = ziplocation + '/' + cat + '/' + cat + '_users.json' - if os.path.exists(questionfile) and os.path.exists(answerfile) and os.path.exists(commentfile) and os.path.exists(userfile): + if os.path.exists(questionfile) and os.path.exists(answerfile) and os.path.exists(userfile) and os.path.exists(commentfile): pass # All good, we don't need to unzip anything else: zip_ref = zipfile.ZipFile(zipped_catfile, 'r') @@ -149,6 +149,18 @@ def get_all_postids(self): ''' Takes no input and returns a list of ALL post ids. ''' return self.postdict.keys() + def get_true_label(self, postid1, postid2): + ''' Takes two postids as input and returns the true label, which is one of "dup", "nodup" or "related". ''' + if postid1 in self.postdict[postid2]['dups']: + return "dup" + elif postid1 in self.postdict[postid2]['related']: + return "related" + elif postid2 in self.postdict[postid1]['dups']: + return "dups" + elif postid2 in self.postdict[postid1]['related']: + return "related" + else: + return "nodup" ########################### # PARTICULAR POST METHODS # @@ -427,6 +439,18 @@ def perform_cleaning(self, s, remove_stopwords=False, remove_punct=False, stem=F return s + def url_cleaning(self, s): + ''' Takes a string as input and removes references to possible duplicate posts, and other stackexchange urls. ''' + + posduppat = re.compile('
(.|\n)+Possible Duplicate(.|\n)+
', re.MULTILINE) + s = re.sub(posduppat, '', s) + + s = re.sub(']+stackexchange[^>]+>([^<]+)', 'stackexchange-url ("\1")', s) + s = re.sub(']+stackoverflow[^>]+>([^<]+)', 'stackexchange-url ("\1")', s) + + return s + + def _remove_stopwords(self, s): ''' Takes a string as input, removes the stop words in the current stop word list, and returns the result. The current stop word list can be accessed via self.stopwords, or altered by calling supply_stopwords(). ''' @@ -1224,6 +1248,48 @@ def evaluate_classification(self, scorefile): rec_pos, rec_neg = self._compute_recall_oneclass(truenegatives, truepositives, falsenegatives, falsepositives) return {'precision': precision, 'recall': recall, 'fscore': fscore, 'accuracy': accuracy, 'precision_positive_class': prec_pos, 'precision_negative_class': prec_neg, 'recall_positive_class': rec_pos, 'recall_negative_class': rec_neg} + def plot_roc(self, scorefile, plotfilename): + ''' Takes a file with scores and a the name of a plot file (png) as input and returns the false positive rates (list), true positive rates (list), thresholds at which they were computed (list) and the area under the curve (float). The plot will be written to the supplied plot file. + The scores can either be probability estimates of the positive class, confidence values, or binary decisions. + This method requires scikit-learn to be installed: http://scikit-learn.org/stable/install.html + This method only computes the ROC curve for the positive class. See http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html for an example on how to make curves for multiple classes (for instance when you have a third class for the related questions). ''' + # A simple example on how to use roc_curve: http://scikit-learn.org/stable/modules/model_evaluation.html#roc-metrics + + import numpy as np + from sklearn.metrics import roc_curve + from sklearn.metrics import roc_auc_score + import matplotlib.pyplot as plt + + y_list = [] + scores_list = [] + with open(scorefile) as fileobject: + for line in fileobject: + postid, compareid, verdict = line.split() + scores_list.append(float(verdict)) + if compareid in self.get_duplicates(postid): + y_list.append(1) + else: + y_list.append(0) + + y = np.array(y_list) # y need to contains the true binary values + scores = np.array(scores_list) # the scores can either be probability estimates of the positive class, confidence values, or binary decisions. + fpr, tpr, thresholds = roc_curve(y, scores, pos_label=1) + auc = roc_auc_score(y, scores) + + plt.figure() + plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % auc) + plt.plot([0, 1], [0, 1], 'k--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('Receiver Operating Characteristic curve') + plt.legend(loc="lower right") + plt.savefig(plotfilename) + #plt.show() + + return fpr, tpr, thresholds, auc + def _compute_precision(self, truenegatives, truepositives, falsenegatives, falsepositives): ''' Takes the nr of truenegatives, truepositives, falsenegatives, and falsepositives as input and returns the precision. ''' predicted_positives = float(truepositives + falsepositives)