optuna.storages.RDBStorage
- class optuna.storages.RDBStorage(url, engine_kwargs=None, skip_compatibility_check=False, *, heartbeat_interval=None, grace_period=None, failed_trial_callback=None, skip_table_creation=False)[source]
RDB 后端的存储类。
请注意,库用户可以实例化此类,但不应直接访问此类提供的属性。
示例
创建一个具有自定义
pool_size
和timeout
设置的RDBStorage
实例。import optuna def objective(trial): x = trial.suggest_float("x", -100, 100) return x**2 storage = optuna.storages.RDBStorage( url="sqlite:///:memory:", engine_kwargs={"pool_size": 20, "connect_args": {"timeout": 10}}, ) study = optuna.create_study(storage=storage) study.optimize(objective, n_trials=10)
- 参数:
url (str) – 存储的 URL。
engine_kwargs (dict[str, Any] | None) – 传递给 sqlalchemy.engine.create_engine 函数的关键字参数字典。
heartbeat_interval (int | None) –
记录心跳的间隔。每隔
interval
秒记录一次。heartbeat_interval
必须是None
或正整数。注意
心跳机制应与
optimize()
结合使用。如果您改用ask()
和tell()
,将不起作用。grace_period (int | None) – 自上次心跳以来,运行中的试验被标记为失败之前的宽限期。
grace_period
必须是None
或正整数。如果为None
,则宽限期为2 * heartbeat_interval
。failed_trial_callback (Callable[['optuna.study.Study', FrozenTrial], None] | None) –
一个回调函数,在每个过期的试验失败后调用。函数必须按此顺序接受两个参数,类型如下:
Study
和FrozenTrial
。注意
标记现有过期试验失败的过程在向 study 请求新试验之前调用。
注意
如果您使用 MySQL,pool_pre_ping 默认设置为
True
以防止连接超时。您可以通过设置engine_kwargs['pool_pre_ping']=False
来关闭它,但如果您的目标函数的执行时间长于 MySQL 配置中的 wait_timeout,建议保留此设置。注意
我们绝不推荐将 SQLite3 用于并行优化。有关详细信息,请参阅常见问题 如何解决使用 SQLite3 进行并行优化时发生的错误?。
注意
主要在集群环境中,运行中的试验经常意外终止。如果您想检测试验的失败,请使用心跳机制。根据您的使用情况,适当地设置
heartbeat_interval
、grace_period
和failed_trial_callback
。有关更多详细信息,请参阅教程和示例页面。另请参阅
您可以使用
RetryFailedTrialCallback
自动重试由心跳检测到的失败试验。方法
check_trial_is_updatable
(trial_id, trial_state)检查试验状态是否可更新。
create_new_study
(directions[, study_name])根据名称创建一个新的 study。
create_new_trial
(study_id[, template_trial])创建新试验并将其添加到 study。
delete_study
(study_id)删除一个 study。
读取
FrozenStudy
对象列表。get_all_trials
(study_id[, deepcopy, states])读取 study 中的所有试验。
返回模式版本列表。
get_best_trial
(study_id)返回 study 中具有最佳值的试验。
返回此 storage 当前使用的模式版本。
获取失败试验回调函数。
返回最新的模式版本。
获取设置的心跳间隔。
get_n_trials
(study_id[, state])计算 study 中的试验数量。
get_study_directions
(study_id)读取 study 是最大化还是最小化目标。
get_study_id_from_name
(study_name)读取 study 的 ID。
get_study_name_from_id
(study_id)读取 study 的名称。
get_study_system_attrs
(study_id)读取 study 的 optuna 内部属性。
get_study_user_attrs
(study_id)读取 study 的用户定义属性。
get_trial
(trial_id)读取一个试验。
读取试验的试验 ID。
get_trial_number_from_id
(trial_id)读取试验的试验编号。
get_trial_param
(trial_id, param_name)读取试验的参数。
get_trial_params
(trial_id)读取试验的参数字典。
get_trial_system_attrs
(trial_id)读取试验的 optuna 内部属性。
get_trial_user_attrs
(trial_id)读取试验的用户定义属性。
record_heartbeat
(trial_id)记录试验的心跳。
移除当前会话。
set_study_system_attr
(study_id, key, value)将 optuna 内部属性注册到 study。
set_study_user_attr
(study_id, key, value)将用户定义属性注册到 study。
set_trial_intermediate_value
(trial_id, step, ...)报告目标函数的中间值。
set_trial_param
(trial_id, param_name, ...)为一个试验设置参数。
set_trial_state_values
(trial_id, state[, values])更新试验的状态和值。
set_trial_system_attr
(trial_id, key, value)为一个试验设置 optuna 内部属性。
set_trial_user_attr
(trial_id, key, value)为一个试验设置用户定义属性。
upgrade
()升级存储模式。
- check_trial_is_updatable(trial_id, trial_state)
检查试验状态是否可更新。
- 参数:
trial_id (int) – 试验的 ID。仅用于错误消息。
trial_state (TrialState) – 要检查的试验状态。
- 引发:
UpdateFinishedTrialError – 如果试验已完成。
- 返回类型:
None
- create_new_study(directions, study_name=None)[source]
根据名称创建一个新的 study。
如果未指定名称,则 storage 类会生成一个名称。返回的 study ID 在所有当前和已删除的 study 中是唯一的。
- 参数:
directions (Sequence[StudyDirection]) – 方向序列,其元素为
MAXIMIZE
或MINIMIZE
。study_name (str | None) – 要创建的新 study 的名称。
- 返回:
创建的 study 的 ID。
- 引发:
optuna.exceptions.DuplicatedStudyError – 如果具有相同
study_name
的 study 已存在。- 返回类型:
- create_new_trial(study_id, template_trial=None)[source]
创建新试验并将其添加到 study。
返回的试验 ID 在所有当前和已删除的试验中是唯一的。
- 参数:
study_id (int) – study 的 ID。
template_trial (FrozenTrial | None) – 包含默认用户属性、系统属性、中间值和状态的
FrozenTrial
模板。
- 返回:
创建的试验的 ID。
- 引发:
KeyError – 如果不存在具有匹配
study_id
的 study。- 返回类型:
- get_all_studies()[source]
读取
FrozenStudy
对象列表。- 返回:
FrozenStudy
对象列表,按study_id
排序。- 返回类型:
list[FrozenStudy]
- get_all_trials(study_id, deepcopy=True, states=None)[source]
读取 study 中的所有试验。
- 参数:
- 返回:
study 中的试验列表,按
trial_id
排序。- 引发:
KeyError – 如果不存在具有匹配
study_id
的 study。- 返回类型:
- get_best_trial(study_id)[source]
返回 study 中具有最佳值的试验。
此方法仅在单目标优化期间有效。
- 参数:
study_id (int) – study 的 ID。
- 返回:
study 中所有已完成试验中具有最佳目标值的试验。
- 引发:
KeyError – 如果不存在具有匹配
study_id
的 study。RuntimeError – 如果 study 有多个方向。
ValueError – 如果没有试验完成。
- 返回类型:
- get_failed_trial_callback()[source]
获取失败试验回调函数。
- 返回:
如果设置了失败试验回调函数,则返回该函数,否则返回
None
。- 返回类型:
Callable[[Study, FrozenTrial], None] | None
- get_n_trials(study_id, state=None)
计算 study 中的试验数量。
- 参数:
study_id (int) – study 的 ID。
state (tuple[TrialState, ...] | TrialState | None) – 要过滤的试验状态。如果为
None
,则包含所有状态。
- 返回:
study 中的试验数量。
- 引发:
KeyError – 如果不存在具有匹配
study_id
的 study。- 返回类型:
- get_study_directions(study_id)[source]
读取 study 是最大化还是最小化目标。
- get_trial(trial_id)[source]
读取一个试验。
- get_trial_number_from_id(trial_id)
读取试验的试验编号。
注意
试验编号仅在 study 中唯一,并且是顺序的。
- get_trial_params(trial_id)
读取试验的参数字典。
- remove_session()[source]
移除当前会话。
每个线程的会话都存储在 SQLAlchemy 的 ThreadLocalRegistry 中。此方法关闭并移除与当前线程关联的会话。特别是在多线程使用场景下,从每个线程调用此方法非常重要。否则,所有会话及其关联的数据库连接会被某个偶尔触发垃圾收集器的线程销毁。默认情况下,不允许从创建连接的线程以外的其他线程访问 SQLite 连接。因此,我们需要从每个线程显式地关闭连接。
- 返回类型:
None
- set_trial_intermediate_value(trial_id, step, intermediate_value)[source]
报告目标函数的中间值。
此方法会覆盖与给定 step 关联的任何现有中间值。
- 参数:
- 引发:
KeyError – 如果不存在具有匹配
trial_id
的试验。UpdateFinishedTrialError – 如果试验已完成。
- 返回类型:
None
- set_trial_param(trial_id, param_name, param_value_internal, distribution)[source]
为一个试验设置参数。
- 参数:
- 引发:
KeyError – 如果不存在具有匹配
trial_id
的试验。UpdateFinishedTrialError – 如果试验已完成。
- 返回类型:
None
- set_trial_state_values(trial_id, state, values=None)[source]
更新试验的状态和值。
将目标函数的返回值设置为 values 参数。如果 values 参数不为
None
,则此方法会覆盖任何现有试验值。
- set_trial_user_attr(trial_id, key, value)[source]
为一个试验设置用户定义属性。
此方法会覆盖任何现有属性。
- 参数:
- 引发:
KeyError – 如果不存在具有匹配
trial_id
的试验。UpdateFinishedTrialError – 如果试验已完成。
- 返回类型:
None