PyTorch中的优化器

1
import torch.optim as optim

PyTorch提供个多个优化器,其中常用的有SGD、ASGD、RMSprop、Adam等。

SGD

CLASS torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)

随机梯度下降

参数

params(iterable)

待优化参数的iterable或者是定义了参数组的dict

lr(float)

学习率

momentum(float)

动量因子(默认为0)

dampening(float)

动量的抑制因子(默认为0)

nesterov(bool)

使用Nesterov动量(默认为False)

ASGD

class torch.optim.ASGD(params, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)[source]

平均随机梯度下降

参数

params(iterable)

待优化参数的iterable或者是定义了参数组的dict

lr(float)

学习率,默认为(1e-2)

lambd(float)

衰减项,默认为(1e-4)

alpha(float)

eta更新指数,默认为0.75

t0(float)

指明在哪一次开始平均化,默认为(1e6)

weight_decay(float)

权重衰减(L2惩罚)默认为0

Adam

class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

来源于adaptive moments

参数

params (iterable)

待优化参数的iterable或者是定义了参数组的dict

lr (float, 可选)

学习率(默认:1e-3)

betas (Tuple[float, float], 可选)

用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)

eps (float, 可选)

为了增加数值计算的稳定性而加到分母里的项(默认:1e-8)

weight_decay (float, 可选)

权重衰减(L2惩罚)(默认: 0)