Active Learning (AL) for regression has been systematically under-researched due to the increased difficulty of measuring uncertainty in regression models. Since normalizing flows offer a full predictive distribution instead of a point forecast, they facilitate direct usage of known heuristics for AL like Entropy or Least-Confident sampling. However, we show that most of these heuristics do not work well for normalizing flows in pool-based AL and we need more sophisticated algorithms to distinguish between aleatoric and epistemic uncertainty. In this work we propose BALSA, an adaptation of the BALD algorithm, tailored for regression with normalizing flows. With this work we extend current research on uncertainty quantification with normalizing flows \cite{berry2023normalizing, berry2023escaping} to real world data and pool-based AL with multiple acquisition functions and query sizes. We report SOTA results for BALSA across 4 different datasets and 2 different architectures.
- Paper-ID: 2501.01248
- Titel: Bayesian Active Learning By Distribution Disagreement
- Autoren: Thorben Werner, Lars Schmidt-Thieme (Universität Hildesheim)
- Klassifikation: cs.LG (Machine Learning)
- Veröffentlichungsdatum: 2. Januar 2025 (arXiv-Preprint)
- Paper-Link: https://arxiv.org/abs/2501.01248
Das aktive Lernen für Regressionsprobleme ist untererforscht, da die Quantifizierung der Unsicherheit von Regressionsmodellen schwierig ist. Obwohl normalisierte Flüsse vollständige Vorhersageverteilungen statt Punktvorhersagen liefern und die direkte Verwendung bekannter Heuristiken wie Entropie oder Least Confidence Sampling ermöglichen, zeigt diese Arbeit, dass diese Heuristiken bei normalisierten Flüssen in Pool-basiertem aktivem Lernen schlecht funktionieren und komplexere Algorithmen zur Unterscheidung zwischen aleathorischer und epistemischer Unsicherheit erforderlich sind. Die Arbeit schlägt den BALSA-Algorithmus vor, eine verbesserte Version des BALD-Algorithmus, speziell für Regressionsprobleme mit normalisierten Flüssen. Diese Arbeit erweitert die Forschung zur Unsicherheitsquantifizierung normalisierter Flüsse auf reale Daten und Pool-basiertes aktives Lernen mit verschiedenen Akquisitionsfunktionen und Abfragegrößen. BALSA erreicht State-of-the-Art-Ergebnisse auf 4 verschiedenen Datensätzen und 2 verschiedenen Architekturen.
- Kernproblem: Das aktive Lernen für Regressionsprobleme ist stark untererforscht, hauptsächlich weil die Unsicherheitsquantifizierung von Regressionsmodellen schwieriger ist als bei Klassifikationsaufgaben
- Bedeutung: Aktives Lernen kann die Menge der annotierten Daten reduzieren, die zum Trainieren starker Modelle erforderlich ist, aber die bestehende Forschung konzentriert sich hauptsächlich auf Klassifikationsprobleme
- Einschränkungen bestehender Methoden:
- Traditionelle Regressionsmodelle (außer Gaußschen Prozessen) können nicht direkt Unsicherheitsquantifizierung bereitstellen
- Bestehende Unsicherheitsheuristiken (wie Standardabweichung, Least Confidence, Shannon-Entropie) funktionieren schlecht bei normalisierten Flüssen
- Können nicht effektiv zwischen aleathorischer Unsicherheit (Datenlärm) und epistemischer Unsicherheit (Modellunterpassung) unterscheiden
- Forschungsmotivation: Neue Modelle wie normalisierte Flüsse und Gaußsche neuronale Netze bieten vollständige Vorhersageverteilungen und eröffnen neue Möglichkeiten für aktives Lernen bei Regressionsprobleme
- Vorschlag des BALSA-Algorithmus: Eine verbesserte Version des BALD-Algorithmus, die für Modelle mit Vorhersageverteilungen konzipiert ist, mit zwei Varianten (BALSAKL und BALSAEMD)
- Aufbau einer umfassenden Benchmark: Erstellung einer umfassenden Benchmark für aktives Lernen mit Modellen mit Vorhersageverteilungen, einschließlich 3 Heuristik-Baselines und 3 BALD-Adaptationen
- Technische Innovation: Zwei neue BALD-Erweiterungsalgorithmen, die Vorhersageverteilungen direkt nutzen, anstatt sich auf Aggregationsmethoden zu verlassen
- Experimentelle Validierung: Umfangreiche Vergleiche auf 4 realen Datensätzen und 2 Modellarchitekturen, die die Wirksamkeit der Methode demonstrieren
- Eingabe: Trainingsdatensatz Dtrain:={(xi,yi)}i=1N, wobei x∈X,y∈Y
- Ziel: Durch eine aktive Lernstrategie die wertvollsten Stichproben zur Annotation auswählen und die Annotationskosten minimieren
- Einschränkung: Pool-basierte aktive Lerneinstellung mit festem Annotationsbudget B
Die Arbeit verwendet zwei Regressionsmodelle mit Vorhersageverteilungen:
- Gaußsche neuronale Netze (GNN): Verwenden einen MLP-Encoder zur Erzeugung von μ- und σ-Parametern und konstruieren eine Gaußsche Vorhersageverteilung
- Normalisierte Flüsse (NF): Verwenden invertierbare Transformationen zur Parametrisierung von freiformigen Vorhersageverteilungen und können komplexere Zielverteilungen modellieren
BALSA basiert auf der Kernidee des BALD-Algorithmus, wurde aber für Vorhersageverteilungen verbessert:
Ursprüngliche BALD-Formel:
BALD(x)=∑i=1k(H[yˉ(x)]−H[y^θi(x)])
BALSA-Verbesserungsstrategie:
BALD(x)=∑i=1kϕ(y^θi(x),yˉ(x))
wobei φ eine Maßfunktion ist, die direkt den Abstand zwischen Vorhersageverteilungen misst.
Gitter-Sampling-Methode:
- Normalisierung der Zielwerte auf 0,1
- Sampling über 200 Gitterpunkte verteilt
- Berechnung des Likelihood-Vektors und Mittelwertbildung: pˉ∣x=k1∑j=1kp^θj⊣∣x
Paarweise Vergleichsmethode:
- Vermeidung der Berechnung der Durchschnittsverteilung
- Verwendung von k-1 Parameterpaaren: ∑i=1k−1ϕ(p^θi∣x,p^θi+1∣x)
BALSAKL (KL-Divergenz):
- Gitter-Version: BALSAKLGrid(x)=∑i=1kKL(p^θi⊣∣x,pˉ∣x)
- Paarweise Version: BALSAKLPair(x)=∑i=1k−1KL(p^θi∣x,p^θi+1∣x)
BALSAEMD (Earth Mover's Distance):
BALSAEMD(x)=∑i=1k−1EMD(yθi′,yθi+1′)
wobei yθ′∼p^θ∣x
Verwendung von 4 Regressionsdatensätzen mit unterschiedlichen Größen und Komplexitäten:
| Datensatz | Merkmale | Trainingsmuster | Initiale Annotation | Budget |
|---|
| Parkinsons | 61 | 3.760 | 200 | 800 |
| Superconductors | 81 | 13.608 | 200 | 800 |
| Sarcos | 21 | 28.470 | 200 | 1.200 |
| Diamonds | 26 | 34.522 | 200 | 1.200 |
- Hauptmetrik: Negative Log-Likelihood (NLL)
- Hilfsmetriken: Mittlerer absoluter Fehler (MAE), CRPS-Score
- Statistische Methode: Wilcoxon-Vorzeichenrangtest, CD-Diagramme für Ergebnisaggregation
- Clustering-Methoden: Coreset, CoreGCN, TypiClust
- Heuristik-Methoden: Standardabweichung (Std), Least Confidence (LC), Shannon-Entropie (Entropy)
- BALD-Varianten: BALDσ, BALDLC, BALDH
- Vorgeschlagene Methoden: BALSAKL Grid/Pair, BALSAEMD
- Modellarchitektur: MLP-Encoder + Verteilungs-Decoder
- Normalisierte Flüsse: Autoregressive neuronale Spline-Flüsse mit rationalen quadratischen Spline-Transformationen
- Optimierer: NAdam
- Dropout-Rate: 0,008-0,05 (für jeden Datensatz optimiert)
- Experimentwiederholungen: Jedes Experiment 30-mal wiederholt
Critical Difference-Diagramm basierend auf NLL-Metrik zeigt:
- BALSAKL Pairs: Beste durchschnittliche Rangfolge, optimale Leistung
- BALSAKL Grid: Dicht dahinter, zweiter Platz
- BALDH: Dritter Platz
- Coreset: Beste Leistung unter geometrischen Methoden
Wichtigste Erkenntnisse:
- Traditionelle Heuristiken (Entropie, Standardabweichung, Least Confidence) funktionieren schlecht bei normalisierten Flüssen
- BALSA-Methoden zeigen deutliche Vorteile bei normalisierten Fluss-Architekturen
- Coreset und CoreGCN funktionieren besser bei GNN-Architekturen
Test der Auswirkungen unterschiedlicher Dropout-Raten in Trainings- und Bewertungsphasen:
- Inkonsistente Ergebnisse: BALSAEMD dual zeigt Leistungsabfall, BALSAKL Grid dual zeigt leichte Verbesserung
- Hypothese: Dropout-Rate-Wechsel könnte die Modellvorhersagequalität beeinflussen
Test der normalisierten Version von BALSAKL Grid:
- Normalisierte Version zeigt etwas niedrigere Leistung als nicht normalisierte Version
- Wahl der einfacheren nicht normalisierten Formel
Leistung bei τ = {50, 200}:
- Unsicherheits-Sampling-Methoden behalten Leistung bei großen Abfragegrößen
- Clustering-Algorithmen (Coreset, TypiClust) zeigen schnelleren Leistungsabfall
- Widerspricht gängigen Erkenntnissen bei Klassifikationsaufgaben
Aktive Lernverlauf des Diamonds-Datensatzes zeigt:
- BALSA-Methoden konvergieren schneller
- Traditionelle Heuristiken nähern sich zufälligem Sampling an
- Konsistente Leistung bei NLL- und MAE-Metriken
- Geometrische Methoden: Coreset, CoreGCN, TypiClust und andere basierend auf Datengeometrie-Eigenschaften
- Unsicherheitsmethoden: Meisten an spezifische Modellarchitekturen gebunden, geringe Allgemeingültigkeit
- BALD-Algorithmus: Einer der wenigen modellunabhängigen Ansätze
Berry und Meger 1,2:
- Schlagen normalisierte Fluss-Ensembles und MC-Dropout-Approximation vor
- Nur auf synthetischen Daten validiert
- Diese Arbeit erweitert auf reale Daten und mehrere Akquisitionsfunktionen
- Verwendung von Shannon-Entropie statt einfacher -∑logŷθ(x)
- Erweiterung auf reale Datensätze
- Vergleich mit mehreren aktiven Lernalgorithmen
- Methodische Wirksamkeit: BALSA zeigt hervorragende Leistung bei normalisierten Flüssen, besonders die BALSAKL Pairs-Version
- Heuristik-Versagen: Traditionelle Unsicherheitsheuristiken funktionieren schlecht bei normalisierten Flüssen
- Architektur-Abhängigkeit: Verschiedene Algorithmen zeigen signifikante Leistungsunterschiede bei verschiedenen Modellarchitekturen
- Abfragegrößen-Einfluss: Unsicherheitsmethoden sind bei großen Abfragegrößen stabiler
- Unzureichende theoretische Analyse: Fehlende Konvergenzanalyse des BALSA-Algorithmus
- Rechenkomplexität: MC-Dropout und Verteilungsdistanzberechnung erhöhen Rechenkosten
- Hyperparameter-Sensitivität: Dropout-Rate-Wahl hat großen Einfluss auf Leistung
- Datensatz-Einschränkung: Validierung nur auf 4 Datensätzen, Verallgemeinerbarkeit unklar
- Erweiterung auf andere Parametersampling-Methoden (Langevin Dynamics, SVGD)
- Theoretische Analyse der Konvergenzeigenschaften von BALSA
- Untersuchung weiterer Verteilungsdistanzmaße
- Validierung auf größeren Datensätzen
- Problemrelevanz: Löst das vernachlässigte aber wichtige Problem des aktiven Lernens bei Regression
- Methodische Innovativität: Erste direkte Verwendung von Verteilungsdistanzen für aktives Lernen, vermeidet Informationsverlust durch Aggregationsmethoden
- Experimentelle Umfassendheit: Umfassende Bewertung über mehrere Datensätze, Architekturen und Metriken
- Praktischer Wert: Bereitstellung von reproduzierbarem Code und detaillierten Experimenteinstellungen
- Schwache theoretische Grundlagen: Fehlende theoretische Analyse zur Erklärung der BALSA-Wirksamkeit
- Rechnerische Effizienz: MC-Dropout und EMD-Berechnung könnten praktische Anwendung beeinflussen
- Hyperparameter-Optimierung: Dropout-Rate-Wahl fehlt prinzipiengestützte Anleitung
- Bewertungsbeschränkungen: Hauptsächlich auf NLL basiert, Konsistenz mit anderen Regressions-Metriken unklar
- Akademischer Beitrag: Eröffnet neue Forschungsrichtung für Regressions-Aktives Lernen
- Praktischer Wert: Besonders geeignet für Regressionsanwendungen, die Unsicherheitsquantifizierung erfordern
- Reproduzierbarkeit: Vollständiger Code und Experimentkonfigurationen ermöglichen Folgeforschen
- Wissenschaftliche Berechnung: Physik-/Chemie-Modellierung mit erforderlicher Unsicherheitsquantifizierung
- Risikobewertung: Finanz-, Medizin- und andere Bereiche mit Unsicherheitssensitivität
- Ingenieuroptimierung: Designoptimierungsprobleme, die Explorations-Exploitations-Abwägung erfordern
- Zeitreihen: Vorhersageaufgaben mit komplexen Verteilungen
Diese Arbeit bezieht sich hauptsächlich auf folgende Schlüsselarbeiten:
- Berry & Meger (2023): Unsicherheitsmodellierung mit normalisierten Fluss-Ensembles
- Gal et al. (2017): Ursprüngliche Einführung des BALD-Algorithmus
- Sener & Savarese (2017): Coreset-Methode für aktives Lernen
- Durkan et al. (2019): Technische Grundlagen neuronaler Spline-Flüsse
Gesamtbewertung: Dies ist eine hochwertige Forschungsarbeit zu dem wichtigen, aber vernachlässigten Problem des Regressions-Aktiven Lernens. Der Vorschlag des BALSA-Algorithmus füllt die Lücke bei der Anwendung normalisierter Flüsse im aktiven Lernen, das Experimentdesign ist umfassend und die Ergebnisse überzeugend. Obwohl es Raum für Verbesserungen in theoretischer Analyse und Rechnerischer Effizienz gibt, leistet diese Arbeit einen wichtigen Beitrag zur Entwicklung dieses Forschungsbereichs.