2025-11-20T09:28:14.240195

Lightweight and Interpretable Transformer via Mixed Graph Algorithm Unrolling for Traffic Forecast

Qi, Do, Liu et al.
Unlike conventional "black-box" transformers with classical self-attention mechanism, we build a lightweight and interpretable transformer-like neural net by unrolling a mixed-graph-based optimization algorithm to forecast traffic with spatial and temporal dimensions. We construct two graphs: an undirected graph $\mathcal{G}^u$ capturing spatial correlations across geography, and a directed graph $\mathcal{G}^d$ capturing sequential relationships over time. We predict future samples of signal $\mathbf{x}$, assuming it is "smooth" with respect to both $\mathcal{G}^u$ and $\mathcal{G}^d$, where we design new $\ell_2$ and $\ell_1$-norm variational terms to quantify and promote signal smoothness (low-frequency reconstruction) on a directed graph. We design an iterative algorithm based on alternating direction method of multipliers (ADMM), and unroll it into a feed-forward network for data-driven parameter learning. We insert graph learning modules for $\mathcal{G}^u$ and $\mathcal{G}^d$ that play the role of self-attention. Experiments show that our unrolled networks achieve competitive traffic forecast performance as state-of-the-art prediction schemes, while reducing parameter counts drastically. Our code is available in https://github.com/SingularityUndefined/Unrolling-GSP-STForecast .
academic

Lightweight and Interpretable Transformer via Mixed Graph Algorithm Unrolling for Traffic Forecast

Basic Information

  • Paper ID: 2505.13102
  • Title: Lightweight and Interpretable Transformer via Mixed Graph Algorithm Unrolling for Traffic Forecast
  • Authors: Ji Qi, Mingxiao Liu, Tam Thuc Do, Yuzhe Li, Zhuoshi Pan, Gene Cheung, H. Vicky Zhao
  • Classification: cs.LG cs.AI eess.SP
  • Publication Date: October 12, 2025 (arXiv v2)
  • Paper Link: https://arxiv.org/abs/2505.13102

Abstract

This paper proposes a lightweight and interpretable Transformer model based on mixed graph algorithm unrolling for traffic forecasting. Unlike traditional "black-box" Transformers, this approach constructs an interpretable Transformer-like neural network by unrolling a mixed graph optimization algorithm. The model constructs two graphs: an undirected graph Gu\mathcal{G}^u capturing geospatial correlations and a directed graph Gd\mathcal{G}^d capturing temporal relationships. Novel 2\ell_2 and 1\ell_1 norm variational terms are designed to quantify and promote signal smoothness on the directed graph. Based on the Alternating Direction Method of Multipliers (ADMM), an iterative algorithm is designed and unrolled into a feedforward network for data-driven parameter learning. Experiments demonstrate that the model maintains competitive traffic forecasting performance while significantly reducing the number of parameters.

Research Background and Motivation

Problem Definition

Traffic forecasting is an important spatio-temporal data modeling problem requiring simultaneous capture of:

  1. Spatial Correlations: Relationships between monitoring stations at geographically proximate locations
  2. Temporal Dependencies: Influence of historical observations on future states

Limitations of Existing Methods

  1. Traditional Transformers: Massive parameter counts, lack of interpretability, face computational and memory constraints in practical deployment
  2. Model-based Methods: Often process spatial and temporal dimensions independently, failing to fully exploit spatio-temporal relationships
  3. Existing Deep Learning Methods: While achieving excellent performance, they remain "black-box" models with large parameter counts

Research Motivation

  1. Urgent industrial demand for lightweight models
  2. Algorithm unrolling provides a new paradigm combining model-driven and data-driven approaches
  3. Existing work uses only positive undirected graphs, unable to effectively model complex spatio-temporal relationships

Core Contributions

  1. First Mixed Graph Algorithm Unrolling: Combines undirected graphs (spatial) and directed graphs (temporal) to model complex spatio-temporal relationships
  2. Innovative Directed Graph Regularization Terms: Designs directed graph Laplacian regularizer (DGLR) and directed graph total variation (DGTV)
  3. Lightweight Interpretable Transformer: Achieves significant parameter reduction (only 6.4% of PDFormer) through ADMM algorithm unrolling
  4. Theoretical Contribution: Proves that directed graph frequency definition degenerates to classical Fourier frequency in the unweighted directed line graph case

Methodology Details

Task Definition

Given observations from N monitoring stations over T+1 past time steps, predict traffic states for S future time steps. Input is partially observed spatio-temporal signal yRMy \in \mathbb{R}^M, output is complete spatio-temporal signal xRN(T+S+1)x \in \mathbb{R}^{N(T+S+1)}.

Mixed Graph Construction

