Improving Generalization via Meta-Learning on Hard Samples

Read original: arXiv:2403.12236 - Published 4/1/2024 by Nishant Jain, Arun S. Suggala, Pradeep Shenoy
Total Score

0

↗️

Sign in to get full access

or

If you already have an account, we'll log you in

Introduction

The paper discusses the problem of overfitting in overparameterized supervised learning models. Typically, a separate validation dataset is used to measure model generalization and select hyperparameters. Recently, the validation dataset is being incorporated into the learning objective itself, through approaches like learned reweighting (LRW). LRW learns importance weights for training instances or groups by optimizing a weighted training loss alongside an unweighted meta-loss on the validation data.

The authors hypothesize that optimizing the choice of validation data for LRW can further improve the generalization of the learned classifier. They formalize this as the problem of "Meta-Optimized Learned Reweighting" (MOLERE), where the partitioning of data into train and validation sets, and the LRW classifier corresponding to that split, are jointly optimized.

The paper presents an efficient algorithm to solve this meta-optimization problem. Key contributions include: 1) Proving that the optimization objective asymptotically maximizes accuracy on the hardest training samples. 2) Simplifying the nested optimization into a tractable bi-level optimization. 3) Empirically demonstrating the importance of optimizing the LRW validation set, and obtaining reliable gains over empirical risk minimization across datasets, domain generalization, and noisy label settings. 4) Extending the approach to use natural hard samples (like Imagenet-R and Imagenet-A) as validation without additional cost.

The work establishes the value of meta-optimization of meta-learning techniques, providing a strong proof-of-concept for this research direction.

Related Work

The provided text discusses several approaches to improving the robustness of machine learning models by using importance weighting of training examples. The key points are:

  • Prior work has focused on learning per-instance weights or using an MLP to predict weights based on loss values on a clean validation set. However, this requirement for a clean validation set limits the applicability of these methods.

  • To address this, a meta-learning based re-weighting approach called Fast Sample Reweighting was proposed, which generates pseudo-clean data as a proxy validation set.

  • Another recent method called RHO-loss selects only the most "worthy" training points by calculating the difference between training loss and a held-out set loss.

  • Some methods have proposed weighting based on the context or relevance of an instance compared to the overall data distribution, targeting issues like domain shift and uncertainty.

  • The sample re-weighting task is closely related to meta-learning approaches like MAML, which involve optimizing model parameters to perform well on a meta-test set.

  • Recent work has also explored using probabilistic margins of neural networks to reweight training instances in the presence of adversarial attacks.

  • A "just train twice" approach was proposed, where the first stage trains a standard ERM model, and the second stage upweights the loss of incorrectly classified examples from the first stage.

Preliminaries: Learned ReWeighting

The text summarizes a learned reweighting (LRW) classifier, where training data is reweighted to optimize a specified metric on validation data. The key aspects are:

  • LRW works with two datasets: a training set St⁢r and a validation set Sv⁢a⁢l.
  • The goal is to learn a classifier fθ⁢(⋅) and a reweighting function ϕ⁢(⋅) that minimize a bi-level objective. This objective computes an estimate of classifier performance on the validation set and optimizes it by reweighting the training data.
  • The weighted-loss-minimizing classifier fθ
    ⁢(⋅) and the weights ϕ
    ⁢(⋅) are learned jointly, typically using alternating stochastic updates.
  • Recent work uses a neural network ϕ⁢(x) for the instance-dependent reweighting function, instead of free parameters.
  • LRW has been used to overcome training label noise and handle covariate shift, assuming the validation set is representative of test samples.

MOLERE: Optimizing LRW models

