注意
跳至末尾下载完整的示例代码。
Study.optimize 的回调函数
本教程展示了如何使用和实现用于 optimize()
的 Optuna Callback
。
Callback
在每次评估 objective
后被调用,它接收 Study
和 FrozenTrial
作为参数,并执行一些操作。
MLflowCallback 是一个很好的例子。
当连续出现多次剪枝试验时停止优化
本示例实现了一个有状态的回调函数,当连续一定数量的试验被剪枝时,它会停止优化。连续被剪枝的试验数量由 threshold
指定。
import optuna
class StopWhenTrialKeepBeingPrunedCallback:
def __init__(self, threshold: int):
self.threshold = threshold
self._consequtive_pruned_count = 0
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
if trial.state == optuna.trial.TrialState.PRUNED:
self._consequtive_pruned_count += 1
else:
self._consequtive_pruned_count = 0
if self._consequtive_pruned_count >= self.threshold:
study.stop()
这个目标函数会剪枝除前 5 次试验之外的所有试验(trial.number
从 0 开始)。
def objective(trial):
if trial.number > 4:
raise optuna.TrialPruned
return trial.suggest_float("x", 0, 1)
在这里,我们将阈值设置为 2
:一旦连续两次试验被剪枝,优化就会结束。因此,我们预计此研究会在 7 次试验后停止。
import logging
import sys
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_stop_cb = StopWhenTrialKeepBeingPrunedCallback(2)
study = optuna.create_study()
study.optimize(objective, n_trials=10, callbacks=[study_stop_cb])
A new study created in memory with name: no-name-25aa1a6d-caf6-4b2f-8c42-69e9ecba73c0
Trial 0 finished with value: 0.8825610737221564 and parameters: {'x': 0.8825610737221564}. Best is trial 0 with value: 0.8825610737221564.
Trial 1 finished with value: 0.46034973343224217 and parameters: {'x': 0.46034973343224217}. Best is trial 1 with value: 0.46034973343224217.
Trial 2 finished with value: 0.2774235052222288 and parameters: {'x': 0.2774235052222288}. Best is trial 2 with value: 0.2774235052222288.
Trial 3 finished with value: 0.07815056456769365 and parameters: {'x': 0.07815056456769365}. Best is trial 3 with value: 0.07815056456769365.
Trial 4 finished with value: 0.9498664437998013 and parameters: {'x': 0.9498664437998013}. Best is trial 3 with value: 0.07815056456769365.
Trial 5 pruned.
Trial 6 pruned.
正如您在上面的日志中所见,该研究按预期在 7 次试验后停止了。
脚本总运行时间: (0 分钟 0.005 秒)