-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathkirchenbauer_watermarks.py
115 lines (90 loc) · 4.07 KB
/
kirchenbauer_watermarks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Implementation of Watermarking algorithms from Kirchenbaur et. al (2023)
- Are these general? Like not specific to any model
"""
import locale
import matplotlib.pyplot as plt
import math
import random
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
gpt2large = GPT2LMHeadModel.from_pretrained('gpt2-large').to(device)
vocab_size = gpt2.config.vocab_size
def sample(model, prompt, length, masker=None):
"""
- Sample `length` tokens of output text from model given input `text`
- `masker` is the watermark
"""
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
with torch.no_grad():
for _ in range(length):
logits = model(input_ids)
if masker:
# For now, only assume that watermark has access to prompt tokens
mask = masker(input_ids)
else:
mask = torch.zeros_like(logits[0][0][-1])
# Mask bumps up certain tokens' probabilities in the vocabulary
logits[0][0][-1] += mask
probs = torch.softmax(logits[0][0][-1], dim=0)
# TODO(ltang): move beyond multinomial decode to beam search
sample = torch.multinomial(probs, 1)
# Input window extends to include generated token
# Affects masker/watermark via `hash_tensor``
input_ids = torch.cat([input_ids, sample.unsqueeze(0)], dim=1)
return input_ids, logits
def hash_tensor(tensor):
return hash(tuple(tensor.tolist()))
def soft_watermark(input_ids, gamma, delta):
"""
- Randomly partition vocabulary into green and red list. Bump up likelihood of green list tokens.
- Gamma controls the green list size (specifically, what proportion of entire list is green)
- Delta is the logit bump that gets added to tokens in the green list
"""
green_list_length = int(vocab_size * gamma)
# Use hash of token values to seed RNG
random.seed(hash_tensor(input_ids[0]))
# Use RNG to re-partition
indices_to_mask = random.sample(range(vocab_size), green_list_length)
mask = torch.zeros(vocab_size).to(device)
mask[indices_to_mask] = delta
return mask
# TODO(ltang): should we be removing the prompt from text? Right now it seems like we might overcount green_count
def detect_soft_watermark(text, gamma):
"""
- Given (human or machine generated) `text`, a combination of prompt + output, determine if it has been watermarked
- Assume detector has access to hash function and RNG
- Gamma is the green list proportion
"""
tokens = tokenizer.encode(text)
T = len(tokens)
# Count the number of tokens in the test text that are green list tokens
green_count = 0
for i, token in enumerate(tokens):
prev_tokens = tokens[:i]
# Detector has access to hash function and RNG
random.seed(hash(tuple(prev_tokens)))
green_list_length = int(vocab_size * gamma)
# Recover green list at each sample step
green_list = set(random.sample(range(vocab_size), green_list_length))
# Check if the current token is in the green list
if token in green_list:
green_count += 1
# One proportion z-test to evaluate the null hypothesis (the text sequence is generated with no knowledge of the red list rule)
z = 2 * (green_count - T / 2) / math.sqrt(T)
return z
gamma = 0.5
watermarked = tokenizer.decode(sample(gpt2large, "The quick brown", 100, masker = lambda x: soft_watermark(x, gamma=gamma, delta=999))[0][0])
unmarked = tokenizer.decode(sample(gpt2large, "The quick brown", 100)[0][0])
print("WATERMARKED TEXT:")
print(watermarked)
print("UNMARKED TEXT:")
print(unmarked)
print("Z-Score on Watermarked Text")
print(detect_soft_watermark(watermarked, gamma))
print("Z-Score on Unmarked Text")
print(detect_soft_watermark(unmarked, gamma))