forked from Synthyra/ESMplusplus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_throughput.py
217 lines (180 loc) · 7.73 KB
/
test_throughput.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import torch
import time
import random
import argparse
import matplotlib.pyplot as plt
import pandas as pd
from huggingface_hub import login
from transformers import AutoModelForMaskedLM, EsmTokenizer
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
parser = argparse.ArgumentParser()
parser.add_argument('--model_paths', nargs='+', type=str, default=[
#'facebook/esm2_t6_8M_UR50D',
'Synthyra/FastESM2_650',
'facebook/esm2_t12_35M_UR50D',
'facebook/esm2_t30_150M_UR50D',
'facebook/esm2_t33_650M_UR50D',
'esmc_300m', # esmc model
'esmc_600m', # esmc model
'Synthyra/ESMplusplus_small',
'Synthyra/ESMplusplus_large'
])
parser.add_argument('--token', type=str, default=None)
parser.add_argument('--test', action='store_true', help='Generate random results for testing')
args = parser.parse_args()
if args.token:
login(args.token)
model_paths = args.model_paths
canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class ESMCForEmbedding(torch.nn.Module):
def __init__(self, esm):
super().__init__()
self.esm = esm
def forward(self, seq):
protein = ESMProtein(sequence=seq)
protein_tensor = self.esm.encode(protein)
embeddings = self.esm.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
).embeddings.cpu()
return embeddings
def generate_random_sequence(length: int) -> str:
return 'M' + "".join(random.choices(canonical_amino_acids, k=length-3))
def generate_batch_sequences(length: int, batch_size: int, num_batches: int = 100) -> list:
all_sequences = []
for _ in range(num_batches):
batch_sequences = [generate_random_sequence(length) for _ in range(batch_size)]
all_sequences.append(batch_sequences)
return all_sequences
def time_model(model, inputs, warmup=4):
model.eval()
with torch.no_grad():
# Warmup
for _ in range(warmup):
_ = model(**inputs[0])
start_time = time.time()
for input_batch in inputs:
_ = model(**input_batch)
return time.time() - start_time
def time_model_esmc(model, sequences, warmup=10):
model.eval()
with torch.no_grad():
# Warmup
for _ in range(warmup):
for seq in sequences[0]:
_ = model(seq)
start_time = time.time()
for batch in sequences:
for seq in batch:
_ = model(seq)
return time.time() - start_time
def get_gpu_memory():
torch.cuda.synchronize()
return torch.cuda.max_memory_allocated() / 1024**2 # Convert to MB
# Test different sequence lengths and batch sizes
lengths = [32, 64, 128, 256, 512, 1024, 2048]
batch_sizes = [1, 2, 4, 8, 16, 32]
num_batches = 16
results = []
if not args.test:
# Generate all test sequences first
all_sequences = {}
for length in lengths:
for batch_size in batch_sizes:
print(f"\nGenerating sequences for length={length}, batch_size={batch_size}")
all_sequences[(length, batch_size)] = generate_batch_sequences(length, batch_size, num_batches)
# Test each model
for model_path in model_paths:
print(f"\nTesting {model_path}...")
if 'esmc' in model_path.lower():
esm = ESMC.from_pretrained(model_path, device=device).to(device)
model = ESMCForEmbedding(esm).to(device)
tokenizer = None
elif 'synthyra' in model_path.lower():
model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(device)
tokenizer = model.tokenizer
else:
model = AutoModelForMaskedLM.from_pretrained(model_path).to(device)
tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
for length in lengths:
for batch_size in batch_sizes:
print(f"\nTesting length={length}, batch_size={batch_size}")
sequences = all_sequences[(length, batch_size)]
torch.cuda.reset_peak_memory_stats()
if isinstance(model, ESMCForEmbedding):
model_time = time_model_esmc(model, sequences)
else:
inputs = [tokenizer(batch_seq, padding=True, return_tensors="pt").to(device) for batch_seq in sequences]
model_time = time_model(model, inputs)
model_memory = get_gpu_memory()
results.append({
'Model': model_path,
'Length': length,
'Batch Size': batch_size,
'Time': model_time,
'Memory': model_memory
})
print(f"Time: {model_time:.2f}s, memory: {model_memory:.0f}MB")
torch.cuda.empty_cache()
model.cpu()
del model
torch.cuda.empty_cache()
else:
# Generate random test results
for model_path in model_paths:
for length in lengths:
for batch_size in batch_sizes:
# Generate random time between 0.1 and 10 seconds, scaling with length and batch size
model_time = random.uniform(0.1, 10) * (length/2) * (batch_size/1)
# Generate random memory between 100 and 5000 MB, scaling with length and batch size
model_memory = random.uniform(100, 5000) * (length/2) * (batch_size/1)
results.append({
'Model': model_path,
'Length': length,
'Batch Size': batch_size,
'Time': model_time,
'Memory': model_memory
})
print(f"Generated random - Time: {model_time:.2f}s, memory: {model_memory:.0f}MB")
# Save results to CSV
df = pd.DataFrame(results)
df.to_csv('model_benchmarks.csv', index=False)
# Create visualization for throughput
num_batch_sizes = len(batch_sizes)
plt.figure(figsize=(15, 5 * num_batch_sizes))
for i, batch_size in enumerate(batch_sizes):
plt.subplot(num_batch_sizes, 1, i + 1)
for model_path in model_paths:
model_results = [(r['Length'], r['Time']) for r in results
if r['Model'] == model_path and r['Batch Size'] == batch_size]
if model_results:
lengths, times = zip(*model_results)
throughput = [batch_size * len * num_batches / time for len, time in zip(lengths, times)]
plt.plot(lengths, throughput, marker='o', label=model_path)
plt.title(f'Model Throughput vs Sequence Length (Batch Size = {batch_size})')
plt.xlabel('Sequence Length')
plt.ylabel('Throughput (tokens/second)')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True)
plt.tight_layout()
plt.savefig('model_throughput.png', bbox_inches='tight', dpi=300)
plt.close()
# Create visualization for memory usage
plt.figure(figsize=(15, 5 * num_batch_sizes))
for i, batch_size in enumerate(batch_sizes):
plt.subplot(num_batch_sizes, 1, i + 1)
for model_path in model_paths:
model_results = [(r['Length'], r['Memory']) for r in results
if r['Model'] == model_path and r['Batch Size'] == batch_size]
if model_results:
lengths, memory = zip(*model_results)
plt.plot(lengths, memory, marker='o', label=model_path)
plt.title(f'GPU Memory Usage vs Sequence Length (Batch Size = {batch_size})')
plt.xlabel('Sequence Length')
plt.ylabel('Memory Usage (MB)')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True)
plt.tight_layout()
plt.savefig('model_memory.png', bbox_inches='tight', dpi=300)
plt.close()