使用 RDB 后端保存/恢复研究

RDB 后端支持持久化实验(即保存和恢复研究)以及访问研究历史。此外,我们还可以利用此功能运行多节点优化任务,这在 易于并行化 中进行了描述。

在本节中,我们将尝试在本地环境中使用 SQLite 数据库运行简单示例。

注意

您还可以通过将 storage 参数设置为数据库的 URL 来利用其他 RDB 后端,例如 PostgreSQL 或 MySQL。有关如何设置 URL,请参阅 SQLAlchemy 的文档

新建研究

我们可以通过调用 create_study() 函数来创建一个持久化研究,如下所示。SQLite 文件 example.db 将自动初始化一个新研究记录。

import logging
import sys

import optuna

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "example-study"  # Unique identifier of the study.
storage_name = "sqlite:///{}.db".format(study_name)
study = optuna.create_study(study_name=study_name, storage=storage_name)
A new study created in RDB with name: example-study

要运行研究,请调用 optimize() 方法并传入一个目标函数。

def objective(trial):
    x = trial.suggest_float("x", -10, 10)
    return (x - 2) ** 2


study.optimize(objective, n_trials=3)
Trial 0 finished with value: 95.06957693414783 and parameters: {'x': -7.750362912945745}. Best is trial 0 with value: 95.06957693414783.
Trial 1 finished with value: 66.26096634444009 and parameters: {'x': -6.140083927358495}. Best is trial 1 with value: 66.26096634444009.
Trial 2 finished with value: 0.6325095023350713 and parameters: {'x': 2.7953046600737803}. Best is trial 2 with value: 0.6325095023350713.

恢复研究

要恢复研究,请实例化一个 Study 对象,传入研究名称 example-study 和数据库 URL sqlite:///example-study.db

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=3)
Using an existing study with name 'example-study' instead of creating a new one.
Trial 3 finished with value: 12.36026881536561 and parameters: {'x': -1.5157173969711515}. Best is trial 2 with value: 0.6325095023350713.
Trial 4 finished with value: 43.48397938180832 and parameters: {'x': -4.59423834736115}. Best is trial 2 with value: 0.6325095023350713.
Trial 5 finished with value: 28.165038569716963 and parameters: {'x': -3.3070743889375587}. Best is trial 2 with value: 0.6325095023350713.

请注意,存储库不存储 samplerspruners 实例的状态。当我们使用一个为其 seed 参数指定了可复现性的采样器来恢复研究时,您需要使用 pickle 来恢复采样器,如下所示:

import pickle

# Save the sampler with pickle to be loaded later.
with open("sampler.pkl", "wb") as fout:
    pickle.dump(study.sampler, fout)

restored_sampler = pickle.load(open("sampler.pkl", "rb"))
study = optuna.create_study(
    study_name=study_name, storage=storage_name, load_if_exists=True, sampler=restored_sampler
)
study.optimize(objective, n_trials=3)

实验历史

请注意,本节需要安装 Pandas

$ pip install pandas

我们可以通过 Study 类访问研究和试验的历史。例如,我们可以获取 example-study 的所有试验,如下所示:

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
Using an existing study with name 'example-study' instead of creating a new one.

名为 trials_dataframe() 的方法返回一个 pandas 数据框,如下所示:

print(df)
   number      value  params_x     state
0       0  95.069577 -7.750363  COMPLETE
1       1  66.260966 -6.140084  COMPLETE
2       2   0.632510  2.795305  COMPLETE
3       3  12.360269 -1.515717  COMPLETE
4       4  43.483979 -4.594238  COMPLETE
5       5  28.165039 -3.307074  COMPLETE

Study 对象还提供诸如 trialsbest_valuebest_params 等属性(另请参阅 轻量级、通用、平台无关的架构)。

print("Best params: ", study.best_params)
print("Best value: ", study.best_value)
print("Best Trial: ", study.best_trial)
print("Trials: ", study.trials)
Best params:  {'x': 2.7953046600737803}
Best value:  0.6325095023350713
Best Trial:  FrozenTrial(number=2, state=<TrialState.COMPLETE: 1>, values=[0.6325095023350713], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 244320), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 262623), params={'x': 2.7953046600737803}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=3, value=None)
Trials:  [FrozenTrial(number=0, state=<TrialState.COMPLETE: 1>, values=[95.06957693414783], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 169086), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 198753), params={'x': -7.750362912945745}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=1, value=None), FrozenTrial(number=1, state=<TrialState.COMPLETE: 1>, values=[66.26096634444009], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 213304), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 232305), params={'x': -6.140083927358495}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=2, value=None), FrozenTrial(number=2, state=<TrialState.COMPLETE: 1>, values=[0.6325095023350713], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 244320), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 262623), params={'x': 2.7953046600737803}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=3, value=None), FrozenTrial(number=3, state=<TrialState.COMPLETE: 1>, values=[12.36026881536561], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 304658), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 328784), params={'x': -1.5157173969711515}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=4, value=None), FrozenTrial(number=4, state=<TrialState.COMPLETE: 1>, values=[43.48397938180832], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 341031), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 360220), params={'x': -4.59423834736115}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=5, value=None), FrozenTrial(number=5, state=<TrialState.COMPLETE: 1>, values=[28.165038569716963], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 371146), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 389763), params={'x': -3.3070743889375587}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=6, value=None)]

脚本总运行时间: (0 分钟 0.926 秒)

由 Sphinx-Gallery 生成的画廊