Undirected Graph Gu\mathcal{G}^u

  • Connects nodes at geographically proximate locations at the same time step
  • Captures spatial correlations
  • Uses symmetric adjacency matrix WuW^u

Directed Graph Gd\mathcal{G}^d

  • Connects nodes from time step τ\tau to the same nodes at time steps τ+1,...,τ+W\tau+1, ..., \tau+W
  • Captures temporal causal relationships
  • Uses asymmetric adjacency matrix WdW^d

Directed Graph Variational Terms Design

2\ell_2 Norm Term: Directed Graph Laplacian Regularizer (DGLR)

xTLrdx=xT(Lrd)TLrdx=xWrdx22x^T\mathcal{L}_r^d x = x^T(L_r^d)^T L_r^d x = \|x - W_r^d x\|_2^2

where Lrd=IWrdL_r^d = I - W_r^d is the random walk Laplacian matrix and Wrd=(Dd)1WdW_r^d = (D^d)^{-1}W^d is the row-stochastic adjacency matrix.

1\ell_1 Norm Term: Directed Graph Total Variation (DGTV)

Lrdx1=jSˉxjiwj,ixi\|L_r^d x\|_1 = \sum_{j \in \bar{S}} |x_j - \sum_i w_{j,i} x_i|

Optimization Objective Function

minxyHx22+μuxTLux+μd,2xTLrdx+μd,1Lrdx1\min_x \|y - Hx\|_2^2 + \mu_u x^T L^u x + \mu_{d,2} x^T \mathcal{L}_r^d x + \mu_{d,1} \|L_r^d x\|_1

where HH is the sampling matrix and μu,μd,2,μd,1\mu_u, \mu_{d,2}, \mu_{d,1} are weight parameters.

ADMM Algorithm Design

By introducing auxiliary variable ϕ\phi, the optimization problem is transformed to: minx,ϕyHx22+μuxTLux+μd,2xTLrdx+μd,1ϕ1\min_{x,\phi} \|y - Hx\|_2^2 + \mu_u x^T L^u x + \mu_{d,2} x^T \mathcal{L}_r^d x + \mu_{d,1} \|\phi\|_1s.t. ϕ=Lrdx\text{s.t. } \phi = L_r^d x

Subproblem Solutions

  1. xx Subproblem: Solved via conjugate gradient method for linear systems
  2. ϕ\phi Subproblem: Soft thresholding operation ϕiτ+1=sign(δ)max(δρ1μd,1,0)\phi_i^{\tau+1} = \text{sign}(\delta) \cdot \max(|\delta| - \rho^{-1}\mu_{d,1}, 0) where δ=(Lrd)ixτ+1ρ1γiτ\delta = (L_r^d)_i x^{\tau+1} - \rho^{-1}\gamma_i^\tau

Graph Learning Module

Undirected Graph Learning (UGL)

Computes node similarity using Mahalanobis distance: du(i,j)=(fiufju)TM(fiufju)d^u(i,j) = (f_i^u - f_j^u)^T M (f_i^u - f_j^u)

Edge weights computed via normalized exponential function: wi,ju=exp(du(i,j))lNiexp(du(i,l))kNjexp(du(k,j))w_{i,j}^u = \frac{\exp(-d^u(i,j))}{\sqrt{\sum_{l \in \mathcal{N}_i} \exp(-d^u(i,l))} \sqrt{\sum_{k \in \mathcal{N}_j} \exp(-d^u(k,j))}}

Directed Graph Learning (DGL)

Similarly computes directed edge weights using metric matrix PP.

Network Architecture

Each ADMM iteration is implemented as a neural layer:

  • 5 ADMM blocks, each with 25 layers
  • Graph learning module inserted before each block
  • Uses multi-head attention mechanism (4 parallel graph learning modules)

Experimental Setup

Datasets

  • METR-LA: Los Angeles traffic speed data, 207 nodes, 1315 edges
  • PEMS03: Traffic flow data, 358 nodes, 547 edges
  • Sampling interval: 5 minutes
  • Data split: 6:2:2 (train:validation:test)

Evaluation Metrics

  • RMSE: Root Mean Square Error
  • MAE: Mean Absolute Error
  • MAPE: Mean Absolute Percentage Error

Baseline Methods

Six categories of baseline methods:

  • Model-based: VAR
  • GNN methods: STGCN, STSGCN
  • GAT methods: GMAN, ST-Wave
  • Transformer methods: PDFormer, STAEformer
  • Adaptive graph methods: Graph WaveNet, AGCRN
  • Simple linear models: STID, SimpleTM

Implementation Details

  • Prediction horizons: 30/60/120 minutes (6/12/24 steps)
  • Historical window: 60 minutes (12 steps)
  • Optimizer: Adam, learning rate 5×10⁻⁴
  • Loss function: Huber loss (δ=1)
  • Hardware: NVIDIA GeForce RTX 3090

