Skip to content

Commit

Permalink
speed up nnenc layer
Browse files Browse the repository at this point in the history
  • Loading branch information
richzhang committed Aug 10, 2016
1 parent 8dce25f commit b30ee0a
Showing 1 changed file with 31 additions and 39 deletions.
70 changes: 31 additions & 39 deletions resources/caffe_traininglayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class BGR2LabLayer(caffe.Layer):
''' Layer converts BGR to Lab
INPUTS
bottom[0] Nx3xXxY
bottom[0].data Nx3xXxY
OUTPUTS
top[0].data Nx3xXxY
'''
Expand Down Expand Up @@ -144,16 +144,16 @@ def backward(self, top, propagate_down, bottom):
bottom[i].diff[...] = np.zeros_like(bottom[i].data)

class ClassRebalanceMultLayer(caffe.Layer):
# '''
# INPUTS
# bottom[0] NxMxXxY feature map
# bottom[1] Nx1xXxY boost coefficients
# OUTPUTS
# top[0] NxMxXxY on forward, gets copied from bottom[0]
# FUNCTIONALITY
# On forward pass, top[0] passes bottom[0]
# On backward pass, bottom[0] gets boosted by bottom[1]
# through pointwise multiplication (with singleton expansion) '''
'''
INPUTS
bottom[0] NxMxXxY feature map
bottom[1] Nx1xXxY boost coefficients
OUTPUTS
top[0] NxMxXxY on forward, gets copied from bottom[0]
FUNCTIONALITY
On forward pass, top[0] passes bottom[0]
On backward pass, bottom[0] gets boosted by bottom[1]
through pointwise multiplication (with singleton expansion) '''
def setup(self, bottom, top):
# check input pair
if len(bottom)==0:
Expand Down Expand Up @@ -185,12 +185,11 @@ def backward(self, top, propagate_down, bottom):
# ***************************
class PriorFactor():
''' Class handles prior factor '''
# def __init__(self,alpha,gamma=0,verbose=True,priorFile='/home/eecs/rich.zhang/src/projects/cross_domain/save/ab_grid_10/prior_probs.npy',genc=-1):
def __init__(self,alpha,gamma=0,verbose=True,priorFile=''):
# INPUTS
# alpha integer prior correction factor, 0 to ignore prior, 1 to divide by prior, alpha to divide by prior^alpha power
# gamma integer percentage to mix in prior probability
# priorFile file file which contains prior probabilities across classes
# alpha integer prior correction factor, 0 to ignore prior, 1 to divide by prior, alpha to divide by prior**alpha
# gamma integer percentage to mix in uniform prior with empirical prior
# priorFile file file which contains prior probabilities across classes

# settings
self.alpha = alpha
Expand All @@ -216,30 +215,15 @@ def __init__(self,alpha,gamma=0,verbose=True,priorFile=''):
self.implied_prior = self.prior_probs*self.prior_factor
self.implied_prior = self.implied_prior/np.sum(self.implied_prior) # re-normalize

# add this to the softmax score
# self.softmax_correction = np.log(self.prior_probs/self.implied_prior * (1-self.implied_prior)/(1-self.prior_probs))

if(self.verbose):
self.print_correction_stats()

# if(not check_value(genc,-1)):
# self.expand_grid(genc)

# def expand_grid(self,genc):
# self.prior_probs_full_grid = genc.enc_full_grid_mtx_nd(self.prior_probs,axis=0,returnGrid=True)
# self.uni_probs_full_grid = genc.enc_full_grid_mtx_nd(self.uni_probs,axis=0,returnGrid=True)
# self.prior_mix_full_grid = genc.enc_full_grid_mtx_nd(self.prior_mix,axis=0,returnGrid=True)
# self.prior_factor_full_grid = genc.enc_full_grid_mtx_nd(self.prior_factor,axis=0,returnGrid=True)
# self.implied_prior_full_grid = genc.enc_full_grid_mtx_nd(self.implied_prior,axis=0,returnGrid=True)
# self.softmax_correction_full_grid = genc.enc_full_grid_mtx_nd(self.softmax_correction,axis=0,returnGrid=True)

