mmdetection3增加12种注意力机制

news/2024/7/7 19:04:53 标签: 深度学习, pytorch, python

在mmdetection/mmdet/models/layers/目录下增加attention_layers.py

import torch.nn as nn
from mmdet.registry import MODELS
#自定义注意力机制算法
from .attention.CBAM import CBAMBlock as _CBAMBlock
from .attention.BAM import BAMBlock as _BAMBlock
from .attention.SEAttention import SEAttention as _SEAttention
from .attention.ECAAttention import ECAAttention as _ECAAttention
from .attention.ShuffleAttention import ShuffleAttention as _ShuffleAttention
from .attention.SGE import SpatialGroupEnhance as _SpatialGroupEnhance
from .attention.A2Atttention import DoubleAttention as _DoubleAttention
from .attention.PolarizedSelfAttention import SequentialPolarizedSelfAttention as _SequentialPolarizedSelfAttention
from .attention.CoTAttention import CoTAttention as _CoTAttention
from .attention.TripletAttention import TripletAttention as _TripletAttention
from .attention.CoordAttention import CoordAtt as _CoordAtt
from .attention.ParNetAttention import ParNetAttention as _ParNetAttention


@MODELS.register_module()
class CBAMBlock(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(CBAMBlock, self).__init__()
        print("======激活注意力机制模块【CBAMBlock】======")
        self.module = _CBAMBlock(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)
    
    
@MODELS.register_module()
class BAMBlock(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(BAMBlock, self).__init__()
        print("======激活注意力机制模块【BAMBlock】======")
        self.module = _BAMBlock(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)


@MODELS.register_module()
class SEAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(SEAttention, self).__init__()
        print("======激活注意力机制模块【SEAttention】======")
        self.module = _SEAttention(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)   
 

@MODELS.register_module()
class ECAAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(ECAAttention, self).__init__()
        print("======激活注意力机制模块【ECAAttention】======")
        self.module = _ECAAttention(**kwargs)

    def forward(self, x):
        return self.module(x)  


@MODELS.register_module()
class ShuffleAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(ShuffleAttention, self).__init__()
        print("======激活注意力机制模块【ShuffleAttention】======")
        self.module = _ShuffleAttention(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)


@MODELS.register_module()
class SpatialGroupEnhance(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(SpatialGroupEnhance, self).__init__()
        print("======激活注意力机制模块【SpatialGroupEnhance】======")
        self.module = _SpatialGroupEnhance(**kwargs)

    def forward(self, x):
        return self.module(x)   
    

@MODELS.register_module()
class DoubleAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(DoubleAttention, self).__init__()
        print("======激活注意力机制模块【DoubleAttention】======")
        self.module = _DoubleAttention(in_channels, 128, 128,True)

    def forward(self, x):
        return self.module(x)  


@MODELS.register_module()
class SequentialPolarizedSelfAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(SequentialPolarizedSelfAttention, self).__init__()
        print("======激活注意力机制模块【Polarized Self-Attention】======")
        self.module = _SequentialPolarizedSelfAttention(channel=in_channels)

    def forward(self, x):
        return self.module(x)   
    
    
@MODELS.register_module()
class CoTAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(CoTAttention, self).__init__()
        print("======激活注意力机制模块【CoTAttention】======")
        self.module = _CoTAttention(dim=in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)  

    
@MODELS.register_module()
class TripletAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(TripletAttention, self).__init__()
        print("======激活注意力机制模块【TripletAttention】======")
        self.module = _TripletAttention()

    def forward(self, x):
        return self.module(x)      


@MODELS.register_module()
class CoordAtt(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(CoordAtt, self).__init__()
        print("======激活注意力机制模块【CoordAtt】======")
        self.module = _CoordAtt(in_channels, in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)    


@MODELS.register_module()
class ParNetAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(ParNetAttention, self).__init__()
        print("======激活注意力机制模块【ParNetAttention】======")
        self.module = _ParNetAttention(channel=in_channels)

    def forward(self, x):
        return self.module(x)  

与attention_layers.py同级目录下创建attention文件夹,在attention文件中放12种注意力机制算法文件。

下载地址:mmdetection3的12种注意力机制资源-CSDN文库icon-default.png?t=N7T8https://download.csdn.net/download/lanyan90/89513979

使用方法:

以faster-rcnn_r50为例,创建faster-rcnn_r50_fpn_1x_coco_attention.py

_base_ = 'configs/detection/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'

custom_imports = dict(imports=['mmdet.models.layers.attention_layers'], allow_failed_imports=False)

model = dict(
    backbone=dict(
        plugins = [
            dict(
                position='after_conv3',
                #cfg = dict(type='CBAMBlock', reduction=16, kernel_size=7)
                #cfg = dict(type='BAMBlock', reduction=16, dia_val=1)
                #cfg = dict(type='SEAttention', reduction=8)
                #cfg = dict(type='ECAAttention', kernel_size=3)
                #cfg = dict(type='ShuffleAttention', G=8)
                #cfg = dict(type='SpatialGroupEnhance', groups=8)
                #cfg = dict(type='DoubleAttention')
                #cfg = dict(type='SequentialPolarizedSelfAttention')
                #cfg = dict(type='CoTAttention', kernel_size=3)
                #cfg = dict(type='TripletAttention')
                #cfg = dict(type='CoordAtt', reduction=32)
                #cfg = dict(type='ParNetAttention')
            )
        ]
    )
)

想使用哪种注意力机制,放开plugins中的注释即可。

以mask-rcnn_r50为例,创建mask-rcnn_r50_fpn_1x_coco_attention.py

_base_ = 'configs/segmentation/mask_rcnn/mask-rcnn_r50_fpn_1x_coco.py'
custom_imports = dict(imports=['mmdet.models.layers.attention_layers'], allow_failed_imports=False)

model = dict(
    backbone=dict(
        plugins = [
            dict(
                position='after_conv3',
                #cfg = dict(type='CBAMBlock', reduction=16, kernel_size=7)
                #cfg = dict(type='BAMBlock', reduction=16, dia_val=1)
                #cfg = dict(type='SEAttention', reduction=8)
                #cfg = dict(type='ECAAttention', kernel_size=3)
                #cfg = dict(type='ShuffleAttention', G=8)
                #cfg = dict(type='SpatialGroupEnhance', groups=8)
                #cfg = dict(type='DoubleAttention')
                #cfg = dict(type='SequentialPolarizedSelfAttention')
                #cfg = dict(type='CoTAttention', kernel_size=3)
                #cfg = dict(type='TripletAttention')
                #cfg = dict(type='CoordAtt', reduction=32)
                #cfg = dict(type='ParNetAttention')
            )
        ]
    )
)

用法一样!


http://www.niftyadmin.cn/n/5534927.html

相关文章

自然语言处理学习--3

对自然语言处理领域相关文献进行梳理和总结,对学习的文献进行梳理和学习记录。希望和感兴趣的小伙伴们一起学习。欢迎大家在评论区进行学习交流! 论文:《ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information》 下面…

Android 使用adb操作WiFi连接扫描等相关指令

Android 使用adb操作WiFi连接扫描等相关指令 文章目录 Android 使用adb操作WiFi连接扫描等相关指令一、前言二、adb shell cmd wifi 主要命令### 1、adb shell cmd wifi -h2、打开关闭WLAN3、扫描WiFi 和查看扫描列表4、连接WiFi5、查看WiFi状态(1) wifi 正常连接下的信息&…

在Ubuntu 22.04 LTS 上安装 MySQL两种方式:在线方式和离线方式

Ubuntu安装MySQL 介绍: Ubuntu 是一款基于Linux操作系统的免费开源发行版,广受欢迎。它以稳定性、安全性和用户友好性而闻名,适用于桌面和服务器环境。Ubuntu提供了大量的软件包和应用程序,拥有庞大的社区支持和活跃的开发者社区…

深度学习中的反向传播算法的原理

深度学习中的反向传播算法的原理,以及如何计算梯度 反向传播算法(Backpropagation)是深度学习中最核心的优化技术之一,用于训练神经网络。它基于链式法则,通过从输出层逆向计算误差并逐层传递到输入层来更新模型参数&…

windows非白名单exe监控并杀死

需求:孩子在家用电脑上网课,总是悄悄打开游戏或视频软件 方案:指定白名单exe,打开非白名单的就自动被杀死,并记录日志供查看 不知道是否还有更好的结果方案? import psutil import time import logging#…

golang 获取系统的主机 CPU 内存 磁盘等信息

golang 获取系统的主机 CPU 内存 磁盘等信息 要求 需要go1.18或更高版本 官方地址:https://github.com/shirou/gopsutil 使用 #下载包 go get github.com/shirou/gopsutil/v3/cpu go get github.com/shirou/gopsutil/v3/disk go get github.com/shirou/gopsuti…

macOS笔记

1、MAC中抹掉就是格式化; 2、MAC中拔出U盘:在桌面找到U盘,点击右键显示“推出***”,点击退出。 3、MAC系统版本: macOS 11: Big Sur macOS 12 Monterey macOS 13 Ventura macOS 14 Sonoma macOS 15 Sequoia 4、通用快捷键&#xf…

P2P文件传输协议介绍

P2P文件传输协议是一种基于对等网络(Peer-to-Peer,简称P2P)的文件共享和传输技术。以下是关于P2P文件传输协议的详细介绍: 一、定义与原理 P2P文件传输协议允许网络中的各个节点(即计算机或其他设备)之间…