Skip to content

Commit

Permalink
disable AVOID_PENALTY_TOKENS
Browse files Browse the repository at this point in the history
  • Loading branch information
josStorer committed Feb 28, 2024
1 parent 225abc5 commit 18ab8b1
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions backend-python/utils/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,14 +378,14 @@ def __init__(self, model, pipeline) -> None:
dd = self.pipeline.encode(i)
assert len(dd) == 1
self.AVOID_REPEAT_TOKENS.add(dd[0])
self.AVOID_PENALTY_TOKENS = set()
AVOID_PENALTY = (
"\n" # \n,.:?!,。:?!"“”<>[]{}/\\|;;~`@#$%^&*()_+-=0123456789
)
for i in AVOID_PENALTY:
dd = self.pipeline.encode(i)
assert len(dd) == 1
self.AVOID_PENALTY_TOKENS.add(dd[0])
# self.AVOID_PENALTY_TOKENS = set()
# AVOID_PENALTY = (
# "\n" # \n,.:?!,。:?!"“”<>[]{}/\\|;;~`@#$%^&*()_+-=0123456789
# )
# for i in AVOID_PENALTY:
# dd = self.pipeline.encode(i)
# assert len(dd) == 1
# self.AVOID_PENALTY_TOKENS.add(dd[0])

self.__preload()

Expand All @@ -399,11 +399,11 @@ def adjust_occurrence(self, occurrence: Dict, token: int):

def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
for n in occurrence:
if n not in self.AVOID_PENALTY_TOKENS:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)
# if n not in self.AVOID_PENALTY_TOKENS:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)

if i == 0:
for token in self.model_tokens:
Expand Down

0 comments on commit 18ab8b1

Please sign in to comment.