Skip to content

Easy-to-use UI based tool that visualizes the internal layers and activations of any Pytorch network that takes image as input , built using PyQt

License

Notifications You must be signed in to change notification settings

Param-Uttarwar/neural-network-visualizer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

neural-network-visualizer

Easy-to-use UI based tool that visualizes the internal layers and activations of any networks that takes image as in input ,using Pytorch and PyQt. Automatically reads the intermediation activations using hooks and displays them in a user-friendly GUI.

GUI Screenshot

Installation

pip install git+https://github.com/Param-Uttarwar/neural-network-visualizer.git

Usage

  • Register modules/networks
  • Define forward_function
  • Run

Example (ResNet18)

import torch.nn.functional as F
from torchvision.models import resnet18
import torch
import nn_viz.common.world as world
from nn_viz import NetworkVisualizer

world.set_device()  # Automatically defaults to cuda if available

encoder = resnet18(pretrained=True)

# Setup
viz = NetworkVisualizer()
viz.register_module(encoder, depth=1)  # Adds hooks to save intermediate outputs, depth = 2 for two level deep activations


def forward_fn(x: torch.Tensor):  # Always take in [-1,1] BCHW tensor
    x = F.interpolate(x, (224, 224))  # Resizing, other preprocessing should be done here
    z = encoder(x)
    return {
        'z': z}  # Always return a dictionary since intermediate feature maps are added to this


# Set your forward function
viz.set_forward(forward_fn)
viz.run()

Example (Custom Encoder Decoder)

from nn_viz import NetworkVisualizer
import torch.nn.functional as F
import torch

encoder = SomeEncoder()  # nn.Module
decoder = SomeDecoder()  # nn.Module

# Setup
viz = NetworkVisualizer()
viz.register_module(encoder)
viz.register_module(decoder)


def forward_fn(x: torch.Tensor):  # Always take in [-1,1] BCHW tensor
    x = F.interpolate(x, (224, 224))  # Resizing, other preprocessing should be done here
    z = encoder(x)
    g_img = decoder(z)
    return {'img': g_img}  # Always return a dictionary since intermediate feature maps are added to this


# Set your forward function
viz.set_forward(forward_fn)
viz.run()

About

Easy-to-use UI based tool that visualizes the internal layers and activations of any Pytorch network that takes image as input , built using PyQt

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages