注意
要下载完整的示例代码,请 跳转至末尾。
超参数优化分析快速可视化
Optuna 在 optuna.visualization 中提供了各种可视化功能,用于直观分析优化结果。
请注意,本教程需要安装 Plotly。
$ pip install plotly
# Required if you are running this tutorial in Jupyter Notebook.
$ pip install nbformat
如果您倾向于使用 Matplotlib 而非 Plotly,请运行以下命令:
$ pip install matplotlib
本教程将通过可视化 PyTorch 模型在 FashionMNIST 数据集上的优化结果来引导您了解此模块。
有关多目标优化的可视化(即 optuna.visualization.plot_pareto_front() 的使用),请参阅 Optuna 多目标优化 教程。
注意
通过使用 Optuna Dashboard,您还可以以图表和表格的形式查看优化历史、超参数重要性、超参数关系等。请使用 RDB 后端 使您的研究持久化,并执行以下命令来运行 Optuna Dashboard。
$ pip install optuna-dashboard
$ optuna-dashboard sqlite:///example-study.db
有关更多详细信息,请查看 GitHub 仓库。
管理研究 |
使用交互式图表进行可视化 |
|---|---|
|
|
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import optuna
# You can use Matplotlib instead of Plotly for visualization by simply replacing `optuna.visualization` with
# `optuna.visualization.matplotlib` in the following examples.
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_intermediate_values
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_rank
from optuna.visualization import plot_slice
from optuna.visualization import plot_timeline
SEED = 13
torch.manual_seed(SEED)
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
def define_model(trial):
n_layers = trial.suggest_int("n_layers", 1, 2)
layers = []
in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int("n_units_l{}".format(i), 64, 512)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
in_features = out_features
layers.append(nn.Linear(in_features, 10))
layers.append(nn.LogSoftmax(dim=1))
return nn.Sequential(*layers)
# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
F.nll_loss(model(data), target).backward()
optimizer.step()
def eval_model(model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
pred = model(data).argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / N_VALID_EXAMPLES
return accuracy
定义目标函数。
def objective(trial):
train_dataset = torchvision.datasets.FashionMNIST(
DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
val_dataset = torchvision.datasets.FashionMNIST(
DIR, train=False, transform=torchvision.transforms.ToTensor()
)
val_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
model = define_model(trial).to(DEVICE)
optimizer = torch.optim.Adam(
model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
)
for epoch in range(10):
train_model(model, optimizer, train_loader)
val_accuracy = eval_model(model, val_loader)
trial.report(val_accuracy, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return val_accuracy
study = optuna.create_study(
direction="maximize",
sampler=optuna.samplers.TPESampler(seed=SEED),
pruner=optuna.pruners.MedianPruner(),
)
study.optimize(objective, n_trials=30, timeout=300)
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 32.8k/26.4M [00:00<01:21, 326kB/s]
0%| | 65.5k/26.4M [00:00<01:21, 324kB/s]
0%| | 131k/26.4M [00:00<00:55, 471kB/s]
1%| | 229k/26.4M [00:00<00:39, 668kB/s]
2%|▏ | 459k/26.4M [00:00<00:20, 1.24MB/s]
3%|▎ | 918k/26.4M [00:00<00:10, 2.36MB/s]
7%|▋ | 1.84M/26.4M [00:00<00:05, 4.54MB/s]
14%|█▍ | 3.67M/26.4M [00:00<00:02, 8.84MB/s]
28%|██▊ | 7.31M/26.4M [00:00<00:01, 17.3MB/s]
42%|████▏ | 11.0M/26.4M [00:01<00:00, 23.3MB/s]
56%|█████▌ | 14.7M/26.4M [00:01<00:00, 26.5MB/s]
71%|███████ | 18.7M/26.4M [00:01<00:00, 30.4MB/s]
85%|████████▍ | 22.4M/26.4M [00:01<00:00, 31.7MB/s]
100%|█████████▉| 26.4M/26.4M [00:01<00:00, 33.6MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 18.4MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 292kB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 291kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%| | 32.8k/4.42M [00:00<00:13, 318kB/s]
1%|▏ | 65.5k/4.42M [00:00<00:13, 317kB/s]
3%|▎ | 131k/4.42M [00:00<00:09, 462kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 655kB/s]
10%|█ | 459k/4.42M [00:00<00:03, 1.22MB/s]
21%|██ | 918k/4.42M [00:00<00:01, 2.31MB/s]
41%|████▏ | 1.84M/4.42M [00:00<00:00, 4.46MB/s]
83%|████████▎ | 3.67M/4.42M [00:00<00:00, 8.69MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.34MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 42.6MB/s]
绘图函数
可视化优化历史。有关详细信息,请参阅 plot_optimization_history()。
plot_optimization_history(study)
可视化试验的学习曲线。有关详细信息,请参阅 plot_intermediate_values()。
plot_intermediate_values(study)
可视化高维参数关系。有关详细信息,请参阅 plot_parallel_coordinate()。
plot_parallel_coordinate(study)
选择要可视化的参数。
plot_parallel_coordinate(study, params=["lr", "n_layers"])
可视化超参数关系。有关详细信息,请参阅 plot_contour()。
plot_contour(study)
选择要可视化的参数。
plot_contour(study, params=["lr", "n_layers"])
将各个超参数可视化为切片图。有关详细信息,请参阅 plot_slice()。
plot_slice(study)
选择要可视化的参数。
plot_slice(study, params=["lr", "n_layers"])
可视化参数重要性。有关详细信息,请参阅 plot_param_importances()。
plot_param_importances(study)
通过超参数重要性了解哪些超参数影响试验持续时间。
optuna.visualization.plot_param_importances(
study, target=lambda t: t.duration.total_seconds(), target_name="duration"
)
可视化经验分布函数。有关详细信息,请参阅 plot_edf()。
plot_edf(study)
可视化参数关系,并按目标值着色散点图。有关详细信息,请参阅 plot_rank()。
plot_rank(study)
可视化已执行试验的优化时间线。有关详细信息,请参阅 plot_timeline()。
plot_timeline(study)
自定义生成的图形
在 optuna.visualization 和 optuna.visualization.matplotlib 中,函数会返回一个可编辑的图形对象:根据可视化库,它是 plotly.graph_objects.Figure 或 matplotlib.axes.Axes。这允许用户通过使用可视化库的 API 来修改生成的图形以满足他们的需求。以下示例手动替换了由基于 Plotly 的 plot_intermediate_values() 绘制的图形标题。
fig = plot_intermediate_values(study)
fig.update_layout(
title="Hyperparameter optimization for FashionMNIST classification",
xaxis_title="Epoch",
yaxis_title="Validation Accuracy",
)
脚本总运行时间: (1 分钟 19.413 秒)