Skip to content

Instantly share code, notes, and snippets.

@antimatter15
Last active June 24, 2026 21:20
Show Gist options
  • Select an option

  • Save antimatter15/3b64be2deb2b1fb4decb899d1aed8204 to your computer and use it in GitHub Desktop.

Select an option

Save antimatter15/3b64be2deb2b1fb4decb899d1aed8204 to your computer and use it in GitHub Desktop.
Scalar Sort Attention

Scalar Sort Attention

Overview

Standard transformer attention computes similarity between queries and keys as a dot product over high-dimensional vectors, then normalizes with softmax to produce attention weights over values. This work proposes replacing the dot-product similarity with a Gaussian kernel over scalar (one-dimensional) projections of queries and keys:

Attention(Q, K, V) = softmax(-(Q - K^T)^2 / τ) @ V

where Q, K ∈ R^(T×1) are learned scalar projections of the input sequence, V ∈ R^(T×d_v) are standard high-dimensional value vectors, and τ > 0 is a learnable temperature parameter controlling the bandwidth of the kernel.

The key properties that follow from this design are:

  1. Attention weights are determined entirely by the scalar distance between a query and each key. Tokens with similar key values receive high attention weight; tokens with dissimilar key values receive exponentially suppressed weight.
  2. Because queries and keys are scalars, the KV cache is sortable by key value. At inference time, a query can locate its most relevant keys via binary search rather than a linear scan — reducing per-step memory access from O(n) to O(log n).
  3. The Gaussian kernel is inherently sparse: weight decays as exp(-(q-k)²/τ), so only keys within a bounded distance of q receive non-negligible weight regardless of sequence length. The attention distribution does not dilute or collapse as context grows.

Mechanism

Projections

Each input token x_t ∈ R^(d_model) is projected to a scalar query, scalar key, and vector value:

q_t = x_t W_Q    (W_Q ∈ R^(d_model × 1))
k_t = x_t W_K    (W_K ∈ R^(d_model × 1))
v_t = x_t W_V    (W_V ∈ R^(d_model × d_v))

The scalar projections W_Q and W_K are unconstrained learned parameters. There is no requirement to tie them (unlike some L2 attention formulations motivated by Lipschitz analysis).

Attention Weights

For a query at position t attending to key at position s, the unnormalized log-weight is:

a_{ts} = -(q_t - k_s)^2 / τ

This is a Gaussian (radial basis function) kernel evaluated at the scalar distance |q_t - k_s|. After causal masking (setting a_{ts} = -∞ for s > t) and row-wise softmax normalization:

w_{ts} = exp(a_{ts}) / Σ_{s'≤t} exp(a_{ts'})

The output at position t is the weighted combination of values:

o_t = Σ_{s≤t} w_{ts} v_s

Temperature

The temperature τ controls the peakedness of the attention distribution. It can be:

  • A single global scalar shared across all heads
  • A per-head scalar
  • A per-token scalar derived from the input (data-dependent bandwidth)

Small τ concentrates weight on only the nearest key; large τ spreads weight across many keys. At τ → 0 the attention approaches hard nearest-neighbor lookup; at τ → ∞ it approaches uniform attention. During training, τ is learned and regularizes the effective attention radius.


The Key Ordering Property

Because queries and keys are scalars, they induce a natural total ordering on the key set. For a given query value q, the attention weight w_{ts} is a unimodal function of k_s — it is maximized when k_s = q and decays symmetrically as |k_s - q| increases. This has two important consequences.

The KV Cache is Sortable

Across the full context, keys {k_1, k_2, ..., k_n} are a sequence of scalars. This sequence can be sorted, and the corresponding values {v_1, v_2, ..., v_n} kept in the same order, forming a sorted KV cache indexed by key value rather than by token position.

Standard attention caches are indexed by position, and computing attention against them requires touching every entry. A sorted scalar key cache can be navigated by value proximity — the structure the Gaussian kernel actually exploits.

Binary Search at Inference

