diff --git a/donkey.lua b/donkey.lua index 9b4cf83..14eb609 100755 --- a/donkey.lua +++ b/donkey.lua @@ -243,7 +243,7 @@ local function makeData(images, flows) elseif opt.level == 3 then local coarseImages = image.scale(images, opt.fineWidth/2, opt.fineHeight/2) - initFlow = computeInitFlowL1(coarseImages:resize(1,coarseImages:size(1), + initFlow = computeInitFlowL2(coarseImages:resize(1,coarseImages:size(1), coarseImages:size(2), coarseImages:size(3)):cuda()) initFlow = scaleFlow(initFlow:squeeze():float(), opt.fineHeight, opt.fineWidth) @@ -252,7 +252,7 @@ local function makeData(images, flows) elseif opt.level == 4 then local coarseImages = image.scale(images, opt.fineWidth/2, opt.fineHeight/2) - initFlow = computeInitFlowL1(coarseImages:resize(1,coarseImages:size(1), + initFlow = computeInitFlowL3(coarseImages:resize(1,coarseImages:size(1), coarseImages:size(2), coarseImages:size(3)):cuda()) initFlow = scaleFlow(initFlow:squeeze():float(), opt.fineHeight, opt.fineWidth) @@ -261,7 +261,7 @@ local function makeData(images, flows) elseif opt.level == 5 then local coarseImages = image.scale(images, opt.fineWidth/2, opt.fineHeight/2) - initFlow = computeInitFlowL1(coarseImages:resize(1,coarseImages:size(1), + initFlow = computeInitFlowL4(coarseImages:resize(1,coarseImages:size(1), coarseImages:size(2), coarseImages:size(3)):cuda()) initFlow = scaleFlow(initFlow:squeeze():float(), opt.fineHeight, opt.fineWidth) diff --git a/spynet.lua b/spynet.lua index 8793eea..f020031 100644 --- a/spynet.lua +++ b/spynet.lua @@ -389,14 +389,54 @@ local easyComputeFlow = function(im1, im2) end -local function easy_setup() - modelL1path = paths.concat('models', 'modelL1_F.t7') - modelL2path = paths.concat('models', 'modelL2_F.t7') - modelL3path = paths.concat('models', 'modelL3_F.t7') - modelL4path = paths.concat('models', 'modelL4_F.t7') - modelL5path = paths.concat('models', 'modelL5_F.t7') - modelL6path = paths.concat('models', 'modelL6_F.t7') +local function easy_setup(opt) + opt = opt or 'sintelFinal' + if opt=="sintelFinal" then + modelL1path = paths.concat('models', 'modelL1_F.t7') + modelL2path = paths.concat('models', 'modelL2_F.t7') + modelL3path = paths.concat('models', 'modelL3_F.t7') + modelL4path = paths.concat('models', 'modelL4_F.t7') + modelL5path = paths.concat('models', 'modelL5_F.t7') + modelL6path = paths.concat('models', 'modelL6_F.t7') + end + + if opt=="sintelClean" then + modelL1path = paths.concat('models', 'modelL1_C.t7') + modelL2path = paths.concat('models', 'modelL2_C.t7') + modelL3path = paths.concat('models', 'modelL3_C.t7') + modelL4path = paths.concat('models', 'modelL4_C.t7') + modelL5path = paths.concat('models', 'modelL5_C.t7') + modelL6path = paths.concat('models', 'modelL6_C.t7') + end + + if opt=="chairsClean" then + modelL1path = paths.concat('models', 'modelL1_4.t7') + modelL2path = paths.concat('models', 'modelL2_4.t7') + modelL3path = paths.concat('models', 'modelL3_4.t7') + modelL4path = paths.concat('models', 'modelL4_4.t7') + modelL5path = paths.concat('models', 'modelL5_4.t7') + modelL6path = paths.concat('models', 'modelL5_4.t7') + end + + if opt=="chairsFinal" then + modelL1path = paths.concat('models', 'modelL1_3.t7') + modelL2path = paths.concat('models', 'modelL2_3.t7') + modelL3path = paths.concat('models', 'modelL3_3.t7') + modelL4path = paths.concat('models', 'modelL4_3.t7') + modelL5path = paths.concat('models', 'modelL5_3.t7') + modelL6path = paths.concat('models', 'modelL5_3.t7') + end + + if opt=="kittiFinal" then + modelL1path = paths.concat('models', 'modelL1_K.t7') + modelL2path = paths.concat('models', 'modelL2_K.t7') + modelL3path = paths.concat('models', 'modelL3_K.t7') + modelL4path = paths.concat('models', 'modelL4_K.t7') + modelL5path = paths.concat('models', 'modelL5_K.t7') + modelL6path = paths.concat('models', 'modelL6_K.t7') + end + modelL1 = torch.load(modelL1path) if torch.type(modelL1) == 'nn.DataParallelTable' then modelL1 = modelL1:get(1)