Unlike conventional "black-box" transformers with classical self-attention mechanism, we build a lightweight and interpretable transformer-like neural net by unrolling a mixed-graph-based optimization algorithm to forecast traffic with spatial and temporal dimensions. We construct two graphs: an undirected graph $\mathcal{G}^u$ capturing spatial correlations across geography, and a directed graph $\mathcal{G}^d$ capturing sequential relationships over time. We predict future samples of signal $\mathbf{x}$, assuming it is "smooth" with respect to both $\mathcal{G}^u$ and $\mathcal{G}^d$, where we design new $\ell_2$ and $\ell_1$-norm variational terms to quantify and promote signal smoothness (low-frequency reconstruction) on a directed graph. We design an iterative algorithm based on alternating direction method of multipliers (ADMM), and unroll it into a feed-forward network for data-driven parameter learning. We insert graph learning modules for $\mathcal{G}^u$ and $\mathcal{G}^d$ that play the role of self-attention. Experiments show that our unrolled networks achieve competitive traffic forecast performance as state-of-the-art prediction schemes, while reducing parameter counts drastically. Our code is available in https://github.com/SingularityUndefined/Unrolling-GSP-STForecast .
論文ID : 2505.13102タイトル : Lightweight and Interpretable Transformer via Mixed Graph Algorithm Unrolling for Traffic Forecast著者 : Ji Qi, Mingxiao Liu, Tam Thuc Do, Yuzhe Li, Zhuoshi Pan, Gene Cheung, H. Vicky Zhao分類 : cs.LG cs.AI eess.SP発表日 : 2025年10月12日 (arXiv v2)論文リンク : https://arxiv.org/abs/2505.13102 本論文は、混合グラフアルゴリズムアンローリングに基づく軽量で解釈可能なTransformerモデルを交通予測に提案している。従来の「ブラックボックス」Transformerとは異なり、本手法は混合グラフ最適化アルゴリズムをアンローリングすることで、解釈可能なTransformer型ニューラルネットワークを構築している。モデルは2つのグラフを構築する:無向グラフG u \mathcal{G}^u G u は地理的空間相関性を捉え、有向グラフG d \mathcal{G}^d G d は時間的関係を捉える。有向グラフ上の信号平滑性を定量化・促進するための新しいℓ 2 \ell_2 ℓ 2 およびℓ 1 \ell_1 ℓ 1 ノルム変分項を設計し、交互方向乗数法(ADMM)に基づいて反復アルゴリズムを設計し、これをフィードフォワードネットワークにアンローリングしてデータ駆動型のパラメータ学習を行う。実験により、本モデルは競争力のある交通予測性能を維持しながら、パラメータ数を大幅に削減することが示された。
交通予測は重要な時空間データモデリング問題であり、以下を同時に捉える必要がある:
空間相関性 :地理的に近い監視地点間の相関性時間依存性 :過去の観測が将来に与える影響関係従来のTransformer :パラメータ数が膨大で解釈性に欠け、実際の展開時に計算とメモリの制約に直面するモデルベースの手法 :空間と時間の次元を独立に処理することが多く、時空間関係を十分に活用できない既存の深層学習手法 :性能は優れているがなお「ブラックボックス」モデルであり、パラメータ数が多い産業応用における軽量モデルの緊急の需要 アルゴリズムアンローリング(Algorithm Unrolling)がモデル駆動とデータ駆動を組み合わせた新しいパラダイムを提供 既存の研究は正の無向グラフのみを使用しており、複雑な時空間関係を効果的にモデル化できない 混合グラフアルゴリズムアンローリングの初提案 :無向グラフ(空間)と有向グラフ(時間)を組み合わせて複雑な時空間関係をモデル化革新的な有向グラフ正則化項 :有向グラフラプラシアン正則化器(DGLR)と有向グラフ全変動(DGTV)を設計軽量で解釈可能なTransformer :ADMMアルゴリズムアンローリングにより、パラメータを大幅削減(PDFormerのわずか6.4%)理論的貢献 :有向グラフ周波数定義が無重み有向線グラフの場合に古典的フーリエ周波数に退化することを証明N個の監視地点における過去T+1時刻の観測値が与えられたとき、将来S時刻の交通状態を予測する。入力は部分的に観測された時空間信号y ∈ R M y \in \mathbb{R}^M y ∈ R M であり、出力は完全な時空間信号x ∈ R N ( T + S + 1 ) x \in \mathbb{R}^{N(T+S+1)} x ∈ R N ( T + S + 1 ) である。
同一時刻の地理的に近いノード同士を接続 空間相関性を捉える 対称隣接行列W u W^u W u を使用 時刻τ \tau τ のノードからτ + 1 , . . . , τ + W \tau+1, ..., \tau+W τ + 1 , ... , τ + W 時刻の同じノードへ接続 時間的因果関係を捉える 非対称隣接行列W d W^d W d を使用 x T L r d x = x T ( L r d ) T L r d x = ∥ x − W r d x ∥ 2 2 x^T\mathcal{L}_r^d x = x^T(L_r^d)^T L_r^d x = \|x - W_r^d x\|_2^2 x T L r d x = x T ( L r d ) T L r d x = ∥ x − W r d x ∥ 2 2
ここでL r d = I − W r d L_r^d = I - W_r^d L r d = I − W r d は確率的ウォークラプラシアン行列、W r d = ( D d ) − 1 W d W_r^d = (D^d)^{-1}W^d W r d = ( D d ) − 1 W d は行確率的隣接行列である。
∥ L r d x ∥ 1 = ∑ j ∈ S ˉ ∣ x j − ∑ i w j , i x i ∣ \|L_r^d x\|_1 = \sum_{j \in \bar{S}} |x_j - \sum_i w_{j,i} x_i| ∥ L r d x ∥ 1 = ∑ j ∈ S ˉ ∣ x j − ∑ i w j , i x i ∣
min x ∥ y − H x ∥ 2 2 + μ u x T L u x + μ d , 2 x T L r d x + μ d , 1 ∥ L r d x ∥ 1 \min_x \|y - Hx\|_2^2 + \mu_u x^T L^u x + \mu_{d,2} x^T \mathcal{L}_r^d x + \mu_{d,1} \|L_r^d x\|_1 min x ∥ y − H x ∥ 2 2 + μ u x T L u x + μ d , 2 x T L r d x + μ d , 1 ∥ L r d x ∥ 1
ここでH H H はサンプリング行列、μ u , μ d , 2 , μ d , 1 \mu_u, \mu_{d,2}, \mu_{d,1} μ u , μ d , 2 , μ d , 1 は重み付けパラメータである。
補助変数ϕ \phi ϕ を導入することで、最適化問題を以下に変換する:
min x , ϕ ∥ y − H x ∥ 2 2 + μ u x T L u x + μ d , 2 x T L r d x + μ d , 1 ∥ ϕ ∥ 1 \min_{x,\phi} \|y - Hx\|_2^2 + \mu_u x^T L^u x + \mu_{d,2} x^T \mathcal{L}_r^d x + \mu_{d,1} \|\phi\|_1 min x , ϕ ∥ y − H x ∥ 2 2 + μ u x T L u x + μ d , 2 x T L r d x + μ d , 1 ∥ ϕ ∥ 1 s.t. ϕ = L r d x \text{s.t. } \phi = L_r^d x s.t. ϕ = L r d x
x x x 部分問題 :共役勾配法により線形システムを求解ϕ \phi ϕ 部分問題 :ソフト閾値処理
ϕ i τ + 1 = sign ( δ ) ⋅ max ( ∣ δ ∣ − ρ − 1 μ d , 1 , 0 ) \phi_i^{\tau+1} = \text{sign}(\delta) \cdot \max(|\delta| - \rho^{-1}\mu_{d,1}, 0) ϕ i τ + 1 = sign ( δ ) ⋅ max ( ∣ δ ∣ − ρ − 1 μ d , 1 , 0 )
ここでδ = ( L r d ) i x τ + 1 − ρ − 1 γ i τ \delta = (L_r^d)_i x^{\tau+1} - \rho^{-1}\gamma_i^\tau δ = ( L r d ) i x τ + 1 − ρ − 1 γ i τ マハラノビス距離を用いてノード相似性を計算:
d u ( i , j ) = ( f i u − f j u ) T M ( f i u − f j u ) d^u(i,j) = (f_i^u - f_j^u)^T M (f_i^u - f_j^u) d u ( i , j ) = ( f i u − f j u ) T M ( f i u − f j u )
辺の重みは正規化指数関数により計算:
w i , j u = exp ( − d u ( i , j ) ) ∑ l ∈ N i exp ( − d u ( i , l ) ) ∑ k ∈ N j exp ( − d u ( k , j ) ) w_{i,j}^u = \frac{\exp(-d^u(i,j))}{\sqrt{\sum_{l \in \mathcal{N}_i} \exp(-d^u(i,l))} \sqrt{\sum_{k \in \mathcal{N}_j} \exp(-d^u(k,j))}} w i , j u = ∑ l ∈ N i e x p ( − d u ( i , l )) ∑ k ∈ N j e x p ( − d u ( k , j )) e x p ( − d u ( i , j ))
同様に計量行列P P P を用いて有向辺の重みを計算。
ADMMの各反復をニューラル層として実装:
5つのADMMブロック、各ブロック25層 各ブロック前にグラフ学習モジュールを挿入 マルチヘッドアテンション機構を使用(4つの並列グラフ学習モジュール) METR-LA :ロサンゼルス交通速度データ、207ノード、1315辺PEMS03 :交通流量データ、358ノード、547辺サンプリング間隔:5分 データ分割:6:2:2(訓練:検証:テスト) RMSE :二乗平均平方根誤差MAE :平均絶対誤差MAPE :平均絶対パーセント誤差6つのカテゴリーの基線手法を含む:
モデルベース:VAR GNN手法:STGCN, STSGCN GAT手法:GMAN, ST-Wave Transformer手法:PDFormer, STAEformer 適応グラフ手法:Graph WaveNet, AGCRN シンプルな線形モデル:STID, SimpleTM 予測期間:30/60/120分(6/12/24ステップ) 履歴ウィンドウ:60分(12ステップ) オプティマイザー:Adam、学習率5×10⁻⁴ 損失関数:Huber損失(δ=1) ハードウェア:NVIDIA GeForce RTX 3090 データセット 期間 本手法 最良基線 パラメータ数比較 PEMS03 30分 26.10/17.03/18.85 23.71/15.05/18.16 34K vs 531K PEMS03 60分 27.67/17.46/17.72 25.56/15.97/15.49 (6.4%パラメータ) METR-LA 60分 12.34/5.18/11.80 11.96/5.49/9.65
パラメータ効率 :PDFormerのわずか6.4%のパラメータ数で競争力のある性能を達成長期予測の優位性 :予測期間が長いほど、最良手法との性能差が小さくなるデータ効率 :データが不足している場合、より安定した性能を示す変種 PEMS03 (RMSE/MAE/MAPE) METR-LA (RMSE/MAE/MAPE) 完全モデル 27.67/17.46/17.72 12.34/5.18/11.80 DGTVなし 27.78/17.85/17.90 12.36/5.40/12.31 DGLRなし 30.89/20.02/21.10 12.41/5.35/12.20 無向時間グラフ 27.52/17.87/18.82 12.51/5.42/12.11
結果から以下が示される:
DGLR項が性能向上に最も重要 DGTV項も明らかな貢献を示す 有向グラフモデリングが無向グラフモデリングより優れている 定理3.1 は以下を証明する:無重み有向線グラフの場合、対称化有向グラフラプラシアンL r d = ( L r d ) T L r d \mathcal{L}_r^d = (L_r^d)^T L_r^d L r d = ( L r d ) T L r d は無向線グラフのラプラシアン行列と等価であり、周波数定義の妥当性を検証している。
大規模言語モデル:LoRA低ランク適応、パラメータ量子化 音声強調:局所因果自己注意 画像処理:YUVチャネル分離処理 GNN手法 :STGCN、Graph WaveNetなど、空間モデリングに注力Transformer手法 :時空間次元を別々に処理する二重Transformerシンプルな線形モデル :複雑なモデルの有効性に異議を唱える最適化アルゴリズムの反復をニューラル層にアンローリング 数学的解釈可能性とデータ駆動能力の両立 画像処理で既に成功事例がある 混合グラフアルゴリズムアンローリングは軽量で解釈可能な交通予測モデルの実現に成功 有向グラフ変分項は時間的因果関係を効果的に捉える パラメータ数を大幅削減しながら競争力のある性能を維持 距離制限 :学習されたマハラノビス距離は非負であるが、従来の自己注意は負になり得るグラフ疎性 :実際の道路接続に基づく制限がグラフの接続性を制限時間ウィンドウ固定 :事前定義された時間ウィンドウは十分に柔軟でない可能性がある符号付き距離とより複雑なグラフモデリングへの拡張 適応的時間ウィンドウ学習 他の時空間予測タスクへの応用 理論的革新 :有向グラフの周波数概念を初めて定義し、対応する正則化項を設計手法の新規性 :混合グラフアルゴリズムアンローリングはTransformer設計に新しい視点を提供実用的価値 :顕著なパラメータ削減は実際の展開に重要な意義を持つ解釈可能性 :各層は最適化アルゴリズムの反復に対応し、明確な数学的意味を持つ性能トレードオフ :いくつかの指標で最良の基線手法に及ばない適用範囲 :主に交通予測で検証されており、他の時空間タスクへの汎化性は未知理論分析 :収束性と複雑性の理論分析が不足している学術的貢献 :グラフ信号処理とTransformer設計に新しい視点を提供実用的価値 :軽量特性はモバイルデバイスとリソース制約環境に適している再現性 :オープンソースコードを提供し、実験設定が詳細リソース制約環境 :モバイルデバイス、エッジコンピューティングリアルタイム予測システム :高速応答が必要な交通管理システム解釈可能なAI応用 :モデルの透明性が必要な安全関連システム論文は以下を含む複数の重要な研究を引用している:
Transformer原論文 (Vaswani et al., 2017) アルゴリズムアンローリング総説 (Monga et al., 2021) グラフ信号処理基礎 (Ortega et al., 2018) 交通予測関連研究 (Li et al., 2017; Yu et al., 2018) 総合評価 :これは交通予測分野における革新的な研究であり、アルゴリズムアンローリングの考え方を混合グラフ設定に成功裏に拡張し、性能を維持しながらパラメータ数を大幅に削減している。いくつかの指標でなお改善の余地があるが、その軽量性と解釈可能性の特性により、重要な実用的価値と学術的意義を持つ。