On-Demand Sampling: Learning Optimally from Multiple Distributions

2210.12529

YC

0

Reddit

0

Published 4/4/2024 by Nika Haghtalab, Michael I. Jordan, Eric Zhao

🤖

Abstract

Social and real-world considerations such as robustness, fairness, social welfare and multi-agent tradeoffs have given rise to multi-distribution learning paradigms, such as collaborative learning, group distributionally robust optimization, and fair federated learning. In each of these settings, a learner seeks to uniformly minimize its expected loss over $n$ predefined data distributions, while using as few samples as possible. In this paper, we establish the optimal sample complexity of these learning paradigms and give algorithms that meet this sample complexity. Importantly, our sample complexity bounds for multi-distribution learning exceed that of learning a single distribution by only an additive factor of $n log(n) / epsilon^2$. This improves upon the best known sample complexity bounds for fair federated learning by Mohri et al. and collaborative learning by Nguyen and Zakynthinou by multiplicative factors of $n$ and $log(n)/epsilon^3$, respectively. We also provide the first sample complexity bounds for the group DRO objective of Sagawa et al. To guarantee these optimal sample complexity bounds, our algorithms learn to sample from data distributions on demand. Our algorithm design and analysis are enabled by our extensions of online learning techniques for solving stochastic zero-sum games. In particular, we contribute stochastic variants of no-regret dynamics that can trade off between players' differing sampling costs.

Create account to get full access

or

If you already have an account, we'll log you in

Overview

  • This paper explores learning algorithms that can perform well across multiple data distributions, rather than just optimizing for a single distribution.
  • The authors establish the optimal sample complexity for these multi-distribution learning paradigms and provide algorithms that achieve this optimal complexity.
  • The sample complexity of the new algorithms only slightly exceeds that of learning a single distribution, improving upon previous state-of-the-art results.
  • The algorithms rely on techniques from online learning and stochastic zero-sum games to trade off between different distributions' sampling costs.

Plain English Explanation

Machine learning models are often trained on a single dataset, optimizing their performance on that specific data. However, in many real-world scenarios, we need models that can perform well across a variety of data distributions. For example, a lending algorithm should work fairly for applicants from different demographic groups, not just optimize for the majority group.

The learning paradigms explored in this paper, such as collaborative learning and fair federated learning, aim to create models that minimize their expected loss across multiple pre-defined data distributions. This is challenging because the model needs to generalize well while using as few training samples as possible from each distribution.

The key innovation of this work is establishing the optimal sample complexity, or the minimum number of training examples required, for these multi-distribution learning problems. The authors show that their new algorithms can achieve this optimal complexity, only slightly exceeding what would be needed to learn a single distribution. This represents a significant improvement over prior state-of-the-art approaches.

Importantly, the algorithms are designed to intelligently sample from the different data distributions, trading off their varying costs. This is enabled by techniques from online learning and game theory, allowing the model to efficiently learn across the multiple distributions.

Technical Explanation

The paper focuses on three multi-distribution learning paradigms: collaborative learning, group distributionally robust optimization (group DRO), and fair federated learning. In each case, the goal is to train a model that minimizes its expected loss across a set of n predefined data distributions, using as few samples as possible.

The authors establish the optimal sample complexity for these problems, showing it only exceeds the complexity of learning a single distribution by an additive factor of n log(n) / ε^2, where ε is the desired accuracy level. This improves upon previous state-of-the-art results by multiplicative factors of n and log(n)/ε^3.

To achieve this optimal complexity, the authors develop algorithms that learn to intelligently sample from the different data distributions. This is enabled by extending online learning techniques for solving stochastic zero-sum games. Specifically, the authors contribute stochastic variants of no-regret dynamics that can balance the varying sampling costs across the distributions.

Critical Analysis

The paper provides a strong theoretical foundation for multi-distribution learning problems, establishing tight sample complexity bounds and designing algorithms to match this optimality. However, the analysis is largely focused on the sample efficiency aspect, without extensive empirical validation of the algorithms' practical performance.

While the theoretical results are compelling, it would be valuable to see experiments comparing the new algorithms to prior state-of-the-art approaches on real-world benchmarks. This could reveal additional insights into the tradeoffs and practical considerations of deploying these methods.

Additionally, the paper does not explore the sensitivity of the algorithms to factors like the specific choice of data distributions or the degree of distributional shift between them. Investigating these aspects could provide a more holistic understanding of the strengths and limitations of the proposed techniques.

Conclusion

This research advances the state of the art in multi-distribution learning, providing optimal sample complexity results and algorithms to achieve this efficiency. By developing techniques to intelligently sample from diverse data sources, the authors enable the creation of models that are more robust, fair, and socially beneficial than those trained on a single distribution.

These findings have important implications for real-world applications where fairness, robustness, and generalization across sub-populations are crucial, such as in healthcare, finance, and public policy. The theoretical insights and algorithmic innovations presented in this paper lay the groundwork for further advancements in this important area of machine learning.



This summary was produced with help from an AI and may contain inaccuracies - check out the links to read the original source documents!

Related Papers

🗣️

Optimal Multi-Distribution Learning

Zihan Zhang, Wenhao Zhan, Yuxin Chen, Simon S. Du, Jason D. Lee

YC

0

Reddit

0

