2025-11-18T10:52:13.210456

A mathematical theory for understanding when abstract representations emerge in neural networks

Wang, Johnston, Fusi
Recent experiments reveal that task-relevant variables are often encoded in approximately orthogonal subspaces of the neural activity space. These disentangled low-dimensional representations are observed in multiple brain areas and across different species, and are typically the result of a process of abstraction that supports simple forms of out-of-distribution generalization. The mechanisms by which such geometries emerge remain poorly understood, and the mechanisms that have been investigated are typically unsupervised (e.g., based on variational auto-encoders). Here, we show mathematically that abstract representations of latent variables are guaranteed to appear in the last hidden layer of feedforward nonlinear networks when they are trained on tasks that depend directly on these latent variables. These abstract representations reflect the structure of the desired outputs or the semantics of the input stimuli. To investigate the neural representations that emerge in these networks, we develop an analytical framework that maps the optimization over the network weights into a mean-field problem over the distribution of neural preactivations. Applying this framework to a finite-width ReLU network, we find that its hidden layer exhibits an abstract representation at all global minima of the task objective. We further extend these analyses to two broad families of activation functions and deep feedforward architectures, demonstrating that abstract representations naturally arise in all these scenarios. Together, these results provide an explanation for the widely observed abstract representations in both the brain and artificial neural networks, as well as a mathematically tractable toolkit for understanding the emergence of different kinds of representations in task-optimized, feature-learning network models.
academic

A mathematical theory for understanding when abstract representations emerge in neural networks

Basic Information

  • Paper ID: 2510.09816
  • Title: A mathematical theory for understanding when abstract representations emerge in neural networks
  • Authors: Bin Wang, W. Jeffrey Johnston, Stefano Fusi
  • Institution: Center for Theoretical Neuroscience, Columbia University
  • Classification: q-bio.NC math.OC physics.bio-ph physics.data-an stat.ML
  • Publication Date: October 14, 2025 (Preprint)
  • Paper Link: https://arxiv.org/abs/2510.09816

Abstract

This paper investigates the mathematical mechanisms underlying the emergence of abstract representations in neural networks. Experimental findings reveal that task-relevant variables are typically encoded in approximately orthogonal subspaces of neural activity space, forming decoupled low-dimensional representations. While this geometric structure supports simple out-of-distribution generalization, the mechanisms of its emergence remain unclear. The authors mathematically prove that abstract representations necessarily emerge in the final hidden layer of feedforward nonlinear networks trained on tasks dependent on latent variables. To this end, the authors develop an analytical framework that maps network weight optimization to a mean-field problem over neural pre-activation distributions.

Research Background and Motivation

Core Problems

  1. Universality of abstract representations: Neuroscience experiments demonstrate that neural activity across multiple brain regions and species exhibits abstract representations, where task-relevant variables are encoded in approximately orthogonal subspaces
  2. Missing mechanistic understanding: Despite the widespread existence of this geometric structure, the network mechanisms underlying its emergence remain unclear
  3. Limitations of existing approaches: Previously studied mechanisms are primarily unsupervised methods (e.g., variational autoencoders), but pure unsupervised learning of disentangled representations faces significant challenges due to identifiability issues

Research Significance

  • Theoretical importance: Provides mathematical explanation for the widely observed phenomenon of abstract representations
  • Practical value: Understanding representation learning mechanisms aids in designing better neural network architectures
  • Cross-disciplinary impact: Bridges representation learning theory in neuroscience and machine learning

Core Contributions

  1. Theoretical guarantees: First mathematical proof that feedforward nonlinear networks necessarily produce abstract representations under multi-task supervised learning settings
  2. Analytical framework: Develops a general analytical tool mapping network weight optimization to mean-field problems over neural pre-activation distributions
  3. Activation function robustness: Proves that abstract representation emergence is robust to activation function choice
  4. Architecture extensions: Extends analysis to deep networks and recurrent networks
  5. Neuroscience insights: Provides computational explanations for abstract representations observed in biological neural networks

Methodology Details

Task Definition

Consider a training dataset D={(xi,yi)}i=1PD = \{(x^i, y^i)\}_{i=1}^P, where:

  • Input xiRdXx^i \in \mathbb{R}^{d_X} is essentially unstructured
  • Output yi{±1}dYy^i \in \{±1\}^{d_Y} contains dYd_Y binary labels reflecting latent variable structure
  • All data form 2dY2^{d_Y} distinct classes, each containing nn samples
  • Total sample count P=n2dYP = n \cdot 2^{d_Y}

Network Architecture

The study focuses on the simplest two-layer network: fW1,W2,b(x)=W2ϕ(W1x+b)f_{W_1,W_2,b}(x) = W_2\phi(W_1x + b)

