2025-11-20T09:28:14.240195

Lightweight and Interpretable Transformer via Mixed Graph Algorithm Unrolling for Traffic Forecast

Qi, Do, Liu et al.
Unlike conventional "black-box" transformers with classical self-attention mechanism, we build a lightweight and interpretable transformer-like neural net by unrolling a mixed-graph-based optimization algorithm to forecast traffic with spatial and temporal dimensions. We construct two graphs: an undirected graph $\mathcal{G}^u$ capturing spatial correlations across geography, and a directed graph $\mathcal{G}^d$ capturing sequential relationships over time. We predict future samples of signal $\mathbf{x}$, assuming it is "smooth" with respect to both $\mathcal{G}^u$ and $\mathcal{G}^d$, where we design new $\ell_2$ and $\ell_1$-norm variational terms to quantify and promote signal smoothness (low-frequency reconstruction) on a directed graph. We design an iterative algorithm based on alternating direction method of multipliers (ADMM), and unroll it into a feed-forward network for data-driven parameter learning. We insert graph learning modules for $\mathcal{G}^u$ and $\mathcal{G}^d$ that play the role of self-attention. Experiments show that our unrolled networks achieve competitive traffic forecast performance as state-of-the-art prediction schemes, while reducing parameter counts drastically. Our code is available in https://github.com/SingularityUndefined/Unrolling-GSP-STForecast .
academic

混合グラフアルゴリズムアンローリングによる軽量で解釈可能なTransformerを用いた交通予測

基本情報

  • 論文ID: 2505.13102
  • タイトル: Lightweight and Interpretable Transformer via Mixed Graph Algorithm Unrolling for Traffic Forecast
  • 著者: Ji Qi, Mingxiao Liu, Tam Thuc Do, Yuzhe Li, Zhuoshi Pan, Gene Cheung, H. Vicky Zhao
  • 分類: cs.LG cs.AI eess.SP
  • 発表日: 2025年10月12日 (arXiv v2)
  • 論文リンク: https://arxiv.org/abs/2505.13102

要旨

本論文は、混合グラフアルゴリズムアンローリングに基づく軽量で解釈可能なTransformerモデルを交通予測に提案している。従来の「ブラックボックス」Transformerとは異なり、本手法は混合グラフ最適化アルゴリズムをアンローリングすることで、解釈可能なTransformer型ニューラルネットワークを構築している。モデルは2つのグラフを構築する:無向グラフGu\mathcal{G}^uは地理的空間相関性を捉え、有向グラフGd\mathcal{G}^dは時間的関係を捉える。有向グラフ上の信号平滑性を定量化・促進するための新しい2\ell_2および1\ell_1ノルム変分項を設計し、交互方向乗数法(ADMM)に基づいて反復アルゴリズムを設計し、これをフィードフォワードネットワークにアンローリングしてデータ駆動型のパラメータ学習を行う。実験により、本モデルは競争力のある交通予測性能を維持しながら、パラメータ数を大幅に削減することが示された。

研究背景と動機

問題定義

交通予測は重要な時空間データモデリング問題であり、以下を同時に捉える必要がある:

  1. 空間相関性:地理的に近い監視地点間の相関性
  2. 時間依存性:過去の観測が将来に与える影響関係

既存手法の限界

  1. 従来のTransformer:パラメータ数が膨大で解釈性に欠け、実際の展開時に計算とメモリの制約に直面する
  2. モデルベースの手法:空間と時間の次元を独立に処理することが多く、時空間関係を十分に活用できない
  3. 既存の深層学習手法:性能は優れているがなお「ブラックボックス」モデルであり、パラメータ数が多い

研究動機

  1. 産業応用における軽量モデルの緊急の需要
  2. アルゴリズムアンローリング(Algorithm Unrolling)がモデル駆動とデータ駆動を組み合わせた新しいパラダイムを提供
  3. 既存の研究は正の無向グラフのみを使用しており、複雑な時空間関係を効果的にモデル化できない

核心的貢献

  1. 混合グラフアルゴリズムアンローリングの初提案:無向グラフ(空間)と有向グラフ(時間)を組み合わせて複雑な時空間関係をモデル化
  2. 革新的な有向グラフ正則化項:有向グラフラプラシアン正則化器(DGLR)と有向グラフ全変動(DGTV)を設計
  3. 軽量で解釈可能なTransformer:ADMMアルゴリズムアンローリングにより、パラメータを大幅削減(PDFormerのわずか6.4%)
  4. 理論的貢献:有向グラフ周波数定義が無重み有向線グラフの場合に古典的フーリエ周波数に退化することを証明

方法の詳細

タスク定義

N個の監視地点における過去T+1時刻の観測値が与えられたとき、将来S時刻の交通状態を予測する。入力は部分的に観測された時空間信号yRMy \in \mathbb{R}^Mであり、出力は完全な時空間信号xRN(T+S+1)x \in \mathbb{R}^{N(T+S+1)}である。

