optuna.pruners.HyperbandPruner
- class optuna.pruners.HyperbandPruner(min_resource=1, max_resource='auto', reduction_factor=3, bootstrap_count=0)[source]
使用 Hyperband 的剪枝器。
由于 SuccessiveHalving (SHA) 需要配置数量 \(n\) 作为其超参数。对于给定的有限预算 \(B\),所有配置平均拥有 \(B \over n\) 的资源。正如您所见,这将在 \(B\) 和 \(B \over n\) 之间存在权衡。Hyperband 通过在固定预算下尝试不同的 \(n\) 值来解决此权衡问题。
注意
在 Hyperband 论文中,使用了与
RandomSampler
对应的采样器。Optuna 默认使用
TPESampler
。基准测试结果表明
optuna.pruners.HyperbandPruner
支持这两种采样器。
注意
如果您将
HyperbandPruner
与TPESampler
一起使用,建议考虑设置更大的n_trials
或timeout
,以便充分利用TPESampler
的特性,因为TPESampler
在启动时会使用一些(默认为 \(10\))Trial
。由于 Hyperband 运行多个
SuccessiveHalvingPruner
并根据当前Trial
的 bracket ID 收集 trials,因此每个 bracket 需要观察超过 \(10\) 个Trial
,以便TPESampler
调整其搜索空间。因此,例如,如果
HyperbandPruner
包含 \(4\) 个 pruners,则启动时至少会消耗 \(4 \times 10\) 个 trials。注意
Hyperband 包含多个
SuccessiveHalvingPruner
。在原始论文中,每个SuccessiveHalvingPruner
被称为“bracket”。bracket 的数量是控制 Hyperband 早期停止行为的重要因素,并通过min_resource
、max_resource
和reduction_factor
自动确定,计算公式为 \(\mathrm{Bracket\数量} = \mathrm{floor}(\log_{\texttt{reduction}\_\texttt{factor}} (\frac{\texttt{max}\_\texttt{resource}}{\texttt{min}\_\texttt{resource}})) + 1\)。请设置reduction_factor
,使 bracket 的数量不要过大(在大多数用例中约为 4-6)。详细信息请参阅原始论文的第 3.6 节。注意
HyperbandPruner
使用一个函数计算每个 trial 的 bracket ID,该函数接收Study
的study_name
和number
。请指定study_name
以使剪枝算法可复现。示例
我们使用 Hyperband 剪枝算法最小化目标函数。
import numpy as np from sklearn.datasets import load_iris from sklearn.linear_model import SGDClassifier from sklearn.model_selection import train_test_split import optuna X, y = load_iris(return_X_y=True) X_train, X_valid, y_train, y_valid = train_test_split(X, y) classes = np.unique(y) n_train_iter = 100 def objective(trial): alpha = trial.suggest_float("alpha", 0.0, 1.0) clf = SGDClassifier(alpha=alpha) for step in range(n_train_iter): clf.partial_fit(X_train, y_train, classes=classes) intermediate_value = clf.score(X_valid, y_valid) trial.report(intermediate_value, step) if trial.should_prune(): raise optuna.TrialPruned() return clf.score(X_valid, y_valid) study = optuna.create_study( direction="maximize", pruner=optuna.pruners.HyperbandPruner( min_resource=1, max_resource=n_train_iter, reduction_factor=3 ), ) study.optimize(objective, n_trials=20)
- 参数:
min_resource (int) – 用于指定分配给 trial 的最小资源的参数,在论文中记作 \(r\)。较小的 \(r\) 会更快地得出结果,但较大的 \(r\) 能更好地保证配置之间的成功判断。详细信息请参阅
SuccessiveHalvingPruner
。用于指定分配给 trial 的最大资源的参数。论文中的 \(R\) 对应于
max_resource / min_resource
。此值代表并应与最大迭代步数(例如,神经网络的 epoch 数)匹配。当此参数为“auto”时,最大资源将根据已完成的 trials 进行估计。此参数的默认值为“auto”。注意
使用“auto”时,最大资源将是第一个(如果在并行训练中则是其中一个)已完成 trial 中
report()
报告的最大步数。在确定最大资源之前,不会对任何 trial 进行剪枝。注意
如果最后一个中间值的步数可能随每个 trial 而改变,请手动将最大可能步数指定给
max_resource
。reduction_factor (int) – 用于指定可晋升 trials 的缩减因子的参数,在论文中记作 \(\eta\)。详细信息请参阅
SuccessiveHalvingPruner
。bootstrap_count (int) – 在任何 trial 可以晋升之前,一个 rung 中所需 trials 数量的参数。与
max_resource
为"auto"
不兼容。详细信息请参阅SuccessiveHalvingPruner
。
方法
prune
(study, trial)根据报告的值判断 trial 是否应该被剪枝。
- prune(study, trial)[source]
根据报告的值判断 trial 是否应该被剪枝。
请注意,此方法不应由库用户直接调用。相反,
optuna.trial.Trial.report()
和optuna.trial.Trial.should_prune()
提供了用户界面,用于在目标函数中实现剪枝机制。- 参数:
study (Study) – 目标 study 的 Study 对象。
trial (FrozenTrial) – 目标 trial 的 FrozenTrial 对象。在修改此对象之前请先复制一份。
- 返回:
一个布尔值,表示 trial 是否应该被剪枝。
- 返回类型: