【mmengine】优化器封装(OptimWrapper)(进阶)在执行器(Runner)中配置优化器封装(OptimWrapper)

一、 简单配置

  • 以配置一个 SGD 优化器封装为例:
    优化器封装需要接受 optimizer 参数,因此我们首先需要为优化器封装配置 optimizer。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 OPTIMIZERS 注册表中

  • 以配置一个 SGD 优化器封装为例:

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
  • 这样我们就配置好了一个优化器类型为 SGD 的优化器封装,学习率、动量等参数如配置所示。考虑到 OptimWrapper 为标准的单精度训练,因此我们也可以不配置 type 字段:
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(optimizer=optimizer)
  • 要想开启混合精度训练和梯度累加,需要将 type 切换成 AmpOptimWrapper,并指定 accumulative_counts 参数
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(type='AmpOptimWrapper', optimizer=optimizer, accumulative_counts=2)

二、 模型中的不同参数设置不同的超参数

PyTorch 的优化器支持对模型中的不同参数设置不同的超参数,例如对一个分类模型的骨干(backbone)和分类头(head)设置不同的学习率:

from torch.optim import SGD
import torch.nn as nn

model = nn.ModuleDict(dict(backbone=nn.Linear(1, 1), head=nn.Linear(1, 1)))

# backbone部分: lr=0.01, momentum=0.9
# head部分: lr=1e-3, momentum=0.8
optimizer = SGD(
    [
    {'params': model.backbone.parameters()},
    {'params': model.head.parameters(), 'lr': 1e-3,'momentum': 0.8}
    ],
    lr=0.01, momentum=0.9)

backbone部分: lr=0.01, momentum=0.9
head部分: lr=1e-3, momentum=0.8

三、 不同类型的参数设置不同的超参系数

例如,我们可以在 paramwise_cfg 中设置 norm_decay_mult=0,从而将正则化层(normalization layer)的权重(weight)和偏置(bias)的权值衰减系数(weight decay)设置为 0
具体示例如下,我们将 ToyModel 中所有正则化层(head.bn)的权重衰减系数设置为 0:

from mmengine.optim import build_optim_wrapper
from collections import OrderedDict

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.ModuleDict(
            dict(layer0=nn.Linear(1, 1), layer1=nn.Linear(1, 1)))
        self.head = nn.Sequential(
            OrderedDict(
                linear=nn.Linear(1, 1),
                bn=nn.BatchNorm1d(1)))

optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
    paramwise_cfg=dict(norm_decay_mult=0))
optimizer = build_optim_wrapper(ToyModel(), optim_wrapper)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/885892.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【学习笔记】手写 Tomcat 六

目录 一、线程池 1. 构建线程池的类 2. 创建任务 3. 执行任务 测试 二、URL编码 解决方案 测试 三、如何接收客户端发送的全部信息 解决方案 测试 四、作业 1. 了解工厂模式 2. 了解反射技术 一、线程池 昨天使用了数据库连接池,我们了解了连接池的优…

渗透测试--文件上传常用绕过方式

文件上传常用绕过方式 1.前端代码,限制只允许上传图片。修改png为php即可绕过前端校验。 2.后端校验Content-Type 校验文件格式 前端修改,抓取上传数据包,并且修改 Content-Type 3.服务端检测(目录路径检测) 对目…

医院体检管理系统小程序的设计

管理员账户功能包括:系统首页,个人中心,用户管理,体检分类管理,体检套餐管理,体检预约管理,体检报告管理,系统管理 微信端账号功能包括:系统首页,体检套餐&a…

四、Drf认证组件

四、Drf认证组件 4.1 快速使用 from django.shortcuts import render,HttpResponse from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.authentication import BaseAuthentication from rest_framework.exception…

数据结构:将复杂的现实问题简化为计算机可以理解和处理的形式

整句话的总体意义是,**数据结构是用于将现实世界中的实体和关系抽象为数学模型,并在计算机中表示和实现的关键工具**。它不仅包括如何存储数据,还包括对这些数据的操作,能够有效支持计算机程序的运行。通过这一过程,数…

利用PDLP扩展线性规划求解能力

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

