大家好,欢迎来到IT知识分享网。
1.简介
在之前的文章里,我们介绍了集成一阶动量和二阶动量的优化器Adam。AdamW其实是在Adam的基础上加入了weight decay正则化,但是我们上一篇文章里也看到了Adam的代码中已经有正则化,那么两者有什么区别呢?
2.AdamW
其实AdamW和Adam唯一的区别,就是weight decay的加入方式。
在Adam当中,weight decay是直接加入到梯度当中:
其中


而在AdamW中,正则化变成了:
其中
所以AdamW的思路特别简单:反正正则化系数加进梯度之后最终也要在权重上进行更新,那为什么还需要加进梯度去呢?因此,AdamW直接在权重上进行衰减,在收敛速度上也能领先于Adam。
3.思考
但仔细一想,Adam+L2正则化和AdamW虽然都可以实现权重衰减,但是两者的实施细节上其实是不一样的。L2正则化是在loss上加入权重的惩罚系数,也可以说是在梯度上进行修改,而AdamW其实是更字面意思的weight decay,就是直接让权重衰减。
这两者其实在SGD上是对等的:
只不过在Adam这种要考虑一阶和二阶动量时,以上方程已不满足线性关系,所以最终的结果是有区别的。那么AdamW相对于Adam而言,除了收敛速度更快之外,它的正则系数也不再受动量的影响(一般会被除以二阶动量而稀释),因此拥有超参独立和正则力度增加的优点,这也是原论文名字中带有decouple的原因。
4.pytorch代码
AdamW的伪代码流程如下:
以下代码为pytorch官方Adam的代码。
def _single_tensor_adamw( params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, maximize: bool, capturable: bool, differentiable: bool, ): assert grad_scale is None and found_inf is None for i, param in enumerate(params): grad = grads[i] if not maximize else -grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] if capturable: assert ( param.is_cuda and step_t.is_cuda ), "If capturable=True, params and state_steps must be CUDA tensors." if torch.is_complex(param): grad = torch.view_as_real(grad) exp_avg = torch.view_as_real(exp_avg) exp_avg_sq = torch.view_as_real(exp_avg_sq) param = torch.view_as_real(param) # update step step_t += 1 # Perform stepweight decay param.mul_(1 - lr * weight_decay) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if capturable or differentiable: step = step_t # 1 - beta1 step can't be captured in a CUDA graph, even if step is a CUDA tensor # (incurs "RuntimeError: CUDA error: operation not permitted when stream is capturing") bias_correction1 = 1 - torch.pow(beta1, step) bias_correction2 = 1 - torch.pow(beta2, step) step_size = lr / bias_correction1 step_size_neg = step_size.neg() bias_correction2_sqrt = bias_correction2.sqrt() if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now if differentiable: max_exp_avg_sqs_i = max_exp_avg_sqs[i].clone() else: max_exp_avg_sqs_i = max_exp_avg_sqs[i] max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs_i, exp_avg_sq)) # Uses the max. for normalizing running avg. of gradient # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) denom = ( max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) ).add_(eps / step_size_neg) else: denom = ( exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) ).add_(eps / step_size_neg) param.addcdiv_(exp_avg, denom) else: step = _get_value(step_t) bias_correction1 = 1 - beta1 step bias_correction2 = 1 - beta2 step step_size = lr / bias_correction1 bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) else: denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) param.addcdiv_(exp_avg, denom, value=-step_size)
业务合作/学习交流+v:lizhiTechnology
如果想要了解更多优化器相关知识,可以参考我的专栏和其他相关文章:
优化器_Lcm_Tech的博客-CSDN博客
【优化器】(一) SGD原理 & pytorch代码解析_sgd优化器-CSDN博客
【优化器】(二) AdaGrad原理 & pytorch代码解析_adagrad优化器-CSDN博客
【优化器】(三) RMSProp原理 & pytorch代码解析_rmsprop优化器-CSDN博客
【优化器】(四) AdaDelta原理 & pytorch代码解析_adadelta里rho越大越敏感-CSDN博客
【优化器】(五) Adam原理 & pytorch代码解析_adam优化器-CSDN博客
【优化器】(六) AdamW原理 & pytorch代码解析-CSDN博客
【优化器】(七) 优化器统一框架 & 总结分析_mosec优化器优点-CSDN博客
如果想要了解更多深度学习相关知识,可以参考我的其他文章:
【损失函数】(一) L1Loss原理 & pytorch代码解析_l1 loss-CSDN博客
【图像生成】(一) DNN 原理 & pytorch代码实例_pytorch dnn代码-CSDN博客
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/111411.html