At each autoregressive decode step, given a new query q:

  1. Binary search the sorted key cache to find the insertion point of q. Cost: O(log n).
  2. Retrieve a window of k nearest keys (k/2 on each side of the insertion point). The Gaussian kernel guarantees that the window captures essentially all attention mass for any fixed error tolerance: the probability mass outside a window of radius r decays as erfc(r / sqrt(τ)), independently of n.
  3. Compute softmax(-(q - K_window)^2 / τ) @ V_window using only the k retrieved entries.
  4. Insert the new token's key into the sorted cache in O(log n) using an appropriate data structure (e.g., a sorted array with gap buffer, a skip list, or a balanced BST with O(log n) insert).

The total cost per decode step is O(log n + k), where k is a fixed constant determined by the desired approximation quality. Because the approximation error is exponentially small in the window radius (rather than decaying as a power law), k can be small in practice — on the order of tens to low hundreds — and does not need to grow with n.

This is qualitatively different from approximate nearest neighbor (ANN) methods used in vector-key approaches such as RetrievalAttention. Those methods require distribution-aware corrections because high-dimensional query and key vectors are out-of-distribution from standard ANNS assumptions. With scalar keys, the one-dimensional sorted order is exact — there is no dimensionality to approximate across, and no distribution correction needed.


Training Complexity

During training, the exact full Gaussian attention matrix requires all n^2 pairwise scalar differences, so the naive exact form remains O(n^2) in sequence length.

The sorted scalar-key structure may still enable approximate subquadratic training algorithms, such as truncated local windows in key order, radius-limited attention, kernel feature approximations, Nyström-style approximations, or fast-Gauss-transform-style methods. These are approximation strategies rather than an exact consequence of the Gaussian kernel.


Multi-Head Extension

The scalar Q/K design is compatible with multi-head attention. In a multi-head setting, each head independently learns its own scalar projections W_Q^h, W_K^h and its own temperature τ^h, along with a value projection. Different heads can learn different notions of "similarity" along different 1D axes of the representation space. Head h attends based on the scalar coordinate q_t^h - k_s^h, which can correspond to different latent properties of the tokens.

Grouped Value Attention

Because Q and K projections are scalar, their contribution to the KV cache is negligible — one float per token per head regardless of d_model. The expensive part of the cache is the value vectors. This asymmetry makes it natural to decouple the number of attention patterns (heads) from the number of value projections.

In grouped value attention, H scalar Q/K heads are partitioned into G groups (G << H), where all heads within a group share a single value projection W_V^g ∈ R^(d_model × d_v). Each head h still has its own W_Q^h and W_K^h, so it learns a distinct soft-selection criterion. But heads within the same group read from the same value vectors and differ only in which tokens they attend to.

The KV cache cost under this scheme is:

G × n × d_v   (values, one set per group)
H × n × 1     (keys, one scalar per head — negligible)

compared to H × n × d_v for standard multi-head attention with per-head values. With e.g. H=64 heads and G=4 value groups, 64 distinct learned selection patterns are maintained at the memory cost of 4 standard attention heads.

This is possible specifically because scalar keys are cheap. In standard multi-head attention, increasing H proportionally increases both the key and value cache. Here, increasing H increases only the key cache, which is already negligible, leaving the value cache cost determined entirely by G. The number of heads and the number of value groups are independent design choices.

At inference time, the sorted KV cache is maintained per group rather than per head: each group g holds its shared value vectors sorted by the key values of whatever head is used to index them (or a designated "sort head" per group). Heads within a group binary-search into the same sorted value store using their own scalar query.

Hybrid Architecture

This also admits a natural hybrid: some heads use scalar Q/K with the Gaussian kernel, while other heads use standard dot-product attention. The scalar heads can specialize on local or similarity-based retrieval patterns while the dot-product heads handle global, content-based patterns.


Long Context Behavior

Standard softmax attention over dot-product logits suffers from entropy collapse at long context: as n grows, the maximum logit Q_t K_s^T / sqrt(d) becomes dominated by the log(n) terms in the softmax normalizer, causing the attention distribution to flatten toward uniform. This requires explicit corrections such as the LogN scaling trick, Scalable-Softmax, or YaRN to maintain sharp attention at long context.

Scalar Gaussian attention does not have this problem. The unnormalized weight exp(-(q-k)^2/τ) depends only on the distance |q-k|, not on n. Adding more tokens far from the query (|k_s - q| large) contributes negligible mass to the softmax denominator — they are exponentially suppressed. The attention distribution remains sharp and concentrated around nearby keys regardless of how large n grows, without any explicit scaling correction.

