diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..182d6b0 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +checkpoint* diff --git a/EPECriterion.lua b/EPECriterion.lua new file mode 100755 index 0000000..c97534e --- /dev/null +++ b/EPECriterion.lua @@ -0,0 +1,68 @@ + +-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. +-- All rights reserved. +-- This software is provided for research purposes only. +-- By using this software you agree to the terms of the license file +-- in the root folder. +-- For commercial use, please contact ps-license@tue.mpg.de. + +local EPECriterion, parent = torch.class('nn.EPECriterion', 'nn.Criterion') + +-- Computes average endpoint error for batchSize x ChannelSize x Height x Width +-- flow fields or general multidimensional matrices. + +local eps = 1e-12 + +function EPECriterion:__init() + parent.__init(self) + self.sizeAverage = true +end + +function EPECriterion:updateOutput(input, target) + assert( input:nElement() == target:nElement(), + "input and target size mismatch") + + self.buffer = self.buffer or input.new() + + local buffer = self.buffer + local output + local npixels + + buffer:resizeAs(input) + npixels = input:nElement()/2 -- 2 channel flow fields + + buffer:add(input, -1, target):pow(2) + output = torch.sum(buffer,2):sqrt() -- second channel is flow + output = output:sum() + + output = output / npixels + + self.output = output + + return self.output +end + +function EPECriterion:updateGradInput(input, target) + + assert( input:nElement() == target:nElement(), + "input and target size mismatch") + + self.buffer = self.buffer or input.new() + + local buffer = self.buffer + local gradInput = self.gradInput + local npixels + local loss + + buffer:resizeAs(input) + npixels = input:nElement()/2 + + buffer:add(input, -1, target):pow(2) + loss = torch.sum(buffer,2):sqrt():add(eps) -- forms the denominator + loss = torch.cat(loss, loss, 2) -- Repeat tensor to scale the gradients + + gradInput:resizeAs(input) + gradInput:add(input, -1, target):cdiv(loss) + gradInput = gradInput / npixels + return gradInput +end \ No newline at end of file diff --git a/README.md b/README.md index 984868b..6b5ec6b 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,17 @@ You can also use batch-mode, if your images `im` are a tensor of size `Bx6xHxW`, ```lua flow = computeFlow(im) ``` + +## Training +Training sequentially is faster than training end-to-end since you need to learn small number of parameters at each level. To train a level `N`, we need the trained models at levels `1` to `N-1`. You also initialize the model with a pretrained model at `N-1`. + +E.g. To train level 3, we need trained models at `L1` and `L2`, and we initialize it `modelL2_3.t7`. +```bash +th main.lua -fineWidth 128 -fineHeight 96 -level 3 -netType volcon \ +-cache checkpoint -data FLYING_CHAIRS_DIR \ +-L1 models/modelL1_3.t7 -L2 models/modelL2_3.t7 \ +-retrain models/modelL2_3.t7 +``` ## Timing Benchmarks Our timing benchmark is set up on Flying chair dataset. To test it, you need to download ```bash diff --git a/data.lua b/data.lua new file mode 100755 index 0000000..7f7ca11 --- /dev/null +++ b/data.lua @@ -0,0 +1,51 @@ +-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. +-- All rights reserved. +-- This software is provided for research purposes only. +-- By using this software you agree to the terms of the license file +-- in the root folder. +-- For commercial use, please contact ps-license@tue.mpg.de. +-- +-- Copyright (c) 2014, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +local ffi = require 'ffi' +local Threads = require 'threads' +Threads.serialization('threads.sharedserialize') + +-- This script contains the logic to create K threads for parallel data-loading. +-- For the data-loading details, look at donkey.lua +------------------------------------------------------------------------------- +do -- start K datathreads (donkeys) + if opt.nDonkeys > 0 then + local options = opt -- make an upvalue to serialize over to donkey threads + donkeys = Threads( + opt.nDonkeys, + function() + require 'torch' + end, + function(idx) + opt = options -- pass to all donkeys via upvalue + tid = idx + local seed = opt.manualSeed + idx + torch.manualSeed(seed) + print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) + paths.dofile('donkey.lua') + end + ); + else -- single threaded data loading. useful for debugging + paths.dofile('donkey.lua') + donkeys = {} + function donkeys:addjob(f1, f2) f2(f1()) end + function donkeys:synchronize() end + end +end + +nTest = 0 +donkeys:addjob(function() return testLoader:size() end, function(c) nTest = c end) +donkeys:synchronize() +assert(nTest > 0, "Failed to get nTest") +print('nTest: ', nTest) diff --git a/dataset.lua b/dataset.lua new file mode 100755 index 0000000..0df7633 --- /dev/null +++ b/dataset.lua @@ -0,0 +1,148 @@ + +-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. +-- All rights reserved. +-- This software is provided for research purposes only. +-- By using this software you agree to the terms of the license file +-- in the root folder. +-- For commercial use, please contact ps-license@tue.mpg.de. + +require 'torch' +torch.setdefaulttensortype('torch.FloatTensor') +local ffi = require 'ffi' +local class = require('pl.class') +local dir = require 'pl.dir' +local tablex = require 'pl.tablex' +local argcheck = require 'argcheck' +require 'sys' +require 'xlua' +require 'image' + +local dataset = torch.class('dataLoader') + +local initcheck = argcheck{ + pack=true, + help=[[ + A dataset class for images in a flat folder structure (folder-name is class-name). + Optimized for extremely large datasets (upwards of 14 million images). + Tested only on Linux (as it uses command-line linux utilities to scale up) +]], + {name="inputSize", + type="table", + help="the size of the input images"}, + + {name="outputSize", + type="table", + help="the size of the network output"}, + + {name="split", + type="number", + help="Percentage of split to go to Training" + }, + + {name="samplingMode", + type="string", + help="Sampling mode: random | balanced ", + default = "balanced"}, + + {name="verbose", + type="boolean", + help="Verbose mode during initialization", + default = false}, + + {name="loadSize", + type="table", + help="a size to load the images to, initially", + opt = true}, + + {name="samplingIds", + type="torch.LongTensor", + help="the ids of training or testing images", + opt = true}, + + {name="sampleHookTrain", + type="function", + help="applied to sample during training(ex: for lighting jitter). " + .. "It takes the image path as input", + opt = true}, + + {name="sampleHookTest", + type="function", + help="applied to sample during testing", + opt = true}, +} + +function dataset:__init(...) + + -- argcheck + local args = initcheck(...) + print(args) + for k,v in pairs(args) do self[k] = v end + + if not self.loadSize then self.loadSize = self.inputSize; end + + if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end + if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end + + local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end + + self.numSamples = self.samplingIds:size()[1] + assert(self.numSamples > 0, "Could not find any sample in the given input paths") + + if self.verbose then print(self.numSamples .. ' samples found.') end +end + +function dataset:size(class, list) + return self.numSamples +end + +-- converts a table of samples (and corresponding labels) to a clean tensor +local function tableToOutput(self, imgTable, outputTable) + local images, outputs + local quantity = #imgTable + assert(imgTable[1]:size()[1] == self.inputSize[1]) + assert(outputTable[1]:size()[1] == self.outputSize[1]) + + images = torch.Tensor(quantity, + self.inputSize[1], self.inputSize[2], self.inputSize[3]) + outputs = torch.Tensor(quantity, + self.outputSize[1], self.outputSize[2], self.outputSize[3]) + + for i=1,quantity do + images[i]:copy(imgTable[i]) + outputs[i]:copy(outputTable[i]) + end + return images, outputs +end + +-- sampler, samples from the training set. +function dataset:sample(quantity) + assert(quantity) + local imgTable = {} + local outputTable = {} + for i=1,quantity do + local id = torch.random(1, self.numSamples) + local img, output = self:sampleHookTrain(self.samplingIds[id][1]) -- single element[not tensor] from a row + + table.insert(imgTable, img) + table.insert(outputTable, output) + end + local images, outputs = tableToOutput(self, imgTable, outputTable) + return images, outputs +end + +function dataset:get(i1, i2) + local indices = self.samplingIds[{{i1, i2}}]; + local quantity = i2 - i1 + 1; + assert(quantity > 0) + local imgTable = {} + local outputTable = {} + for i=1,quantity do + local img, output = self:sampleHookTest(indices[i][1]) + table.insert(imgTable, img) + table.insert(outputTable, output) + end + local images, outputs = tableToOutput(self, imgTable, outputTable) + return images, outputs +end + +return dataset diff --git a/donkey.lua b/donkey.lua new file mode 100755 index 0000000..9b4cf83 --- /dev/null +++ b/donkey.lua @@ -0,0 +1,412 @@ +-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. +-- All rights reserved. +-- This software is provided for research purposes only. +-- By using this software you agree to the terms of the license file +-- in the root folder. +-- For commercial use, please contact ps-license@tue.mpg.de. + +require 'image' +require 'nn' +require 'cunn' +require 'cudnn' +require 'nngraph' +require 'stn' +require 'spy' + +local flowX = require 'flowExtensions' +local TF = require 'transforms' + +paths.dofile('dataset.lua') +paths.dofile('util.lua') + +-- This file contains the data-loading logic and details. +-- It is run by each data-loader thread. +------------------------------------------ +local eps = 1e-6 +-- a cache file of the training metadata (if doesnt exist, will be created) +local trainCache = paths.concat(opt.cache, 'trainCache.t7') +local testCache = paths.concat(opt.cache, 'testCache.t7') + +local meanstd = { + mean = { 0.485, 0.456, 0.406 }, + std = { 0.229, 0.224, 0.225 }, +} +local pca = { + eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, + eigvec = torch.Tensor{ + { -0.5675, 0.7192, 0.4009 }, + { -0.5808, -0.0045, -0.8140 }, + { -0.5836, -0.6948, 0.4203 }, + }, +} + +local mean = meanstd.mean +local std = meanstd.std +------------------------------------------ +-- Warping Function: +local function createWarpModel() + local imgData = nn.Identity()() + local floData = nn.Identity()() + + local imgOut = nn.Transpose({2,3},{3,4})(imgData) + local floOut = nn.Transpose({2,3},{3,4})(floData) + + local warpImOut = nn.Transpose({3,4},{2,3})(nn.BilinearSamplerBHWD()({imgOut, floOut})) + local model = nn.gModule({imgData, floData}, {warpImOut}) + + return model +end + +local modelL1, modelL2, modelL3, modelL4 +local modelL1path, modelL2path, modelL3path, modelL4path +local down1, down2, down3, down4, up2, up3, up4 +local warpmodel2, warpmodel3, warpmodel4 + +modelL1path = opt.L1 +modelL2path = opt.L2 +modelL3path = opt.L3 +modelL4path = opt.L4 + +if opt.level > 1 then + -- Load modelL1 + modelL1 = torch.load(modelL1path) + if torch.type(modelL1) == 'nn.DataParallelTable' then + modelL1 = modelL1:get(1) + end + modelL1:evaluate() + down1 = nn.SpatialAveragePooling(2,2,2,2):cuda() + down1:evaluate() +end + +if opt.level > 2 then +-- Load modelL2 + modelL2 = torch.load(modelL2path) + if torch.type(modelL2) == 'nn.DataParallelTable' then + modelL2 = modelL2:get(1) + end + modelL2:evaluate() + + down2 = nn.SpatialAveragePooling(2,2,2,2):cuda() + up2 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() + warpmodel2 = createWarpModel():cuda() + + down2:evaluate() + up2:evaluate() + warpmodel2:evaluate() +end + +if opt.level > 3 then + -- Load modelL3 + modelL3 = torch.load(modelL3path) + if torch.type(modelL3) == 'nn.DataParallelTable' then + modelL3 = modelL3:get(1) + end + modelL3:evaluate() + + down3 = nn.SpatialAveragePooling(2,2,2,2):cuda() + up3 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() + warpmodel3 = createWarpModel():cuda() + + down3:evaluate() + up3:evaluate() + warpmodel3:evaluate() +end + +if opt.level > 4 then + -- Load modelL4 + modelL4 = torch.load(modelL4path) + if torch.type(modelL4) == 'nn.DataParallelTable' then + modelL4 = modelL4:get(1) + end + modelL4:evaluate() + + down4 = nn.SpatialAveragePooling(2,2,2,2):cuda() + up4 = nn.Sequential():add(nn.Transpose({2,3},{3,4})):add(nn.ScaleBHWD(2)):add(nn.Transpose({3,4},{2,3})):cuda() + warpmodel4 = createWarpModel():cuda() + + down4:evaluate() + up4:evaluate() + warpmodel4:evaluate() +end + +-- Check for existence of opt.data +if not os.execute('cd ' .. opt.data) then + error(("could not chdir to '%s'"):format(opt.data)) +end + +local loadSize = opt.loadSize +local inputSize = {8, opt.fineHeight, opt.fineWidth} +local outputSize = {2, opt.fineHeight, opt.fineWidth} + +local function getTrainValidationSplits(path) + local numSamples = sys.fexecute( "ls " .. opt.data .. "| wc -l")/3 + local ff = torch.DiskFile(path, 'r') + local trainValidationSamples = torch.IntTensor(numSamples) + ff:readInt(trainValidationSamples:storage()) + ff:close() + + local train_samples = trainValidationSamples:eq(1):nonzero() + local validation_samples = trainValidationSamples:eq(2):nonzero() + + return train_samples, validation_samples +end + +local train_samples, validation_samples = getTrainValidationSplits(opt.trainValidationSplit) + +local function loadImage(path) + local input = image.load(path, 3, 'float') + return input +end + +local function rotateFlow(flow, angle) + local flow_rot = image.rotate(flow, angle) + local fu = torch.mul(flow_rot[1], math.cos(-angle)) - torch.mul(flow_rot[2], math.sin(-angle)) + local fv = torch.mul(flow_rot[1], math.sin(-angle)) + torch.mul(flow_rot[2], math.cos(-angle)) + flow_rot[1]:copy(fu) + flow_rot[2]:copy(fv) + + return flow_rot +end + +local function scaleFlow(flow, height, width) + -- scale the original flow to a flow of size height x width + local sc = height/flow:size(2) + assert(torch.abs(width/flow:size(3) - sc) Creating model from file: models/' .. opt.netType .. '.lua') + model = createModel(opt.nGPU) -- for the model creation code, check the models/ folder + if opt.backend == 'cudnn' then + require 'cudnn' + cudnn.convert(model, cudnn) + elseif opt.backend ~= 'nn' then + error'Unsupported backend' + end +end + +-- 2. Create Criterion +criterion = nn.EPECriterion() + +print('=> Model') +print(model) + +print('=> Criterion') +print(criterion) + +criterion:cuda() + +collectgarbage() diff --git a/models/volcon.lua b/models/volcon.lua new file mode 100644 index 0000000..0217852 --- /dev/null +++ b/models/volcon.lua @@ -0,0 +1,30 @@ +-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. +-- All rights reserved. +-- This software is provided for research purposes only. +-- By using this software you agree to the terms of the license file +-- in the root folder. +-- For commercial use, please contact ps-license@tue.mpg.de. + +require 'nn' +require 'cutorch' +require 'cunn' +require 'cudnn' +function createModel(nGPU) + local model = nn.Sequential() + model:add(nn.SpatialConvolution(8,32,7,7,1,1,3,3)) + model:add(nn.ReLU()) + model:add(nn.SpatialConvolution(32,64,7,7,1,1,3,3)) + model:add(nn.ReLU()) + model:add(nn.SpatialConvolution(64,32,7,7,1,1,3,3)) + model:add(nn.ReLU()) + model:add(nn.SpatialConvolution(32,16,7,7,1,1,3,3)) + model:add(nn.ReLU()) + model:add(nn.SpatialConvolution(16,2,7,7,1,1,3,3)) + + if nGPU>0 then + model:cuda() + model = makeDataParallel(model, nGPU) + end + + return model +end diff --git a/opts.lua b/opts.lua new file mode 100755 index 0000000..4bb9538 --- /dev/null +++ b/opts.lua @@ -0,0 +1,61 @@ +-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. +-- All rights reserved. +-- This software is provided for research purposes only. +-- By using this software you agree to the terms of the license file +-- in the root folder. +-- For commercial use, please contact ps-license@tue.mpg.de. + +local M = { } + +function M.parse(arg) + local cmd = torch.CmdLine() + cmd:text() + cmd:text('SPyNet Coarse-to-Fine Optical Flow Training') + cmd:text() + cmd:text('Options:') + ------------ General options -------------------- + + cmd:option('-cache', 'checkpoint/', 'subdirectory in which to save/log experiments') + cmd:option('-data', 'flying_chairs/data', 'Home of Flying Chairs dataset') + cmd:option('-trainValidationSplit', 'train_val_split.txt', 'File containing training and validation split') + cmd:option('-manualSeed', 2, 'Manually set RNG seed') + cmd:option('-GPU', 1, 'Default preferred GPU') + cmd:option('-nGPU', 1, 'Number of GPUs to use by default') + cmd:option('-backend', 'cudnn', 'Options: cudnn | ccn2 | cunn') + ------------- Data options ------------------------ + cmd:option('-nDonkeys', 4, 'number of donkeys to initialize (data loading threads)') + cmd:option('-fineWidth', 512, 'the length of the fine flow field') + cmd:option('-fineHeight', 384, 'the width of the fine flow field') + cmd:option('-level', 1, 'Options: 1,2,3.., wheather to initialize flow to zero' ) + ------------- Training options -------------------- + cmd:option('-augment', 1, 'augment the data') + cmd:option('-nEpochs', 1000, 'Number of total epochs to run') + cmd:option('-epochSize', 1000, 'Number of batches per epoch') + cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)') + cmd:option('-batchSize', 32, 'mini-batch size (1 = pure stochastic)') + ---------- Optimization options ---------------------- + cmd:option('-LR', 0.0, 'learning rate; if set, overrides default LR/WD recipe') + cmd:option('-momentum', 0.9, 'momentum') + cmd:option('-weightDecay', 5e-4, 'weight decay') + cmd:option('-optimizer', 'adam', 'adam or sgd') + ---------- Model options ---------------------------------- + cmd:option('-L1', 'models/modelL1_4.t7', 'Trained Level 1 model') + cmd:option('-L2', 'models/modelL2_4.t7', 'Trained Level 2 model') + cmd:option('-L3', 'models/modelL3_4.t7', 'Trained Level 3 model') + cmd:option('-L4', 'models/modelL4_4.t7', 'Trained Level 4 model') + + cmd:option('-netType', 'volcon', 'Lua network file') + cmd:option('-retrain', 'none', 'provide path to model to retrain with') + cmd:option('-optimState', 'none', 'provide path to an optimState to reload from') + cmd:text() + + local opt = cmd:parse(arg or {}) + opt.save = paths.concat(opt.cache) + -- add date/time + opt.save = paths.concat(opt.save, '' .. os.date():gsub(' ','')) + + opt.loadSize = {8, 384, 512} + return opt +end + +return M diff --git a/test.lua b/test.lua new file mode 100755 index 0000000..e174c42 --- /dev/null +++ b/test.lua @@ -0,0 +1,84 @@ +-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. +-- All rights reserved. +-- This software is provided for research purposes only. +-- By using this software you agree to the terms of the license file +-- in the root folder. +-- For commercial use, please contact ps-license@tue.mpg.de. +-- +-- Copyright (c) 2014, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) + +local batchNumber +local error_center, loss +local timer = torch.Timer() + +function test() + print('==> doing epoch on validation data:') + print("==> online epoch # " .. epoch) + + batchNumber = 0 + cutorch.synchronize() + timer:reset() + + -- set the dropouts to evaluate mode + model:evaluate() + + error_center = 0 + loss = 0 + for i=1,nTest/opt.batchSize do -- nTest is set in 1_data.lua + local indexStart = (i-1) * opt.batchSize + 1 + local indexEnd = (indexStart + opt.batchSize - 1) + donkeys:addjob( + -- work to be done by donkey thread + function() + local inputs, labels = testLoader:get(indexStart, indexEnd) + return inputs, labels + end, + -- callback that is run in the main thread once the work is done + testBatch + ) + end + + donkeys:synchronize() + cutorch.synchronize() + + error_center = error_center * 100 / nTest + loss = loss / (nTest/opt.batchSize) -- because loss is calculated per batch + testLogger:add{ + ['% top1 accuracy (test set) (center crop)'] = error_center, + ['avg loss (test set)'] = loss + } + print(string.format('Epoch: [%d][TESTING SUMMARY] Total Time(s): %.2f \t' + .. 'average loss (per batch): %.2f \t ' + .. 'accuracy [Center](%%):\t top-1 %.2f\t ', + epoch, timer:time().real, loss, error_center)) + + print('\n') + + +end -- of test() +----------------------------------------------------------------------------- +local inputs = torch.CudaTensor() +local labels = torch.CudaTensor() + +function testBatch(inputsCPU, labelsCPU) + batchNumber = batchNumber + opt.batchSize + + inputs:resize(inputsCPU:size()):copy(inputsCPU) + labels:resize(labelsCPU:size()):copy(labelsCPU) + + local outputs = model:forward(inputs) + local err = criterion:forward(outputs, labels) + cutorch.synchronize() + local pred = outputs:float() + + loss = loss + err + + print(('Epoch: Testing [%d][%d/%d]'):format(epoch, batchNumber, nTest)) +end diff --git a/train.lua b/train.lua new file mode 100755 index 0000000..9dafecb --- /dev/null +++ b/train.lua @@ -0,0 +1,193 @@ +-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. +-- All rights reserved. +-- This software is provided for research purposes only. +-- By using this software you agree to the terms of the license file +-- in the root folder. +-- For commercial use, please contact ps-license@tue.mpg.de. +-- +-- Copyright (c) 2014, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +require 'optim' + +--[[ + 1. Setup SGD optimization state and learning rate schedule + 2. Create loggers. + 3. train - this function handles the high-level training loop, + i.e. load data, train model, save model and state to disk + 4. trainBatch - Used by train() to train a single batch after the data is loaded. +]]-- + +-- Setup a reused optimization state (for sgd). If needed, reload it from disk +local optimState = { + learningRate = opt.LR, + learningRateDecay = 0.0, + momentum = opt.momentum, + dampening = 0.0, + weightDecay = opt.weightDecay +} + +if opt.optimState ~= 'none' then + assert(paths.filep(opt.optimState), 'File not found: ' .. opt.optimState) + print('Loading optimState from file: ' .. opt.optimState) + optimState = torch.load(opt.optimState) +end + +-- Learning rate annealing schedule. We will build a new optimizer for +-- each epoch. +-- +-- By default we follow a known recipe for a 55-epoch training. If +-- the learningRate command-line parameter has been specified, though, +-- we trust the user is doing something manual, and will use her +-- exact settings for all optimization. +-- +-- Return values: +-- diff to apply to optimState, +-- true IFF this is the first epoch of a new regime +local function paramsForEpoch(epoch) + if opt.LR ~= 0.0 then -- if manually specified + return { } + end + local regimes = { + -- start, end, LR, WD, + { 1, 10, 5e-3, 0 }, + { 11, 80, 1e-4, 0 }, + { 81, 120, 1e-4, 0 }, + { 121, 160, 1e-4, 0 }, + { 161, 1e8, 5e-5, 0 }, + } + + for _, row in ipairs(regimes) do + if epoch >= row[1] and epoch <= row[2] then + return { learningRate=row[3], weightDecay=row[4] }, epoch == row[1] + end + end +end + +-- 2. Create loggers. +trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) +local batchNumber +local top1_epoch, loss_epoch + +-- 3. train - this function handles the high-level training loop, +-- i.e. load data, train model, save model and state to disk +function train() + print('==> doing epoch on training data:') + print("==> online epoch # " .. epoch) + + local params, newRegime = paramsForEpoch(epoch) + if newRegime then + optimState = { + learningRate = params.learningRate, + learningRateDecay = 0.0, + momentum = opt.momentum, + dampening = 0.0, + weightDecay = params.weightDecay + } + end + batchNumber = 0 + cutorch.synchronize() + + -- set the dropouts to training mode + model:training() + + local tm = torch.Timer() + top1_epoch = 0 + loss_epoch = 0 + for i=1,opt.epochSize do + -- queue jobs to data-workers + donkeys:addjob( + -- the job callback (runs in data-worker thread) + function() + local inputs, labels = trainLoader:sample(opt.batchSize) + return inputs, labels + end, + -- the end callback (runs in the main thread) + trainBatch + ) + end + + donkeys:synchronize() + cutorch.synchronize() + + top1_epoch = top1_epoch * 100 / (opt.batchSize * opt.epochSize) + loss_epoch = loss_epoch / opt.epochSize + + trainLogger:add{ + ['% top1 accuracy (train set)'] = top1_epoch, + ['avg loss (train set)'] = loss_epoch + } + print(string.format('Epoch: [%d][TRAINING SUMMARY] Total Time(s): %.2f\t' + .. 'average loss (per batch): %.2f \t ' + .. 'accuracy(%%):\t top-1 %.2f\t', + epoch, tm:time().real, loss_epoch, top1_epoch)) + print('\n') + + -- save model + collectgarbage() + + -- clear the intermediate states in the model before saving to disk + -- this saves lots of disk space + model:clearState() + saveDataParallel(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), model) -- defined in util.lua + torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState) +end -- of train() +------------------------------------------------------------------------------------------- +-- GPU inputs (preallocate) +local inputs = torch.CudaTensor() +local labels = torch.CudaTensor() + +local timer = torch.Timer() +local dataTimer = torch.Timer() + +local parameters, gradParameters = model:getParameters() + +-- 4. trainBatch - Used by train() to train a single batch after the data is loaded. +function trainBatch(inputsCPU, labelsCPU) + cutorch.synchronize() + collectgarbage() + local dataLoadingTime = dataTimer:time().real + timer:reset() + + -- transfer over to GPU + inputs:resize(inputsCPU:size()):copy(inputsCPU) + labels:resize(labelsCPU:size()):copy(labelsCPU) + + local err, outputs + feval = function(x) + model:zeroGradParameters() + outputs = model:forward(inputs) + err = criterion:forward(outputs, labels) + local gradOutputs = criterion:backward(outputs, labels) + model:backward(inputs, gradOutputs) + return err, gradParameters + end + + if opt.optimizer == 'adam' then + optim.adam(feval, parameters, optimState) + elseif opt.optimizer == 'sgd' then + optim.sgd(feval, parameters, optimState) + else + error("Specify Optimizer") + end + + -- DataParallelTable's syncParameters + if model.needsSync then + model:syncParameters() + end + + cutorch.synchronize() + batchNumber = batchNumber + 1 + loss_epoch = loss_epoch + err + + -- Calculate top-1 error, and print information + print(('Epoch: [%d][%d/%d]\tTime %.3f Err %.4f LR %.0e DataLoadingTime %.3f'):format( + epoch, batchNumber, opt.epochSize, timer:time().real, err, + optimState.learningRate, dataLoadingTime)) + + dataTimer:reset() +end diff --git a/util.lua b/util.lua new file mode 100755 index 0000000..d7249b5 --- /dev/null +++ b/util.lua @@ -0,0 +1,69 @@ +require 'cunn' +local ffi=require 'ffi' + +function makeDataParallel(model, nGPU) + if nGPU > 1 then + print('converting module to nn.DataParallelTable') + assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified') + local model_single = model + model = nn.DataParallelTable(1) + for i=1, nGPU do + cutorch.setDevice(i) + model:add(model_single:clone():cuda(), i) + end + end + cutorch.setDevice(opt.GPU) + + return model +end + +local function cleanDPT(module) + -- This assumes this DPT was created by the function above: all the + -- module.modules are clones of the same network on different GPUs + -- hence we only need to keep one when saving the model to the disk. + local newDPT = nn.DataParallelTable(1) + cutorch.setDevice(opt.GPU) + newDPT:add(module:get(1), opt.GPU) + return newDPT +end + +function saveDataParallel(filename, model) + if torch.type(model) == 'nn.DataParallelTable' then + torch.save(filename, cleanDPT(model)) + elseif torch.type(model) == 'nn.Sequential' then + local temp_model = nn.Sequential() + for i, module in ipairs(model.modules) do + if torch.type(module) == 'nn.DataParallelTable' then + temp_model:add(cleanDPT(module)) + else + temp_model:add(module) + end + end + torch.save(filename, temp_model) + elseif torch.type(model) == 'nn.gModule' then + torch.save(filename, model) + else + error('This saving function only works with Sequential or DataParallelTable modules.') + end +end + +function loadDataParallel(filename, nGPU) + if opt.backend == 'cudnn' then + require 'cudnn' + end + local model = torch.load(filename) + if torch.type(model) == 'nn.DataParallelTable' then + return makeDataParallel(model:get(1), nGPU) + elseif torch.type(model) == 'nn.Sequential' then + for i,module in ipairs(model.modules) do + if torch.type(module) == 'nn.DataParallelTable' then + model.modules[i] = makeDataParallel(module:get(1):float(), nGPU) + end + end + return model + elseif torch.type(model) == 'nn.gModule' then + return model + else + error('The loaded model is not a Sequential or DataParallelTable module.') + end +end