forked from facebookarchive/fb.resnet.torch
-
Notifications
You must be signed in to change notification settings - Fork 3
/
init.lua
87 lines (75 loc) · 2.55 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
--
-- Copyright (c) 2016, Facebook, Inc.
-- Copyright (c) 2016, Fedor Chervinskii
--
require 'nn'
require 'cunn'
require 'cudnn'
local M = {}
function M.setup(opt, checkpoint)
local model
if checkpoint then
local modelPath = paths.concat(opt.resume, checkpoint.modelFile)
assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath)
print('=> Resuming model from ' .. modelPath)
model = torch.load(modelPath):cuda()
elseif opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain)
model = torch.load(opt.retrain):cuda()
else
print('=> Creating model from file: models/' .. opt.netType .. '.lua')
model = require('models/' .. opt.netType)(opt)
end
-- For resetting the classifier when fine-tuning on a different Dataset
if opt.resetClassifier and not checkpoint then
print(' => Replacing classifier with ' .. opt.nClasses .. '-way classifier')
local orig = model:get(#model.modules)
assert(torch.type(orig) == 'nn.Linear',
'expected last layer to be fully connected')
local linear = nn.Linear(orig.weight:size(2), opt.nClasses)
linear.bias:zero()
model:remove(#model.modules)
model:add(linear:cuda())
end
-- Set the CUDNN flags
if opt.cudnn == 'fastest' then
cudnn.fastest = true
cudnn.benchmark = true
elseif opt.cudnn == 'deterministic' then
-- Use a deterministic convolution implementation
model:apply(function(m)
if m.setMode then m:setMode(1, 1, 1) end
end)
end
local criterion = cudnn.SpatialCrossEntropyCriterion()
return model, criterion
end
function M.shareGradInput(model)
local function sharingKey(m)
local key = torch.type(m)
if m.__shareGradInputKey then
key = key .. ':' .. m.__shareGradInputKey
end
return key
end
-- Share gradInput for memory efficient backprop
local cache = {}
model:apply(function(m)
local moduleType = torch.type(m)
if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' then
local key = sharingKey(m)
if cache[key] == nil then
cache[key] = torch.CudaStorage(1)
end
m.gradInput = torch.CudaTensor(cache[key], 1, 0)
end
end)
for i, m in ipairs(model:findModules('nn.ConcatTable')) do
if cache[i % 2] == nil then
cache[i % 2] = torch.CudaStorage(1)
end
m.gradInput = torch.CudaTensor(cache[i % 2], 1, 0)
end
end
return M