主页 > 软件开发  > 

DeepseekNativelySparseAttention

DeepseekNativelySparseAttention
NSA(Natively Sparse Attention)论文原理解析

论文标题: Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention 作者团队: DeepSeek-AI, Peking University, University of Washington 核心目标: 提出一种高效、可训练的稀疏注意力机制,以提高长文本处理的计算效率,同时保持模型性能。


1. NSA 研究背景 1.1 长文本建模的挑战 现代大模型(如 GPT-4, Gemini 1.5)需要处理超长文本(64k 甚至更长)。传统 全注意力(Full Attention) 计算复杂度为 O(N²),在长文本上计算开销巨大,导致 训练和推理效率低下。现有的稀疏注意力(Sparse Attention)方法在 训练阶段支持较弱,通常只优化推理阶段。 1.2 现有稀疏注意力方法的局限 理论计算减少 ≠ 真实速度提升: 许多方法仅在 推理阶段(Inference Stage) 优化,忽略了 训练时间的计算成本。例如 H2O、Quest 关注 KV 缓存剪枝(KV-cache pruning),但 在真实硬件上加速有限。 训练阶段支持不足: 许多方法采用 离散选择(Discrete Selection),导致 梯度无法回传,难以进行端到端训练(End-to-End Training)。
2. NSA 方法:基于层次化的稀疏注意力

NSA 提出的创新点:

层次化稀疏策略(Hierarchical Sparse Strategy) 结合 粗粒度 token 压缩(Compression) 和 细粒度 token 选择(Selection),同时保留 全局信息 和 局部精度。 硬件优化(Hardware-Aligned System) 设计 适用于现代 GPU(如 A100, H100)的优化算子,提升推理效率。 可训练性增强(Natively Trainable Design) 允许在 训练阶段 进行稀疏优化,而不仅仅是在推理阶段加速。 2.1 NSA 关键机制

NSA 通过 三种注意力路径 进行计算:

压缩注意力(Compressed Attention) 通过块级 Token 压缩(Blockwise Token Compression),减少计算开销。 选择性注意力(Selected Attention) 仅保留 Top-k 重要 token,忽略不重要的 Token,提高计算效率。 滑动窗口注意力(Sliding Attention) 确保局部上下文不会丢失,提高信息完整性。 NSA 计算过程 查询(Query) 经过 三种注意力路径 计算 注意力得分(Attention Score)。不同路径的注意力结果通过门控机制(Gating Mechanism)进行加权融合。最终得到优化后的注意力输出(Sparse Attention Output)。
3. NSA 在硬件上的优化 3.1 计算强度均衡(Arithmetic Intensity Balance) 在现代 GPU 上,计算强度(Arithmetic Intensity)决定了性能瓶颈: 高计算强度(Compute-Bound):计算单元占用率高,计算能力未完全发挥。低计算强度(Memory-Bound):计算单元空闲,受限于显存访问速度。 NSA 通过 块级计算(Blockwise Computation) 提高 计算密度(Compute Density),减少显存访问瓶颈。 3.2 Triton 自定义内核(Triton Kernel Optimization) 传统注意力计算 内存访问不连续,GPU 计算利用率低。NSA 通过 基于 Triton 的自定义 GPU 内核(Custom GPU Kernel for Sparse Selection): 组级数据加载(Group-Centric Data Loading):避免多次访问 KV 缓存,减少内存带宽压力。共享 KV 读取(Shared KV Fetching):减少重复数据加载,提高计算效率。
4. NSA 在实验中的表现 4.1 计算加速 相比全注意力(Full Attention),NSA 在 64k 序列上的速度提升最高可达 11.6×。在训练阶段,NSA 前向传播(Forward)速度提高 9.0×,反向传播(Backward)速度提高 6.0×。 4.2 模型性能 在多个 自然语言任务(NLP Benchmarks) 上,NSA 在 保持甚至超过全注意力性能 的同时,大幅提高计算效率。在 64k 长文本任务(LongBench Benchmark)中,NSA 超过所有现有稀疏注意力方法。 4.3 复杂推理能力 NSA 在 数学推理任务(AIME 24 Benchmark) 中表现出色: 在 8k 和 16k 上下文长度下,NSA 比全注意力基线提高 2.5× 和 1.6×。
5. NSA 的关键优势 特点NSA 贡献计算复杂度降低通过 层次化稀疏选择,将 O(N²) 降至 O(N log K)。硬件优化适配 GPU Tensor Cores,优化内存访问,提高计算效率。训练支持NSA 可训练(Natively Trainable),不同于只优化推理的稀疏方法。长文本处理能力在 64k 长文本任务上超越全注意力,同时加速 推理和训练。
6. 论文总结

