forked from Shikhargupta/Spiking-Neural-Network
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c700e94
commit 41d104f
Showing
8 changed files
with
454 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
|
||
|
||
####################################################### README #################################################################### | ||
|
||
# This is the main file which calls all the functions and trains the network by updating weights | ||
|
||
|
||
##################################################################################################################################### | ||
|
||
|
||
import numpy as np | ||
from neuron import neuron | ||
import random | ||
from matplotlib import pyplot as plt | ||
from recep_field import rf | ||
import cv2 | ||
from spike_train import encode | ||
from rl import rl | ||
from rl import update | ||
from reconstruct import reconst_weights | ||
from parameters import param as par | ||
from var_th import threshold | ||
import os | ||
|
||
#potentials of output neurons | ||
pot_arrays = [] | ||
for i in range(par.n): | ||
pot_arrays.append([]) | ||
|
||
#time series | ||
time = np.arange(1, par.T+1, 1) | ||
|
||
layer2 = [] | ||
|
||
# creating the hidden layer of neurons | ||
for i in range(par.n): | ||
a = neuron() | ||
layer2.append(a) | ||
|
||
#synapse matrix initialization | ||
synapse = np.zeros((par.n,par.m)) | ||
|
||
for i in range(par.n): | ||
for j in range(par.m): | ||
synapse[i][j] = random.uniform(0,0.4*par.scale) | ||
|
||
|
||
for k in range(par.epoch): | ||
for i in range(322,323): | ||
print i," ",k | ||
img = cv2.imread("mnist1/" + str(i) + ".png", 0) | ||
|
||
#Convolving image with receptive field | ||
pot = rf(img) | ||
|
||
#Generating spike train | ||
train = np.array(encode(pot)) | ||
|
||
#calculating threshold value for the image | ||
var_threshold = threshold(train) | ||
|
||
# print var_threshold | ||
# synapse_act = np.zeros((par.n,par.m)) | ||
# var_threshold = 9 | ||
# print var_threshold | ||
# var_D = (var_threshold*3)*0.07 | ||
|
||
var_D = 0.15*par.scale | ||
|
||
for x in layer2: | ||
x.initial(var_threshold) | ||
|
||
#flag for lateral inhibition | ||
f_spike = 0 | ||
|
||
img_win = 100 | ||
|
||
active_pot = [] | ||
for index1 in range(par.n): | ||
active_pot.append(0) | ||
|
||
#Leaky integrate and fire neuron dynamics | ||
for t in time: | ||
for j, x in enumerate(layer2): | ||
active = [] | ||
if(x.t_rest<t): | ||
x.P = x.P + np.dot(synapse[j], train[:,t]) | ||
if(x.P>par.Prest): | ||
x.P -= var_D | ||
active_pot[j] = x.P | ||
|
||
pot_arrays[j].append(x.P) | ||
|
||
# Lateral Inhibition | ||
if(f_spike==0): | ||
high_pot = max(active_pot) | ||
if(high_pot>var_threshold): | ||
f_spike = 1 | ||
winner = np.argmax(active_pot) | ||
img_win = winner | ||
print "winner is " + str(winner) | ||
for s in range(par.n): | ||
if(s!=winner): | ||
layer2[s].P = par.Pmin | ||
|
||
#Check for spikes and update weights | ||
for j,x in enumerate(layer2): | ||
s = x.check() | ||
if(s==1): | ||
x.t_rest = t + x.t_ref | ||
x.P = par.Prest | ||
for h in range(par.m): | ||
for t1 in range(-2,par.t_back-1, -1): | ||
if 0<=t+t1<par.T+1: | ||
if train[h][t+t1] == 1: | ||
# print "weight change by" + str(update(synapse[j][h], rl(t1))) | ||
synapse[j][h] = update(synapse[j][h], rl(t1)) | ||
|
||
|
||
|
||
|
||
for t1 in range(2,par.t_fore+1, 1): | ||
if 0<=t+t1<par.T+1: | ||
if train[h][t+t1] == 1: | ||
# print "weight change by" + str(update(synapse[j][h], rl(t1))) | ||
synapse[j][h] = update(synapse[j][h], rl(t1)) | ||
|
||
if(img_win!=100): | ||
for p in range(par.m): | ||
if sum(train[p])==0: | ||
synapse[img_win][p] -= 0.06*par.scale | ||
if(synapse[img_win][p]<par.w_min): | ||
synapse[img_win][p] = par.w_min | ||
|
||
|
||
ttt = np.arange(0,len(pot_arrays[0]),1) | ||
Pth = [] | ||
for i in range(len(ttt)): | ||
Pth.append(layer2[0].Pth) | ||
|
||
#plotting | ||
for i in range(par.n): | ||
axes = plt.gca() | ||
axes.set_ylim([-20,50]) | ||
plt.plot(ttt,Pth, 'r' ) | ||
plt.plot(ttt,pot_arrays[i]) | ||
plt.show() | ||
|
||
#Reconstructing weights to analyse training | ||
for i in range(par.n): | ||
reconst_weights(synapse[i],i+1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
############################################################ README ############################################################## | ||
|
||
# This is neuron class which defines the dynamics of a neuron. All the parameters are initialised and methods are included to check | ||
# for spikes and apply lateral inhibition. | ||
|
||
################################################################################################################################### | ||
|
||
import numpy as np | ||
import random | ||
from matplotlib import pyplot as plt | ||
from parameters import param as par | ||
|
||
class neuron: | ||
def __init__(self): | ||
self.t_ref = 30 | ||
self.t_rest = -1 | ||
self.P = par.Prest | ||
def check(self): | ||
if self.P>= self.Pth: | ||
self.P = par.Prest | ||
return 1 | ||
elif self.P < par.Pmin: | ||
self.P = par.Prest | ||
return 0 | ||
else: | ||
return 0 | ||
def inhibit(self): | ||
self.P = par.Pmin | ||
def initial(self, th): | ||
self.Pth = th | ||
self.t_rest = -1 | ||
self.P = par.Prest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
################################################ README ######################################################### | ||
|
||
# This file contains all the parameters of the network. | ||
|
||
################################################################################################################# | ||
|
||
class param: | ||
scale = 1 | ||
T = 200 | ||
t_back = -20 | ||
t_fore = 20 | ||
|
||
pixel_x = 28 | ||
Prest = 0 | ||
m = pixel_x*pixel_x #Number of neurons in first layer | ||
n = 3 #Number of neurons in second layer | ||
Pmin = -500*scale | ||
# Pth = 5 | ||
# D = 0.7 | ||
w_max = 1.5*scale | ||
w_min = -1.2*scale | ||
sigma = 0.1 #0.02 | ||
A_plus = 0.8 # time difference is positive i.e negative reinforcement | ||
A_minus = 0.3 # 0.01 # time difference is negative i.e positive reinforcement | ||
tau_plus = 8 | ||
tau_minus = 5 | ||
|
||
epoch = 12 | ||
|
||
|
||
fr_bits = 12 | ||
int_bits = 12 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
####################################################### README ######################################################### | ||
|
||
# This file consists of function that convolves an image with a receptive field so that input to the network is | ||
# close to the form perceived by our eyes. | ||
|
||
######################################################################################################################### | ||
|
||
|
||
import numpy as np | ||
import cv2 | ||
from parameters import param as par | ||
|
||
def rf(inp): | ||
sca1 = 0.625 | ||
sca2 = 0.125 | ||
sca3 = -0.125 | ||
sca4 = -.5 | ||
|
||
#Receptive field kernel | ||
w = [[ sca4 ,sca3 , sca2 ,sca3 ,sca4], | ||
[ sca3 ,sca2 , sca1 ,sca2 ,sca3], | ||
[ sca2 ,sca1 , 1 ,sca1 ,sca2], | ||
[ sca3 ,sca2 , sca1 ,sca2 ,sca3], | ||
[ sca4 ,sca3 , sca2 ,sca3 ,sca4]] | ||
|
||
pot = np.zeros([par.pixel_x,par.pixel_x]) | ||
ran = [-2,-1,0,1,2] | ||
ox = 2 | ||
oy = 2 | ||
|
||
#Convolution | ||
for i in range(par.pixel_x): | ||
for j in range(par.pixel_x): | ||
summ = 0 | ||
for m in ran: | ||
for n in ran: | ||
if (i+m)>=0 and (i+m)<=par.pixel_x-1 and (j+n)>=0 and (j+n)<=par.pixel_x-1: | ||
summ = summ + w[ox+m][oy+n]*inp[i+m][j+n]/255 | ||
pot[i][j] = summ | ||
return pot | ||
|
||
if __name__ == '__main__': | ||
|
||
img = cv2.imread("mnist1/" + str(1) + ".png", 0) | ||
pot = rf(img) | ||
max_a = [] | ||
min_a = [] | ||
for i in pot: | ||
max_a.append(max(i)) | ||
min_a.append(min(i)) | ||
print "max", max(max_a) | ||
print "min", min(min_a) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
###################################################### README ##################################################### | ||
|
||
# This file is used to leverage the generative property of a Spiking Neural Network. reconst_weights function is used | ||
# for that purpose. Looking at the reconstructed images helps to analyse training process. | ||
|
||
#################################################################################################################### | ||
|
||
|
||
import numpy as np | ||
from numpy import interp | ||
import cv2 | ||
from recep_field import rf | ||
from parameters import param as par | ||
|
||
|
||
def reconst_weights(weights, num): | ||
weights = np.array(weights) | ||
weights = np.reshape(weights, (par.pixel_x,par.pixel_x)) | ||
img = np.zeros((par.pixel_x,par.pixel_x)) | ||
for i in range(par.pixel_x): | ||
for j in range(par.pixel_x): | ||
img[i][j] = int(interp(weights[i][j], [par.w_min,par.w_max], [0,255])) | ||
|
||
cv2.imwrite('neuron' + str(num) + '.png' ,img) | ||
return img | ||
|
||
def reconst_rf(weights, num): | ||
weights = np.array(weights) | ||
weights = np.reshape(weights, (par.pixel_x,par.pixel_x)) | ||
img = np.zeros((par.pixel_x,par.pixel_x)) | ||
for i in range(par.pixel_x): | ||
for j in range(par.pixel_x): | ||
img[i][j] = int(interp(weights[i][j], [-2,3.625], [0,255])) | ||
|
||
cv2.imwrite('neuron' + str(num) + '.png' ,img) | ||
return img | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
img = cv2.imread("images2/" + "69" + ".png", 0) | ||
pot = rf(img) | ||
reconst_rf(pot, 12) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
########################################################## README ########################################################### | ||
|
||
# This file implements STDP curve and weight update rule | ||
|
||
############################################################################################################################## | ||
|
||
|
||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
from parameters import param as par | ||
|
||
#STDP reinforcement learning curve | ||
def rl(t): | ||
|
||
if t>0: | ||
return -par.A_plus*np.exp(-float(t)/par.tau_plus) | ||
if t<=0: | ||
return par.A_minus*np.exp(float(t)/par.tau_minus) | ||
|
||
|
||
#STDP weight update rule | ||
def update(w, del_w): | ||
if del_w<0: | ||
return w + par.sigma*del_w*(w-abs(par.w_min))*par.scale | ||
elif del_w>0: | ||
return w + par.sigma*del_w*(par.w_max-w)*par.scale | ||
|
||
if __name__ == '__main__': | ||
|
||
print rl(-20)*par.sigma | ||
|
||
|
Oops, something went wrong.