Cross Validation is not a difficult topic. But when it comes to understanding how to get the tuning parameter using Cross Validation, a lot of people get confused. Hopefully, this blog might help out a little bit.
Let’s start from the beginning.
What is Cross Validation?
Cross Validation is a general technique used to identify the better performing model out of a bunch of given models.
Let’s say we have some data and we divide it into train and test,
something like this —
But, why only this 25% is used to test? Why not the starting 25%? Even if we randomly take a certain 25% data, then why only that? The point is that if we train a model with a certain 75% data and use a certain 25% data for testing, then we’ve introduced data bias in our model — it works well for ‘that’ 75% and ‘that’ 25%. It also begs the following questions — is only a certain 25% of data good for testing and only a certain 75% data good for training?
Wouldn’t it be better if we somehow get to leverage the whole data set for testing as well as training?
And this is where K‑fold Cross Validation comes into play.
Basic Idea
Let’s say we have some data like so -
D = {x1y1, x2y2, .… xnyn}
and some ML models m (1, 2,..c)
By the way, are these models yet? I guess not! These are not ready yet to do any predictions. They are simply a configuration of predictor and feature variables. They will only become a models once they pass through a data set. So, these m algorithms can be any mix of algorithms that you think should solve your problem — like for classification problem these can be logistic regression, SVM, Neural Nets, etc.
Anyhow, so here’s what K‑fold CV does -
Step 1 — Set some value of K. 5 or 10 are very common choices. Now, shuffle/permute the data randomly once.
Step 2 — Split the data into K folds (let’s say 5).
Step 3 — This is just for illustration purposes. Since there are k‑folds (5 in this case), each ML algorithm will go through k iterations of training — testing
For each of the iterations shown above, we have a test set and the rest is training set. Now, for each algorithm m in your set —
- Train the model on Iteration 1 training set and get the error by using its test set.
- Train the model on Iteration 2 training set and get the error by using its test set.
.…. and so on for K iterations. Now, find the average error as (1/k) x (sum of all errors). Note that these will be misclassified cases for classification problems and residuals for regression problems. Instead, we can also find out accuracy for classification problems.
We repeat the above for all our algorithms and choose the one that has lowest average error (or highest accuracy).
So, what do we gain by this exercise? Just this — If there was going to be some bias in our final model due to data selection bias, then we have gotten rid of that and have hopefully selected a better model.
How can K‑Fold CV be used for Complexity Parameter?
When we build a tree, we run the chance of building a big, overfitted tree. Think about it, if there’s nothing stopping tree generation then it will try to fit the train data as best as possible. Overfitting would simply mean that prediction might not be good for test data. So, how do we go around this problem? Well, we need to ‘prune’ our trees.
And so what’s pruning? It’s merging of nodes to make the tree shorter. As shown below, the tree is pruned at t2 node.
As you can guess, the more pruned a tree is, the more Sum of Squared Error (in case of regression trees) or more misclassification error (in case of classification trees) it would have — tree with just node would have the most error (most underfit tree) and the largest tree would have the least error (most overfit tree). The job of pruning is to find the balance here — so that we are able to identify a tree model that does not over or under fit. And pruning does this by something called Cost Complexity Pruning.
In simple words, the idea is this — we add this extra term α|T | (Tree Cost Complexity) to the total cost and we seek to minimize the overall cost. |T| is the total number of terminal nodes in the tree and α is the complexity parameter. In other words, this term is basically penalizing big trees — when the number of leaves in the tree increase by one, the cost increases by α. Depending on the value of α(≥ 0) a complex tree that makes no errors may now have a higher total cost than a small tree that makes a number of errors! This is how this term enables us to find a good tree.
Also, convince yourself that for a given α, there can only be one tree T that will minimize the loss function.
But. How to find α?
And this is where CV comes in. Here’s what the APIs do as an overview (of course, there would be differences in implementations and more subtleties but the point here is to get the intuition right) —
We take a range of α’s we want to try out. For each of there α’s and a given data set and pre-determined value of K(for CV), we do the following —
- For α1, find the best tree for Fold 1
- For α1, find the best tree for Fold 2
- ..and so on, For α1, find the best tree for Fold K
- For α2, find the best tree for Fold 1
- For α2, find the best tree for Fold 2
- ..and so on, For α2, find the best tree for Fold K
- …and so on, till whatever number of α’s you want to try out (typically trying out anything between 30 — 50 are probably enough)
So, for each α, we find the average accuracy of the model (accuracy in case of classification models or RMSE/SSE in case of regressions models. Let’s go with Accuracy for illustration purposes). We can plot this out like below (in practice you won’t need to plot it out and the CV api can tell you the best value of alpha) and see for what value of α we get the highest accuracy and we choose that α and plug in our original tree’s cost function and build a good tree! Note that CV is not generally used for Random Forests or GBM kind of algorithms — this is because CV won’t be much effective there since these algorithms already have a lot of randomness built into them, so chance overfitting of greatly reduced.
Also, note that the shape of the graph below is such because when value of α is lower then it will favor bigger trees and accuracy will be high (if not highest), and when it’s value keeps getting bigger then it favor more and more short trees, the shorter the tree the lesser it’s accuracy. A sweet spot is, therefore, somewhere in the middle.
So, once we find our α, we can use it to build our new tree model rather than using arbitrary parameters like minbucket, etc. to limit our trees.
Think about what CV did here — we had a bunch of models we built using different values of α, we used CV on them to find a good value of α and then eventually use α to find a good single tree model. So, this is just an application of what CV does in general — to find a good model, given a bunch of models.