注意
跳到末尾 以下载完整示例代码。
使用 RDB 后端保存/恢复研究
RDB 后端支持持久化实验(即保存和恢复研究)以及访问研究历史记录。此外,我们可以利用此特性运行多节点优化任务,这在轻松并行化中有所描述。
在本节中,让我们尝试一些在本地环境中使用 SQLite 数据库运行的简单示例。
注意
您也可以通过将 storage 参数设置为数据库的 URL 来使用其他 RDB 后端,例如 PostgreSQL 或 MySQL。有关如何设置 URL 的信息,请参阅 SQLAlchemy 文档。
新建研究
我们可以通过调用 create_study()
函数来创建持久化研究,如下所示。SQLite 文件 example.db
将自动初始化并包含一个新的研究记录。
import logging
import sys
import optuna
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "example-study" # Unique identifier of the study.
storage_name = "sqlite:///{}.db".format(study_name)
study = optuna.create_study(study_name=study_name, storage=storage_name)
A new study created in RDB with name: example-study
要运行研究,请调用 optimize()
方法并传入一个目标函数。
def objective(trial):
x = trial.suggest_float("x", -10, 10)
return (x - 2) ** 2
study.optimize(objective, n_trials=3)
Trial 0 finished with value: 0.046443906828316436 and parameters: {'x': 1.784491515646561}. Best is trial 0 with value: 0.046443906828316436.
Trial 1 finished with value: 7.623390436986242 and parameters: {'x': 4.761048792938336}. Best is trial 0 with value: 0.046443906828316436.
Trial 2 finished with value: 4.241017991378444 and parameters: {'x': 4.059373203520538}. Best is trial 0 with value: 0.046443906828316436.
恢复研究
要恢复研究,请实例化一个 Study
对象,并传入研究名称 example-study
和数据库 URL sqlite:///example-study.db
。
study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=3)
Using an existing study with name 'example-study' instead of creating a new one.
Trial 3 finished with value: 4.278975688797918 and parameters: {'x': -0.0685685119903372}. Best is trial 0 with value: 0.046443906828316436.
Trial 4 finished with value: 115.38194887296103 and parameters: {'x': -8.741598990511656}. Best is trial 0 with value: 0.046443906828316436.
Trial 5 finished with value: 10.881585583368315 and parameters: {'x': 5.298724842021279}. Best is trial 0 with value: 0.046443906828316436.
请注意,存储后端不会保存 samplers
和 pruners
实例的状态。当我们恢复一个指定了用于重现性的 seed
参数的采样器进行研究时,您需要使用 pickle
如下恢复采样器
import pickle
# Save the sampler with pickle to be loaded later.
with open("sampler.pkl", "wb") as fout:
pickle.dump(study.sampler, fout)
restored_sampler = pickle.load(open("sampler.pkl", "rb"))
study = optuna.create_study(
study_name=study_name, storage=storage_name, load_if_exists=True, sampler=restored_sampler
)
study.optimize(objective, n_trials=3)
实验历史
请注意,本节需要安装 Pandas
$ pip install pandas
我们可以通过 Study
类访问研究和试验的历史记录。例如,我们可以获取 example-study
的所有试验,如下所示
study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
Using an existing study with name 'example-study' instead of creating a new one.
方法 trials_dataframe()
返回一个 pandas 数据框,如下所示
print(df)
number value params_x state
0 0 0.046444 1.784492 COMPLETE
1 1 7.623390 4.761049 COMPLETE
2 2 4.241018 4.059373 COMPLETE
3 3 4.278976 -0.068569 COMPLETE
4 4 115.381949 -8.741599 COMPLETE
5 5 10.881586 5.298725 COMPLETE
一个 Study
对象还提供了一些属性,例如 trials
、best_value
、best_params
(另请参阅 轻量、通用且平台无关的架构)。
print("Best params: ", study.best_params)
print("Best value: ", study.best_value)
print("Best Trial: ", study.best_trial)
print("Trials: ", study.trials)
Best params: {'x': 1.784491515646561}
Best value: 0.046443906828316436
Best Trial: FrozenTrial(number=0, state=1, values=[0.046443906828316436], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 28, 980453), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 16320), params={'x': 1.784491515646561}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=1, value=None)
Trials: [FrozenTrial(number=0, state=1, values=[0.046443906828316436], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 28, 980453), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 16320), params={'x': 1.784491515646561}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=1, value=None), FrozenTrial(number=1, state=1, values=[7.623390436986242], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 36411), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 57875), params={'x': 4.761048792938336}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=2, value=None), FrozenTrial(number=2, state=1, values=[4.241017991378444], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 73596), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 93995), params={'x': 4.059373203520538}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=3, value=None), FrozenTrial(number=3, state=1, values=[4.278975688797918], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 142694), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 169903), params={'x': -0.0685685119903372}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=4, value=None), FrozenTrial(number=4, state=1, values=[115.38194887296103], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 187677), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 208178), params={'x': -8.741598990511656}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=5, value=None), FrozenTrial(number=5, state=1, values=[10.881585583368315], datetime_start=datetime.datetime(2025, 4, 14, 5, 11, 29, 223675), datetime_complete=datetime.datetime(2025, 4, 14, 5, 11, 29, 244282), params={'x': 5.298724842021279}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=6, value=None)]
脚本总运行时间: (0 分钟 1.319 秒)