2025-11-12T09:04:09.780506

SHAP-Based Supervised Clustering for Sample Classification and the Generalized Waterfall Plot

Lin, Fukuyama
In this growing age of data and technology, large black-box models are becoming the norm due to their ability to handle vast amounts of data and learn incredibly complex input-output relationships. The deficiency of these methods, however, is their inability to explain the prediction process, making them untrustworthy and their use precarious in high-stakes situations. SHapley Additive exPlanations (SHAP) analysis is an explainable AI method growing in popularity for its ability to explain model predictions in terms of the original features. For each sample and feature in the data set, we associate a SHAP value that quantifies the contribution of that feature to the prediction of that sample. Clustering these SHAP values can provide insight into the data by grouping samples that not only received the same prediction, but received the same prediction for similar reasons. In doing so, we map the various pathways through which distinct samples arrive at the same prediction. To showcase this methodology, we present a simulated experiment in addition to a case study in Alzheimer's disease using data from the Alzheimer's Disease Neuroimaging Initiative (ADNI) database. We also present a novel generalization of the waterfall plot for multi-classification.
academic

SHAP-Based Supervised Clustering for Sample Classification and the Generalized Waterfall Plot

基本信息

  • 论文ID: 2510.08737
  • 标题: SHAP-Based Supervised Clustering for Sample Classification and the Generalized Waterfall Plot
  • 作者: Justin Lin (Indiana University Mathematics Department), Julia Fukuyama (Indiana University Statistics Department)
  • 分类: cs.LG, stat.ME, stat.ML
  • 发表时间: 2025年10月9日 (arXiv预印本)
  • 论文链接: https://arxiv.org/abs/2510.08737v1

摘要

在数据和技术快速发展的时代,大型黑盒模型因其处理海量数据和学习复杂输入输出关系的能力而成为主流。然而,这些方法的缺陷在于无法解释预测过程,使其在高风险场景中的应用变得不可信且危险。SHAP(SHapley Additive exPlanations)分析作为一种可解释AI方法,因其能够用原始特征解释模型预测而日益流行。本文提出对SHAP值进行聚类分析,不仅能将获得相同预测的样本分组,更重要的是将因相似原因获得相同预测的样本分组。通过仿真实验和阿尔茨海默病案例研究(使用ADNI数据库),展示了该方法的有效性,并提出了多分类问题的瀑布图泛化方法。

研究背景与动机

问题定义

随着机器学习模型复杂度的不断提升,黑盒模型在预测准确性方面表现优异,但其缺乏可解释性的特点在医疗等高风险领域造成了应用障碍。传统的聚类分析仅基于原始数据特征,无法揭示样本到达相同预测结果的不同路径。

研究重要性

  1. 医学应用需求:在阿尔茨海默病等异质性疾病中,不同患者可能通过完全不同的病理机制到达相同的诊断结果
  2. 精准医疗:理解疾病的异质性有助于制定个性化治疗方案
  3. 模型可解释性:在高风险决策场景中,理解模型预测的原因至关重要

现有方法局限性

  1. 传统聚类方法:仅基于原始数据特征,无法捕获模型学习到的复杂输入输出关系
  2. SHAP值聚类研究稀少:现有文献中对SHAP值聚类的研究极为有限
  3. 可视化工具不足:多分类问题缺乏有效的SHAP值可视化方法

核心贡献

  1. 提出SHAP-based监督聚类方法:基于SHAP值而非原始数据进行聚类,揭示样本到达相同预测的不同路径
  2. 开发高维瀑布图:将传统瀑布图泛化到多分类问题,支持k维SHAP向量的可视化
  3. 提供完整的分析流程:包含预测建模、SHAP分析、可视化、聚类分析和聚类解释的五步工作流
  4. 验证方法有效性:通过仿真实验和阿尔茨海默病真实案例验证方法的实用性

方法详解

任务定义

给定训练数据集X' ⊂ X ⊂ R^p和训练好的模型f: X → R,对每个样本x ∈ X计算SHAP值φ(f;x)₁, ..., φ(f;x)ₚ,使得:

