Skip to content

Commit

Permalink
MaskApi: adding simple box/rle NMS code (matlab/lua)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdollar committed Aug 2, 2016
1 parent 2934299 commit 6e41fb5
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ images/
annotations/
results/
external/
.DS_Store

MatlabAPI/analyze*/
MatlabAPI/visualize*/
Expand All @@ -11,4 +12,4 @@ PythonAPI/pycocotools/_mask.c
PythonAPI/pycocotools/_mask.so
PythonAPI/pycocotools/coco.pyc
PythonAPI/pycocotools/cocoeval.pyc
PythonAPI/pycocotools/mask.pyc
PythonAPI/pycocotools/mask.pyc
20 changes: 20 additions & 0 deletions LuaAPI/MaskApi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The following API functions are defined:
decode - Decode binary masks encoded via RLE.
merge - Compute union or intersection of encoded masks.
iou - Compute intersection over union between masks.
nms - Compute non-maximum suppression between ordered masks.
area - Compute area of encoded masks.
toBbox - Get bounding boxes surrounding encoded masks.
frBbox - Convert bounding boxes to encoded masks.
Expand All @@ -24,6 +25,7 @@ Usage:
masks = MaskApi.decode( Rs )
R = MaskApi.merge( Rs, [intersect=false] )
o = MaskApi.iou( dt, gt, [iscrowd=false] )
keep = MaskApi.nms( dt, thr )
a = MaskApi.area( Rs )
bbs = MaskApi.toBbox( Rs )
Rs = MaskApi.frBbox( bbs, h, w )
Expand Down Expand Up @@ -107,6 +109,22 @@ MaskApi.iou = function( dt, gt, iscrowd )
end
end

MaskApi.nms = function( dt, thr )
if torch.isTensor(dt) then
local n, k = dt:size(1), dt:size(2); assert(k==4)
local Q = dt:type('torch.DoubleTensor'):contiguous():data()
local kp = torch.IntTensor(n):contiguous()
libmaskapi.bbNms(Q,n,kp:data(),thr)
return kp
else
local Q, n = MaskApi._rlesFrLua(dt)
local kp = torch.IntTensor(n):contiguous()
libmaskapi.rleNms(Q,n,kp:data(),thr)
MaskApi._rlesFree(Q,n)
return kp
end
end

MaskApi.area = function( Rs )
local Qs, n, h, w = MaskApi._rlesFrLua(Rs)
local a = torch.IntTensor(n):contiguous()
Expand Down Expand Up @@ -254,7 +272,9 @@ ffi.cdef[[
void rleMerge( const RLE *R, RLE *M, siz n, int intersect );
void rleArea( const RLE *R, siz n, uint *a );
void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o );
void rleNms( RLE *dt, siz n, uint *keep, double thr );
void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o );
void bbNms( BB dt, siz n, uint *keep, double thr );
void rleToBbox( const RLE *R, BB bb, siz n );
void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n );
void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w );
Expand Down
6 changes: 6 additions & 0 deletions MatlabAPI/MaskApi.m
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
% decode - Decode binary masks encoded via RLE.
% merge - Compute union or intersection of encoded masks.
% iou - Compute intersection over union between masks.
% nms - Compute non-maximum suppression between ordered masks.
% area - Compute area of encoded masks.
% toBbox - Get bounding boxes surrounding encoded masks.
% frBbox - Convert bounding boxes to encoded masks.
Expand All @@ -39,6 +40,7 @@
% masks = MaskApi.decode( Rs )
% R = MaskApi.merge( Rs, [intersect=false] )
% o = MaskApi.iou( dt, gt, [iscrowd=false] )
% keep = MaskApi.nms( dt, thr )
% a = MaskApi.area( Rs )
% bbs = MaskApi.toBbox( Rs )
% Rs = MaskApi.frBbox( bbs, h, w )
Expand Down Expand Up @@ -90,6 +92,10 @@
o = maskApiMex( 'iou', dt', gt', varargin{:} );
end

function keep = nms( dt, thr )
keep = maskApiMex('nms',dt',thr);
end

function a = area( Rs )
a = maskApiMex( 'area', Rs );
end
Expand Down
14 changes: 14 additions & 0 deletions MatlabAPI/private/maskApiMex.c
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ void mexFunction( int nl, mxArray *pl[], int nr, const mxArray *pr[] )
double *o=mxGetPr(pl[0]); bbIou(dt,gt,nDt,nGt,iscrowd,o);
}

} else if(!strcmp(action,"nms")) {
siz n; uint *keep; double thr=(double) mxGetScalar(pr[1]);
if(mxIsStruct(pr[0])) {
RLE *dt=frMxArray(pr[0],&n,1);
pl[0]=mxCreateNumericMatrix(1,n,mxUINT32_CLASS,mxREAL);
keep=(uint*) mxGetPr(pl[0]); rleNms(dt,n,keep,thr);
rlesFree(&dt,n);
} else {
checkType(pr[0],mxDOUBLE_CLASS);
double *dt=mxGetPr(pr[0]); n=mxGetN(pr[0]);
pl[0]=mxCreateNumericMatrix(1,n,mxUINT32_CLASS,mxREAL);
keep=(uint*) mxGetPr(pl[0]); bbNms(dt,n,keep,thr);
}

} else if(!strcmp(action,"toBbox")) {
R=frMxArray(pr[0],&n,0);
pl[0]=mxCreateNumericMatrix(4,n,mxDOUBLE_CLASS,mxREAL);
Expand Down
Binary file modified MatlabAPI/private/maskApiMex.mexa64
Binary file not shown.
Binary file modified MatlabAPI/private/maskApiMex.mexmaci64
Binary file not shown.
22 changes: 22 additions & 0 deletions common/maskApi.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ) {
}
}

void rleNms( RLE *dt, siz n, uint *keep, double thr ) {
siz i, j; double u;
for( i=0; i<n; i++ ) keep[i]=1;
for( i=0; i<n; i++ ) if(keep[i]) {
for( j=i+1; j<n; j++ ) if(keep[j]) {
rleIou(dt+i,dt+j,1,1,0,&u);
if(u>thr) keep[j]=0;
}
}
}

void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) {
double h, w, i, u, ga, da; siz g, d; int crowd;
for( g=0; g<n; g++ ) {
Expand All @@ -108,6 +119,17 @@ void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) {
}
}

void bbNms( BB dt, siz n, uint *keep, double thr ) {
siz i, j; double u;
for( i=0; i<n; i++ ) keep[i]=1;
for( i=0; i<n; i++ ) if(keep[i]) {
for( j=i+1; j<n; j++ ) if(keep[j]) {
bbIou(dt+i*4,dt+j*4,1,1,0,&u);
if(u>thr) keep[j]=0;
}
}
}

void rleToBbox( const RLE *R, BB bb, siz n ) {
siz i; for( i=0; i<n; i++ ) {
uint h, w, x, y, xs, ys, xe, ye, cc, t; siz j, m;
Expand Down
6 changes: 6 additions & 0 deletions common/maskApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,15 @@ void rleArea( const RLE *R, siz n, uint *a );
/* Compute intersection over union between masks. */
void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o );

/* Compute non-maximum suppression between bounding masks */
void rleNms( RLE *dt, siz n, uint *keep, double thr );

/* Compute intersection over union between bounding boxes. */
void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o );

/* Compute non-maximum suppression between bounding boxes */
void bbNms( BB dt, siz n, uint *keep, double thr );

/* Get bounding boxes surrounding encoded masks. */
void rleToBbox( const RLE *R, BB bb, siz n );

Expand Down

0 comments on commit 6e41fb5

Please sign in to comment.