forked from anuragranj/spynet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request anuragranj#5 from anuragranj/train
merging train branch
- Loading branch information
Showing
14 changed files
with
1,288 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
checkpoint* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
|
||
-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. | ||
-- All rights reserved. | ||
-- This software is provided for research purposes only. | ||
-- By using this software you agree to the terms of the license file | ||
-- in the root folder. | ||
-- For commercial use, please contact [email protected]. | ||
|
||
local EPECriterion, parent = torch.class('nn.EPECriterion', 'nn.Criterion') | ||
|
||
-- Computes average endpoint error for batchSize x ChannelSize x Height x Width | ||
-- flow fields or general multidimensional matrices. | ||
|
||
local eps = 1e-12 | ||
|
||
function EPECriterion:__init() | ||
parent.__init(self) | ||
self.sizeAverage = true | ||
end | ||
|
||
function EPECriterion:updateOutput(input, target) | ||
assert( input:nElement() == target:nElement(), | ||
"input and target size mismatch") | ||
|
||
self.buffer = self.buffer or input.new() | ||
|
||
local buffer = self.buffer | ||
local output | ||
local npixels | ||
|
||
buffer:resizeAs(input) | ||
npixels = input:nElement()/2 -- 2 channel flow fields | ||
|
||
buffer:add(input, -1, target):pow(2) | ||
output = torch.sum(buffer,2):sqrt() -- second channel is flow | ||
output = output:sum() | ||
|
||
output = output / npixels | ||
|
||
self.output = output | ||
|
||
return self.output | ||
end | ||
|
||
function EPECriterion:updateGradInput(input, target) | ||
|
||
assert( input:nElement() == target:nElement(), | ||
"input and target size mismatch") | ||
|
||
self.buffer = self.buffer or input.new() | ||
|
||
local buffer = self.buffer | ||
local gradInput = self.gradInput | ||
local npixels | ||
local loss | ||
|
||
buffer:resizeAs(input) | ||
npixels = input:nElement()/2 | ||
|
||
buffer:add(input, -1, target):pow(2) | ||
loss = torch.sum(buffer,2):sqrt():add(eps) -- forms the denominator | ||
loss = torch.cat(loss, loss, 2) -- Repeat tensor to scale the gradients | ||
|
||
gradInput:resizeAs(input) | ||
gradInput:add(input, -1, target):cdiv(loss) | ||
gradInput = gradInput / npixels | ||
return gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. | ||
-- All rights reserved. | ||
-- This software is provided for research purposes only. | ||
-- By using this software you agree to the terms of the license file | ||
-- in the root folder. | ||
-- For commercial use, please contact [email protected]. | ||
-- | ||
-- Copyright (c) 2014, Facebook, Inc. | ||
-- All rights reserved. | ||
-- | ||
-- This source code is licensed under the BSD-style license found in the | ||
-- LICENSE file in the root directory of this source tree. An additional grant | ||
-- of patent rights can be found in the PATENTS file in the same directory. | ||
-- | ||
local ffi = require 'ffi' | ||
local Threads = require 'threads' | ||
Threads.serialization('threads.sharedserialize') | ||
|
||
-- This script contains the logic to create K threads for parallel data-loading. | ||
-- For the data-loading details, look at donkey.lua | ||
------------------------------------------------------------------------------- | ||
do -- start K datathreads (donkeys) | ||
if opt.nDonkeys > 0 then | ||
local options = opt -- make an upvalue to serialize over to donkey threads | ||
donkeys = Threads( | ||
opt.nDonkeys, | ||
function() | ||
require 'torch' | ||
end, | ||
function(idx) | ||
opt = options -- pass to all donkeys via upvalue | ||
tid = idx | ||
local seed = opt.manualSeed + idx | ||
torch.manualSeed(seed) | ||
print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) | ||
paths.dofile('donkey.lua') | ||
end | ||
); | ||
else -- single threaded data loading. useful for debugging | ||
paths.dofile('donkey.lua') | ||
donkeys = {} | ||
function donkeys:addjob(f1, f2) f2(f1()) end | ||
function donkeys:synchronize() end | ||
end | ||
end | ||
|
||
nTest = 0 | ||
donkeys:addjob(function() return testLoader:size() end, function(c) nTest = c end) | ||
donkeys:synchronize() | ||
assert(nTest > 0, "Failed to get nTest") | ||
print('nTest: ', nTest) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
|
||
-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft. | ||
-- All rights reserved. | ||
-- This software is provided for research purposes only. | ||
-- By using this software you agree to the terms of the license file | ||
-- in the root folder. | ||
-- For commercial use, please contact [email protected]. | ||
|
||
require 'torch' | ||
torch.setdefaulttensortype('torch.FloatTensor') | ||
local ffi = require 'ffi' | ||
local class = require('pl.class') | ||
local dir = require 'pl.dir' | ||
local tablex = require 'pl.tablex' | ||
local argcheck = require 'argcheck' | ||
require 'sys' | ||
require 'xlua' | ||
require 'image' | ||
|
||
local dataset = torch.class('dataLoader') | ||
|
||
local initcheck = argcheck{ | ||
pack=true, | ||
help=[[ | ||
A dataset class for images in a flat folder structure (folder-name is class-name). | ||
Optimized for extremely large datasets (upwards of 14 million images). | ||
Tested only on Linux (as it uses command-line linux utilities to scale up) | ||
]], | ||
{name="inputSize", | ||
type="table", | ||
help="the size of the input images"}, | ||
|
||
{name="outputSize", | ||
type="table", | ||
help="the size of the network output"}, | ||
|
||
{name="split", | ||
type="number", | ||
help="Percentage of split to go to Training" | ||
}, | ||
|
||
{name="samplingMode", | ||
type="string", | ||
help="Sampling mode: random | balanced ", | ||
default = "balanced"}, | ||
|
||
{name="verbose", | ||
type="boolean", | ||
help="Verbose mode during initialization", | ||
default = false}, | ||
|
||
{name="loadSize", | ||
type="table", | ||
help="a size to load the images to, initially", | ||
opt = true}, | ||
|
||
{name="samplingIds", | ||
type="torch.LongTensor", | ||
help="the ids of training or testing images", | ||
opt = true}, | ||
|
||
{name="sampleHookTrain", | ||
type="function", | ||
help="applied to sample during training(ex: for lighting jitter). " | ||
.. "It takes the image path as input", | ||
opt = true}, | ||
|
||
{name="sampleHookTest", | ||
type="function", | ||
help="applied to sample during testing", | ||
opt = true}, | ||
} | ||
|
||
function dataset:__init(...) | ||
|
||
-- argcheck | ||
local args = initcheck(...) | ||
print(args) | ||
for k,v in pairs(args) do self[k] = v end | ||
|
||
if not self.loadSize then self.loadSize = self.inputSize; end | ||
|
||
if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end | ||
if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end | ||
|
||
local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end | ||
|
||
self.numSamples = self.samplingIds:size()[1] | ||
assert(self.numSamples > 0, "Could not find any sample in the given input paths") | ||
|
||
if self.verbose then print(self.numSamples .. ' samples found.') end | ||
end | ||
|
||
function dataset:size(class, list) | ||
return self.numSamples | ||
end | ||
|
||
-- converts a table of samples (and corresponding labels) to a clean tensor | ||
local function tableToOutput(self, imgTable, outputTable) | ||
local images, outputs | ||
local quantity = #imgTable | ||
assert(imgTable[1]:size()[1] == self.inputSize[1]) | ||
assert(outputTable[1]:size()[1] == self.outputSize[1]) | ||
|
||
images = torch.Tensor(quantity, | ||
self.inputSize[1], self.inputSize[2], self.inputSize[3]) | ||
outputs = torch.Tensor(quantity, | ||
self.outputSize[1], self.outputSize[2], self.outputSize[3]) | ||
|
||
for i=1,quantity do | ||
images[i]:copy(imgTable[i]) | ||
outputs[i]:copy(outputTable[i]) | ||
end | ||
return images, outputs | ||
end | ||
|
||
-- sampler, samples from the training set. | ||
function dataset:sample(quantity) | ||
assert(quantity) | ||
local imgTable = {} | ||
local outputTable = {} | ||
for i=1,quantity do | ||
local id = torch.random(1, self.numSamples) | ||
local img, output = self:sampleHookTrain(self.samplingIds[id][1]) -- single element[not tensor] from a row | ||
|
||
table.insert(imgTable, img) | ||
table.insert(outputTable, output) | ||
end | ||
local images, outputs = tableToOutput(self, imgTable, outputTable) | ||
return images, outputs | ||
end | ||
|
||
function dataset:get(i1, i2) | ||
local indices = self.samplingIds[{{i1, i2}}]; | ||
local quantity = i2 - i1 + 1; | ||
assert(quantity > 0) | ||
local imgTable = {} | ||
local outputTable = {} | ||
for i=1,quantity do | ||
local img, output = self:sampleHookTest(indices[i][1]) | ||
table.insert(imgTable, img) | ||
table.insert(outputTable, output) | ||
end | ||
local images, outputs = tableToOutput(self, imgTable, outputTable) | ||
return images, outputs | ||
end | ||
|
||
return dataset |
Oops, something went wrong.