Skip to content

Commit 341cf4a

Browse files
committed
initial commit
1 parent 6d5b31b commit 341cf4a

7 files changed

+372
-1
lines changed

.gitignore

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# UV
98+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
#uv.lock
102+
103+
# poetry
104+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105+
# This is especially recommended for binary packages to ensure reproducibility, and is more
106+
# commonly ignored for libraries.
107+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108+
#poetry.lock
109+
110+
# pdm
111+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112+
#pdm.lock
113+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114+
# in version control.
115+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116+
.pdm.toml
117+
.pdm-python
118+
.pdm-build/
119+
120+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121+
__pypackages__/
122+
123+
# Celery stuff
124+
celerybeat-schedule
125+
celerybeat.pid
126+
127+
# SageMath parsed files
128+
*.sage.py
129+
130+
# Environments
131+
.env
132+
.venv
133+
env/
134+
venv/
135+
ENV/
136+
env.bak/
137+
venv.bak/
138+
139+
# Spyder project settings
140+
.spyderproject
141+
.spyproject
142+
143+
# Rope project settings
144+
.ropeproject
145+
146+
# mkdocs documentation
147+
/site
148+
149+
# mypy
150+
.mypy_cache/
151+
.dmypy.json
152+
dmypy.json
153+
154+
# Pyre type checker
155+
.pyre/
156+
157+
# pytype static type analyzer
158+
.pytype/
159+
160+
# Cython debug symbols
161+
cython_debug/
162+
163+
# PyCharm
164+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166+
# and can be added to the global gitignore or merged into this file. For a more nuclear
167+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
168+
#.idea/
169+
170+
# Ruff stuff:
171+
.ruff_cache/
172+
173+
# PyPI configuration file
174+
.pypirc

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2025 Greg DeVosNouri
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+56-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,56 @@
1-
# Scalable-Softmax
1+
# ScalableSoftmax
2+
3+
An unofficial PyTorch implementation of Scalable-Softmax (Ssmax) from the paper "Scalable-Softmax Is Superior for Attention" (Nakanishi, 2025).
4+
5+
## Overview
6+
7+
ScalableSoftmax is a drop-in replacement for standard Softmax that helps prevent attention fading in transformers by incorporating input size scaling. This helps maintain focused attention distributions even with large input sizes.
8+
9+
## Installation
10+
11+
```bash
12+
pip install scalable-softmax
13+
```
14+
15+
## Usage
16+
17+
```python
18+
import torch
19+
from scalable_softmax import ScalableSoftmax
20+
21+
# Initialize with default parameters
22+
ssmax = ScalableSoftmax()
23+
24+
# Or customize parameters
25+
ssmax = ScalableSoftmax(
26+
s=0.43, # scaling parameter
27+
learn_scaling=True, # make scaling parameter learnable
28+
bias=False # whether to use bias term
29+
)
30+
31+
# Apply to input tensor
32+
x = torch.randn(batch_size, sequence_length)
33+
output = ssmax(x)
34+
```
35+
36+
## Features
37+
38+
- Drop-in replacement for standard softmax
39+
- Learnable scaling parameter
40+
- Optional bias term
41+
- Maintains focused attention with large inputs
42+
43+
## Citation
44+
45+
```bibtex
46+
@article{nakanishi2025scalable,
47+
title={Scalable-Softmax Is Superior for Attention},
48+
author={Nakanishi, Ken M.},
49+
journal={arXiv preprint arXiv:2501.19399},
50+
year={2025}
51+
}
52+
```
53+
54+
## License
55+
56+
MIT License

main.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
from scalable_softmax import ScalableSoftmax
3+
4+
# Initialize with default parameters
5+
smax = ScalableSoftmax()
6+
7+
# Or customize parameters
8+
smax = ScalableSoftmax(
9+
s=0.43, # scaling parameter
10+
learn_scaling=True, # make scaling parameter learnable
11+
bias=False # whether to use bias term
12+
)
13+
14+
# Apply to input tensor
15+
batch_size = 32
16+
sequence_length = 128
17+
x = torch.randn(batch_size, sequence_length)
18+
output = smax(x)

