Skip to content

Commit

Permalink
Merge pull request lucidrains#62 from gurvindersingh/main
Browse files Browse the repository at this point in the history
Support top_a sampling method
  • Loading branch information
lucidrains authored Sep 18, 2021
2 parents 97e66d6 + c60cbc0 commit 2eb20cf
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion x_transformers/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,15 @@ def top_k(logits, thres = 0.9):
probs.scatter_(1, ind, val)
return probs

# top_a

def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = -float("Inf")
logits[probs >= limit] = 1
return logits

# entmax

ENTMAX_ALPHA = 1.3
@@ -46,7 +55,7 @@ def __init__(self, net, ignore_index = -100, pad_value = 0):
self.max_seq_len = net.max_seq_len

@torch.no_grad()
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, min_p_pow=2.0, min_p_ratio=0.02, **kwargs):
device = start_tokens.device
was_training = self.net.training
num_dims = len(start_tokens.shape)
@@ -73,6 +82,10 @@ def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., fi
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

elif filter_logits_fn is top_a:
filtered_logits = filter_logits_fn(logits, min_p_pow = min_p_pow, min_p_ratio= min_p_ratio)
probs = F.softmax(filtered_logits / temperature, dim=-1)

elif filter_logits_fn is entmax:
probs = entmax(logits / temperature, alpha = ENTMAX_ALPHA, dim=-1)

0 comments on commit 2eb20cf

Please sign in to comment.