主页 > 手机  > 

Pytorch中的ebmedding到底怎么理解?

Pytorch中的ebmedding到底怎么理解?

在 PyTorch 中,nn.Embedding 是一个用于处理离散符号映射到连续向量空间的模块。它通常用于自然语言处理(NLP)任务(如词嵌入)、处理分类特征,或任何需要将离散索引转换为密集向量的场景。


核心理解

功能:

将离散的整数索引(例如单词的索引、类别ID)映射为固定维度的连续向量。这些向量是可学习的参数,在训练过程中通过反向传播优化。

参数:

num_embeddings:词汇表的大小(有多少个唯一的符号/类别)。embedding_dim:每个符号对应的向量维度。例如:nn.Embedding(1000, 128) 表示将 1000 个符号映射到 128 维的向量空间。

输入与输出:

输入:一个整数张量,形状为 (*)(可以是任意维度,通常是 [batch_size, sequence_length])。输出:形状为 (*, embedding_dim) 的张量。例如,输入形状为 [2, 3],输出为 [2, 3, 128]。
工作原理

内部权重矩阵:

nn.Embedding 内部维护一个形状为 (num_embeddings, embedding_dim) 的权重矩阵。当输入索引 i 时,输出是该矩阵的第 i 行(即 weight[i])。

类比 One-Hot + 全连接层:

可以理解为对输入进行 One-Hot 编码,然后通过一个 无偏置的全连接层。例如,输入 3 会转换为一个 One-Hot 向量 [0,0,0,1,0,...],再与权重矩阵相乘,直接取出第 3 行的向量。但实际实现是高效的直接索引查找,避免了显式的 One-Hot 计算。
使用示例 import torch import torch.nn as nn # 定义 Embedding 层:10 个符号,每个符号映射到 3 维向量 embedding = nn.Embedding(num_embeddings=10, embedding_dim=3) # 输入:形状为 [2, 4] 的整数张量(例如,两个样本,每个样本长度为4) input_indices = torch.LongTensor([[1,2,4,5], [4,3,2,9]]) # 输出:形状为 [2, 4, 3] output = embedding(input_indices) print(output)
关键特性

可学习的参数:

通过 embedding.weight 可以访问或修改权重矩阵(例如加载预训练词向量)。默认初始化:权重矩阵的值从正态分布 N(0,1) 中随机采样。

填充索引(Padding):

通过 padding_idx 参数指定填充位置的索引(例如 padding_idx=0),使该位置的向量在训练中不更新。

冻结权重:

通过 embedding.weight.requires_grad_(False) 可以冻结参数,使其不参与训练。
应用场景

词嵌入(Word Embedding):

vocab_size = 5000 # 词汇表大小 embedding_dim = 300 embedding_layer = nn.Embedding(vocab_size, embedding_dim)

类别特征嵌入:

处理分类特征时,将类别ID转换为向量(类似One-Hot的密集版本)。

推荐系统:

用户ID、物品ID的嵌入表示。
注意事项

输入范围:

输入的索引必须在 [0, num_embeddings-1] 范围内,否则会报错。

梯度传播:

只有实际被用到的索引对应的向量会更新梯度(未被使用的索引不影响模型参数)。

预训练初始化:

可以加载预训练的权重(如 Word2Vec、GloVe):embedding_layer.weight.data.copy_(torch.from_numpy(pretrained_matrix))
总结

nn.Embedding 是 PyTorch 中实现嵌入操作的核心模块,它将离散符号映射到连续的语义空间,是处理符号数据的基础工具。通过训练,模型可以自动学习符号之间的语义关系(例如相似性)。

标签:

Pytorch中的ebmedding到底怎么理解?由讯客互联手机栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“Pytorch中的ebmedding到底怎么理解?