optuna.pruners.HyperbandPruner

class optuna.pruners.HyperbandPruner(min_resource=1, max_resource='auto', reduction_factor=3, bootstrap_count=0)[源代码]

使用 Hyperband 的 Pruner。

由于 SuccessiveHalving (SHA) 需要配置数量 \(n\) 作为其超参数。对于给定的有限预算 \(B\),所有配置平均拥有 \(B \over n\) 的资源。正如你所见,这会在 \(B\)\(B \over n\) 之间进行权衡。Hyperband 通过为固定预算尝试不同的 \(n\) 值来解决这种权衡。

注意

注意

如果你将 HyperbandPrunerTPESampler 一起使用,建议考虑设置更大的 n_trialstimeout,以便充分利用 TPESampler 的特性,因为 TPESampler 使用了一些(默认情况下,\(10\) 个) Trial 进行启动。

由于 Hyperband 运行多个 SuccessiveHalvingPruner 并根据当前 Trial 的 bracket ID 收集 trial,因此每个 bracket 需要观察超过 \(10\)Trial 才能使 TPESampler 适应其搜索空间。

因此,例如,如果 HyperbandPruner 中有 \(4\) 个 pruner,则至少会消耗 \(4 \times 10\) 个 trial 用于启动。

注意

Hyperband 有几个 SuccessiveHalvingPruner。每个 SuccessiveHalvingPruner 在原始论文中被称为“bracket”。bracket 的数量是控制 Hyperband 提前停止行为的重要因素,它由 min_resourcemax_resourcereduction_factor 自动确定,公式为 \(\mathrm{The\ number\ of\ brackets} = \mathrm{floor}(\log_{\texttt{reduction}\_\texttt{factor}} (\frac{\texttt{max}\_\texttt{resource}}{\texttt{min}\_\texttt{resource}})) + 1\)。请设置 reduction_factor,使 bracket 的数量不是太大(在大多数用例中约为 4-6)。有关详细信息,请参阅 原始论文 的第 3.6 节。

注意

HyperbandPruner 使用一个接受 Studystudy_namenumber 的函数来计算每个 trial 的 bracket ID。请指定 study_name 以使剪枝算法可复现。

示例

我们使用 Hyperband 剪枝算法最小化一个目标函数。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split

import optuna

X, y = load_iris(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
classes = np.unique(y)
n_train_iter = 100


def objective(trial):
    alpha = trial.suggest_float("alpha", 0.0, 1.0)
    clf = SGDClassifier(alpha=alpha)

    for step in range(n_train_iter):
        clf.partial_fit(X_train, y_train, classes=classes)

        intermediate_value = clf.score(X_valid, y_valid)
        trial.report(intermediate_value, step)

        if trial.should_prune():
            raise optuna.TrialPruned()

    return clf.score(X_valid, y_valid)


study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.HyperbandPruner(
        min_resource=1, max_resource=n_train_iter, reduction_factor=3
    ),
)
study.optimize(objective, n_trials=20)
参数:
  • min_resource (int) – 用于指定分配给 trial 的最小资源的参数,在论文中记为 \(r\)。较小的 \(r\) 将更快地得到结果,但较大的 \(r\) 将更好地保证配置之间判断的成功。有关详细信息,请参阅 SuccessiveHalvingPruner

  • max_resource (str | int) –

    用于指定分配给 trial 的最大资源的参数。论文中的 \(R\) 对应于 max_resource / min_resource。此值表示并且应与最大迭代步数匹配(例如,神经网络的 epoch 数)。当此参数为“auto”时,将根据已完成的 trial 估算最大资源。此参数的默认值为“auto”。

    注意

    使用“auto”时,最大资源将是第一个、或并行训练的第一个 trial 中 report() 报告的最大步数。在确定最大资源之前,不会进行任何 trial 的剪枝。

    注意

    如果最后一个中间值的步数可能因 trial 而异,请手动将可能的最大步数指定给 max_resource

  • reduction_factor (int) – 用于指定可晋升 trial 的缩减因子的参数,在论文中记为 \(\eta\)。有关详细信息,请参阅 SuccessiveHalvingPruner

  • bootstrap_count (int) – 指定在任何 trial 可以晋升之前,每个 rung 所需的 trial 数量的参数。与 max_resource"auto" 时不兼容。有关详细信息,请参阅 SuccessiveHalvingPruner

方法

prune(study, trial)

根据报告的值判断试验是否应被剪枝。

prune(study, trial)[源代码]

根据报告的值判断试验是否应被剪枝。

请注意,此方法不应由库用户调用。相反,optuna.trial.Trial.report()optuna.trial.Trial.should_prune() 提供了在目标函数中实现剪枝机制的用户界面。

参数:
  • study (Study) – 目标 study 的 study 对象。

  • trial (FrozenTrial) – 目标 trial 的 FrozenTrial 对象。修改此对象之前请复制一份。

返回:

一个布尔值,表示试验是否应被剪枝。

返回类型:

bool