From c3aaadcb165a7e188d808cc36507d8bd10a94524 Mon Sep 17 00:00:00 2001 From: Andrej Date: Thu, 9 Jun 2022 20:55:27 +0000 Subject: [PATCH] split out train,test,new separately when reporting on sampling word identity --- makemore.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/makemore.py b/makemore.py index 2026f15..53a98aa 100644 --- a/makemore.py +++ b/makemore.py @@ -252,8 +252,7 @@ def print_samples(num=10): top_k = args.top_k if args.top_k != -1 else None steps = train_dataset.get_output_length() - 1 # -1 because we already start with token (index 0) X_samp = sample(model, X_init, steps, top_k=top_k).to('cpu') - unique_samples = [] - had_samples = [] + train_samples, test_samples, new_samples = [], [], [] for i in range(X_samp.size(0)): # get the i'th row of sampled integers, as python list row = X_samp[i, 1:].tolist() # note: we need to crop out the first token @@ -262,17 +261,17 @@ def print_samples(num=10): row = row[:crop_index] word_samp = train_dataset.decode(row) # separately track samples that we have and have not seen before - word_have = train_dataset.contains(word_samp) or test_dataset.contains(word_samp) - sample_list = had_samples if word_have else unique_samples - sample_list.append(word_samp) - + if train_dataset.contains(word_samp): + train_samples.append(word_samp) + elif test_dataset.contains(word_samp): + test_samples.append(word_samp) + else: + new_samples.append(word_samp) print('-'*80) - print(f'{len(had_samples)} Samples that were found in input dataset:') - for word in had_samples: - print(word) - print(f'{len(unique_samples)} Samples that were NOT found in input dataset:') - for word in unique_samples: - print(word) + for lst, desc in [(train_samples, 'in train'), (test_samples, 'in test'), (new_samples, 'new')]: + print(f"{len(lst)} samples that are {desc}:") + for word in lst: + print(word) print('-'*80) @torch.inference_mode()