where:

  • W1RM×dXW_1 \in \mathbb{R}^{M \times d_X}: first layer weight matrix
  • W2RdY×MW_2 \in \mathbb{R}^{d_Y \times M}: second layer weight matrix
  • bRMb \in \mathbb{R}^M: bias parameters
  • ϕ\phi: element-wise nonlinear activation function
  • MM: hidden layer width

Loss Function

Uses mean squared error with L2 regularization: E(W1,W2,b)=YW2ϕ(WX)F2+λ1WF2+λ2W2F2E(W_1,W_2,b) = \|Y - W_2\phi(WX)\|_F^2 + \lambda_1\|W\|_F^2 + \lambda_2\|W_2\|_F^2

Abstract Representation Metric

Uses Parallelism Score (PS) to quantify the degree of representation abstraction:

  1. Class prototype representation: r(y)=1ni:yi=yrir^{(y)} = \frac{1}{n}\sum_{i:y^i=y} r^i
  2. Representation variation direction: Δr(k;α)=r(yk=+1,yk=α)r(yk=1,yk=α)\Delta r^{(k;\alpha)} = r^{(y_k=+1,y_{\setminus k}=\alpha)} - r^{(y_k=-1,y_{\setminus k}=\alpha)}
  3. Parallelism score: PS=1dYk=1dYPSkPS = \frac{1}{d_Y}\sum_{k=1}^{d_Y} PS_k

where PSkPS_k measures the consistency of the encoding direction for the kk-th latent label. PS = 1 corresponds to perfectly abstract representation.

Analytical Framework Core

Mean-Field Transformation

The key innovation transforms the original optimization problem: minW1,W2,bE(W1,W2,b)\min_{W_1,W_2,b} E(W_1,W_2,b)

into optimization over neural pre-activation distributions: minρME[ρM]\min_{\rho_M} \mathcal{E}[\rho_M]

where ρM=k=1Mδhk\rho_M = \sum_{k=1}^M \delta_{h_k} is the empirical measure of pre-activation patterns.

Effective Energy Function

The effective system's energy function is: E[ρM]=λ1hTKXhdρM(h)+tr(λ2λ2+ϕ(h)ϕ(h)TdρM(h)KY)\mathcal{E}[\rho_M] = \lambda_1\int h^T K_X^\dagger h d\rho_M(h) + \text{tr}\left(\frac{\lambda_2}{\lambda_2 + \int\phi(h)\phi(h)^T d\rho_M(h)} K_Y\right)

where:

  • KX=XTXK_X = X^TX: input kernel matrix
  • KY=YTYK_Y = Y^TY: output kernel matrix
  • KXK_X^\dagger: Moore-Penrose pseudoinverse

KKT Conditions

Optimal solutions satisfy: λ1hTKXhλ2ϕ(h)T1λ2+K[ρ]KY1λ2+K[ρ]ϕ(h)0\lambda_1 h^T K_X^\dagger h - \lambda_2\phi(h)^T \frac{1}{\lambda_2 + K[\rho^*]} K_Y \frac{1}{\lambda_2 + K[\rho^*]} \phi(h) \geq 0

with equality if and only if hsupp(ρ)h \in \text{supp}(\rho^*).

Experimental Setup

Data Configuration

  1. Whitened inputs: XdataTXdata=IPX_{\text{data}}^T X_{\text{data}} = I_P
  2. Target-aligned inputs: inputs with geometric structure partially aligned with outputs
  3. Anisotropic inputs: different directions with different scaling factors

Network Configuration

  • Activation functions: ReLU, hard sigmoid, tanh, etc.
  • Network width: M2dYM \geq 2^{d_Y}
  • Regularization parameters: small λ1,λ2\lambda_1, \lambda_2

Evaluation Metrics

  • Parallelism Score (PS)
  • Training loss
  • Comparison of theoretical predictions vs. actual results for representation kernel matrices

Experimental Results

Main Results

Optimal Representation for ReLU Networks

For whitened inputs and singleton classes (n=1n=1), the optimal hidden representation kernel is: K[ρ]=b(dY11T+KY)K[\rho^*] = b^*(d_Y \mathbf{1}\mathbf{1}^T + K_Y)

where: b=λ2λ1P+1P(P+2)λ2Pb^* = \sqrt{\frac{\lambda_2}{\lambda_1}\frac{P+1}{P(P+2)}} - \frac{\lambda_2}{P}

Abstract Representation Guarantees

Theorem: When M2dYM \geq 2^{d_Y} and inputs are whitened or target-aligned, all global minima correspond to abstract representations (PS = 1).

Neural Tuning Properties

Optimal pre-activation patterns are: h=α(1±vi),α0,i{1,2,,dY}h = \alpha(\mathbf{1} \pm v_i), \quad \alpha \geq 0, i \in \{1,2,\ldots,d_Y\}

