PyTorch大白话解释算子二
- 开源代码
- 2025-09-13 22:57:03

目录
一、reshape
1. 什么是 Reshape?
2. Reshape 的核心作用
① 数据适配
② 维度对齐
③ 特征重组
3. Reshape 的数学表示
4.代码示例
5.permute()方法
6.view()方法
7. Reshape 的注意事项
① 数据连续性
② 维度顺序
③ 性能优化
二、sequenze 和 unequenze
1. Sequence(序列)是什么?
2. Unsequence(非序列)是什么?
1. 填充(Padding)
三、concat、stack、expand 和 flatten
1. concat(拼接)
功能
关键点
2. stack(堆叠)
功能
关键点
3. expand(扩展)
功能
关键点
4. flatten(展平)
功能
关键点
5. 对比总结
四、pointwise
五、split和slice
一、reshape
返回一个具有与输入相同的数据和元素数量,但具有指定形状的张量。如果可能的话,返回的张量将是输入的视图。否则,它将是一个副本。连续的输入和具有兼容步幅的输入可以进行重塑而无需复制,但您不应依赖于复制与视图行为。
1. 什么是 Reshape? 核心功能:改变张量的维度(形状),但不改变其元素内容和存储顺序。数学本质:通过重新排列索引,将原张量映射到新的形状空间。 2. Reshape 的核心作用 ① 数据适配 将数据转换为模型输入要求的形状(如 [batch_size, channels, height, width])。示例:将 [100, 784](MNIST 图像展平)转换为 [100, 1, 28, 28]。 ② 维度对齐 在矩阵乘法、卷积等操作中,确保输入张量的维度匹配。示例:将 [3, 5, 5] 转换为 [3, 1, 5, 5] 以适配卷积层。 ③ 特征重组 提取特定维度的特征(如将 [batch, height, width, channels] 转换为 [batch, channels, height, width])。 3. Reshape 的数学表示 输入形状:(N, C_in, H_in, W_in)输出形状:(N, C_out, H_out, W_out)关键约束:N×Cin×Hin×Win=N×Cout×Hout×Wout即总元素数量必须保持不变。 4.代码示例 # 输入:[2, 3, 5, 5] x = torch.randn(2, 3, 5, 5) # Reshape to [2, 15, 5] y = x.reshape(2, -1, 5) print(y.shape) # torch.Size([2, 15, 5]) 5.permute()方法 功能:重新排列张量的轴顺序(不改变元素值)。 # 输入:[batch=2, channels=3, height=5, width=5] x = torch.randn(2, 3, 5, 5) # 将 channels 和 height 交换 y = x.permute(0, 2, 1, 3) # 输出形状:[2, 5, 3, 5] print(y.shape) 6.view()方法 功能:返回一个与原张量共享内存的新视图(需数据连续)。 import torch # 输入:[batch=2, channels=3, height=5, width=5] x = torch.randn(2, 3, 5, 5) # Reshape to [2, 15, 5](3 * 5=15) y = x.view(2, -1, 5) # -1 表示自动计算剩余维度 print(y.shape) # torch.Size([2, 15, 5]) 7. Reshape 的注意事项 ① 数据连续性 view():要求原张量数据连续,否则会报错。reshape():允许非连续数据,但会复制内存,可能影响性能。 ② 维度顺序 使用 permute() 时需明确指定轴顺序,避免逻辑错误。 ③ 性能优化 尽量使用 view() 而非 reshape() 以复用内存。 二、sequenze 和 unequenze 1. Sequence(序列)是什么? 定义:按顺序排列的数据,每个元素之间存在时间或逻辑上的依赖关系。常见场景: 自然语言处理(NLP):句子、单词序列。时间序列分析:股票价格、传感器数据。语音识别:音频信号帧序列。 数学形式:X=[x1,x2,...,xT],其中 T 是序列长度。 2. Unsequence(非序列)是什么? 定义:无顺序依赖的数据,元素之间是独立或空间相关的。常见场景: 图像分类:二维像素矩阵。无监督聚类:客户分群、文档分类。图神经网络(GNN):节点间无固定顺序的图结构。3. 序列 vs 非序列的核心差异
维度序列非序列数据依赖时间/逻辑顺序敏感无顺序依赖典型任务文本生成、语音识别、时间序列预测图像分类、目标检测、聚类常用模型RNN、LSTM、TransformerCNN、GCN、全连接层输入形状[batch, T, ...](T为序列长度)[batch, C, H, W](C为通道数)4. 序列数据的处理方法
1. 填充(Padding) 目的:将不同长度的序列统一到相同长度。 import torch.nn.utils.rnn as rnn_utils # 输入序列:batch=2, 最大长度=5 sequences = [ torch.randn(3), # 序列1(长度3) torch.randn(5) # 序列2(长度5) ] # 填充到长度5,用0填充 padded = rnn_utils.pad_sequence(sequences, batch_first=True) print(padded.shape) # torch.Size([2, 5, ...])2.打包(Packing)
目的:仅保留有效数据,忽略填充部分,提升计算效率。 # 输入序列和长度掩码 lengths = [3, 5] packed = rnn_utils.pack_padded_sequence(sequences, lengths, batch_first=True) # 解包输出 output, output_lengths = rnn_utils.unpack_packed_sequence(packed) 三、concat、stack、expand 和 flatten 1. concat(拼接) 功能沿指定维度将多个张量连接成一个更大的张量,不改变原有维度。
import torch # 定义两个二维张量 a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[5, 6], [7, 8]]) # 沿第0维(行方向)拼接 concatenated = torch.cat([a, b], dim=0) print(concatenated) # 输出: # tensor([[1, 2], # [3, 4], # [5, 6], # [7, 8]]) 关键点 输入张量的其他维度必须一致。结果形状:(N+M, ...), 其中 N 和 M 是拼接张量的大小。 2. stack(堆叠) 功能沿新维度将多个张量堆叠成更高维度的张量,新增一个维度。
import torch a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[5, 6], [7, 8]]) # 沿新维度(第1维)堆叠 stacked = torch.stack([a, b], dim=1) print(stacked) # 输出: # tensor([[[1, 2], # [3, 4]], # [[5, 6], # [7, 8]]]) 关键点 所有输入张量的形状必须完全相同。结果形状:(K, ...,),其中 K 是堆叠的张量数量。 3. expand(扩展) 功能通过广播机制,将张量在指定维度上重复元素,不复制数据(仅创建视图)。
import torch # 原始张量:[1, 2] x = torch.tensor([1, 2]) # 在第0维扩展2倍,得到 [1, 2, 1, 2] expanded = x.expand(2, -1) print(expanded) # tensor([1, 2, 1, 2]) # 在第1维扩展3倍,得到 [[1,1,1], [2,2,2]] expanded_2d = x.unsqueeze(1).expand(-1, 3) print(expanded_2d) # tensor([[1, 1, 1], # [2, 2, 2]]) 关键点 expand 的参数需满足:new_dim_size >= original_dim_size。需先通过 unsqueeze 创建新维度才能扩展。 4. flatten(展平) 功能将多维张量压缩为一维或指定维度的连续数组,忽略其他维度。
import torch # 原始张量:[2, 3, 4] x = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # 展平为一维 flattened = x.flatten() print(flattened) # 输出: # tensor([1, 2, 3, 4, 5, 6, 7, 8]) # 展平到指定维度(保留第0维,合并后两维) flattened_2d = x.flatten(start_dim=1) print(flattened_2d) # 输出: # tensor([[1, 2, 3, 4], # [5, 6, 7, 8]]) 关键点 start_dim 指定从哪个维度开始展平,默认为 0。展平后张量的总元素数不变。 5. 对比总结 操作核心功能是否改变维度内存消耗典型场景concat沿指定维度拼接张量否(保持原有维度)低(共享数据)数据合并(如特征拼接)stack新增维度堆叠张量是(维度+1)中(复制数据)多模型输出堆叠(如图像分割)expand广播机制扩展元素可能改变维度极低(仅视图)扩展特征图尺寸(如上采样)flatten压缩多维张量为低维是(降维)低(共享数据)全连接层输入适配 四、pointwiseTensor 中逐元素进行的操作,也叫element wise 操作,大部分的activation 算子以及 add、sub、mul、div、sqrt 等都属于pointwise 类别。操作和numpy数组差不多
五、split和slice将张量分割成多个块。每个块都是原始张量的视图。
import torch # 创建一个示例张量 tensor = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # 对张量进行切片 slice_tensor = tensor[2:7] # 从索引2到索引6(不包含7) print(slice_tensor) # 输出: tensor([3, 4, 5, 6, 7]) # 使用步长对张量进行切片 step_slice_tensor = tensor[1:9:2] # 从索引1到索引8(不包含9),步长为2 print(step_slice_tensor) # 输出: tensor([2, 4, 6, 8]) # 省略起始索引和结束索引来选择整个张量 full_tensor = tensor[:] print(full_tensor) # 输出: tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])PyTorch大白话解释算子二由讯客互联开源代码栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“PyTorch大白话解释算子二”
上一篇
              C语言:51单片机程序设计基础
 
               
               
               
               
               
               
               
               
   
   
   
   
   
   
   
   
   
   
   
   
  