Easy-to-use UI based tool that visualizes the internal layers and activations of any networks that takes image as in input ,using Pytorch and PyQt. Automatically reads the intermediation activations using hooks and displays them in a user-friendly GUI.
pip install git+https://github.com/Param-Uttarwar/neural-network-visualizer.git
- Register modules/networks
- Define forward_function
- Run
- See examples/resnet18_visualizer.py for the full code.
import torch.nn.functional as F
from torchvision.models import resnet18
import torch
import nn_viz.common.world as world
from nn_viz import NetworkVisualizer
world.set_device() # Automatically defaults to cuda if available
encoder = resnet18(pretrained=True)
# Setup
viz = NetworkVisualizer()
viz.register_module(encoder, depth=1) # Adds hooks to save intermediate outputs, depth = 2 for two level deep activations
def forward_fn(x: torch.Tensor): # Always take in [-1,1] BCHW tensor
x = F.interpolate(x, (224, 224)) # Resizing, other preprocessing should be done here
z = encoder(x)
return {
'z': z} # Always return a dictionary since intermediate feature maps are added to this
# Set your forward function
viz.set_forward(forward_fn)
viz.run()
from nn_viz import NetworkVisualizer
import torch.nn.functional as F
import torch
encoder = SomeEncoder() # nn.Module
decoder = SomeDecoder() # nn.Module
# Setup
viz = NetworkVisualizer()
viz.register_module(encoder)
viz.register_module(decoder)
def forward_fn(x: torch.Tensor): # Always take in [-1,1] BCHW tensor
x = F.interpolate(x, (224, 224)) # Resizing, other preprocessing should be done here
z = encoder(x)
g_img = decoder(z)
return {'img': g_img} # Always return a dictionary since intermediate feature maps are added to this
# Set your forward function
viz.set_forward(forward_fn)
viz.run()