对模型结构 UNetWithAttention 的详细讲解

对模型结构 UNetWithAttention 的详细讲解
ytkzclass PAM_Module(nn.Module):
"""Position Attention Module"""
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma * out + x
return out
class CAM_Module(nn.Module):
"""Channel Attention Module"""
def __init__(self, in_dim):
super(CAM_Module, self).__init__()
self.chanel_in = in_dim
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
m_batchsize, C, height, width = x.size()
proj_query = x.view(m_batchsize, C, -1)
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
attention = self.softmax(energy_new)
proj_value = x.view(m_batchsize, C, -1)
out = torch.bmm(attention, proj_value)
out = out.view(m_batchsize, C, height, width)
out = self.gamma * out + x
return out
# ------------------- Enhanced U-Net Model -------------------
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UNetWithAttention(nn.Module):
def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
super(UNetWithAttention, self).__init__()
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
# Bottleneck with attention
self.bottleneck = DoubleConv(features[-1], features[-1]*2)
self.pam = PAM_Module(features[-1]*2)
self.cam = CAM_Module(features[-1]*2)
# Up part
for feature in reversed(features):
self.ups.append(
nn.ConvTranspose2d(
feature*2, feature, kernel_size=2, stride=2
)
)
self.ups.append(DoubleConv(feature*2, feature))
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
# Downsampling
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
# Bottleneck with attention
x = self.bottleneck(x)
x = self.pam(x)
x = self.cam(x)
# Upsampling
skip_connections = skip_connections[::-1]
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
x = F.interpolate(x, size=skip_connection.shape[2:], mode='bilinear', align_corners=True)
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx+1](concat_skip)
return self.final_conv(x)
这段代码实现了一个基于 U-Net 的分割模型,结合了注意力机制(PAM 和 CAM),用于图像分割任务(如农田分割)。以下是对模型结构 UNetWithAttention 的详细讲解,涵盖其组成部分、工作原理和设计意图。
1. 模型概述
UNetWithAttention 是一个改进的 U-Net 模型,专门为图像分割任务设计。它继承了 U-Net 的经典编码器-解码器结构,同时在瓶颈层加入了 位置注意力模块(PAM) 和 通道注意力模块(CAM),以增强特征提取能力。模型的输入是 RGB 图像(3 通道),输出是单通道的分割掩码(二值化预测)。
模型的核心结构包括:
- 编码器(Downsampling Path):通过卷积和池化操作逐步提取高级特征。
- 瓶颈层(Bottleneck):结合注意力机制增强特征表达。
- 解码器(Upsampling Path):通过上采样和跳跃连接恢复空间分辨率。
- 最终输出层:生成分割掩码。
2. 模型结构详解
2.1 整体架构
模型定义在 UNetWithAttention 类中,构造函数初始化了以下参数:
- in_channels:输入通道数,默认为 3(RGB 图像)。
- out_channels:输出通道数,默认为 1(单通道分割掩码)。
- features:编码器和解码器每层的特征通道数,默认为 [64, 128, 256, 512]。
模型由以下模块组成:
- downs:编码器模块列表,包含下采样操作。
- ups:解码器模块列表,包含上采样和特征融合操作。
- pool:最大池化层,用于下采样。
- bottleneck:瓶颈层,包含双重卷积和注意力机制。
- pam 和 cam:位置和通道注意力模块。
- final_conv:最终的 1x1 卷积层,生成分割掩码。
2.2 子模块详解
2.2.1 DoubleConv 模块
DoubleConv 是模型的基础构建块,广泛用于编码器、解码器和瓶颈层。它的结构是一个双重卷积块:
- 结构:
- 第一个 3x3 卷积(Conv2d),无偏置,保持特征图大小(padding=1)。
- 批归一化(BatchNorm2d),稳定训练。
- ReLU 激活函数(ReLU),引入非线性。
- 第二个 3x3 卷积,同样无偏置,保持特征图大小。
- 批归一化。
- ReLU 激活。
- 作用:
- 通过两次卷积增加网络深度,增强特征提取能力。
- 批归一化和 ReLU 提高训练稳定性和非线性表达。
- 输入输出:
- 输入:in_channels 通道的特征图。
- 输出:out_channels 通道的特征图。
2.2.2 PAM_Module(位置注意力模块)
PAM_Module 是一个空间注意力机制,用于增强模型对空间位置关系的建模能力。
- 结构:
- 查询(Query)分支:1x1 卷积将输入通道数从 in_dim 降到 in_dim//8。
- 键(Key)分支:1x1 卷积,同样降维。
- 值(Value)分支:1x1 卷积保持原始通道数。
- 注意力计算:
- 查询和键进行矩阵乘法,生成空间注意力图(energy)。
- 通过 Softmax 归一化生成注意力权重(attention)。
- 值与注意力权重相乘,生成加权特征。
- 残差连接:输出为 gamma * 加权特征 + 输入特征,其中 gamma 是可学习的缩放因子。
- 作用:
- 捕捉全局空间依赖关系,增强模型对不同位置像素的相关性建模。
- 通过降维降低计算量,同时保留重要信息。
- 输入输出:
- 输入:形状为 (batch_size, in_dim, height, width) 的特征图。
- 输出:相同形状,经过空间注意力增强的特征图。
2.2.3 CAM_Module(通道注意力模块)
CAM_Module 是一个通道注意力机制,关注不同通道之间的关系。
- 结构:
- 直接使用输入特征图作为查询、键和值。
- 注意力计算:
- 查询和键进行矩阵乘法,生成通道相关性矩阵(energy)。
- 通过能量归一化(减去最大值后 Softmax)生成通道注意力权重。
- 值与注意力权重相乘,生成加权特征。
- 残差连接:输出为 gamma * 加权特征 + 输入特征。
- 作用:
- 增强重要通道的权重,抑制无关通道,提升特征表达的针对性。
- 相比 PAM,CAM 更关注通道间的全局关系。
- 输入输出:
- 输入:形状为 (batch_size, in_dim, height, width) 的特征图。
- 输出:相同形状,经过通道注意力增强的特征图。
2.2.4 编码器(Downsampling Path)
编码器由 downs 模块列表组成,每个模块是一个 DoubleConv,后接最大池化层(pool)。
结构:
对于每个特征数
feature
(如 64, 128, 256, 512):
- DoubleConv(in_channels, feature):卷积操作提取特征。
- MaxPool2d(kernel_size=2, stride=2):下采样,空间尺寸减半。
通道数逐步增加:3 → 64 → 128 → 256 → 512。
空间分辨率逐步减小:如 256x256 → 128x128 → 64x64 → 32x32 → 16x16。
作用:
- 提取多尺度特征,从低级(边缘、纹理)到高级(语义信息)。
- 跳跃连接保存中间特征,供解码器使用。
跳跃连接:
- 每个 DoubleConv 的输出存储在 skip_connections 中,用于后续解码器融合。
2.2.5 瓶颈层(Bottleneck)
瓶颈层是模型的最深处,连接编码器和解码器。
- 结构:
- DoubleConv(features[-1], features[-1]*2):将通道数从 512 扩展到 1024。
- PAM_Module(features[-1]*2):空间注意力,增强空间关系。
- CAM_Module(features[-1]*2):通道注意力,增强通道重要性。
- 作用:
- 捕获深层语义信息。
- 注意力机制提高特征的质量,减少冗余信息。
- 输入输出:
- 输入:512 通道的特征图(如 16x16 分辨率)。
- 输出:1024 通道的特征图,经过注意力增强。
2.2.6 解码器(Upsampling Path)
解码器由 ups 模块列表组成,包含上采样和特征融合操作。
结构:
对于每个特征数
feature
(如 512, 256, 128, 64):
- ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2):上采样,空间尺寸翻倍。
- DoubleConv(feature*2, feature):融合跳跃连接特征并进一步处理。
跳跃连接融合:
- 将上采样特征与对应编码器层的特征(skip_connection)在通道维度拼接。
- 如果尺寸不匹配,使用双线性插值(F.interpolate)调整大小。
通道数逐步减少:1024 → 512 → 256 → 128 → 64。
空间分辨率逐步增大:如 16x16 → 32x32 → 64x64 → 128x128 → 256x256。
作用:
- 恢复空间分辨率,生成高分辨率的分割掩码。
- 跳跃连接融合低级和高级特征,保留细节和语义信息。
2.2.7 最终输出层
- 结构:
- Conv2d(features[0], out_channels, kernel_size=1):1x1 卷积将通道数从 64 降到 1。
- 作用:
- 生成单通道的分割预测(未激活的 logits)。
- 通过 Sigmoid 激活后,输出值为 [0, 1],表示像素属于目标类别的概率。
- 输入输出:
- 输入:64 通道的特征图(256x256 分辨率)。
- 输出:1 通道的分割掩码(256x256 分辨率)。
2.3 前向传播(Forward Pass)
UNetWithAttention 的前向传播过程如下:
- 编码器阶段:
- 输入图像(batch_size, 3, 256, 256)通过 downs 模块逐层卷积和池化。
- 每次池化后,空间尺寸减半,通道数增加,生成跳跃连接特征。
- 输出为 (batch_size, 512, 16, 16)。
- 瓶颈阶段:
- 输入 (batch_size, 512, 16, 16) 通过 DoubleConv 扩展到 (batch_size, 1024, 16, 16)。
- 应用 PAM_Module 和 CAM_Module,增强特征表达。
- 解码器阶段:
- 通过 ConvTranspose2d 上采样,空间尺寸逐步恢复。
- 拼接跳跃连接特征,输入 DoubleConv 进行融合和处理。
- 最终输出为 (batch_size, 64, 256, 256)。
- 输出阶段:
- 1x1 卷积生成 (batch_size, 1, 256, 256) 的分割掩码。
3. 设计意图与优势
- U-Net 结构:
- 编码器-解码器架构适合分割任务,能捕捉多尺度特征。
- 跳跃连接保留低级细节,解决深层网络中的信息丢失问题。
- 注意力机制:
- PAM:增强空间关系建模,适合处理复杂场景(如农田边界)。
- CAM:突出重要通道,减少冗余,提升特征针对性。
- 在瓶颈层添加注意力模块,避免过早干扰低级特征提取。
- DoubleConv:
- 双重卷积增加非线性,提升特征提取能力。
- 批归一化加速收敛,ReLU 引入稀疏性。
- 灵活性:
- 可调整 features 参数,适应不同任务的计算需求。
- 支持多通道输入/输出,适配多种分割任务。
4. 与其他模型的对比
- 传统 U-Net:
- 相比传统 U-Net,UNetWithAttention 在瓶颈层引入注意力机制,增强了特征选择能力,适合复杂场景。
- 计算复杂度略高,但性能提升显著。
- SegNet:
- SegNet 使用池化索引上采样,内存效率高,但细节恢复不如 U-Net 的跳跃连接。
- UNetWithAttention 的注意力机制进一步提升了特征质量。
- DeepLab:
- DeepLab 使用空洞卷积捕捉多尺度信息,适合大场景分割。
- UNetWithAttention 更适合需要精细边界的任务(如农田分割),且计算量较小。
5. 代码中的实现细节
- 模块初始化:
- 使用 nn.ModuleList 动态构建编码器和解码器,便于扩展。
- PAM 和 CAM 的 gamma 参数初始化为 0,逐步学习注意力权重。
- 跳跃连接处理:
- 使用 F.interpolate 确保尺寸匹配,增强鲁棒性。
- 训练支持:
- 模型与自定义损失函数(ContinuousLoss)和数据增强(EnhancedSegmentationDataset)无缝集成,优化分割性能。
6. 总结
UNetWithAttention 是一个功能强大且灵活的分割模型,结合了 U-Net 的跳跃连接和注意力机制的特征增强能力。其结构清晰,分为编码器、瓶颈层和解码器,通过 DoubleConv、PAM 和 CAM 模块实现高效特征提取和融合。模型特别适合需要精细边界分割的任务,如农田分割,同时具有较好的扩展性和可调性。