Skip to content

Commit

Permalink
Merge pull request THU-BPM#10 from TalkIsCheap22/KGW-v2-dyc
Browse files Browse the repository at this point in the history
add more hashing schemes for KGW
  • Loading branch information
panly2003 authored Jul 4, 2024
2 parents 8a48a84 + 364f49a commit ba3477e
Showing 1 changed file with 40 additions and 7 deletions.
47 changes: 40 additions & 7 deletions watermark/kgw/kgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,57 @@ def __init__(self, config: KGWConfig, *args, **kwargs) -> None:
"""
self.config = config
self.rng = torch.Generator(device=self.config.device)
self.rng.manual_seed(self.config.hash_key)
self.prf = torch.randperm(self.config.vocab_size, device=self.config.device, generator=self.rng)

def _seed_rng(self, input_ids: torch.LongTensor) -> None:
"""Seed the RNG with the last min_prefix_len tokens of the input_ids."""
def f_time(self, input_ids: torch.LongTensor) -> int:
"""Get the previous token time."""
time_result = 1
for i in range(0, self.config.prefix_length):
time_result *= input_ids[-1 - i].item()
prev_token = time_result % self.config.vocab_size
self.rng.manual_seed(self.config.hash_key * prev_token)
return self.prf[time_result % self.config.vocab_size]

def f_additive(self, input_ids: torch.LongTensor) -> int:
"""Get the previous token additive."""
additive_result = 0
for i in range(0, self.config.prefix_length):
additive_result += input_ids[-1 - i].item()
return self.prf[additive_result % self.config.vocab_size]

def f_skip(self, input_ids: torch.LongTensor) -> int:
"""Get the previous token skip."""
return self.prf[self.input_ids[- self.config.prefix_length].item()]

def f_min(self, input_ids: torch.LongTensor) -> int:
"""Get the previous token min."""
return min(self.prf[input_ids[-1 - i].item()] for i in range(0, self.config.prefix_length))

def _seed_rng(self, input_ids: torch.LongTensor) -> None:
"""Seed the RNG with the last min_prefix_len tokens of the input_ids."""
return

def get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
"""Get greenlist ids for the input_ids."""
self._seed_rng(input_ids)
def get_greenlist_ids_left(self, input_ids: torch.LongTensor) -> list[int]:
"""Get greenlist ids for the input_ids via leftHash scheme."""
self.rng.manual_seed(self.config.hash_key * self.f_time(input_ids))
greenlist_size = int(self.config.vocab_size * self.config.gamma)
vocab_permutation = torch.randperm(self.config.vocab_size, device=input_ids.device, generator=self.rng)
greenlist_ids = vocab_permutation[:greenlist_size]
return greenlist_ids

def get_greenlist_ids_self(self, input_ids: torch.LongTensor) -> list[int]:
"""Get greenlist ids for the input_ids via selfHash scheme."""
greenlist_size = int(self.config.vocab_size * self.config.gamma)
greenlist_ids = []
f_x = self.f_time(input_ids)
for k in range(0, self.config.vocab_size):
h_k = f_x * self.prf[k]
self._seed_rng(h_k)
vocab_permutation = torch.randperm(self.config.vocab_size, device=input_ids.device, generator=self.rng)
temp_greenlist_ids = vocab_permutation[:greenlist_size]
if k in temp_greenlist_ids:
greenlist_ids.append(k)
return greenlist_ids

def _compute_z_score(self, observed_count: int , T: int) -> float:
"""Compute z-score for the given observed count and total tokens."""
expected_count = self.config.gamma
Expand Down

0 comments on commit ba3477e

Please sign in to comment.