对模型结构 UNetWithAttention 的详细讲解

class PAM_Module(nn.Module):
    """Position Attention Module"""
    def __init__(self, in_dim):
        super(PAM_Module, self).__init__()
        self.chanel_in = in_dim
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out

class CAM_Module(nn.Module):
    """Channel Attention Module"""
    def __init__(self, in_dim):
        super(CAM_Module, self).__init__()
        self.chanel_in = in_dim
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)
        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out

# ------------------- Enhanced U-Net Model -------------------
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNetWithAttention(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNetWithAttention, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Bottleneck with attention
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.pam = PAM_Module(features[-1]*2)
        self.cam = CAM_Module(features[-1]*2)

        # Up part
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Downsampling
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        # Bottleneck with attention
        x = self.bottleneck(x)
        x = self.pam(x)
        x = self.cam(x)

        # Upsampling
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:], mode='bilinear', align_corners=True)
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

这段代码实现了一个基于 U-Net 的分割模型,结合了注意力机制(PAM 和 CAM),用于图像分割任务(如农田分割)。以下是对模型结构 UNetWithAttention 的详细讲解,涵盖其组成部分、工作原理和设计意图。

image-20250415163501556

1. 模型概述

UNetWithAttention 是一个改进的 U-Net 模型,专门为图像分割任务设计。它继承了 U-Net 的经典编码器-解码器结构,同时在瓶颈层加入了 位置注意力模块(PAM)通道注意力模块(CAM),以增强特征提取能力。模型的输入是 RGB 图像(3 通道),输出是单通道的分割掩码(二值化预测)。

模型的核心结构包括:

  • 编码器(Downsampling Path):通过卷积和池化操作逐步提取高级特征。
  • 瓶颈层(Bottleneck):结合注意力机制增强特征表达。
  • 解码器(Upsampling Path):通过上采样和跳跃连接恢复空间分辨率。
  • 最终输出层:生成分割掩码。

2. 模型结构详解

2.1 整体架构

模型定义在 UNetWithAttention 类中,构造函数初始化了以下参数:

  • in_channels:输入通道数,默认为 3(RGB 图像)。
  • out_channels:输出通道数,默认为 1(单通道分割掩码)。
  • features:编码器和解码器每层的特征通道数,默认为 [64, 128, 256, 512]。

模型由以下模块组成:

  • downs:编码器模块列表,包含下采样操作。
  • ups:解码器模块列表,包含上采样和特征融合操作。
  • pool:最大池化层,用于下采样。
  • bottleneck:瓶颈层,包含双重卷积和注意力机制。
  • pam 和 cam:位置和通道注意力模块。
  • final_conv:最终的 1x1 卷积层,生成分割掩码。

2.2 子模块详解

2.2.1 DoubleConv 模块

DoubleConv 是模型的基础构建块,广泛用于编码器、解码器和瓶颈层。它的结构是一个双重卷积块:

  • 结构:
    • 第一个 3x3 卷积(Conv2d),无偏置,保持特征图大小(padding=1)。
    • 批归一化(BatchNorm2d),稳定训练。
    • ReLU 激活函数(ReLU),引入非线性。
    • 第二个 3x3 卷积,同样无偏置,保持特征图大小。
    • 批归一化。
    • ReLU 激活。
  • 作用:
    • 通过两次卷积增加网络深度,增强特征提取能力。
    • 批归一化和 ReLU 提高训练稳定性和非线性表达。
  • 输入输出:
    • 输入:in_channels 通道的特征图。
    • 输出:out_channels 通道的特征图。
2.2.2 PAM_Module(位置注意力模块)

