Model Selection
Learning Objectives
After this unit, students should be able to
- compare and contrast training and testing error.
- explain the rationale to choose an appropriate model.
- describe the concept bias and variance of the model.
- comprehend synergy between statistics and machine learning.
The following figure presents a scatter plot of a sample generated using the true function. We use different hypotheses to estimate this true function. The plot on the left employs a degree one polynomial (linear hypothesis), the middle plot uses a degree four polynomial, and the right plot uses a degree fifteen polynomial. We notice that the mean squared error decreases as the polynomial degree (or the complexity of the model) increases; however, the estimated function increasingly deviates from the true function. How to choose the model in this case?
Overfitting versus Underfitting
The goal of machine learning models is to uncover the hidden patterns within the data that generate the observed outcomes. They achieve this by optimising statistical estimates of the parameters that represent these patterns based on a given data sample. To ensure that the trained model performs well on unseen data, the available dataset is typically divided into two parts: the training dataset and the testing dataset. The training dataset is used to estimate the parameters, while the testing dataset is used to evaluate the model's ability to generalise to new data. The following conceptual diagram shows the plot of error calculated on both the training and testing datasets against the complexity of the model.
A model with very high training error is said to be underfit, while a model with very low training error is considered to be overfit. As the complexity of the model increases, it tends to overfit the training data. Initially, both the training error and testing error decrease; however, beyond a certain point, the training error continues to decrease while the testing error begins to increase. Such overfit models typically fail to generalise well over the unseen data. We can empirically identify the optimal model from a range of models by comparing their performance on both the training and testing datasets.
K-fold Cross Validation
Splitting a dataset into two subsets, a training dataset and a testing dataset, is often ineffective. If the training dataset is a skewed subsample, it can introduce sampling bias into the estimation. To mitigate this, \(K\)-fold cross-validation is used. The original dataset is divided into \(K\) equal subsets, called folds. A model is trained on \(K−1\) folds, and the remaining fold is used as the testing dataset. This process is repeated \(K\) times, resulting in \(K\) different models, from which the optimal model is selected. However, this approach is not practical for very large datasets.
Bias-Variance Tradeoff
Let's revisit machine learning through the lens of statistical analysis. In the Bayesian framework, the observed dataset is considered a random sample drawn from an unknown data-generating distribution. If this distribution were known, we could identify the true parameter \(\theta^*\) for the entire population. However, since it is not known, we instead estimate an approximation \(\hat{\theta}\) by maximising the likelihood or posterior on the available dataset. Different datasets would yield different estimates. Thus, we can treat \(\hat{\theta}\) as a random variable.
Bias refers the how far the estimate \(\hat{\theta}\) is from the true parameter \(\theta^*\) in expectation. Variance refers to the variance of the estimate itself.
If there is a high bias in the model, it means that the hypothesis does not accurately explain the observed data; this is the case of underfitting. If there is high variance, the hypothesis is very diverse that it also accurately captures noise in the observed data; this is the case of overfitting. Typically approaches used to reduce the bias tend to increase variance and approaches used to reduce variance tend to increase bias. This is known as the bias-variance tradeoff in machine learning1. In later units we will learn techniques such as regularisation and ensemble learning that help us find a way out of this tradeoff.
-
Interested readers can refer to CS229 Lecture Notes for a more mathematical discourse. ↩