2025-11-20T09:28:14.240195

Lightweight and Interpretable Transformer via Mixed Graph Algorithm Unrolling for Traffic Forecast

Qi, Do, Liu et al.
Unlike conventional "black-box" transformers with classical self-attention mechanism, we build a lightweight and interpretable transformer-like neural net by unrolling a mixed-graph-based optimization algorithm to forecast traffic with spatial and temporal dimensions. We construct two graphs: an undirected graph $\mathcal{G}^u$ capturing spatial correlations across geography, and a directed graph $\mathcal{G}^d$ capturing sequential relationships over time. We predict future samples of signal $\mathbf{x}$, assuming it is "smooth" with respect to both $\mathcal{G}^u$ and $\mathcal{G}^d$, where we design new $\ell_2$ and $\ell_1$-norm variational terms to quantify and promote signal smoothness (low-frequency reconstruction) on a directed graph. We design an iterative algorithm based on alternating direction method of multipliers (ADMM), and unroll it into a feed-forward network for data-driven parameter learning. We insert graph learning modules for $\mathcal{G}^u$ and $\mathcal{G}^d$ that play the role of self-attention. Experiments show that our unrolled networks achieve competitive traffic forecast performance as state-of-the-art prediction schemes, while reducing parameter counts drastically. Our code is available in https://github.com/SingularityUndefined/Unrolling-GSP-STForecast .
academic

Leichter und interpretierbarer Transformer durch gemischtes Graphen-Algorithmus-Unrolling für Verkehrsprognose

Grundinformationen

  • Paper-ID: 2505.13102
  • Titel: Lightweight and Interpretable Transformer via Mixed Graph Algorithm Unrolling for Traffic Forecast
  • Autoren: Ji Qi, Mingxiao Liu, Tam Thuc Do, Yuzhe Li, Zhuoshi Pan, Gene Cheung, H. Vicky Zhao
  • Klassifizierung: cs.LG cs.AI eess.SP
  • Veröffentlichungsdatum: 12. Oktober 2025 (arXiv v2)
  • Paper-Link: https://arxiv.org/abs/2505.13102

Zusammenfassung

Diese Arbeit präsentiert ein leichtes und interpretierbares Transformer-Modell basierend auf gemischtem Graphen-Algorithmus-Unrolling für Verkehrsprognosen. Im Gegensatz zu traditionellen „Black-Box"-Transformern wird ein interpretierbares Transformer-ähnliches neuronales Netzwerk durch das Unrolling eines gemischten Graphen-Optimierungsalgorithmus konstruiert. Das Modell konstruiert zwei Graphen: ein ungerichteter Graph Gu\mathcal{G}^u erfasst geographisch-räumliche Korrelationen, ein gerichteter Graph Gd\mathcal{G}^d erfasst zeitliche Beziehungen. Durch die Gestaltung neuer 2\ell_2- und 1\ell_1-Norm-Variationsterme werden Signalglätte auf dem gerichteten Graphen quantifiziert und gefördert. Basierend auf der Alternating Direction Method of Multipliers (ADMM) wird ein iterativer Algorithmus entworfen und als Feedforward-Netzwerk für datengesteuerte Parameterlernvorgänge entfaltet. Experimente zeigen, dass das Modell die Anzahl der Parameter erheblich reduziert, während es wettbewerbsfähige Verkehrsprognose-Leistung beibehält.

Forschungshintergrund und Motivation

Problemdefinition

Verkehrsprognose ist ein wichtiges raumzeitliches Datenmodellierungsproblem, das Folgendes erfassen muss:

  1. Räumliche Korrelation: Korrelationen zwischen geografisch nahe beieinander liegenden Überwachungsstationen
  2. Zeitliche Abhängigkeit: Auswirkungen historischer Beobachtungen auf zukünftige Zustände

Einschränkungen bestehender Methoden

  1. Traditionelle Transformer: Enorme Parametermenge, mangelnde Interpretierbarkeit, Herausforderungen bei Berechnungs- und Speicherbeschränkungen in der praktischen Bereitstellung
  2. Modellbasierte Methoden: Behandeln räumliche und zeitliche Dimensionen oft unabhängig, nutzen raumzeitliche Beziehungen nicht vollständig
  3. Bestehende Deep-Learning-Methoden: Obwohl leistungsstark, sind sie immer noch „Black-Box"-Modelle mit großer Parametermenge

Forschungsmotivation

  1. Dringender Bedarf der Industrie nach leichten Modellen
  2. Algorithm Unrolling bietet ein neues Paradigma, das modellgesteuerte und datengesteuerte Ansätze kombiniert
  3. Bestehende Arbeiten verwenden nur positive ungerichtete Graphen und können komplexe raumzeitliche Beziehungen nicht effektiv modellieren