PAM_Module 是一个空间注意力机制,用于增强模型对空间位置关系的建模能力。

  • 结构:
    • 查询(Query)分支:1x1 卷积将输入通道数从 in_dim 降到 in_dim//8。
    • 键(Key)分支:1x1 卷积,同样降维。
    • 值(Value)分支:1x1 卷积保持原始通道数。
    • 注意力计算:
      • 查询和键进行矩阵乘法,生成空间注意力图(energy)。
      • 通过 Softmax 归一化生成注意力权重(attention)。
      • 值与注意力权重相乘,生成加权特征。
    • 残差连接:输出为 gamma * 加权特征 + 输入特征,其中 gamma 是可学习的缩放因子。
  • 作用:
    • 捕捉全局空间依赖关系,增强模型对不同位置像素的相关性建模。
    • 通过降维降低计算量,同时保留重要信息。
  • 输入输出:
    • 输入:形状为 (batch_size, in_dim, height, width) 的特征图。
    • 输出:相同形状,经过空间注意力增强的特征图。
2.2.3 CAM_Module(通道注意力模块)

CAM_Module 是一个通道注意力机制,关注不同通道之间的关系。

  • 结构:
    • 直接使用输入特征图作为查询、键和值。
    • 注意力计算:
      • 查询和键进行矩阵乘法,生成通道相关性矩阵(energy)。
      • 通过能量归一化(减去最大值后 Softmax)生成通道注意力权重。
      • 值与注意力权重相乘,生成加权特征。
    • 残差连接:输出为 gamma * 加权特征 + 输入特征。
  • 作用:
    • 增强重要通道的权重,抑制无关通道,提升特征表达的针对性。
    • 相比 PAM,CAM 更关注通道间的全局关系。
  • 输入输出:
    • 输入:形状为 (batch_size, in_dim, height, width) 的特征图。
    • 输出:相同形状,经过通道注意力增强的特征图。
2.2.4 编码器(Downsampling Path)

编码器由 downs 模块列表组成,每个模块是一个 DoubleConv,后接最大池化层(pool)。

  • 结构:

    • 对于每个特征数

      feature

      (如 64, 128, 256, 512):

      • DoubleConv(in_channels, feature):卷积操作提取特征。
      • MaxPool2d(kernel_size=2, stride=2):下采样,空间尺寸减半。
    • 通道数逐步增加:3 → 64 → 128 → 256 → 512。

    • 空间分辨率逐步减小:如 256x256 → 128x128 → 64x64 → 32x32 → 16x16。

  • 作用:

    • 提取多尺度特征,从低级(边缘、纹理)到高级(语义信息)。
    • 跳跃连接保存中间特征,供解码器使用。
  • 跳跃连接:

    • 每个 DoubleConv 的输出存储在 skip_connections 中,用于后续解码器融合。
2.2.5 瓶颈层(Bottleneck)

瓶颈层是模型的最深处,连接编码器和解码器。

  • 结构:
    • DoubleConv(features[-1], features[-1]*2):将通道数从 512 扩展到 1024。
    • PAM_Module(features[-1]*2):空间注意力,增强空间关系。
    • CAM_Module(features[-1]*2):通道注意力,增强通道重要性。
  • 作用:
    • 捕获深层语义信息。
    • 注意力机制提高特征的质量,减少冗余信息。
  • 输入输出:
    • 输入:512 通道的特征图(如 16x16 分辨率)。
    • 输出:1024 通道的特征图,经过注意力增强。
2.2.6 解码器(Upsampling Path)

解码器由 ups 模块列表组成,包含上采样和特征融合操作。

  • 结构:

    • 对于每个特征数

      feature

      (如 512, 256, 128, 64):

      • ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2):上采样,空间尺寸翻倍。
      • DoubleConv(feature*2, feature):融合跳跃连接特征并进一步处理。
    • 跳跃连接融合:

      • 将上采样特征与对应编码器层的特征(skip_connection)在通道维度拼接。
      • 如果尺寸不匹配,使用双线性插值(F.interpolate)调整大小。
    • 通道数逐步减少:1024 → 512 → 256 → 128 → 64。

    • 空间分辨率逐步增大:如 16x16 → 32x32 → 64x64 → 128x128 → 256x256。

  • 作用:

    • 恢复空间分辨率,生成高分辨率的分割掩码。
    • 跳跃连接融合低级和高级特征,保留细节和语义信息。
