optuna.integration
integration
模块包含用于将 Optuna 与外部机器学习框架集成的类。
对于 Optuna 支持的大多数 ML 框架,相应的 Optuna 集成类仅用于实现一个回调对象和函数,以符合框架特定的回调 API,并在模型训练的每个中间步骤中调用。这些回调函数在不同 ML 框架中实现的功能包括
使用
optuna.trial.Trial.report()
将中间模型分数报告回 Optuna 试验,根据
optuna.trial.Trial.should_prune()
的结果,通过引发optuna.TrialPruned()
来剪枝当前模型,以及将当前的试验编号等中间 Optuna 数据报告回框架,就像在
MLflowCallback
中所做的那样。
对于 scikit-learn,提供了一个集成的 OptunaSearchCV
估计器,它结合了 scikit-learn BaseEstimator 的功能并可以访问类级别的 Study
对象。
各集成模块的依赖项
我们总结了每个集成所需的依赖项。
集成模块 |
依赖项 |
---|---|
allennlp, torch, psutil, jsonnet |
|
botorch, gpytorch, torch |
|
catboost |
|
chainermn |
|
chainer |
|
cma |
|
distributed |
|
fastai |
|
keras |
|
lightgbm, scikit-learn |
|
lightgbm |
|
mlflow |
|
mxnet |
|
PyTorch Distributed |
torch |
PyTorch (Ignite) |
pytorch-ignite |
PyTorch (Lightning) |
pytorch-lightning |
scikit-learn, shap |
|
pandas, scipy, scikit-learn |
|
skorch |
|
tensorboard, tensorflow |
|
tensorflow, tensorflow-estimator |
|
tensorflow |
|
wandb |
|
xgboost |