2025-11-12T19:34:10.329996

Bayesian Active Learning By Distribution Disagreement

Werner, Schmidt-Thieme
Active Learning (AL) for regression has been systematically under-researched due to the increased difficulty of measuring uncertainty in regression models. Since normalizing flows offer a full predictive distribution instead of a point forecast, they facilitate direct usage of known heuristics for AL like Entropy or Least-Confident sampling. However, we show that most of these heuristics do not work well for normalizing flows in pool-based AL and we need more sophisticated algorithms to distinguish between aleatoric and epistemic uncertainty. In this work we propose BALSA, an adaptation of the BALD algorithm, tailored for regression with normalizing flows. With this work we extend current research on uncertainty quantification with normalizing flows \cite{berry2023normalizing, berry2023escaping} to real world data and pool-based AL with multiple acquisition functions and query sizes. We report SOTA results for BALSA across 4 different datasets and 2 different architectures.
academic

Apprentissage Actif Bayésien Par Désaccord de Distribution

Informations Fondamentales

  • ID de l'article : 2501.01248
  • Titre : Bayesian Active Learning By Distribution Disagreement
  • Auteurs : Thorben Werner, Lars Schmidt-Thieme (Université de Hildesheim)
  • Classification : cs.LG (Apprentissage Automatique)
  • Date de publication : 2 janvier 2025 (prépublication arXiv)
  • Lien de l'article : https://arxiv.org/abs/2501.01248

Résumé

L'apprentissage actif pour les tâches de régression est sous-étudié en raison de la difficulté à quantifier l'incertitude des modèles de régression. Bien que les flots normalisés fournissent des distributions prédictives complètes plutôt que des prédictions ponctuelles, permettant l'utilisation directe d'heuristiques connues telles que l'entropie ou l'échantillonnage le moins confiant, cet article démontre que ces heuristiques fonctionnent mal sur les flots normalisés dans l'apprentissage actif basé sur un ensemble de candidats. Il est nécessaire d'utiliser des algorithmes plus sophistiqués pour distinguer l'incertitude aléatoire de l'incertitude épistémique. Cet article propose l'algorithme BALSA, une version améliorée de l'algorithme BALD spécialement conçue pour les tâches de régression utilisant des flots normalisés. Ce travail étend la recherche sur la quantification de l'incertitude des flots normalisés à des données du monde réel et à l'apprentissage actif basé sur un ensemble de candidats avec diverses fonctions d'acquisition et tailles de requête. BALSA atteint des résultats de pointe sur 4 ensembles de données différents et 2 architectures différentes.

Contexte et Motivation de la Recherche

