Package ml.dmlc.xgboost4j.java
Class ExternalCheckpointManager
java.lang.Object
ml.dmlc.xgboost4j.java.ExternalCheckpointManager
This class contains the methods that are required for managing the state of the training
process. The training state is stored in a distributed file system, that consists of
UBJ (Universal Binary JSON) model files.
The class provides methods for saving, loading and cleaning up checkpoints.
-
Constructor Summary
ConstructorsConstructorDescriptionExternalCheckpointManager(String checkpointPath, org.apache.hadoop.fs.FileSystem fs) This constructor creates a new Expternal Checkpoint Manager at the specified path in the specified file system. -
Method Summary
Modifier and TypeMethodDescriptionvoidThis method cleans all the directories and files that are present in the checkpoint path.voidcleanUpHigherVersions(int currentRound) This method cleans up all the checkpoint versions that are higher than the current round.getCheckpointRounds(int firstRound, int checkpointInterval, int numOfRounds) Get a list of iterations that need checkpointing.Read the checkpoint from the checkpoint path.voidupdateCheckpoint(Booster boosterToCheckpoint) This method updates the booster checkpoint to the the latest or current version and deleted all the previous versions of the checkpoint.
-
Constructor Details
-
ExternalCheckpointManager
public ExternalCheckpointManager(String checkpointPath, org.apache.hadoop.fs.FileSystem fs) throws XGBoostError This constructor creates a new Expternal Checkpoint Manager at the specified path in the specified file system.- Parameters:
checkpointPath- The directory path where checkpoints will be stored.fs- The file system to use for storing checkpoints.- Throws:
XGBoostError- the error that is thrown is the checkpoint path is null or empty.
-
-
Method Details
-
cleanPath
This method cleans all the directories and files that are present in the checkpoint path.- Throws:
IOException- exception that is thrown when there is an error deleting the checkpoint path.
-
loadCheckpointAsBooster
Read the checkpoint from the checkpoint path. Once the checkpoint path is read, we get the latest version of the checkpoint from all the checkpoint versions and lead it into the booster for the purpose of making predictions.- Returns:
- The booster object that is used for making predictions.
- Throws:
IOException- Any expection that occurs when reading the checkpoint path.XGBoostError- Any exception that occurs when loading the model into the booster.
-
updateCheckpoint
This method updates the booster checkpoint to the the latest or current version and deleted all the previous versions of the checkpoint.- Parameters:
boosterToCheckpoint- The booster object that is to be checkpointed and saved as a model file.- Throws:
IOException- Any exception that occurs when writing the model file to the checkpoint path.XGBoostError- Any exception that occurs when saving the model from the booster.
-
cleanUpHigherVersions
This method cleans up all the checkpoint versions that are higher than the current round. This is useful when multiple training instances are running and we want to make sure that only the checkpoints from the current training instance are retained.- Parameters:
currentRound- The current round of training.- Throws:
IOException- Any exception that occurs when deleting the checkpoint files.
-
getCheckpointRounds
public List<Integer> getCheckpointRounds(int firstRound, int checkpointInterval, int numOfRounds) throws IOException Get a list of iterations that need checkpointing.- Parameters:
firstRound- The first round of training.checkpointInterval- The interval at which checkpoints are to be saved.numOfRounds- The number of rounds to be trained.- Returns:
- A list of integer rounds that need checkpointing.
- Throws:
IOException- Any exception that occurs when getting the list of rounds.
-