Skip to content

Commit

Permalink
split out train,test,new separately when reporting on sampling word i…
Browse files Browse the repository at this point in the history
…dentity
  • Loading branch information
karpathy authored Jun 9, 2022
1 parent e0a08f2 commit c3aaadc
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions makemore.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,7 @@ def print_samples(num=10):
top_k = args.top_k if args.top_k != -1 else None
steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
X_samp = sample(model, X_init, steps, top_k=top_k).to('cpu')
unique_samples = []
had_samples = []
train_samples, test_samples, new_samples = [], [], []
for i in range(X_samp.size(0)):
# get the i'th row of sampled integers, as python list
row = X_samp[i, 1:].tolist() # note: we need to crop out the first <START> token
Expand All @@ -262,17 +261,17 @@ def print_samples(num=10):
row = row[:crop_index]
word_samp = train_dataset.decode(row)
# separately track samples that we have and have not seen before
word_have = train_dataset.contains(word_samp) or test_dataset.contains(word_samp)
sample_list = had_samples if word_have else unique_samples
sample_list.append(word_samp)

if train_dataset.contains(word_samp):
train_samples.append(word_samp)
elif test_dataset.contains(word_samp):
test_samples.append(word_samp)
else:
new_samples.append(word_samp)
print('-'*80)
print(f'{len(had_samples)} Samples that were found in input dataset:')
for word in had_samples:
print(word)
print(f'{len(unique_samples)} Samples that were NOT found in input dataset:')
for word in unique_samples:
print(word)
for lst, desc in [(train_samples, 'in train'), (test_samples, 'in test'), (new_samples, 'new')]:
print(f"{len(lst)} samples that are {desc}:")
for word in lst:
print(word)
print('-'*80)

@torch.inference_mode()
Expand Down

0 comments on commit c3aaadc

Please sign in to comment.