Définition du Problème

  1. Problème central : L'apprentissage actif pour les tâches de régression est gravement sous-étudié, principalement en raison de la difficulté à quantifier l'incertitude des modèles de régression par rapport aux tâches de classification
  2. Importance : L'apprentissage actif peut réduire la quantité de données annotées nécessaires pour entraîner des modèles robustes, mais les recherches existantes se concentrent principalement sur les problèmes de classification
  3. Limitations des méthodes existantes :
    • Les modèles de régression traditionnels (à l'exception des processus gaussiens) ne fournissent pas directement de quantification d'incertitude
    • Les heuristiques d'incertitude existantes (écart-type, moins confiant, entropie de Shannon) fonctionnent mal sur les flots normalisés
    • Incapacité à distinguer efficacement l'incertitude aléatoire (bruit des données) de l'incertitude épistémique (sous-ajustement du modèle)
  4. Motivation de la recherche : Les flots normalisés et les réseaux de neurones gaussiens émergents fournissent des distributions prédictives complètes, offrant de nouvelles opportunités pour l'apprentissage actif dans les tâches de régression

Contributions Principales

  1. Proposition de l'algorithme BALSA : Une version améliorée de l'algorithme BALD conçue pour les modèles avec distributions prédictives, incluant deux variantes (BALSA_KL et BALSA_EMD)
  2. Construction d'un benchmark complet : Création d'un benchmark exhaustif pour l'apprentissage actif de modèles avec distributions prédictives, contenant 3 lignes de base heuristiques et 3 versions adaptées de BALD
  3. Innovation technique : Deux nouveaux algorithmes d'extension BALD qui exploitent directement les distributions prédictives sans dépendre de méthodes d'agrégation
  4. Validation expérimentale : Comparaisons approfondies sur 4 ensembles de données du monde réel et 2 architectures de modèles, démontrant l'efficacité de la méthode

Détails de la Méthode

Définition de la Tâche

  • Entrée : Ensemble de données d'entraînement Dtrain:={(xi,yi)}i=1ND_{train} := \{(x_i, y_i)\}_{i=1}^N, où xX,yYx \in \mathcal{X}, y \in \mathcal{Y}
  • Objectif : Sélectionner les échantillons les plus précieux pour annotation via une stratégie d'apprentissage actif, minimisant le coût d'annotation
  • Contrainte : Paramètre d'apprentissage actif basé sur un ensemble de candidats avec un budget d'annotation fixe B

Architecture du Modèle

1. Modèles de Base

L'article utilise deux modèles de régression avec distributions prédictives :

  • Réseau de Neurones Gaussien (GNN) : Utilise un encodeur MLP pour produire les paramètres μ et σ, construisant une distribution prédictive gaussienne
  • Flot Normalisé (NF) : Utilise des transformations inversibles pour paramétrer une distribution prédictive de forme libre, capable de modéliser des distributions cibles plus complexes

2. Idée Centrale de l'Algorithme BALSA

BALSA est basé sur l'idée centrale de l'algorithme BALD, mais amélioré pour les distributions prédictives :

Formule BALD originale : BALD(x)=i=1k(H[yˉ(x)]H[y^θi(x)])BALD(x) = \sum_{i=1}^k (H[\bar{y}(x)] - H[\hat{y}_{\theta_i}(x)])

Stratégie d'amélioration de BALSA : BALD(x)=i=1kϕ(y^θi(x),yˉ(x))BALD(x) = \sum_{i=1}^k \phi(\hat{y}_{\theta_i}(x), \bar{y}(x))

où φ est une fonction de mesure qui quantifie directement la distance entre les distributions prédictives.

Points d'Innovation Technique

1. Calcul de la Distribution Moyenne

Méthode d'échantillonnage sur grille :

  • Normalisation des valeurs cibles à 0,1
  • Échantillonnage distribué sur 200 points de grille
  • Calcul du vecteur de vraisemblance et moyenne : pˉx=1kj=1kp^θjx\bar{p}|x = \frac{1}{k}\sum_{j=1}^k \hat{p}^⊣_{\theta_j}|x

Méthode de comparaison par paires :

  • Évite le calcul de la distribution moyenne
  • Utilise k-1 paires d'échantillons de paramètres : i=1k1ϕ(p^θix,p^θi+1x)\sum_{i=1}^{k-1} \phi(\hat{p}_{\theta_i}|x, \hat{p}_{\theta_{i+1}}|x)

2. Fonctions de Mesure de Distance

BALSA_KL (Divergence de Kullback-Leibler) :

  • Version grille : BALSAKLGrid(x)=i=1kKL(p^θix,pˉx)BALSA_{KL}^{Grid}(x) = \sum_{i=1}^k KL(\hat{p}^⊣_{\theta_i}|x, \bar{p}|x)
  • Version paires : BALSAKLPair(x)=i=1k1KL(p^θix,p^θi+1x)BALSA_{KL}^{Pair}(x) = \sum_{i=1}^{k-1} KL(\hat{p}_{\theta_i}|x, \hat{p}_{\theta_{i+1}}|x)

BALSA_EMD (Distance du Transport Optimal) : BALSAEMD(x)=i=1k1EMD(yθi,yθi+1)BALSA_{EMD}(x) = \sum_{i=1}^{k-1} EMD(y'_{\theta_i}, y'_{\theta_{i+1}})

yθp^θxy'_\theta \sim \hat{p}_\theta|x

Configuration Expérimentale

Ensembles de Données

Utilisation de 4 ensembles de données de régression couvrant différentes échelles et complexités :

Ensemble de DonnéesNombre de CaractéristiquesÉchantillons d'EntraînementEnsemble Initial AnnotéBudget
Parkinsons613,760200800
Superconducteurs8113,608200800
Sarcos2128,4702001,200
Diamants2634,5222001,200

Métriques d'Évaluation

  • Métrique principale : Vraisemblance logarithmique négative (NLL)
  • Métriques auxiliaires : Erreur absolue moyenne (MAE), Score CRPS
  • Méthode statistique : Test de rang signé de Wilcoxon, utilisation de diagrammes CD pour l'agrégation des résultats

Méthodes de Comparaison

  • Méthodes de clustering : Coreset, CoreGCN, TypiClust
  • Méthodes heuristiques : Écart-type (Std), Moins confiant (LC), Entropie de Shannon (Entropy)
  • Variantes BALD : BALD_σ, BALD_LC, BALD_H
  • Méthodes proposées : BALSA_KL Grille/Paires, BALSA_EMD

Détails d'Implémentation

  • Architecture du modèle : Encodeur MLP + décodeur de distribution
  • Flot normalisé : Flot neuronal de splines auto-régressif avec transformations de splines rationnelles quadratiques
  • Optimiseur : NAdam
  • Taux de Dropout : 0,008-0,05 (optimisé pour chaque ensemble de données)
  • Répétitions expérimentales : 30 répétitions pour chaque expérience

Résultats Expérimentaux

Résultats Principaux

Le diagramme de Différence Critique basé sur la métrique NLL montre :

  1. BALSA_KL Paires : Meilleur classement moyen, performance optimale
  2. BALSA_KL Grille : Suit de près, classement deuxième
  3. BALD_H : Classement troisième
  4. Coreset : Meilleure performance parmi les méthodes géométriques

Découvertes clés :

  • Les méthodes heuristiques traditionnelles (entropie, écart-type, moins confiant) fonctionnent très mal sur les flots normalisés
  • Les méthodes BALSA montrent des avantages évidents sur l'architecture des flots normalisés
  • Coreset et CoreGCN fonctionnent mieux sur l'architecture GNN

Études d'Ablation

1. Expérience en Mode Dual

Test de l'effet d'utiliser différents taux de dropout aux phases d'entraînement et d'évaluation :

  • Résultats incohérents : BALSA_EMD dual montre une baisse de performance, BALSA_KL Grille dual montre une légère amélioration
  • Hypothèse : Le changement de taux de dropout peut affecter la qualité des prédictions du modèle

2. Expérience de Renormalisation

Test de la version renormalisée de BALSA_KL Grille :

  • La version renormalisée montre une performance légèrement inférieure à la version non renormalisée
  • Choix de la formule non renormalisée plus simple

3. Expérience de Taille de Requête

Performance sur τ = {50, 200} :

  • Les méthodes d'échantillonnage d'incertitude maintiennent les performances avec des tailles de requête plus grandes
  • Les algorithmes de clustering (Coreset, TypiClust) montrent une baisse de performance plus rapide
  • Contraste avec les idées reçues communes dans les tâches de classification

Étude de Cas

La trajectoire d'apprentissage actif sur l'ensemble de données Diamants montre :

  • Les méthodes BALSA convergent plus rapidement
  • Les méthodes heuristiques traditionnelles approchent les performances d'échantillonnage aléatoire
  • Performances cohérentes sur les métriques NLL et MAE

Travaux Connexes

Apprentissage Actif pour la Régression

  • Méthodes géométriques : Coreset, CoreGCN, TypiClust basées sur les propriétés géométriques des données
  • Méthodes d'incertitude : La plupart sont liées à des architectures de modèles spécifiques, avec une généralité limitée
  • Algorithme BALD : L'une des rares méthodes indépendantes du modèle

Travaux les Plus Pertinents

Travaux de Berry et Meger 1,2 :

  • Proposition d'ensembles de flots normalisés et d'approximations MC dropout
  • Validation uniquement sur données synthétiques
  • Cet article étend à des données réelles et à plusieurs fonctions d'acquisition

Différences et Améliorations

  1. Utilisation de l'entropie de Shannon plutôt que simplement -∑logŷ_θ(x)
  2. Extension à des ensembles de données du monde réel
  3. Comparaison avec plusieurs algorithmes d'apprentissage actif

Conclusion et Discussion

Conclusions Principales

  1. Efficacité de la méthode : BALSA montre d'excellentes performances sur les flots normalisés, particulièrement la version BALSA_KL Paires
  2. Échec des heuristiques : Les heuristiques d'incertitude traditionnelles fonctionnent mal sur les flots normalisés
  3. Dépendance à l'architecture : Les différents algorithmes montrent des variations de performance significatives selon les architectures de modèles
  4. Impact de la taille de requête : Les méthodes d'incertitude sont plus stables avec des tailles de requête plus grandes

Limitations

  1. Analyse théorique insuffisante : Absence d'analyse de convergence théorique pour l'algorithme BALSA
  2. Surcharge computationnelle : MC dropout et calcul de distance de distribution augmentent les coûts computationnels
  3. Sensibilité aux hyperparamètres : Le choix du taux de dropout a un impact significatif sur les performances
  4. Limitation des ensembles de données : Validation sur seulement 4 ensembles de données, la généralisation reste à vérifier

Directions Futures

  1. Extension à d'autres méthodes d'échantillonnage de paramètres (Dynamique de Langevin, SVGD)
  2. Analyse théorique des propriétés de convergence de BALSA
  3. Exploration de mesures de distance de distribution supplémentaires
  4. Validation sur des ensembles de données de plus grande envergure

Évaluation Approfondie

Points Forts

  1. Importance du problème : Résout le problème négligé mais important de l'apprentissage actif pour la régression
  2. Créativité de la méthode : Première utilisation directe de distances de distribution pour l'apprentissage actif, évitant la perte d'information des méthodes d'agrégation
  3. Complétude expérimentale : Évaluation exhaustive sur plusieurs ensembles de données, architectures et métriques
  4. Valeur pratique : Fournit du code reproductible et des configurations expérimentales détaillées

Insuffisances

  1. Fondations théoriques faibles : Absence d'analyse théorique expliquant pourquoi BALSA est plus efficace
  2. Efficacité computationnelle : MC dropout et calcul EMD peuvent affecter les applications pratiques
  3. Ajustement des hyperparamètres : Le choix du taux de dropout manque de directives principielles
  4. Limitations d'évaluation : Basée principalement sur NLL, la cohérence avec d'autres métriques de régression reste à vérifier

Impact

  1. Contribution académique : Fournit une nouvelle direction de recherche pour l'apprentissage actif en régression
  2. Valeur pratique : Particulièrement adapté aux applications de régression nécessitant une quantification d'incertitude
  3. Reproductibilité : Code complet et configurations expérimentales fournis, facilitant les recherches ultérieures

Scénarios d'Application

  1. Calcul scientifique : Modélisation physique/chimique nécessitant une quantification d'incertitude
  2. Évaluation des risques : Domaines sensibles à l'incertitude comme la finance et la médecine
  3. Optimisation d'ingénierie : Problèmes d'optimisation de conception nécessitant d'équilibrer exploration et exploitation
  4. Séries temporelles : Tâches de prédiction avec distributions complexes

Références

Cet article s'appuie principalement sur les travaux clés suivants :

  1. Berry & Meger (2023) : Modélisation d'incertitude par ensembles de flots normalisés
  2. Gal et al. (2017) : Proposition originale de l'algorithme BALD
  3. Sener & Savarese (2017) : Méthode d'apprentissage actif Coreset
  4. Durkan et al. (2019) : Fondations techniques des flots de splines neuronaux

Évaluation Globale : Ceci est une recherche de haute qualité abordant le problème important mais négligé de l'apprentissage actif pour la régression. La proposition de l'algorithme BALSA comble le vide dans l'application des flots normalisés à l'apprentissage actif, avec une conception expérimentale suffisante et des résultats convaincants. Bien qu'il y ait encore de la place pour amélioration en termes d'analyse théorique et d'efficacité computationnelle, cette recherche apporte une contribution importante au développement du domaine.