高效优化算法

Optuna 通过采用最先进的采样超参数和高效剪枝不那么有希望的试验的算法,实现了高效的超参数优化。

采样算法

采样器基本上会利用建议的参数值和评估的目标值记录,不断缩小搜索空间,从而找到能给出更好目标值的参数的最优搜索空间。有关采样器如何建议参数的更详细解释,请参见 BaseSampler

Optuna 提供以下采样算法:

默认采样器是 TPESampler

切换采样器

import optuna

默认情况下,Optuna 如下所示使用 TPESampler

study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is TPESampler

如果你想使用不同的采样器,例如 RandomSamplerCmaEsSampler

study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")

study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is RandomSampler
Sampler is CmaEsSampler

剪枝算法

Pruners 在训练的早期阶段自动停止那些不太有希望的试验(也称为自动提前停止)。目前 pruners 模块预计仅用于单目标优化。

Optuna 提供以下剪枝算法:

我们在大多数示例中使用 MedianPruner,尽管实际上 SuccessiveHalvingPrunerHyperbandPruner 的性能优于它,正如 此基准测试结果 所示。

激活剪枝器

要启用剪枝功能,您需要在每次迭代训练后调用 report()should_prune()report() 定期监视中间目标值。 should_prune() 决定是否终止不满足预定义条件的试验。

我们建议使用主流机器学习框架的集成模块。独占列表为 integration,并且用例可在 optuna-examples 中找到。

import logging
import sys

import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection


def objective(trial):
    iris = sklearn.datasets.load_iris()
    classes = list(set(iris.target))
    train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
        iris.data, iris.target, test_size=0.25, random_state=0
    )

    alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
    clf = sklearn.linear_model.SGDClassifier(alpha=alpha)

    for step in range(100):
        clf.partial_fit(train_x, train_y, classes=classes)

        # Report intermediate objective value.
        intermediate_value = 1.0 - clf.score(valid_x, valid_y)
        trial.report(intermediate_value, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.TrialPruned()

    return 1.0 - clf.score(valid_x, valid_y)

将中位数停止规则设置为剪枝条件。

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
A new study created in memory with name: no-name-2bda921e-a832-40c1-a630-c2f495c462c6
Trial 0 finished with value: 0.13157894736842102 and parameters: {'alpha': 0.001106574928284306}. Best is trial 0 with value: 0.13157894736842102.
Trial 1 finished with value: 0.10526315789473684 and parameters: {'alpha': 2.9856925293062223e-05}. Best is trial 1 with value: 0.10526315789473684.
Trial 2 finished with value: 0.02631578947368418 and parameters: {'alpha': 0.004538198980597824}. Best is trial 2 with value: 0.02631578947368418.
Trial 3 finished with value: 0.39473684210526316 and parameters: {'alpha': 0.07934944353904619}. Best is trial 2 with value: 0.02631578947368418.
Trial 4 finished with value: 0.10526315789473684 and parameters: {'alpha': 0.00023121196253385567}. Best is trial 2 with value: 0.02631578947368418.
Trial 5 pruned.
Trial 6 finished with value: 0.23684210526315785 and parameters: {'alpha': 0.043141004056920194}. Best is trial 2 with value: 0.02631578947368418.
Trial 7 pruned.
Trial 8 pruned.
Trial 9 pruned.
Trial 10 pruned.
Trial 11 finished with value: 0.052631578947368474 and parameters: {'alpha': 1.1190741610989776e-05}. Best is trial 2 with value: 0.02631578947368418.
Trial 12 pruned.
Trial 13 pruned.
Trial 14 finished with value: 0.3157894736842105 and parameters: {'alpha': 0.00019949202646456292}. Best is trial 2 with value: 0.02631578947368418.
Trial 15 pruned.
Trial 16 pruned.
Trial 17 pruned.
Trial 18 pruned.
Trial 19 finished with value: 0.26315789473684215 and parameters: {'alpha': 0.000573677493078128}. Best is trial 2 with value: 0.02631578947368418.

正如你所见,有几个试验在完成所有迭代之前就被剪枝(停止)了。消息的格式为 "Trial <Trial Number> pruned."

应该使用哪个采样器和剪枝器?

根据可在 optuna/optuna - wiki “Benchmarks with Kurobako” 中找到的基准测试结果,至少对于非深度学习任务,我们可以说:

但是,请注意,该基准测试不包含深度学习。对于深度学习任务,请参阅下表。该表摘自 Ozaki 等人 2020 年发表的论文《超参数优化方法:概述和特征》(日文),该论文发表在 IEICE Trans,Vol.J103-D No.9 pp.615-631。

并行计算资源

分类/条件超参数

推荐算法

有限

TPE。如果搜索空间是低维且连续的,则为 GP-EI。

TPE。如果搜索空间是低维且连续的,则为 GP-EI。

充足

CMA-ES, 随机搜索

随机搜索或遗传算法

用于剪枝的集成模块

为了更简单地实现剪枝机制,Optuna 为以下库提供了集成模块。

Optuna 集成模块的完整列表,请参阅 integration

例如,LightGBMPruningCallback 在不直接更改训练迭代逻辑的情况下引入了剪枝。(另请参阅 示例 以获取整个脚本。)

import optuna.integration

pruning_callback = optuna.integration.LightGBMPruningCallback(trial, 'validation-error')
gbm = lgb.train(param, dtrain, valid_sets=[dvalid], callbacks=[pruning_callback])

脚本总运行时间: (0 分 1.632 秒)

由 Sphinx-Gallery 生成的画廊