Amortizing intractable inference in diffusion models for vision, language, and control

Read original: arXiv:2405.20971 - Published 6/3/2024 by Siddarth Venkatraman, Moksh Jain, Luca Scimeca, Minsu Kim, Marcin Sendera, Mohsin Hasan, Luke Rowe, Sarthak Mittal, Pablo Lemos, Emmanuel Bengio and 5 others
Total Score

0

🤯

Sign in to get full access

or

If you already have an account, we'll log you in

Overview

  • Diffusion models have emerged as effective tools for generating and estimating distributions in various domains, including vision, language, and reinforcement learning.
  • However, using diffusion models as priors in downstream tasks poses an intractable posterior inference problem.
  • This paper studies the problem of amortized sampling from the posterior distribution over data, given a diffusion generative model prior and a black-box constraint or likelihood function.

Plain English Explanation

Diffusion models are a type of machine learning model that can be used to generate and estimate distributions of data. They have been successful in areas like computer vision, natural language processing, and reinforcement learning.

The challenge this paper addresses is that when you try to use a diffusion model as a starting point (or "prior") for a more complex task, it becomes very difficult to do the necessary calculations to figure out the final result (the "posterior"). The authors propose a new way to train diffusion models so that they can be used more easily in these kinds of complex problems.

Their approach involves a technique called "relative trajectory balance," which allows the diffusion model to learn to sample from the desired posterior distribution, even if it's hard to calculate directly. This enables the diffusion model to be used as a powerful building block for a wide range of applications, from guiding image classifiers to generating text and images to solving continuous control problems in reinforcement learning.

Technical Explanation

The paper proposes a method for amortized sampling from the posterior distribution over data, $\mathbf{x} \sim p^{\rm post}(\mathbf{x}) \propto p(\mathbf{x})r(\mathbf{x})$, where $p(\mathbf{x})$ is the diffusion generative model prior and $r(\mathbf{x})$ is a black-box constraint or likelihood function.

The authors state and prove the asymptotic correctness of a data-free learning objective, called "relative trajectory balance," for training a diffusion model to sample from this posterior. This is in contrast to existing methods, which only solve the problem approximately or in restricted cases.

The key insight comes from the "generative flow network" perspective on diffusion models, which allows the use of deep reinforcement learning techniques to improve mode coverage. This enables unbiased inference of arbitrary posteriors under diffusion priors, as demonstrated in the paper's experiments across vision, language, and multimodal data tasks.

Beyond generative modeling, the authors also apply relative trajectory balance to the problem of continuous control with a score-based behavior prior, achieving state-of-the-art results on offline reinforcement learning benchmarks.

Critical Analysis

The paper presents a novel and ambitious approach to using diffusion models as flexible priors for a wide range of downstream tasks. The authors' theoretical analysis and empirical results are impressive, demonstrating the broad potential of their "relative trajectory balance" technique.

One potential limitation is the reliance on the "generative flow network" perspective, which may not be as intuitive or accessible as other interpretations of diffusion models. Additionally, the paper focuses on the asymptotic correctness of the method, but the practical convergence rate and sample efficiency in finite-data regimes are not thoroughly explored.

Further research could investigate the performance and robustness of this approach on a wider range of real-world tasks, as well as potential extensions or alternatives to the relative trajectory balance objective that may offer improved stability or sample efficiency.

Conclusion

This paper presents a significant advance in the use of diffusion models as powerful priors for a variety of machine learning problems. By developing a principled framework for amortized posterior sampling, the authors have opened the door to applying diffusion models in a much broader range of applications, from image and text generation to reinforcement learning.

The implications of this work extend beyond just the technical contributions, as it highlights the growing versatility and importance of diffusion-based models in the broader landscape of machine learning. As the field continues to evolve, techniques like those described in this paper will likely play an increasingly central role in pushing the boundaries of what is possible with generative AI systems.



This summary was produced with help from an AI and may contain inaccuracies - check out the links to read the original source documents!

Follow @aimodelsfyi on 𝕏 →

Related Papers

🤯

Total Score

0

Amortizing intractable inference in diffusion models for vision, language, and control

Siddarth Venkatraman, Moksh Jain, Luca Scimeca, Minsu Kim, Marcin Sendera, Mohsin Hasan, Luke Rowe, Sarthak Mittal, Pablo Lemos, Emmanuel Bengio, Alexandre Adam, Jarrid Rector-Brooks, Yoshua Bengio, Glen Berseth, Nikolay Malkin

