This project demonstrates a complete implementation of a Convolutional Neural Network (CNN) in Java from scratch, without any external libraries. It includes the full training and inference pipeline for the FashionMNIST dataset, highlighting core deep learning algorithms and data structures.
At the end of just 1 epoch the model is able to achieve an impressive ~85+% accuracy with 90+ confidence in most of it's predictions.
- Clone the repository
git clone git@github.com:Shaesh-Kuiper/JavaCNN-FromScratch.git cd JavaCNN-FromScratch - Compile all .java files
javac -encoding UTF-8 *.java - Run the training
java Trainer
NOTE: Training time depends on the number of epochs (on average, 1 epoch takes 5–10 minutes). After training, a file named trained-model.bin will be saved containing the learned model.
- Run inference
java Inference trained-model.bin data\FashionMNIST\raw\t10k-images-idx3-ubyte data\FashionMNIST\raw\t10k-labels-idx1-ubyte
This runs prediction on a single sample from the test set. It takes three arguments:
- Path to the trained model file
- Path to the test images file
- Path to the test labels file
JavaCNN is a learning project designed to understand how CNNs, gradient descent, and backpropagation work internally by building everything from the ground up. It covers:
- CNN architecture implementation
- Manual backpropagation
- Multithreaded execution
- Data loading and processing
- Training and evaluation loop
- Model inference with softmax visualization
Implements a convolutional layer:
- 3×3 kernels, stride 1, valid padding
- Forward and backward passes
- Xavier initialization
- Parallelized with
ExecutorServicefor output channels
Implements 2×2 max pooling:
- Downsamples feature maps
- Stores max positions for backprop
- Sparse gradient propagation in backward pass
Fully connected layer:
- Manual matrix multiplication
- He initialization
- Backprop with gradient update using SGD
Applies ReLU activation:
- Element-wise
max(0, x) - Gradient zeroed for negatives during backward pass
Handles loss calculation:
- Softmax + Negative log-likelihood
- Stable computation via max-subtraction trick
- Computes softmax gradient manually
Custom thread pool:
- 4-core fixed thread executor
- Shared by layers like
Conv2D - Optimizes CPU usage for forward/backward passes
Loads FashionMNIST data:
- Parses IDX format (images + labels)
- Normalizes pixels to [0, 1]
- Wraps data in a custom
DataSetobject
Defines the CNN architecture:
- 2 × (Conv → ReLU → Conv → ReLU → Pool)
- Final dense layer for classification
- Includes flattening from 3D to 1D
- Supports full forward and backward training pass
Trains the CNN:
- Loads and shuffles data
- Trains over mini-batches (batch size = 64)
- 2 training epochs using SGD
- Evaluates accuracy on test set
- Saves model to disk
Performs image classification:
- Loads a sample image and the trained model
- Runs forward pass and softmax
- Displays image and predicted label in GUI
- Gradient descent and backpropagation implemented manually
- Uses only native Java arrays (1D–4D)
- Parallel execution for performance
- Classic DSA usage: arrays, loops, task queues, producer-consumer threading
- No external ML or math libraries used
- End-to-end CNN implementation in Java
- Educational focus: understand the math and logic behind training neural networks
- Works entirely offline
- Ideal for learners exploring DL fundamentals
- FashionMNIST
- 28×28 grayscale images
- 10 classes (e.g., sneakers, t-shirts, bags)
- Data loading from IDX binary format
This project is aimed at demystifying deep learning by writing everything yourself. It builds the bridge between theoretical understanding and practical implementation — completely in Java.
This project is open for learning and inspiration. If you use or build upon this work, please give credit to the original author:
Created by [Sarvesh R]
GitHub: Shaesh-Kuiper
LinkedIn: Sarvesh R
Email: shvoyager2k4@gmail.com
A simple mention or link back is appreciated! 💙