This is not an approximation: it is a direct consequence of the Gaussian kernel's exponential decay. Tokens within the effective bandwidth of the query dominate the softmax; tokens outside it are irrelevant. The effective context width is O(1) in n and controlled by τ.

Hybrid Scalar + Dot-Product Attention

Date: 2026-04-30


Overview

This document describes a variant of scalar sort attention that combines standard dot-product attention with the scalar RBF kernel in a single attention score:

score(q, k) = q·k / sqrt(d)  -  (qs - ks)² / τ
Attention(Q, K, V) = softmax(score) @ V

where q, k ∈ R^d are standard vector queries and keys, qs, ks ∈ R are learned scalar projections, and τ is a learned temperature per head.

This can equivalently be written as a product of two softmax distributions:

p_combined(k|q) ∝ p_dot(k|q) · p_rbf(k|q)
                ∝ exp(q·k/sqrt(d)) · exp(-(qs-ks)²/τ)

The product-of-experts form collapses exactly to the combined-score softmax after renormalization. There is no separate two-softmax computation — the cleanest implementation simply adds the scalar penalty to the existing attention logits before softmax.


Motivation

Pure scalar sort attention (see idea.md) is competitive with standard dot-product attention on TinyStories language modeling but does not clearly win. The scalar bottleneck limits expressivity: each head can only distinguish tokens along one learned 1D axis, and attention patterns are constrained to be unimodal in scalar space.

The hybrid retains the sortability property of scalar keys while restoring the full expressivity of dot-product attention within the sorted window. The scalar term acts as a pre-filter; the dot-product term provides fine-grained discrimination among the filtered candidates.


Inference

The sortability property is preserved. At each autoregressive decode step:

  1. Binary search the sorted scalar key cache to find tokens near qs. Cost: O(log n).
  2. Retrieve a window of k candidates by scalar proximity.
  3. Load the full vector keys for those k candidates.
  4. Compute softmax(q·k/sqrt(d) - (qs-ks)²/τ) over the k candidates.
  5. Read the corresponding values and compute output.
  6. Insert the new token's scalar key into the sorted cache.

Total cost per step: O(log n + k·d), compared to O(n·d) for standard attention.

The approximation quality depends on whether the top-scoring dot-product tokens are concentrated within the scalar window. If the model learns aligned representations — where semantic similarity correlates with proximity in scalar space — the filter is effective. If the two metrics are orthogonal, high-dot-product tokens outside the window are missed.

The cache requires both scalar keys (1 float/token/head, negligible) and full vector keys (d_head floats/token/head). The memory win of pure scalar sort attention is not retained — only the compute access pattern changes.


Relationship to Pure Scalar Sort Attention

Property Pure scalar sort Hybrid
Attention score -(qs-ks)²/τ q·k/sqrt(d) - (qs-ks)²/τ
Attention pattern Unimodal in scalar space Unimodal in scalar space, modulated by dot product
Key cache Scalar only (negligible) Scalar + full vector key
Inference O(log n + k) O(log n + k·d)
Expressivity One learned 1D axis per head Full vector similarity within scalar window
Scalar alignment required No (scalar determines everything) Yes (scalar must predict dot-product relevance)

Bandwidth and Window Size

Low tau = narrow bandwidth = fewer keys accessed.

The unnormalized weight for a key at scalar distance d = qs - ks is exp(-d²/τ). For a key to fall below threshold ε, you need d > sqrt(-τ · log ε). The effective radius scales as sqrt(τ). The expected number of keys inside this radius:

E[keys accessed] ≈ 2 · sqrt(-τ · log ε) · ρ(qs)

where ρ(qs) is the local density of scalar keys near qs. To minimize window size: keep τ small and keep key density low (spread keys uniformly).

The concern from experiments: trained τ consistently rises from initialization (0.03–0.1) to 0.35–0.42 during training on TinyStories. Without intervention this same drift will occur in the hybrid, eroding the inference efficiency advantage.


Regularization for Efficient Inference

Three terms work together to make the scalar coordinate meaningfully distribute the sorted cache.

Term 1: Attention-Weighted Scalar Distance (primary)

L_align = (1/n²) Σ_q Σ_k  stopgrad(w(q,k)) · (qs - ks)²

