forked from liuzhuang13/DenseNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
1,159 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
-- | ||
-- Copyright (c) 2016, Facebook, Inc. | ||
-- All rights reserved. | ||
-- | ||
-- This source code is licensed under the BSD-style license found in the | ||
-- LICENSE file in the root directory of this source tree. An additional grant | ||
-- of patent rights can be found in the PATENTS file in the same directory. | ||
-- | ||
local checkpoint = {} | ||
|
||
local function deepCopy(tbl) | ||
-- creates a copy of a network with new modules and the same tensors | ||
local copy = {} | ||
for k, v in pairs(tbl) do | ||
if type(v) == 'table' then | ||
copy[k] = deepCopy(v) | ||
else | ||
copy[k] = v | ||
end | ||
end | ||
if torch.typename(tbl) then | ||
torch.setmetatable(copy, torch.typename(tbl)) | ||
end | ||
return copy | ||
end | ||
|
||
function checkpoint.latest(opt) | ||
if opt.resume == 'none' then | ||
return nil | ||
end | ||
|
||
local latestPath = paths.concat(opt.resume, 'latest.t7') | ||
if not paths.filep(latestPath) then | ||
return nil | ||
end | ||
|
||
print('=> Loading checkpoint ' .. latestPath) | ||
local latest = torch.load(latestPath) | ||
local optimState = torch.load(paths.concat(opt.resume, latest.optimFile)) | ||
|
||
return latest, optimState | ||
end | ||
|
||
function checkpoint.save(epoch, model, optimState, isBestModel, opt) | ||
-- don't save the DataParallelTable for easier loading on other machines | ||
if torch.type(model) == 'nn.DataParallelTable' then | ||
model = model:get(1) | ||
end | ||
|
||
-- create a clean copy on the CPU without modifying the original network | ||
model = deepCopy(model):float():clearState() | ||
|
||
local modelFile = 'model_' .. epoch .. '.t7' | ||
local optimFile = 'optimState_' .. epoch .. '.t7' | ||
|
||
torch.save(paths.concat(opt.save, modelFile), model) | ||
torch.save(paths.concat(opt.save, optimFile), optimState) | ||
torch.save(paths.concat(opt.save, 'latest.t7'), { | ||
epoch = epoch, | ||
modelFile = modelFile, | ||
optimFile = optimFile, | ||
}) | ||
|
||
if isBestModel then | ||
torch.save(paths.concat(opt.save, 'model_best.t7'), model) | ||
end | ||
end | ||
|
||
return checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
-- | ||
-- Copyright (c) 2016, Facebook, Inc. | ||
-- All rights reserved. | ||
-- | ||
-- This source code is licensed under the BSD-style license found in the | ||
-- LICENSE file in the root directory of this source tree. An additional grant | ||
-- of patent rights can be found in the PATENTS file in the same directory. | ||
-- | ||
-- Multi-threaded data loader | ||
-- | ||
|
||
local datasets = require 'datasets/init' | ||
local Threads = require 'threads' | ||
Threads.serialization('threads.sharedserialize') | ||
|
||
local M = {} | ||
local DataLoader = torch.class('resnet.DataLoader', M) | ||
|
||
function DataLoader.create(opt) | ||
-- The train and val loader | ||
local loaders = {} | ||
|
||
for i, split in ipairs{'train', 'val'} do | ||
local dataset = datasets.create(opt, split) | ||
loaders[i] = M.DataLoader(dataset, opt, split) | ||
end | ||
|
||
return table.unpack(loaders) | ||
end | ||
|
||
function DataLoader:__init(dataset, opt, split) | ||
local manualSeed = opt.manualSeed | ||
local function init() | ||
require('datasets/' .. opt.dataset) | ||
end | ||
local function main(idx) | ||
if manualSeed ~= 0 then | ||
torch.manualSeed(manualSeed + idx) | ||
end | ||
torch.setnumthreads(1) | ||
_G.dataset = dataset | ||
_G.preprocess = dataset:preprocess() | ||
return dataset:size() | ||
end | ||
|
||
local threads, sizes = Threads(opt.nThreads, init, main) | ||
self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1 | ||
self.threads = threads | ||
self.__size = sizes[1][1] | ||
self.batchSize = math.floor(opt.batchSize / self.nCrops) | ||
local function getCPUType(tensorType) | ||
if tensorType == 'torch.CudaHalfTensor' then | ||
return 'HalfTensor' | ||
elseif tensorType == 'torch.CudaDoubleTensor' then | ||
return 'DoubleTensor' | ||
else | ||
return 'FloatTensor' | ||
end | ||
end | ||
self.cpuType = getCPUType(opt.tensorType) | ||
end | ||
|
||
function DataLoader:size() | ||
return math.ceil(self.__size / self.batchSize) | ||
end | ||
|
||
function DataLoader:run() | ||
local threads = self.threads | ||
local size, batchSize = self.__size, self.batchSize | ||
local perm = torch.randperm(size) | ||
|
||
local idx, sample = 1, nil | ||
local function enqueue() | ||
while idx <= size and threads:acceptsjob() do | ||
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) | ||
threads:addjob( | ||
function(indices, nCrops, cpuType) | ||
local sz = indices:size(1) | ||
local batch, imageSize | ||
local target = torch.IntTensor(sz) | ||
for i, idx in ipairs(indices:totable()) do | ||
local sample = _G.dataset:get(idx) | ||
local input = _G.preprocess(sample.input) | ||
if not batch then | ||
imageSize = input:size():totable() | ||
if nCrops > 1 then table.remove(imageSize, 1) end | ||
batch = torch[cpuType](sz, nCrops, table.unpack(imageSize)) | ||
end | ||
batch[i]:copy(input) | ||
target[i] = sample.target | ||
end | ||
collectgarbage() | ||
return { | ||
input = batch:view(sz * nCrops, table.unpack(imageSize)), | ||
target = target, | ||
} | ||
end, | ||
function(_sample_) | ||
sample = _sample_ | ||
end, | ||
indices, | ||
self.nCrops, | ||
self.cpuType | ||
) | ||
idx = idx + batchSize | ||
end | ||
end | ||
|
||
local n = 0 | ||
local function loop() | ||
enqueue() | ||
if not threads:hasjob() then | ||
return nil | ||
end | ||
threads:dojob() | ||
if threads:haserror() then | ||
threads:synchronize() | ||
end | ||
enqueue() | ||
n = n + 1 | ||
return n, sample | ||
end | ||
|
||
return loop | ||
end | ||
|
||
return M.DataLoader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
## Datasets | ||
|
||
Each dataset consist of two files: `dataset-gen.lua` and `dataset.lua`. The `dataset-gen.lua` is responsible for one-time setup, while | ||
the `dataset.lua` handles the actual data loading. | ||
|
||
If you want to be able to use the new dataset from main.lua, you should also modify `opts.lua` to handle the new dataset name. | ||
|
||
### `dataset-gen.lua` | ||
|
||
The `dataset-gen.lua` performs any necessary one-time setup. For example, the [`cifar10-gen.lua`](cifar10-gen.lua) file downloads the CIFAR-10 dataset, and the [`imagenet-gen.lua`](imagenet-gen.lua) file indexes all the training and validation data. | ||
|
||
The module should have a single function `exec(opt, cacheFile)`. | ||
- `opt`: the command line options | ||
- `cacheFile`: path to output | ||
|
||
```lua | ||
local M = {} | ||
function M.exec(opt, cacheFile) | ||
local imageInfo = {} | ||
-- preprocess dataset, store results in imageInfo, save to cacheFile | ||
torch.save(cacheFile, imageInfo) | ||
end | ||
return M | ||
``` | ||
|
||
### `dataset.lua` | ||
|
||
The `dataset.lua` should return a class that implements three functions: | ||
- `get(i)`: returns a table containing two entries, `input` and `target` | ||
- `input`: the training or validation image as a Torch tensor | ||
- `target`: the image category as a number 1-N | ||
- `size()`: returns the number of entries in the dataset | ||
- `preprocess()`: returns a function that transforms the `input` for data augmentation or input normalization | ||
|
||
```lua | ||
local M = {} | ||
local FakeDataset = torch.class('resnet.FakeDataset', M) | ||
|
||
function FakeDataset:__init(imageInfo, opt, split) | ||
-- imageInfo: result from dataset-gen.lua | ||
-- opt: command-line arguments | ||
-- split: "train" or "val" | ||
end | ||
|
||
function FakeDataset:get(i) | ||
return { | ||
input = torch.Tensor(3, 800, 600):uniform(), | ||
target = 42, | ||
} | ||
end | ||
|
||
function FakeDataset:size() | ||
-- size of dataset | ||
return 2000 | ||
end | ||
|
||
function FakeDataset:preprocess() | ||
-- Scale smaller side to 256 and take 224x224 center-crop | ||
return t.Compose{ | ||
t.Scale(256), | ||
t.CenterCrop(224), | ||
} | ||
end | ||
|
||
return M.FakeDataset | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
-- | ||
-- Copyright (c) 2016, Facebook, Inc. | ||
-- All rights reserved. | ||
-- | ||
-- This source code is licensed under the BSD-style license found in the | ||
-- LICENSE file in the root directory of this source tree. An additional grant | ||
-- of patent rights can be found in the PATENTS file in the same directory. | ||
-- | ||
-- Script to compute list of ImageNet filenames and classes | ||
-- | ||
-- This automatically downloads the CIFAR-10 dataset from | ||
-- http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz | ||
-- | ||
|
||
local URL = 'http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz' | ||
|
||
local M = {} | ||
|
||
local function convertToTensor(files) | ||
local data, labels | ||
|
||
for _, file in ipairs(files) do | ||
local m = torch.load(file, 'ascii') | ||
if not data then | ||
data = m.data:t() | ||
labels = m.labels:squeeze() | ||
else | ||
data = torch.cat(data, m.data:t(), 1) | ||
labels = torch.cat(labels, m.labels:squeeze()) | ||
end | ||
end | ||
|
||
-- This is *very* important. The downloaded files have labels 0-9, which do | ||
-- not work with CrossEntropyCriterion | ||
labels:add(1) | ||
|
||
return { | ||
data = data:contiguous():view(-1, 3, 32, 32), | ||
labels = labels, | ||
} | ||
end | ||
|
||
function M.exec(opt, cacheFile) | ||
print("=> Downloading CIFAR-10 dataset from " .. URL) | ||
local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/') | ||
assert(ok == true or ok == 0, 'error downloading CIFAR-10') | ||
|
||
print(" | combining dataset into a single file") | ||
local trainData = convertToTensor({ | ||
'gen/cifar-10-batches-t7/data_batch_1.t7', | ||
'gen/cifar-10-batches-t7/data_batch_2.t7', | ||
'gen/cifar-10-batches-t7/data_batch_3.t7', | ||
'gen/cifar-10-batches-t7/data_batch_4.t7', | ||
'gen/cifar-10-batches-t7/data_batch_5.t7', | ||
}) | ||
local testData = convertToTensor({ | ||
'gen/cifar-10-batches-t7/test_batch.t7', | ||
}) | ||
|
||
print(" | saving CIFAR-10 dataset to " .. cacheFile) | ||
torch.save(cacheFile, { | ||
train = trainData, | ||
val = testData, | ||
}) | ||
end | ||
|
||
return M |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
-- | ||
-- Copyright (c) 2016, Facebook, Inc. | ||
-- All rights reserved. | ||
-- | ||
-- This source code is licensed under the BSD-style license found in the | ||
-- LICENSE file in the root directory of this source tree. An additional grant | ||
-- of patent rights can be found in the PATENTS file in the same directory. | ||
-- | ||
-- CIFAR-10 dataset loader | ||
-- | ||
|
||
local t = require 'datasets/transforms' | ||
|
||
local M = {} | ||
local CifarDataset = torch.class('resnet.CifarDataset', M) | ||
|
||
function CifarDataset:__init(imageInfo, opt, split) | ||
assert(imageInfo[split], split) | ||
self.imageInfo = imageInfo[split] | ||
self.split = split | ||
end | ||
|
||
function CifarDataset:get(i) | ||
local image = self.imageInfo.data[i]:float() | ||
local label = self.imageInfo.labels[i] | ||
|
||
return { | ||
input = image, | ||
target = label, | ||
} | ||
end | ||
|
||
function CifarDataset:size() | ||
return self.imageInfo.data:size(1) | ||
end | ||
|
||
-- Computed from entire CIFAR-10 training set | ||
local meanstd = { | ||
mean = {125.3, 123.0, 113.9}, | ||
std = {63.0, 62.1, 66.7}, | ||
} | ||
|
||
function CifarDataset:preprocess() | ||
if self.split == 'train' then | ||
return t.Compose{ | ||
t.ColorNormalize(meanstd), | ||
t.HorizontalFlip(0.5), | ||
t.RandomCrop(32, 4), | ||
} | ||
elseif self.split == 'val' then | ||
return t.ColorNormalize(meanstd) | ||
else | ||
error('invalid split: ' .. self.split) | ||
end | ||
end | ||
|
||
return M.CifarDataset |
Oops, something went wrong.