As machine learning practitioners, we spend a lot of time staring at loss curves. Over the years, by looking at thousands of them, we build intuition for what the shape of the curve tells us about the model, like: does it need to be bigger? do I need to increase the learning rate?
In other words, the way the model’s loss evolves during training on our dataset tells us something useful about the learning dynamics on that dataset.
This paper asks the question: does the way the loss evolves on each individual example also tell us something useful about the learning dynamics on that datapoint?
Production ML Papers to Know
This is a continuation of Production ML Papers to Know, a series from Gantry highlighting papers we think have been important to the evolving practice of production ML.
One frustrating truth about production machine learning is that models, especially large ones, tend to fail in spectacular and unexpected ways. Suppose your model makes a mistake, like answering "who is the president of the US" incorrectly, or suggesting something offensive.
Why might we care about the learning dynamics on an individual datapoint?
As datasets grow larger, it’s becoming more important to understand which datapoints are most “interesting” or useful for downstream tasks like training. To do so we can tag each example with metadata like whether it’s
- Out of distribution
Each of these metadata types could be useful for different tasks, like training on prototypical data or re-labeling noisy data. In prior work, each metadata type has been treated differently with its own heuristics.
In this work, the authors study Metadata archaeology: the problem of finding a generalized way to infer metadata across categories.
How can you efficiently infer varied metadata like noisiness and out-of-distribution on all of your data?
That brings us back to the idea of looking at loss curves on individual data points. The heuristic proposed by this paper is to assign metadata to a data point based on which category has the most similar loss curves.
More concretely, we’ll need a small dataset of examples of each metadata category and their loss curves. Then, to tag a datapoint, we’ll use kNN to see which metadata category its training dynamics resemble the most.
In short, do machine learning better by learning a model about the learning process.
How does it work?
So, is looking at per-datapoint loss curves actually useful?
To answer that question, the authors compared their more general method to pre-existing heuristics for some of the metadata categories.
On Imagenet, the qualitative results make sense. In particular, the “random output” examples correspond to images that appear mislabeled, and the “atypical” examples are strange but correctly labeled.
There are a couple of technical barriers preventing this technique from being a standard part of your toolkit, including keeping track of the per-datapoint losses at each epoch of training, and using the metadata information for downstream tasks without significantly changing your retraining pipelines.
But if you’re interested in understanding your labeled data better, and willing to some implementation work, this technique is promising!
Check out the paper here: https://arxiv.org/abs/2209.10015