Skip to content

Commit

Permalink
train alpha version
Browse files Browse the repository at this point in the history
  • Loading branch information
Anurag Ranjan authored and Anurag Ranjan committed Jan 6, 2017
1 parent 0c6cb69 commit 0d3ccda
Show file tree
Hide file tree
Showing 11 changed files with 1,246 additions and 0 deletions.
68 changes: 68 additions & 0 deletions EPECriterion.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@

-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
-- All rights reserved.
-- This software is provided for research purposes only.
-- By using this software you agree to the terms of the license file
-- in the root folder.
-- For commercial use, please contact [email protected].

local EPECriterion, parent = torch.class('nn.EPECriterion', 'nn.Criterion')

-- Computes average endpoint error for batchSize x ChannelSize x Height x Width
-- flow fields or general multidimensional matrices.

local eps = 1e-12

function EPECriterion:__init()
parent.__init(self)
self.sizeAverage = true
end

function EPECriterion:updateOutput(input, target)
assert( input:nElement() == target:nElement(),
"input and target size mismatch")

self.buffer = self.buffer or input.new()

local buffer = self.buffer
local output
local npixels

buffer:resizeAs(input)
npixels = input:nElement()/2 -- 2 channel flow fields

buffer:add(input, -1, target):pow(2)
output = torch.sum(buffer,2):sqrt() -- second channel is flow
output = output:sum()

output = output / npixels

self.output = output

return self.output
end

function EPECriterion:updateGradInput(input, target)

assert( input:nElement() == target:nElement(),
"input and target size mismatch")

self.buffer = self.buffer or input.new()

local buffer = self.buffer
local gradInput = self.gradInput
local npixels
local loss

buffer:resizeAs(input)
npixels = input:nElement()/2

buffer:add(input, -1, target):pow(2)
loss = torch.sum(buffer,2):sqrt():add(eps) -- forms the denominator
loss = torch.cat(loss, loss, 2) -- Repeat tensor to scale the gradients

gradInput:resizeAs(input)
gradInput:add(input, -1, target):cdiv(loss)
gradInput = gradInput / npixels
return gradInput
end
51 changes: 51 additions & 0 deletions data.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
-- All rights reserved.
-- This software is provided for research purposes only.
-- By using this software you agree to the terms of the license file
-- in the root folder.
-- For commercial use, please contact [email protected].
--
-- Copyright (c) 2014, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
local ffi = require 'ffi'
local Threads = require 'threads'
Threads.serialization('threads.sharedserialize')

-- This script contains the logic to create K threads for parallel data-loading.
-- For the data-loading details, look at donkey.lua
-------------------------------------------------------------------------------
do -- start K datathreads (donkeys)
if opt.nDonkeys > 0 then
local options = opt -- make an upvalue to serialize over to donkey threads
donkeys = Threads(
opt.nDonkeys,
function()
require 'torch'
end,
function(idx)
opt = options -- pass to all donkeys via upvalue
tid = idx
local seed = opt.manualSeed + idx
torch.manualSeed(seed)
print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
paths.dofile('donkey.lua')
end
);
else -- single threaded data loading. useful for debugging
paths.dofile('donkey.lua')
donkeys = {}
function donkeys:addjob(f1, f2) f2(f1()) end
function donkeys:synchronize() end
end
end

nTest = 0
donkeys:addjob(function() return testLoader:size() end, function(c) nTest = c end)
donkeys:synchronize()
assert(nTest > 0, "Failed to get nTest")
print('nTest: ', nTest)
148 changes: 148 additions & 0 deletions dataset.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@

-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
-- All rights reserved.
-- This software is provided for research purposes only.
-- By using this software you agree to the terms of the license file
-- in the root folder.
-- For commercial use, please contact [email protected].

require 'torch'
torch.setdefaulttensortype('torch.FloatTensor')
local ffi = require 'ffi'
local class = require('pl.class')
local dir = require 'pl.dir'
local tablex = require 'pl.tablex'
local argcheck = require 'argcheck'
require 'sys'
require 'xlua'
require 'image'

local dataset = torch.class('dataLoader')

local initcheck = argcheck{
pack=true,
help=[[
A dataset class for images in a flat folder structure (folder-name is class-name).
Optimized for extremely large datasets (upwards of 14 million images).
Tested only on Linux (as it uses command-line linux utilities to scale up)
]],
{name="inputSize",
type="table",
help="the size of the input images"},

{name="outputSize",
type="table",
help="the size of the network output"},

{name="split",
type="number",
help="Percentage of split to go to Training"
},

{name="samplingMode",
type="string",
help="Sampling mode: random | balanced ",
default = "balanced"},

{name="verbose",
type="boolean",
help="Verbose mode during initialization",
default = false},

{name="loadSize",
type="table",
help="a size to load the images to, initially",
opt = true},

{name="samplingIds",
type="torch.LongTensor",
help="the ids of training or testing images",
opt = true},

{name="sampleHookTrain",
type="function",
help="applied to sample during training(ex: for lighting jitter). "
.. "It takes the image path as input",
opt = true},

{name="sampleHookTest",
type="function",
help="applied to sample during testing",
opt = true},
}

function dataset:__init(...)

-- argcheck
local args = initcheck(...)
print(args)
for k,v in pairs(args) do self[k] = v end

if not self.loadSize then self.loadSize = self.inputSize; end

if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end
if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end

local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end

self.numSamples = self.samplingIds:size()[1]
assert(self.numSamples > 0, "Could not find any sample in the given input paths")

if self.verbose then print(self.numSamples .. ' samples found.') end
end

function dataset:size(class, list)
return self.numSamples
end

-- converts a table of samples (and corresponding labels) to a clean tensor
local function tableToOutput(self, imgTable, outputTable)
local images, outputs
local quantity = #imgTable
assert(imgTable[1]:size()[1] == self.inputSize[1])
assert(outputTable[1]:size()[1] == self.outputSize[1])

images = torch.Tensor(quantity,
self.inputSize[1], self.inputSize[2], self.inputSize[3])
outputs = torch.Tensor(quantity,
self.outputSize[1], self.outputSize[2], self.outputSize[3])

for i=1,quantity do
images[i]:copy(imgTable[i])
outputs[i]:copy(outputTable[i])
end
return images, outputs
end

-- sampler, samples from the training set.
function dataset:sample(quantity)
assert(quantity)
local imgTable = {}
local outputTable = {}
for i=1,quantity do
local id = torch.random(1, self.numSamples)
local img, output = self:sampleHookTrain(self.samplingIds[id][1]) -- single element[not tensor] from a row

table.insert(imgTable, img)
table.insert(outputTable, output)
end
local images, outputs = tableToOutput(self, imgTable, outputTable)
return images, outputs
end

function dataset:get(i1, i2)
local indices = self.samplingIds[{{i1, i2}}];
local quantity = i2 - i1 + 1;
assert(quantity > 0)
local imgTable = {}
local outputTable = {}
for i=1,quantity do
local img, output = self:sampleHookTest(indices[i][1])
table.insert(imgTable, img)
table.insert(outputTable, output)
end
local images, outputs = tableToOutput(self, imgTable, outputTable)
return images, outputs
end

return dataset
Loading

0 comments on commit 0d3ccda

Please sign in to comment.