NSA 通过 层次化稀疏注意力、硬件优化、训练可行性,在 计算加速和性能保持之间取得了平衡。 相较于现有方法,NSA 不仅优化了推理(Inference),还显著降低了训练(Training)计算成本,为长文本建模提供了新的解决方案。


压缩注意力(Compressed Attention)机制解析

目标:

在保持全局信息的同时 降低计算复杂度,减少 Query-Key 计算量。通过 块级(blockwise)token 聚合,减少注意力计算中需要处理的 Key-Value 数量。
1. 为什么需要压缩注意力? 标准注意力机制:每个 Query q q q 需要计算所有 Key K K K 的注意力分数,计算复杂度为 O ( N 2 ) O(N^2) O(N2)。稀疏注意力(Sparse Attention):可以减少部分 Query-Key 计算,但仍然面临计算量和显存占用的问题。压缩注意力(Compressed Attention) 通过 对 Key-Value 进行块级聚合,减少 Key-Value 数量,降低计算复杂度。
2. 压缩注意力的具体方法

NSA 采用 块级 token 聚合 的方式,将 Key-Value 压缩成更少的代表性 token。 这一过程可以分为 四步:

2.1. 按块划分 Key-Value 设 输入序列长度为 T T T,Key-Value 维度为 d k d_k dk​(Key 维度)和 d v d_v dv​(Value 维度)。选择 块大小(block size) l l l,把 Key-Value 分成多个块: 第 i i i 块的 Key 表示为:

K i = { k i ⋅ l , k i ⋅ l + 1 , … , k ( i + 1 ) ⋅ l − 1 } K_i = \{ k_{i \cdot l}, k_{i \cdot l+1}, \dots, k_{(i+1) \cdot l - 1} \} Ki​={ki⋅l​,ki⋅l+1​,…,k(i+1)⋅l−1​} - 第 i i i 块的 Value 表示为:

V i = { v i ⋅ l , v i ⋅ l + 1 , … , v ( i + 1 ) ⋅ l − 1 } V_i = \{ v_{i \cdot l}, v_{i \cdot l+1}, \dots, v_{(i+1) \cdot l - 1} \} Vi​={vi⋅l​,vi⋅l+1​,…,v(i+1)⋅l−1​} - 这样,原始 Key-Value 变成了 块级 Key-Value,大幅减少了 Key 的数量。

2.2. 计算块级 Key 的代表性 块级 Key K cmp K_{\text{cmp}} Kcmp​ 需要能够代表整个块的信息,可以用 平均池化(Mean Pooling) 或 可训练 MLP: 平均池化(Mean Pooling):

K cmp , i = 1 l ∑ j = 0 l − 1 K i ⋅ l + j K_{\text{cmp}, i} = \frac{1}{l} \sum_{j=0}^{l-1} K_{i \cdot l + j} Kcmp,i​=l1​j=0∑l−1​Ki⋅l+j​ - 可训练 MLP(Multi-Layer Perceptron):

K cmp , i = MLP ( K i ⋅ l : ( i + 1 ) ⋅ l ) K_{\text{cmp}, i} = \text{MLP}(K_{i \cdot l : (i+1) \cdot l}) Kcmp,i​=MLP(Ki⋅l:(i+1)⋅l​) - 其中 MLP 可以学习更丰富的特征,而平均池化计算量更低。