2.2.7 最终输出层
  • 结构:
    • Conv2d(features[0], out_channels, kernel_size=1):1x1 卷积将通道数从 64 降到 1。
  • 作用:
    • 生成单通道的分割预测(未激活的 logits)。
    • 通过 Sigmoid 激活后,输出值为 [0, 1],表示像素属于目标类别的概率。
  • 输入输出:
    • 输入:64 通道的特征图(256x256 分辨率)。
    • 输出:1 通道的分割掩码(256x256 分辨率)。

2.3 前向传播(Forward Pass)

UNetWithAttention 的前向传播过程如下:

  1. 编码器阶段:
    • 输入图像(batch_size, 3, 256, 256)通过 downs 模块逐层卷积和池化。
    • 每次池化后,空间尺寸减半,通道数增加,生成跳跃连接特征。
    • 输出为 (batch_size, 512, 16, 16)。
  2. 瓶颈阶段:
    • 输入 (batch_size, 512, 16, 16) 通过 DoubleConv 扩展到 (batch_size, 1024, 16, 16)。
    • 应用 PAM_Module 和 CAM_Module,增强特征表达。
  3. 解码器阶段:
    • 通过 ConvTranspose2d 上采样,空间尺寸逐步恢复。
    • 拼接跳跃连接特征,输入 DoubleConv 进行融合和处理。
    • 最终输出为 (batch_size, 64, 256, 256)。
  4. 输出阶段:
    • 1x1 卷积生成 (batch_size, 1, 256, 256) 的分割掩码。

3. 设计意图与优势

  • U-Net 结构:
    • 编码器-解码器架构适合分割任务,能捕捉多尺度特征。
    • 跳跃连接保留低级细节,解决深层网络中的信息丢失问题。
  • 注意力机制:
    • PAM:增强空间关系建模,适合处理复杂场景(如农田边界)。
    • CAM:突出重要通道,减少冗余,提升特征针对性。
    • 在瓶颈层添加注意力模块,避免过早干扰低级特征提取。
  • DoubleConv:
    • 双重卷积增加非线性,提升特征提取能力。
    • 批归一化加速收敛,ReLU 引入稀疏性。
  • 灵活性:
    • 可调整 features 参数,适应不同任务的计算需求。
    • 支持多通道输入/输出,适配多种分割任务。

4. 与其他模型的对比

  • 传统 U-Net:
    • 相比传统 U-Net,UNetWithAttention 在瓶颈层引入注意力机制,增强了特征选择能力,适合复杂场景。
    • 计算复杂度略高,但性能提升显著。
  • SegNet:
    • SegNet 使用池化索引上采样,内存效率高,但细节恢复不如 U-Net 的跳跃连接。
    • UNetWithAttention 的注意力机制进一步提升了特征质量。
  • DeepLab:
    • DeepLab 使用空洞卷积捕捉多尺度信息,适合大场景分割。
    • UNetWithAttention 更适合需要精细边界的任务(如农田分割),且计算量较小。

5. 代码中的实现细节

  • 模块初始化:
    • 使用 nn.ModuleList 动态构建编码器和解码器,便于扩展。
    • PAM 和 CAM 的 gamma 参数初始化为 0,逐步学习注意力权重。
  • 跳跃连接处理:
    • 使用 F.interpolate 确保尺寸匹配,增强鲁棒性。
  • 训练支持:
    • 模型与自定义损失函数(ContinuousLoss)和数据增强(EnhancedSegmentationDataset)无缝集成,优化分割性能。

6. 总结

UNetWithAttention 是一个功能强大且灵活的分割模型,结合了 U-Net 的跳跃连接和注意力机制的特征增强能力。其结构清晰,分为编码器、瓶颈层和解码器,通过 DoubleConv、PAM 和 CAM 模块实现高效特征提取和融合。模型特别适合需要精细边界分割的任务,如农田分割,同时具有较好的扩展性和可调性。