def print_correction_stats(self):
print 'Prior factor correction:'
print ' (alpha,gamma) = (%.2f, %.2f)'%(self.alpha,self.gamma)
print ' (min,max,mean,med,exp) = (%.2f, %.2f, %.2f, %.2f, %.2f)'%(np.min(self.prior_factor),np.max(self.prior_factor),np.mean(self.prior_factor),np.median(self.prior_factor),np.sum(self.prior_factor*self.prior_probs))

def forward(self,data_ab_quant,axis=1):
# data_ab_quant = net.blobs['data_ab_quant_map_233'].data[...]
data_ab_maxind = np.argmax(data_ab_quant,axis=axis)
corr_factor = self.prior_factor[data_ab_maxind]
if(axis==0):
Expand All @@ -259,23 +243,32 @@ def __init__(self,NN,sigma,km_filepath='',cc=-1):
else:
self.cc = cc
self.K = self.cc.shape[0]
# self.NN = NN
self.NN = int(NN)
self.sigma = sigma
self.nbrs = nn.NearestNeighbors(n_neighbors=NN, algorithm='ball_tree').fit(self.cc)

def encode_points_mtx_nd(self,pts_nd,axis=1,returnSparse=False):
self.alreadyUsed = False

def encode_points_mtx_nd(self,pts_nd,axis=1,returnSparse=False,sameBlock=True):
t = rz.Timer();
pts_flt = flatten_nd_array(pts_nd,axis=axis)
P = pts_flt.shape[0]
if(sameBlock and self.alreadyUsed):
self.pts_enc_flt[...] = 0 # already pre-allocated
else:
self.alreadyUsed = True
self.pts_enc_flt = np.zeros((P,self.K))
self.p_inds = np.arange(0,P,dtype='int')[:,na()]

P = pts_flt.shape[0]

(dists,inds) = self.nbrs.kneighbors(pts_flt)

pts_enc_flt = np.zeros((P,self.K))
wts = np.exp(-dists**2/(2*self.sigma**2))
wts = wts/np.sum(wts,axis=1)[:,na()]

pts_enc_flt[np.arange(0,P,dtype='int')[:,na()],inds] = wts
pts_enc_nd = unflatten_2d_array(pts_enc_flt,pts_nd,axis=axis)
self.pts_enc_flt[self.p_inds,inds] = wts
pts_enc_nd = unflatten_2d_array(self.pts_enc_flt,pts_nd,axis=axis)

return pts_enc_nd

Expand All @@ -293,7 +286,6 @@ def decode_1hot_mtx_nd(self,pts_enc_nd,axis=1,returnEncode=False):
else:
return pts_dec_nd


# *****************************
# ***** Utility functions *****
# *****************************
Expand All @@ -311,7 +303,7 @@ def na(): # shorthand for new axis
def flatten_nd_array(pts_nd,axis=1):
''' Flatten an nd array into a 2d array with a certain axis
INPUTS
pts_nd N0xN1x...xNd array
pts_nd N0xN1x...xNd array
axis integer
OUTPUTS
pts_flt prod(N \ N_axis) x N_axis array '''
Expand All @@ -328,8 +320,8 @@ def unflatten_2d_array(pts_flt,pts_nd,axis=1,squeeze=False):
''' Unflatten a 2d array with a certain axis
INPUTS
pts_flt prod(N \ N_axis) x M array
pts_nd N0xN1x...xNd array
axis integer
pts_nd N0xN1x...xNd array
axis integer
squeeze bool if true, M=1, squeeze it out
OUTPUTS
pts_out N0xN1x...xNd array '''
Expand Down

0 comments on commit b30ee0a

Please sign in to comment.