混合グラフの構築

無向グラフGu\mathcal{G}^u

  • 同一時刻の地理的に近いノード同士を接続
  • 空間相関性を捉える
  • 対称隣接行列WuW^uを使用

有向グラフGd\mathcal{G}^d

  • 時刻τ\tauのノードからτ+1,...,τ+W\tau+1, ..., \tau+W時刻の同じノードへ接続
  • 時間的因果関係を捉える
  • 非対称隣接行列WdW^dを使用

有向グラフ変分項の設計

2\ell_2ノルム項:有向グラフラプラシアン正則化器(DGLR)

xTLrdx=xT(Lrd)TLrdx=xWrdx22x^T\mathcal{L}_r^d x = x^T(L_r^d)^T L_r^d x = \|x - W_r^d x\|_2^2

ここでLrd=IWrdL_r^d = I - W_r^dは確率的ウォークラプラシアン行列、Wrd=(Dd)1WdW_r^d = (D^d)^{-1}W^dは行確率的隣接行列である。

1\ell_1ノルム項:有向グラフ全変動(DGTV)

Lrdx1=jSˉxjiwj,ixi\|L_r^d x\|_1 = \sum_{j \in \bar{S}} |x_j - \sum_i w_{j,i} x_i|

最適化目的関数

minxyHx22+μuxTLux+μd,2xTLrdx+μd,1Lrdx1\min_x \|y - Hx\|_2^2 + \mu_u x^T L^u x + \mu_{d,2} x^T \mathcal{L}_r^d x + \mu_{d,1} \|L_r^d x\|_1

ここでHHはサンプリング行列、μu,μd,2,μd,1\mu_u, \mu_{d,2}, \mu_{d,1}は重み付けパラメータである。

ADMMアルゴリズム設計

補助変数ϕ\phiを導入することで、最適化問題を以下に変換する: minx,ϕyHx22+μuxTLux+μd,2xTLrdx+μd,1ϕ1\min_{x,\phi} \|y - Hx\|_2^2 + \mu_u x^T L^u x + \mu_{d,2} x^T \mathcal{L}_r^d x + \mu_{d,1} \|\phi\|_1s.t. ϕ=Lrdx\text{s.t. } \phi = L_r^d x

部分問題の求解

  1. xx部分問題:共役勾配法により線形システムを求解
  2. ϕ\phi部分問題:ソフト閾値処理 ϕiτ+1=sign(δ)max(δρ1μd,1,0)\phi_i^{\tau+1} = \text{sign}(\delta) \cdot \max(|\delta| - \rho^{-1}\mu_{d,1}, 0) ここでδ=(Lrd)ixτ+1ρ1γiτ\delta = (L_r^d)_i x^{\tau+1} - \rho^{-1}\gamma_i^\tau

グラフ学習モジュール

無向グラフ学習(UGL)

マハラノビス距離を用いてノード相似性を計算: du(i,j)=(fiufju)TM(fiufju)d^u(i,j) = (f_i^u - f_j^u)^T M (f_i^u - f_j^u)

辺の重みは正規化指数関数により計算: wi,ju=exp(du(i,j))lNiexp(du(i,l))kNjexp(du(k,j))w_{i,j}^u = \frac{\exp(-d^u(i,j))}{\sqrt{\sum_{l \in \mathcal{N}_i} \exp(-d^u(i,l))} \sqrt{\sum_{k \in \mathcal{N}_j} \exp(-d^u(k,j))}}

有向グラフ学習(DGL)

同様に計量行列PPを用いて有向辺の重みを計算。

ネットワークアーキテクチャ

ADMMの各反復をニューラル層として実装:

  • 5つのADMMブロック、各ブロック25層
  • 各ブロック前にグラフ学習モジュールを挿入
  • マルチヘッドアテンション機構を使用(4つの並列グラフ学習モジュール)

実験設定

データセット

  • METR-LA:ロサンゼルス交通速度データ、207ノード、1315辺
  • PEMS03:交通流量データ、358ノード、547辺
  • サンプリング間隔:5分
  • データ分割:6:2:2(訓練:検証:テスト)

評価指標

  • RMSE:二乗平均平方根誤差
  • MAE:平均絶対誤差
  • MAPE:平均絶対パーセント誤差

比較手法

6つのカテゴリーの基線手法を含む:

  • モデルベース:VAR
  • GNN手法:STGCN, STSGCN
  • GAT手法:GMAN, ST-Wave
  • Transformer手法:PDFormer, STAEformer
  • 適応グラフ手法:Graph WaveNet, AGCRN
  • シンプルな線形モデル:STID, SimpleTM

実装詳細

  • 予測期間:30/60/120分(6/12/24ステップ)
  • 履歴ウィンドウ:60分(12ステップ)
  • オプティマイザー:Adam、学習率5×10⁻⁴
  • 損失関数:Huber損失(δ=1)
  • ハードウェア:NVIDIA GeForce RTX 3090

実験結果

