Token Pruning for Caching Better: 9 Times Acceleration on Stable Diffusion for Free
Zhang, Xiao, Tang et al.
Stable Diffusion has achieved remarkable success in the field of text-to-image generation, with its powerful generative capabilities and diverse generation results making a lasting impact. However, its iterative denoising introduces high computational costs and slows generation speed, limiting broader adoption. The community has made numerous efforts to reduce this computational burden, with methods like feature caching attracting attention due to their effectiveness and simplicity. Nonetheless, simply reusing features computed at previous timesteps causes the features across adjacent timesteps to become similar, reducing the dynamics of features over time and ultimately compromising the quality of generated images. In this paper, we introduce a dynamics-aware token pruning (DaTo) approach that addresses the limitations of feature caching. DaTo selectively prunes tokens with lower dynamics, allowing only high-dynamic tokens to participate in self-attention layers, thereby extending feature dynamics across timesteps. DaTo combines feature caching with token pruning in a training-free manner, achieving both temporal and token-wise information reuse. Applied to Stable Diffusion on the ImageNet, our approach delivered a 9$\times$ speedup while reducing FID by 0.33, indicating enhanced image quality. On the COCO-30k, we observed a 7$\times$ acceleration coupled with a notable FID reduction of 2.17.
academic
Token Pruning for Caching Better: 9× Beschleunigung auf Stable Diffusion kostenlos
Stable Diffusion hat bedeutende Erfolge im Bereich der Text-zu-Bild-Generierung erzielt, doch sein iterativer Entrauschungsmechanismus führt zu hohen Rechenkosten und langsamer Generierungsgeschwindigkeit. Obwohl Methoden wie Feature-Caching aufgrund ihrer Effektivität und Einfachheit Aufmerksamkeit erhalten, führt die einfache Wiederverwendung von Features aus vorherigen Zeitschritten dazu, dass Features zwischen benachbarten Zeitschritten ähnlich werden, was die Dynamik der Features über die Zeit verringert und letztendlich die Qualität der generierten Bilder beeinträchtigt. Dieses Paper präsentiert eine dynamikbewusste Token-Pruning-Methode (DaTo), um die Einschränkungen des Feature-Caching zu überwinden. DaTo beschneidet selektiv Tokens mit niedriger Dynamik und ermöglicht nur hochdynamischen Tokens, an Self-Attention-Schichten teilzunehmen, wodurch die Feature-Dynamik zwischen Zeitschritten erweitert wird. Bei Anwendung auf Stable Diffusion auf ImageNet erreicht die Methode eine 9×-Beschleunigung, während die FID um 0,33 sinkt; auf COCO-30k wird eine 7×-Beschleunigung mit signifikantem FID-Rückgang von 2,17 beobachtet.
Diffusionsmodelle haben bedeutende Fortschritte in der generativen Modellierung erzielt und werden weit verbreitet in Text-zu-Bild-Generierung, Videogenerierung und anderen Aufgaben eingesetzt. Allerdings führt der iterative Entrauschungsmechanismus von Diffusionsmodellen zu enormen Rechenkosten und langsamer Generierungsgeschwindigkeit, was ihre breitere Anwendung einschränkt.
Die aktuellen Methoden zur Beschleunigung von Diffusionsmodellen umfassen hauptsächlich:
Reduzierung der Sampling-Schritte: wie schnelle Sampler wie DDIM
Reduzierung der Rechenkosten pro Schritt: einschließlich Wissensdestillation, Strukturpruning, Quantisierung, Token-Pruning und Feature-Caching
Unter diesen ist Feature-Caching aufgrund seiner Effektivität und Einfachheit weit verbreitet. Es speichert Features aus vorherigen Zeitschritten und verwendet sie in nachfolgenden Zeitschritten erneut. Allerdings zwingt die Feature-Wiederverwendung Features verschiedener Zeitschritte, ähnliche Werte zu haben, was die Dynamik der Features über Zeitschritte hinweg verringert, den ursprünglichen Diffusionsprozess beschädigt und somit die Generierungsqualität beeinträchtigt.
Das Paper beobachtet durch Experimente, dass die Feature-Unterschiede zwischen benachbarten Zeitschritten bei Modellen mit Feature-Caching im Vergleich zum ursprünglichen Stable Diffusion signifikant abnehmen. Dies wirft eine kritische Frage auf: Ist es möglich, Feature-Caching durchzuführen und gleichzeitig die korrekte Feature-Dynamik zu bewahren?
Vorschlag der dynamikbewussten Token-Pruning-Methode (DaTo): Durch das Beschneiden von Tokens, deren Dynamik durch Feature-Caching in verschiedenen Zeitschritten verringert wird, und deren Wiederherstellung durch Tokens mit großer Dynamik, wird die Qualitätsverschlechterung vermieden, die durch Feature-Caching verursacht wird.
Entwurf einer evolutionären Suchstrategie: Vorschlag einer evolutionären Methode zur Suche nach optimalen Feature-Caching- und Token-Pruning-Strategien, um das volle Potenzial von DaTo freizusetzen.
Erreichung signifikanter Leistungsverbesserungen: Umfangreiche Experimente auf Stable Diffusion und SDXL zeigen, dass ohne Training und zusätzliche Daten auf Stable Diffusion bis zu 9×-Beschleunigung mit verlustfreier Generierungsqualität erreicht werden kann.
Die Aufgabe dieses Papers besteht darin, den Inferenzprozess des Stable Diffusion-Modells erheblich zu beschleunigen, während die Bildgenerierungsqualität erhalten bleibt. Die Eingabe ist ein Textprompt, die Ausgabe ist das entsprechende hochwertige Bild, und die Einschränkung besteht darin, dass das Modell nicht neu trainiert werden muss.
Zeitliche Rausch-Differenz-Bewertung: Für den t-ten Zeitschritt wird die absolute Differenz der Ausgaben der beiden benachbarten vorherigen Zeitschritte berechnet:
Patch-basierte Token-Auswahl: Das Bild wird in nicht überlappende s×s-Patches unterteilt, und in jedem Patch wird der Token mit dem höchsten DiffScore als Basis-Token ausgewählt.
CFG-Ausrichtung:
Um die klassifiziererfreie Anleitung (CFG) zu handhaben, werden die Basis-Token-Positionen der bedingten Generierung in die unbedingte Generierung kopiert:
Pruning-Token-Auswahl:
Basierend auf der Kosinusähnlichkeit werden die K Tokens ausgewählt, die den Basis-Tokens am ähnlichsten sind, um sie zu beschneiden:
X_prune = arg topK max Cosine_Similarity(X_i, X_j)
Pruning-Token-Wiederherstellung:
Die beschnittenen Tokens werden durch direkte Kopie ihres ähnlichsten Basis-Tokens wiederhergestellt.
Pruning-Verhältnis r ist auf {0,3, 0,4, 0,5, 0,6, 0,7} beschränkt
Evolutionärer Suchalgorithmus:
Der NSGA-II-Multiziel-Optimierungsalgorithmus wird verwendet, mit Optimierungszielen einschließlich:
Inferenz-Latenz
Generierungsqualität (FID)
Der Suchprozess umfasst Standard-Evolutionsoperationen wie Selektion, Crossover und Mutation, um letztendlich die optimale zeitschrittbewusste Strategie F(t) zu erhalten.
Dynamik-Wiederherstellungsmechanismus: Durch selektives Beschneiden von Tokens mit niedriger Dynamik und deren Wiederherstellung mit hochdynamischen Tokens wird die durch Feature-Caching beschädigte Feature-Dynamik-Verteilung erfolgreich wiederhergestellt.
Einheitliches Caching-Pruning-Framework: Feature-Caching und Token-Pruning werden in einem trainingsunabhängigen Framework kombiniert, um Informationswiederverwendung auf Zeit- und Token-Ebene zu erreichen.
Adaptive Strategiesuche: Für die unterschiedlichen Redundanzeigenschaften verschiedener Zeitschritte wird eine Methode zur automatischen Suche nach optimaler Caching-Tiefe und Pruning-Verhältnis vorgeschlagen.
Effektivität von DiffScore:
Bei verschiedenen Caching-Einstellungen und Pruning-Verhältnissen verbessert die Verwendung von DiffScore konsistent die FID-Werte und beweist die Effektivität der zeitlichen Rausch-Differenz-Bewertung.
Auswirkung der CFG-Ausrichtung:
Mit zunehmendem Pruning-Verhältnis nimmt der Nutzen der CFG-Ausrichtungskonfiguration allmählich zu, wobei die FID-Verbesserung bei hohem Pruning-Verhältnis (0,7) zwischen 13 und 30 Punkten liegt.
Feature-Dynamik-Wiederherstellung: DaTo stellt die Feature-Differenz-Verteilung erfolgreich auf ein Niveau nahe dem ursprünglichen Stable Diffusion wieder her
Sparsame Kodierungseffekt: Moderates Token-Pruning und Feature-Caching können die Modellleistung durch Konzentration auf Schlüsselfeatures verbessern
Strategieverallgemeinerung: Die auf SD v1.5 gesuchte Strategie zeigt gute Leistung auf SDXL und anderen Datensätzen
Das Paper zitiert 46 verwandte Literaturquellen, die Diffusionsmodelle, Token-Reduktion, Caching-Mechanismen und andere verwandte Bereiche abdecken und eine solide theoretische Grundlage und Vergleichsbenchmarks für diese Forschung bieten.
Gesamtbewertung: Dies ist ein hochqualitatives Computervisions-Paper, das eine innovative Lösung für das wichtige Problem der Diffusionsmodell-Beschleunigung bietet. Das Methodendesign ist elegant, die experimentelle Bewertung umfassend und der praktische Wert hervorragend. Obwohl die theoretische Analysentiefe etwas zu wünschen übrig lässt, sind die praktischen Beiträge und Auswirkungen bemerkenswert.