注意
跳转至末尾下载完整的示例代码。
轻量、通用且平台无关的架构
Optuna 完全用 Python 编写,依赖很少。这意味着一旦您对 Optuna 感兴趣,就可以快速进入实际示例。
二次函数示例
通常,Optuna 用于优化超参数,但作为一个示例,让我们优化一个简单的二次函数:\((x - 2)^2\)。
首先,导入 optuna
。
import optuna
在 Optuna 中,约定俗成地将要优化的函数命名为 objective。
def objective(trial):
x = trial.suggest_float("x", -10, 10)
return (x - 2) ** 2
此函数返回 \((x - 2)^2\) 的值。我们的目标是找到使 objective
函数输出最小化的 x
值。这就是“优化”。在优化过程中,Optuna 会重复调用 objective 函数,并使用不同的 x
值进行评估。
一个 Trial
对象对应于 objective 函数的一次执行,并在每次调用该函数时在内部实例化。
suggest API(例如 suggest_float()
)在 objective 函数内部调用,以获取 trial 的参数。suggest_float()
在给定范围内均匀地选择参数。在我们的示例中,范围是 \(-10\) 到 \(10\)。
要开始优化,我们创建一个 study 对象,并将 objective 函数传递给 optimize()
方法,如下所示。
study = optuna.create_study()
study.optimize(objective, n_trials=100)
您可以按如下方式获取最佳参数。
best_params = study.best_params
found_x = best_params["x"]
print("Found x: {}, (x - 2)^2: {}".format(found_x, (found_x - 2) ** 2))
Found x: 1.9992078095026264, (x - 2)^2: 6.27565784128964e-07
我们可以看到 Optuna 找到的 x
值接近最优值 2
。
注意
在机器学习中用于搜索超参数时,objective 函数通常会返回模型的损失或准确率。
Study 对象
让我们澄清 Optuna 中的术语如下
Trial(试验):objective 函数的一次调用
Study(研究):一个优化会话,即一组 trial
Parameter(参数):一个待优化的变量,例如上面示例中的
x
在 Optuna 中,我们使用 study 对象来管理优化。create_study()
方法返回一个 study 对象。study 对象具有用于分析优化结果的有用属性。
获取参数名称和参数值的字典
study.best_params
{'x': 1.9992078095026264}
获取 objective 函数的最佳观测值
study.best_value
6.27565784128964e-07
获取最佳 trial
study.best_trial
FrozenTrial(number=84, state=1, values=[6.27565784128964e-07], datetime_start=datetime.datetime(2025, 4, 14, 5, 9, 50, 822066), datetime_complete=datetime.datetime(2025, 4, 14, 5, 9, 50, 826011), params={'x': 1.9992078095026264}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=84, value=None)
获取所有 trial
study.trials
for trial in study.trials[:2]: # Show first two trials
print(trial)
FrozenTrial(number=0, state=1, values=[26.06884441311108], datetime_start=datetime.datetime(2025, 4, 14, 5, 9, 50, 553162), datetime_complete=datetime.datetime(2025, 4, 14, 5, 9, 50, 553817), params={'x': -3.105765800848202}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=0, value=None)
FrozenTrial(number=1, state=1, values=[111.66884448550267], datetime_start=datetime.datetime(2025, 4, 14, 5, 9, 50, 554070), datetime_complete=datetime.datetime(2025, 4, 14, 5, 9, 50, 554305), params={'x': -8.567348034653854}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=1, value=None)
获取 trial 的数量
len(study.trials)
100
再次执行 optimize()
,我们可以继续优化。
study.optimize(objective, n_trials=100)
获取更新后的 trial 数量
len(study.trials)
200
由于 objective 函数非常简单,最后 100 个 trial 并未改进结果。但是,我们可以再次检查结果
best_params = study.best_params
found_x = best_params["x"]
print("Found x: {}, (x - 2)^2: {}".format(found_x, (found_x - 2) ** 2))
Found x: 2.0004170544142834, (x - 2)^2: 1.7393438447328081e-07
脚本总运行时间: (0 分钟 0.761 秒)