forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathim2rec.cc
284 lines (277 loc) · 10.7 KB
/
im2rec.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
/*!
* Copyright (c) 2015 by Contributors
* \file im2rec.cc
* \brief convert images into image recordio format
* Image Record Format: zeropad[64bit] imid[64bit] img-binary-content
* The 64bit zero pad was reserved for future purposes
*
* Image List Format: unique-image-index label[s] path-to-image
* \sa dmlc/recordio.h
*/
#include <cctype>
#include <cstring>
#include <string>
#include <vector>
#include <iomanip>
#include <sstream>
#include <dmlc/base.h>
#include <dmlc/io.h>
#include <dmlc/timer.h>
#include <dmlc/logging.h>
#include <dmlc/recordio.h>
#include <opencv2/opencv.hpp>
#include "../src/io/image_recordio.h"
#include <random>
/*!
*\brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR 2-CV_INTER_CUBIC
*\ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for shrink, bilinear for others) 10-RAND(0-4)
*/
int GetInterMethod(int inter_method, int old_width, int old_height, int new_width, int new_height, std::mt19937& prnd) {
if (inter_method == 9) {
if (new_width > old_width && new_height > old_height) {
return 2; // CV_INTER_CUBIC for enlarge
} else if (new_width <old_width && new_height < old_height) {
return 3; // CV_INTER_AREA for shrink
} else {
return 1; // CV_INTER_LINEAR for others
}
} else if (inter_method == 10) {
std::uniform_int_distribution<size_t> rand_uniform_int(0, 4);
return rand_uniform_int(prnd);
} else {
return inter_method;
}
}
int main(int argc, char *argv[]) {
if (argc < 4) {
printf("Usage: <image.lst> <image_root_dir> <output.rec> [additional parameters in form key=value]\n"\
"Possible additional parameters:\n"\
"\tcolor=USE_COLOR[default=1] Force color (1), gray image (0) or keep source unchanged (-1).\n"\
"\tresize=newsize resize the shorter edge of image to the newsize, original images will be packed by default\n"\
"\tlabel_width=WIDTH[default=1] specify the label_width in the list, by default set to 1\n"\
"\tpack_label=PACK_LABEL[default=0] whether to also pack multi dimenional label in the record file\n"\
"\tnsplit=NSPLIT[default=1] used for part generation, logically split the image.list to NSPLIT parts by position\n"\
"\tpart=PART[default=0] used for part generation, pack the images from the specific part in image.list\n"\
"\tcenter_crop=CENTER_CROP[default=0] specify whether to crop the center image to make it square.\n"\
"\tquality=QUALITY[default=80] JPEG quality for encoding (1-100, default: 80) or PNG compression for encoding (1-9, default: 3).\n"\
"\tencoding=ENCODING[default='.jpg'] Encoding type. Can be '.jpg' or '.png'\n"\
"\tinter_method=INTER_METHOD[default=1] NN(0) BILINEAR(1) CUBIC(2) AREA(3) LANCZOS4(4) AUTO(9) RAND(10).\n"\
"\tunchanged=UNCHANGED[default=0] Keep the original image encoding, size and color. If set to 1, it will ignore the others parameters.\n");
return 0;
}
int label_width = 1;
int pack_label = 0;
int new_size = -1;
int nsplit = 1;
int partid = 0;
int center_crop = 0;
int quality = 80;
int color_mode = CV_LOAD_IMAGE_COLOR;
int unchanged = 0;
int inter_method = CV_INTER_LINEAR;
std::string encoding(".jpg");
for (int i = 4; i < argc; ++i) {
char key[128], val[128];
int effct_len = 0;
#ifdef _MSC_VER
effct_len = sscanf_s(argv[i], "%[^=]=%s", key, sizeof(key), val, sizeof(val));
#else
effct_len = sscanf(argv[i], "%[^=]=%s", key, val);
#endif
if (effct_len == 2) {
if (!strcmp(key, "resize")) new_size = atoi(val);
if (!strcmp(key, "label_width")) label_width = atoi(val);
if (!strcmp(key, "pack_label")) pack_label = atoi(val);
if (!strcmp(key, "nsplit")) nsplit = atoi(val);
if (!strcmp(key, "part")) partid = atoi(val);
if (!strcmp(key, "center_crop")) center_crop = atoi(val);
if (!strcmp(key, "quality")) quality = atoi(val);
if (!strcmp(key, "color")) color_mode = atoi(val);
if (!strcmp(key, "encoding")) encoding = std::string(val);
if (!strcmp(key, "unchanged")) unchanged = atoi(val);
if (!strcmp(key, "inter_method")) inter_method = atoi(val);
}
}
// Check parameters ranges
if (color_mode != -1 && color_mode != 0 && color_mode != 1) {
LOG(FATAL) << "Color mode must be -1, 0 or 1.";
}
if (encoding != std::string(".jpg") && encoding != std::string(".png")) {
LOG(FATAL) << "Encoding mode must be .jpg or .png.";
}
if (label_width <= 1 && pack_label) {
LOG(FATAL) << "pack_label can only be used when label_width > 1";
}
if (new_size > 0) {
LOG(INFO) << "New Image Size: Short Edge " << new_size;
} else {
LOG(INFO) << "Keep origin image size";
}
if (center_crop) {
LOG(INFO) << "Center cropping to square";
}
if (color_mode == 0) {
LOG(INFO) << "Use gray images";
}
if (color_mode == -1) {
LOG(INFO) << "Keep original color mode";
}
LOG(INFO) << "Encoding is " << encoding;
if (encoding == std::string(".png") && quality > 9) {
quality = 3;
}
if (inter_method != 1) {
switch (inter_method) {
case 0:
LOG(INFO) << "Use inter_method CV_INTER_NN";
break;
case 2:
LOG(INFO) << "Use inter_method CV_INTER_CUBIC";
break;
case 3:
LOG(INFO) << "Use inter_method CV_INTER_AREA";
break;
case 4:
LOG(INFO) << "Use inter_method CV_INTER_LANCZOS4";
break;
case 9:
LOG(INFO) << "Use inter_method mod auto(cubic for enlarge, area for shrink)";
break;
case 10:
LOG(INFO) << "Use inter_method mod rand(nn/bilinear/cubic/area/lanczos4)";
break;
default:
LOG(INFO) << "Unkown inter_method";
return 0;
}
}
std::random_device rd;
std::mt19937 prnd(rd());
using namespace dmlc;
const static size_t kBufferSize = 1 << 20UL;
std::string root = argv[2];
mxnet::io::ImageRecordIO rec;
size_t imcnt = 0;
double tstart = dmlc::GetTime();
dmlc::InputSplit *flist = dmlc::InputSplit::
Create(argv[1], partid, nsplit, "text");
std::ostringstream os;
if (nsplit == 1) {
os << argv[3];
} else {
os << argv[3] << ".part" << std::setw(3) << std::setfill('0') << partid;
}
LOG(INFO) << "Write to output: " << os.str();
dmlc::Stream *fo = dmlc::Stream::Create(os.str().c_str(), "w");
LOG(INFO) << "Output: " << os.str();
dmlc::RecordIOWriter writer(fo);
std::string fname, path, blob;
std::vector<unsigned char> decode_buf;
std::vector<unsigned char> encode_buf;
std::vector<int> encode_params;
if (encoding == std::string(".png")) {
encode_params.push_back(CV_IMWRITE_PNG_COMPRESSION);
encode_params.push_back(quality);
LOG(INFO) << "PNG encoding compression: " << quality;
} else {
encode_params.push_back(CV_IMWRITE_JPEG_QUALITY);
encode_params.push_back(quality);
LOG(INFO) << "JPEG encoding quality: " << quality;
}
dmlc::InputSplit::Blob line;
std::vector<float> label_buf(label_width, 0.f);
while (flist->NextRecord(&line)) {
std::string sline(static_cast<char*>(line.dptr), line.size);
std::istringstream is(sline);
if (!(is >> rec.header.image_id[0] >> rec.header.label)) continue;
label_buf[0] = rec.header.label;
for (int k = 1; k < label_width; ++k) {
CHECK(is >> label_buf[k])
<< "Invalid ImageList, did you provide the correct label_width?";
}
if (pack_label) rec.header.flag = label_width;
rec.SaveHeader(&blob);
if (pack_label) {
size_t bsize = blob.size();
blob.resize(bsize + label_buf.size()*sizeof(float));
memcpy(BeginPtr(blob) + bsize,
BeginPtr(label_buf), label_buf.size()*sizeof(float));
}
CHECK(std::getline(is, fname));
// eliminate invalid chars in the end
while (fname.length() != 0 &&
(isspace(*fname.rbegin()) || !isprint(*fname.rbegin()))) {
fname.resize(fname.length() - 1);
}
// eliminate invalid chars in beginning.
const char *p = fname.c_str();
while (isspace(*p)) ++p;
path = root + p;
// use "r" is equal to rb in dmlc::Stream
dmlc::Stream *fi = dmlc::Stream::Create(path.c_str(), "r");
decode_buf.clear();
size_t imsize = 0;
while (true) {
decode_buf.resize(imsize + kBufferSize);
size_t nread = fi->Read(BeginPtr(decode_buf) + imsize, kBufferSize);
imsize += nread;
decode_buf.resize(imsize);
if (nread != kBufferSize) break;
}
delete fi;
if (unchanged != 1) {
cv::Mat img = cv::imdecode(decode_buf, color_mode);
CHECK(img.data != NULL) << "OpenCV decode fail:" << path;
cv::Mat res = img;
if (new_size > 0) {
if (center_crop) {
if (img.rows > img.cols) {
int margin = (img.rows - img.cols)/2;
img = img(cv::Range(margin, margin+img.cols), cv::Range(0, img.cols));
} else {
int margin = (img.cols - img.rows)/2;
img = img(cv::Range(0, img.rows), cv::Range(margin, margin + img.rows));
}
}
int interpolation_method = 1;
if (img.rows > img.cols) {
if (img.cols != new_size) {
interpolation_method = GetInterMethod(inter_method, img.cols, img.rows, new_size, img.rows * new_size / img.cols, prnd);
cv::resize(img, res, cv::Size(new_size, img.rows * new_size / img.cols), 0, 0, interpolation_method);
} else {
res = img.clone();
}
} else {
if (img.rows != new_size) {
interpolation_method = GetInterMethod(inter_method, img.cols, img.rows, new_size * img.cols / img.rows, new_size, prnd);
cv::resize(img, res, cv::Size(new_size * img.cols / img.rows, new_size), 0, 0, interpolation_method);
} else {
res = img.clone();
}
}
}
encode_buf.clear();
CHECK(cv::imencode(encoding, res, encode_buf, encode_params));
// write buffer
size_t bsize = blob.size();
blob.resize(bsize + encode_buf.size());
memcpy(BeginPtr(blob) + bsize,
BeginPtr(encode_buf), encode_buf.size());
} else {
size_t bsize = blob.size();
blob.resize(bsize + decode_buf.size());
memcpy(BeginPtr(blob) + bsize,
BeginPtr(decode_buf), decode_buf.size());
}
writer.WriteRecord(BeginPtr(blob), blob.size());
// write header
++imcnt;
if (imcnt % 1000 == 0) {
LOG(INFO) << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
}
}
LOG(INFO) << "Total: " << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
delete fo;
delete flist;
return 0;
}