@@ -37,6 +37,7 @@ namespace cc{
37
37
public:
38
38
operator char *(){ return get (); }
39
39
operator const char *(){ return get (); }
40
+ bool operator ==(const char * str){ return strcmp (get (), str) == 0 ; }
40
41
CCString& operator =(const char * str){ set (str); return *this ; }
41
42
CCString& operator =(char * str){ set (str); return *this ; }
42
43
CCString& operator =(const CCString& str){ set (str.get (), str.len ()); return *this ; }
@@ -171,6 +172,28 @@ namespace cc{
171
172
};
172
173
173
174
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////
175
+ // 轻量级的Blob
176
+ class Blob ;
177
+ struct CCAPI BlobData{
178
+ float * list;
179
+ int num;
180
+ int channels;
181
+ int height;
182
+ int width;
183
+ int capacity_count; // 保留空间的元素个数长度,字节数请 * sizeof(float)
184
+
185
+ BlobData ();
186
+ virtual ~BlobData ();
187
+ bool empty () const ;
188
+ int count () const ;
189
+ void reshape (int num, int channels, int height, int width);
190
+ void reshapeLike (const BlobData* other);
191
+ void copyFrom (const BlobData* other);
192
+ void copyFrom (const Blob* other);
193
+ void reshapeLike (const Blob* other);
194
+ void release ();
195
+ };
196
+
174
197
class CCAPI Blob{
175
198
public:
176
199
void setNative (void * native);
@@ -200,6 +223,7 @@ namespace cc{
200
223
void Reshape (int numShape, int * shapeDims);
201
224
void ReshapeLike (const Blob& other);
202
225
void copyFrom (const Blob& other, bool copyDiff = false , bool reshape = false );
226
+ void copyFrom (const BlobData& other);
203
227
void setDataRGB (int numIndex, const Mat& data);
204
228
CCString shapeString ();
205
229
@@ -219,6 +243,12 @@ namespace cc{
219
243
bool hasParam (const char * path);
220
244
CCString name ();
221
245
MessageHandle* param ();
246
+ int getNumBottom ();
247
+ int getNumTop ();
248
+ CCString bottomName (int index);
249
+ CCString topName (int index);
250
+ Blob* paramBlob (int index);
251
+ int getNumParamBlob ();
222
252
223
253
#ifdef USE_PROTOBUF
224
254
caffe::LayerParameter& layer_param ();
@@ -273,8 +303,8 @@ namespace cc{
273
303
void * getNative ();
274
304
int iter ();
275
305
float smooth_loss ();
276
- void Restore (const char * resume_file );
277
- void Snapshot ();
306
+ void Restore (const char * solvestate_file );
307
+ void Snapshot (const char * caffemodel_savepath = 0 , bool save_solver_state = true );
278
308
int max_iter ();
279
309
void Solve ();
280
310
void installActionSignalOperator ();
@@ -300,8 +330,31 @@ namespace cc{
300
330
CCAPI void CCCALL releaseBlob (Blob* blob);
301
331
CCAPI void CCCALL releaseSolver (Solver* solver);
302
332
CCAPI void CCCALL releaseNet (Net* net);
303
- CCAPI Solver* CCCALL loadSolverFromPrototxt (const char * solver_prototxt);
304
- CCAPI Solver* CCCALL loadSolverFromPrototxtString (const char * solver_prototxt_string);
333
+ CCAPI Solver* CCCALL loadSolverFromPrototxt (const char * solver_prototxt, const char * netstring = 0 );
334
+ CCAPI Solver* CCCALL loadSolverFromPrototxtString (const char * solver_prototxt_string, const char * netstring = 0 );
335
+
336
+ #ifdef USE_CC_PYTHON
337
+ class CCAPI CCPython{
338
+ public:
339
+ CCPython ();
340
+ virtual ~CCPython ();
341
+ bool load (const char * pyfile);
342
+ CCString callstringFunction (const CCString& name, CCString& errmsg = CCString());
343
+ CCString train_ptototxt ();
344
+ CCString deploy_prototxt ();
345
+ CCString solver ();
346
+ CCString last_error ();
347
+
348
+ private:
349
+ void * module_;
350
+ CCString lasterror_;
351
+ };
352
+
353
+ CCAPI Solver* CCCALL loadSolverFromPython (const char * pythonfile);
354
+
355
+ // phase指定加载train_prototxt还是deploy_prototxt
356
+ CCAPI Net* CCCALL loadNetFromPython (const char * pythonfile, const char * func=" deploy_prototxt" , int phase = PhaseTest);
357
+ #endif
305
358
306
359
#ifdef USE_PROTOBUF
307
360
CCAPI Solver* CCCALL newSolverFromProto (const caffe::SolverParameter* solver_param);
@@ -326,6 +379,21 @@ namespace cc{
326
379
CCAPI void CCCALL WriteProtoToBinaryFile (const google::protobuf::Message& proto, const char * filename);
327
380
#endif
328
381
382
+
383
+ // /////////////////////////////////////////////////////////////////////////////////////////////////////////////
384
+ class CCAPI LMDB{
385
+ public:
386
+ LMDB (const char * folder);
387
+ void put (const char * key, const void * data, int length);
388
+ void putAnnotatedDatum (const char * key, void * datum);
389
+ void putDatum (const char * key, void * datum);
390
+ void release ();
391
+ virtual ~LMDB ();
392
+
393
+ private:
394
+ void * native_;
395
+ };
396
+
329
397
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////
330
398
class CCAPI AbstractCustomLayer{
331
399
public:
@@ -379,39 +447,44 @@ namespace cc{
379
447
380
448
class CCAPI DataLayer : public AbstractCustomLayer{
381
449
public:
382
- DataLayer ();
450
+ DataLayer (int batchCacheSize = 3 , int watcherSize = 1 );
383
451
virtual ~DataLayer ();
384
452
385
- virtual int getBatchCacheSize ();
386
453
virtual void loadBatch (Blob** top, int numTop) = 0;
387
454
virtual void setup (const char * name, const char * type, const char * param_str, int phase, Blob** bottom, int numBottom, Blob** top, int numTop);
388
455
virtual void forward (Blob** bottom, int numBottom, Blob** top, int numTop);
389
- virtual void reshape (Blob** bottom, int numBottom, Blob** top, int numTop){}
456
+ virtual void reshape (Blob** bottom, int numBottom, Blob** top, int numTop);
390
457
void stopBatchLoader ();
458
+ int getWatcherIndex ();
391
459
virtual int waitForDataTime ();
460
+ void setPrintWaitData (bool wait);
392
461
393
462
private:
394
463
void setupBatch (Blob** top, int numTop);
395
- static void watcher (DataLayer* ptr);
464
+ static void watcher (DataLayer* ptr, int ind );
396
465
void startWatcher ();
397
466
void stopWatcher ();
398
467
void pullBatch (Blob** top, int numTop);
399
468
400
469
private:
401
470
volatile bool keep_run_watcher_;
402
- void * hsem_;
403
- bool * batch_flags_;
404
- Blob*** batch_;
471
+ void ** hsem_;
472
+ bool ** batch_flags_;
473
+ Blob**** batch_; // batch_
405
474
int numTop_;
406
475
int cacheBatchSize_;
476
+ int watcherSize_;
477
+ void * watcher_map_;
478
+ bool print_waitdata_;
407
479
};
408
480
409
481
class CCAPI SSDDataLayer : public DataLayer{
410
482
public:
411
- SSDDataLayer ();
483
+ SSDDataLayer (int batchCacheSize = 3 , int watcherSize = 1 );
412
484
virtual ~SSDDataLayer ();
413
485
414
486
virtual int getBatchCacheSize ();
487
+ virtual int getWatcherSize ();
415
488
virtual void loadBatch (Blob** top, int numTop);
416
489
virtual void setup (const char * name, const char * type, const char * param_str, int phase, Blob** bottom, int numBottom, Blob** top, int numTop);
417
490
virtual void * getAnnDatum () = 0;
@@ -435,6 +508,9 @@ namespace cc{
435
508
CCAPI void CCCALL releaseLabelMap (void * labelmap);
436
509
CCAPI void CCCALL releaseAnnDatum (void * datum);
437
510
511
+ CCAPI void * CCCALL loadDatum (const char * path, int label);
512
+ CCAPI void CCCALL releaseDatum (void * datum);
513
+
438
514
439
515
// ///////////////////////////////////////////////////////////////////////////
440
516
CCAPI MessageHandle CCCALL loadMessageNetCaffemodel (const char * filename);
0 commit comments