This post is part 2 of my 4 part series on Keras Callbacks. As I stated in part 1, a callback is an object that can perform actions at various stages of training and is specified when the model is trained using model.fit(). They are super helpful and save a lot of lines of code. That's part of the beauty of Keras in general - lots of lines of code saved! Some other callbacks include EarlyStopping, discussed in Part 1 of this series, TensorBoard, and LearningRateScheduler. I'll discuss TensorBoard and LearningRateScheduler in the last 2 parts. For the full list of callbacks, see the Keras Callbacks API documentation. In this post, I will discuss ModelCheckpoint.
The purpose of the ModelCheckpoint callback is exactly what it sounds like. It saves the Keras model, or just the model weights, at some frequency that you determine. Want the model saved every epoch? No problem. Maybe after a certain number of batches? Again, no problem. Want to only keep the best model based on accuracy or loss values? You guessed it - no problem.
As with other callbacks, you need to define how you want ModelCheckpoint to work and pass those values to model.fit() when you train the model. I'm sure you can read the full documentation on your own, so I'll just give an example from my own work and hit the high points. Here is some example code.
Let's discuss the options above.
Here is the output from some training I did using the ModelCheckpoint I defined above. In epochs 3 and 4, the loss decreased, so the model weights were saved. In epochs 5 and 6, the loss did not decreased, so the weights were not saved.
Now that we have our weights saved, we can later go back and load them for inference. I won't cover that (although it is simple) for the sake of brevity. Jason Brownlee (@TeachTheMachine) of Machine Learning Mastery has a very good tutorial on how to do it. He always writes great stuff!
If you have questions and want to connect, you can message me on LinkedIn or Twitter. Also, follow me on Twitter @pacejohn, LinkedIn https://www.linkedin.com/in/john-pace-phd-20b87070/, and follow my company, Mark III Systems, on Twitter @markiiisystems.
#artificialintelligence #ai #machinelearning #ml #tensorflow #keras #neuralnetworks #deeplearning #modelcheckpoint #hyperparameters #callbacks