An unofficial PyTorch implementation of Scalable-Softmax (Ssmax) from the paper "Scalable-Softmax Is Superior for Attention" (Nakanishi, 2025).
ScalableSoftmax is a drop-in replacement for standard Softmax that helps prevent attention fading in transformers by incorporating input size scaling. This helps maintain focused attention distributions even with large input sizes.
pip install scalable-softmax
import torch
from scalable_softmax import ScalableSoftmax
# Initialize with default parameters
ssmax = ScalableSoftmax()
# Or customize parameters
ssmax = ScalableSoftmax(
s=0.43, # scaling parameter
learn_scaling=True, # make scaling parameter learnable
bias=False # whether to use bias term
)
# Apply to input tensor
x = torch.randn(batch_size, sequence_length)
output = ssmax(x)
- Drop-in replacement for standard softmax
- Learnable scaling parameter
- Optional bias term
- Maintains focused attention with large inputs
@article{nakanishi2025scalable,
title={Scalable-Softmax Is Superior for Attention},
author={Nakanishi, Ken M.},
journal={arXiv preprint arXiv:2501.19399},
year={2025}
}
MIT License