残差收缩模块
- IT业界
- 2025-09-15 02:24:02

1. 多尺度阈值生成
创新思路:融合不同尺度的统计信息(如平均池化+最大池化)生成更鲁棒的阈值。
class MultiScaleShrinkage(nn.Module): def __init__(self, channel, reduction=4): super().__init__() # 多尺度池化分支 self.avg_pool = nn.AdaptiveAvgPool1d(1) self.max_pool = nn.AdaptiveMaxPool1d(1) # 双分支融合 self.fc = nn.Sequential( nn.Linear(channel*2, channel//reduction), nn.ReLU(), nn.Linear(channel//reduction, channel), nn.Sigmoid() ) def forward(self, x): residual = x x_abs = torch.abs(x) # 双分支池化 avg = self.avg_pool(x_abs).squeeze(-1) max_ = self.max_pool(x_abs).squeeze(-1) combined = torch.cat([avg, max_], dim=-1) # (B, 2C) threshold = self.fc(combined).unsqueeze(-1) # 后续软阈值处理相同 2. 空间-通道协同阈值化创新思路:引入空间注意力机制,实现通道与空间联合自适应。
class SpatioChannelShrinkage(nn.Module): def __init__(self, channel, reduction=4): super().__init__() # 通道分支 self.channel_fc = nn.Sequential( nn.Linear(channel, channel//reduction), nn.ReLU(), nn.Linear(channel//reduction, channel), nn.Sigmoid() ) # 空间分支(1D卷积) self.spatial_conv = nn.Sequential( nn.Conv1d(channel, 1, kernel_size=3, padding=1), nn.Sigmoid() ) def forward(self, x): residual = x x_abs = torch.abs(x) # 通道阈值 channel_avg = x_abs.mean(-1) # (B,C) channel_th = self.channel_fc(channel_avg).unsqueeze(-1) # (B,C,1) # 空间阈值 spatial_th = self.spatial_conv(x_abs) # (B,1,L) # 联合阈值 combined_th = channel_th * spatial_th # (B,C,L) # 动态软阈值 sub = x_abs - combined_th return torch.sign(residual) * torch.clamp_min(sub, 0) 3. 可微硬阈值化创新思路:通过自适应选择软/硬阈值化,增强特征选择性。
class AdaptiveThreshold(nn.Module): def __init__(self, channel): super().__init__() # 可学习阈值比例系数 self.alpha = nn.Parameter(torch.randn(1, channel, 1)) def forward(self, x): x_abs = torch.abs(x) threshold = self.alpha * x_abs.mean(-1, keepdim=True) # 硬阈值直通式梯度 mask = (x_abs > threshold).float() return x * mask 4. 轻量化动态卷积阈值创新思路:用深度可分离卷积替代全连接层,减少参数量。
class LightShrinkage(nn.Module): def __init__(self, channel): super().__init__() # 深度可分离卷积 self.dw_conv = nn.Sequential( nn.Conv1d(channel, channel, kernel_size=3, padding=1, groups=channel), nn.AdaptiveAvgPool1d(1), nn.Conv1d(channel, channel, kernel_size=1), nn.Sigmoid() ) def forward(self, x): residual = x x_abs = torch.abs(x) # 通过卷积提取局部模式 threshold = self.dw_conv(x_abs) sub = x_abs - threshold return torch.sign(residual) * torch.relu(sub) 5. 残差收缩增强创新思路:引入残差连接避免信息丢失,增强梯度流动。
class ResidualShrinkage(nn.Module): def __init__(self, channel): super().__init__() self.shrink = Shrinkage(channel) # 原收缩模块 def forward(self, x): return x + self.shrink(x) # 残差连接 创新方向总结 方向关键改进适用场景多尺度统计融合平均/最大池化高噪声数据空间-通道协同1D卷积+通道注意力需要局部上下文的任务软硬阈值结合可学习阈值类型需精确特征选择的场景轻量化设计深度可分离卷积移动端/实时处理残差增强收缩结果与原始输入相加深层网络训练稳定性建议通过消融实验验证不同改进方案的有效性,根据具体任务选择最佳组合。例如,对于高噪声时序信号处理,多尺度+空间通道协同的方案可能更有效;而对于计算资源受限的场景,轻量化设计更为合适。
 
               
               
               
               
               
               
               
   
   
   
   
   
   
   
   
   
   
   
  