Skip to content

Visualization of a neural network in PyTorch -> figure out what input images most excite the network. See my blog post for more details.

Notifications You must be signed in to change notification settings

kuo348/visualization_pytorch

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch input visualization of a neural network

This repository explores how to visualize what a neural network learned after having trained it. The basic idea is that we keep the network weights fixed while we run backpropagation on the input image to change the input image to excite our target output the most.

We thus obtain images which are "idealized" versions of our target classes.

For MNIST this is the result of these "idealized" input images that the network likes most for the numbers 0 - 9:

png

How to run this?

Run python mnist.py to train the MNIST neural network. This saves the model weights as mnist_cnn.pt. Afterwards, run python generate_image.py to loop over the 10 target classes and generate the images in the generated folder.

Blog entry explaining the details

If you want to learn more: check out my blog entry explaining this visualization technique for deep neural networks

About

Visualization of a neural network in PyTorch -> figure out what input images most excite the network. See my blog post for more details.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%