使用 RDB 后端保存/恢复研究

RDB 后端支持持久化实验(即保存和恢复研究)以及访问研究历史记录。此外,我们可以利用此特性运行多节点优化任务,这在轻松并行化中有所描述。

在本节中,让我们尝试一些在本地环境中使用 SQLite 数据库运行的简单示例。

注意

您也可以通过将 storage 参数设置为数据库的 URL 来使用其他 RDB 后端,例如 PostgreSQL 或 MySQL。有关如何设置 URL 的信息,请参阅 SQLAlchemy 文档

新建研究

我们可以通过调用 create_study() 函数来创建持久化研究,如下所示。SQLite 文件 example.db 将自动初始化并包含一个新的研究记录。

import logging
import sys

import optuna

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "example-study"  # Unique identifier of the study.
storage_name = "sqlite:///{}.db".format(study_name)
study = optuna.create_study(study_name=study_name, storage=storage_name)
A new study created in RDB with name: example-study

要运行研究,请调用 optimize() 方法并传入一个目标函数。

def objective(trial):
    x = trial.suggest_float("x", -10, 10)
    return (x - 2) ** 2


study.optimize(objective, n_trials=3)
Trial 0 finished with value: 0.046443906828316436 and parameters: {'x': 1.784491515646561}. Best is trial 0 with value: 0.046443906828316436.
Trial 1 finished with value: 7.623390436986242 and parameters: {'x': 4.761048792938336}. Best is trial 0 with value: 0.046443906828316436.
Trial 2 finished with value: 4.241017991378444 and parameters: {'x': 4.059373203520538}. Best is trial 0 with value: 0.046443906828316436.

恢复研究

要恢复研究,请实例化一个 Study 对象,并传入研究名称 example-study 和数据库 URL sqlite:///example-study.db

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=3)
Using an existing study with name 'example-study' instead of creating a new one.
Trial 3 finished with value: 4.278975688797918 and parameters: {'x': -0.0685685119903372}. Best is trial 0 with value: 0.046443906828316436.
Trial 4 finished with value: 115.38194887296103 and parameters: {'x': -8.741598990511656}. Best is trial 0 with value: 0.046443906828316436.
Trial 5 finished with value: 10.881585583368315 and parameters: {'x': 5.298724842021279}. Best is trial 0 with value: 0.046443906828316436.

请注意,存储后端不会保存 samplerspruners 实例的状态。当我们恢复一个指定了用于重现性的 seed 参数的采样器进行研究时,您需要使用 pickle 如下恢复采样器

import pickle

# Save the sampler with pickle to be loaded later.
with open("sampler.pkl", "wb") as fout:
    pickle.dump(study.sampler, fout)

restored_sampler = pickle.load(open("sampler.pkl", "rb"))
study = optuna.create_study(
    study_name=study_name, storage=storage_name, load_if_exists=True, sampler=restored_sampler
)
study.optimize(objective, n_trials=3)

实验历史

请注意,本节需要安装 Pandas

$ pip install pandas

我们可以通过 Study 类访问研究和试验的历史记录。例如,我们可以获取 example-study 的所有试验,如下所示

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
Using an existing study with name 'example-study' instead of creating a new one.

方法 trials_dataframe() 返回一个 pandas 数据框,如下所示

print(df)
   number       value  params_x     state
0       0    0.046444  1.784492  COMPLETE
1       1    7.623390  4.761049  COMPLETE
2       2    4.241018  4.059373  COMPLETE
3       3    4.278976 -0.068569  COMPLETE
4       4  115.381949 -8.741599  COMPLETE
5       5   10.881586  5.298725  COMPLETE

一个 Study 对象还提供了一些属性,例如 trialsbest_valuebest_params(另请参阅 轻量、通用且平台无关的架构)。

print("Best params: ", study.best_params)
print("Best value: ", study.best_value)
print("Best Trial: ", study.best_trial)
print("Trials: ", study.trials)
Best params:  {'x': 1.784491515646561}
Best value:  0.046443906828316436
Best Trial:  FrozenTrial(number=0, state=1, values=[0.046443906828316436], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 28, 980453), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 16320), params={'x': 1.784491515646561}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=1, value=None)
Trials:  [FrozenTrial(number=0, state=1, values=[0.046443906828316436], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 28, 980453), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 16320), params={'x': 1.784491515646561}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=1, value=None), FrozenTrial(number=1, state=1, values=[7.623390436986242], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 36411), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 57875), params={'x': 4.761048792938336}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=2, value=None), FrozenTrial(number=2, state=1, values=[4.241017991378444], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 73596), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 93995), params={'x': 4.059373203520538}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=3, value=None), FrozenTrial(number=3, state=1, values=[4.278975688797918], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 142694), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 169903), params={'x': -0.0685685119903372}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=4, value=None), FrozenTrial(number=4, state=1, values=[115.38194887296103], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 187677), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 208178), params={'x': -8.741598990511656}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=5, value=None), FrozenTrial(number=5, state=1, values=[10.881585583368315], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 223675), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 244282), params={'x': 5.298724842021279}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=6, value=None)]

脚本总运行时间: (0 分钟 1.319 秒)

由 Sphinx-Gallery 生成的图库