注意
转至末尾 下载完整的示例代码。
使用 RDB 后端保存/恢复研究
RDB 后端支持持久化实验(即保存和恢复研究)以及访问研究历史。此外,我们还可以利用此功能运行多节点优化任务,这在 易于并行化 中进行了描述。
在本节中,我们将尝试在本地环境中使用 SQLite 数据库运行简单示例。
注意
您还可以通过将 storage 参数设置为数据库的 URL 来利用其他 RDB 后端,例如 PostgreSQL 或 MySQL。有关如何设置 URL,请参阅 SQLAlchemy 的文档。
新建研究
我们可以通过调用 create_study() 函数来创建一个持久化研究,如下所示。SQLite 文件 example.db 将自动初始化一个新研究记录。
import logging
import sys
import optuna
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "example-study" # Unique identifier of the study.
storage_name = "sqlite:///{}.db".format(study_name)
study = optuna.create_study(study_name=study_name, storage=storage_name)
A new study created in RDB with name: example-study
要运行研究,请调用 optimize() 方法并传入一个目标函数。
def objective(trial):
x = trial.suggest_float("x", -10, 10)
return (x - 2) ** 2
study.optimize(objective, n_trials=3)
Trial 0 finished with value: 95.06957693414783 and parameters: {'x': -7.750362912945745}. Best is trial 0 with value: 95.06957693414783.
Trial 1 finished with value: 66.26096634444009 and parameters: {'x': -6.140083927358495}. Best is trial 1 with value: 66.26096634444009.
Trial 2 finished with value: 0.6325095023350713 and parameters: {'x': 2.7953046600737803}. Best is trial 2 with value: 0.6325095023350713.
恢复研究
要恢复研究,请实例化一个 Study 对象,传入研究名称 example-study 和数据库 URL sqlite:///example-study.db。
study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=3)
Using an existing study with name 'example-study' instead of creating a new one.
Trial 3 finished with value: 12.36026881536561 and parameters: {'x': -1.5157173969711515}. Best is trial 2 with value: 0.6325095023350713.
Trial 4 finished with value: 43.48397938180832 and parameters: {'x': -4.59423834736115}. Best is trial 2 with value: 0.6325095023350713.
Trial 5 finished with value: 28.165038569716963 and parameters: {'x': -3.3070743889375587}. Best is trial 2 with value: 0.6325095023350713.
请注意,存储库不存储 samplers 和 pruners 实例的状态。当我们使用一个为其 seed 参数指定了可复现性的采样器来恢复研究时,您需要使用 pickle 来恢复采样器,如下所示:
import pickle
# Save the sampler with pickle to be loaded later.
with open("sampler.pkl", "wb") as fout:
pickle.dump(study.sampler, fout)
restored_sampler = pickle.load(open("sampler.pkl", "rb"))
study = optuna.create_study(
study_name=study_name, storage=storage_name, load_if_exists=True, sampler=restored_sampler
)
study.optimize(objective, n_trials=3)
实验历史
请注意,本节需要安装 Pandas
$ pip install pandas
我们可以通过 Study 类访问研究和试验的历史。例如,我们可以获取 example-study 的所有试验,如下所示:
study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
Using an existing study with name 'example-study' instead of creating a new one.
名为 trials_dataframe() 的方法返回一个 pandas 数据框,如下所示:
print(df)
number value params_x state
0 0 95.069577 -7.750363 COMPLETE
1 1 66.260966 -6.140084 COMPLETE
2 2 0.632510 2.795305 COMPLETE
3 3 12.360269 -1.515717 COMPLETE
4 4 43.483979 -4.594238 COMPLETE
5 5 28.165039 -3.307074 COMPLETE
Study 对象还提供诸如 trials、best_value、best_params 等属性(另请参阅 轻量级、通用、平台无关的架构)。
print("Best params: ", study.best_params)
print("Best value: ", study.best_value)
print("Best Trial: ", study.best_trial)
print("Trials: ", study.trials)
Best params: {'x': 2.7953046600737803}
Best value: 0.6325095023350713
Best Trial: FrozenTrial(number=2, state=<TrialState.COMPLETE: 1>, values=[0.6325095023350713], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 244320), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 262623), params={'x': 2.7953046600737803}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=3, value=None)
Trials: [FrozenTrial(number=0, state=<TrialState.COMPLETE: 1>, values=[95.06957693414783], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 169086), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 198753), params={'x': -7.750362912945745}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=1, value=None), FrozenTrial(number=1, state=<TrialState.COMPLETE: 1>, values=[66.26096634444009], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 213304), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 232305), params={'x': -6.140083927358495}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=2, value=None), FrozenTrial(number=2, state=<TrialState.COMPLETE: 1>, values=[0.6325095023350713], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 244320), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 262623), params={'x': 2.7953046600737803}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=3, value=None), FrozenTrial(number=3, state=<TrialState.COMPLETE: 1>, values=[12.36026881536561], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 304658), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 328784), params={'x': -1.5157173969711515}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=4, value=None), FrozenTrial(number=4, state=<TrialState.COMPLETE: 1>, values=[43.48397938180832], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 341031), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 360220), params={'x': -4.59423834736115}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=5, value=None), FrozenTrial(number=5, state=<TrialState.COMPLETE: 1>, values=[28.165038569716963], datetime_start=datetime.datetime(2025, 11, 10, 5, 19, 12, 371146), datetime_complete=datetime.datetime(2025, 11, 10, 5, 19, 12, 389763), params={'x': -3.3070743889375587}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=6, value=None)]
脚本总运行时间: (0 分钟 0.926 秒)