2.3. 计算块级 Value 块级 Value V cmp V_{\text{cmp}} Vcmp​ 也可以采用类似方法: 平均池化:

V cmp , i = 1 l ∑ j = 0 l − 1 V i ⋅ l + j V_{\text{cmp}, i} = \frac{1}{l} \sum_{j=0}^{l-1} V_{i \cdot l + j} Vcmp,i​=l1​j=0∑l−1​Vi⋅l+j​ - 或使用 MLP:

V cmp , i = MLP ( V i ⋅ l : ( i + 1 ) ⋅ l ) V_{\text{cmp}, i} = \text{MLP}(V_{i \cdot l : (i+1) \cdot l}) Vcmp,i​=MLP(Vi⋅l:(i+1)⋅l​) - 这样可以降低计算量,同时保留重要信息。

2.4. 使用压缩 Key-Value 计算注意力 计算 Query Q Q Q 和压缩后的 Key K cmp K_{\text{cmp}} Kcmp​ 之间的注意力:

A cmp = Q K cmp T d k A_{\text{cmp}} = \frac{Q K_{\text{cmp}}^T}{\sqrt{d_k}} Acmp​=dk​ ​QKcmpT​​

计算 Softmax:

A cmp ′ = Softmax ( A cmp ) A'_{\text{cmp}} = \text{Softmax}(A_{\text{cmp}}) Acmp′​=Softmax(Acmp​)

计算最终的注意力输出:

O cmp = A cmp ′ V cmp O_{\text{cmp}} = A'_{\text{cmp}} V_{\text{cmp}} Ocmp​=Acmp′​Vcmp​


3. 压缩注意力的优势 对比项普通注意力稀疏注意力(Sparse Attention)压缩注意力(Compressed Attention)计算复杂度 O ( N 2 ) O(N^2) O(N2) O ( N log ⁡ k ) O(N \log k) O(Nlogk) O ( N ⋅ M ) O(N \cdot M) O(N⋅M)(( M \ll N \))信息保留完整信息仅保留 Top-k 信息保留全局信息,同时减少计算量适用场景短文本长文本,但计算仍然较大适合超长文本(64k+),计算高效 相比全注意力(Full Attention),压缩注意力减少了计算量。相比其他稀疏注意力方法,压缩注意力能保留更多全局信息,同时具有更好的计算效率。
4. 代码示例

这里是一个 PyTorch 实现的 压缩注意力:

import torch import torch.nn as nn class CompressedAttention(nn.Module): def __init__(self, embed_dim, num_heads, block_size=32): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.block_size = block_size self.head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, query, key, value): B, T, C = query.size() # Batch, Sequence Length, Embedding Dimension # Projection Q = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # Blockwise compression (mean pooling) num_blocks = T // self.block_size K_cmp = K.view(B, self.num_heads, num_blocks, self.block_size, self.head_dim).mean(dim=3) V_cmp = V.view(B, self.num_heads, num_blocks, self.block_size, self.head_dim).mean(dim=3) # Compute attention with compressed keys attn_weights = torch.matmul(Q, K_cmp.transpose(-2, -1)) / (self.head_dim ** 0.5) attn_weights = torch.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, V_cmp) # Reshape and output attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(attn_output) # 示例调用 B, T, C = 2, 64, 128 # Batch size, Sequence length, Embedding dimension num_heads = 8 block_size = 16 attention = CompressedAttention(C, num_heads, block_size) query = torch.randn(B, T, C) key = torch.randn(B, T, C) value = torch.randn(B, T, C) output = attention(query, key, value) print(output.shape) # (B, T, C)
5. 总结 压缩注意力(Compressed Attention) 通过 块级聚合 Key-Value,大幅降低计算量,同时保留全局信息。计算复杂度降低: O ( N 2 ) → O ( N ⋅ M ) O(N^2) \to O(N \cdot M) O(N2)→O(N⋅M),其中 M ≪ N M \ll N M≪N(压缩后的块数)。适用于超长文本建模,在 64k 甚至更长的序列 上能够高效工作。硬件友好,支持 GPU Tensor Core 优化,减少显存占用。 选择性注意力(Selected Attention)机制解析

