Study.optimize 的回调函数

本教程展示了如何使用和实现用于 optimize() 的 Optuna Callback

Callback 在每次评估 objective 后被调用,它接收 StudyFrozenTrial 作为参数,并执行一些操作。

MLflowCallback 是一个很好的例子。

当连续出现多次剪枝试验时停止优化

本示例实现了一个有状态的回调函数,当连续一定数量的试验被剪枝时,它会停止优化。连续被剪枝的试验数量由 threshold 指定。

import optuna


class StopWhenTrialKeepBeingPrunedCallback:
    def __init__(self, threshold: int):
        self.threshold = threshold
        self._consequtive_pruned_count = 0

    def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
        if trial.state == optuna.trial.TrialState.PRUNED:
            self._consequtive_pruned_count += 1
        else:
            self._consequtive_pruned_count = 0

        if self._consequtive_pruned_count >= self.threshold:
            study.stop()

这个目标函数会剪枝除前 5 次试验之外的所有试验(trial.number 从 0 开始)。

def objective(trial):
    if trial.number > 4:
        raise optuna.TrialPruned

    return trial.suggest_float("x", 0, 1)

在这里,我们将阈值设置为 2:一旦连续两次试验被剪枝,优化就会结束。因此,我们预计此研究会在 7 次试验后停止。

import logging
import sys

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

study_stop_cb = StopWhenTrialKeepBeingPrunedCallback(2)
study = optuna.create_study()
study.optimize(objective, n_trials=10, callbacks=[study_stop_cb])
A new study created in memory with name: no-name-25aa1a6d-caf6-4b2f-8c42-69e9ecba73c0
Trial 0 finished with value: 0.8825610737221564 and parameters: {'x': 0.8825610737221564}. Best is trial 0 with value: 0.8825610737221564.
Trial 1 finished with value: 0.46034973343224217 and parameters: {'x': 0.46034973343224217}. Best is trial 1 with value: 0.46034973343224217.
Trial 2 finished with value: 0.2774235052222288 and parameters: {'x': 0.2774235052222288}. Best is trial 2 with value: 0.2774235052222288.
Trial 3 finished with value: 0.07815056456769365 and parameters: {'x': 0.07815056456769365}. Best is trial 3 with value: 0.07815056456769365.
Trial 4 finished with value: 0.9498664437998013 and parameters: {'x': 0.9498664437998013}. Best is trial 3 with value: 0.07815056456769365.
Trial 5 pruned.
Trial 6 pruned.

正如您在上面的日志中所见,该研究按预期在 7 次试验后停止了。

脚本总运行时间: (0 分钟 0.005 秒)

图库由 Sphinx-Gallery 生成