-
Notifications
You must be signed in to change notification settings - Fork 32
/
utils.lua
189 lines (167 loc) · 4.2 KB
/
utils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
local clock = os.clock
function sleep(n) -- seconds
local t0 = clock()
while clock() - t0 <= n do end
end
function split(s, pattern)
local parts = {}
for i in string.gmatch(s, pattern) do
table.insert(parts, i)
end
return parts
end
function string.starts(String,Start)
return string.sub(String,1,string.len(Start))==Start
end
function string.ends(String,End)
return End=='' or string.sub(String,-string.len(End))==End
end
function string.trim(s)
-- return (s:gsub("^%s*(.-)%s*$", "%1"))
return s:match "^%W*(.-)%s*$"
end
function reverse_tensor(tensor)
--make sure tensor is 1D
local n = tensor:size(1)
local tmp = torch.Tensor(n)
for i=1, n do
tmp[i] = tensor[n+1-i]
end
return tmp
end
-- function specific to make available_objects tensor
function table_to_binary_tensor(t,N)
local tensor
if t then
tensor = torch.zeros(N)
for i,val in pairs(t) do
tensor[val] = 1
end
else
tensor = torch.ones(N)
end
return tensor
end
function str_to_table(str)
if type(str) == 'table' then
return str
end
if not str or type(str) ~= 'string' then
if type(str) == 'table' then
return str
end
return {}
end
local ttr
if str ~= '' then
local ttx=tt
loadstring('tt = {' .. str .. '}')()
ttr = tt
tt = ttx
else
ttr = {}
end
return ttr
end
-- IMP: very specific function - do not use for arbitrary tensors
function tensor_to_table(tensor, state_dim, hist_len)
batch_size = tensor:size(1)
local NULL_INDEX = #symbols+1
-- convert 0 to NULL_INDEX (this happens when hist doesn't go back as far as hist_len in chain)
for i=1, tensor:size(1) do
for j=1, tensor:size(2) do
if tensor[i][j] == 0 then
tensor[i][j] = NULL_INDEX
end
end
end
local t2 = {}
if tensor:size(1) == hist_len then
-- hacky: this is testing case. They don't seem to have a consistent representation
-- so this will have to do for now.
-- print('testing' , tensor:size())
for j=1, tensor:size(1) do
for k=1, tensor:size(2)/state_dim do
t2_tmp = {}
for i=(k-1)*state_dim+1, k*state_dim do
t2_tmp[i%state_dim] = tensor[{{j}, {i}}]:reshape(1)
end
t2_tmp[state_dim] = t2_tmp[0]
t2_tmp[0] = nil
table.insert(t2, t2_tmp)
end
end
else
-- print('training' , tensor:size())
-- print(tensor[{{1}, {}}])
for j=1, tensor:size(2)/state_dim do
t2_tmp = {}
for i=(j-1)*state_dim+1,j*state_dim do
t2_tmp[i%state_dim] = tensor[{{}, {i}}]:reshape(batch_size)
end
t2_tmp[state_dim] = t2_tmp[0]
t2_tmp[0] = nil
table.insert(t2, t2_tmp)
end
end
-- for i=1, #t2 do
-- for j=1, #t2[1] do
-- for k=1, t2[i][j]:size(1) do
-- assert(t2[i][j][k] ~= 0, "0 element at"..i..' '..j..' '..k)
-- end
-- end
-- end
return t2
end
function table.copy(t)
if t == nil then return nil end
local nt = {}
for k, v in pairs(t) do
if type(v) == 'table' then
nt[k] = table.copy(v)
else
nt[k] = v
end
end
setmetatable(nt, table.copy(getmetatable(t)))
return nt
end
function TableConcat(t1,t2)
for i=1,#t2 do
t1[#t1+1] = t2[i]
end
return t1
end
function table.val_to_str ( v )
if "string" == type( v ) then
v = string.gsub( v, "\n", "\\n" )
if string.match( string.gsub(v,"[^'\"]",""), '^"+$' ) then
return "'" .. v .. "'"
end
return '"' .. string.gsub(v,'"', '\\"' ) .. '"'
else
return "table" == type( v ) and table.tostring( v ) or
tostring( v )
end
end
function table.key_to_str ( k )
if "string" == type( k ) and string.match( k, "^[_%a][_%a%d]*$" ) then
return k
else
return "[" .. table.val_to_str( k ) .. "]"
end
end
function table.tostring( tbl )
local result, done = {}, {}
for k, v in ipairs( tbl ) do
table.insert( result, table.val_to_str( v ) )
done[ k ] = true
end
for k, v in pairs( tbl ) do
if not done[ k ] then
table.insert( result,
table.key_to_str( k ) .. "=" .. table.val_to_str( v ) )
end
end
return "{" .. table.concat( result, "," ) .. "}"
end