Package ml.dmlc.xgboost4j.java
Class XGBoost
java.lang.Object
ml.dmlc.xgboost4j.java.XGBoost
trainer for xgboost
- Author:
- hzx
-
Field Summary
Fields -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionstatic String[]crossValidation(DMatrix data, Map<String, Object> params, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) Cross-validation with given parameters.static booleanDecides whether the evaluation metrics are to be maximized or not.static BoosterloadModel(byte[] buffer) Load a new Booster model from a byte array buffer.static BoosterloadModel(InputStream in) Load a new Booster model from a file opened as input stream.static Boosterload model from modelPathstatic Boostertrain(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRound) Train a booster given parameters.static Boostertrain(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRounds, Booster booster) Train a booster given parameters.static Boostertrain(DMatrix dtrain, Map<String, Object> params, int round, Map<String, DMatrix> watches, IObjective obj, IEvaluation eval) Train a booster given parameters.static BoostertrainAndSaveCheckpoint(DMatrix dtrain, Map<String, Object> params, int numRounds, Map<String, DMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRounds, Booster booster, int checkpointInterval, String checkpointPath, org.apache.hadoop.fs.FileSystem fs)
-
Field Details
-
MAXIMIZ_METRICES
-
-
Constructor Details
-
XGBoost
public XGBoost()
-
-
Method Details
-
loadModel
load model from modelPath- Parameters:
modelPath- booster modelPath (model generated by booster.saveModel)- Throws:
XGBoostError- native error
-
loadModel
Load a new Booster model from a file opened as input stream. The assumption is the input stream only contains one XGBoost Model. This can be used to load existing booster models saved by other xgboost bindings.- Parameters:
in- The input stream of the file, will be closed after this function call.- Returns:
- The create boosted
- Throws:
XGBoostErrorIOException
-
loadModel
Load a new Booster model from a byte array buffer. The assumption is the array only contains one XGBoost Model. This can be used to load existing booster models saved by other xgboost bindings.- Parameters:
buffer- The byte contents of the booster.- Returns:
- The create boosted
- Throws:
XGBoostErrorIOException
-
train
public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, throws XGBoostErrorDMatrix> watches, IObjective obj, IEvaluation eval) Train a booster given parameters.- Parameters:
dtrain- Data to be trained.params- Parameters.round- Number of boosting iterations.watches- a group of items to be evaluated during training, this allows user to watch performance on the validation set.obj- customized objectiveeval- customized evaluation- Returns:
- The trained booster.
- Throws:
XGBoostError
-
train
public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, throws XGBoostErrorDMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRound) Train a booster given parameters.- Parameters:
dtrain- Data to be trained.params- Parameters.round- Number of boosting iterations.watches- a group of items to be evaluated during training, this allows user to watch performance on the validation set.metrics- array containing the evaluation metrics for each matrix in watches for each iterationearlyStoppingRound- if non-zero, training would be stopped after a specified number of consecutive increases in any evaluation metric.obj- customized objectiveeval- customized evaluation- Returns:
- The trained booster.
- Throws:
XGBoostError
-
trainAndSaveCheckpoint
public static Booster trainAndSaveCheckpoint(DMatrix dtrain, Map<String, Object> params, int numRounds, Map<String, throws XGBoostError, IOExceptionDMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRounds, Booster booster, int checkpointInterval, String checkpointPath, org.apache.hadoop.fs.FileSystem fs) - Throws:
XGBoostErrorIOException
-
train
public static Booster train(DMatrix dtrain, Map<String, Object> params, int round, Map<String, throws XGBoostErrorDMatrix> watches, float[][] metrics, IObjective obj, IEvaluation eval, int earlyStoppingRounds, Booster booster) Train a booster given parameters.- Parameters:
dtrain- Data to be trained.params- Parameters.round- Number of boosting iterations.watches- a group of items to be evaluated during training, this allows user to watch performance on the validation set.metrics- array containing the evaluation metrics for each matrix in watches for each iterationearlyStoppingRounds- if non-zero, training would be stopped after a specified number of consecutive goes to the unexpected direction in any evaluation metric.obj- customized objectiveeval- customized evaluationbooster- train from scratch if set to null; train from an existing booster if not null.- Returns:
- The trained booster.
- Throws:
XGBoostError
-
isMaximizeEvaluation
public static boolean isMaximizeEvaluation(String evalInfo, String[] evalNames, Map<String, Object> params) Decides whether the evaluation metrics are to be maximized or not.- Parameters:
evalInfo- The evaluation log string from which the metric name is inferred.evalNames- The names of the evaluation matrices.params- The parameters that contain information regarding whether the evaluation metrics are to be maximized or not.- Returns:
- True if the evaluation metrics are to be maximized, false otherwise.
-
crossValidation
public static String[] crossValidation(DMatrix data, Map<String, Object> params, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XGBoostErrorCross-validation with given parameters.- Parameters:
data- Data to be trained.params- Booster params.round- Number of boosting iterations.nfold- Number of folds in CV.metrics- Evaluation metrics to be watched in CV.obj- customized objective (set to null if not used)eval- customized evaluation (set to null if not used)- Returns:
- evaluation history
- Throws:
XGBoostError- native error
-