Visualizzazione concettuale del meccanismo di attenzione Parallax per AI

Parallax: nuovo meccanismo di attenzione AI che unisce softmax a una correzione covariante

Dal 2017, il meccanismo di attenzione dei Transformer è rimasto pressoché invariato. La maggior parte dei tentativi di miglioramento ha cercato di sostituire il softmax con alternative più efficienti. Un nuovo articolo di ricerca prende una strada diversa: mantiene il softmax e ci aggiunge una rampa di correzione laterale.

Un team di ricercatori della Northwestern University, Tilde Research e University of Washington ha introdotto Parallax, un’attenzione locale lineare parametrizzata che scala fino al pre-training di modelli linguistici di grandi dimensioni (LLM) e viene progettata in code-design con l’ottimizzatore Muon. Parallax non insegue l’efficienza tagliando calcoli; aggiunge deliberatamente calcoli, ma rende quei calcoli più economici da eseguire sulle GPU moderne.

Cos’è Parallax

Parallax si basa sull’Attenzione Locale Lineare (LLA). LLA deriva da un framework di regressione: interpreta l’attenzione come un solutore di regressione su coppie chiave-valore. In questa visione, le chiavi sono dati di training, i valori sono etichette e la query è il punto di test. Il softmax è uno stimatore non parametrico chiamato Nadaraya-Watson. LLA migliora questo stimatore rendendolo lineare locale, il che riduce l’errore quadratico medio integrato. Tuttavia, LLA ha un problema: richiede la risoluzione di un sistema lineare per ogni query, usando un solutore con gradienti coniugati (CG), che causa intensivo I/O, un compromesso tra regolarizzazione ed espressività, e incompatibilità con precisioni basse.

Parallax rimuove il solutore CG e lo sostituisce con una matrice di proiezione appresa, WR, che esamina la covarianza KV direttamente dall’input del layer. Questo rende il meccanismo più semplice, efficiente e facile da implementare.

Come funziona il meccanismo

Parallax riformula LLA come softmax più una correzione additiva. L’output è l’output del softmax meno un termine di covarianza proiettato. Il team ha anche eliminato un fattore di amplificazione di confine, necessario per la stabilità. Quando il probe è parametrico, l’interpretazione geometrica originale viene meno, e mantenere il fattore potrebbe far divergere o invertire il segno della scala.

Parallax fa parte di una famiglia di meccanismi di attenzione organizzati da tre assi: larghezza di banda, costruzione del probe e struttura affine. Un punto chiave: quando WR = 0, Parallax si comporta identico al softmax. Questo permette di convertire un checkpoint pre-addestrato aggiungendo WR e poi fare fine-tuning.

Il vantaggio hardware

Parallax eredita la struttura streaming di FlashAttention e aggiunge un ramo di covarianza che riusa lo stesso flusso chiave-valore. Il forward si espande in due rami di scoring paralleli, che condividono il massimo online, il fattore di rescaling e i tile K e V. Così Parallax non richiede I/O extra per iterazione.

La proprietà chiave è la maggiore intensità aritmetica (AI), il rapporto tra operazioni in virgola mobile e traffico di memoria. Parallax raddoppia approssimativamente l’AI nel regime in cui il lavoro KV domina. Aggiunge calcolo riusando lo stesso flusso di memoria, spostando l’attenzione verso un regime più compute-bound — esattamente dove l’ottimizzazione del kernel aiuta sull’hardware moderno.

Il team ha prototipato un kernel di decode su GPU Hopper di NVIDIA, testato contro FlashAttention 2 e 3 su GPU H200 a precisione BF16. Il kernel ha eguagliato o superato FlashAttention in tutte le configurazioni, con speedup fino a 1.54× in setting compute-matched e 1.14× in setting I/O-matched.

Cosa mostrano gli esperimenti

Il team ha validato Parallax su compiti sintetici e su pre-training di LLM a scale di 0.6 miliardi e 1.7 miliardi di parametri, usando architettura Qwen-3 e dataset Ultra-FineWeb. I confronti includevano softmax (Transformer standard), Mamba, Gated DeltaNet, MesaNet e Kimi DeltaAttention.

Nel MAD-Benchmark, Parallax ha raggiunto la massima accuratezza complessiva (0.716 media), migliorando compiti di richiamo come In-Context-Recall e Selective-Copying. Nel language modeling, con Muon ha ottenuto la migliore perplexity a entrambe le scale e la maggiore accuratezza media downstream: 62.45 contro 61.43 del Transformer a 1.7B. I controlli mostrano che il guadagno viene dal meccanismo in sé, non da parametri o compute extra.

L’interazione con l’ottimizzatore Muon

Un risultato chiave è l’interazione architettura-ottimizzatore. Parallax mostra un grande vantaggio sotto Muon, un ottimizzatore per parametri matriciali che usa il fattore polare del buffer di momentum. Sotto AdamW, il vantaggio si riduce drasticamente o scompare. La differenza è legata al rapporto correzione-uscita (COR): sotto Muon, COR supera 8 nei layer più profondi; sotto AdamW, resta sotto 4. La proiezione WR è influenzata: il suo rango stabile collassa con AdamW ma rimane alto con Muon. Il team definisce questa la prima dimostrazione di code-design forte architettura-ottimizzatore per meccanismi di attenzione.

Differenze nelle distribuzioni dei punteggi

Parallax produce distribuzioni di punteggi diverse dal softmax: i pesi per token possono essere negativi e superiori a 1. Questo permette tre effetti:

  • sottrarre attivamente componenti di valore da token irrilevanti
  • ridurre sostanzialmente il sink di attenzione sul primo token
  • mantenere un’entropia softmax di base più alta, con pesi di attenzione più diffusi

Punti di forza e debolezza

Punti di forza:

  • Mantiene il softmax intatto, consentendo la conversione di checkpoint pre-addestrati con fine-tuning
  • Nessun I/O extra per iterazione
  • Raddoppia l’intensità aritmetica, con kernel più veloci di FlashAttention
  • Guadagni consistenti in perplexity e downstream

Debolezze e domande aperte:

  • I guadagni dipendono fortemente da Muon; con AdamW scompaiono in gran parte
  • La causa precisa della dipendenza dall’ottimizzatore rimane sconosciuta
  • I risultati si fermano a 1.7B di parametri, senza MoE, contesto lungo o scale maggiori
  • Il vantaggio si riduce durante la fase di decay di WSD

Takeaway: Parallax mantiene il softmax e aggiunge una correzione covariante appresa, sostituendo la procedura costosa di risolvere un sistema lineare per ogni query. Il risultato è un meccanismo di attenzione più efficiente dal punto di vista computazionale, che ha mostrato prestazioni migliori in compiti di language modeling a scale fino a 1.7 miliardi di parametri, purché si usi l’ottimizzatore giusto.

Articoli simili