This penalizes attention mass placed on tokens far in scalar space. The stop-gradient through w(q,k) ensures gradient flows only to the scalar projections, not to the vector Q/K projections. The scalar space is supervised to predict what the vector attention already cares about — one-way distillation from vector attention → scalar projections.

Per-query bias-variance decomposition:

Σ_k w(q,k) · (qs - ks)²  =  Var_w(ks)  +  (qs - E_w[ks])²
                               ^^^^^^^^       ^^^^^^^^^^^^^^^
                               spread of      mean attended key
                               attended keys  offset from query

Minimizing this simultaneously encourages focused attention (low scalar variance within attended set) and centered attention (mean attended scalar near query scalar). Both directly compress the required inference window. No additional asymptotic training cost: the attention weight matrix is already computed.

Gradient behavior: for a key outside the target window with high combined weight, the gradient pushes ks toward qs, bringing it into the window. The tau term closes the escape hatch of widening the bandwidth to trivially satisfy this loss.

Implementation:

w_detached = w_combined.detach()                               # (B, H, T, T)
sq_dist = (qs.unsqueeze(-1) - ks.unsqueeze(-2)) ** 2          # (B, H, T, T)
L_align = (w_detached * sq_dist).mean()

Term 2: Key Spread (prevents collapse)

L_align alone is minimized trivially if all scalar keys collapse to the same value — every query is near every key, L_align = 0, and the window contains all n tokens. A spread term is required.

Wasserstein distance to uniform over a target range [a, b]:

ks_sorted = sort(ks_within_context)        # (n,)
targets   = linspace(a, b, n)              # uniform quantile targets
L_spread  = mean((ks_sorted - targets)²)  # gradient through ks values only, not sort order

Gradient does not flow through the sort operation. The target range can be fixed (e.g., [-3, 3]) or adaptive per batch as [mean(ks) ± 3·std(ks)]. This is strictly stronger than -Var(ks), which allows bimodal or non-uniform distributions with dense local regions.

Term 3: Tau Regularization (prevents bandwidth growth)

L_tau = λ_tau · Σ_heads log(1 + τ_h)

Log rather than linear to avoid extreme gradient magnitudes at small τ. Can alternatively be implemented as a hard cap: τ = softplus(τ_raw) with a max value set based on acceptable window size.

Complete Regularizer

L_reg = λ_a · L_align  +  λ_s · L_spread  +  λ_t · L_tau
Pressure Effect
L_align High-weight (q,k) pairs pushed toward scalar proximity
L_spread All scalar keys pushed toward uniform spacing
L_tau Bandwidth stays narrow
Combined Scalar space learns a 1D ordering where attended tokens cluster; overall distribution stays spread; local density per query stays low

At convergence, the scalar projections learn a 1D metric approximating the dot-product attention structure. Tokens that attend to each other map to nearby scalar coordinates; the overall key distribution remains uniformly spread. The result is an ordered cache where each query has a small, semantically coherent neighborhood.

Tuning notes:

  • Start with λ_align as the dominant term; introduce λ_spread after initial convergence.
  • Monitor inference approximation error directly (fraction of combined attention mass captured by top-k nearest scalar neighbors) as the primary metric — this is what determines inference cost.
  • Use adaptive range for L_spread across different model sizes and training stages.

Attention Sinks and Token Disposal

Why Attention Sinks Exist

Softmax must sum to 1. When a token does not need to retrieve anything specific, attention mass still needs a destination. Models learn to dump it on tokens with approximately neutral value vectors — typically BOS or other early structural tokens. These absorb garbage mass not because they are semantically relevant, but because softmax forces mass somewhere and the model learns the safest garbage destination.

How Scalar Sort Attention Changes the Sink Geometry

In standard attention, garbage mass can be routed to any token via a learned Q/K direction. In pure scalar sort attention, mass concentrates on the nearest scalar neighbor, regardless of semantic relevance. There is no principled sink unless one is explicitly placed nearby. If no relevant token exists near qs, whatever token is nearest in scalar space absorbs all the mass — potentially harmful content-bearing tokens.