The provided text presents a hypothesis and formal objective for improving the generalization capabilities of supervised learning models. The key ideas are:

  1. Combining a learned-reweighting classifier with an optimized validation set that encourages desired properties in the reweighting classifier. This is referred to as Meta-Optimization of the Learned Reweighting framework (MOLERE).

  2. The proposed approach involves learning an LRW classifier with "hard samples" as the validation set, to improve accuracy and generalization of the learned classifier.

  3. This leads to a joint optimization problem of data partitioning (train, validation) and LRW training.

  4. Theoretically, the paper shows that in the limit of infinite samples, MOLERE's objective is equivalent to a robust optimization problem, which is related to distributionally robust optimization (DRO).

  5. To efficiently solve the proposed tri-level optimization, the authors introduce a soft data assignment technique using a "splitter" network, and collapse the outer two loops into a min-max formulation.

  6. The overall algorithm, called LRWOpt, is described in detail.

  7. Additionally, a simple "train-twice" heuristic is introduced, which uses the probabilistic margin of an ERM classifier as a proxy for instance hardness, to design different LRW variants (LRW-Easy, LRW-Random, LRW-Hard) and quantify the impact of validation set optimization.

Experiments

The paper examines multiple classification tasks, including distribution shift benchmarks. For datasets with existing train-validation splits, the authors use the training data for the ERM classifier in the train-twice heuristic before pooling, ranking, and repartitioning. For end-to-end optimization, the authors start with pooled train-validation data and simultaneously learn the data splits and the corresponding LRWOpt model.

The datasets used include CIFAR-100, ImageNet-100, ImageNet-1K, Aircraft, Stanford Cars, Oxford-IIIT Fine-grained classification (Cats vs Dogs), and Diabetic Retinopathy. For out-of-distribution (OOD) analysis, the authors use ImageNet-A, ImageNet-R, Camelyon, iWildCam, and a country-shifted test set for Diabetic Retinopathy. They also analyze robustness in the presence of instance-dependent noise using the noisy version of CIFAR-100 and Clothing-1M datasets.

The paper compares the proposed method against various reweighting-based baselines, including learned reweighting methods like MWN, FSR, L2R, MAPLE, BiLAW, GDW, and StableNet, as well as Margin-Based Reweighting and Rho-Loss.

Results

This section of the paper focuses on analyzing the performance of the MOLERE (Meta-Optimized Label-Robust Estimation) method, which improves the classification accuracy of deep neural networks. The key findings are:

  1. In-Distribution Generalization:

    • MOLERE variants (LRW-Easy, LRW-Random, LRW-Hard) show consistent gains over an ERM (Empirical Risk Minimization) baseline across multiple datasets.
    • The ordering of gains is LRW-Easy < LRW-Random < LRW-Hard, confirming the importance of optimizing the validation set.
    • The end-to-end optimization method LRWOpt matches or exceeds the gains of LRW-Hard.
  2. Out-of-Distribution Generalization:

    • MOLERE classifiers outperform existing re-weighting methods on out-of-distribution datasets, with LRWOpt being the best-performing method.
    • StableNet is the closest competitor to LRWOpt, as it is also designed for out-of-distribution robustness.
  3. Noisy Label Scenarios:

    • On datasets with instance-dependent noise, such as Clothing1M and noisy CIFAR-100, LRWOpt outperforms other bi-level optimization-based methods designed for noisy labels.
  4. Skewed Label Scenarios:

    • LRWOpt significantly outperforms existing instance-based re-weighting schemes on CIFAR-100 datasets with various label skew levels.
  5. Scalability to Large Pre-Trained Models:

    • MOLERE methods, especially LRW-Hard, improve the performance of a ViT-B/16 pre-trained backbone on ImageNet-1K, showing the scalability of the approach.
  6. Leveraging Out-of-Distribution Validation Sets:

    • Using known hard or out-of-distribution datasets as the validation set for

Discussion & conclusion

The text proposes a novel idea of optimizing the choice of validation data in a learned-reweighting setting. The researchers found that choosing "hard" samples as validation data in a learned-reweighting framework consistently outperforms empirical risk minimization (ERM) across a range of datasets and domain generalization benchmarks. This result supports the hypothesis that meta-optimization of the metalearning workflow in learned-reweighting is an important area of research with potential for substantial impact.

