optuna.integration

integration 模块包含用于将 Optuna 与外部机器学习框架集成的类。

注意

Optuna 第三方库的集成模块已开始从 Optuna 本身迁移到一个名为 optuna-integration 的包中。请查看该 仓库文档

对于 Optuna 支持的大多数 ML 框架,相应的 Optuna 集成类仅用于实现一个回调对象和函数,以符合框架特定的回调 API,并在模型训练的每个中间步骤中调用。这些回调函数在不同 ML 框架中实现的功能包括

  1. 使用 optuna.trial.Trial.report() 将中间模型分数报告回 Optuna 试验,

  2. 根据 optuna.trial.Trial.should_prune() 的结果,通过引发 optuna.TrialPruned() 来剪枝当前模型,以及

  3. 将当前的试验编号等中间 Optuna 数据报告回框架,就像在 MLflowCallback 中所做的那样。

对于 scikit-learn,提供了一个集成的 OptunaSearchCV 估计器,它结合了 scikit-learn BaseEstimator 的功能并可以访问类级别的 Study 对象。

各集成模块的依赖项

我们总结了每个集成所需的依赖项。

集成模块

依赖项

AllenNLP

allennlp, torch, psutil, jsonnet

BoTorch

botorch, gpytorch, torch

CatBoost

catboost

ChainerMN

chainermn

Chainer

chainer

pycma

cma

Dask

distributed

FastAI

fastai

Keras

keras

LightGBMTuner

lightgbm, scikit-learn

LightGBMPruningCallback

lightgbm

MLflow

mlflow

MXNet

mxnet

PyTorch Distributed

torch

PyTorch (Ignite)

pytorch-ignite

PyTorch (Lightning)

pytorch-lightning

SHAP

scikit-learn, shap

Scikit-learn

pandas, scipy, scikit-learn

SKorch

skorch

TensorBoard

tensorboard, tensorflow

TensorFlow

tensorflow, tensorflow-estimator

TensorFlow + Keras

tensorflow

Weights & Biases

wandb

XGBoost

xgboost