Diffusion models have emerged as a promising approach for generating high-quality, high-dimensional images. Nevertheless, these models are hindered by their high computational cost and slow inference, partly due to the quadratic computational complexity of the self-attention mechanisms with respect to input size. Various approaches have been proposed to address this drawback. One such approach focuses on reducing the number of tokens fed into the self-attention, known as token merging (ToMe). In our method, which is called cached adaptive token merging(CA-ToMe), we calculate the similarity between tokens and then merge the r proportion of the most similar tokens. However, due to the repetitive patterns observed in adjacent steps and the variation in the frequency of similarities, we aim to enhance this approach by implementing an adaptive threshold for merging tokens and adding a caching mechanism that stores similar pairs across several adjacent steps. Empirical results demonstrate that our method operates as a training-free acceleration method, achieving a speedup factor of 1.24 in the denoising process while maintaining the same FID scores compared to existing approaches.
- ID Articolo: 2501.00946
- Titolo: Cached Adaptive Token Merging: Dynamic Token Reduction and Redundant Computation Elimination in Diffusion Model
- Autori: Omid Saghatchian, Atiyeh Gh. Moghadam, Ahmad Nickabadi (Amirkabir University of Technology)
- Classificazione: cs.CV (Visione Artificiale)
- Data di Pubblicazione: 1 gennaio 2025 (preprint arXiv)
- Link Articolo: https://arxiv.org/abs/2501.00946
- Link Codice: https://github.com/omidiu/ca_tome
I modelli di diffusione si sono affermati come metodi promettenti per la generazione di immagini ad alta qualità e ad alta dimensionalità. Tuttavia, questi modelli sono ostacolati da elevati costi computazionali e velocità di inferenza lenta, in parte dovuti alla complessità computazionale quadratica del meccanismo di auto-attenzione rispetto alla dimensione dell'input. Questo articolo propone il metodo Cached Adaptive Token Merging (CA-ToMe), che affronta questo problema calcolando la similarità tra i token e unendo i token con similarità superiore a un parametro di soglia t. A causa dei modelli ripetitivi osservati in fasi adiacenti e delle variazioni nella frequenza di similarità, il metodo migliora l'approccio di fusione dei token implementando una soglia adattiva e aggiungendo un meccanismo di memorizzazione. I risultati sperimentali dimostrano che il metodo, come approccio di accelerazione senza addestramento, raggiunge un'accelerazione di 1,24 volte nel processo di denoising mantenendo lo stesso punteggio FID dei metodi esistenti.
I modelli di diffusione eccellono nei compiti di generazione di immagini, ma affrontano gravi problemi di efficienza computazionale:
- Elevati costi computazionali: La complessità quadratica del meccanismo di auto-attenzione determina una velocità di inferenza lenta
- Processo di denoising seriale: Non può essere parallelizzato, ogni fase di denoising richiede calcoli ripetuti
- Calcoli ridondanti: Esiste una notevole quantità di calcoli ripetuti tra fasi temporali adiacenti
- L'elevata latenza dei modelli di diffusione limita il loro utilizzo in applicazioni che richiedono inferenza rapida
- L'elevato costo computazionale rende difficile il deployment del modello, specialmente in ambienti con risorse limitate
- I metodi di accelerazione esistenti richiedono o un nuovo addestramento o comportano perdite significative di qualità
- I metodi che riducono il numero di fasi di campionamento generalmente richiedono un nuovo addestramento o l'utilizzo di risolutori numerici complessi
- I metodi di potatura dei token causano perdita di informazioni e degradazione delle prestazioni
- La fusione tradizionale dei token (ToMe) utilizza un tasso di fusione fisso, incapace di adattarsi alle variazioni nella distribuzione di similarità tra diversi passi temporali e strati
Basata su due fenomeni chiave osservati:
- Esistono variazioni significative nella distribuzione di similarità dei token tra diversi passi temporali e strati
- Le coppie di token tra fasi di inferenza adiacenti mostrano un'elevata similarità
- Propone un meccanismo di soglia adattiva: Regola dinamicamente la strategia di fusione in base alla distribuzione di similarità dei token, sostituendo il tasso di fusione fisso
- Progetta un meccanismo di memorizzazione: Sfrutta la similarità tra fasi adiacenti, memorizzando le coppie di token per ridurre i calcoli ripetuti
- Implementa accelerazione senza addestramento: Il metodo può essere applicato direttamente ai modelli pre-addestrati senza richiedere un nuovo addestramento
- Raggiunge un migliore compromesso qualità-velocità: Rispetto al metodo ToMe di base, realizza una velocità di inferenza più rapida mantenendo la qualità dell'immagine
Input: Sequenza di token nel processo di denoising del modello di diffusione
Output: Processo di inferenza accelerato attraverso fusione adattiva e ottimizzazione della memorizzazione
Vincoli: Mantenere il calo della qualità dell'immagine generata non significativo
Il metodo ToMe tradizionale utilizza un rapporto fisso r per la fusione dei token, mentre CA-ToMe introduce una soglia di similarità t:
Idea Centrale:
- Dividere l'immagine in regioni di stride di dimensione sx × sy
- Selezionare il token nell'angolo in alto a sinistra di ogni regione di stride come token di destinazione
- Calcolare la similarità del coseno tra i token sorgente e i token di destinazione
- Fondere solo le coppie di token con similarità superiore alla soglia t
Analisi dei Vantaggi:
- Scenario A: Quando la maggior parte dei token ha bassa similarità, il tasso di fusione fisso forza la fusione di token non simili, causando perdita di informazioni. La soglia adattiva garantisce la fusione solo di token ad alta similarità
- Scenario B: Quando la maggior parte dei token è altamente simile (come nelle fasi iniziali del denoising), il tasso di fusione fisso limita la quantità di fusione. La soglia adattiva consente la fusione di più token, migliorando l'efficienza
Basato sull'analisi della distanza di Jaccard che scopre l'elevata similarità delle coppie di token tra fasi adiacenti:
JaccardDistance(An,An+1)=1−∣An∪An+1∣∣An∩An+1∣
dove An rappresenta l'insieme di tutte le coppie di token sorgente-destinazione al passo n.
Strategia di Implementazione:
- Impostare checkpoint, calcolando la matrice di similarità solo in fasi temporali specifiche
- Riutilizzare le coppie di token calcolate in precedenza nelle fasi non-checkpoint
- Ridurre significativamente l'overhead di calcolo ripetuto della matrice di similarità
- Adattività Dinamica: Regola automaticamente la strategia di fusione in base alla distribuzione di similarità, evitando i limiti dei parametri fissi
- Ottimizzazione della Dimensione Temporale: Sfrutta la ridondanza tra fasi temporali, riducendo la quantità di calcolo attraverso la memorizzazione
- Applicazione Selettiva a Livello di Strato: Applica specificamente l'ottimizzazione ai livelli superiori della U-Net ad alta intensità computazionale (D1 e U1)
- Nessun Riaddestramento Richiesto: Come metodo di accelerazione plug-and-play, può essere applicato direttamente ai modelli esistenti
- Dataset ImageNet-1k: Generazione di 2000 immagini a risoluzione 512×512 (2 immagini per classe, 1000 classi totali)
- Set di Validazione: Utilizzo di 5000 immagini di validazione ImageNet-1k per il calcolo del punteggio FID
- Modello di Prompt: "A high-quality photograph of a classname."
- FID (Fréchet Inception Distance): Metrica principale per misurare la qualità dell'immagine generata
- Tempo di Inferenza: Tempo medio per generare 2000 immagini
- PSNR: Rapporto Picco Segnale-Rumore, misura la qualità della ricostruzione a livello di pixel
- SSIM: Indice di Similarità Strutturale, valuta la coerenza spaziale e strutturale
- Baseline: Stable Diffusion v1.5 originale
- ToMe: Metodo tradizionale di fusione dei token (r=50%)
- Hardware: GPU Tesla V100S
- Fasi di Diffusione: 50 fasi di campionamento PLMS
- Scala CFG: 7.5
- Dimensione dello Stride: Fissa a 2×2
- Strati di Applicazione: Applicato solo ai livelli D1 e U1 della U-Net
| Modello | FID | Tempo Medio (s) | Rapporto di Accelerazione |
|---|
| Baseline | 33.66 | 7.61±0.001 | 1.0× |
| ToMe | 34.16 | 6.39±0.006 | 1.19× |
| CA-ToMe | 34.05 | 6.09±0.001 | 1.24× |
Scoperte Chiave:
- CA-ToMe raggiunge la velocità di inferenza più rapida (6.09s)
- Il punteggio FID (34.05) è superiore a ToMe (34.16) e vicino al baseline (33.66)
- Raggiunge il miglior equilibrio tra velocità e qualità
| Soglia t | FID | Tempo Medio (s) | PSNR | SSIM |
|---|
| 0.4 | 35.28 | 6.07±0.007 | 27.90 | 0.191 |
| 0.5 | 35.46 | 6.07±0.004 | 27.909 | 0.208 |
| 0.6 | 35.56 | 6.10±0.005 | 27.908 | 0.218 |
| 0.7 | 34.30 | 6.23±0.002 | 27.910 | 0.234 |
| 0.8 | 33.80 | 6.58±0.004 | 27.904 | 0.239 |
| 0.9 | 33.42 | 6.92±0.003 | 27.907 | 0.238 |
Osservazioni dei Risultati:
- Le variazioni nell'intervallo di soglia 0.4-0.6 sono minime, poiché la maggior parte dei token ha similarità ≥0.6
- La soglia 0.7 fornisce il miglior compromesso tra qualità e velocità
- Soglie più elevate migliorano la qualità ma riducono la velocità
| Configurazione | Impostazione Checkpoint | Tempo (s) | FID |
|---|
| CONFIG 1 | 0,1,2,3,5,10,15,25,35 | 6.18±0.02 | 36.14 |
| CONFIG 2 | 0,10,11,12,15,20,25,30,35,45 | 6.13±0.001 | 34.33 |
| CONFIG 3 | 0,8,11,13,20,25,30,35,45,46,47,48,49 | 6.09±0.001 | 34.05 |
CONFIG 3 mostra le migliori prestazioni, coerente con l'analisi della distanza di Jaccard, con più checkpoint impostati ai passi 8, 11, 13 e negli ultimi passi.
Attraverso il confronto del contributo di diversi componenti:
- Solo soglia adattiva: Migliora la qualità dell'immagine rispetto al tasso di fusione fisso
- Solo meccanismo di memorizzazione: Riduce significativamente il tempo di calcolo
- CA-ToMe Completo: La combinazione di entrambe le tecniche raggiunge le migliori prestazioni
- Riduzione del numero di fasi di campionamento:
- Metodi di distillazione della conoscenza 26,51,28
- Campionamento implicito 32
- Risolutori di equazioni differenziali avanzati 52,33
- La maggior parte richiede un nuovo addestramento
- Riduzione del calcolo per fase:
- Metodi di quantizzazione 31,36
- Riduzione dei token 21,40,41,43,44
- Tecniche di memorizzazione 24,37,38,39
- Plug-and-play, senza richiedere nuovo addestramento
- Potatura dei token: Eliminazione diretta di token non importanti, potrebbe causare perdita di informazioni
- Fusione dei token: Fusione di token simili, preservando l'integrità delle informazioni
- ToMe 21: Utilizza un tasso di fusione fisso
- CA-ToMe di questo articolo: Soglia adattiva + meccanismo di memorizzazione
I metodi di memorizzazione esistenti si rivolgono a diversi componenti:
- Memorizzazione dell'attenzione incrociata 38
- Memorizzazione dell'encoder U-Net 39
- Memorizzazione delle caratteristiche avanzate 24
Questo articolo è il primo ad applicare la memorizzazione al calcolo di similarità nella fusione dei token.
- La soglia adattiva affronta efficacemente i limiti del tasso di fusione fisso, regolando dinamicamente la strategia di fusione in base alla distribuzione di similarità
- Il meccanismo di memorizzazione sfrutta la ridondanza tra fasi temporali, riducendo significativamente i calcoli ripetuti
- Il metodo CA-ToMe raggiunge un'accelerazione di 1,24 volte mantenendo e persino leggermente migliorando la qualità dell'immagine
- La caratteristica senza addestramento conferisce al metodo una buona praticità e scalabilità
- Ottimizzazione dei parametri di soglia: Richiede l'adeguamento della soglia ottimale per diversi modelli e compiti
- Limitazione dell'ambito di applicabilità: Principalmente orientato ai modelli di diffusione con architettura U-Net
- Overhead di memorizzazione: Richiede memoria aggiuntiva per memorizzare le informazioni delle coppie di token memorizzate
- Limitazione dei livelli: Applicato solo ai livelli superiori, potrebbe perdere opportunità di ottimizzazione in altri livelli
- Apprendimento automatico della soglia: Sviluppare metodi per determinare automaticamente la soglia ottimale
- Estensione ad altre architetture: Adattamento ad architetture di modelli di diffusione nuove come DiT
- Strategie di memorizzazione più raffinate: Meccanismi di memorizzazione adattivi basati sul contenuto
- Ottimizzazione hardware: Implementazioni ottimizzate per hardware specifico
- Forte innovatività: Introduce il concetto di adattività nella fusione dei token, combinando il meccanismo di memorizzazione per formare una soluzione completa
- Elevato valore pratico: La caratteristica senza addestramento e plug-and-play la rende facile da distribuire
- Esperimenti completi: Esperimenti di ablazione completi e analisi dei parametri supportano l'efficacia del metodo
- Fondamento teorico solido: L'analisi di similarità basata sulla distanza di Jaccard fornisce supporto teorico al meccanismo di memorizzazione
- Analisi teorica non sufficientemente approfondita: Manca la guida teorica per la selezione della soglia adattiva
- Ambito sperimentale limitato: Validazione solo su ImageNet, mancanza di valutazione su altri dataset e compiti
- Pochi metodi di confronto: Principalmente confronto con ToMe, mancanza di confronto con altri metodi di accelerazione
- Valutazione della qualità singolare: Principalmente dipendente dalla metrica FID, mancanza di valutazione umana e altre metriche di qualità
- Contributo accademico: Fornisce nuove idee e metodi per l'accelerazione dei modelli di diffusione
- Valore pratico: Può essere applicato direttamente ai modelli di diffusione esistenti, con ampi prospettive di applicazione
- Riproducibilità: Fornisce implementazione completa del codice, facilitando la riproduzione e l'estensione
- Natura ispirativa: Le idee di adattività e memorizzazione possono ispirare più ricerche correlate
- Ambienti con risorse limitate: Dispositivi mobili, scenari di edge computing
- Applicazioni in tempo reale: Applicazioni interattive che richiedono generazione rapida di immagini
- Distribuzione su larga scala: Riduzione dei costi computazionali del server e della latenza
- Prototipi di ricerca: Fornire componenti di base per altre tecniche di accelerazione
Questo articolo cita 54 riferimenti correlati, principalmente includenti:
- Teoria fondamentale dei modelli di diffusione 1,2,3
- Applicazioni di generazione di immagini 4,5,18,19,20
- Tecniche di accelerazione 24,25,26,27,28
- Metodi di elaborazione dei token 21,40,41,43,44
- Tecniche di memorizzazione 24,37,38,39
Valutazione Complessiva: Questo è un lavoro di valore pratico nel campo dell'accelerazione dei modelli di diffusione. Attraverso la combinazione ingegnosa di soglia adattiva e meccanismo di memorizzazione, realizza un significativo aumento di velocità mantenendo la qualità dell'immagine. Sebbene vi sia ancora spazio per miglioramenti nell'analisi teorica e nell'ambito sperimentale, la sua caratteristica senza addestramento e i buoni risultati sperimentali gli conferiscono un elevato valore pratico e impatto.