Skip to content

Commit

Permalink
add all supporting files
Browse files Browse the repository at this point in the history
  • Loading branch information
gaohuang committed Aug 23, 2017
1 parent 1846d73 commit 4b5cc63
Show file tree
Hide file tree
Showing 16 changed files with 1,159 additions and 4 deletions.
69 changes: 69 additions & 0 deletions checkpoints.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
local checkpoint = {}

local function deepCopy(tbl)
-- creates a copy of a network with new modules and the same tensors
local copy = {}
for k, v in pairs(tbl) do
if type(v) == 'table' then
copy[k] = deepCopy(v)
else
copy[k] = v
end
end
if torch.typename(tbl) then
torch.setmetatable(copy, torch.typename(tbl))
end
return copy
end

function checkpoint.latest(opt)
if opt.resume == 'none' then
return nil
end

local latestPath = paths.concat(opt.resume, 'latest.t7')
if not paths.filep(latestPath) then
return nil
end

print('=> Loading checkpoint ' .. latestPath)
local latest = torch.load(latestPath)
local optimState = torch.load(paths.concat(opt.resume, latest.optimFile))

return latest, optimState
end

function checkpoint.save(epoch, model, optimState, isBestModel, opt)
-- don't save the DataParallelTable for easier loading on other machines
if torch.type(model) == 'nn.DataParallelTable' then
model = model:get(1)
end

-- create a clean copy on the CPU without modifying the original network
model = deepCopy(model):float():clearState()

local modelFile = 'model_' .. epoch .. '.t7'
local optimFile = 'optimState_' .. epoch .. '.t7'

torch.save(paths.concat(opt.save, modelFile), model)
torch.save(paths.concat(opt.save, optimFile), optimState)
torch.save(paths.concat(opt.save, 'latest.t7'), {
epoch = epoch,
modelFile = modelFile,
optimFile = optimFile,
})

if isBestModel then
torch.save(paths.concat(opt.save, 'model_best.t7'), model)
end
end

return checkpoint
127 changes: 127 additions & 0 deletions dataloader.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Multi-threaded data loader
--

local datasets = require 'datasets/init'
local Threads = require 'threads'
Threads.serialization('threads.sharedserialize')

local M = {}
local DataLoader = torch.class('resnet.DataLoader', M)

function DataLoader.create(opt)
-- The train and val loader
local loaders = {}

for i, split in ipairs{'train', 'val'} do
local dataset = datasets.create(opt, split)
loaders[i] = M.DataLoader(dataset, opt, split)
end

return table.unpack(loaders)
end

function DataLoader:__init(dataset, opt, split)
local manualSeed = opt.manualSeed
local function init()
require('datasets/' .. opt.dataset)
end
local function main(idx)
if manualSeed ~= 0 then
torch.manualSeed(manualSeed + idx)
end
torch.setnumthreads(1)
_G.dataset = dataset
_G.preprocess = dataset:preprocess()
return dataset:size()
end

local threads, sizes = Threads(opt.nThreads, init, main)
self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1
self.threads = threads
self.__size = sizes[1][1]
self.batchSize = math.floor(opt.batchSize / self.nCrops)
local function getCPUType(tensorType)
if tensorType == 'torch.CudaHalfTensor' then
return 'HalfTensor'
elseif tensorType == 'torch.CudaDoubleTensor' then
return 'DoubleTensor'
else
return 'FloatTensor'
end
end
self.cpuType = getCPUType(opt.tensorType)
end

function DataLoader:size()
return math.ceil(self.__size / self.batchSize)
end

function DataLoader:run()
local threads = self.threads
local size, batchSize = self.__size, self.batchSize
local perm = torch.randperm(size)

local idx, sample = 1, nil
local function enqueue()
while idx <= size and threads:acceptsjob() do
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
threads:addjob(
function(indices, nCrops, cpuType)
local sz = indices:size(1)
local batch, imageSize
local target = torch.IntTensor(sz)
for i, idx in ipairs(indices:totable()) do
local sample = _G.dataset:get(idx)
local input = _G.preprocess(sample.input)
if not batch then
imageSize = input:size():totable()
if nCrops > 1 then table.remove(imageSize, 1) end
batch = torch[cpuType](sz, nCrops, table.unpack(imageSize))
end
batch[i]:copy(input)
target[i] = sample.target
end
collectgarbage()
return {
input = batch:view(sz * nCrops, table.unpack(imageSize)),
target = target,
}
end,
function(_sample_)
sample = _sample_
end,
indices,
self.nCrops,
self.cpuType
)
idx = idx + batchSize
end
end

local n = 0
local function loop()
enqueue()
if not threads:hasjob() then
return nil
end
threads:dojob()
if threads:haserror() then
threads:synchronize()
end
enqueue()
n = n + 1
return n, sample
end

return loop
end

return M.DataLoader
66 changes: 66 additions & 0 deletions datasets/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
## Datasets