Kernbeiträge

  1. Erstmalige Vorschlag gemischter Graphen-Algorithmus-Unrolling: Kombiniert ungerichtete Graphen (räumlich) und gerichtete Graphen (zeitlich) zur Modellierung komplexer raumzeitlicher Beziehungen
  2. Innovative gerichtete Graphen-Regularisierungsterme: Entwurf von gerichteter Graphen-Laplace-Regularisierung (DGLR) und gerichteter Graphen-Totalvariation (DGTV)
  3. Leichter interpretierbarer Transformer: Durch ADMM-Algorithmus-Unrolling erreichte massive Parameterreduktion (nur 6,4% von PDFormer)
  4. Theoretischer Beitrag: Beweis, dass die gerichtete Graphen-Frequenzdefinition im Fall ungewichteter gerichteter Liniengraphen zu klassischen Fourier-Frequenzen degeneriert

Methodische Details

Aufgabendefinition

Gegeben sind Beobachtungen von N Überwachungsstationen über T+1 vergangene Zeitpunkte; die Aufgabe besteht darin, den Verkehrszustand für die nächsten S Zeitpunkte vorherzusagen. Die Eingabe ist ein teilweise beobachtetes raumzeitliches Signal yRMy \in \mathbb{R}^M, die Ausgabe ist ein vollständiges raumzeitliches Signal xRN(T+S+1)x \in \mathbb{R}^{N(T+S+1)}.

Gemischte Graphen-Konstruktion

Ungerichteter Graph Gu\mathcal{G}^u

  • Verbindet Knoten geografisch nahe beieinander liegender Standorte zum gleichen Zeitpunkt
  • Erfasst räumliche Korrelation
  • Verwendet symmetrische Nachbarschaftsmatrix WuW^u

Gerichteter Graph Gd\mathcal{G}^d

  • Verbindet Knoten von Zeitpunkt τ\tau mit Knoten der gleichen Standorte zu Zeitpunkten τ+1,...,τ+W\tau+1, ..., \tau+W
  • Erfasst zeitliche Kausalbeziehungen
  • Verwendet asymmetrische Nachbarschaftsmatrix WdW^d

Entwurf gerichteter Graphen-Variationsterme

2\ell_2-Norm-Term: Gerichtete Graphen-Laplace-Regularisierung (DGLR)

xTLrdx=xT(Lrd)TLrdx=xWrdx22x^T\mathcal{L}_r^d x = x^T(L_r^d)^T L_r^d x = \|x - W_r^d x\|_2^2

wobei Lrd=IWrdL_r^d = I - W_r^d die zufallsbasierte Laplace-Matrix ist und Wrd=(Dd)1WdW_r^d = (D^d)^{-1}W^d die zeilenweise stochastische Nachbarschaftsmatrix ist.

1\ell_1-Norm-Term: Gerichtete Graphen-Totalvariation (DGTV)

Lrdx1=jSˉxjiwj,ixi\|L_r^d x\|_1 = \sum_{j \in \bar{S}} |x_j - \sum_i w_{j,i} x_i|

Optimierungsziel-Funktion

minxyHx22+μuxTLux+μd,2xTLrdx+μd,1Lrdx1\min_x \|y - Hx\|_2^2 + \mu_u x^T L^u x + \mu_{d,2} x^T \mathcal{L}_r^d x + \mu_{d,1} \|L_r^d x\|_1

wobei HH die Abtastmatrix ist und μu,μd,2,μd,1\mu_u, \mu_{d,2}, \mu_{d,1} Gewichtungsparameter sind.

ADMM-Algorithmus-Entwurf

Durch Einführung von Hilfsvariablen ϕ\phi wird das Optimierungsproblem transformiert zu: minx,ϕyHx22+μuxTLux+μd,2xTLrdx+μd,1ϕ1\min_{x,\phi} \|y - Hx\|_2^2 + \mu_u x^T L^u x + \mu_{d,2} x^T \mathcal{L}_r^d x + \mu_{d,1} \|\phi\|_1s.t. ϕ=Lrdx\text{s.t. } \phi = L_r^d x

Teilproblem-Lösungen

  1. xx-Teilproblem: Gelöst durch konjugierte Gradientenmethode für lineares System
  2. ϕ\phi-Teilproblem: Soft-Thresholding-Operation ϕiτ+1=sign(δ)max(δρ1μd,1,0)\phi_i^{\tau+1} = \text{sign}(\delta) \cdot \max(|\delta| - \rho^{-1}\mu_{d,1}, 0) wobei δ=(Lrd)ixτ+1ρ1γiτ\delta = (L_r^d)_i x^{\tau+1} - \rho^{-1}\gamma_i^\tau

