optuna.pruners.PatientPruner

class optuna.pruners.PatientPruner(wrapped_pruner, patience, min_delta=0.0)[source]

带有容错机制的剪枝器,用于包装另一个剪枝器。

示例

import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split

import optuna

X, y = load_iris(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
classes = np.unique(y)


def objective(trial):
    alpha = trial.suggest_float("alpha", 0.0, 1.0)
    clf = SGDClassifier(alpha=alpha)
    n_train_iter = 100

    for step in range(n_train_iter):
        clf.partial_fit(X_train, y_train, classes=classes)

        intermediate_value = clf.score(X_valid, y_valid)
        trial.report(intermediate_value, step)

        if trial.should_prune():
            raise optuna.TrialPruned()

    return clf.score(X_valid, y_valid)


study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.PatientPruner(optuna.pruners.MedianPruner(), patience=1),
)
study.optimize(objective, n_trials=20)
参数:
  • wrapped_pruner (BasePruner | None) – 被包装的剪枝器,当 PatientPruner 允许试验被剪枝时执行剪枝。如果为 None,则此剪枝器等同于仅依靠单个试验的中间值进行早停。

  • patience (int) – 剪枝将禁用,直到目标连续 patience 步没有改进。

  • min_delta (float) – 用于检查目标是否改进的容差值。此值应为非负数。

注意

在 v2.8.0 中作为实验性特性添加。其接口在后续版本中可能会在未提前通知的情况下发生变化。请参阅 https://github.com/optuna/optuna/releases/tag/v2.8.0

方法

prune(study, trial) - 根据报告的值判断试验是否应该被剪枝。

基于报告的值,判断是否应该修剪试验。

prune(study, trial)[source]

基于报告的值,判断是否应该修剪试验。

请注意,此方法不应由库用户直接调用。相反,optuna.trial.Trial.report()optuna.trial.Trial.should_prune() 提供了用户接口,用于在目标函数中实现剪枝机制。

参数:
  • study (Study) – 目标研究的 Study 对象。

  • trial (FrozenTrial) – 目标试验的 FrozenTrial 对象。在修改此对象之前请复制一份。

返回:

一个布尔值,表示试验是否应该被剪枝。

返回类型:

bool