高效优化算法

Optuna 采用最先进的算法来采样超参数并高效地剪枝无前景的试验,从而实现高效的超参数优化。

采样算法

采样器基本上通过建议参数值和评估目标值的记录不断缩小搜索空间,从而找到一个最优搜索空间,该空间能产生更好的目标值对应的参数。关于采样器如何建议参数的更详细解释,请参阅 BaseSampler

Optuna 提供以下采样算法

默认的采样器是 TPESampler

切换采样器

import optuna

默认情况下,Optuna 使用 TPESampler,如下所示。

study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is TPESampler

如果您想使用不同的采样器,例如 RandomSamplerCmaEsSampler

study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")

study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is RandomSampler
Sampler is CmaEsSampler

剪枝算法

Pruners(剪枝器)会在训练的早期阶段自动停止无前景的试验(又称自动化早期停止)。目前 pruners 模块预计仅用于单目标优化。

Optuna 提供以下剪枝算法

我们在大多数示例中使用 MedianPruner,尽管基本上它不如 SuccessiveHalvingPrunerHyperbandPruner,如 此基准测试结果 所示。

激活剪枝器

要启用剪枝功能,您需要在每次迭代训练步骤后调用 report()should_prune()report() 定期监控中间目标值。should_prune() 决定不满足预定义条件的试验是否终止。

我们建议使用主要机器学习框架的集成模块。完整的列表在 integration 中,用例可在 optuna-examples 中找到。

import logging
import sys

import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection


def objective(trial):
    iris = sklearn.datasets.load_iris()
    classes = list(set(iris.target))
    train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
        iris.data, iris.target, test_size=0.25, random_state=0
    )

    alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
    clf = sklearn.linear_model.SGDClassifier(alpha=alpha)

    for step in range(100):
        clf.partial_fit(train_x, train_y, classes=classes)

        # Report intermediate objective value.
        intermediate_value = 1.0 - clf.score(valid_x, valid_y)
        trial.report(intermediate_value, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.TrialPruned()

    return 1.0 - clf.score(valid_x, valid_y)

将中位数停止规则设置为剪枝条件。

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
A new study created in memory with name: no-name-ad883395-5901-41c3-85a2-08657acb0c1a
Trial 0 finished with value: 0.39473684210526316 and parameters: {'alpha': 5.0911054691409205e-05}. Best is trial 0 with value: 0.39473684210526316.
Trial 1 finished with value: 0.23684210526315785 and parameters: {'alpha': 0.005248996754035697}. Best is trial 1 with value: 0.23684210526315785.
Trial 2 finished with value: 0.13157894736842102 and parameters: {'alpha': 0.0002683583933567496}. Best is trial 2 with value: 0.13157894736842102.
Trial 3 finished with value: 0.13157894736842102 and parameters: {'alpha': 3.185663178106781e-05}. Best is trial 2 with value: 0.13157894736842102.
Trial 4 finished with value: 0.052631578947368474 and parameters: {'alpha': 1.6431435839092097e-05}. Best is trial 4 with value: 0.052631578947368474.
Trial 5 pruned.
Trial 6 pruned.
Trial 7 pruned.
Trial 8 pruned.
Trial 9 pruned.
Trial 10 pruned.
Trial 11 pruned.
Trial 12 pruned.
Trial 13 pruned.
Trial 14 pruned.
Trial 15 pruned.
Trial 16 pruned.
Trial 17 pruned.
Trial 18 pruned.
Trial 19 pruned.

正如您所见,有几个试验在完成所有迭代之前被剪枝(停止)了。消息的格式是 "Trial <Trial Number> pruned."

应该使用哪种采样器和剪枝器?

optuna/optuna - wiki “Benchmarks with Kurobako” 提供的基准测试结果来看,至少对于非深度学习任务,我们可以说

然而,请注意该基准测试不涉及深度学习。对于深度学习任务,请参考下表。该表来源于 Ozaki 等人的论文 Hyperparameter Optimization Methods: Overview and Characteristics, in IEICE Trans, Vol.J103-D No.9 pp.615-631, 2020,该论文是用日语撰写的。

并行计算资源

分类/条件超参数

推荐算法

有限

TPE。如果搜索空间是低维且连续的,则使用 GP-EI。

TPE。如果搜索空间是低维且连续的,则使用 GP-EI。

充足

CMA-ES, 随机搜索

随机搜索或遗传算法

剪枝集成模块

为了以更简单的方式实现剪枝机制,Optuna 为以下库提供了集成模块。

有关 Optuna 集成模块的完整列表,请参阅 integration

例如,LightGBMPruningCallback 无需直接更改训练迭代逻辑即可引入剪枝。(另请参阅 示例 获取完整脚本。)

import optuna.integration

pruning_callback = optuna.integration.LightGBMPruningCallback(trial, 'validation-error')
gbm = lgb.train(param, dtrain, valid_sets=[dvalid], callbacks=[pruning_callback])

脚本总运行时间: (0 分钟 1.159 秒)

由 Sphinx-Gallery 生成的图库