Each dataset consist of two files: `dataset-gen.lua` and `dataset.lua`. The `dataset-gen.lua` is responsible for one-time setup, while
the `dataset.lua` handles the actual data loading.

If you want to be able to use the new dataset from main.lua, you should also modify `opts.lua` to handle the new dataset name.

### `dataset-gen.lua`

The `dataset-gen.lua` performs any necessary one-time setup. For example, the [`cifar10-gen.lua`](cifar10-gen.lua) file downloads the CIFAR-10 dataset, and the [`imagenet-gen.lua`](imagenet-gen.lua) file indexes all the training and validation data.

The module should have a single function `exec(opt, cacheFile)`.
- `opt`: the command line options
- `cacheFile`: path to output

```lua
local M = {}
function M.exec(opt, cacheFile)
local imageInfo = {}
-- preprocess dataset, store results in imageInfo, save to cacheFile
torch.save(cacheFile, imageInfo)
end
return M
```

### `dataset.lua`

The `dataset.lua` should return a class that implements three functions:
- `get(i)`: returns a table containing two entries, `input` and `target`
- `input`: the training or validation image as a Torch tensor
- `target`: the image category as a number 1-N
- `size()`: returns the number of entries in the dataset
- `preprocess()`: returns a function that transforms the `input` for data augmentation or input normalization

```lua
local M = {}
local FakeDataset = torch.class('resnet.FakeDataset', M)

function FakeDataset:__init(imageInfo, opt, split)
-- imageInfo: result from dataset-gen.lua
-- opt: command-line arguments
-- split: "train" or "val"
end

function FakeDataset:get(i)
return {
input = torch.Tensor(3, 800, 600):uniform(),
target = 42,
}
end

function FakeDataset:size()
-- size of dataset
return 2000
end

function FakeDataset:preprocess()
-- Scale smaller side to 256 and take 224x224 center-crop
return t.Compose{
t.Scale(256),
t.CenterCrop(224),
}
end

return M.FakeDataset
```
67 changes: 67 additions & 0 deletions datasets/cifar10-gen.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Script to compute list of ImageNet filenames and classes
--
-- This automatically downloads the CIFAR-10 dataset from
-- http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz
--

local URL = 'http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz'

local M = {}

local function convertToTensor(files)
local data, labels

for _, file in ipairs(files) do
local m = torch.load(file, 'ascii')
if not data then
data = m.data:t()
labels = m.labels:squeeze()
else
data = torch.cat(data, m.data:t(), 1)
labels = torch.cat(labels, m.labels:squeeze())
end
end

-- This is *very* important. The downloaded files have labels 0-9, which do
-- not work with CrossEntropyCriterion
labels:add(1)

return {
data = data:contiguous():view(-1, 3, 32, 32),
labels = labels,
}
end

function M.exec(opt, cacheFile)
print("=> Downloading CIFAR-10 dataset from " .. URL)
local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/')
assert(ok == true or ok == 0, 'error downloading CIFAR-10')

print(" | combining dataset into a single file")
local trainData = convertToTensor({
'gen/cifar-10-batches-t7/data_batch_1.t7',
'gen/cifar-10-batches-t7/data_batch_2.t7',
'gen/cifar-10-batches-t7/data_batch_3.t7',
'gen/cifar-10-batches-t7/data_batch_4.t7',
'gen/cifar-10-batches-t7/data_batch_5.t7',
})
local testData = convertToTensor({
'gen/cifar-10-batches-t7/test_batch.t7',
})

print(" | saving CIFAR-10 dataset to " .. cacheFile)
torch.save(cacheFile, {
train = trainData,
val = testData,
})
end

return M
57 changes: 57 additions & 0 deletions datasets/cifar10.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- CIFAR-10 dataset loader
--

local t = require 'datasets/transforms'

local M = {}
local CifarDataset = torch.class('resnet.CifarDataset', M)

function CifarDataset:__init(imageInfo, opt, split)
assert(imageInfo[split], split)
self.imageInfo = imageInfo[split]
self.split = split
end

function CifarDataset:get(i)
local image = self.imageInfo.data[i]:float()
local label = self.imageInfo.labels[i]

return {
input = image,
target = label,
}
end

function CifarDataset:size()
return self.imageInfo.data:size(1)
end

-- Computed from entire CIFAR-10 training set
local meanstd = {
mean = {125.3, 123.0, 113.9},
std = {63.0, 62.1, 66.7},
}

function CifarDataset:preprocess()
if self.split == 'train' then
return t.Compose{
t.ColorNormalize(meanstd),
t.HorizontalFlip(0.5),
t.RandomCrop(32, 4),
}
elseif self.split == 'val' then
return t.ColorNormalize(meanstd)
else
error('invalid split: ' .. self.split)
end
end

return M.CifarDataset
Loading

0 comments on commit 4b5cc63

Please sign in to comment.