The specific heuristic of choosing low-margin points as validation data is a simple implementation of this general approach, though it is not competitive under very high label noise scenarios. The paper suggests the need for more formal, optimization-driven approaches to address this issue. The text also expresses excitement about elucidating the theoretical basis of the observed gains in the MOLERE framework.

Supplementary Material

Appendix A Algorithmic Description

The paper describes an algorithm, Algorithm 1, that covers the details of the one-shot LRWOpt scheme. This scheme requires a set of initialized Splitter parameters φ, Meta-Network parameters Θ, and classifier parameters Θ, and outputs a set of optimal classifier parameters θ*.

The algorithm also splits the total dataset D into train (S^tr) and validation (S^val) sets using a Splitter parameterized as a neural network f_Θ(x,y), where 0 < f_Θ(x,y) < 1 for any instance (x,y). Examples with f_Θ(x,y) > 0.5 are put into the validation set, and the rest go into the train set. The F_Θ(D) function applies this splitting to the overall dataset.

The paper notes that instead of applying the nested loops for the bi-level setup at the epoch level, the authors have done it at the batch level, which yields nearly similar results.

Appendix B Proof of Theorem 1111

The paper presents a theorem on the asymptotics of a tri-level optimization problem. The key points are:

  1. The objective of the MOLERE model is equivalent to maximizing the minimum empirical loss over a subset S' of the dataset S, where the size of S' is a fraction δ of the total dataset size N+M.

  2. The proof shows that this objective is equivalent to two other optimization problems:

    • Minimizing the expected loss on the validation set Q^val under the optimal model parameter θ*.
    • Minimizing the weighted expected loss on the training set Q^tr, where the weighting function φ(x,y) is chosen appropriately.
  3. The proof considers two cases:

    • When the supports of the training and validation distributions are the same. Here, the weighting function φ is chosen as the ratio of the validation and training distributions.
    • When the supports differ. Here, φ is chosen to perform probability matching on the intersection of the supports, and set to 1 elsewhere.
  4. The equivalence of the optimization problems is shown by constructing the appropriate weighting function φ in each case.

Appendix C Deriving the Update Equations

The paper discusses the update equations for the neural networks in the proposed method, which is formulated as a bi-level optimization task.

The update equation for the Splitter Network (Θ) is provided in Equation 11. It is updated at the outer loop level using the validation set to minimize the loss and maximize the generalization error.

The update equation for the Meta-Network (ϕ) is provided in Equations 12 and 13. It is updated alongside the Splitter objective, with the loss term corresponding to minimizing the error on the validation set.

The update equation for the Classifier Network (θ) is provided in Equations 14 and 15. It is updated through the weighted training loss on the split training set, with the weights approximated using the Meta-Network.

The paper notes that the Classifier and Meta-Network update equations are similar to existing instance-based re-weighting works, with the validation set bi-level setup approximated as a single-level optimization.

The paper also discusses the early stage performance and convergence of the proposed method, stating that it initially falls back to the baseline ERM performance, with gradual divergence observed in the learning curves. Previous work is cited for convergence guarantees for bi-level and min-max objectives using alternating updates.

Appendix D Experimental Details

The paper summarizes the training and evaluation details for the proposed methods:

Architectures: Various neural network architectures were used for the classifier, including WRN28-10, VGG-16, ResNet-152, ResNet-32, and ResNet-50, with dropout regularization. A pretrained backbone was used as the base of the meta-network, with a fully connected layer for predicting instance weights. The splitter architecture followed a prior work.

Training: The training used a batch size of 64 and image size of 224x224 for most datasets, except CIFAR-100 which used 32x32. An initial learning rate of 0.1 was used for the classifier, with a decay by factor of 10 every 50 epochs. The meta-network and splitter used a fixed learning rate of 1e-3. Momentum of 0.9 was used for all components. The main classifier was warm-started for 25 epochs before updating the meta-network and splitter.

Regularization: Two regularizers based on KL divergence were used to maintain the train-validation ratio and balance labels across splits.