Natural solution: a learned scalar sink token — a persistent entry in the KV cache with a fixed scalar key (e.g., ks_sink = 0) and a neutral value vector. Query projections learn to map unfocused queries toward qs ≈ 0. Binary search always includes this entry as a baseline candidate at zero marginal cost.

The Hybrid's Natural Solution

In the hybrid, the combined score has two independent suppression mechanisms: scalar distance kills tokens far from qs, and dot product kills tokens with low semantic similarity. A token is truly ignorable only when it fails both filters. This reduces the sink problem relative to pure scalar attention: nearby-but-irrelevant tokens are suppressed by low dot product. The scalar sink token remains useful as a fallback but the failure mode is less severe.

Emergent Token Disposal via Alignment Regularization

With L_align active, high-weight pairs are pushed toward scalar proximity. Consistently ignored tokens receive no pull from L_align. Over training, ignored tokens drift toward peripheral regions of scalar space that no query occupies. At inference, binary search for any query never reaches them — they are effectively disposed of without explicit eviction.

This is a structural analog of H2O (Heavy Hitter Oracle) eviction: instead of tracking cumulative attention scores post-hoc, the geometry of the trained scalar space encodes relevance directly.

Explicit Eviction Policies in Scalar Space

  • Geometric eviction: Tokens with scalar keys outside a rolling window [min_recent_qs - r, max_recent_qs + r] are eviction candidates. O(log n) range comparison, no attention tracking needed.
  • Density-based compression: In high-density regions of scalar space, adjacent tokens have nearly identical attention patterns. Replace k adjacent tokens with one representative (value = weighted average, scalar key = centroid). 1D histogram compression of the value cache.
  • Sink zone reservation: Reserve a small fixed scalar range as a permanent sink zone. All other tokens are subject to eviction.

The Structural Picture

At convergence, the scalar space develops a natural topology:

peripheral cold zone  |  ... active hot zone ...  |  peripheral cold zone
   (ignored tokens)       (relevant, high-attn)       (ignored tokens)
        ↑                          ↑                         ↑
  never retrieved           binary search               never retrieved
                             lands here

The key difference from standard attention sinks: in standard attention, sinks are an emergent workaround located arbitrarily in embedding space. In scalar sort attention, sinks are a geometric feature of a 1D ordered cache — explicitly locatable by binary search, trivially preserved under any eviction policy, and designable by reserving a scalar region.


Prior Art

1. Standard Attention Already Implicitly Combines Dot Product and RBF

Implicit Kernel Attention (Song et al., AAAI 2021, arXiv:2006.06147) derives that standard scaled dot-product attention decomposes into two multiplicative components:

α_ij ∝  exp(-‖qi - kj‖² / 2√dk)   ×   exp((‖qi‖² + ‖kj‖²) / 2√dk)
              ↑                                     ↑
        RBF kernel (similarity)            magnitude / importance

This follows from the algebraic identity:

q·k = -½‖q-k‖² + ½‖q‖² + ½‖k‖²

Row-wise softmax removes the query-norm term. Standard attention is therefore approximately Gaussian distance attention plus a key-norm bias. The hybrid makes the distance component explicit and independent via a learned 1D projection rather than letting it be entangled with the key-norm term.

Scaled RBF Attention (Pisoni, 2025, pisoni.ai; arXiv:2310.18805 for a related formulation) exploits this identity further: replacing dot-product with −γ‖Q−K‖² produces score = 2γ(Q·K) − γ‖K‖², which is standard dot-product attention with an explicit L2 penalty on key magnitude. This is a combined score (dot product + magnitude penalty) but applies to full-dimensional vectors uniformly, and does not use a separate scalar projection for sorted-cache inference.

2. Additive Score Biases: ALiBi, T5, TUPE, Transformer-XL

The structural pattern combined_score = dot_product_score + secondary_score is well-established, primarily for positional encoding.

ALiBi (Press et al., ICLR 2022, arXiv:2108.12409) adds a distance-based penalty to dot-product logits before softmax:

score = q·k / √d  +  m · (−|position_i − position_j|)

where m is a head-specific slope. This is structurally identical to the hybrid with (qs-ks)²/T replaced by a linear penalty on token position. ALiBi was motivated by length generalization; the hybrid is motivated by inference efficiency via content-derived scalar projections.

