optuna.pruners.PatientPruner
- class optuna.pruners.PatientPruner(wrapped_pruner, patience, min_delta=0.0)[source]
包装另一个剪枝器并提供容忍度的剪枝器。
此剪枝器会监视试验中的中间值,并在中间值的改进在达到耐心期后小于阈值时剪枝该试验。
- 剪枝器按以下方式处理 NaN 值:
1. 如果在耐心期之前或期间的所有中间值均为 NaN,则不会剪枝该试验。 2. 在剪枝计算期间,会忽略 NaN 值。仅考虑有效的数值。
示例
import numpy as np from sklearn.datasets import load_iris from sklearn.linear_model import SGDClassifier from sklearn.model_selection import train_test_split import optuna X, y = load_iris(return_X_y=True) X_train, X_valid, y_train, y_valid = train_test_split(X, y) classes = np.unique(y) def objective(trial): alpha = trial.suggest_float("alpha", 0.0, 1.0) clf = SGDClassifier(alpha=alpha) n_train_iter = 100 for step in range(n_train_iter): clf.partial_fit(X_train, y_train, classes=classes) intermediate_value = clf.score(X_valid, y_valid) trial.report(intermediate_value, step) if trial.should_prune(): raise optuna.TrialPruned() return clf.score(X_valid, y_valid) study = optuna.create_study( direction="maximize", pruner=optuna.pruners.PatientPruner(optuna.pruners.MedianPruner(), patience=1), ) study.optimize(objective, n_trials=20)
- 参数:
wrapped_pruner (BasePruner | None) – 当
PatientPruner允许剪枝试验时,用于执行剪枝的包装剪枝器。如果为None,则此剪枝器等同于采用单个试验中的中间值进行提前停止。patience (int) – 在目标值连续
patience步未得到改进之前,剪枝将被禁用。min_delta (float) – 用于检查目标值是否得到改进的容忍度值。此值应为非负数。
注意
在 v2.8.0 中作为实验性功能添加。接口在后续版本中可能在没有事先通知的情况下发生更改。请参阅 https://github.com/optuna/optuna/releases/tag/v2.8.0。
方法
prune(study, trial)根据报告的值判断试验是否应被剪枝。
- prune(study, trial)[source]
根据报告的值判断试验是否应被剪枝。
请注意,此方法不应由库用户调用。相反,
optuna.trial.Trial.report()和optuna.trial.Trial.should_prune()提供了在目标函数中实现剪枝机制的用户界面。- 参数:
study (Study) – 目标 study 的 study 对象。
trial (FrozenTrial) – 目标 trial 的 FrozenTrial 对象。修改此对象之前请复制一份。
- 返回:
一个布尔值,表示试验是否应被剪枝。
- 返回类型: