diff --git a/evaluation.py b/evaluation.py index 3f39796..c2447b5 100644 --- a/evaluation.py +++ b/evaluation.py @@ -99,7 +99,12 @@ def create_dataset(indexed_labels, origin_file_path): def evaluate(): - model = C3D2(n_labels=100, num_channels=1) + model_path = '/Users/leonidas/Downloads/model_14_percent_best_so_far.pt' + + if not torch.cuda.is_available(): + model = C3D2(100, 1).load_checkpoint(torch.load(model_path, map_location=lambda storage,loc: storage)) + else: + model = C3D2(100, 1).load_checkpoint(torch.load(model_path)) dir_path = os.path.join(c.ROOT, 'speaker_models') test_set = os.path.join(c.ROOT, '50_first_ids.txt')