HARU-Net:混合注意力 + 残差 U-Net,如何在 CBCT 降噪中"保住"齿根边界?
一句话总结
HARU-Net 把 CBAM 风格的混合注意力(通道+空间)嵌入残差 U-Net,解决了传统降噪方法”去噪越狠,边缘越糊”的内在矛盾,专为 CBCT 牙科影像设计。
为什么这篇论文重要?
问题的根源:CBCT 的噪声不是普通噪声
锥束 CT(Cone-Beam CT)的噪声来源是光子计数统计——泊松分布,不是高斯分布。低剂量 CBCT 为了减少辐射,光子数更少,噪声更重。
更麻烦的是,CBCT 还有结构性伪影:
- 光束硬化(Beam Hardening):金属修复体旁边的条纹,看起来像噪声,但不是
- 环形伪影(Ring Artifacts):探测器不均匀性导致的同心圆纹路
这意味着:任何”均匀平滑”的降噪策略都是错的。
现有方法的痛点
| 方法类别 | 代表 | 问题 |
|---|---|---|
| 滤波器法 | BM3D、NLM | 对 CBCT 非高斯噪声假设错误 |
| 标准 U-Net | DnCNN | 去噪即平滑,损伤高频边缘 |
| 自注意力 Transformer | Restormer | 医学图像数据少,容易过拟合 |
HARU-Net 的核心洞见
用注意力机制教网络区分”这里是噪声”和”这里是边缘”——两者在频率域上重叠,但在语义上完全不同。
核心方法解析
直觉:为什么”混合注意力”能保边缘?
想象你在 Photoshop 里降噪,你会:
- 先找到”这张图哪些区域是平坦的纹理(可以大力平滑)”
- 再找到”哪些是牙根/骨骼边界(不能碰)”
HARU-Net 的注意力模块做的正是这件事——只不过是自动学习的:
- 通道注意力(Channel Attention):学习”哪些特征图编码了边缘信息”,给这些 channel 更高权重
- 空间注意力(Spatial Attention):学习”哪些像素位置是边缘”,在那里抑制降噪力度
两者串联(先通道后空间),就是 CBAM 结构——这里被称为”Hybrid Attention”。
数学公式
通道注意力:
\[M_c(F) = \sigma\!\left(W_1 \operatorname{ReLU}\!\left(W_0 F^{avg}_c\right) + W_1 \operatorname{ReLU}\!\left(W_0 F^{max}_c\right)\right)\]其中 $F^{avg}_c, F^{max}_c$ 分别是全局平均池化和最大池化结果。
空间注意力:
\[M_s(F) = \sigma\!\left(f^{7\times7}\!\left([\operatorname{AvgPool}(F);\, \operatorname{MaxPool}(F)]\right)\right)\]残差降噪(Residual Learning):
论文采用噪声估计而非直接映射——网络学的不是干净图像,而是噪声本身:
\[\hat{x} = y - \mathcal{F}_\theta(y)\]这让残差连接有了双重含义:ResNet 的梯度通路 + DnCNN 的噪声残差学习。
整体架构
输入(噪声CBCT) → [编码器] → [瓶颈] → [解码器] → 输出(干净图像)
↕ skip connections(保留空间细节)
每个块内部:Conv → BN → ReLU → Conv → BN → HybridAttn → +残差
动手实现
核心模块:混合注意力(CBAM)
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChannelAttention(nn.Module):
"""通道注意力:学习"哪些特征图"重要"""
def __init__(self, in_channels, reduction=16):
super().__init__()
mid = max(in_channels // reduction, 4)
self.shared_mlp = nn.Sequential(
nn.Linear(in_channels, mid),
nn.ReLU(inplace=True),
nn.Linear(mid, in_channels)
)
def forward(self, x):
B, C, H, W = x.shape
avg = x.mean(dim=[2, 3]) # (B, C) 全局平均
mx = x.amax(dim=[2, 3]) # (B, C) 全局最大
# 两路共享 MLP,相加后激活
attn = torch.sigmoid(self.shared_mlp(avg) + self.shared_mlp(mx))
return x * attn.view(B, C, 1, 1) # 广播到空间维度
class SpatialAttention(nn.Module):
"""空间注意力:学习"哪些像素位置"重要"""
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
def forward(self, x):
avg = x.mean(dim=1, keepdim=True) # (B,1,H,W)
mx = x.amax(dim=1, keepdim=True) # (B,1,H,W)
attn = torch.sigmoid(self.conv(torch.cat([avg, mx], dim=1)))
return x * attn
class HybridAttention(nn.Module):
"""先通道后空间——顺序很重要"""
def __init__(self, in_channels, reduction=16):
super().__init__()
self.ca = ChannelAttention(in_channels, reduction)
self.sa = SpatialAttention()
def forward(self, x):
return self.sa(self.ca(x)) # 通道 → 空间
残差注意力块
class ResAttnBlock(nn.Module):
"""U-Net 每级的基础块:Conv-BN-ReLU × 2 + HybridAttn + 残差"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv_path = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
)
self.attn = HybridAttention(out_ch)
# 通道数不同时需要投影
self.proj = nn.Conv2d(in_ch, out_ch, 1, bias=False) if in_ch != out_ch else nn.Identity()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.attn(self.conv_path(x))
return self.relu(out + self.proj(x)) # 残差相加
HARU-Net 完整骨架
class HARUNet(nn.Module):
def __init__(self, in_ch=1, out_ch=1, features=(64, 128, 256, 512)):
super().__init__()
# 编码器:逐级下采样
self.encoders = nn.ModuleList()
self.pools = nn.ModuleList()
ch = in_ch
for f in features:
self.encoders.append(ResAttnBlock(ch, f))
self.pools.append(nn.MaxPool2d(2))
ch = f
# 瓶颈
self.bottleneck = ResAttnBlock(features[-1], features[-1] * 2)
# 解码器:逐级上采样 + skip connection
self.upconvs = nn.ModuleList()
self.decoders = nn.ModuleList()
ch = features[-1] * 2
for f in reversed(features):
self.upconvs.append(nn.ConvTranspose2d(ch, f, 2, stride=2))
self.decoders.append(ResAttnBlock(f * 2, f)) # *2 因为 concat skip
ch = f
self.head = nn.Conv2d(features[0], out_ch, 1)
def forward(self, x):
skips, inp = [], x
for enc, pool in zip(self.encoders, self.pools):
inp = enc(inp); skips.append(inp); inp = pool(inp)
inp = self.bottleneck(inp)
for up, dec, skip in zip(self.upconvs, self.decoders, reversed(skips)):
inp = up(inp)
# 处理奇数尺寸的边界情况
if inp.shape != skip.shape:
inp = F.interpolate(inp, size=skip.shape[2:])
inp = dec(torch.cat([inp, skip], dim=1))
# 残差输出:预测噪声,干净 = 输入 - 噪声
return x - self.head(inp)
边缘感知损失函数
这是论文最关键的工程决策——只用 MSE 是不够的:
class EdgePreservingLoss(nn.Module):
def __init__(self, edge_weight=0.3):
super().__init__()
self.w = edge_weight
# Sobel 算子检测边缘,固定权重不参与训练
sobel = torch.tensor([[[-1,0,1],[-2,0,2],[-1,0,1]],
[[-1,-2,-1],[0,0,0],[1,2,1]]], dtype=torch.float32)
self.register_buffer('sobel', sobel.unsqueeze(1)) # (2,1,3,3)
def _edge_map(self, x):
# x: (B,1,H,W) → 计算梯度幅值
grad = F.conv2d(x, self.sobel, padding=1) # (B,2,H,W)
return grad.pow(2).sum(dim=1, keepdim=True).sqrt() # (B,1,H,W)
def forward(self, pred, target):
loss_pixel = F.l1_loss(pred, target) # L1 比 MSE 更保边缘
loss_edge = F.l1_loss(self._edge_map(pred),
self._edge_map(target))
return loss_pixel + self.w * loss_edge
训练配置
model = HARUNet(in_ch=1, out_ch=1).cuda()
criterion = EdgePreservingLoss(edge_weight=0.3)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(100):
for noisy, clean in dataloader:
noisy, clean = noisy.cuda(), clean.cuda()
pred = model(noisy)
loss = criterion(pred, clean)
optimizer.zero_grad(); loss.backward(); optimizer.step()
scheduler.step()
实现中的坑
坑 1:CBCT 图像的 HU 值范围
CBCT 原始数据是 Hounsfield Unit(HU),范围 $[-1000, +3000]$。直接送入网络会导致梯度爆炸:
# 错误做法
img = dicom.pixel_array.astype(np.float32) # 原始 HU 值
# 正确做法:窗宽窗位归一化(牙科常用 W=4000, L=1000)
def normalize_cbct(img, window=4000, level=1000):
low, high = level - window/2, level + window/2
return np.clip((img - low) / window, 0.0, 1.0)
坑 2:注意力的 reduction ratio 在小图像上会崩
当 in_channels=64, reduction=16 时 mid=4,没问题。但如果误设 reduction=32,mid=2,MLP 表达力严重不足:
# HybridAttention 构造时加保护
mid = max(in_channels // reduction, 4) # 前面代码已经处理了这个
坑 3:跳跃连接的尺寸错位
CBCT 图像常见非 2 的幂次尺寸(如 $512 \times 492$)。下采样后上采样回来会差 1 个像素——已在 forward 中用 F.interpolate 修复,但要确认 align_corners=False(PyTorch 默认)。
坑 4:边缘损失权重不是越大越好
edge_weight > 0.5 时,网络倾向于”制造边缘”来降低损失,在平坦区域产生锐化伪影。建议从 0.1 开始网格搜索。
实验:论文说的 vs 现实
论文预期结果
在 CBCT 牙科数据集上,HARU-Net 应当在 PSNR/SSIM 上超过:
- BM3D:约 +2-3 dB PSNR
- 标准 U-Net:约 +1 dB PSNR
- 主观评价:齿根边界更清晰
复现时的现实条件
| 条件 | 影响 |
|---|---|
| 训练数据 < 200 对 | 注意力模块过拟合风险高,建议关掉 SpatialAttention 或加 Dropout |
| 金属植入物区域 | 模型往往在此失效,边缘损失反而学到了金属伪影轮廓 |
| 不同剂量设备间迁移 | 泛化性差,需要 domain adaptation 或 fine-tune |
| 3D CBCT vs 2D 切片 | 论文大概率用 2D 切片训练,3D 推理需要逐层或改 3D 卷积 |
什么时候用 / 不用这个方法?
| 适用场景 | 不适用场景 |
|---|---|
| 低剂量 CBCT 降噪(泊松噪声为主) | 金属伪影去除(需要专门的 MAR 方法) |
| 数据量 > 500 对,且有配对干净/噪声图 | 无配对数据(需换用自监督方法如 Noise2Void) |
| 边缘精度要求高(齿根、骨骼轮廓测量) | 实时推理(注意力计算有额外开销) |
| 单中心固定设备 | 多中心多设备部署(域偏移问题显著) |
我的观点
值得肯定的部分:CBAM 加入 U-Net 用于医学图像降噪,技术路线是合理的。边缘感知损失是必须的设计,不是可选项——任何面向诊断的降噪系统都应该包含某种边缘保护机制。
我的疑虑:
-
“Hybrid”的命名有点虚。CBAM 是 2018 年的工作,在这里被重新包装为”Hybrid Attention”。真正的创新点需要看论文是否在注意力结构上有实质改动(比如专门设计的边缘引导注意力),还是直接套用 CBAM。
-
配对数据的获取是真正的瓶颈。医院能提供的配对数据(同一患者低剂量+高剂量扫描)极为稀少。论文可能使用合成噪声(往干净图上加泊松噪声),但真实低剂量噪声的分布远比合成复杂。
-
与 Diffusion-based 方法的比较缺失。2024-2025 年,扩散模型在医学图像重建上的表现已经相当出色(如 DiffuseIR、score-based CT denoising)。HARU-Net 这类判别式方法的推理速度有优势,但效果上可能已被超越。
结论:如果你在做 CBCT/CT 降噪的工程落地,HARU-Net 的架构值得参考,特别是残差学习 + 混合注意力 + 边缘损失这个组合。但如果追求 SOTA,建议同时评估基于 score matching 的生成式方法。
注:本文基于论文摘要进行架构推断,完整实验细节(超参数、数据集划分、具体注意力变体)请参考原论文 arxiv: 2602.22544。
Comments