Skip to content

MikeyBeez/hybrid-transformer-experiment

Repository files navigation

Static Functions Can Approximate Deep Attention Layers

This repository contains the code accompanying the paper "Static Functions Can Approximate Deep Attention Layers".
It demonstrates that small fixed MLPs can replace trained Transformer attention blocks with minimal accuracy loss, while improving runtime speed.


🧩 Project Overview

The repository provides:

  • A minimal GPT-like baseline model (train_base_model.py)
  • A mechanism for extracting intermediate representations and training static approximator functions (train_approximators.py)
  • A hybrid model that combines trained Transformer layers with frozen approximators (train_end_to_end.py)
  • A benchmark script comparing runtime performance (benchmark.py)

All experiments use the Tiny Shakespeare dataset.


⚙️ Environment Setup

Requirements

python >= 3.10
torch >= 2.1.0
tiktoken
requests
Install dependencies:
pip install torch tiktoken requests
(Optional) For GPU acceleration:
pip install nvidia-cuda-toolkit
📁 Repository Structure
.
├── config.py
├── model.py
├── prepare_data.py
├── train_base_model.py
├── train_approximators.py
├── train_end_to_end.py
├── benchmark.py
└── input.txt  (auto-downloaded)

📚 Data Preparation
The dataset (Tiny Shakespeare) is automatically downloaded.

python prepare_data.py

This creates input.txt (~1MB). No external credentials or APIs are needed.

🧠 Train the Base Model
Train a small GPT baseline:

python train_base_model.py

This will produce:

base_gpt_model.pt

🔁 Train the Static Approximators
Extract layer inputs/outputs and train MLP approximators:

python train_approximators.py

This creates:

approximator_data.pt
approximator_L2.pt
approximator_L3.pt

🧩 Train Hybrid Model End-to-End
Train a model that mixes trained attention with static approximations:

python train_end_to_end.py

This verifies that the hybrid model maintains accuracy within the reported margin.

⚡ Benchmark
Compare inference speed:

python benchmark.py

Example output:
--- Benchmarking Standard GPT with Context Size 256 ---
Prefill Time: 0.0050s
Decode Speed: 218.08 tok/s

--- Benchmarking Hybrid GPT with Context Size 256 ---
Prefill Time: 0.0029s
Decode Speed: 344.26 tok/s
🔍 Reproducibility Notes
Category	Status
Code completeness	✅
Data accessibility	✅ Auto-downloads
Random seed	Optional (torch.manual_seed(1337))
Hardware	Works on CPU or GPU
Expected result reproducibility	✅ ±0.05 validation loss variance
Reproducing all results should take under 1 hour on a GPU.

📜 Citation
If you use or reference this work, please cite:
Micheal Bee (2025). Static Functions Can Approximate Deep Attention Layers.

🧾 License
MIT License.
Use, modify, and distribute freely with attribution.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages