Skip to content
/ sFPTT Public

FeedForward Propagation Through Time on Spiking Neural Network (SNNs)

License

Notifications You must be signed in to change notification settings

byin-cwi/sFPTT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 

Repository files navigation

Training spiking neural networks with Forward Porpogation Through Time (FPTT)


This repository contains code to reproduce the key findings of "Training spiking neural networks with Forward Porpogation Through Time". This code implements the spiking recurrent networks with Liquid Time-Constant spiking neurons (LTC) on Pytorch for various tasks. This is scientific software, and as such subject to many modifications; we aim to further improve the software to become more user-friendly and extendible in the future.

Datasets


  1. S/P-MNIST, R-MNIST: This dataset can easily be found in torchvision.datasets.MNIST(MNIST)
  2. Fashion-MNIST: This dataset can easily access via torchvision.datasets.FashionMNIST(FashionMNIST)
  3. DVS dataset: SpikingJelly includes neuromorphic datasets (Gesture128-DVS and Cifar10-DVS.You can also download the datasets from official sit. Our prerpocess of DVS datasets also support in SpikingJelly.
  4. PASCAL Visual Object Classes (VOC) dataset(VOC) contains 20 object categories. Each image in this dataset has pixel-level segmentation annotations, bounding box annotations, and object class annotations. This dataset has been widely used as a benchmark for object detection, semantic segmentation, and classification tasks. In this paper, SPiking-YOLO (SPYv4) network was trained and tested on VOC07+12.

Requirements


  1. Pyhton 3.8.10
  2. A working version of python and Pytorch This should be easy: either use the Google Colab facilities, or do a simple installation on your laptop could probabily using pip. (Start Locally | PyTorch) torch==1.7.1
  3. SpikingJelly(SpikingJelly)
  4. For object detection taskes, it requires OpenCV 2

FPTT posude code


for i in range(sequence_len): # read the sequence
    if i ==0:
        model.init_h(x_in.shape[0]) # At first step initialize the hidden states
    else:
        model.h = list(v.detach() for v in model.h) # detach computation graph from previous timestep
    out = model.forward_t(x_in[:,:,i]) # read input and generate output
    loss_c = (i)/sequence_len*criterion(out, targets) # get prediction loss 
    loss_r = get_regularizer_named_params(named_params, _lambda=1.0 ) # get regularizer loss
    loss = loss_c+loss_r
    optimizer.zero_grad()
    loss.backward() # calculate gradient of current timestep
    optimizer.step() # update the network
    post_optimizer_updates( named_params, epoch) # update trace \bar{w} and \delta{l}

Object detection Demo


A video demo of SPiking-YOLO (SPYv4) :

SPYv4

Running code


You can find more details in readme file of each task.

  1. Adding task
  2. P/S-MNIST task
  3. Image and DVS task
  4. Spiking YOLO Demo

Finally, we’d love to hear from you if you have any comments or suggestions.

References


[1]. https://github.com/bubbliiiing/yolov4-tiny-pytorch

License

MIT

About

FeedForward Propagation Through Time on Spiking Neural Network (SNNs)

Resources

License

Stars

Watchers

Forks

Packages

No packages published