Skip to content

Commit

Permalink
Merge pull request intel#120 from sguada/images_layer
Browse files Browse the repository at this point in the history
Images layer: A data provider layer directly from images
  • Loading branch information
shelhamer committed Mar 13, 2014
2 parents c965bc1 + 587eeab commit b1765ce
Show file tree
Hide file tree
Showing 6 changed files with 454 additions and 0 deletions.
Binary file added data/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
38 changes: 38 additions & 0 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,44 @@ class DataLayer : public Layer<Dtype> {
Blob<Dtype> data_mean_;
};

// This function is used to create a pthread that prefetches the data.
template <typename Dtype>
void* ImagesLayerPrefetch(void* layer_pointer);

template <typename Dtype>
class ImagesLayer : public Layer<Dtype> {
// The function used to perform prefetching.
friend void* ImagesLayerPrefetch<Dtype>(void* layer_pointer);

public:
explicit ImagesLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual ~ImagesLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);

vector<std::pair<std::string, int> > lines_;
int lines_id_;
int datum_channels_;
int datum_height_;
int datum_width_;
int datum_size_;
pthread_t thread_;
shared_ptr<Blob<Dtype> > prefetch_data_;
shared_ptr<Blob<Dtype> > prefetch_label_;
Blob<Dtype> data_mean_;
};


template <typename Dtype>
class SoftmaxLayer : public Layer<Dtype> {
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new ConvolutionLayer<Dtype>(param);
} else if (type == "data") {
return new DataLayer<Dtype>(param);
} else if (type == "images") {
return new ImagesLayer<Dtype>(param);
} else if (type == "dropout") {
return new DropoutLayer<Dtype>(param);
} else if (type == "euclidean_loss") {
Expand Down
274 changes: 274 additions & 0 deletions src/caffe/layers/images_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
// Copyright 2013 Yangqing Jia

#include <stdint.h>
#include <leveldb/db.h>
#include <pthread.h>

#include <string>
#include <vector>
#include <iostream>
#include <fstream>

#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"

using std::string;
using std::pair;

namespace caffe {

template <typename Dtype>
void* ImagesLayerPrefetch(void* layer_pointer) {
CHECK(layer_pointer);
ImagesLayer<Dtype>* layer = reinterpret_cast<ImagesLayer<Dtype>*>(layer_pointer);
CHECK(layer);
Datum datum;
CHECK(layer->prefetch_data_);
Dtype* top_data = layer->prefetch_data_->mutable_cpu_data();
Dtype* top_label = layer->prefetch_label_->mutable_cpu_data();
const Dtype scale = layer->layer_param_.scale();
const int batchsize = layer->layer_param_.batchsize();
const int cropsize = layer->layer_param_.cropsize();
const bool mirror = layer->layer_param_.mirror();
const int new_height = layer->layer_param_.new_height();
const int new_width = layer->layer_param_.new_height();

if (mirror && cropsize == 0) {
LOG(FATAL) << "Current implementation requires mirror and cropsize to be "
<< "set at the same time.";
}
// datum scales
const int channels = layer->datum_channels_;
const int height = layer->datum_height_;
const int width = layer->datum_width_;
const int size = layer->datum_size_;
const int lines_size = layer->lines_.size();
const Dtype* mean = layer->data_mean_.cpu_data();
for (int itemid = 0; itemid < batchsize; ++itemid) {
// get a blob
CHECK_GT(lines_size,layer->lines_id_);
if (!ReadImageToDatum(layer->lines_[layer->lines_id_].first, layer->lines_[layer->lines_id_].second,
new_height, new_width, &datum)) {
continue;
};
const string& data = datum.data();
if (cropsize) {
CHECK(data.size()) << "Image cropping only support uint8 data";
int h_off, w_off;
// We only do random crop when we do training.
if (Caffe::phase() == Caffe::TRAIN) {
h_off = rand() % (height - cropsize);
w_off = rand() % (width - cropsize);
} else {
h_off = (height - cropsize) / 2;
w_off = (width - cropsize) / 2;
}
if (mirror && rand() % 2) {
// Copy mirrored version
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < cropsize; ++h) {
for (int w = 0; w < cropsize; ++w) {
top_data[((itemid * channels + c) * cropsize + h) * cropsize
+ cropsize - 1 - w] =
(static_cast<Dtype>(
(uint8_t)data[(c * height + h + h_off) * width
+ w + w_off])
- mean[(c * height + h + h_off) * width + w + w_off])
* scale;
}
}
}
} else {
// Normal copy
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < cropsize; ++h) {
for (int w = 0; w < cropsize; ++w) {
top_data[((itemid * channels + c) * cropsize + h) * cropsize + w]
= (static_cast<Dtype>(
(uint8_t)data[(c * height + h + h_off) * width
+ w + w_off])
- mean[(c * height + h + h_off) * width + w + w_off])
* scale;
}
}
}
}
} else {
// Just copy the whole data
if (data.size()) {
for (int j = 0; j < size; ++j) {
top_data[itemid * size + j] =
(static_cast<Dtype>((uint8_t)data[j]) - mean[j]) * scale;
}
} else {
for (int j = 0; j < size; ++j) {
top_data[itemid * size + j] =
(datum.float_data(j) - mean[j]) * scale;
}
}
}

top_label[itemid] = datum.label();
// go to the next iter
layer->lines_id_++;
if (layer->lines_id_ >= lines_size) {
// We have reached the end. Restart from the first.
DLOG(INFO) << "Restarting data prefetching from start.";
layer->lines_id_=0;
if (layer->layer_param_.shuffle_images()) {
std::random_shuffle(layer->lines_.begin(), layer->lines_.end());
}
}
}