Multi-distribution learning (MDL), which seeks to learn a shared model that minimizes the worst-case risk across $k$ distinct data distributions, has emerged as a unified framework in response to the evolving demand for robustness, fairness, multi-group collaboration, etc. Achieving data-efficient MDL necessitates adaptive sampling, also called on-demand sampling, throughout the learning process. However, there exist substantial gaps between the state-of-the-art upper and lower bounds on the optimal sample complexity. Focusing on a hypothesis class of Vapnik-Chervonenkis (VC) dimension d, we propose a novel algorithm that yields an varepsilon-optimal randomized hypothesis with a sample complexity on the order of (d+k)/varepsilon^2 (modulo some logarithmic factor), matching the best-known lower bound. Our algorithmic ideas and theory are further extended to accommodate Rademacher classes. The proposed algorithms are oracle-efficient, which access the hypothesis class solely through an empirical risk minimization oracle. Additionally, we establish the necessity of randomization, revealing a large sample size barrier when only deterministic hypotheses are permitted. These findings resolve three open problems presented in COLT 2023 (i.e., citet[Problems 1, 3 and 4]{awasthi2023sample}).

Read more

5/24/2024

Collaborative Learning with Different Labeling Functions

Yuyang Deng, Mingda Qiao

YC

0

Reddit

0

We study a variant of Collaborative PAC Learning, in which we aim to learn an accurate classifier for each of the $n$ data distributions, while minimizing the number of samples drawn from them in total. Unlike in the usual collaborative learning setup, it is not assumed that there exists a single classifier that is simultaneously accurate for all distributions. We show that, when the data distributions satisfy a weaker realizability assumption, which appeared in [Crammer and Mansour, 2012] in the context of multi-task learning, sample-efficient learning is still feasible. We give a learning algorithm based on Empirical Risk Minimization (ERM) on a natural augmentation of the hypothesis class, and the analysis relies on an upper bound on the VC dimension of this augmented class. In terms of the computational efficiency, we show that ERM on the augmented hypothesis class is NP-hard, which gives evidence against the existence of computationally efficient learners in general. On the positive side, for two special cases, we give learners that are both sample- and computationally-efficient.

Read more

5/24/2024

Taking a Moment for Distributional Robustness

Taking a Moment for Distributional Robustness

Jabari Hastings, Christopher Jung, Charlotte Peale, Vasilis Syrgkanis

YC

0

Reddit

0

A rich line of recent work has studied distributionally robust learning approaches that seek to learn a hypothesis that performs well, in the worst-case, on many different distributions over a population. We argue that although the most common approaches seek to minimize the worst-case loss over distributions, a more reasonable goal is to minimize the worst-case distance to the true conditional expectation of labels given each covariate. Focusing on the minmax loss objective can dramatically fail to output a solution minimizing the distance to the true conditional expectation when certain distributions contain high levels of label noise. We introduce a new min-max objective based on what is known as the adversarial moment violation and show that minimizing this objective is equivalent to minimizing the worst-case $ell_2$-distance to the true conditional expectation if we take the adversary's strategy space to be sufficiently rich. Previous work has suggested minimizing the maximum regret over the worst-case distribution as a way to circumvent issues arising from differential noise levels. We show that in the case of square loss, minimizing the worst-case regret is also equivalent to minimizing the worst-case $ell_2$-distance to the true conditional expectation. Although their objective and our objective both minimize the worst-case distance to the true conditional expectation, we show that our approach provides large empirical savings in computational cost in terms of the number of groups, while providing the same noise-oblivious worst-distribution guarantee as the minimax regret approach, thus making positive progress on an open question posed by Agarwal and Zhang (2022).

Read more

5/10/2024

🔍

Robust Distribution Learning with Local and Global Adversarial Corruptions

Sloan Nietert, Ziv Goldfeld, Soroosh Shafiee

YC

0

Reddit

0

We consider learning in an adversarial environment, where an $varepsilon$-fraction of samples from a distribution $P$ are arbitrarily modified (*global* corruptions) and the remaining perturbations have average magnitude bounded by $rho$ (*local* corruptions). Given access to $n$ such corrupted samples, we seek a computationally efficient estimator $hat{P}_n$ that minimizes the Wasserstein distance $mathsf{W}_1(hat{P}_n,P)$. In fact, we attack the fine-grained task of minimizing $mathsf{W}_1(Pi_# hat{P}_n, Pi_# P)$ for all orthogonal projections $Pi in mathbb{R}^{d times d}$, with performance scaling with $mathrm{rank}(Pi) = k$. This allows us to account simultaneously for mean estimation ($k=1$), distribution estimation ($k=d$), as well as the settings interpolating between these two extremes. We characterize the optimal population-limit risk for this task and then develop an efficient finite-sample algorithm with error bounded by $sqrt{varepsilon k} + rho + d^{O(1)}tilde{O}(n^{-1/k})$ when $P$ has bounded moments of order $2+delta$, for constant $delta > 0$. For data distributions with bounded covariance, our finite-sample bounds match the minimax population-level optimum for large sample sizes. Our efficient procedure relies on a novel trace norm approximation of an ideal yet intractable 2-Wasserstein projection estimator. We apply this algorithm to robust stochastic optimization, and, in the process, uncover a new method for overcoming the curse of dimensionality in Wasserstein distributionally robust optimization.

Read more

6/11/2024