注意
跳转至末尾 下载完整的示例代码。
使用 Optuna 进行多目标优化
本教程通过优化 PyTorch 中实现的模型的 Fashion MNIST 数据集的验证准确率和模型的 FLOPS(每秒浮点运算次数)来展示 Optuna 的多目标优化功能。
我们使用 fvcore 来衡量 FLOPS。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from fvcore.nn import FlopCountAnalysis
import optuna
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
def define_model(trial):
n_layers = trial.suggest_int("n_layers", 1, 3)
layers = []
in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
p = trial.suggest_float("dropout_{}".format(i), 0.2, 0.5)
layers.append(nn.Dropout(p))
in_features = out_features
layers.append(nn.Linear(in_features, 10))
layers.append(nn.LogSoftmax(dim=1))
return nn.Sequential(*layers)
# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
F.nll_loss(model(data), target).backward()
optimizer.step()
def eval_model(model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
pred = model(data).argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / N_VALID_EXAMPLES
flops = FlopCountAnalysis(model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),)).total()
return flops, accuracy
定义多目标目标函数。目标是 FLOPS 和准确率。
def objective(trial):
train_dataset = torchvision.datasets.FashionMNIST(
DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
val_dataset = torchvision.datasets.FashionMNIST(
DIR, train=False, transform=torchvision.transforms.ToTensor()
)
val_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
model = define_model(trial).to(DEVICE)
optimizer = torch.optim.Adam(
model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
)
for epoch in range(10):
train_model(model, optimizer, train_loader)
flops, accuracy = eval_model(model, val_loader)
return flops, accuracy
运行多目标优化
如果您的优化问题是多目标的,Optuna 会假定您为每个目标指定优化方向。具体来说,在本例中,我们希望最小化 FLOPS(我们想要一个更快的模型)并最大化准确率。因此,我们将 directions 设置为 ["minimize", "maximize"]。
study = optuna.create_study(directions=["minimize", "maximize"])
study.optimize(objective, n_trials=30, timeout=300)
print("Number of finished trials: ", len(study.trials))
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Number of finished trials: 30
请注意,以下部分需要安装 Plotly 用于可视化,以及 scikit-learn 用于计算超参数重要性。
$ pip install plotly
$ pip install scikit-learn
$ pip install nbformat # Required if you are running this tutorial in Jupyter Notebook.
直观地检查 Pareto 前沿上的试验。
optuna.visualization.plot_pareto_front(study, target_names=["FLOPS", "accuracy"])
使用 best_trials 获取 Pareto 前沿上的试验列表。
例如,下面的代码显示了 Pareto 前沿上的试验数量,并选择了准确率最高的试验。
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")
trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
print("Trial with highest accuracy: ")
print(f"\tnumber: {trial_with_highest_accuracy.number}")
print(f"\tparams: {trial_with_highest_accuracy.params}")
print(f"\tvalues: {trial_with_highest_accuracy.values}")
Number of trials on the Pareto front: 3
Trial with highest accuracy:
number: 13
params: {'n_layers': 1, 'n_units_l0': 61, 'dropout_0': 0.31682158090960255, 'lr': 0.006763684547607584}
values: [48434.0, 0.8375]
通过超参数重要性了解哪些超参数对 FLOPS 的影响最大。
optuna.visualization.plot_param_importances(
study, target=lambda t: t.values[0], target_name="flops"
)
脚本总运行时间: (1 分钟 16.620 秒)