目标:

选择 最重要的 Key-Value 进行计算,而不是对所有 Key 计算注意力,从而降低计算复杂度。通过 Top-K 选择策略,保留最关键的信息,减少冗余计算,提高长序列建模能力。
1. 为什么需要选择性注意力? 普通注意力(Full Attention) 计算复杂度 O(N²),当序列长度很长(如 64k+),计算量巨大。压缩注意力(Compressed Attention) 通过 块级聚合 降低计算量,但可能损失部分细节信息。选择性注意力(Selected Attention) 进一步优化,只保留最重要的 Token 参与计算,避免处理不重要的信息,减少计算开销,同时保持全局和局部信息。
2. 选择性注意力的核心步骤

NSA 采用 基于注意力得分的动态 Top-K 选择(Top-K Token Selection) 方法来筛选关键 Token:

2.1. 计算 Query-Key 相关性

首先,计算 查询(Query) 和 所有键(Key) 的相似性(即注意力分数):

A = Q K T d k A = \frac{Q K^T}{\sqrt{d_k}} A=dk​ ​QKT​

其中:

A A A 是注意力分数矩阵,形状为 ( B , H , T , T ) (B, H, T, T) (B,H,T,T),表示每个 Query 对应 Key 的注意力得分。 2.2. 选择 Top-K 重要 Token 对于每个 Query,选择 Top-K 重要的 Key,其余的 Key 设为 − ∞ -\infty −∞(即被 Mask)。具体实现: 计算每个 Query 对所有 Key 的注意力分数。使用 Top-K 算法 找出最大的 K K K 个值,索引存入 I top-k I_{\text{top-k}} Itop-k​:

I top-k = argtopk ( A , K ) I_{\text{top-k}} = \text{argtopk}(A, K) Itop-k​=argtopk(A,K) - 构造稀疏化的注意力分数矩阵:

