轻量、通用且平台无关的架构

Optuna 完全用 Python 编写,依赖很少。这意味着一旦您对 Optuna 感兴趣,就可以快速进入实际示例。

二次函数示例

通常,Optuna 用于优化超参数,但作为一个示例,让我们优化一个简单的二次函数:\((x - 2)^2\)

首先,导入 optuna

import optuna

在 Optuna 中,约定俗成地将要优化的函数命名为 objective

def objective(trial):
    x = trial.suggest_float("x", -10, 10)
    return (x - 2) ** 2

此函数返回 \((x - 2)^2\) 的值。我们的目标是找到使 objective 函数输出最小化的 x 值。这就是“优化”。在优化过程中,Optuna 会重复调用 objective 函数,并使用不同的 x 值进行评估。

一个 Trial 对象对应于 objective 函数的一次执行,并在每次调用该函数时在内部实例化。

suggest API(例如 suggest_float())在 objective 函数内部调用,以获取 trial 的参数。suggest_float() 在给定范围内均匀地选择参数。在我们的示例中,范围是 \(-10\)\(10\)

要开始优化,我们创建一个 study 对象,并将 objective 函数传递给 optimize() 方法,如下所示。

study = optuna.create_study()
study.optimize(objective, n_trials=100)

您可以按如下方式获取最佳参数。

best_params = study.best_params
found_x = best_params["x"]
print("Found x: {}, (x - 2)^2: {}".format(found_x, (found_x - 2) ** 2))
Found x: 1.9992078095026264, (x - 2)^2: 6.27565784128964e-07

我们可以看到 Optuna 找到的 x 值接近最优值 2

注意

在机器学习中用于搜索超参数时,objective 函数通常会返回模型的损失或准确率。

Study 对象

让我们澄清 Optuna 中的术语如下

  • Trial(试验):objective 函数的一次调用

  • Study(研究):一个优化会话,即一组 trial

  • Parameter(参数):一个待优化的变量,例如上面示例中的 x

在 Optuna 中,我们使用 study 对象来管理优化。create_study() 方法返回一个 study 对象。study 对象具有用于分析优化结果的有用属性。

获取参数名称和参数值的字典

{'x': 1.9992078095026264}

获取 objective 函数的最佳观测值

study.best_value
6.27565784128964e-07

获取最佳 trial

study.best_trial
FrozenTrial(number=84, state=1, values=[6.27565784128964e-07], datetime_start=datetime.datetime(2025, 4, 14, 5, 9, 50, 822066), datetime_complete=datetime.datetime(2025, 4, 14, 5, 9, 50, 826011), params={'x': 1.9992078095026264}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=84, value=None)

获取所有 trial

study.trials
for trial in study.trials[:2]:  # Show first two trials
    print(trial)
FrozenTrial(number=0, state=1, values=[26.06884441311108], datetime_start=datetime.datetime(2025, 4, 14, 5, 9, 50, 553162), datetime_complete=datetime.datetime(2025, 4, 14, 5, 9, 50, 553817), params={'x': -3.105765800848202}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=0, value=None)
FrozenTrial(number=1, state=1, values=[111.66884448550267], datetime_start=datetime.datetime(2025, 4, 14, 5, 9, 50, 554070), datetime_complete=datetime.datetime(2025, 4, 14, 5, 9, 50, 554305), params={'x': -8.567348034653854}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=1, value=None)

获取 trial 的数量

len(study.trials)
100

再次执行 optimize(),我们可以继续优化。

study.optimize(objective, n_trials=100)

获取更新后的 trial 数量

len(study.trials)
200

由于 objective 函数非常简单,最后 100 个 trial 并未改进结果。但是,我们可以再次检查结果

best_params = study.best_params
found_x = best_params["x"]
print("Found x: {}, (x - 2)^2: {}".format(found_x, (found_x - 2) ** 2))
Found x: 2.0004170544142834, (x - 2)^2: 1.7393438447328081e-07

脚本总运行时间: (0 分钟 0.761 秒)

由 Sphinx-Gallery 生成的图库