Java项目实战II基于Java+Spring Boot+MySQL的甘肃非物质文化网站设计与实现(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者 一、前言 甘肃省作为中国历史文化名省,拥有丰富的非物质文化遗产资源,涵盖表演艺术、手…

TypeScript 封装 Axios 1.7.7

随着Axios版本的不同,类型也在改变,以后怎么写类型? 1. 封装Axios 将Axios封装成一个类,同时重新封装request方法 重新封装request有几个好处: 所有的请求将从我们定义的requet请求中发送,这样以后更换…

Golang | Leetcode Golang题解之第441题排列硬币

题目: 题解: func arrangeCoins(n int) int {return sort.Search(n, func(k int) bool { k; return k*(k1) > 2*n }) }

【Unity服务】如何使用Unity Version Control

Unity上的线上服务有很多,我们接触到的第一个一般就是Version Control,用于对项目资源的版本管理。 本文介绍如何为项目添加Version Control,并如何使用,以及如何将项目与Version Control断开链接。 其实如果仅仅是对项目资源进…

09_OpenCV彩色图片直方图

import cv2 import numpy as np import matplotlib.pyplot as plt %matplotlib inlineimg cv2.imread(computer.jpeg, 1) img cv2.cvtColor(img, cv2.COLOR_BGR2RGB) plt.imshow(img) plt.show()plot绘制直方图 plt.hist(img.ravel(), 256) #ravel() 二维降一维 256灰度级…

学习记录:js算法(五十):二叉树的右视图

文章目录 二叉树的右视图我的思路网上思路 总结 二叉树的右视图 给定一个二叉树的 根节点 root,想象自己站在它的右侧,按照从顶部到底部的顺序,返回从右侧所能看到的节点值。 图一: 示例 1:如图一 输入: [1,2,3,null,5,null,4] …

C++面向对象基础

目录 一.函数 1.内联函数 2.函数重载 3.哑元函数 二.类和对象 2.1 类的定义 2.2 创建对象 三. 封装(重点) 四. 构造函数 constructor(重点) 4.1 基础使用 4.2 构造初始化列表 4.3 构造函数的调用方式(掌握…

解决方法:PDF文件打开之后不能打印?

打开PDF文件之后,发现文件不能打印?这是什么原因?首先我们需要先查看一下自己的打印机是否能够正常运行,如果打印机是正常的,我们再查看一下,文件中的打印功能按钮是否是灰色的状态。 如果PDF中的大多数功…

秋招内推--招联金融2025

【投递方式】 直接扫下方二维码,或点击内推官网https://wecruit.hotjob.cn/SU61025e262f9d247b98e0a2c2/mc/position/campus,使用内推码 igcefb 投递) 【招聘岗位】 后台开发 前端开发 数据开发 数据运营 算法开发 技术运维 软件测试 产品策…

数据结构-LRU缓存(C语言实现)

遇到困难,不必慌张,正是成长的时候,耐心一点! 目录 前言一、题目介绍二、实现过程2.1 实现原理2.2 实现思路2.2.1 双向链表2.2.2 散列表 2.3 代码实现2.3.1 结构定义2.3.2 双向链表操作实现2.3.3 实现散列表的操作2.3.4 内存释放代…

SigmaStudio控件Cross Mixer\Signal Merger算法效果分析

衰减与叠加混音算法验证分析一 CH2:输入源为-20dB正弦波1khz CH1叠加混音:参考混音算法https://blog.csdn.net/weixin_48408892/article/details/129878036?spm1001.2014.3001.5502 Ch0衰减混音:外部多个输入源做混音时,建议参考该算法控件&…

在VMware虚拟机上部署polardb

免密登录到我们的虚拟机之后,要在虚拟机上部署polardb数据库,首先第一步要先克隆源码: 为了进SSH协议进行传输源码需要先进行下面的步骤: 将宿主机上的私钥文件复制到虚拟机上 scp "C:\Users\waitw\.ssh\id_rsa" ann…

Azkaban:大数据任务调度与编排工具的安装与使用

在当今大数据时代,数据处理和分析任务变得越来越复杂。一个完整的大数据分析系统通常由大量任务单元组成,如 shell 脚本程序、mapreduce 程序、hive 脚本、spark 程序等。这些任务单元之间存在时间先后及前后依赖关系,为了高效地组织和执行这…

Leetcode 每日一题:Crack The Safe

写在前面: 学期真的忙起来了,我的两个社团也在上一周终于完成了大部分招新活动。虽然后面有即将到来的期中考试和求职,但希望能有时间将帖子的频率提上去吧(真的尽量因为从做题思考到写博客讲解思路需要大量的时间,在…