pyproject.toml

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
[project]
2+
name = "scalable_softmax"
3+
version = "0.1.0"
4+
description = "PyTorch implementation of Scalable-Softmax for attention mechanisms"
5+
authors = [
6+
{ name = "Greg DeVosNouri", email = "gdevos010@gamil.com" }
7+
]
8+
readme = "README.md"
9+
requires-python = ">= 3.9"
10+
license = { file = "LICENSE" }
11+
keywords = ["pytorch", "deep-learning", "attention", "transformer"]
12+
13+
classifiers=[
14+
'Development Status :: 4 - Beta',
15+
'Intended Audience :: Developers',
16+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
17+
'License :: OSI Approved :: MIT License',
18+
'Programming Language :: Python :: 3.9',
19+
]
20+
21+
dependencies = [
22+
"torch>=1.8",
23+
]
24+
25+
[project.urls]
26+
Homepage = "https://github.com/gdevos010/Scalable-Softmax"
27+
Repository = "https://github.com/gdevos010/Scalable-Softmax"
28+
29+
[project.optional-dependencies]
30+
dev = [
31+
"ruff"
32+
]
33+

src/__init__.py

Whitespace-only changes.

src/scalable_softmax.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import math
5+
6+
class ScalableSoftmax(nn.Module):
7+
"""Scalable-Softmax (SSMax) implementation from the paper
8+
'Scalable-Softmax Is Superior for Attention'.
9+
10+
This is a drop-in replacement for standard Softmax that helps prevent attention
11+
fading in transformers by incorporating input size scaling. The scaling helps maintain
12+
focused attention distributions even with large input sizes.
13+
14+
Args:
15+
s (float, optional): Scaling parameter that controls attention focusing strength.
16+
Lower values (e.g. 0.1) produce sharper attention, higher values (e.g. 1.0)
17+
produce softer attention. Default: 0.43 as used in paper.
18+
learn_scaling (bool, optional): If True, make scaling parameter learnable.
19+
Default: True
20+
bias (bool, optional): If True, adds a learnable bias term. The paper found
21+
that while bias helps training, it can hurt length generalization.
22+
Default: False
23+
24+
Shape:
25+
- Input: (*, N) where * is any number of dimensions and N is the sequence length
26+
- Output: Same shape as input
27+
"""
28+
def __init__(self, s: float = 0.43, learn_scaling: bool = True, bias: bool = False):
29+
super().__init__()
30+
31+
# Initialize scaling parameter
32+
if learn_scaling:
33+
self.s = nn.Parameter(torch.tensor(s, dtype=torch.float))
34+
else:
35+
self.register_buffer('s', torch.tensor(s, dtype=torch.float))
36+
37+
# Optional bias parameter
38+
if bias:
39+
self.b = nn.Parameter(torch.zeros(1))
40+
else:
41+
self.b = None
42+
43+
def forward(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
44+
"""Forward pass applying SSMax along specified dimension.
45+
46+
Args:
47+
x (torch.Tensor): Input tensor
48+
dim (int): Dimension along which to apply SSMax. Default: -1
49+
50+
Returns:
51+
torch.Tensor: Output tensor with same shape as input
52+
"""
53+
# Get size of dimension we're applying SSMax to
54+
n = x.size(dim)
55+
56+
# Apply scaling factor based on input size
57+
if self.b is not None:
58+
# Version with bias term
59+
x_scaled = (self.s * math.log(n) + self.b) * x
60+
else:
61+
# Standard version from paper
62+
x_scaled = self.s * math.log(n) * x
63+
64+
# Apply standard softmax
65+
return F.softmax(x_scaled, dim=dim)
66+
67+
def extra_repr(self) -> str:
68+
"""String representation of module."""
69+
return f's={self.s.item():.3f}, bias={self.b is not None}'
70+

0 commit comments

Comments
 (0)