return (void*)NULL;
}

template <typename Dtype>
ImagesLayer<Dtype>::~ImagesLayer<Dtype>() {
// Finally, join the thread
CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
}

template <typename Dtype>
void ImagesLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 0) << "Input Layer takes no input blobs.";
CHECK_EQ(top->size(), 2) << "Input Layer takes two blobs as output.";
const int new_height = this->layer_param_.new_height();
const int new_width = this->layer_param_.new_height();
CHECK((new_height==0 && new_width==0)||(new_height>0 && new_width > 0)) <<
"Current implementation requires new_height and new_width to be set at the same time.";
// Read the file with filenames and labels
LOG(INFO) << "Opening file " << this->layer_param_.source();
std::ifstream infile(this->layer_param_.source().c_str());
string filename;
int label;
while (infile >> filename >> label) {
lines_.push_back(std::make_pair(filename, label));
}

if (this->layer_param_.shuffle_images()) {
// randomly shuffle data
LOG(INFO) << "Shuffling data";
std::random_shuffle(lines_.begin(), lines_.end());
}
LOG(INFO) << "A total of " << lines_.size() << " images.";

lines_id_ = 0;
// Check if we would need to randomly skip a few data points
if (this->layer_param_.rand_skip()) {
unsigned int skip = rand() % this->layer_param_.rand_skip();
LOG(INFO) << "Skipping first " << skip << " data points.";
CHECK_GT(lines_.size(),skip) << "Not enought points to skip";
lines_id_ = skip;
}
// Read a data point, and use it to initialize the top blob.
Datum datum;
CHECK(ReadImageToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
new_height,new_width,&datum));
// image
int cropsize = this->layer_param_.cropsize();
if (cropsize > 0) {
(*top)[0]->Reshape(
this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize);
prefetch_data_.reset(new Blob<Dtype>(
this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize));
} else {
(*top)[0]->Reshape(
this->layer_param_.batchsize(), datum.channels(), datum.height(),
datum.width());
prefetch_data_.reset(new Blob<Dtype>(
this->layer_param_.batchsize(), datum.channels(), datum.height(),
datum.width()));
}
LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
<< (*top)[0]->channels() << "," << (*top)[0]->height() << ","
<< (*top)[0]->width();
// label
(*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
prefetch_label_.reset(
new Blob<Dtype>(this->layer_param_.batchsize(), 1, 1, 1));
// datum size
datum_channels_ = datum.channels();
datum_height_ = datum.height();
datum_width_ = datum.width();
datum_size_ = datum.channels() * datum.height() * datum.width();
CHECK_GT(datum_height_, cropsize);
CHECK_GT(datum_width_, cropsize);
// check if we want to have mean
if (this->layer_param_.has_meanfile()) {
BlobProto blob_proto;
LOG(INFO) << "Loading mean file from" << this->layer_param_.meanfile();
ReadProtoFromBinaryFile(this->layer_param_.meanfile().c_str(), &blob_proto);
data_mean_.FromProto(blob_proto);
CHECK_EQ(data_mean_.num(), 1);
CHECK_EQ(data_mean_.channels(), datum_channels_);
CHECK_EQ(data_mean_.height(), datum_height_);
CHECK_EQ(data_mean_.width(), datum_width_);
} else {
// Simply initialize an all-empty mean.
data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
}
// Now, start the prefetch thread. Before calling prefetch, we make two
// cpu_data calls so that the prefetch thread does not accidentally make
// simultaneous cudaMalloc calls when the main thread is running. In some
// GPUs this seems to cause failures if we do not so.
prefetch_data_->mutable_cpu_data();
prefetch_label_->mutable_cpu_data();
data_mean_.cpu_data();
DLOG(INFO) << "Initializing prefetch";
CHECK(!pthread_create(&thread_, NULL, ImagesLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
DLOG(INFO) << "Prefetch initialized.";
}

template <typename Dtype>
void ImagesLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
// Copy the data
memcpy((*top)[0]->mutable_cpu_data(), prefetch_data_->cpu_data(),
sizeof(Dtype) * prefetch_data_->count());
memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
sizeof(Dtype) * prefetch_label_->count());
// Start a new prefetch thread
CHECK(!pthread_create(&thread_, NULL, ImagesLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}

template <typename Dtype>
void ImagesLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
// Copy the data
CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
cudaMemcpyHostToDevice));
// Start a new prefetch thread
CHECK(!pthread_create(&thread_, NULL, ImagesLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}

// The backward operations are dummy - they do not carry any computation.
template <typename Dtype>
Dtype ImagesLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
return Dtype(0.);
}

template <typename Dtype>
Dtype ImagesLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
return Dtype(0.);
}

INSTANTIATE_CLASS(ImagesLayer);

} // namespace caffe
11 changes: 11 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ message LayerParameter {
// point would be set as rand_skip * rand(0,1). Note that rand_skip should not
// be larger than the number of keys in the leveldb.
optional uint32 rand_skip = 53 [ default = 0 ];

// For the Reshape Layer one need to specify the new dimensions
optional int32 new_num = 60 [default = 0];
optional int32 new_channels = 61 [default = 0];
optional int32 new_height = 62 [default = 0];
optional int32 new_width = 63 [default = 0];

// Used by ImageLayer to shuffle the list of files at every epoch it will also
// resize images if new_height or new_width are not zero
optional bool shuffle_images = 64 [default = false];

}

message LayerConnection {
Expand Down
Loading

0 comments on commit b1765ce

Please sign in to comment.