Graphen-Lernmodul

Ungerichtetes Graphen-Lernen (UGL)

Berechnung der Knotenähnlichkeit mittels Mahalanobis-Distanz: du(i,j)=(fiufju)TM(fiufju)d^u(i,j) = (f_i^u - f_j^u)^T M (f_i^u - f_j^u)

Kantengewichte werden durch normalisierte Exponentialfunktion berechnet: wi,ju=exp(du(i,j))lNiexp(du(i,l))kNjexp(du(k,j))w_{i,j}^u = \frac{\exp(-d^u(i,j))}{\sqrt{\sum_{l \in \mathcal{N}_i} \exp(-d^u(i,l))} \sqrt{\sum_{k \in \mathcal{N}_j} \exp(-d^u(k,j))}}

Gerichtetes Graphen-Lernen (DGL)

Ähnlich werden gerichtete Kantengewichte mittels Metrik-Matrix PP berechnet.

Netzwerk-Architektur

Jede ADMM-Iteration wird als neuronale Schicht implementiert:

  • 5 ADMM-Blöcke, jeder mit 25 Schichten
  • Graphen-Lernmodul vor jedem Block eingefügt
  • Verwendung von Multi-Head-Aufmerksamkeitsmechanismus (4 parallele Graphen-Lernmodule)

Experimentelle Einrichtung

Datensätze

  • METR-LA: Verkehrsgeschwindigkeitsdaten von Los Angeles, 207 Knoten, 1315 Kanten
  • PEMS03: Verkehrsfluss-Daten, 358 Knoten, 547 Kanten
  • Abtastintervall: 5 Minuten
  • Datenteilung: 6:2:2 (Training:Validierung:Test)

Bewertungsmetriken

  • RMSE: Quadratischer Mittelfehler
  • MAE: Mittlerer absoluter Fehler
  • MAPE: Mittlerer absoluter prozentualer Fehler

Vergleichsmethoden

Umfasst 6 Kategorien von Baseline-Methoden:

  • Modellbasiert: VAR
  • GNN-Methoden: STGCN, STSGCN
  • GAT-Methoden: GMAN, ST-Wave
  • Transformer-Methoden: PDFormer, STAEformer
  • Adaptive Graphen-Methoden: Graph WaveNet, AGCRN
  • Einfache lineare Modelle: STID, SimpleTM

Implementierungsdetails

  • Prognosezeitraum: 30/60/120 Minuten (6/12/24 Schritte)
  • Historisches Fenster: 60 Minuten (12 Schritte)
  • Optimierer: Adam, Lernrate 5×10⁻⁴
  • Verlustfunktion: Huber-Verlust (δ=1)
  • Hardware: NVIDIA GeForce RTX 3090

Experimentelle Ergebnisse

Hauptergebnisse

DatensatzZeitraumDiese MethodeBeste BaselineParametervergleich
PEMS0330min26.10/17.03/18.8523.71/15.05/18.1634K vs 531K
PEMS0360min27.67/17.46/17.7225.56/15.97/15.49(6,4% Parameter)
METR-LA60min12.34/5.18/11.8011.96/5.49/9.65

Wichtigste Erkenntnisse

  1. Parametereffizienz: Erreicht wettbewerbsfähige Leistung mit nur 6,4% der Parameter von PDFormer
  2. Vorteil bei Langzeitprognose: Je länger der Prognosezeitraum, desto kleiner der Leistungsunterschied zur besten Methode
  3. Dateneffizienz: Stabilere Leistung bei spärlichen Daten

Ablationsstudien

VariantePEMS03 (RMSE/MAE/MAPE)METR-LA (RMSE/MAE/MAPE)
Vollständiges Modell27.67/17.46/17.7212.34/5.18/11.80
Ohne DGTV27.78/17.85/17.9012.36/5.40/12.31
Ohne DGLR30.89/20.02/21.1012.41/5.35/12.20
Ungerichteter Zeitgraph27.52/17.87/18.8212.51/5.42/12.11

Ergebnisse zeigen:

  • DGLR-Term ist kritischster für Leistungsverbesserung
  • DGTV-Term trägt auch deutlich bei
  • Gerichtete Graphen-Modellierung übertrifft ungerichtete Modellierung

Theoretische Verifikation

Theorem 3.1 beweist: Für ungewichtete gerichtete Liniengraphen ist die symmetrisierte gerichtete Graphen-Laplace-Matrix Lrd=(Lrd)TLrd\mathcal{L}_r^d = (L_r^d)^T L_r^d äquivalent zur Laplace-Matrix des ungerichteten Liniengraphen, was die Rationalität der Frequenzdefinition bestätigt.

Verwandte Arbeiten

