Skip to content

Commit

Permalink
Update experiments.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TonyBagnall committed Sep 5, 2019
1 parent 6b3bd8b commit e5ab738
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions sktime/contrib/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
import sktime.classifiers.dictionary_based.boss as db
import sktime.classifiers.frequency_based.rise as fb
import sktime.classifiers.interval_based.tsf as ib
#import sktime.classifiers.distance_based.elastic_ensemble as dist
#import sktime.classifiers.distance_based.proximity_forest as pf
import sktime.classifiers.distance_based.elastic_ensemble as dist
import sktime.classifiers.distance_based.proximity_forest as pf
import sktime.classifiers.shapelet_based.stc as st
from sktime.utils.load_data import load_from_tsfile_to_dataframe as load_ts

Expand Down Expand Up @@ -314,9 +314,11 @@ def run_experiment(problem_path, results_path, cls_name, dataset, classifier=Non
print(cls_name + " on " + dataset + " resample number " + str(resampleID) + ' test acc: ' + str(ac)
+ ' time: ' + str(test_time))
# print(str(classifier.findEnsembleTrainAcc(trainX, trainY)))
second = str(classifier.get_params())
second.replace('\n',' ').replace('\r',' ')

if "Composite" in cls_name:
second="Para info too long!"
else:
second = str(classifier.get_params())
print(second)
third = str(ac)+","+str(build_time)+","+str(test_time)+",-1,-1,"+str(len(classifier.classes_))+ "," + str(classifier.classes_)
write_results_to_uea_format(second_line=second, third_line=third, output_path=results_path, classifier_name=cls_name, resample_seed= resampleID,
predicted_class_vals=preds, actual_probas=probs, dataset_name=dataset, actual_class_vals=testY, split='TEST')
Expand All @@ -331,8 +333,10 @@ def run_experiment(problem_path, results_path, cls_name, dataset, classifier=Non
train_acc = accuracy_score(trainY,train_preds)
print(cls_name + " on " + dataset + " resample number " + str(resampleID) + ' train acc: ' + str(train_acc)
+ ' time: ' + str(train_time))
second = str(classifier.get_params())
second.replace('\n',' ').replace('\r',' ')
if "Composite" in cls_name:
second="Para info too long!"
else:
second = str(classifier.get_params())
third = str(train_acc)+","+str(train_time)+",-1,-1,-1,"+str(len(classifier.classes_)) + "," + str(classifier.classes_)
write_results_to_uea_format(second_line=second, third_line=third, output_path=results_path, classifier_name=cls_name, resample_seed= resampleID,
predicted_class_vals=train_preds, actual_probas=train_probs, dataset_name=dataset, actual_class_vals=trainY, split='TRAIN')
Expand Down

0 comments on commit e5ab738

Please sign in to comment.