optuna.storages.RDBStorage

class optuna.storages.RDBStorage(url, engine_kwargs=None, skip_compatibility_check=False, *, heartbeat_interval=None, grace_period=None, failed_trial_callback=None, skip_table_creation=False)[source]

RDB 后端的存储类。

请注意,库用户可以实例化此类,但不应直接访问此类提供的属性。

示例

创建一个具有自定义 pool_sizetimeout 设置的 RDBStorage 实例。

import optuna


def objective(trial):
    x = trial.suggest_float("x", -100, 100)
    return x**2


storage = optuna.storages.RDBStorage(
    url="sqlite:///:memory:",
    engine_kwargs={"pool_size": 20, "connect_args": {"timeout": 10}},
)

study = optuna.create_study(storage=storage)
study.optimize(objective, n_trials=10)
参数:
  • url (str) – 存储的 URL。

  • engine_kwargs (dict[str, Any] | None) – 传递给 sqlalchemy.engine.create_engine 函数的关键字参数字典。

  • skip_compatibility_check (bool) – 如果设置为 True,则跳过模式兼容性检查的标志。

  • heartbeat_interval (int | None) –

    记录心跳的间隔。每隔 interval 秒记录一次。heartbeat_interval 必须是 None 或正整数。

    注意

    心跳机制应与 optimize() 结合使用。如果您改用 ask()tell(),将不起作用。

  • grace_period (int | None) – 自上次心跳以来,运行中的试验被标记为失败之前的宽限期。grace_period 必须是 None 或正整数。如果为 None,则宽限期为 2 * heartbeat_interval

  • failed_trial_callback (Callable[['optuna.study.Study', FrozenTrial], None] | None) –

    一个回调函数,在每个过期的试验失败后调用。函数必须按此顺序接受两个参数,类型如下:StudyFrozenTrial

    注意

    标记现有过期试验失败的过程在向 study 请求新试验之前调用。

  • skip_table_creation (bool) – 如果设置为 True,则跳过表创建的标志。

注意

如果您使用 MySQL,pool_pre_ping 默认设置为 True 以防止连接超时。您可以通过设置 engine_kwargs['pool_pre_ping']=False 来关闭它,但如果您的目标函数的执行时间长于 MySQL 配置中的 wait_timeout,建议保留此设置。

注意

我们绝不推荐将 SQLite3 用于并行优化。有关详细信息,请参阅常见问题 如何解决使用 SQLite3 进行并行优化时发生的错误?

注意

主要在集群环境中,运行中的试验经常意外终止。如果您想检测试验的失败,请使用心跳机制。根据您的使用情况,适当地设置 heartbeat_intervalgrace_periodfailed_trial_callback。有关更多详细信息,请参阅教程示例页面

另请参阅

您可以使用 RetryFailedTrialCallback 自动重试由心跳检测到的失败试验。

方法

check_trial_is_updatable(trial_id, trial_state)

检查试验状态是否可更新。

create_new_study(directions[, study_name])

根据名称创建一个新的 study。

create_new_trial(study_id[, template_trial])

创建新试验并将其添加到 study。

delete_study(study_id)

删除一个 study。

get_all_studies()

读取 FrozenStudy 对象列表。

get_all_trials(study_id[, deepcopy, states])

读取 study 中的所有试验。

get_all_versions()

返回模式版本列表。

get_best_trial(study_id)

返回 study 中具有最佳值的试验。

get_current_version()

返回此 storage 当前使用的模式版本。

get_failed_trial_callback()

获取失败试验回调函数。

get_head_version()

返回最新的模式版本。

get_heartbeat_interval()

获取设置的心跳间隔。

get_n_trials(study_id[, state])

计算 study 中的试验数量。

get_study_directions(study_id)

读取 study 是最大化还是最小化目标。

get_study_id_from_name(study_name)

读取 study 的 ID。

get_study_name_from_id(study_id)

读取 study 的名称。

get_study_system_attrs(study_id)

读取 study 的 optuna 内部属性。

get_study_user_attrs(study_id)

读取 study 的用户定义属性。

get_trial(trial_id)

读取一个试验。

get_trial_id_from_study_id_trial_number(...)

读取试验的试验 ID。

get_trial_number_from_id(trial_id)

读取试验的试验编号。

get_trial_param(trial_id, param_name)

读取试验的参数。

get_trial_params(trial_id)

读取试验的参数字典。

get_trial_system_attrs(trial_id)

读取试验的 optuna 内部属性。

get_trial_user_attrs(trial_id)

读取试验的用户定义属性。

record_heartbeat(trial_id)

记录试验的心跳。

remove_session()

移除当前会话。

set_study_system_attr(study_id, key, value)

将 optuna 内部属性注册到 study。

set_study_user_attr(study_id, key, value)

将用户定义属性注册到 study。

set_trial_intermediate_value(trial_id, step, ...)

报告目标函数的中间值。

set_trial_param(trial_id, param_name, ...)

为一个试验设置参数。

set_trial_state_values(trial_id, state[, values])

更新试验的状态和值。

set_trial_system_attr(trial_id, key, value)

