Improved Sample Complexity For Diffusion Model Training Without Empirical Risk Minimizer Access
Gaur, Trivedi, Kunapuli et al.
Diffusion models have demonstrated state-of-the-art performance across vision, language, and scientific domains. Despite their empirical success, prior theoretical analyses of the sample complexity suffer from poor scaling with input data dimension or rely on unrealistic assumptions such as access to exact empirical risk minimizers. In this work, we provide a principled analysis of score estimation, establishing a sample complexity bound of $\mathcal{O}(ε^{-4})$. Our approach leverages a structured decomposition of the score estimation error into statistical, approximation, and optimization errors, enabling us to eliminate the exponential dependence on neural network parameters that arises in prior analyses. It is the first such result that achieves sample complexity bounds without assuming access to the empirical risk minimizer of score function estimation loss.
academic
Complejidad de Muestra Mejorada para Entrenamiento de Modelos de Difusión sin Acceso al Minimizador de Riesgo Empírico
Los modelos de difusión han demostrado un rendimiento de última generación en los campos de visión, lenguaje y ciencia. A pesar del éxito empírico, los análisis teóricos previos sobre complejidad de muestra presentan dos problemas principales: primero, crecimiento exponencial con la dimensionalidad de los datos de entrada; segundo, dependencia de suposiciones poco realistas (como acceso a un minimizador de riesgo empírico exacto). Este artículo proporciona un análisis principista de la estimación de puntuación, estableciendo un límite de complejidad de muestra de O~(ϵ−4). El método descompone estructuradamente el error de estimación de puntuación en error estadístico, error de aproximación y error de optimización, eliminando la dependencia exponencial de los parámetros de la red neuronal en análisis previos. Este es el primer resultado que logra un límite de complejidad de muestra sin asumir acceso a un minimizador de riesgo empírico de la pérdida de estimación de puntuación.
Los modelos de difusión muestrean desde distribuciones complejas aprendiendo a invertir el proceso de adición de ruido, siendo el núcleo la estimación de la función de puntuación (score function) ∇logpt(x). Aunque los modelos de difusión funcionan excepcionalmente bien en la práctica, la comprensión teórica sigue siendo limitada, particularmente:
Problema de Complejidad de Muestra: ¿Cuántas muestras se necesitan para entrenar un modelo de difusión de alta calidad?
Maldición de Dimensionalidad: Los resultados teóricos existentes muestran dependencia exponencial de la dimensionalidad de datos d (como O~(ϵ−d))
Suposiciones Poco Realistas: Todos los trabajos previos asumen acceso a un minimizador de riesgo empírico (ERM) de la pérdida de estimación de puntuación, lo cual es imposible en la práctica
Este artículo busca responder la pregunta central:
¿Cuántas muestras se necesitan para que una red neuronal suficientemente expresiva estime la función de puntuación sin acceso a un minimizador de riesgo empírico, permitiendo generar muestras de alta calidad usando el algoritmo DDPM?
Primer Límite de Complejidad de Muestra en Tiempo Finito sin Suposición ERM: Establece un límite de complejidad de muestra de O~(ϵ−4) sin necesidad de acceso a un minimizador de riesgo empírico, sin dependencia exponencial de la dimensionalidad o parámetros de la red neuronal
Marco de Descomposición de Error Principista: Propone descomponer sistemáticamente el error de estimación de puntuación en tres componentes:
Error de Aproximación (Approximation Error): Limitaciones de capacidad expresiva de la clase de funciones de red neuronal
Error Estadístico (Statistical Error): Error causado por muestras finitas
Error de Optimización (Optimization Error): Error causado por número finito de pasos SGD
Análisis Técnico Novedoso:
Utiliza normalidad condicional para manejar funciones de pérdida no acotadas en el error estadístico
Delimita el error de optimización mediante la condición de Polyak-Łojasiewicz (PL) y análisis recursivo
Proporciona garantías de convergencia para tasas de aprendizaje constantes y decrecientes
Puente entre Teoría y Práctica: Conecta directamente la calidad de la función de puntuación aprendida con la distancia de variación total entre la distribución generada y la distribución objetivo
Proceso de Difusión Hacia Adelante: Utiliza el proceso de Ornstein-Uhlenbeck (OU):
dxt=−xtdt+2dBt,x0∼p0,x∈Rd
La solución en forma cerrada es:
xt∼e−tx0+1−e−2tϵ,ϵ∼N(0,I)
Cuando t→∞, el proceso converge a la distribución estacionaria N(0,I).
Proceso de Difusión Inversa: Obtenido mediante teoría de inversión temporal:
dxT−t=(xT−t+2∇logpT−t(xT−t))dt+2dBt
Discretización: Se discretiza en puntos de tiempo 0<t0<t1<⋯<tK=T, implementando el algoritmo DDPM usando la función de puntuación estimada s^tk.
Objetivo: Cuantificar la distancia de variación total (TV) entre el modelo generativo aprendido p^t0 y la distribución de datos verdadera p:
TV(pt0,p^t0)≤O(ϵ)
Suposición 1 (Distribución de Datos con Segundo Momento Acotado): La distribución de datos p0 es absolutamente continua, con soporte en un conjunto compacto Γ⊂Rd, y E[∥x0∥2]≤C1.
Suposición 2 (Condición de Polyak-Łojasiewicz): La función de pérdida Lk(θ) satisface la condición PL:
21∥∇Lk(θ)∥2≥μt(Lk(θ)−Lk(θ∗))
Esto es significativamente más débil que convexidad fuerte y es común en redes neuronales sobreparametrizadas.
Suposición 3 (Error de Aproximación): Existe un parámetro de red neuronal θ∈Θ tal que:
Ex∼pt[∥sθ(x,t)−∇logpt(x)∥2]≤ϵapprox
Suposición 4 (Suavidad y Varianza de Gradiente Acotada):
Función de pérdida κ-suave: ∥∇Lk(θ)−∇Lk(θ′)∥≤κ∥θ−θ′∥
Varianza de estimación de gradiente acotada: E∥∇L^k(θ)−∇Lk(θ)∥2≤σ2
Lema 1 (Error de Aproximación): Obtenido directamente de la Suposición 3:
Ekapprox≤ϵapprox
Lema 2 (Error Estadístico): Utilizando normalidad condicional y segundo momento acotado, con probabilidad al menos 1−δ:
Ekstat≤O(WD⋅d⋅nklog(2/δ))
Técnicas Clave:
Definición de función de puntuación truncada para manejar no acotación
Uso de complejidad de Rademacher para delimitar error de generalización
Control de masa de probabilidad fuera de la región truncada
Lema 3 (Error de Optimización): Usando tasa de aprendizaje decreciente ηi=i+γα (donde αμ>1, γ>ακ), con probabilidad al menos 1−δ:
Ekopt≤O(WD⋅d⋅nklog(2/δ))
Técnicas Clave:
Explotación de la propiedad de crecimiento cuadrático de la condición PL
Análisis recursivo de cada paso SGD
Manejo de ruido de cola pesada bajo recorte de gradientes
Nota: Este es un artículo puramente teórico sin sección experimental. Las contribuciones principales radican en el análisis teórico y el establecimiento de límites de complejidad de muestra.
Constante de Error de Aproximación: Trata ϵapprox como constante, sin analizar su relación con el tamaño de red (en la práctica puede requerir redes exponencialmente grandes para lograr error de aproximación pequeño)
Condición PL: Aunque más débil que convexidad fuerte, puede no cumplirse en configuraciones no convexas generales (aunque es común en redes sobreparametrizadas)
Tiempo de Parada Temprana: El límite es para TV(pt0,p^t0) en lugar de TV(p0,p^t0), siendo este último requiere suposiciones sub-Gaussiano adicionales (Teorema 2)
Generación Incondicional: El análisis es solo para distribuciones incondicionales, la extensión a configuraciones condicionales es una dirección futura
Verificación Experimental: Como trabajo puramente teórico, carece de verificación experimental de predicciones teóricas
La teoría de aprendizaje estadístico tradicional (como Shalev-Shwartz & Ben-David, 2014) requiere funciones de pérdida acotadas para aplicar complejidad de Rademacher. Pero la función de puntuación ∇logpt(x)=σt2x−e−tx0 es no acotada cuando x es no acotado.