Datasets: The paper evaluates on several popular classification benchmarks including CIFAR-100, ImageNet-100, ImageNet-1K, and fine-grained datasets like Aircraft, Stanford Cars, Oxford-IIIT Pets, as well as out-of-distribution datasets like ImageNet-A, ImageNet-R, Camelyon, iWildCam, and Diabetic Retinopathy.

The paper also describes the baselines used for comparison, including standard ERM, margin-based reweighting, and several meta-learning based reweighting methods.

Appendix E Time Complexity, Compute, Tuning:

The provided text summarizes the training time, FLOPS, and hyperparameter considerations for various optimization methods:

Training time: Compared to the baseline ERM (Empirical Risk Minimization) cost, the runtimes for the different optimization methods are as follows: LRWOpt (1.6x), LRW-hard (2.4x), MWN (1.4x), and L2R (1.4x). LRWOpt is slightly more expensive than MWN but achieves noticeably higher accuracy, and is substantially cheaper than the train-twice heuristic (LRW-hard) while meeting or exceeding its accuracy.

FLOPS: The FLOPS for LRWOpt and LRW-hard are similar, at around 1.7x and 2.3x the FLOPS of ERM, respectively.

Hyperparameters: Compared to LRW, LRWOpt has one additional tunable hyperparameter for the splitter's learning rate. LRW itself requires a meta-network learning rate and Q. Sensitivity analysis suggests any moderate value of Q is sufficient, and the authors set Q=5 across datasets. The parameter δ is fixed at 0.1, which the authors believe is a reasonable general tradeoff between training and validation sizes. The authors also found that the ERM hyperparameters are sufficient for LRW, and the meta-network and splitter learning rates can be tied to the classifier learning rates without much degradation.

Appendix F Comparison with only hard examples in validation set

The paper analyzes three new variants of the methods used, all of which involve using only hard examples in the validation set:

  1. The first method incorporates the loss highlighted in [22] for the validation set, along with the LRW-Hard method.

  2. The second method, LRWOpt, decreases the threshold Θ for the train set to 0.2, so that only the hardest examples are included in the validation set.

  3. The third method limits the validation set in the LRW-Hard approach to only negative margin (incorrectly classified) examples from the Empirical Risk Minimization (ERM) method.

Table 7 in the paper shows the accuracy percentage gains of the LRWOpt method over these three variants on 4 randomly picked datasets, including 1 Out-of-Distribution (OOD) challenge dataset.

Appendix G Accuracy Comparison

The paper presents the raw accuracy values of all methods reported in the study. Table 5 shows the results for the comparison corresponding to Figure 1, and Table 6 corresponds to Figure 2 in the main paper. This demonstrates the effectiveness of the proposed method, as it shows gains even at high accuracy values, such as for the Oxford-IIIT Pets dataset. Furthermore, the method is also effective on relatively difficult datasets where models tend to suffer in performance, like the ImageNet-1K dataset.

Appendix H Analysis of Predicted Margins

The paper presents histograms showing the difference in margins predicted by the LRW-Hard and ERM classifiers on additional datasets, including Aircraft, Stanford Cars, and ImageNet-100. The results are similar to the main findings, with more examples on the positive side, indicating that LRW-Hard is able to optimize margins effectively.

The paper also shows an experiment that groups the data points based on the ERM margin values and reports the mean and standard deviation of the margin difference between the LRW-Hard, LRW-Easy, and ERM classifiers for the Aircraft, ImageNet-100, and ImageNet-1K datasets. The results are similar to the findings for the CIFAR-100 and Clothing datasets in the main paper. The mean margin difference is positive and significant, especially for the positive margin examples, demonstrating the effectiveness of the LRW-Hard approach compared to ERM. The significant difference between the LRW-Hard and LRW-Easy plots further supports the importance of the validation set and its effectiveness in margin maximization.



This summary was produced with help from an AI and may contain inaccuracies - check out the links to read the original source documents!

Follow @aimodelsfyi on 𝕏 →