A i j ′ = { A i j , j ∈ I top-k ( i ) − ∞ , 否则 A'_{ij} = \begin{cases} A_{ij}, & j \in I_{\text{top-k}}(i) \\ -\infty, & \text{否则} \end{cases} Aij′​={Aij​,−∞,​j∈Itop-k​(i)否则​ - 这样,我们 只在最重要的 Top-K Token 上计算 Softmax:

A ~ = Softmax ( A ′ ) \tilde{A} = \text{Softmax}(A') A~=Softmax(A′)

2.3. 计算注意力输出

最终,用选择的 Top-K 注意力分数 计算新的 Value 权重求和:

O = A ~ V O = \tilde{A} V O=A~V

这样,Query 只会与 最相关的 Key-Value 交互,提高计算效率,同时保留重要信息。


3. 选择性注意力的优势 方法计算复杂度信息保留能力适用场景全注意力(Full Attention) O ( N 2 ) O(N^2) O(N2)完整适用于短文本压缩注意力(Compressed Attention) O ( N ⋅ M ) O(N \cdot M) O(N⋅M)保留全局信息适用于长文本选择性注意力(Selected Attention) O ( N ⋅ K ) O(N \cdot K) O(N⋅K)只保留最重要信息适用于超长文本(64k+) 相比全注意力(Full Attention),选择性注意力只计算 Top-K 重要信息,大幅降低计算量。相比压缩注意力(Compressed Attention),选择性注意力能保留 更精确的局部信息,保证高精度。
4. PyTorch 实现

以下是 选择性注意力 的 PyTorch 代码:

import torch import torch.nn as nn import torch.nn.functional as F class SelectedAttention(nn.Module): def __init__(self, embed_dim, num_heads, top_k): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.top_k = top_k self.head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, query, key, value): B, T, C = query.shape # Projection Q = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, T, d_k) K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # Compute attention scores attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) # (B, H, T, T) # Select Top-K tokens topk_values, topk_indices = torch.topk(attn_scores, self.top_k, dim=-1) # (B, H, T, K) # Create a mask for non-Top-K elements mask = torch.full_like(attn_scores, float('-inf')) # Default mask mask.scatter_(-1, topk_indices, topk_values) # Retain Top-K values # Apply softmax on selected tokens attn_weights = F.softmax(mask, dim=-1) # Compute attention output attn_output = torch.matmul(attn_weights, V) # (B, H, T, d_k) # Reshape and output attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(attn_output) # 示例调用 B, T, C = 2, 64, 128 # Batch size, Sequence length, Embedding dimension num_heads = 8 top_k = 16 # 选择 Top-K 重要 Token attention = SelectedAttention(C, num_heads, top_k) query = torch.randn(B, T, C) key = torch.randn(B, T, C) value = torch.randn(B, T, C) output = attention(query, key, value) print(output.shape) # (B, T, C)
5. 选择性注意力的优化方向 Top-K 选择的优化 目前使用 torch.topk() 进行选择,时间复杂度为 O ( N log ⁡ K ) O(N \log K) O(NlogK)。可以优化为 Heap Sort + 近似选择算法,进一步提高效率。 自适应 K 值选择 目前的 K 值是固定的,可以使用 Learnable Gate 机制,让模型 动态决定 K 的大小。 结合其他稀疏注意力 压缩注意力 + 选择性注意力 可以同时 减少计算量 和 保留最关键信息,适合超长序列任务(64k+)。
6. 总结 选择性注意力(Selected Attention) 通过 Top-K 选择 只保留最重要的 Key-Value,降低计算量。计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降低到 O ( N ⋅ K ) O(N \cdot K) O(N⋅K),适用于 超长文本(64k+)。相较于全注意力和压缩注意力,选择性注意力能更精准地保留信息,同时减少计算成本。可进一步优化 通过 更快的 Top-K 选择算法 或 自适应 K 值选择 提升性能。 滑动窗口注意力(Sliding Attention)机制解析

目标:

在减少计算量的同时,保留局部上下文信息,确保模型能够感知短期依赖关系。结合 压缩注意力(Compressed Attention) 和 选择性注意力(Selected Attention),在局部窗口范围内保留完整的注意力计算,避免远程信息丢失。
1. 为什么需要滑动窗口注意力? 全注意力(Full Attention) 计算复杂度 O(N²),长序列(64k+)下计算成本极高。压缩注意力(Compressed Attention) 关注全局信息,但可能会丢失局部细节。选择性注意力(Selected Attention) 关注最关键的信息,但可能无法保留局部语境。滑动窗口注意力(Sliding Attention) 通过局部窗口机制,确保模型可以关注最近的信息,同时减少计算量。
2. 滑动窗口注意力的核心步骤

NSA 采用 基于局部窗口的注意力计算(Local Context Attention),主要分为 四步:

2.1. 定义窗口范围 设序列长度为 T T T,窗口大小设定为 W W W(window size),则对于每个 Query Q i Q_i Qi​,它只会计算:

K win , i = { k i − W , k i − W + 1 , … , k i } K_{\text{win}, i} = \{ k_{i-W}, k_{i-W+1}, \dots, k_i \} Kwin,i​={ki−W​,ki−W+1​,…,ki​}

V win , i = { v i − W , v i − W + 1 , … , v i } V_{\text{win}, i} = \{ v_{i-W}, v_{i-W+1}, \dots, v_i \} Vwin,i​={vi−W​,vi−W+1​,…,vi​} - 窗口只包含最近 W W W 个 Token,降低计算复杂度。 - 可变窗口机制:可根据任务需求设定不同的窗口大小(例如代码生成任务可能需要更大的窗口)。

2.2. 计算窗口内的 Query-Key 注意力 在窗口范围 W W W 内计算标准注意力:

A win , i = Q i K win , i T d k A_{\text{win}, i} = \frac{Q_i K_{\text{win}, i}^T}{\sqrt{d_k}} Awin,i​=dk​ ​Qi​Kwin,iT​​ - 相比于全局注意力(O(N²)),窗口内计算量为 O(N × W),显著降低复杂度。 - 仅关注 最近 W W W 个 Token,保证短期依赖关系。

2.3. 计算 Softmax 并加权求和 计算窗口内的注意力分布:

A win , i ′ = Softmax ( A win , i ) A'_{\text{win}, i} = \text{Softmax}(A_{\text{win}, i}) Awin,i′​=Softmax(Awin,i​)

计算最终的注意力输出:

O win , i = A win , i ′ V win , i O_{\text{win}, i} = A'_{\text{win}, i} V_{\text{win}, i} Owin,i​=Awin,i′​Vwin,i​

2.4. 结合其他注意力机制 最终输出:

O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmp​Ocmp​+gsel​Osel​+gwin​Owin​ - g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp​,gsel​,gwin​ 是可学习的门控参数(Gating Mechanism)。 - 这样可以在训练过程中,让模型学习最佳的注意力组合方式。


3. 滑动窗口注意力的优势 方法计算复杂度局部信息保留适用场景全注意力(Full Attention) O ( N 2 ) O(N^2) O(N2)✅ 完整适用于短文本压缩注意力(Compressed Attention) O ( N ⋅ M ) O(N \cdot M) O(N⋅M)⚠️ 可能丢失局部信息适用于长文本选择性注意力(Selected Attention) O ( N ⋅ K ) O(N \cdot K) O(N⋅K)⚠️ 仅保留关键 Token适用于超长文本滑动窗口注意力(Sliding Attention) O ( N ⋅ W ) O(N \cdot W) O(N⋅W)✅ 重点保留局部信息适用于超长文本(64k+) 相比全注意力(Full Attention),滑动窗口注意力 显著减少计算量。相比压缩注意力(Compressed Attention),滑动窗口注意力确保局部信息不会丢失。相比选择性注意力(Selected Attention),滑动窗口注意力不会忽略短期依赖。
4. PyTorch 实现

以下是 滑动窗口注意力(Sliding Attention) 的 PyTorch 代码:

import torch import torch.nn as nn import torch.nn.functional as F class SlidingWindowAttention(nn.Module): def __init__(self, embed_dim, num_heads, window_size): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.window_size = window_size self.head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, query, key, value): B, T, C = query.shape # Projection Q = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, T, d_k) K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # Initialize attention scores (masked) attn_scores = torch.full((B, self.num_heads, T, T), float('-inf'), device=query.device) # Apply sliding window mask for i in range(T): start_idx = max(0, i - self.window_size) attn_scores[:, :, i, start_idx:i+1] = torch.matmul( Q[:, :, i:i+1, :], K[:, :, start_idx:i+1, :].transpose(-2, -1) ) / (self.head_dim ** 0.5) # Compute attention with masked softmax attn_weights = F.softmax(attn_scores, dim=-1) # Compute attention output attn_output = torch.matmul(attn_weights, V) # (B, H, T, d_k) # Reshape and output attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(attn_output) # 示例调用 B, T, C = 2, 64, 128 # Batch size, Sequence length, Embedding dimension num_heads = 8 window_size = 16 # 滑动窗口大小 attention = SlidingWindowAttention(C, num_heads, window_size) query = torch.randn(B, T, C) key = torch.randn(B, T, C) value = torch.randn(B, T, C) output = attention(query, key, value) print(output.shape) # (B, T, C)
5. 进一步优化方向 动态窗口大小 当前的 窗口大小 W W W 是固定的,可以使用 自适应机制(Adaptive Window Size) 让模型学习最佳的窗口大小。 结合 FlashAttention 提高计算效率 目前的 滑动窗口计算仍然需要遍历 Query,可以优化成 块级计算(Blockwise Computation),提升 GPU 利用率。
6. 总结 滑动窗口注意力(Sliding Attention) 通过 局部窗口计算,减少计算量,同时保留最近的上下文信息。计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降低到 O ( N ⋅ W ) O(N \cdot W) O(N⋅W),适用于 超长文本(64k+)。结合其他注意力机制(压缩 + 选择性 + 滑动窗口)可以 提高计算效率,同时保留全局 + 局部信息。 NSA 论文中如何结合三种注意力机制?