Experimental Results

Main Results

DatasetDurationProposedBest BaselineParameter Comparison
PEMS0330min26.10/17.03/18.8523.71/15.05/18.1634K vs 531K
PEMS0360min27.67/17.46/17.7225.56/15.97/15.49(6.4% parameters)
METR-LA60min12.34/5.18/11.8011.96/5.49/9.65

Key Findings

  1. Parameter Efficiency: Achieves competitive performance using only 6.4% of PDFormer's parameters
  2. Long-term Prediction Advantage: Performance gap with best methods decreases as prediction horizon increases
  3. Data Efficiency: More stable performance in data-scarce scenarios

Ablation Study

VariantPEMS03 (RMSE/MAE/MAPE)METR-LA (RMSE/MAE/MAPE)
Full Model27.67/17.46/17.7212.34/5.18/11.80
Without DGTV27.78/17.85/17.9012.36/5.40/12.31
Without DGLR30.89/20.02/21.1012.41/5.35/12.20
Undirected Temporal Graph27.52/17.87/18.8212.51/5.42/12.11

Results demonstrate:

  • DGLR term is most critical for performance improvement
  • DGTV term also contributes significantly
  • Directed graph modeling outperforms undirected graph modeling

Theoretical Verification

Theorem 3.1 proves that for unweighted directed line graphs, the symmetrized directed graph Laplacian Lrd=(Lrd)TLrd\mathcal{L}_r^d = (L_r^d)^T L_r^d is equivalent to the undirected line graph Laplacian, validating the reasonableness of the frequency definition.

Lightweight Models

  • Large language models: LoRA low-rank adaptation, parameter quantization
  • Speech enhancement: Local causal self-attention
  • Image processing: YUV channel separation processing

Traffic Forecasting Methods

  1. GNN Methods: STGCN, Graph WaveNet, etc., focusing on spatial modeling
  2. Transformer Methods: Dual Transformers handling spatial and temporal dimensions separately
  3. Simple Linear Models: Challenging the effectiveness of complex models

Algorithm Unrolling

  • Unfolds optimization algorithm iterations into neural layers
  • Combines mathematical interpretability with data-driven capability
  • Successful applications in image processing

Conclusions and Discussion

Main Conclusions

  1. Mixed graph algorithm unrolling successfully achieves lightweight and interpretable traffic forecasting models
  2. Directed graph variational terms effectively capture temporal causal relationships
  3. Maintains competitive performance while significantly reducing parameters

Limitations

  1. Distance Constraints: Learned Mahalanobis distances are non-negative, while traditional self-attention can be negative
  2. Graph Sparsity: Learning based on real road connections limits graph connectivity
  3. Fixed Temporal Window: Predefined temporal windows may lack flexibility

Future Directions

  1. Extend to signed distances and more complex graph modeling
  2. Adaptive temporal window learning
  3. Application to other spatio-temporal forecasting tasks

In-Depth Evaluation

Strengths

  1. Theoretical Innovation: First to define frequency concepts for directed graphs and design corresponding regularization terms
  2. Novel Methodology: Mixed graph algorithm unrolling provides new insights for Transformer design
  3. Practical Value: Significant parameter reduction has important implications for practical deployment
  4. Interpretability: Each layer corresponds to optimization algorithm iterations with clear mathematical meaning

Weaknesses

  1. Performance Trade-off: Still underperforms best baseline methods on certain metrics
  2. Limited Scope: Primarily validated on traffic forecasting; generalization to other spatio-temporal tasks unknown
  3. Theoretical Analysis: Lacks convergence and complexity analysis

Impact

  1. Academic Contribution: Provides new perspectives for graph signal processing and Transformer design
  2. Practical Value: Lightweight characteristics suitable for edge computing and resource-constrained environments
  3. Reproducibility: Open-source code provided with detailed experimental settings

Applicable Scenarios

  1. Resource-Constrained Environments: Mobile devices, edge computing
  2. Real-time Prediction Systems: Traffic management systems requiring rapid response
  3. Interpretable AI Applications: Safety-critical systems requiring model transparency

References

The paper cites important works including:

  • Original Transformer paper (Vaswani et al., 2017)
  • Algorithm unrolling survey (Monga et al., 2021)
  • Graph signal processing foundations (Ortega et al., 2018)
  • Traffic forecasting related work (Li et al., 2017; Yu et al., 2018)

Overall Assessment: This is an innovative work in the traffic forecasting domain that successfully extends algorithm unrolling ideas to mixed graph settings, achieving significant parameter reduction while maintaining performance. Although there remains room for improvement on certain metrics, its lightweight and interpretable characteristics make it of considerable practical value and academic significance.