forked from aladdinpersson/Machine-Learning-Collection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
75 lines (58 loc) · 2.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
def ask_user():
print("Write your array as a list [i,j,k..] with arbitrary positive numbers")
array = input("Input q if you want to quit \n")
return array
def sort_array(encoder, decoder, device, arr=None):
"""
A very simple example of use of the model
Input: encoder nn.Module
decoder nn.Module
device
array to sort (optional)
"""
if arr is None:
arr = ask_user()
with torch.no_grad():
while arr != "q":
# Avoid numerical errors by rounding to max_len
arr = eval(arr)
lengths = [
len(str(elem).split(".")[1]) if len(str(elem).split(".")) > 1 else 0
for elem in arr
]
max_len = max(lengths)
source = torch.tensor(arr, dtype=torch.float).to(device).unsqueeze(1)
batch_size = source.shape[1]
target_len = source.shape[0] + 1
outputs = torch.zeros(target_len, batch_size, target_len - 1).to(device)
encoder_states, hidden, cell = encoder(source)
# First input will be <SOS> token
x = torch.tensor([-1], dtype=torch.float).to(device)
predictions = torch.zeros((target_len)).to(device)
for t in range(1, target_len):
# At every time step use encoder_states and update hidden, cell
attention, energy, hidden, cell = decoder(
x, encoder_states, hidden, cell
)
# Store prediction for current time step
outputs[t] = energy.permute(1, 0)
# Get the best word the Decoder predicted (index in the vocabulary)
best_guess = attention.argmax(0)
predictions[t] = best_guess.item()
x = torch.tensor([best_guess.item()], dtype=torch.float).to(device)
output = [
round(source[predictions[1:].long()][i, :].item(), max_len)
for i in range(source.shape[0])
]
print(f"Here's the result: {output}")
arr = ask_user()
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def load_checkpoint(checkpoint, model, optimizer): # , steps):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# steps = checkpoint['steps']
# return steps