Skip to content

Commit

Permalink
MAJOR BUG FIX: donkey.lua makeData function
Browse files Browse the repository at this point in the history
  • Loading branch information
Anurag Ranjan committed Mar 21, 2017
1 parent 5ce5aa0 commit cd47649
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 10 deletions.
6 changes: 3 additions & 3 deletions donkey.lua
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ local function makeData(images, flows)

elseif opt.level == 3 then
local coarseImages = image.scale(images, opt.fineWidth/2, opt.fineHeight/2)
initFlow = computeInitFlowL1(coarseImages:resize(1,coarseImages:size(1),
initFlow = computeInitFlowL2(coarseImages:resize(1,coarseImages:size(1),
coarseImages:size(2), coarseImages:size(3)):cuda())
initFlow = scaleFlow(initFlow:squeeze():float(), opt.fineHeight, opt.fineWidth)

Expand All @@ -252,7 +252,7 @@ local function makeData(images, flows)

elseif opt.level == 4 then
local coarseImages = image.scale(images, opt.fineWidth/2, opt.fineHeight/2)
initFlow = computeInitFlowL1(coarseImages:resize(1,coarseImages:size(1),
initFlow = computeInitFlowL3(coarseImages:resize(1,coarseImages:size(1),
coarseImages:size(2), coarseImages:size(3)):cuda())
initFlow = scaleFlow(initFlow:squeeze():float(), opt.fineHeight, opt.fineWidth)

Expand All @@ -261,7 +261,7 @@ local function makeData(images, flows)

elseif opt.level == 5 then
local coarseImages = image.scale(images, opt.fineWidth/2, opt.fineHeight/2)
initFlow = computeInitFlowL1(coarseImages:resize(1,coarseImages:size(1),
initFlow = computeInitFlowL4(coarseImages:resize(1,coarseImages:size(1),
coarseImages:size(2), coarseImages:size(3)):cuda())
initFlow = scaleFlow(initFlow:squeeze():float(), opt.fineHeight, opt.fineWidth)

Expand Down
54 changes: 47 additions & 7 deletions spynet.lua
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,54 @@ local easyComputeFlow = function(im1, im2)

end

local function easy_setup()
modelL1path = paths.concat('models', 'modelL1_F.t7')
modelL2path = paths.concat('models', 'modelL2_F.t7')
modelL3path = paths.concat('models', 'modelL3_F.t7')
modelL4path = paths.concat('models', 'modelL4_F.t7')
modelL5path = paths.concat('models', 'modelL5_F.t7')
modelL6path = paths.concat('models', 'modelL6_F.t7')
local function easy_setup(opt)
opt = opt or 'sintelFinal'

if opt=="sintelFinal" then
modelL1path = paths.concat('models', 'modelL1_F.t7')
modelL2path = paths.concat('models', 'modelL2_F.t7')
modelL3path = paths.concat('models', 'modelL3_F.t7')
modelL4path = paths.concat('models', 'modelL4_F.t7')
modelL5path = paths.concat('models', 'modelL5_F.t7')
modelL6path = paths.concat('models', 'modelL6_F.t7')
end

if opt=="sintelClean" then
modelL1path = paths.concat('models', 'modelL1_C.t7')
modelL2path = paths.concat('models', 'modelL2_C.t7')
modelL3path = paths.concat('models', 'modelL3_C.t7')
modelL4path = paths.concat('models', 'modelL4_C.t7')
modelL5path = paths.concat('models', 'modelL5_C.t7')
modelL6path = paths.concat('models', 'modelL6_C.t7')
end

if opt=="chairsClean" then
modelL1path = paths.concat('models', 'modelL1_4.t7')
modelL2path = paths.concat('models', 'modelL2_4.t7')
modelL3path = paths.concat('models', 'modelL3_4.t7')
modelL4path = paths.concat('models', 'modelL4_4.t7')
modelL5path = paths.concat('models', 'modelL5_4.t7')
modelL6path = paths.concat('models', 'modelL5_4.t7')
end

if opt=="chairsFinal" then
modelL1path = paths.concat('models', 'modelL1_3.t7')
modelL2path = paths.concat('models', 'modelL2_3.t7')
modelL3path = paths.concat('models', 'modelL3_3.t7')
modelL4path = paths.concat('models', 'modelL4_3.t7')
modelL5path = paths.concat('models', 'modelL5_3.t7')
modelL6path = paths.concat('models', 'modelL5_3.t7')
end

if opt=="kittiFinal" then
modelL1path = paths.concat('models', 'modelL1_K.t7')
modelL2path = paths.concat('models', 'modelL2_K.t7')
modelL3path = paths.concat('models', 'modelL3_K.t7')
modelL4path = paths.concat('models', 'modelL4_K.t7')
modelL5path = paths.concat('models', 'modelL5_K.t7')
modelL6path = paths.concat('models', 'modelL6_K.t7')
end

modelL1 = torch.load(modelL1path)
if torch.type(modelL1) == 'nn.DataParallelTable' then
modelL1 = modelL1:get(1)
Expand Down

0 comments on commit cd47649

Please sign in to comment.