在 Natively Sparse Attention(NSA) 机制中,作者采用了一种 层次化稀疏注意力策略(Hierarchical Sparse Strategy),将 压缩注意力(Compressed Attention)、选择性注意力(Selected Attention)和滑动窗口注意力(Sliding Attention) 结合,以 同时保留全局信息、关键 Token 以及局部信息,提高计算效率并优化长序列建模。


1. NSA 采用的三条注意力路径

NSA 通过以下三种不同的注意力计算路径,让 Transformer 既能高效处理长序列,又不会丢失关键信息:

压缩注意力(Compressed Attention) 作用:全局信息提取方式:将 Key-Value 进行 块级压缩,生成粗粒度的全局 Token 表示。计算复杂度: O ( N ⋅ M ) O(N \cdot M) O(N⋅M)(其中 M ≪ N M \ll N M≪N)。 选择性注意力(Selected Attention) 作用:筛选最关键的 Token 进行计算方式:对所有 Query 计算注意力分数,并选择 Top-K 重要 Token,仅对这些 Key 计算注意力。计算复杂度: O ( N ⋅ K ) O(N \cdot K) O(N⋅K)(其中 K ≪ N K \ll N K≪N)。 滑动窗口注意力(Sliding Attention) 作用:局部上下文信息保留方式:每个 Query 仅在其 最近的 W W W 个 Token 内 计算注意力,保留短期依赖信息。计算复杂度: O ( N ⋅ W ) O(N \cdot W) O(N⋅W)(其中 W ≪ N W \ll N W≪N)。