为一个试验设置 optuna 内部属性。

set_trial_user_attr(trial_id, key, value)

为一个试验设置用户定义属性。

upgrade()

升级存储模式。

check_trial_is_updatable(trial_id, trial_state)

检查试验状态是否可更新。

参数:
  • trial_id (int) – 试验的 ID。仅用于错误消息。

  • trial_state (TrialState) – 要检查的试验状态。

引发:

UpdateFinishedTrialError – 如果试验已完成。

返回类型:

None

create_new_study(directions, study_name=None)[source]

根据名称创建一个新的 study。

如果未指定名称,则 storage 类会生成一个名称。返回的 study ID 在所有当前和已删除的 study 中是唯一的。

参数:
返回:

创建的 study 的 ID。

引发:

optuna.exceptions.DuplicatedStudyError – 如果具有相同 study_name 的 study 已存在。

返回类型:

int

create_new_trial(study_id, template_trial=None)[source]

创建新试验并将其添加到 study。

返回的试验 ID 在所有当前和已删除的试验中是唯一的。

参数:
  • study_id (int) – study 的 ID。

  • template_trial (FrozenTrial | None) – 包含默认用户属性、系统属性、中间值和状态的 FrozenTrial 模板。

返回:

创建的试验的 ID。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

int

delete_study(study_id)[source]

删除一个 study。

参数:

study_id (int) – study 的 ID。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

None

get_all_studies()[source]

读取 FrozenStudy 对象列表。

返回:

FrozenStudy 对象列表,按 study_id 排序。

返回类型:

list[FrozenStudy]

get_all_trials(study_id, deepcopy=True, states=None)[source]

读取 study 中的所有试验。

参数:
  • study_id (int) – study 的 ID。

  • deepcopy (bool) – 返回之前是否复制试验列表。如果您打算更新列表或列表中的元素,请设置为 True

  • states (Container[TrialState] | None) – 要过滤的试验状态。如果为 None,则包含所有状态。

返回:

study 中的试验列表,按 trial_id 排序。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

list[FrozenTrial]

get_all_versions()[source]

返回模式版本列表。

返回类型:

list[str]

get_best_trial(study_id)[source]

返回 study 中具有最佳值的试验。

此方法仅在单目标优化期间有效。

参数:

study_id (int) – study 的 ID。

返回:

study 中所有已完成试验中具有最佳目标值的试验。

引发:
  • KeyError – 如果不存在具有匹配 study_id 的 study。

  • RuntimeError – 如果 study 有多个方向。

  • ValueError – 如果没有试验完成。

返回类型:

FrozenTrial

get_current_version()[source]

返回此 storage 当前使用的模式版本。

返回类型:

str

get_failed_trial_callback()[source]

获取失败试验回调函数。

返回:

如果设置了失败试验回调函数,则返回该函数,否则返回 None

返回类型:

Callable[[Study, FrozenTrial], None] | None

get_head_version()[source]

返回最新的模式版本。

返回类型:

str

get_heartbeat_interval()[source]

获取设置的心跳间隔。

返回:

如果设置了心跳间隔,则返回该值,否则返回 None

返回类型:

int | None

get_n_trials(study_id, state=None)

计算 study 中的试验数量。

参数:
  • study_id (int) – study 的 ID。

  • state (tuple[TrialState, ...] | TrialState | None) – 要过滤的试验状态。如果为 None,则包含所有状态。

返回:

study 中的试验数量。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

int

get_study_directions(study_id)[source]

读取 study 是最大化还是最小化目标。

参数:

study_id (int) – study 的 ID。

返回:

study 的优化方向列表。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

list[StudyDirection]

get_study_id_from_name(study_name)[source]

读取 study 的 ID。

参数:

study_name (str) – study 的名称。

返回:

study 的 ID。

引发:

KeyError – 如果不存在具有匹配 study_name 的 study。

返回类型:

int

get_study_name_from_id(study_id)[source]

读取 study 的名称。

参数:

study_id (int) – study 的 ID。

返回:

study 的名称。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

str

get_study_system_attrs(study_id)[source]

读取 study 的 optuna 内部属性。

参数:

study_id (int) – study 的 ID。

返回:

包含 study 的 optuna 内部属性的字典。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

dict[str, Any]

get_study_user_attrs(study_id)[source]

读取 study 的用户定义属性。

参数:

study_id (int) – study 的 ID。

返回:

包含 study 的用户属性的字典。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

dict[str, Any]

get_trial(trial_id)[source]

读取一个试验。

参数:

trial_id (int) – 试验的 ID。

返回:

具有匹配试验 ID 的试验。

引发:

KeyError – 如果不存在具有匹配 trial_id 的试验。

返回类型:

FrozenTrial

get_trial_id_from_study_id_trial_number(study_id, trial_number)[source]

读取试验的试验 ID。

参数:
  • study_id (int) – study 的 ID。

  • trial_number (int) – 试验的编号。

返回:

试验的 ID。

引发:

KeyError – 如果不存在具有匹配 study_idtrial_number 的试验。

返回类型:

int

get_trial_number_from_id(trial_id)

读取试验的试验编号。

