注意
跳转至末尾 下载完整示例代码。
高效优化算法
Optuna 采用最先进的算法来采样超参数并高效地剪枝无前景的试验,从而实现高效的超参数优化。
采样算法
采样器基本上通过建议参数值和评估目标值的记录不断缩小搜索空间,从而找到一个最优搜索空间,该空间能产生更好的目标值对应的参数。关于采样器如何建议参数的更详细解释,请参阅 BaseSampler
。
Optuna 提供以下采样算法
网格搜索,实现在
GridSampler
中随机搜索,实现在
RandomSampler
中树状 Parzen 估计器算法,实现在
TPESampler
中基于 CMA-ES 的算法,实现在
CmaEsSampler
中基于高斯过程的算法,实现在
GPSampler
中实现部分固定参数的算法,实现在
PartialFixedSampler
中非支配排序遗传算法 II,实现在
NSGAIISampler
中一种准蒙特卡洛采样算法,实现在
QMCSampler
中
默认的采样器是 TPESampler
。
切换采样器
import optuna
默认情况下,Optuna 使用 TPESampler
,如下所示。
study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is TPESampler
如果您想使用不同的采样器,例如 RandomSampler
和 CmaEsSampler
,
study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is RandomSampler
Sampler is CmaEsSampler
剪枝算法
Pruners
(剪枝器)会在训练的早期阶段自动停止无前景的试验(又称自动化早期停止)。目前 pruners
模块预计仅用于单目标优化。
Optuna 提供以下剪枝算法
中位数剪枝算法,实现在
MedianPruner
中不执行剪枝的算法,实现在
NopPruner
中实现带容忍度剪枝器的算法,实现在
PatientPruner
中实现剪枝指定百分比试验的算法,实现在
PercentilePruner
中异步逐次减半算法,实现在
SuccessiveHalvingPruner
中Hyperband 算法,实现在
HyperbandPruner
中阈值剪枝算法,实现在
ThresholdPruner
中基于 Wilcoxon 符号秩检验 的剪枝算法,实现在
WilcoxonPruner
中
我们在大多数示例中使用 MedianPruner
,尽管基本上它不如 SuccessiveHalvingPruner
和 HyperbandPruner
,如 此基准测试结果 所示。
激活剪枝器
要启用剪枝功能,您需要在每次迭代训练步骤后调用 report()
和 should_prune()
。report()
定期监控中间目标值。should_prune()
决定不满足预定义条件的试验是否终止。
我们建议使用主要机器学习框架的集成模块。完整的列表在 integration
中,用例可在 optuna-examples 中找到。
import logging
import sys
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
def objective(trial):
iris = sklearn.datasets.load_iris()
classes = list(set(iris.target))
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
iris.data, iris.target, test_size=0.25, random_state=0
)
alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
clf = sklearn.linear_model.SGDClassifier(alpha=alpha)
for step in range(100):
clf.partial_fit(train_x, train_y, classes=classes)
# Report intermediate objective value.
intermediate_value = 1.0 - clf.score(valid_x, valid_y)
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.TrialPruned()
return 1.0 - clf.score(valid_x, valid_y)
将中位数停止规则设置为剪枝条件。
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
A new study created in memory with name: no-name-ad883395-5901-41c3-85a2-08657acb0c1a
Trial 0 finished with value: 0.39473684210526316 and parameters: {'alpha': 5.0911054691409205e-05}. Best is trial 0 with value: 0.39473684210526316.
Trial 1 finished with value: 0.23684210526315785 and parameters: {'alpha': 0.005248996754035697}. Best is trial 1 with value: 0.23684210526315785.
Trial 2 finished with value: 0.13157894736842102 and parameters: {'alpha': 0.0002683583933567496}. Best is trial 2 with value: 0.13157894736842102.
Trial 3 finished with value: 0.13157894736842102 and parameters: {'alpha': 3.185663178106781e-05}. Best is trial 2 with value: 0.13157894736842102.
Trial 4 finished with value: 0.052631578947368474 and parameters: {'alpha': 1.6431435839092097e-05}. Best is trial 4 with value: 0.052631578947368474.
Trial 5 pruned.
Trial 6 pruned.
Trial 7 pruned.
Trial 8 pruned.
Trial 9 pruned.
Trial 10 pruned.
Trial 11 pruned.
Trial 12 pruned.
Trial 13 pruned.
Trial 14 pruned.
Trial 15 pruned.
Trial 16 pruned.
Trial 17 pruned.
Trial 18 pruned.
Trial 19 pruned.
正如您所见,有几个试验在完成所有迭代之前被剪枝(停止)了。消息的格式是 "Trial <Trial Number> pruned."
。
应该使用哪种采样器和剪枝器?
从 optuna/optuna - wiki “Benchmarks with Kurobako” 提供的基准测试结果来看,至少对于非深度学习任务,我们可以说
对于
RandomSampler
,MedianPruner
是最好的选择。对于
TPESampler
,HyperbandPruner
是最好的选择。
然而,请注意该基准测试不涉及深度学习。对于深度学习任务,请参考下表。该表来源于 Ozaki 等人的论文 Hyperparameter Optimization Methods: Overview and Characteristics, in IEICE Trans, Vol.J103-D No.9 pp.615-631, 2020,该论文是用日语撰写的。
并行计算资源 |
分类/条件超参数 |
推荐算法 |
---|---|---|
有限 |
否 |
TPE。如果搜索空间是低维且连续的,则使用 GP-EI。 |
是 |
TPE。如果搜索空间是低维且连续的,则使用 GP-EI。 |
|
充足 |
否 |
CMA-ES, 随机搜索 |
是 |
随机搜索或遗传算法 |
剪枝集成模块
为了以更简单的方式实现剪枝机制,Optuna 为以下库提供了集成模块。
有关 Optuna 集成模块的完整列表,请参阅 integration
。
例如,LightGBMPruningCallback 无需直接更改训练迭代逻辑即可引入剪枝。(另请参阅 示例 获取完整脚本。)
import optuna.integration
pruning_callback = optuna.integration.LightGBMPruningCallback(trial, 'validation-error')
gbm = lgb.train(param, dtrain, valid_sets=[dvalid], callbacks=[pruning_callback])
脚本总运行时间: (0 分钟 1.159 秒)