最终的注意力输出是三种机制的加权和:

O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmp​Ocmp​+gsel​Osel​+gwin​Owin​

其中:

g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp​,gsel​,gwin​ 是 可学习的门控参数(Gating Mechanism),用于控制不同注意力机制的重要性。
2. NSA 具体如何组合这三种注意力? (1) 计算 Query-Key 相关性

首先,对 Query 计算三种不同 Key 形式的注意力分数:

压缩 Key( K cmp K_{\text{cmp}} Kcmp​):计算 Query 和 压缩后的 Key 的相关性:

A cmp = Q K cmp T d k A_{\text{cmp}} = \frac{Q K_{\text{cmp}}^T}{\sqrt{d_k}} Acmp​=dk​ ​QKcmpT​​ 2. 选择性 Key( K sel K_{\text{sel}} Ksel​):计算 Query 和 Top-K 选择的 Key 的相关性:

A sel = Q K sel T d k A_{\text{sel}} = \frac{Q K_{\text{sel}}^T}{\sqrt{d_k}} Asel​=dk​ ​QKselT​​ 3. 滑动窗口 Key( K win K_{\text{win}} Kwin​):计算 Query 在 局部窗口范围内 的注意力:

A win = Q K win T d k A_{\text{win}} = \frac{Q K_{\text{win}}^T}{\sqrt{d_k}} Awin​=dk​ ​QKwinT​​

(2) 计算 Softmax 归一化

对每个注意力分数进行 Softmax 计算:

A ~ cmp = Softmax ( A cmp ) \tilde{A}_{\text{cmp}} = \text{Softmax}(A_{\text{cmp}}) A~cmp​=Softmax(Acmp​)