注意

试验编号仅在 study 中唯一,并且是顺序的。

参数:

trial_id (int) – 试验的 ID。

返回:

试验的编号。

引发:

KeyError – 如果不存在具有匹配 trial_id 的试验。

返回类型:

int

get_trial_param(trial_id, param_name)[source]

读取试验的参数。

参数:
  • trial_id (int) – 试验的 ID。

  • param_name (str) – 参数的名称。

返回:

参数的内部表示。

引发:

KeyError – 如果不存在具有匹配 trial_id 的试验。如果不存在此参数。

返回类型:

float

get_trial_params(trial_id)

读取试验的参数字典。

参数:

trial_id (int) – 试验的 ID。

返回:

参数字典。键是参数名称,值是参数值的外部表示。

引发:

KeyError – 如果不存在具有匹配 trial_id 的试验。

返回类型:

dict[str, Any]

get_trial_system_attrs(trial_id)[source]

读取试验的 optuna 内部属性。

参数:

trial_id (int) – 试验的 ID。

返回:

包含试验的 optuna 内部属性的字典。

引发:

KeyError – 如果不存在具有匹配 trial_id 的试验。

返回类型:

dict[str, Any]

get_trial_user_attrs(trial_id)[source]

读取试验的用户定义属性。

参数:

trial_id (int) – 试验的 ID。

返回:

包含试验的用户定义属性的字典。

引发:

KeyError – 如果不存在具有匹配 trial_id 的试验。

返回类型:

dict[str, Any]

record_heartbeat(trial_id)[source]

记录试验的心跳。

参数:

trial_id (int) – 试验的 ID。

返回类型:

None

remove_session()[source]

移除当前会话。

每个线程的会话都存储在 SQLAlchemy 的 ThreadLocalRegistry 中。此方法关闭并移除与当前线程关联的会话。特别是在多线程使用场景下,从每个线程调用此方法非常重要。否则,所有会话及其关联的数据库连接会被某个偶尔触发垃圾收集器的线程销毁。默认情况下,不允许从创建连接的线程以外的其他线程访问 SQLite 连接。因此,我们需要从每个线程显式地关闭连接。

返回类型:

None

set_study_system_attr(study_id, key, value)[source]

将 optuna 内部属性注册到 study。

此方法会覆盖任何现有属性。

参数:
  • study_id (int) – study 的 ID。

  • key (str) – 属性键。

  • value (Mapping[str, JSONSerializable] | Sequence[JSONSerializable] | str | int | float | bool | None) – 属性值。它应该是 JSON 可序列化的。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

None

set_study_user_attr(study_id, key, value)[source]

将用户定义属性注册到 study。

此方法会覆盖任何现有属性。

参数:
  • study_id (int) – study 的 ID。

  • key (str) – 属性键。

  • value (Any) – 属性值。它应该是 JSON 可序列化的。

引发:

KeyError – 如果不存在具有匹配 study_id 的 study。

返回类型:

None

set_trial_intermediate_value(trial_id, step, intermediate_value)[source]

报告目标函数的中间值。

此方法会覆盖与给定 step 关联的任何现有中间值。

参数:
  • trial_id (int) – 试验的 ID。

  • step (int) – 试验的 step(例如,训练神经网络时的 epoch)。

  • intermediate_value (float) – 与 step 对应的中间值。

引发:
返回类型:

None

set_trial_param(trial_id, param_name, param_value_internal, distribution)[source]

为一个试验设置参数。

参数:
  • trial_id (int) – 试验的 ID。

  • param_name (str) – 参数的名称。

  • param_value_internal (float) – 参数值的内部表示。

  • distribution (BaseDistribution) – 参数的采样分布。

引发:
返回类型:

None

set_trial_state_values(trial_id, state, values=None)[source]

更新试验的状态和值。

将目标函数的返回值设置为 values 参数。如果 values 参数不为 None,则此方法会覆盖任何现有试验值。

参数:
  • trial_id (int) – 试验的 ID。

  • state (TrialState) – 试验的新状态。

  • values (Sequence[float] | None) – 目标函数的值。

返回:

如果状态成功更新,则为 True。如果状态保持不变,则为 False。后一种情况发生在当此方法尝试将 RUNNING 试验的状态更新为 RUNNING 时。

引发:
返回类型:

bool

set_trial_system_attr(trial_id, key, value)[source]

为一个试验设置 optuna 内部属性。

此方法会覆盖任何现有属性。

参数:
  • trial_id (int) – 试验的 ID。

  • key (str) – 属性键。

  • value (Mapping[str, JSONSerializable] | Sequence[JSONSerializable] | str | int | float | bool | None) – 属性值。它应该是 JSON 可序列化的。

引发:
返回类型:

None

set_trial_user_attr(trial_id, key, value)[source]

为一个试验设置用户定义属性。

此方法会覆盖任何现有属性。

参数:
  • trial_id (int) – 试验的 ID。

  • key (str) – 属性键。

  • value (Any) – 属性值。它应该是 JSON 可序列化的。

引发:
返回类型:

None

upgrade()[source]

升级存储模式。

返回类型:

None