From 0a84d0b645a5acefc8c2c6269a6e7a5dfa0ebac6 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Tue, 1 Oct 2013 16:27:48 -0700 Subject: [PATCH] misc update --- .gitignore | 3 +- src/Makefile | 4 +- src/caffe/layers/loss_layer.cu | 1 + src/caffe/layers/softmax_layer.cpp | 91 -------------- src/caffe/layers/softmax_layer.cu | 181 +++++++++++++++++++++++++++ src/caffe/test/test_solver_mnist.cpp | 108 ---------------- src/caffe/util/io.cpp | 22 ++++ src/caffe/util/io.hpp | 2 + src/caffe/vision_layers.hpp | 8 +- src/programs/convert_dataset.cpp | 66 ++++++++++ 10 files changed, 280 insertions(+), 206 deletions(-) delete mode 100644 src/caffe/layers/softmax_layer.cpp create mode 100644 src/caffe/layers/softmax_layer.cu delete mode 100644 src/caffe/test/test_solver_mnist.cpp create mode 100644 src/programs/convert_dataset.cpp diff --git a/.gitignore b/.gitignore index 14428f6da..bc38afc9d 100644 --- a/.gitignore +++ b/.gitignore @@ -19,8 +19,9 @@ *.pb.cc *_pb2.py -# test files +# bin files *.testbin +*.bin # vim swp files *.swp diff --git a/src/Makefile b/src/Makefile index 05b7bc042..31d225e94 100644 --- a/src/Makefile +++ b/src/Makefile @@ -79,8 +79,8 @@ $(PROTO_GEN_CC): $(PROTO_SRCS) protoc $(PROTO_SRCS) --cpp_out=. --python_out=. clean: - @- $(RM) $(NAME) $(TEST_BINS) - @- $(RM) $(OBJS) $(TEST_OBJS) + @- $(RM) $(NAME) $(TEST_BINS) $(PROGRAM_BINS) + @- $(RM) $(OBJS) $(TEST_OBJS) $(PROGRAM_OBJS) @- $(RM) $(PROTO_GEN_HEADER) $(PROTO_GEN_CC) $(PROTO_GEN_PY) distclean: clean diff --git a/src/caffe/layers/loss_layer.cu b/src/caffe/layers/loss_layer.cu index 1ea0626c3..737f1a239 100644 --- a/src/caffe/layers/loss_layer.cu +++ b/src/caffe/layers/loss_layer.cu @@ -47,6 +47,7 @@ Dtype MultinomialLogisticLossLayer::Backward_cpu(const vector // TODO: implement the GPU version for multinomial loss + template void EuclideanLossLayer::SetUp( const vector*>& bottom, vector*>* top) { diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp deleted file mode 100644 index a263ad326..000000000 --- a/src/caffe/layers/softmax_layer.cpp +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2013 Yangqing Jia - -#include -#include - -#include "caffe/layer.hpp" -#include "caffe/vision_layers.hpp" -#include "caffe/util/math_functions.hpp" - -using std::max; - -namespace caffe { - -template -void SoftmaxLayer::SetUp(const vector*>& bottom, - vector*>* top) { - CHECK_EQ(bottom.size(), 1) << "Softmax Layer takes a single blob as input."; - CHECK_EQ(top->size(), 1) << "Softmax Layer takes a single blob as output."; - (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); - sum_multiplier_.Reshape(1, bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); - Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data(); - for (int i = 0; i < sum_multiplier_.count(); ++i) { - multiplier_data[i] = 1.; - } - scale_.Reshape(bottom[0]->num(), 1, 1, 1); -}; - -template -void SoftmaxLayer::Forward_cpu(const vector*>& bottom, - vector*>* top) { - const Dtype* bottom_data = bottom[0]->cpu_data(); - Dtype* top_data = (*top)[0]->mutable_cpu_data(); - Dtype* scale_data = scale_.mutable_cpu_data(); - int num = bottom[0]->num(); - int dim = bottom[0]->count() / bottom[0]->num(); - memcpy(top_data, bottom_data, sizeof(Dtype) * bottom[0]->count()); - // we need to subtract the max to avoid numerical issues, compute the exp, - // and then normalize. - // Compute sum - for (int i = 0; i < num; ++i) { - scale_data[i] = bottom_data[i*dim]; - for (int j = 0; j < dim; ++j) { - scale_data[i] = max(scale_data[i], bottom_data[i * dim + j]); - } - } - // subtraction - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., - scale_data, sum_multiplier_.cpu_data(), 1., top_data); - // Perform exponentiation - caffe_exp(num * dim, top_data, top_data); - // sum after exp - caffe_cpu_gemv(CblasNoTrans, num, dim, 1., top_data, - sum_multiplier_.cpu_data(), 0., scale_data); - // Do division - for (int i = 0; i < num; ++i) { - caffe_scal(dim, Dtype(1.) / scale_data[i], top_data + i * dim); - } -} - -template -Dtype SoftmaxLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, - vector*>* bottom) { - const Dtype* top_diff = top[0]->cpu_diff(); - const Dtype* top_data = top[0]->cpu_data(); - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - Dtype* scale_data = scale_.mutable_cpu_data(); - int num = top[0]->num(); - int dim = top[0]->count() / top[0]->num(); - memcpy(bottom_diff, top_diff, sizeof(Dtype) * top[0]->count()); - // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff - for (int i = 0; i < num; ++i) { - scale_data[i] = caffe_cpu_dot(dim, top_diff + i * dim, - top_data + i * dim); - } - // subtraction - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., - scale_data, sum_multiplier_.cpu_data(), 1., bottom_diff); - // elementwise multiplication - caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff); - return Dtype(0); -} - -// TODO(Yangqing): implement the GPU version of softmax. - -INSTANTIATE_CLASS(SoftmaxLayer); - - -} // namespace caffe diff --git a/src/caffe/layers/softmax_layer.cu b/src/caffe/layers/softmax_layer.cu new file mode 100644 index 000000000..a7659697a --- /dev/null +++ b/src/caffe/layers/softmax_layer.cu @@ -0,0 +1,181 @@ +// Copyright 2013 Yangqing Jia + +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" + +using std::max; + +namespace caffe { + +template +void SoftmaxLayer::SetUp(const vector*>& bottom, + vector*>* top) { + CHECK_EQ(bottom.size(), 1) << "Softmax Layer takes a single blob as input."; + CHECK_EQ(top->size(), 1) << "Softmax Layer takes a single blob as output."; + (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + sum_multiplier_.Reshape(1, bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data(); + for (int i = 0; i < sum_multiplier_.count(); ++i) { + multiplier_data[i] = 1.; + } + scale_.Reshape(bottom[0]->num(), 1, 1, 1); +}; + +template +void SoftmaxLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = (*top)[0]->mutable_cpu_data(); + Dtype* scale_data = scale_.mutable_cpu_data(); + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + memcpy(top_data, bottom_data, sizeof(Dtype) * bottom[0]->count()); + // we need to subtract the max to avoid numerical issues, compute the exp, + // and then normalize. + for (int i = 0; i < num; ++i) { + scale_data[i] = bottom_data[i*dim]; + for (int j = 0; j < dim; ++j) { + scale_data[i] = max(scale_data[i], bottom_data[i * dim + j]); + } + } + // subtraction + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + scale_data, sum_multiplier_.cpu_data(), 1., top_data); + // Perform exponentiation + caffe_exp(num * dim, top_data, top_data); + // sum after exp + caffe_cpu_gemv(CblasNoTrans, num, dim, 1., top_data, + sum_multiplier_.cpu_data(), 0., scale_data); + // Do division + for (int i = 0; i < num; ++i) { + caffe_scal(dim, Dtype(1.) / scale_data[i], top_data + i * dim); + } +} + +template +__global__ void kernel_get_max(const int num, const int dim, + const Dtype* data, Dtype* out) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < num) { + Dtype maxval = -FLT_MAX; + for (int i = 0; i < dim; ++i) { + maxval = max(data[index * dim + i], maxval); + } + out[index] = maxval; + } +} + +template +__global__ void kernel_softmax_div(const int num, const int dim, + const Dtype* scale, Dtype* data) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < num * dim) { + int n = index / dim; + data[index] /= scale[n]; + } +} + +template +__global__ void kernel_exp(const int num, const Dtype* data, Dtype* out) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < num) { + out[index] = exp(data[index]); + } +} + +template +void SoftmaxLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = (*top)[0]->mutable_gpu_data(); + Dtype* scale_data = scale_.mutable_gpu_data(); + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + CUDA_CHECK(cudaMemcpy(top_data, bottom_data, + sizeof(Dtype) * bottom[0]->count(), cudaMemcpyDeviceToDevice)); + // we need to subtract the max to avoid numerical issues, compute the exp, + // and then normalize. + // Compute max + kernel_get_max<<>>( + num, dim, bottom_data, scale_data); + // subtraction + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + scale_data, sum_multiplier_.gpu_data(), 1., top_data); + // Perform exponentiation + kernel_exp<<>>( + num * dim, top_data, top_data); + // sum after exp + caffe_gpu_gemv(CblasNoTrans, num, dim, 1., top_data, + sum_multiplier_.gpu_data(), 0., scale_data); + // Do division + kernel_softmax_div<<>>( + num, dim, scale_data, top_data); +} + +template +Dtype SoftmaxLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* top_data = top[0]->cpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + Dtype* scale_data = scale_.mutable_cpu_data(); + int num = top[0]->num(); + int dim = top[0]->count() / top[0]->num(); + memcpy(bottom_diff, top_diff, sizeof(Dtype) * top[0]->count()); + // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff + for (int i = 0; i < num; ++i) { + scale_data[i] = caffe_cpu_dot(dim, top_diff + i * dim, + top_data + i * dim); + } + // subtraction + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + scale_data, sum_multiplier_.cpu_data(), 1., bottom_diff); + // elementwise multiplication + caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff); + return Dtype(0); +} + +// TODO(Yangqing): implement the GPU version of softmax. +template +Dtype SoftmaxLayer::Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = top[0]->gpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + int num = top[0]->num(); + int dim = top[0]->count() / top[0]->num(); + CUDA_CHECK(cudaMemcpy(bottom_diff, top_diff, + sizeof(Dtype) * top[0]->count(), cudaMemcpyDeviceToDevice)); + // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff + // cuda dot returns the result to cpu, so we temporarily change the pointer + // mode + CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(), + CUBLAS_POINTER_MODE_DEVICE)); + Dtype* scale_data = scale_.mutable_gpu_data(); + for (int i = 0; i < num; ++i) { + caffe_gpu_dot(dim, top_diff + i * dim, + top_data + i * dim, scale_data + i); + } + CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(), + CUBLAS_POINTER_MODE_HOST)); + // subtraction + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + scale_.gpu_data(), sum_multiplier_.gpu_data(), 1., bottom_diff); + // elementwise multiplication + caffe_gpu_mul(top[0]->count(), bottom_diff, top_data, bottom_diff); + return Dtype(0); +} + +INSTANTIATE_CLASS(SoftmaxLayer); + + +} // namespace caffe diff --git a/src/caffe/test/test_solver_mnist.cpp b/src/caffe/test/test_solver_mnist.cpp deleted file mode 100644 index 4c8d3fd2a..000000000 --- a/src/caffe/test/test_solver_mnist.cpp +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2013 Yangqing Jia - -#include -#include -#include -#include -#include - -#include - -#include "caffe/blob.hpp" -#include "caffe/common.hpp" -#include "caffe/net.hpp" -#include "caffe/filler.hpp" -#include "caffe/proto/caffe.pb.h" -#include "caffe/util/io.hpp" -#include "caffe/optimization/solver.hpp" - -#include "caffe/test/test_caffe_main.hpp" - -namespace caffe { - -template -class MNISTSolverTest : public ::testing::Test {}; - -typedef ::testing::Types Dtypes; -TYPED_TEST_CASE(MNISTSolverTest, Dtypes); - -TYPED_TEST(MNISTSolverTest, TestSolve) { - Caffe::set_mode(Caffe::GPU); - - NetParameter net_param; - ReadProtoFromTextFile("caffe/test/data/lenet.prototxt", - &net_param); - vector*> bottom_vec; - Net caffe_net(net_param, bottom_vec); - - // Run the network without training. - LOG(ERROR) << "Performing Forward"; - caffe_net.Forward(bottom_vec); - LOG(ERROR) << "Performing Backward"; - LOG(ERROR) << "Initial loss: " << caffe_net.Backward(); - - SolverParameter solver_param; - solver_param.set_base_lr(0.01); - solver_param.set_display(0); - solver_param.set_max_iter(6000); - solver_param.set_lr_policy("inv"); - solver_param.set_gamma(0.0001); - solver_param.set_power(0.75); - solver_param.set_momentum(0.9); - - LOG(ERROR) << "Starting Optimization"; - SGDSolver solver(solver_param); - solver.Solve(&caffe_net); - LOG(ERROR) << "Optimization Done."; - - // Run the network after training. - LOG(ERROR) << "Performing Forward"; - caffe_net.Forward(bottom_vec); - LOG(ERROR) << "Performing Backward"; - TypeParam loss = caffe_net.Backward(); - LOG(ERROR) << "Final loss: " << loss; - EXPECT_LE(loss, 0.5); - - NetParameter trained_net_param; - caffe_net.ToProto(&trained_net_param); - // LOG(ERROR) << "Writing to disk."; - // WriteProtoToBinaryFile(trained_net_param, - // "caffe/test/data/lenet_trained.prototxt"); - - NetParameter traintest_net_param; - ReadProtoFromTextFile("caffe/test/data/lenet_traintest.prototxt", - &traintest_net_param); - Net caffe_traintest_net(traintest_net_param, bottom_vec); - caffe_traintest_net.CopyTrainedLayersFrom(trained_net_param); - - // Test run - double train_accuracy = 0; - int batch_size = traintest_net_param.layers(0).layer().batchsize(); - for (int i = 0; i < 60000 / batch_size; ++i) { - const vector*>& result = - caffe_traintest_net.Forward(bottom_vec); - train_accuracy += result[0]->cpu_data()[0]; - } - train_accuracy /= 60000 / batch_size; - LOG(ERROR) << "Train accuracy:" << train_accuracy; - EXPECT_GE(train_accuracy, 0.98); - - NetParameter test_net_param; - ReadProtoFromTextFile("caffe/test/data/lenet_test.prototxt", &test_net_param); - Net caffe_test_net(test_net_param, bottom_vec); - caffe_test_net.CopyTrainedLayersFrom(trained_net_param); - - // Test run - double test_accuracy = 0; - batch_size = test_net_param.layers(0).layer().batchsize(); - for (int i = 0; i < 10000 / batch_size; ++i) { - const vector*>& result = - caffe_test_net.Forward(bottom_vec); - test_accuracy += result[0]->cpu_data()[0]; - } - test_accuracy /= 10000 / batch_size; - LOG(ERROR) << "Test accuracy:" << test_accuracy; - EXPECT_GE(test_accuracy, 0.98); -} - -} // namespace caffe diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index 0d4f9bb1e..b7a830bbc 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -47,6 +47,28 @@ void ReadImageToProto(const string& filename, BlobProto* proto) { } } +void ReadImageToDatum(const string& filename, const int label, Datum* datum) { + Mat cv_img; + cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR); + CHECK(cv_img.data) << "Could not open or find the image."; + DCHECK_EQ(cv_img.channels(), 3); + datum->set_channels(3); + datum->set_height(cv_img.rows); + datum->set_width(cv_img.cols); + datum->set_label(label); + datum->clear_data(); + datum->clear_float_data(); + string* datum_string = datum->mutable_data(); + for (int c = 0; c < 3; ++c) { + for (int h = 0; h < cv_img.rows; ++h) { + for (int w = 0; w < cv_img.cols; ++w) { + datum_string->push_back(static_cast(cv_img.at(h, w)[c])); + } + } + } +} + + void WriteProtoToImage(const string& filename, const BlobProto& proto) { CHECK_EQ(proto.num(), 1); CHECK(proto.channels() == 3 || proto.channels() == 1); diff --git a/src/caffe/util/io.hpp b/src/caffe/util/io.hpp index 57beef1dc..ab4593668 100644 --- a/src/caffe/util/io.hpp +++ b/src/caffe/util/io.hpp @@ -33,6 +33,8 @@ inline void WriteBlobToImage(const string& filename, const Blob& blob) { WriteProtoToImage(filename, proto); } +void ReadImageToDatum(const string& filename, const int label, Datum* datum); + void ReadProtoFromTextFile(const char* filename, Message* proto); inline void ReadProtoFromTextFile(const string& filename, diff --git a/src/caffe/vision_layers.hpp b/src/caffe/vision_layers.hpp index 3aa43b2ba..74c597825 100644 --- a/src/caffe/vision_layers.hpp +++ b/src/caffe/vision_layers.hpp @@ -267,12 +267,12 @@ class SoftmaxLayer : public Layer { protected: virtual void Forward_cpu(const vector*>& bottom, vector*>* top); - // virtual void Forward_gpu(const vector*>& bottom, - // vector*>* top); + virtual void Forward_gpu(const vector*>& bottom, + vector*>* top); virtual Dtype Backward_cpu(const vector*>& top, const bool propagate_down, vector*>* bottom); - // virtual Dtype Backward_gpu(const vector*>& top, - // const bool propagate_down, vector*>* bottom); + virtual Dtype Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); // sum_multiplier is just used to carry out sum using blas Blob sum_multiplier_; diff --git a/src/programs/convert_dataset.cpp b/src/programs/convert_dataset.cpp new file mode 100644 index 000000000..7fb6a0455 --- /dev/null +++ b/src/programs/convert_dataset.cpp @@ -0,0 +1,66 @@ +// Copyright 2013 Yangqing Jia +// This program converts a set of images to a leveldb by storing them as Datum +// proto buffers. +// Usage: +// convert_dataset ROOTFOLDER LISTFILE DB_NAME +// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE +// should be a list of files as well as their labels, in the format as +// subfolder1/file1.JPEG 0 +// .... + +#include +#include + +#include +#include +#include + +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" + +using namespace caffe; +using std::string; + +// A utility function to generate random strings +void GenerateRandomPrefix(const int n, string* key) { + const char* kCHARS = "abcdefghijklmnopqrstuvwxyz"; + key->clear(); + for (int i = 0; i < n; ++i) { + key->push_back(kCHARS[rand() % 26]); + } + key->push_back('_'); +} + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + std::ifstream infile(argv[2]); + leveldb::DB* db; + leveldb::Options options; + options.error_if_exists = true; + options.create_if_missing = true; + LOG(INFO) << "Opening leveldb " << argv[3]; + leveldb::Status status = leveldb::DB::Open( + options, argv[3], &db); + CHECK(status.ok()) << "Failed to open leveldb " << argv[3]; + + string root_folder(argv[1]); + string filename; + int label; + Datum datum; + string key; + string value; + while (infile >> filename >> label) { + ReadImageToDatum(root_folder + filename, label, &datum); + // get the key, and add a random string so the leveldb will have permuted + // data + GenerateRandomPrefix(8, &key); + key += filename; + // get the value + datum.SerializeToString(&value); + db->Put(leveldb::WriteOptions(), key, value); + LOG(ERROR) << "Writing " << key; + } + + delete db; + return 0; +}