Class ExternalCheckpointManager

java.lang.Object
ml.dmlc.xgboost4j.java.ExternalCheckpointManager

public class ExternalCheckpointManager extends Object
This class contains the methods that are required for managing the state of the training process. The training state is stored in a distributed file system, that consists of UBJ (Universal Binary JSON) model files. The class provides methods for saving, loading and cleaning up checkpoints.
  • Constructor Summary

    Constructors
    Constructor
    Description
    ExternalCheckpointManager(String checkpointPath, org.apache.hadoop.fs.FileSystem fs)
    This constructor creates a new Expternal Checkpoint Manager at the specified path in the specified file system.
  • Method Summary

    Modifier and Type
    Method
    Description
    void
    This method cleans all the directories and files that are present in the checkpoint path.
    void
    cleanUpHigherVersions(int currentRound)
    This method cleans up all the checkpoint versions that are higher than the current round.
    getCheckpointRounds(int firstRound, int checkpointInterval, int numOfRounds)
    Get a list of iterations that need checkpointing.
    Read the checkpoint from the checkpoint path.
    void
    updateCheckpoint(Booster boosterToCheckpoint)
    This method updates the booster checkpoint to the the latest or current version and deleted all the previous versions of the checkpoint.

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
  • Constructor Details

    • ExternalCheckpointManager

      public ExternalCheckpointManager(String checkpointPath, org.apache.hadoop.fs.FileSystem fs) throws XGBoostError
      This constructor creates a new Expternal Checkpoint Manager at the specified path in the specified file system.
      Parameters:
      checkpointPath - The directory path where checkpoints will be stored.
      fs - The file system to use for storing checkpoints.
      Throws:
      XGBoostError - the error that is thrown is the checkpoint path is null or empty.
  • Method Details

    • cleanPath

      public void cleanPath() throws IOException
      This method cleans all the directories and files that are present in the checkpoint path.
      Throws:
      IOException - exception that is thrown when there is an error deleting the checkpoint path.
    • loadCheckpointAsBooster

      public Booster loadCheckpointAsBooster() throws IOException, XGBoostError
      Read the checkpoint from the checkpoint path. Once the checkpoint path is read, we get the latest version of the checkpoint from all the checkpoint versions and lead it into the booster for the purpose of making predictions.
      Returns:
      The booster object that is used for making predictions.
      Throws:
      IOException - Any expection that occurs when reading the checkpoint path.
      XGBoostError - Any exception that occurs when loading the model into the booster.
    • updateCheckpoint

      public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError
      This method updates the booster checkpoint to the the latest or current version and deleted all the previous versions of the checkpoint.
      Parameters:
      boosterToCheckpoint - The booster object that is to be checkpointed and saved as a model file.
      Throws:
      IOException - Any exception that occurs when writing the model file to the checkpoint path.
      XGBoostError - Any exception that occurs when saving the model from the booster.
    • cleanUpHigherVersions

      public void cleanUpHigherVersions(int currentRound) throws IOException
      This method cleans up all the checkpoint versions that are higher than the current round. This is useful when multiple training instances are running and we want to make sure that only the checkpoints from the current training instance are retained.
      Parameters:
      currentRound - The current round of training.
      Throws:
      IOException - Any exception that occurs when deleting the checkpoint files.
    • getCheckpointRounds

      public List<Integer> getCheckpointRounds(int firstRound, int checkpointInterval, int numOfRounds) throws IOException
      Get a list of iterations that need checkpointing.
      Parameters:
      firstRound - The first round of training.
      checkpointInterval - The interval at which checkpoints are to be saved.
      numOfRounds - The number of rounds to be trained.
      Returns:
      A list of integer rounds that need checkpointing.
      Throws:
      IOException - Any exception that occurs when getting the list of rounds.