T5 relative position bias (Raffel et al., 2020) adds a learned scalar bias indexed by bucketed relative position to each attention logit. Same structural pattern; the bias is learned over a lookup table rather than an analytic function.

Transformer-XL / XLNet (Dai et al., 2019; Yang et al., 2019) decompose the attention score into four terms — content-content, content-position, position-content, position-position — and sum them. The hybrid's q·k/√d − (qs-ks)²/T is a two-component decomposition in the same spirit, where the second term uses a content-derived scalar coordinate instead of positional encoding.

TUPE (Ke et al., ICLR 2021) explicitly separates content and positional correlations with independent parameterizations and sums them to form the attention logit. This prevents the interference that arises from adding positional embeddings to token embeddings before attention.

The pattern of score = f_content(q,k) + f_structural(q,k) is therefore well-established. The hybrid's contribution is making f_structural a learned RBF on content-derived scalar projections rather than a function of token position, and connecting this to inference efficiency via sorted caches.

3. Gaussian Spatial Bias on Content Attention: SMCA

Spatially Modulated Co-Attention (SMCA) (Gao et al., ICCV 2021, arXiv:2101.07448) proposes combining dot-product attention with a Gaussian spatial weight map:

score = Kᵀ·Q / √d  +  log G

where G is a predicted 2D Gaussian centered at an estimated bounding box location. This is structurally the closest match to the hybrid in the literature. The combined score is dot product plus the log of a Gaussian kernel, which is exactly q·k/√d − (center_distance)²/σ². The Gaussian provides spatial locality while the dot product provides content discrimination.

The difference from the hybrid: SMCA uses predicted 2D spatial coordinates derived from cross-attention context (object position in an image), not a learned scalar projection of the token itself. The Gaussian in SMCA is a spatial prior about where relevant objects are, not a learned ordering of the key space. SMCA does not claim or use inference speedup via sorted caches.

4. Learned Log-Prior on Attention Logits: GOAT

GOAT: Generalized Optimal Transport Attention with Trainable Priors (Litman and Guo, arXiv:2601.15380, January 2026) derives attention through entropic optimal transport and shows that standard attention corresponds to transport regularized by a uniform prior. GOAT adds a learnable continuous prior as an additive log term on the attention logits:

score_j = content_score_j / τ  +  log π_j

The prior log π_j is encoded via a factorization of query and key vectors into content and prior subspaces, absorbed into the dot product without materializing a dense bias matrix (FlashAttention-compatible). The prior uses learnable spectral weights over truncated Fourier series encoding relative positions, and can express both attraction and repulsion patterns.

GOAT is close in motivation to the hybrid: combining a content dot-product score with a learned structural prior. The prior in GOAT encodes positional structure; the scalar projection in the hybrid encodes content-derived order for retrieval. GOAT also provides a principled explanation for attention sinks as optimal transport defaults under low semantic signal, and decouples sink behavior from content via key-only bias terms — paralleling the scalar sink token proposal above.

GOAT is very recent (January 2026) and represents convergent thinking: the field is independently arriving at "dot-product attention score plus a learned second term" as a general framework.

5. Pure Distance Attention Alternatives

Several works replace dot-product attention entirely with distance-based scores:

  • L2 / RBF Self-Attention (Kim et al., ICML 2021, arXiv:2006.04710): Score −‖q−k‖² for Lipschitz analysis. Uses full-dimensional Q/K vectors, not scalar projections.
  • Inverse Distance Weighting Attention (McCarter, 2023, arXiv:2310.18805): Score −log(ε + ‖q−k‖²), simplifying to IDW interpolation. Pure distance replacement.
  • SOFT (Lu et al., 2021, arXiv:2110.11945): Gaussian kernel for softmax-free linear complexity.
  • Skyformer (Chen et al., 2021, arXiv:2111.00035): Gaussian kernel with Nyström approximation.

All are pure replacements rather than combinations, and none exploits scalar projections for sorted-cache inference.

6. Scalar / 1D Projections for Efficient Attention

Sliced ReLU Attention (Boufadene et al., arXiv:2512.11411, December 2024) is the most structurally related work to the scalar sort attention project. It uses 1D projections of key-query differences with a ReLU-derived kernel and exploits sorting for O(n log n) computation via cumulative sums on sorted sequences. It provides theoretical expressivity arguments (contextual universal approximation) for 1D sliced mechanisms.