This indicates that hidden layer neurons are divided into 2dY2^{d_Y} groups, each responding only to a single output label.

Activation Function Robustness

Threshold-type Activation Functions

For activation functions of the form ϕ(z)=ϕ+(z)1z0\phi(z) = \phi_+(z) \cdot \mathbf{1}_{z \geq 0}, the optimal representation kernel maintains the same form, with only coefficient changes.

Odd-symmetric Activation Functions

For odd-function activations, the optimal kernel is: K[ρ]=bKYK[\rho^*] = b^* K_Y

While lacking the constant term, it still corresponds to abstract representation (PS = 1).

Extension Results

Deep Networks

For L-layer deep networks, each layer exhibits abstract representation: K(l)[ρl]=bl(dY11T+KY)K^{(l)}[\rho_l^*] = b_l^*(d_Y \mathbf{1}\mathbf{1}^T + K_Y)

where bl=(γ)l1b1b_l^* = (\gamma^*)^{l-1} b_1^*.

Recurrent Networks

Abstract representations similarly emerge at the final time step, validating the broad applicability of the framework.

Neuroscience Background

  • Abstract representations observed in multiple brain regions (hippocampus, prefrontal cortex, etc.)
  • These representations support out-of-distribution generalization and abstract reasoning

Machine Learning Approaches

  • Variational Autoencoders: Standard method for unsupervised disentangled representation learning
  • Supervised methods: Obtain disentangled representations through multi-task learning
  • Neural Collapse: Representation geometric phenomena in late-stage deep network training

Theoretical Analysis

  • Neural Tangent Kernel: Theoretical analysis of infinite-width networks
  • Mean-field theory: Statistical physics approaches to deep networks
  • Learning dynamics: Mathematical analysis of weight evolution

Conclusions and Discussion

Main Conclusions

  1. Theoretical guarantees: Under suitable conditions, supervised learning necessarily produces abstract representations
  2. Mechanism explanation: Task structure determines representation geometry, while input geometry affects learning efficiency
  3. Universality: Results are robust to activation function and network architecture choices

Biological Significance

  • Provides computational explanations for abstract representations widely observed in the brain
  • "Re-encoding" in brain regions like the hippocampus may facilitate downstream abstract representation formation
  • Single-neuron nonlinearity affects tuning properties but does not alter population geometry

Limitations

  1. Task constraints: Primarily applicable to combinatorial binary classification tasks
  2. Input assumptions: Requires specific input geometric structure
  3. Regularization dependence: Requires appropriate L2 regularization strength

Future Directions

  1. Continuous variables: Extension to representation learning with continuous latent variables
  2. Learning dynamics: Analysis of abstract representation formation processes
  3. Biological implementation: Investigation of representation emergence under biologically plausible learning rules

In-Depth Evaluation

Strengths

  1. Theoretical rigor: Provides mathematical proofs for abstract representation emergence, filling an important theoretical gap
  2. Methodological innovation: Mean-field framework provides new tools for analyzing finite-width networks
  3. Universal applicability: Results hold for multiple activation functions and network architectures
  4. Cross-disciplinary value: Bridges neuroscience observations and machine learning theory
  5. Comprehensive experimental validation: Theoretical predictions align well with numerical experiments

Limitations

  1. Restricted task scope: Primarily focuses on specific binary label combinatorial tasks
  2. Strict input conditions: Requires whitened or target-aligned input geometry
  3. Distance from practical applications: Still far from complex real-world tasks
  4. Computational complexity: Solving mean-field equations may be computationally expensive

Impact

  1. Theoretical contribution: Provides important mathematical foundations for representation learning theory
  2. Methodological value: Analytical framework applicable to other network models
  3. Practical guidance: Informs design of network architectures promoting abstract representations
  4. Cross-field influence: May impact interdisciplinary research between neuroscience and machine learning

Applicable Scenarios

  • Tasks requiring interpretable representation learning
  • Feature disentanglement in multi-task learning
  • Theoretical modeling of representation geometry in neuroscience
  • Applications requiring out-of-distribution generalization capability

Technical Innovation Points

Core Mathematical Tools

  1. Measure-theoretic methods: Transforms discrete neuron problems into continuous measure optimization
  2. Convex optimization theory: Utilizes KKT conditions to analyze global optimal solutions
  3. Matrix analysis: Characterizes representation geometric structure through kernel matrices

Analytical Techniques

  • Copositive programming: Handles non-convex constraints in ReLU networks
  • Schur convexity: Analyzes unified properties across different activation functions
  • Perturbation analysis: Extends results through continuity arguments

This work provides important theoretical foundations for understanding representation learning in neural networks, with mathematical frameworks and insights valuable to both neuroscience and machine learning.