Leichte Modelle

  • Große Sprachmodelle: LoRA-Niedrig-Rang-Anpassung, Parameterquantisierung
  • Sprachverbesserung: Lokale kausale Selbstaufmerksamkeit
  • Bildverarbeitung: YUV-Kanal-Trennung

Verkehrsprognose-Methoden

  1. GNN-Methoden: STGCN, Graph WaveNet usw., konzentriert sich auf räumliche Modellierung
  2. Transformer-Methoden: Duale Transformer behandeln raumzeitliche Dimensionen separat
  3. Einfache lineare Modelle: Hinterfragen die Effektivität komplexer Modelle

Algorithmus-Unrolling

  • Entfaltet Optimierungsalgorithmus-Iterationen als neuronale Schichten
  • Kombiniert mathematische Interpretierbarkeit und datengesteuerte Fähigkeit
  • Erfolgreiche Anwendungen in der Bildverarbeitung

Schlussfolgerungen und Diskussion

Hauptschlussfolgerungen

  1. Gemischtes Graphen-Algorithmus-Unrolling realisiert erfolgreich leichte und interpretierbare Verkehrsprognose-Modelle
  2. Gerichtete Graphen-Variationsterme erfassen effektiv zeitliche Kausalbeziehungen
  3. Massive Parameterreduktion bei Beibehaltung wettbewerbsfähiger Leistung

Einschränkungen

  1. Distanz-Einschränkung: Gelernte Mahalanobis-Distanz ist nicht-negativ, während traditionelle Selbstaufmerksamkeit negativ sein kann
  2. Graphen-Sparsität: Basierend auf echten Straßenverbindungen wird die Graphen-Konnektivität begrenzt
  3. Festes Zeitfenster: Vordefiniertes Zeitfenster könnte weniger flexibel sein

Zukünftige Richtungen

  1. Erweiterung auf signierte Distanzen und komplexere Graphen-Modellierung
  2. Adaptives Zeitfenster-Lernen
  3. Anwendung auf andere raumzeitliche Prognosaufgaben

Tiefe Bewertung

Stärken

  1. Theoretische Innovation: Erstmalige Definition von Frequenzkonzept für gerichtete Graphen mit entsprechenden Regularisierungstermen
  2. Methodische Neuheit: Gemischtes Graphen-Algorithmus-Unrolling bietet neue Perspektive für Transformer-Entwurf
  3. Praktischer Wert: Signifikante Parameterreduktion hat große Bedeutung für praktische Bereitstellung
  4. Interpretierbarkeit: Jede Schicht entspricht Optimierungsalgorithmus-Iteration mit klarer mathematischer Bedeutung

Mängel

  1. Leistungs-Kompromiss: Bei einigen Metriken immer noch nicht so gut wie beste Baseline-Methoden
  2. Anwendungsbereich: Hauptsächlich auf Verkehrsprognose validiert, Generalisierbarkeit auf andere raumzeitliche Aufgaben unbekannt
  3. Theoretische Analyse: Fehlende Konvergenz- und Komplexitätsanalyse

Auswirkungen

  1. Akademischer Beitrag: Bietet neue Perspektive für Graphen-Signalverarbeitung und Transformer-Entwurf
  2. Praktischer Wert: Leichte Eigenschaften eignen sich für Edge-Computing und ressourcenbeschränkte Umgebungen
  3. Reproduzierbarkeit: Bereitstellung von Open-Source-Code mit detaillierten Experimenteinstellungen

Anwendungsszenarien

  1. Ressourcenbeschränkte Umgebungen: Mobile Geräte, Edge-Computing
  2. Echtzeit-Prognosesysteme: Verkehrsmanagementsysteme, die schnelle Reaktion erfordern
  3. Interpretierbare KI-Anwendungen: Sicherheitskritische Systeme, die Modell-Transparenz erfordern

Referenzen

Das Paper zitiert mehrere wichtige Arbeiten, einschließlich:

  • Originales Transformer-Paper (Vaswani et al., 2017)
  • Algorithmus-Unrolling-Übersicht (Monga et al., 2021)
  • Grundlagen der Graphen-Signalverarbeitung (Ortega et al., 2018)
  • Verwandte Arbeiten zur Verkehrsprognose (Li et al., 2017; Yu et al., 2018)

Gesamtbewertung: Dies ist eine innovative Arbeit im Bereich der Verkehrsprognose, die erfolgreich die Algorithmus-Unrolling-Idee auf gemischte Graphen-Einstellungen erweitert und dabei die Parametermenge erheblich reduziert, während die Leistung beibehalten wird. Obwohl bei einigen Metriken noch Verbesserungsspielraum besteht, machen die leichten und interpretierbaren Eigenschaften diese Arbeit von großem praktischem Wert und akademischer Bedeutung.