主要結果

データセット期間本手法最良基線パラメータ数比較
PEMS0330分26.10/17.03/18.8523.71/15.05/18.1634K vs 531K
PEMS0360分27.67/17.46/17.7225.56/15.97/15.49(6.4%パラメータ)
METR-LA60分12.34/5.18/11.8011.96/5.49/9.65

主要な発見

  1. パラメータ効率:PDFormerのわずか6.4%のパラメータ数で競争力のある性能を達成
  2. 長期予測の優位性:予測期間が長いほど、最良手法との性能差が小さくなる
  3. データ効率:データが不足している場合、より安定した性能を示す

アブレーション実験

変種PEMS03 (RMSE/MAE/MAPE)METR-LA (RMSE/MAE/MAPE)
完全モデル27.67/17.46/17.7212.34/5.18/11.80
DGTVなし27.78/17.85/17.9012.36/5.40/12.31
DGLRなし30.89/20.02/21.1012.41/5.35/12.20
無向時間グラフ27.52/17.87/18.8212.51/5.42/12.11

結果から以下が示される:

  • DGLR項が性能向上に最も重要
  • DGTV項も明らかな貢献を示す
  • 有向グラフモデリングが無向グラフモデリングより優れている

理論的検証

定理3.1は以下を証明する:無重み有向線グラフの場合、対称化有向グラフラプラシアンLrd=(Lrd)TLrd\mathcal{L}_r^d = (L_r^d)^T L_r^dは無向線グラフのラプラシアン行列と等価であり、周波数定義の妥当性を検証している。

関連研究

軽量モデル

  • 大規模言語モデル:LoRA低ランク適応、パラメータ量子化
  • 音声強調:局所因果自己注意
  • 画像処理:YUVチャネル分離処理

交通予測手法

  1. GNN手法:STGCN、Graph WaveNetなど、空間モデリングに注力
  2. Transformer手法:時空間次元を別々に処理する二重Transformer
  3. シンプルな線形モデル:複雑なモデルの有効性に異議を唱える

アルゴリズムアンローリング

  • 最適化アルゴリズムの反復をニューラル層にアンローリング
  • 数学的解釈可能性とデータ駆動能力の両立
  • 画像処理で既に成功事例がある

結論と考察

主要な結論

  1. 混合グラフアルゴリズムアンローリングは軽量で解釈可能な交通予測モデルの実現に成功
  2. 有向グラフ変分項は時間的因果関係を効果的に捉える
  3. パラメータ数を大幅削減しながら競争力のある性能を維持

限界

  1. 距離制限:学習されたマハラノビス距離は非負であるが、従来の自己注意は負になり得る
  2. グラフ疎性:実際の道路接続に基づく制限がグラフの接続性を制限
  3. 時間ウィンドウ固定:事前定義された時間ウィンドウは十分に柔軟でない可能性がある

今後の方向性

  1. 符号付き距離とより複雑なグラフモデリングへの拡張
  2. 適応的時間ウィンドウ学習
  3. 他の時空間予測タスクへの応用

深層的評価

強み

  1. 理論的革新:有向グラフの周波数概念を初めて定義し、対応する正則化項を設計
  2. 手法の新規性:混合グラフアルゴリズムアンローリングはTransformer設計に新しい視点を提供
  3. 実用的価値:顕著なパラメータ削減は実際の展開に重要な意義を持つ
  4. 解釈可能性:各層は最適化アルゴリズムの反復に対応し、明確な数学的意味を持つ

不足点

  1. 性能トレードオフ:いくつかの指標で最良の基線手法に及ばない
  2. 適用範囲:主に交通予測で検証されており、他の時空間タスクへの汎化性は未知
  3. 理論分析:収束性と複雑性の理論分析が不足している

影響力

  1. 学術的貢献:グラフ信号処理とTransformer設計に新しい視点を提供
  2. 実用的価値:軽量特性はモバイルデバイスとリソース制約環境に適している
  3. 再現性:オープンソースコードを提供し、実験設定が詳細

適用シーン

  1. リソース制約環境:モバイルデバイス、エッジコンピューティング
  2. リアルタイム予測システム:高速応答が必要な交通管理システム
  3. 解釈可能なAI応用:モデルの透明性が必要な安全関連システム

参考文献

論文は以下を含む複数の重要な研究を引用している:

  • Transformer原論文 (Vaswani et al., 2017)
  • アルゴリズムアンローリング総説 (Monga et al., 2021)
  • グラフ信号処理基礎 (Ortega et al., 2018)
  • 交通予測関連研究 (Li et al., 2017; Yu et al., 2018)

総合評価:これは交通予測分野における革新的な研究であり、アルゴリズムアンローリングの考え方を混合グラフ設定に成功裏に拡張し、性能を維持しながらパラメータ数を大幅に削減している。いくつかの指標でなお改善の余地があるが、その軽量性と解釈可能性の特性により、重要な実用的価値と学術的意義を持つ。