optuna.pruners.WilcoxonPruner

class optuna.pruners.WilcoxonPruner(*, p_threshold=0.1, n_startup_steps=2)[源码]

基于 Wilcoxon 符号秩检验 的剪枝器。

该剪枝器对当前试验与当前最佳试验进行 Wilcoxon 符号秩检验,并在剪枝器确信(达到给定 p 值)当前试验比最佳试验差时停止。

此剪枝器对于优化一组问题实例上的某些(评估成本较高)性能得分的平均值/中位数非常有效。示例应用包括优化

  • 启发式方法(模拟退火、遗传算法、SAT 求解器等)在一组问题实例上的平均性能,

  • 机器学习模型的 k 折交叉验证得分,以及

  • 大型语言模型 (LLM) 在一组问题上的输出准确性。

可能存在“容易”或“困难”的实例(剪枝器处理不同试验之间实例的对应关系)。在每次试验中,建议打乱评估顺序,以避免优化过度拟合开头的实例。

使用此剪枝器时,必须为每个步骤(实例 ID)调用 Trial.report(value, step) 方法,并提供评估值。实例 ID 可能不是升序排列的。这与其他剪枝器不同,因为报告的值不必收敛到实际值。要在相同设置下使用诸如 SuccessiveHalvingPruner 的剪枝器,您必须提供例如评估值的历史平均值。

另请参阅

请参阅 report()

示例

import optuna
import numpy as np


# We minimize the mean evaluation loss over all the problem instances.
def evaluate(param, instance):
    # A toy loss function for demonstrative purpose.
    return (param - instance) ** 2


problem_instances = np.linspace(-1, 1, 100)


def objective(trial):
    # Sample a parameter.
    param = trial.suggest_float("param", -1, 1)

    # Evaluate performance of the parameter.
    results = []

    # For best results, shuffle the evaluation order in each trial.
    instance_ids = np.random.permutation(len(problem_instances))
    for instance_id in instance_ids:
        loss = evaluate(param, problem_instances[instance_id])
        results.append(loss)

        # Report loss together with the instance id.
        # CAVEAT: You need to pass the same id for the same instance,
        # otherwise WilcoxonPruner cannot correctly pair the losses across trials and
        # the pruning performance will degrade.
        trial.report(loss, instance_id)

        if trial.should_prune():
            # Return the current predicted value instead of raising `TrialPruned`.
            # This is a workaround to tell the Optuna about the evaluation
            # results in pruned trials. (See the note below.)
            return sum(results) / len(results)

    return sum(results) / len(results)


study = optuna.create_study(pruner=optuna.pruners.WilcoxonPruner(p_threshold=0.1))
study.optimize(objective, n_trials=100)

注意

此剪枝器无法处理 infinitynan 值。包含这些值的试验永远不会被剪枝。

注意

如果 should_prune() 返回 True,您可以返回最终值的估计值(例如,所有评估值的平均值),而不是 raise optuna.TrialPruned()。这是一个针对当前无法告知 Optuna 抛出 optuna.TrialPruned 的试验的预测目标值的问题的权宜之计。

参数:
  • p_threshold (float) –

    剪枝的 p 值阈值。此值应介于 0 和 1 之间。每当剪枝器确信(达到给定 p 值)当前试验比最佳试验差时,该试验就会被剪枝。此值越大,剪枝将越激进。默认为 0.1。

    注意

    此剪枝器在当前试验与当前最佳试验之间重复执行统计检验,并增加样本。这种序列检验的假阳性率与仅执行一次检验的假阳性率不同。要获得名义假阳性率,请指定 Pocock 校正后的 p 值。

  • n_startup_steps (int) – 在此步数之前不进行剪枝。只有在当前试验和最佳试验之间有 n_startup_steps 步可用于比较的观测值后,才开始剪枝。默认为 2。请注意,即使将 n_startup_steps 设置为 0 或 1,由于缺乏足够的比较数据,试验也不会在第一步和第二步被剪枝。

注意

在 v3.6.0 中作为实验性功能添加。接口在更新版本中可能会有改动,恕不另行通知。请参阅 https://github.com/optuna/optuna/releases/tag/v3.6.0

方法

prune(study, trial)

根据报告的值判断试验是否应被剪枝。

prune(study, trial)[源码]

根据报告的值判断试验是否应被剪枝。

请注意,库用户不应调用此方法。相反,optuna.trial.Trial.report()optuna.trial.Trial.should_prune() 提供了用户界面,以便在目标函数中实现剪枝机制。

参数:
  • study (Study) – 目标 Study 的 Study 对象。

  • trial (FrozenTrial) – 目标 Trial 的 FrozenTrial 对象。修改此对象前请先复制。

返回:

一个布尔值,表示试验是否应被剪枝。

返回类型:

bool