Cross Validation — Intuitively Explained

Cross Val­i­da­tion is not a dif­fi­cult top­ic. But when it comes to under­stand­ing how to get the tun­ing para­me­ter using Cross Val­i­da­tion, a lot of peo­ple get con­fused. Hope­ful­ly, this blog might help out a lit­tle bit.

Let’s start from the begin­ning. 

What is Cross Val­i­da­tion?

Cross Val­i­da­tion is a gen­er­al tech­nique used to iden­ti­fy the bet­ter per­form­ing mod­el out of a bunch of giv­en mod­els. 

Let’s say we have some data and we divide it into train and test,
some­thing like this — 


But, why only this 25% is used to test? Why not the start­ing 25%? Even if we ran­dom­ly take a cer­tain 25% data, then why only that? The point is that if we train a mod­el with a cer­tain 75% data and use a cer­tain 25% data for test­ing, then we’ve intro­duced data bias in our mod­el — it works well for ‘that’ 75% and ‘that’ 25%. It also begs the fol­low­ing ques­tions — is only a cer­tain 25% of data good for test­ing and only a cer­tain 75% data good for train­ing?

Would­n’t it be bet­ter if we some­how get to lever­age the whole data set for test­ing as well as train­ing? 

And this is where K‑fold Cross Val­i­da­tion comes into play.

Basic Idea

Let’s say we have some data like so -

D = {x1y1, x2y2, .… xnyn}

and some ML mod­els m (1, 2,..c)
By the way, are these mod­els yet? I guess not! These are not ready yet to do any pre­dic­tions. They are sim­ply a con­fig­u­ra­tion of pre­dic­tor and fea­ture vari­ables. They will only become a mod­els once they pass through a data set. So, these m algo­rithms can be any mix of algo­rithms that you think should solve your prob­lem — like for clas­si­fi­ca­tion prob­lem these can be logis­tic regres­sion, SVM, Neur­al Nets, etc.

Any­how, so here’s what K‑fold CV does -

Step 1 — Set some val­ue of K. 5 or 10 are very com­mon choic­es. Now, shuffle/permute the data ran­dom­ly once.

Step 2 — Split the data into K folds (let’s say 5).

Step 3 — This is just for illus­tra­tion pur­pos­es. Since there are k‑folds (5 in this case), each ML algo­rithm will go through k iter­a­tions of train­ing — test­ing






For each of the iter­a­tions shown above, we have a test set and the rest is train­ing set. Now, for each algo­rithm m in your set — 

  1. Train the mod­el on Iter­a­tion 1 train­ing set and get the error by using its test set.
  2. Train the mod­el on Iter­a­tion 2 train­ing set and get the error by using its test set.

    .…. and so on for K iter­a­tions. Now, find the aver­age error as (1/k) x (sum of all errors). Note that these will be mis­clas­si­fied cas­es for clas­si­fi­ca­tion prob­lems and resid­u­als for regres­sion prob­lems. Instead, we can also find out accu­ra­cy for clas­si­fi­ca­tion prob­lems.

We repeat the above for all our algo­rithms and choose the one that has low­est aver­age error (or high­est accu­ra­cy).

So, what do we gain by this exer­cise? Just this — If there was going to be some bias in our final mod­el due to data selec­tion bias, then we have got­ten rid of that and have hope­ful­ly select­ed a bet­ter mod­el.

How can K‑Fold CV be used for Com­plex­i­ty Para­me­ter?

When we build a tree, we run the chance of build­ing a big, over­fit­ted tree. Think about it, if there’s noth­ing stop­ping tree gen­er­a­tion then it will try to fit the train data as best as pos­si­ble. Over­fit­ting would sim­ply mean that pre­dic­tion might not be good for test data. So, how do we go around this prob­lem? Well, we need to ‘prune’ our trees. 

And so what’s prun­ing? It’s merg­ing of nodes to make the tree short­er. As shown below, the tree is pruned at t2 node. 












As you can guess, the more pruned a tree is, the more Sum of Squared Error (in case of regres­sion trees) or more mis­clas­si­fi­ca­tion error (in case of clas­si­fi­ca­tion trees) it would have — tree with just node would have the most error (most under­fit tree) and the largest tree would have the least error (most over­fit tree). The job of prun­ing is to find the bal­ance here — so that we are able to iden­ti­fy a tree mod­el that does not over or under fit. And prun­ing does this by some­thing called Cost Com­plex­i­ty Prun­ing.




In sim­ple words, the idea is this — we add this extra term α|T | (Tree Cost Com­plex­i­ty) to the total cost and we seek to min­i­mize the over­all cost. |T| is the total num­ber of ter­mi­nal nodes in the tree and α is the com­plex­i­ty para­me­ter. In oth­er words, this term is basi­cal­ly penal­iz­ing big trees — when the num­ber of leaves in the tree increase by one, the cost increas­es by α. Depend­ing on the val­ue of α(≥ 0) a com­plex tree that makes no errors may now have a high­er total cost than a small tree that makes a num­ber of errors! This is how this term enables us to find a good tree. 

Also, con­vince your­self that for a giv­en α, there can only be one tree T that will min­i­mize the loss func­tion. 

But. How to find  α?

And this is where CV comes in. Here’s what the APIs do as an overview (of course, there would be dif­fer­ences in imple­men­ta­tions and more sub­tleties but the point here is to get the intu­ition right) — 

We take a range of α’s we want to try out. For each of there α’s and a giv­en data set and pre-deter­mined val­ue of K(for CV), we do the fol­low­ing — 

  • For α1, find the best tree for Fold 1
  • For α1, find the best tree for Fold 2
  • ..and so on, For α1, find the best tree for Fold K
  • For α2, find the best tree for Fold 1
  • For α2, find the best tree for Fold 2
  • ..and so on, For α2, find the best tree for Fold K
  • …and so on, till what­ev­er num­ber of α’s you want to try out (typ­i­cal­ly try­ing out any­thing between 30 — 50 are prob­a­bly enough)

So, for each α, we find the aver­age accu­ra­cy of the mod­el (accu­ra­cy in case of clas­si­fi­ca­tion mod­els or RMSE/SSE in case of regres­sions mod­els. Let’s go with Accu­ra­cy for illus­tra­tion pur­pos­es). We can plot this out like below (in prac­tice you won’t need to plot it out and the CV api can tell you the best val­ue of alpha) and see for what val­ue of α we get the high­est accu­ra­cy and we choose that α and plug in our orig­i­nal tree’s cost func­tion and build a good tree! Note that CV is not gen­er­al­ly used for Ran­dom Forests or GBM kind of algo­rithms — this is because CV won’t be much effec­tive there since these algo­rithms already have a lot of ran­dom­ness built into them, so chance over­fit­ting of great­ly reduced. 

Also, note that the shape of the graph below is such because when val­ue of α is low­er then it will favor big­ger trees and accu­ra­cy will be high (if not high­est), and when it’s val­ue keeps get­ting big­ger then it favor more and more short trees, the short­er the tree the less­er it’s accu­ra­cy. A sweet spot is, there­fore, some­where in the mid­dle.

So, once we find our α, we can use it to build our new tree mod­el rather than using arbi­trary para­me­ters like min­buck­et, etc. to lim­it our trees.

Think about what CV did here — we had a bunch of mod­els we built using dif­fer­ent val­ues of α, we used CV on them to find a good val­ue of α and then even­tu­al­ly use α to find a good sin­gle tree mod­el. So, this is just an appli­ca­tion of what CV does in gen­er­al — to find a good mod­el, giv­en a bunch of mod­els.

Leave a Reply

Your email address will not be published. Required fields are marked *