forked from anuragranj/spynet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.lua
executable file
·148 lines (121 loc) · 4.21 KB
/
dataset.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
-- All rights reserved.
-- This software is provided for research purposes only.
-- By using this software you agree to the terms of the license file
-- in the root folder.
-- For commercial use, please contact [email protected].
require 'torch'
torch.setdefaulttensortype('torch.FloatTensor')
local ffi = require 'ffi'
local class = require('pl.class')
local dir = require 'pl.dir'
local tablex = require 'pl.tablex'
local argcheck = require 'argcheck'
require 'sys'
require 'xlua'
require 'image'
local dataset = torch.class('dataLoader')
local initcheck = argcheck{
pack=true,
help=[[
A dataset class for images in a flat folder structure (folder-name is class-name).
Optimized for extremely large datasets (upwards of 14 million images).
Tested only on Linux (as it uses command-line linux utilities to scale up)
]],
{name="inputSize",
type="table",
help="the size of the input images"},
{name="outputSize",
type="table",
help="the size of the network output"},
{name="split",
type="number",
help="Percentage of split to go to Training"
},
{name="samplingMode",
type="string",
help="Sampling mode: random | balanced ",
default = "balanced"},
{name="verbose",
type="boolean",
help="Verbose mode during initialization",
default = false},
{name="loadSize",
type="table",
help="a size to load the images to, initially",
opt = true},
{name="samplingIds",
type="torch.LongTensor",
help="the ids of training or testing images",
opt = true},
{name="sampleHookTrain",
type="function",
help="applied to sample during training(ex: for lighting jitter). "
.. "It takes the image path as input",
opt = true},
{name="sampleHookTest",
type="function",
help="applied to sample during testing",
opt = true},
}
function dataset:__init(...)
-- argcheck
local args = initcheck(...)
print(args)
for k,v in pairs(args) do self[k] = v end
if not self.loadSize then self.loadSize = self.inputSize; end
if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end
if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end
local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end
self.numSamples = self.samplingIds:size()[1]
assert(self.numSamples > 0, "Could not find any sample in the given input paths")
if self.verbose then print(self.numSamples .. ' samples found.') end
end
function dataset:size(class, list)
return self.numSamples
end
-- converts a table of samples (and corresponding labels) to a clean tensor
local function tableToOutput(self, imgTable, outputTable)
local images, outputs
local quantity = #imgTable
assert(imgTable[1]:size()[1] == self.inputSize[1])
assert(outputTable[1]:size()[1] == self.outputSize[1])
images = torch.Tensor(quantity,
self.inputSize[1], self.inputSize[2], self.inputSize[3])
outputs = torch.Tensor(quantity,
self.outputSize[1], self.outputSize[2], self.outputSize[3])
for i=1,quantity do
images[i]:copy(imgTable[i])
outputs[i]:copy(outputTable[i])
end
return images, outputs
end
-- sampler, samples from the training set.
function dataset:sample(quantity)
assert(quantity)
local imgTable = {}
local outputTable = {}
for i=1,quantity do
local id = torch.random(1, self.numSamples)
local img, output = self:sampleHookTrain(self.samplingIds[id][1]) -- single element[not tensor] from a row
table.insert(imgTable, img)
table.insert(outputTable, output)
end
local images, outputs = tableToOutput(self, imgTable, outputTable)
return images, outputs
end
function dataset:get(i1, i2)
local indices = self.samplingIds[{{i1, i2}}];
local quantity = i2 - i1 + 1;
assert(quantity > 0)
local imgTable = {}
local outputTable = {}
for i=1,quantity do
local img, output = self:sampleHookTest(indices[i][1])
table.insert(imgTable, img)
table.insert(outputTable, output)
end
local images, outputs = tableToOutput(self, imgTable, outputTable)
return images, outputs
end
return dataset