注意
前往末尾 下载完整的示例代码。
高效优化算法
Optuna 通过采用最先进的采样超参数和高效剪枝不那么有希望的试验的算法,实现了高效的超参数优化。
采样算法
采样器基本上会利用建议的参数值和评估的目标值记录,不断缩小搜索空间,从而找到能给出更好目标值的参数的最优搜索空间。有关采样器如何建议参数的更详细解释,请参见 BaseSampler。
Optuna 提供以下采样算法:
在
GridSampler中实现的网格搜索在
RandomSampler中实现的随机搜索在
TPESampler中实现的树状 Parzen 估计器算法在
CmaEsSampler中实现的基于 CMA-ES 的算法在
GPSampler中实现的基于高斯过程的算法在
PartialFixedSampler中实现的允许部分固定参数的算法在
NSGAIISampler中实现的非支配排序遗传算法 II在
QMCSampler中实现的准蒙特卡洛采样算法
默认采样器是 TPESampler。
切换采样器
import optuna
默认情况下,Optuna 如下所示使用 TPESampler。
study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is TPESampler
如果你想使用不同的采样器,例如 RandomSampler 和 CmaEsSampler,
study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is RandomSampler
Sampler is CmaEsSampler
剪枝算法
Pruners 在训练的早期阶段自动停止那些不太有希望的试验(也称为自动提前停止)。目前 pruners 模块预计仅用于单目标优化。
Optuna 提供以下剪枝算法:
在
MedianPruner中实现的其中位数剪枝算法在
NopPruner中实现的非剪枝算法在
PatientPruner中实现的带有容差运行剪枝器的算法在
PercentilePruner中实现的按指定百分比剪枝试验的算法在
SuccessiveHalvingPruner中实现的异步连续减半算法在
HyperbandPruner中实现的 Hyperband 算法在
ThresholdPruner中实现的阈值剪枝算法在
WilcoxonPruner中实现的基于 Wilcoxon 符号秩检验 的剪枝算法
我们在大多数示例中使用 MedianPruner,尽管实际上 SuccessiveHalvingPruner 和 HyperbandPruner 的性能优于它,正如 此基准测试结果 所示。
激活剪枝器
要启用剪枝功能,您需要在每次迭代训练后调用 report() 和 should_prune()。 report() 定期监视中间目标值。 should_prune() 决定是否终止不满足预定义条件的试验。
我们建议使用主流机器学习框架的集成模块。独占列表为 integration,并且用例可在 optuna-examples 中找到。
import logging
import sys
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
def objective(trial):
iris = sklearn.datasets.load_iris()
classes = list(set(iris.target))
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
iris.data, iris.target, test_size=0.25, random_state=0
)
alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
clf = sklearn.linear_model.SGDClassifier(alpha=alpha)
for step in range(100):
clf.partial_fit(train_x, train_y, classes=classes)
# Report intermediate objective value.
intermediate_value = 1.0 - clf.score(valid_x, valid_y)
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.TrialPruned()
return 1.0 - clf.score(valid_x, valid_y)
将中位数停止规则设置为剪枝条件。
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
A new study created in memory with name: no-name-2bda921e-a832-40c1-a630-c2f495c462c6
Trial 0 finished with value: 0.13157894736842102 and parameters: {'alpha': 0.001106574928284306}. Best is trial 0 with value: 0.13157894736842102.
Trial 1 finished with value: 0.10526315789473684 and parameters: {'alpha': 2.9856925293062223e-05}. Best is trial 1 with value: 0.10526315789473684.
Trial 2 finished with value: 0.02631578947368418 and parameters: {'alpha': 0.004538198980597824}. Best is trial 2 with value: 0.02631578947368418.
Trial 3 finished with value: 0.39473684210526316 and parameters: {'alpha': 0.07934944353904619}. Best is trial 2 with value: 0.02631578947368418.
Trial 4 finished with value: 0.10526315789473684 and parameters: {'alpha': 0.00023121196253385567}. Best is trial 2 with value: 0.02631578947368418.
Trial 5 pruned.
Trial 6 finished with value: 0.23684210526315785 and parameters: {'alpha': 0.043141004056920194}. Best is trial 2 with value: 0.02631578947368418.
Trial 7 pruned.
Trial 8 pruned.
Trial 9 pruned.
Trial 10 pruned.
Trial 11 finished with value: 0.052631578947368474 and parameters: {'alpha': 1.1190741610989776e-05}. Best is trial 2 with value: 0.02631578947368418.
Trial 12 pruned.
Trial 13 pruned.
Trial 14 finished with value: 0.3157894736842105 and parameters: {'alpha': 0.00019949202646456292}. Best is trial 2 with value: 0.02631578947368418.
Trial 15 pruned.
Trial 16 pruned.
Trial 17 pruned.
Trial 18 pruned.
Trial 19 finished with value: 0.26315789473684215 and parameters: {'alpha': 0.000573677493078128}. Best is trial 2 with value: 0.02631578947368418.
正如你所见,有几个试验在完成所有迭代之前就被剪枝(停止)了。消息的格式为 "Trial <Trial Number> pruned."。
应该使用哪个采样器和剪枝器?
根据可在 optuna/optuna - wiki “Benchmarks with Kurobako” 中找到的基准测试结果,至少对于非深度学习任务,我们可以说:
对于
RandomSampler,MedianPruner是最佳选择。对于
TPESampler,HyperbandPruner是最佳选择。
但是,请注意,该基准测试不包含深度学习。对于深度学习任务,请参阅下表。该表摘自 Ozaki 等人 2020 年发表的论文《超参数优化方法:概述和特征》(日文),该论文发表在 IEICE Trans,Vol.J103-D No.9 pp.615-631。
并行计算资源 |
分类/条件超参数 |
推荐算法 |
|---|---|---|
有限 |
否 |
TPE。如果搜索空间是低维且连续的,则为 GP-EI。 |
是 |
TPE。如果搜索空间是低维且连续的,则为 GP-EI。 |
|
充足 |
否 |
CMA-ES, 随机搜索 |
是 |
随机搜索或遗传算法 |
用于剪枝的集成模块
为了更简单地实现剪枝机制,Optuna 为以下库提供了集成模块。
Optuna 集成模块的完整列表,请参阅 integration。
例如,LightGBMPruningCallback 在不直接更改训练迭代逻辑的情况下引入了剪枝。(另请参阅 示例 以获取整个脚本。)
import optuna.integration
pruning_callback = optuna.integration.LightGBMPruningCallback(trial, 'validation-error')
gbm = lgb.train(param, dtrain, valid_sets=[dvalid], callbacks=[pruning_callback])
脚本总运行时间: (0 分 1.632 秒)