The Basics of Regularization in Machine Learning

What is regularization?

Regularization is the mitigation we take to ensure data is not overfitted. There are many ways to do this, but among the most important is ensuring we penalize model complexity. Generally: the simpler the model, the less likely we are to overfit.

How can we tell that we’re overfitting?

If loss in validation data initially goes down, but then starts to rise again after a certain number of iterations (while the training data continues to go down), chances are we’re overfitting.

How do we define model complexity?

The smaller the weights attached to our features, the less complex our model. L_2 regularization (also known as ridge regularization) is one common method of regularizing: it takes the sum of the squares of our attached weights, and aims to minimize them. This penalizes extreme weights in our examples.

The L_2 formula is:

L_2 regularization term = ||w||^2_2 = w^2_1 + w^2_2 + ... + w^2_n

That is: the sum of the squares of each attached weight. This discourages extreme/outlier weights.

Ideally, weights should be centered around 0, and normally distributed.

How much should you regularize?

If you have lots of training data, and the training and test data look similar, you may need little or no regularization at all. If you have less data and/or the test and training data look different, you may require more regularization.

Model developers aim to minimize both loss AND complexity, such that:

    \[minimize(Loss(Data|Model) + lambda complexity(Model))\]

L_2 regularization:

  • Encourages weight values towards zero
  • Encourages the mean of weights towards zero, with a normal (gaussian/bell-shaped) distribution

The aim is to choose a lambda value that is neither too complex (overfitting the data) nor too simple (underfitting the data).

Leave a Comment

Your email address will not be published. Required fields are marked *