用户自定义剪枝器

optuna.pruners 中,我们描述了目标函数如何可选地包含剪枝功能调用,该功能允许 Optuna 在中间结果不理想时终止优化过程。在本文档中,我们将介绍如何实现您自己的剪枝器,即一种自定义的策略来确定何时停止一个 trial。

剪枝接口概述

create_study() 构造函数接受一个可选参数,即继承自 BasePruner 的剪枝器。剪枝器应实现抽象方法 prune(),该方法接受与关联的 StudyTrial 相关的参数,并返回一个布尔值:如果 trial 应被剪枝则返回 True,否则返回 False。使用 Study 和 Trial 对象,您可以通过 get_trials() 方法访问所有其他 trial,并通过 intermediate_values() (一个将整数 step 映射到浮点值的字典) 访问 trial 的已报告中间值。

您可以参考内置 Optuna 剪枝器的源代码作为构建您自己的剪枝器的模板。在本文档中,为了说明,我们将描述一个简单的(但激进的)剪枝器的构建和使用,该剪枝器会剪枝掉在同一步中与已完成的 trial 相比处于最后位置的 trial。我们将在“预热” 1 个训练步骤和 5 个已完成的 trial 后开始考虑剪枝。出于演示目的,当 prune 即将返回 True(表示剪枝)时,我们会从 pruneprint() 一个诊断消息。

注意

请参阅 BasePruner 的文档,或者例如 ThresholdPrunerPercentilePruner,以获取更健壮的剪枝器实现示例,包括错误检查和复杂的剪枝器内部逻辑。

示例:实现 LastPlacePruner

我们的目标是在 sklearn 的鸢尾花数据集上运行的随机梯度下降分类器(SGDClassifier)中优化 lossalpha 超参数。我们实现了一个剪枝器,该剪枝器如果在同一步中落后于已完成的 trial,则在该步终止一个 trial。我们将在“预热” 1 个训练步骤和 5 个已完成的 trial 后开始考虑剪枝。出于演示目的,当 prune 即将返回 True(表示剪枝)时,我们会从 pruneprint() 一个诊断消息。

需要注意的是,由于 SGDClassifier 的得分是在一个独立的验证集上评估的,因此由于过拟合,随着训练步骤的增加,得分会下降。这意味着即使一个 trial 在之前的训练集上具有有利(高)值,也可能被剪枝。剪枝后,Optuna 将使用最后报告的中间值作为 trial 的值。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import SGDClassifier

import optuna
from optuna.pruners import BasePruner
from optuna.trial._state import TrialState


class LastPlacePruner(BasePruner):
    def __init__(self, warmup_steps, warmup_trials):
        self._warmup_steps = warmup_steps
        self._warmup_trials = warmup_trials

    def prune(self, study: "optuna.study.Study", trial: "optuna.trial.FrozenTrial") -> bool:
        # Get the latest score reported from this trial
        step = trial.last_step

        if step:  # trial.last_step == None when no scores have been reported yet
            this_score = trial.intermediate_values[step]

            # Get scores from other trials in the study reported at the same step
            completed_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))
            other_scores = [
                t.intermediate_values[step]
                for t in completed_trials
                if step in t.intermediate_values
            ]
            other_scores = sorted(other_scores)

            # Prune if this trial at this step has a lower value than all completed trials
            # at the same step. Note that steps will begin numbering at 0 in the objective
            # function definition below.
            if step >= self._warmup_steps and len(other_scores) > self._warmup_trials:
                if this_score < other_scores[0]:
                    print(f"prune() True: Trial {trial.number}, Step {step}, Score {this_score}")
                    return True

        return False

最后,让我们通过简单的超参数优化来确认实现的正确性。

def objective(trial):
    iris = load_iris()
    classes = np.unique(iris.target)
    X_train, X_valid, y_train, y_valid = train_test_split(
        iris.data, iris.target, train_size=100, test_size=50, random_state=0
    )

    loss = trial.suggest_categorical("loss", ["hinge", "log_loss", "perceptron"])
    alpha = trial.suggest_float("alpha", 0.00001, 0.001, log=True)
    clf = SGDClassifier(loss=loss, alpha=alpha, random_state=0)
    score = 0

    for step in range(0, 5):
        clf.partial_fit(X_train, y_train, classes=classes)
        score = clf.score(X_valid, y_valid)

        trial.report(score, step)

        if trial.should_prune():
            raise optuna.TrialPruned()

    return score


pruner = LastPlacePruner(warmup_steps=1, warmup_trials=5)
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=50)
prune() True: Trial 9, Step 3, Score 0.48
prune() True: Trial 11, Step 1, Score 0.36
prune() True: Trial 16, Step 4, Score 0.5
prune() True: Trial 28, Step 1, Score 0.34
prune() True: Trial 41, Step 1, Score 0.38
prune() True: Trial 43, Step 1, Score 0.34
prune() True: Trial 44, Step 1, Score 0.38
prune() True: Trial 45, Step 2, Score 0.48
prune() True: Trial 49, Step 4, Score 0.62

脚本总运行时间: (0 分钟 0.467 秒)

由 Sphinx-Gallery 生成的画廊