Skip to content

Train-Test Split

Learning Objectives

After this unit, students should be able to

  • explain the rationale behind training and testing dataset.
  • perform train and test split.
  • understand the need of validation dataset.
  • explain cross validation

The train-test split is a fundamental step to prepare the data to be provided to machine learning model. The aim of machine learning models is to find latent patterns in the data that generalise over unseen data. In order to simulate this behaviour, available data is split into two parts: training data and testing data. Training data is portion of the dataset used to train the model. The model learns patterns, relationships, and parameters from this data. Testing data is portion of the dataset used to evaluate the performance of the trained model. The model makes predictions on this data, and the predictions are compared to the actual values to assess performance. Conventionally, \(70-80\%\) of the data is kept for training purposes whereas \(20-30\%\) of the data is reserved to test the trained model.

Should we preprocess the data before or after splitting?

Whether to preprocess data before or after performing the train-test split depends on the specific preprocessing steps and the type of data.

  • Preprocess before the split. Data preprocessing steps such as handling missing values, removing duplicates, and correcting data errors should be done before splitting the data. These operations are necessary for ensuring the quality of your entire dataset. If new features are to be created, they should be created for the entire dataset. Therefore, feature generation should be done before splitting.
  • Preprocess after the split. Data preprocessing steps such as scaling, normalisation, encoding and imputation should be done after splitting the dataset. The main reason is to avoid information leakage, where information from the test set influences the training process.

Validation Set

Many times we want to experiment with various machine learning models that accurately represent the latent patterns in the data. Many machine learning models have hyperparameters (we will learn more about them in later chapters) that need to be set before training begins. If we choose appropriate model and tune hyperparameters directly on the test set, the model could eventually overfit to the test data, leading to an overestimation of its performance on new, unseen data. Therefore, it is necessary to have access to another dataset to perform model selection and model validation.

To facilitate model tuning and selection during the training process, the training data is divided into two subsets: the training dataset and the validation dataset. The model is usually trained on the training set and assessed on the validation set. The model is iteratively improved until it begins to overfit on the validation data. At this stage, the training and validation sets are combined, and the model is trained on this new dataset. Once satisfactory performance is achieved, the model's effectiveness is tested on the test dataset.

What should we do if the validated model performs poorly on the test data?

Cross Validation

Hyperparameter tuning and model selection using only a single training and validation dataset can sometimes lead to misleading and overly optimistic results. This issue is more pronounced with small datasets, where the validation set may become too small to accurately reflect the characteristics of the test data. Cross validation involves repeatedly partitioning "training" (non testing data) data into training set and validation set. Model is trained on diverse training datasets and evaluated on diverse validation datasets. This provides a more robust and unbiased estimation of model performance.

For the discussion consider the scenario wherein the dataset has been split into two sets: training set and testing set. Unless explicitly specified, dataset refers to training set for the rest of the discussion below. Cross validation can be performed in the following ways:

  • Leave-One-Out Cross Validation. This strategy splits the dataset such that the validation set comprises of one datapoint and the model is trained on the rest of the datapoints. The process is repeated for every single datapoint. We can clearly observe that this strategy is not practical for large datasets. We can extend this strategy to leave-k-Out cross validation wherein the \(k\) datapoints are reserved for validation.

  • K-fold Cross Validation. This strategy splits the dataset into \(K\) equal disjoint sets. \(K-1\) sets are used for training purpose and the remaining set is used for validation. The process is repeated \(K\) times such that each of the \(K\) sets serves as validation dataset.

  • Stratified Cross Validation.. Earlier techniques are not well suited for datasets with class imbalance issue. The division of dataset may give rise to datasets that may result into validation sets with data from a single class. This strategy ensures splitting strategy that ensure class distribution. The process is repeated \(K\) times.

What to do after cross-validation?

Exercise. After obtaining the best model (or best hyperparameters), what should we do? Should we choose that model or retrain that model by combining the training and validation dataset?