The hybrid is complementary: instead of replacing dot-product with a 1D ReLU kernel, it adds a 1D Gaussian term to the existing dot-product score. Sliced ReLU enables exact O(n log n) training due to the finite-support nature of the ReLU kernel; the Gaussian softmax kernel does not admit finite prefix statistics and does not inherit this property.

7. Gating After Attention

Gated Attention for LLMs (Qiu et al., NeurIPS 2025 Oral and Best Paper, arXiv:2505.06708) applies a head-specific sigmoid gate after SDPA output:

output = sigmoid(gate_query) ⊙ SDPA(Q, K, V)

This modulates the output rather than the score and is not a second score component. However, the motivation overlaps: introducing sparsity, addressing attention sinks, and improving long-context generalization. The NeurIPS 2025 Best Paper designation indicates the field considers learned gating alongside dot-product attention to be a significant direction. The paper includes ablations at 15B MoE and 1.7B dense scale on 3.5 trillion tokens.

Gated Sparse Attention (arXiv:2601.15305, integrated into Qwen3-Next) uses dual gates on values and outputs alongside sparse token selection. These demonstrate production-scale validation of combining a secondary learned signal with dot-product attention.

8. Elliptical Attention

Elliptical Attention (Nielsen et al., NeurIPS 2024, arXiv:2406.13770) replaces the standard Euclidean distance in attention with a Mahalanobis distance:

score = Q · M · Kᵀ / √D

where M is a learned diagonal matrix stretching the feature space according to coordinate-wise relevance. This is a modified inner product rather than a sum of two separate score components. It addresses representation collapse and adversarial robustness but does not produce a sorted-cache inference argument.


Summary of Novelty Positioning

Property Hybrid Proposal Closest Prior
Dot product + distance score in a single softmax Yes SMCA (2D spatial, ICCV 2021) — exact structural match in vision domain
Learned content-based scalar projection Yes Not found
Scalar projection sortable for O(log n) inference Yes Sliced ReLU Attention (different kernel, no dot-product hybrid)
Alignment regularizer (scalar ↔ attention) Yes Not found
Emergent geometric token eviction from sorted cache Yes Not found
Additive secondary score on dot-product logits Yes ALiBi, T5, TUPE, GOAT (all use position-derived terms)
Product-of-experts framing Yes Standard in Bayesian learning; not found in attention literature
Attention sink as geometric scalar-space feature Yes Not found; GOAT addresses sinks differently

The combination that does not appear in the literature: a dot-product attention score augmented with a learned scalar RBF term, where the scalar projection is content-derived and trained to enable sorted-cache approximate inference, with an explicit alignment regularizer to enforce correspondence between the two score components. The individual components all have prior art. The closest single paper in structure is SMCA (dot product + log Gaussian score in vision), which arrived at the same formula independently for a different task with no inference efficiency argument.


Proposed Experiments

1. Baseline hybrid without regularizer

Replace pure scalar sort with combined score q·k/√d − (qs-ks)²/τ. No regularizer. Compare to pure sort attention and pure standard attention on TinyStories ctx-128 at 100k steps. Expected result: better than pure sort, likely close to standard attention.

2. Alignment measurement before and after regularization

For each (q,k) pair sorted by combined attention weight, plot |qs−ks|. If alignment regularizer is working, high-weight pairs should have small scalar distance. Run without and with L_align and compare.

3. Regularizer ablation

Compare no regularization / L_align only / L_align + L_tau / L_align + L_spread / all three terms. Metrics: val loss, tau trajectory, inference approximation error at k = 8, 16, 32, 64.

4. Inference approximation error vs. context length

For a trained hybrid model, sweep context length from 128 to 4096 and measure the fraction of combined attention mass captured by top-k scalar neighbors at fixed k. This is the key systems experiment: does error grow with context length, or does the regularizer keep the window effectively constant?

5. Scalar sink token

Add a learned sink token with fixed ks = 0. Measure whether it absorbs garbage mass and whether tau dynamics change.

6. Comparison to GOAT

Implement GOAT's prior-augmented attention and compare on TinyStories. GOAT adds a positional prior; the hybrid adds a content-derived scalar prior. Are they complementary? Can both be combined?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment