主页 > IT业界  > 

Transformer代码剖析7-词元嵌入(TokenEmbedding)(pytorch实现)

Transformer代码剖析7-词元嵌入(TokenEmbedding)(pytorch实现)
一、类定义与继承关系剖析 1.1 代码结构图示 #mermaid-svg-9COHbtmHJhpiroHM {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-9COHbtmHJhpiroHM .error-icon{fill:#552222;}#mermaid-svg-9COHbtmHJhpiroHM .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-9COHbtmHJhpiroHM .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-9COHbtmHJhpiroHM .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-9COHbtmHJhpiroHM .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-9COHbtmHJhpiroHM .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-9COHbtmHJhpiroHM .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-9COHbtmHJhpiroHM .marker{fill:#333333;stroke:#333333;}#mermaid-svg-9COHbtmHJhpiroHM .marker.cross{stroke:#333333;}#mermaid-svg-9COHbtmHJhpiroHM svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-9COHbtmHJhpiroHM .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-9COHbtmHJhpiroHM .cluster-label text{fill:#333;}#mermaid-svg-9COHbtmHJhpiroHM .cluster-label span{color:#333;}#mermaid-svg-9COHbtmHJhpiroHM .label text,#mermaid-svg-9COHbtmHJhpiroHM span{fill:#333;color:#333;}#mermaid-svg-9COHbtmHJhpiroHM .node rect,#mermaid-svg-9COHbtmHJhpiroHM .node circle,#mermaid-svg-9COHbtmHJhpiroHM .node ellipse,#mermaid-svg-9COHbtmHJhpiroHM .node polygon,#mermaid-svg-9COHbtmHJhpiroHM .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-9COHbtmHJhpiroHM .node .label{text-align:center;}#mermaid-svg-9COHbtmHJhpiroHM .node.clickable{cursor:pointer;}#mermaid-svg-9COHbtmHJhpiroHM .arrowheadPath{fill:#333333;}#mermaid-svg-9COHbtmHJhpiroHM .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-9COHbtmHJhpiroHM .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-9COHbtmHJhpiroHM .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-9COHbtmHJhpiroHM .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-9COHbtmHJhpiroHM .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-9COHbtmHJhpiroHM .cluster text{fill:#333;}#mermaid-svg-9COHbtmHJhpiroHM .cluster span{color:#333;}#mermaid-svg-9COHbtmHJhpiroHM div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-9COHbtmHJhpiroHM :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;}#mermaid-svg-9COHbtmHJhpiroHM .base>*{fill:#f0f0f0!important;stroke:#333!important;stroke-width:2px!important;}#mermaid-svg-9COHbtmHJhpiroHM .base span{fill:#f0f0f0!important;stroke:#333!important;stroke-width:2px!important;}#mermaid-svg-9COHbtmHJhpiroHM .inherit>*{fill:#e6f7ff!important;stroke:#3399ff!important;stroke-width:2px!important;}#mermaid-svg-9COHbtmHJhpiroHM .inherit span{fill:#e6f7ff!important;stroke:#3399ff!important;stroke-width:2px!important;}#mermaid-svg-9COHbtmHJhpiroHM .method>*{fill:#fff3e0!important;stroke:#ffa726!important;stroke-width:2px!important;}#mermaid-svg-9COHbtmHJhpiroHM .method span{fill:#fff3e0!important;stroke:#ffa726!important;stroke-width:2px!important;}#mermaid-svg-9COHbtmHJhpiroHM .param>*{fill:#d4edda!important;stroke:#28a745!important;stroke-width:2px!important;}#mermaid-svg-9COHbtmHJhpiroHM .param span{fill:#d4edda!important;stroke:#28a745!important;stroke-width:2px!important;} 神经网络基础模块 词嵌入基类 自定义词元嵌入 构造函数定义 基类初始化 词汇量参数 维度参数 填充标识参数 1.2 代码实现精讲 """ @author : Hyunwoong @when : 2019-10-22 @homepage : github /gusdnd852 """ from torch import nn class TokenEmbedding(nn.Embedding): """ 基于PyTorch实现的动态词元嵌入模块 实现词元索引到高维向量的可学习映射 核心功能:将离散的词元序列转换为连续的语义空间表示 """ def __init__(self, vocab_size, d_model): """ 词元嵌入构造器 :param vocab_size: 词表容量(不同词元的总数) :param d_model: 嵌入维度(与Transformer模型维度一致) 设计要点: - 继承nn.Embedding的矩阵运算特性 - 固化填充索引为可训练参数 - 保持维度与模型其他组件兼容 """ super(TokenEmbedding, self).__init__( vocab_size, # 嵌入数量 num_embeddings # 嵌入矩阵行数 = 词表大小 d_model, # 嵌入维度 embedding_dim # 嵌入矩阵列数 = 模型维度 padding_idx=1 # 填充符索引的特殊处理 ) 二、核心参数深度解读 2.1 参数矩阵可视化

假设词表容量vocab_size=10000,模型维度d_model=512时:

参数维度元素数量数学意义weight[10000,512]5,120,000可训练的嵌入查询矩阵padding_idxscalar1动态掩码位置标识 2.2 关键参数说明

1. vocab_size

控制嵌入矩阵的行维度决定模型可处理的词元种类上限典型值域:BERT系列(~30000),GPT系列(~50000)

2. d_model

控制嵌入向量的列维度与Transformer隐藏层维度严格对齐典型值域:512(原始论文)、768(BERT-base)、1024(大型模型)

3. padding_idx

实现动态序列掩码的关键参数索引位置对应的梯度会被自动抑制防止填充符影响模型语义理解 三、运算过程分步推演 3.1 前向传播示例

输入序列:[3, 28, 1, 0] (1为填充符)

运算步骤:

1. 建立索引映射:

[[3], → [[0.2, -0.5, ..., 1.2], # 索引3的嵌入 [28], → [0.7, 1.1, ..., -0.3], # 索引28的嵌入 [1], → [0.0, 0.0, ..., 0.0], # 填充符固定值 [0]] → [-0.9, 0.4, ..., 0.1]] # 索引0的嵌入

2. 矩阵缩放(后续处理):

embeddings * sqrt(d_model) # 维度对齐的数学技巧 3.2 梯度传播特性 可微分性: 整个映射过程保持梯度通路参数更新: 通过反向传播调整嵌入矩阵特殊处理: padding_idx位置梯度始终为0 四、设计哲学解析 4.1 继承关系价值 #mermaid-svg-yPOt5xNZ81fG2rSl {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-yPOt5xNZ81fG2rSl .error-icon{fill:#552222;}#mermaid-svg-yPOt5xNZ81fG2rSl .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-yPOt5xNZ81fG2rSl .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-yPOt5xNZ81fG2rSl .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-yPOt5xNZ81fG2rSl .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-yPOt5xNZ81fG2rSl .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-yPOt5xNZ81fG2rSl .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-yPOt5xNZ81fG2rSl .marker{fill:#333333;stroke:#333333;}#mermaid-svg-yPOt5xNZ81fG2rSl .marker.cross{stroke:#333333;}#mermaid-svg-yPOt5xNZ81fG2rSl svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-yPOt5xNZ81fG2rSl .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-yPOt5xNZ81fG2rSl .cluster-label text{fill:#333;}#mermaid-svg-yPOt5xNZ81fG2rSl .cluster-label span{color:#333;}#mermaid-svg-yPOt5xNZ81fG2rSl .label text,#mermaid-svg-yPOt5xNZ81fG2rSl span{fill:#333;color:#333;}#mermaid-svg-yPOt5xNZ81fG2rSl .node rect,#mermaid-svg-yPOt5xNZ81fG2rSl .node circle,#mermaid-svg-yPOt5xNZ81fG2rSl .node ellipse,#mermaid-svg-yPOt5xNZ81fG2rSl .node polygon,#mermaid-svg-yPOt5xNZ81fG2rSl .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-yPOt5xNZ81fG2rSl .node .label{text-align:center;}#mermaid-svg-yPOt5xNZ81fG2rSl .node.clickable{cursor:pointer;}#mermaid-svg-yPOt5xNZ81fG2rSl .arrowheadPath{fill:#333333;}#mermaid-svg-yPOt5xNZ81fG2rSl .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-yPOt5xNZ81fG2rSl .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-yPOt5xNZ81fG2rSl .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-yPOt5xNZ81fG2rSl .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-yPOt5xNZ81fG2rSl .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-yPOt5xNZ81fG2rSl .cluster text{fill:#333;}#mermaid-svg-yPOt5xNZ81fG2rSl .cluster span{color:#333;}#mermaid-svg-yPOt5xNZ81fG2rSl div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-yPOt5xNZ81fG2rSl :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;}#mermaid-svg-yPOt5xNZ81fG2rSl .custom>*{fill:#ffeb3b!important;stroke:#333!important;}#mermaid-svg-yPOt5xNZ81fG2rSl .custom span{fill:#ffeb3b!important;stroke:#333!important;}#mermaid-svg-yPOt5xNZ81fG2rSl .torch>*{fill:#e0f7fa!important;stroke:#333!important;}#mermaid-svg-yPOt5xNZ81fG2rSl .torch span{fill:#e0f7fa!important;stroke:#333!important;}#mermaid-svg-yPOt5xNZ81fG2rSl .base>*{fill:#f0f0f0!important;stroke:#333!important;}#mermaid-svg-yPOt5xNZ81fG2rSl .base span{fill:#f0f0f0!important;stroke:#333!important;} TokenEmbedding torch.nn.Embedding torch.nn.Module PyTorch基础设施

优势分析:

复用性:继承矩阵运算和参数管理功能扩展性:保留自定义前向传播的可能性兼容性:无缝对接PyTorch生态工具 4.2 工程实践建议

1. 初始化技巧:

默认采用均匀分布 U ( − 1 d m o d e l , 1 d m o d e l ) U(-\sqrt{\frac{1}{d_{model}}}, \sqrt{\frac{1}{d_{model}}}) U(−dmodel​1​ ​,dmodel​1​ ​)可扩展为Xavier/Kaiming初始化:# Xavier均匀初始化(默认) nn.init.xavier_uniform_(self.weight) # 特殊处理填充符 self.weight.data[1].zero_()

2. 维度对齐策略:

# 与位置编码相加前的缩放 embeddings = embeddings * math.sqrt(d_model)

3. 混合精度训练:

# 自动转换为半精度 with autocast(): embeddings = embedding_layer(input_ids)

4. 填充符处理机制:

训练阶段自动跳过无效位置的计算推理阶段维持序列形状一致性

5. 计算复杂度分析:

时间复杂度: O ( B ⋅ S ⋅ D ) O(B \cdot S \cdot D) O(B⋅S⋅D)空间复杂度: O ( V ⋅ D ) O(V \cdot D) O(V⋅D)

完整实现细节可参考PyTorch中sparse.py 模块解析的相关文章(嵌入(Embedding)基类代码解析)或PyTorch官方Embedding文档。

标签:

Transformer代码剖析7-词元嵌入(TokenEmbedding)(pytorch实现)由讯客互联IT业界栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“Transformer代码剖析7-词元嵌入(TokenEmbedding)(pytorch实现)