forked from BVLC/caffe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_imageset.cpp
108 lines (100 loc) · 3.34 KB
/
convert_imageset.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
// Copyright 2013 Yangqing Jia
// This program converts a set of images to a leveldb by storing them as Datum
// proto buffers.
// Usage:
// convert_imageset ROOTFOLDER/ LISTFILE DB_NAME [0/1]
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
// subfolder1/file1.JPEG 7
// ....
// if the last argument is 1, a random shuffle will be carried out before we
// process the file lines.
// You are responsible for shuffling the files yourself.
#include <glog/logging.h>
#include <leveldb/db.h>
#include <leveldb/write_batch.h>
#include <algorithm>
#include <string>
#include <iostream>
#include <fstream>
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
using namespace caffe;
using std::pair;
using std::string;
int main(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
if (argc < 4) {
printf("Convert a set of images to the leveldb format used\n"
"as input for Caffe.\n"
"Usage:\n"
" convert_imageset ROOTFOLDER/ LISTFILE DB_NAME"
" RANDOM_SHUFFLE_DATA[0 or 1]\n"
"The ImageNet dataset for the training demo is at\n"
" http://www.image-net.org/download-images\n");
return 0;
}
std::ifstream infile(argv[2]);
std::vector<std::pair<string, int> > lines;
string filename;
int label;
while (infile >> filename >> label) {
lines.push_back(std::make_pair(filename, label));
}
if (argc == 5 && argv[4][0] == '1') {
// randomly shuffle data
LOG(INFO) << "Shuffling data";
std::random_shuffle(lines.begin(), lines.end());
}
LOG(INFO) << "A total of " << lines.size() << " images.";
leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
LOG(INFO) << "Opening leveldb " << argv[3];
leveldb::Status status = leveldb::DB::Open(
options, argv[3], &db);
CHECK(status.ok()) << "Failed to open leveldb " << argv[3];
string root_folder(argv[1]);
Datum datum;
int count = 0;
const int maxKeyLength = 256;
char key_cstr[maxKeyLength];
leveldb::WriteBatch* batch = new leveldb::WriteBatch();
int data_size;
bool data_size_initialized = false;
for (int line_id = 0; line_id < lines.size(); ++line_id) {
if (!ReadImageToDatum(root_folder + lines[line_id].first, lines[line_id].second,
&datum)) {
continue;
};
if (!data_size_initialized) {
data_size = datum.channels() * datum.height() * datum.width();
data_size_initialized = true;
} else {
const string& data = datum.data();
CHECK_EQ(data.size(), data_size) << "Incorrect data field size " << data.size();
}
// sequential
snprintf(key_cstr, maxKeyLength, "%08d_%s", line_id, lines[line_id].first.c_str());
string value;
// get the value
datum.SerializeToString(&value);
batch->Put(string(key_cstr), value);
if (++count % 1000 == 0) {
db->Write(leveldb::WriteOptions(), batch);
LOG(ERROR) << "Processed " << count << " files.";
delete batch;
batch = new leveldb::WriteBatch();
}
}
// write the last batch
if (count % 1000 != 0) {
db->Write(leveldb::WriteOptions(), batch);
LOG(ERROR) << "Processed " << count << " files.";
}
delete batch;
delete db;
return 0;
}