forked from vl2g/floco
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
142 lines (117 loc) · 5.2 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from tqdm.auto import tqdm
import json
import torch
from torch.utils.data import Dataset
from evaluate import load
from calc_code_bleu import compute_codebleu
from transformers import RobertaTokenizer, T5ForConditionalGeneration
import warnings
warnings.filterwarnings('ignore')
def data_visualisation(code_pth, encodings_pth):
with open(encodings_pth) as file:
data = file.read()
js = json.loads(data)
image_ids = list(js.keys())
encodings = list(js.values())
python_codes=[]
for id in image_ids:
cdp = code_pth+str(id)+'.py'
lines=''
file = open(cdp, 'r')
if file.read()[0]=='#':
file = open(cdp, 'r')
next(file)
lines = file.read()
else:
file = open(cdp, 'r')
lines = file.read()
python_codes.append(lines)
return image_ids, encodings, python_codes
def CodeT5_tokenize():
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
return tokenizer
class CustomDataset(Dataset):
def __init__(self, input_ids, attention_mask, output, imageids):
self.input_ids = input_ids
self.attention_mask = attention_mask
self.output = output
self.imageids = imageids
def __len__(self):
return self.input_ids.shape[0]
def __getitem__(self, idx):
return (self.imageids[idx], self.input_ids[idx], self.attention_mask[idx], self.output[idx])
def data_loading(test_set, batch_size):
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
return test_loader
def CodeT5_model():
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')
return model
def writing_results(results_pth, test_loader, tokenizer, model, device):
bleu_score=0.0
codebleu_score=0.0
exact_match=0.0
for image_id, input_id, attention_mask, code in tqdm(test_loader):
input_id = input_id.to(device)
attention_mask = attention_mask.to(device)
code = code.to(device)
# Generating the code from the model
outputs = model.generate(input_ids = input_id, attention_mask = attention_mask, return_dict_in_generate=True, output_scores=True, max_length=1024)
out = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
program = tokenizer.batch_decode(code, skip_special_tokens=True)
encod = tokenizer.batch_decode(input_id, skip_special_tokens=True)
exact_match_metric = load("exact_match")
# Calculating the metrics for each generated code
for i in range(len(program)):
codebleu = compute_codebleu([out[i]], [[program[i]]], 'python')[0]
bleu = compute_codebleu([out[i]], [[program[i]]], 'python')[1][0]
EM = exact_match_metric.compute(predictions=[out[i]], references=[program[i]])['exact_match']
bleu_score+=bleu
exact_match+=EM
codebleu_score+=codebleu
# Writing results to the file
file=open(results_pth, 'a')
for i in range(len(image_id)):
file.write(str(image_id[i])+".png\n")
file.write("Encoding from tokenizer : \n " + encod[i])
file.write("\n")
file.write("Original Python Program : \n " + program[i])
file.write("\n")
file.write("Output : \n " + out[i])
file.write("\n \n ")
bleu = bleu_score/len(test_loader.dataset)
EM = exact_match/len(test_loader.dataset)
codebleu = codebleu_score/len(test_loader.dataset)
print(bleu, codebleu, EM)
def run():
# Path to the test codes
test_code_pth = ''
# Path to the test encodings
test_encodings_pth = ''
# Path to the trained model checkpoints
trained_model_pth = ''
# Path to file where the generated codes will be stored
results_pth = ''
# Batch size for the test data
batch_size = 16
# Loading test data
test_image_ids, test_encodings, test_codes = data_visualisation(test_code_pth, test_encodings_pth)
print(len(test_image_ids))
# Tokenize the test data with CodeT5 tokenizer
tokenizer = CodeT5_tokenize()
tokenizer.add_tokens(['[SEP]', 'PARALLELOGRAM', 'RECTANGLE', 'OVAL', 'DIAMOND'], special_tokens=True)
test_input = tokenizer(test_encodings, padding='max_length', truncation=True, return_tensors='pt', max_length=512)
with tokenizer.as_target_tokenizer():
test_labels = tokenizer(test_codes, padding='max_length', truncation=True, return_tensors='pt', max_length=512)
# Create the test dataset and dataloader
test_set = CustomDataset(test_input['input_ids'], test_input['attention_mask'], test_labels['input_ids'], test_image_ids)
test_loader = data_loading(test_set, batch_size)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Load pre-trained CodeT5 model from HuggingFace
model = CodeT5_model()
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
# Load the fine-tuned model for inference on test data
model.load_state_dict(torch.load(trained_model_pth, map_location=torch.device('cuda:0')))
# Generate the results
writing_results(results_pth, test_loader, tokenizer, model, device)
run()