Diffusion models have emerged as effective distribution estimators in vision, language, and reinforcement learning, but their use as priors in downstream tasks poses an intractable posterior inference problem. This paper studies amortized sampling of the posterior over data, $mathbf{x}sim p^{rm post}(mathbf{x})propto p(mathbf{x})r(mathbf{x})$, in a model that consists of a diffusion generative model prior $p(mathbf{x})$ and a black-box constraint or likelihood function $r(mathbf{x})$. We state and prove the asymptotic correctness of a data-free learning objective, relative trajectory balance, for training a diffusion model that samples from this posterior, a problem that existing methods solve only approximately or in restricted cases. Relative trajectory balance arises from the generative flow network perspective on diffusion models, which allows the use of deep reinforcement learning techniques to improve mode coverage. Experiments illustrate the broad potential of unbiased inference of arbitrary posteriors under diffusion priors: in vision (classifier guidance), language (infilling under a discrete diffusion LLM), and multimodal data (text-to-image generation). Beyond generative modeling, we apply relative trajectory balance to the problem of continuous control with a score-based behavior prior, achieving state-of-the-art results on benchmarks in offline reinforcement learning.

Read more

6/3/2024

Amortized Posterior Sampling with Diffusion Prior Distillation
Total Score

0

Amortized Posterior Sampling with Diffusion Prior Distillation

Abbas Mammadov, Hyungjin Chung, Jong Chul Ye

We propose a variational inference approach to sample from the posterior distribution for solving inverse problems. From a pre-trained diffusion model, our approach trains a conditional flow model to minimize the divergence between the proposal variational distribution and the posterior distribution implicitly defined through the diffusion model. Once trained, the flow model is capable of sampling from the posterior distribution with a single NFE, amortized with respect to the measurement. The proposed method paves a new path for distilling a diffusion prior for efficient posterior sampling. We show that our method is applicable to standard signals in Euclidean space, as well as signals on manifold.

Read more

7/26/2024

🤯

Total Score

0

Diffusion Prior-Based Amortized Variational Inference for Noisy Inverse Problems

Sojin Lee, Dogyun Park, Inho Kong, Hyunwoo J. Kim

Recent studies on inverse problems have proposed posterior samplers that leverage the pre-trained diffusion models as powerful priors. These attempts have paved the way for using diffusion models in a wide range of inverse problems. However, the existing methods entail computationally demanding iterative sampling procedures and optimize a separate solution for each measurement, which leads to limited scalability and lack of generalization capability across unseen samples. To address these limitations, we propose a novel approach, Diffusion prior-based Amortized Variational Inference (DAVI) that solves inverse problems with a diffusion prior from an amortized variational inference perspective. Specifically, instead of separate measurement-wise optimization, our amortized inference learns a function that directly maps measurements to the implicit posterior distributions of corresponding clean data, enabling a single-step posterior sampling even for unseen measurements. Extensive experiments on image restoration tasks, e.g., Gaussian deblur, 4$times$ super-resolution, and box inpainting with two benchmark datasets, demonstrate our approach's superior performance over strong baselines. Code is available at https://github.com/mlvlab/DAVI.

Read more

7/24/2024

👁️

Total Score

0

Diffusion Posterior Sampling for General Noisy Inverse Problems

Hyungjin Chung, Jeongsol Kim, Michael T. Mccann, Marc L. Klasky, Jong Chul Ye

Diffusion models have been recently studied as powerful generative inverse problem solvers, owing to their high quality reconstructions and the ease of combining existing iterative solvers. However, most works focus on solving simple linear inverse problems in noiseless settings, which significantly under-represents the complexity of real-world problems. In this work, we extend diffusion solvers to efficiently handle general noisy (non)linear inverse problems via approximation of the posterior sampling. Interestingly, the resulting posterior sampling scheme is a blended version of diffusion sampling with the manifold constrained gradient without a strict measurement consistency projection step, yielding a more desirable generative path in noisy settings compared to the previous studies. Our method demonstrates that diffusion models can incorporate various measurement noise statistics such as Gaussian and Poisson, and also efficiently handle noisy nonlinear inverse problems such as Fourier phase retrieval and non-uniform deblurring. Code available at https://github.com/DPS2022/diffusion-posterior-sampling

Read more

5/21/2024