i=1pϕ(f;x)i=f(x)E[f(X)]\sum_{i=1}^{p} \phi(f;x)_i = f(x) - E[f(X')]

目标是对SHAP值矩阵进行聚类,发现具有相似模型解释的样本群组。

监督聚类工作流

1. 预测建模

  • 使用XGBoost构建预测模型
  • 通过重复交叉验证确保模型泛化性能

2. SHAP分析

  • 二分类:每个特征对应一个SHAP值
  • 多分类:每个特征对应k维SHAP向量(k为类别数)
  • 使用TreeSHAP算法计算树模型的SHAP值
  • 通过交叉验证避免过拟合

3. 可视化

  • 使用UMAP进行降维可视化
  • 保持局部结构,适合聚类检测

4. 聚类分析

  • 采用HDBSCAN进行层次密度聚类
  • 能够处理噪声和可变密度聚类

5. 聚类解释

  • 使用热图分析原始数据
  • 采用高维瀑布图解释聚类

高维瀑布图创新

传统瀑布图局限

传统瀑布图仅适用于一维SHAP值,无法处理多分类的k维SHAP向量。

解决方案

  1. 投影到类别子空间:选择两个类别,忽略其他类别的SHAP值,适合类别间的两两比较
  2. PCA投影:投影到保留最多信息的二维子空间,保留所有k个类别的信息但轴解释较复杂

数学表示

将SHAP向量序列视为k维空间中的路径,每个路径段对应一个特征的贡献,从平均预测点出发到达样本的具体预测点。

实验设置

数据集

仿真数据

  • 生成模型:多项逻辑回归
  • 样本规模:1,500个样本,10维特征
  • 设计思想:创建到达相同目标类别的不同路径
  • 函数定义
    • f₁(x) = 4x₁x₂ + 4x₁ + 4x₂ + Σβ₁,ᵢxᵢ
    • f₂(x) = 4x₁x₂ - 4x₁ - 4x₂ + Σβ₂,ᵢxᵢ
    • 其中βⱼ,ᵢ ~ N(0,1)

ADNI数据

  • 数据来源:阿尔茨海默病神经影像倡议数据库
  • 样本规模:2,422名患者,39个特征
  • 目标类别:认知正常(CN)、轻度认知障碍(MCI)、阿尔茨海默病/痴呆(AD)
  • 预处理:移除访问数据、设备信息等,线性缩放到0,1区间

评价指标

  • 分类性能:精确率、召回率、F1分数
  • 聚类质量:通过可视化和领域知识验证

实现细节

  • 预测模型:XGBoost
  • 降维方法:UMAP
  • 聚类算法:HDBSCAN
  • 交叉验证:重复交叉验证计算SHAP值

实验结果

仿真实验结果

模型性能

XGBoost模型在测试集上表现优异:

  • 整体准确率:90%
  • 各类别F1分数:0.88-0.92
  • 证明了模型解释的可靠性

聚类发现

  1. 原始数据无聚类结构:UMAP可视化显示原始数据无明显聚类模式
  2. SHAP值揭示4个聚类
    • 聚类0:x₁ < 0, x₂ < 0 → 类别0
    • 聚类3:x₁ > 0, x₂ > 0 → 类别1
    • 聚类1和2:x₁, x₂异号 → 类别2(两条不同路径)

高维瀑布图验证

  • 成功识别了到达类别2的两条不同路径
  • 聚类1:x₁ > 0, x₂ < 0
  • 聚类2:x₁ < 0, x₂ > 0

更精细聚类

进一步分析发现聚类3可细分为两个子聚类,主要区别在于特征8的贡献,验证了方法的稳定性。

ADNI案例研究结果

模型性能

  • 整体准确率:93%
  • 各类别表现:CN(F1=0.96)、MCI(F1=0.92)、AD(F1=0.86)

关键特征识别

  1. CDRSB(临床痴呆评定量表总分):最重要的预测因子
  2. LDELTOTAL:在CN和MCI区分中作用显著
  3. mPACCdigitMMSE:在MCI和AD区分中重要

聚类发现

  1. CN患者:聚类0和4,尽管APOE4基因型不同但SHAP模式相似
  2. MCI患者:聚类3和6
    • 聚类3:CDRSB对AD贡献为-1.50(保护性)
    • 聚类6:CDRSB对AD贡献为-0.50(风险性)
  3. AD患者:聚类1、2、5,展现不同的疾病路径

临床意义

  • 揭示了相同诊断类别内的异质性
  • CDRSB评估可用于MCI患者的风险分层
  • 不同AD聚类可能需要不同的治疗策略

相关工作

SHAP分析发展

  • 理论基础:基于Shapley值(Lloyd Shapley, 1953)
  • 现代发展:Lundberg和Lee (2017)将其应用于机器学习
  • TreeSHAP算法:专门用于树模型的SHAP值计算

聚类方法演进

  • 传统方法:K-means、层次聚类等基于原始特征
  • 密度聚类:DBSCAN及其改进版本HDBSCAN
  • 监督聚类:结合监督学习信息的聚类方法

SHAP值聚类研究

现有研究极为有限,本文是该领域的重要贡献之一,为后续研究奠定了基础。

结论与讨论

主要结论

  1. SHAP-based聚类有效性:能够发现原始数据中无法观察到的有意义分组
  2. 高维瀑布图实用性:成功解决了多分类SHAP值可视化问题
  3. 医学应用价值:在阿尔茨海默病研究中展现了实际应用潜力
  4. 疾病异质性洞察:揭示了相同诊断类别内的不同病理路径

局限性

  1. 计算复杂度:需要计算大量SHAP值,计算成本较高
  2. 模型依赖性:聚类结果依赖于底层预测模型的质量
  3. 参数敏感性:HDBSCAN等算法的参数选择可能影响结果
  4. 类别数限制:高维瀑布图的可视化仍受类别数量限制

未来方向

  1. 可视化方法扩展:开发其他SHAP图表的高维版本(条形图、热图、蜂群图等)
  2. 算法优化:提高大规模数据的计算效率
  3. 理论分析:建立SHAP-based聚类的理论基础
  4. 应用扩展:在更多领域验证方法的普适性

深度评价

优点

  1. 创新性强:首次系统性地提出SHAP-based监督聚类方法
  2. 实用价值高:在医学等高风险领域具有重要应用价值
  3. 方法完整:提供了从建模到解释的完整工作流程
  4. 验证充分:通过仿真和真实案例双重验证
  5. 可视化创新:高维瀑布图解决了多分类可解释性难题

不足

  1. 理论基础薄弱:缺乏对SHAP-based聚类的理论分析
  2. 计算效率:大规模应用时的计算复杂度问题未充分讨论
  3. 参数选择:聚类算法参数选择的指导原则不够明确
  4. 统计显著性:缺乏聚类结果的统计显著性检验
  5. 对比实验不足:与其他解释性聚类方法的比较有限

影响力

  1. 学术贡献:为可解释AI和监督聚类领域提供了新思路
  2. 实用价值:在精准医疗等领域具有直接应用潜力
  3. 方法推广:工作流程可推广到其他领域和问题
  4. 后续研究:为SHAP值的深度应用开辟了新方向

适用场景

  1. 医疗诊断:疾病异质性分析和个性化治疗
  2. 金融风控:客户风险分层和差异化策略
  3. 推荐系统:用户行为模式分析
  4. 质量控制:产品缺陷的不同成因分析

参考文献

论文引用了23篇重要文献,涵盖SHAP理论、聚类算法、可视化方法和阿尔茨海默病研究等多个领域,为跨学科研究提供了良好的理论支撑。


总体评价:这是一篇高质量的跨学科研究论文,在可解释AI和监督聚类的交叉领域做出了重要贡献。方法创新性强,实验验证充分,在医疗等高风险应用领域具有重要价值。尽管在理论分析和计算效率方面还有改进空间,但为后续研究奠定了良好基础。