This repository contains code to reproduce the key findings of "Training spiking neural networks with Forward Porpogation Through Time". This code implements the spiking recurrent networks with Liquid Time-Constant spiking neurons (LTC) on Pytorch for various tasks. This is scientific software, and as such subject to many modifications; we aim to further improve the software to become more user-friendly and extendible in the future.
- S/P-MNIST, R-MNIST: This dataset can easily be found in torchvision.datasets.MNIST(MNIST)
- Fashion-MNIST: This dataset can easily access via torchvision.datasets.FashionMNIST(FashionMNIST)
- DVS dataset: SpikingJelly includes neuromorphic datasets (Gesture128-DVS and Cifar10-DVS.You can also download the datasets from official sit. Our prerpocess of DVS datasets also support in SpikingJelly.
- PASCAL Visual Object Classes (VOC) dataset(VOC) contains 20 object categories. Each image in this dataset has pixel-level segmentation annotations, bounding box annotations, and object class annotations. This dataset has been widely used as a benchmark for object detection, semantic segmentation, and classification tasks. In this paper, SPiking-YOLO (SPYv4) network was trained and tested on VOC07+12.
- Pyhton 3.8.10
- A working version of python and Pytorch This should be easy: either use the Google Colab facilities, or do a simple installation on your laptop could probabily using pip. (Start Locally | PyTorch) torch==1.7.1
- SpikingJelly(SpikingJelly)
- For object detection taskes, it requires OpenCV 2
for i in range(sequence_len): # read the sequence
if i ==0:
model.init_h(x_in.shape[0]) # At first step initialize the hidden states
else:
model.h = list(v.detach() for v in model.h) # detach computation graph from previous timestep
out = model.forward_t(x_in[:,:,i]) # read input and generate output
loss_c = (i)/sequence_len*criterion(out, targets) # get prediction loss
loss_r = get_regularizer_named_params(named_params, _lambda=1.0 ) # get regularizer loss
loss = loss_c+loss_r
optimizer.zero_grad()
loss.backward() # calculate gradient of current timestep
optimizer.step() # update the network
post_optimizer_updates( named_params, epoch) # update trace \bar{w} and \delta{l}
A video demo of SPiking-YOLO (SPYv4) :
You can find more details in readme file of each task.
Finally, we’d love to hear from you if you have any comments or suggestions.
[1]. https://github.com/bubbliiiing/yolov4-tiny-pytorch
MIT