Related Papers

↗️

Total Score

0

Improving Generalization via Meta-Learning on Hard Samples

Nishant Jain, Arun S. Suggala, Pradeep Shenoy

Learned reweighting (LRW) approaches to supervised learning use an optimization criterion to assign weights for training instances, in order to maximize performance on a representative validation dataset. We pose and formalize the problem of optimized selection of the validation set used in LRW training, to improve classifier generalization. In particular, we show that using hard-to-classify instances in the validation set has both a theoretical connection to, and strong empirical evidence of generalization. We provide an efficient algorithm for training this meta-optimized model, as well as a simple train-twice heuristic for careful comparative study. We demonstrate that LRW with easy validation data performs consistently worse than LRW with hard validation data, establishing the validity of our meta-optimization problem. Our proposed algorithm outperforms a wide range of baselines on a range of datasets and domain shift challenges (Imagenet-1K, CIFAR-100, Clothing-1M, CAMELYON, WILDS, etc.), with ~1% gains using VIT-B on Imagenet. We also show that using naturally hard examples for validation (Imagenet-R / Imagenet-A) in LRW training for Imagenet improves performance on both clean and naturally hard test instances by 1-2%. Secondary analyses show that using hard validation data in an LRW framework improves margins on test data, hinting at the mechanism underlying our empirical gains. We believe this work opens up new research directions for the meta-optimization of meta-learning in a supervised learning context.

Read more

4/1/2024

🤿

Total Score

0

Reimplementation of Learning to Reweight Examples for Robust Deep Learning

Parth Patil, Ben Boardley, Jack Gardner, Emily Loiselle, Deerajkumar Parthipan

Deep neural networks (DNNs) have been used to create models for many complex analysis problems like image recognition and medical diagnosis. DNNs are a popular tool within machine learning due to their ability to model complex patterns and distributions. However, the performance of these networks is highly dependent on the quality of the data used to train the models. Two characteristics of these sets, noisy labels and training set biases, are known to frequently cause poor generalization performance as a result of overfitting to the training set. This paper aims to solve this problem using the approach proposed by Ren et al. (2018) using meta-training and online weight approximation. We will first implement a toy-problem to crudely verify the claims made by the authors of Ren et al. (2018) and then venture into using the approach to solve a real world problem of Skin-cancer detection using an imbalanced image dataset.

Read more

5/14/2024

🧠

Total Score

0

Multiplicative Reweighting for Robust Neural Network Optimization

Noga Bar, Tomer Koren, Raja Giryes

Neural networks are widespread due to their powerful performance. However, they degrade in the presence of noisy labels at training time. Inspired by the setting of learning with expert advice, where multiplicative weight (MW) updates were recently shown to be robust to moderate data corruptions in expert advice, we propose to use MW for reweighting examples during neural networks optimization. We theoretically establish the convergence of our method when used with gradient descent and prove its advantages in 1d cases. We then validate our findings empirically for the general case by showing that MW improves the accuracy of neural networks in the presence of label noise on CIFAR-10, CIFAR-100 and Clothing1M. We also show the impact of our approach on adversarial robustness.

Read more

5/28/2024

Making Robust Generalizers Less Rigid with Soft Ascent-Descent
Total Score

0

Making Robust Generalizers Less Rigid with Soft Ascent-Descent

Matthew J. Holland, Toma Hamada

While the traditional formulation of machine learning tasks is in terms of performance on average, in practice we are often interested in how well a trained model performs on rare or difficult data points at test time. To achieve more robust and balanced generalization, methods applying sharpness-aware minimization to a subset of worst-case examples have proven successful for image classification tasks, but only using deep neural networks in a scenario where the most difficult points are also the least common. In this work, we show how such a strategy can dramatically break down under more diverse models, and as a more robust alternative, instead of typical sharpness we propose and evaluate a training criterion which penalizes poor loss concentration, which can be easily combined with loss transformations such as CVaR or DRO that control tail emphasis.

Read more

8/9/2024