用户自定义剪枝器

optuna.pruners 中,我们描述了目标函数如何可选地包含对剪枝功能的调用,该功能允许 Optuna 在中间结果看起来没有希望时终止优化试验。在本文档中,我们将描述如何实现您自己的剪枝器,即用于确定何时停止试验的自定义策略。

剪枝接口概述

create_study() 构造函数可选地接受一个继承自 BasePruner 的剪枝器作为参数。剪枝器应该实现抽象方法 prune(),该方法接受关联的 StudyTrial 作为参数,并返回一个布尔值:如果试验应该被剪枝则为 True,否则为 False。使用 Study 和 Trial 对象,您可以通过 get_trials() 方法访问所有其他试验,并从一个试验中,通过 intermediate_values() (一个将整数 step 映射到浮点值的字典)访问其报告的中间值。

您可以参考 Optuna 内置剪枝器的源代码作为构建您自己的剪枝器的模板。在本文档中,为了说明,我们将描述一个简单(但激进)的剪枝器的构建和使用,该剪枝器会剪枝在同一阶段与已完成试验相比处于最后位置的试验。

注意

请参考 BasePruner 的文档,或者例如 ThresholdPrunerPercentilePruner,以获取更稳健的剪枝器实现示例,包括错误检查和复杂的剪枝器内部逻辑。

示例:实现 LastPlacePruner

我们的目标是优化在 sklearn iris 数据集上运行的随机梯度下降分类器 (SGDClassifier) 的 lossalpha 超参数。我们实现了一个剪枝器,如果试验在某个阶段与同一阶段已完成的试验相比处于最后位置,它将终止该试验。我们在“热身”1 个训练步骤和 5 个已完成试验后开始考虑剪枝。为了演示目的,当 prune 即将返回 True (表示正在进行剪枝)时,我们 print() 一个诊断消息。

重要的是要注意,SGDClassifier 的评分(在保留集上评估)在足够多的训练步骤后会因过拟合而下降。这意味着即使试验在之前的训练集上获得了有利(高)的值,它也可能被剪枝。剪枝后,Optuna 将把最后报告的中间值作为该试验的值。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import SGDClassifier

import optuna
from optuna.pruners import BasePruner
from optuna.trial._state import TrialState


class LastPlacePruner(BasePruner):
    def __init__(self, warmup_steps, warmup_trials):
        self._warmup_steps = warmup_steps
        self._warmup_trials = warmup_trials

    def prune(self, study: "optuna.study.Study", trial: "optuna.trial.FrozenTrial") -> bool:
        # Get the latest score reported from this trial
        step = trial.last_step

        if step:  # trial.last_step == None when no scores have been reported yet
            this_score = trial.intermediate_values[step]

            # Get scores from other trials in the study reported at the same step
            completed_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))
            other_scores = [
                t.intermediate_values[step]
                for t in completed_trials
                if step in t.intermediate_values
            ]
            other_scores = sorted(other_scores)

            # Prune if this trial at this step has a lower value than all completed trials
            # at the same step. Note that steps will begin numbering at 0 in the objective
            # function definition below.
            if step >= self._warmup_steps and len(other_scores) > self._warmup_trials:
                if this_score < other_scores[0]:
                    print(f"prune() True: Trial {trial.number}, Step {step}, Score {this_score}")
                    return True

        return False

最后,让我们通过简单的超参数优化来确认实现是正确的。

def objective(trial):
    iris = load_iris()
    classes = np.unique(iris.target)
    X_train, X_valid, y_train, y_valid = train_test_split(
        iris.data, iris.target, train_size=100, test_size=50, random_state=0
    )

    loss = trial.suggest_categorical("loss", ["hinge", "log_loss", "perceptron"])
    alpha = trial.suggest_float("alpha", 0.00001, 0.001, log=True)
    clf = SGDClassifier(loss=loss, alpha=alpha, random_state=0)
    score = 0

    for step in range(0, 5):
        clf.partial_fit(X_train, y_train, classes=classes)
        score = clf.score(X_valid, y_valid)

        trial.report(score, step)

        if trial.should_prune():
            raise optuna.TrialPruned()

    return score


pruner = LastPlacePruner(warmup_steps=1, warmup_trials=5)
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=50)
prune() True: Trial 6, Step 3, Score 0.66
prune() True: Trial 10, Step 1, Score 0.7
prune() True: Trial 11, Step 1, Score 0.62
prune() True: Trial 12, Step 1, Score 0.64
prune() True: Trial 15, Step 3, Score 0.68
prune() True: Trial 16, Step 1, Score 0.62
prune() True: Trial 17, Step 1, Score 0.66
prune() True: Trial 18, Step 1, Score 0.7
prune() True: Trial 20, Step 1, Score 0.62
prune() True: Trial 24, Step 1, Score 0.68
prune() True: Trial 25, Step 3, Score 0.68
prune() True: Trial 26, Step 1, Score 0.7
prune() True: Trial 36, Step 3, Score 0.58
prune() True: Trial 38, Step 3, Score 0.66
prune() True: Trial 39, Step 3, Score 0.66

脚本总运行时间: (0 分钟 0.670 秒)

图库由 Sphinx-Gallery 生成