A ~ sel = Softmax ( A sel ) \tilde{A}_{\text{sel}} = \text{Softmax}(A_{\text{sel}}) A~sel​=Softmax(Asel​)

A ~ win = Softmax ( A win ) \tilde{A}_{\text{win}} = \text{Softmax}(A_{\text{win}}) A~win​=Softmax(Awin​)

(3) 计算注意力输出

计算不同注意力的加权求和:

O cmp = A ~ cmp V cmp O_{\text{cmp}} = \tilde{A}_{\text{cmp}} V_{\text{cmp}} Ocmp​=A~cmp​Vcmp​

O sel = A ~ sel V sel O_{\text{sel}} = \tilde{A}_{\text{sel}} V_{\text{sel}} Osel​=A~sel​Vsel​

O win = A ~ win V win O_{\text{win}} = \tilde{A}_{\text{win}} V_{\text{win}} Owin​=A~win​Vwin​

(4) 加权融合不同注意力结果

最终的输出由三种注意力结果加权融合:

O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmp​Ocmp​+gsel​Osel​+gwin​Owin​

其中:

g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp​,gsel​,gwin​ 是 可学习的门控参数,通过 MLP 计算:

g = σ ( MLP ( X ) ) g = \sigma(\text{MLP}(X)) g=σ(MLP(X))

其中 σ \sigma σ 是 Sigmoid 激活函数,确保 g g g 取值在 (0,1) 之间。


3. PyTorch 实现

以下是 结合三种注意力的 NSA 模型:

import torch import torch.nn as nn import torch.nn.functional as F class NSA(nn.Module): def __init__(self, embed_dim, num_heads, top_k, window_size): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.top_k = top_k self.window_size = window_size self.head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.gate_mlp = nn.Sequential(nn.Linear(embed_dim, 3), nn.Sigmoid()) # 生成3个门控权重 def forward(self, query, key, value): B, T, C = query.shape # Projection Q = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # 计算门控权重 gate_weights = self.gate_mlp(query).unsqueeze(-1).unsqueeze(-1) # (B, T, 3) -> (B, T, 3, 1, 1) # 压缩注意力 K_cmp = K.mean(dim=-2, keepdim=True) V_cmp = V.mean(dim=-2, keepdim=True) attn_cmp = torch.matmul(Q, K_cmp.transpose(-2, -1)) / (self.head_dim ** 0.5) O_cmp = torch.matmul(F.softmax(attn_cmp, dim=-1), V_cmp) # 选择性注意力 topk_values, topk_indices = torch.topk(attn_cmp, self.top_k, dim=-1) attn_sel = torch.zeros_like(attn_cmp).scatter_(-1, topk_indices, topk_values) O_sel = torch.matmul(F.softmax(attn_sel, dim=-1), V) # 滑动窗口注意力 attn_win = attn_cmp.masked_fill(torch.arange(T)[:, None] < (torch.arange(T) - self.window_size), float('-inf')) O_win = torch.matmul(F.softmax(attn_win, dim=-1), V) # 加权求和 O = gate_weights[..., 0] * O_cmp + gate_weights[..., 1] * O_sel + gate_weights[..., 2] * O_win O = O.transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(O) # 测试 B, T, C = 2, 64, 128 attention = NSA(C, num_heads=8, top_k=16, window_size=16) query = torch.randn(B, T, C) key = torch.randn(B, T, C) value = torch.randn(B, T, C) output = attention(query, key, value) print(output.shape) # (B, T, C)
总结 NSA 通过三种注意力机制的组合,既保证全局信息,又保留关键 Token 和局部上下文信息。最终的注意力结果通过可学习的门控机制(Gating Mechanism)进行融合,实现动态调整。计算复杂度降低到 O ( N log ⁡ K ) O(N \log K) O(NlogK),适用于超长文本(64k+)。

代码是AI生成的 还在调试中

标签:

DeepseekNativelySparseAttention由讯客互联软件开发栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“DeepseekNativelySparseAttention