Files
python/Pytorch/nn/Functional/Functional.md
2025-09-09 15:56:55 +08:00

295 lines
14 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

`torch.nn.functional` 是 PyTorch 中提供神经网络相关功能的模块,包含了大量的函数,用于实现卷积、池化、激活函数、损失函数、归一化等操作。这些函数是 `torch.nn` 模块的函数式接口,通常用于定义神经网络的 `forward` 方法中,尤其是当不需要管理参数(如权重和偏置)时。以下是对 `torch.nn.functional` 模块中主要方法的全面讲解,基于 PyTorch 2.8 官方文档()以及相关信息。由于方法数量庞大,我将按类别组织并简要讲解每个方法的功能、输入输出和典型用途,尽量清晰简洁。如果需要更详细的讲解或代码示例,请告诉我![](https://docs.pytorch.org/docs/2.8/nn.functional.html)
---
### 1. 卷积函数 (Convolution Functions)
这些函数用于执行卷积操作,常用于处理图像、时间序列等数据。
- **`conv1d`**: 一维卷积,应用于序列数据(如时间序列)。
- **输入**: input (N, C_in, L_in), weight (C_out, C_in, kernel_size)
- **输出**: (N, C_out, L_out)
- **用途**: 特征提取,如音频信号处理。
- **参数**: stride, padding, dilation, groups 等控制卷积行为。
- **`conv2d`**: 二维卷积,适用于图像数据。
- **输入**: input (N, C_in, H, W), weight (C_out, C_in, kH, kW)
- **输出**: (N, C_out, H_out, W_out)
- **用途**: 图像特征提取,如边缘检测。
- **`conv3d`**: 三维卷积,适用于视频或体视显微镜数据。
- **输入**: input (N, C_in, D, H, W), weight (C_out, C_in, kD, kH, kW)
- **输出**: (N, C_out, D_out, H_out, W_out)
- **用途**: 视频处理或 3D 图像分析。
- **`conv_transpose1d` / `conv_transpose2d` / `conv_transpose3d`**: 转置卷积(有时称为“反卷积”),用于上采样。
- **输入**: 类似普通卷积,但 weight 维度相反。
- **输出**: 更大的特征图。
- **用途**: 生成模型、图像分割中的上采样。
- **`unfold`**: 从输入张量中提取滑动局部块。
- **输入**: input (N, C, *), kernel_size, stride, padding 等
- **输出**: (N, C * prod(kernel_size), L)
- **用途**: 实现自定义卷积或池化操作。
- **`fold`**: 将滑动局部块组合成大张量,与 `unfold` 相反。
- **输入**: input (N, C * prod(kernel_size), L), output_size, kernel_size
- **输出**: (N, C, output_size)
- **用途**: 重构特征图。
---
### 2. 池化函数 (Pooling Functions)
池化函数用于下采样,减少空间维度,增强模型的鲁棒性。
- **`avg_pool1d` / `avg_pool2d` / `avg_pool3d`**: 平均池化,计算区域内的平均值。
- **输入**: input (N, C, *), kernel_size, stride
- **输出**: 降维后的张量
- **用途**: 减少特征图尺寸,平滑特征。
- **`max_pool1d` / `max_pool2d` / `max_pool3d`**: 最大池化,提取区域内的最大值。
- **输入**: input (N, C, *), kernel_size, stride, return_indices可选
- **输出**: 降维后的张量,可选返回最大值索引
- **用途**: 提取显著特征,如边缘或纹理。
- **`max_unpool1d` / `max_unpool2d` / `max_unpool3d`**: 最大池化的逆操作,上采样。
- **输入**: input, indices来自 max_pooloutput_size
- **输出**: 恢复的张量
- **用途**: 解码器或生成模型中的上采样。
- **`lp_pool1d` / `lp_pool2d`**: Lp 范数池化,计算区域内的 p 次方均值。
- **输入**: input, norm_type, kernel_size
- **输出**: 降维后的张量
- **用途**: 更灵活的池化方式。
- **`adaptive_avg_pool1d` / `adaptive_avg_pool2d` / `adaptive_avg_pool3d`**: 自适应平均池化,输出固定大小。
- **输入**: input, output_size
- **输出**: 指定大小的张量
- **用途**: 适配不同输入尺寸的模型。
- **`adaptive_max_pool1d` / `adaptive_max_pool2d` / `adaptive_max_pool3d`**: 自适应最大池化。
- **输入/输出**: 类似自适应平均池化
- **用途**: 固定输出尺寸的特征提取。
---
### 3. 激活函数 (Non-linear Activations)
激活函数为神经网络引入非线性,增强表达能力。
- **`relu`**: ReLU 激活f(x) = max(0, x)。
- **输入**: 张量
- **输出**: 相同形状的张量
- **用途**: 常用激活函数,简单高效。
- **`relu6`**: ReLU 的变体,限制输出最大值为 6。
- **用途**: 移动设备模型,控制数值范围。
- **`elu`**: 指数线性单元f(x) = x if x > 0 else alpha * (exp(x) - 1)。
- **用途**: 缓解梯度消失问题。
- **`selu`**: 缩放指数线性单元,需配合特定初始化。
- **用途**: 自归一化网络。
- **`celu`**: 连续指数线性单元,平滑版 ELU。
- **用途**: 更平滑的非线性。
- **`leaky_relu`**: Leaky ReLUf(x) = x if x > 0 else negative_slope * x。
- **用途**: 允许负值梯度,缓解“死亡 ReLU”问题。
- **`prelu`**: 参数化的 ReLUnegative_slope 可学习。
- **用途**: 更灵活的激活。
- **`rrelu`**: 随机 ReLU负斜率在训练时随机。
- **用途**: 正则化,减少过拟合。
- **`glu`**: Gated Linear Unitf(x) = x[:half] * sigmoid(x[half:])。
- **用途**: 门控机制,语言模型常用。
- **`gelu`**: Gaussian Error Linear Unit近似 ReLU 和 Dropout 的组合。
- **用途**: Transformer 模型常用。
- **`sigmoid`**: Sigmoid 激活f(x) = 1 / (1 + exp(-x))。
- **用途**: 二分类输出。
- **`tanh`**: 双曲正切激活f(x) = tanh(x)。
- **用途**: 输出范围 [-1, 1] 的场景。
- **`softmax`**: Softmax 激活,归一化为概率分布。
- **输入**: 张量dim归一化维度
- **输出**: 概率分布
- **用途**: 多分类任务。
- **`log_softmax`**: Log-Softmax计算 softmax 的对数。
- **用途**: 数值稳定性,配合 NLLLoss 使用。
- **`softplus`**: f(x) = log(1 + exp(x)),平滑近似 ReLU。
- **用途**: 正输出场景。
- **`softsign`**: f(x) = x / (1 + |x|)。
- **用途**: 输出范围 [-1, 1],平滑激活。
- **`silu`**: SiLU (Sigmoid Linear Unit)f(x) = x * sigmoid(x)。
- **用途**: Transformer 模型。
- **`mish`**: f(x) = x * tanh(softplus(x))。
- **用途**: 现代网络的平滑激活。
- **`hardswish` / `hardsigmoid` / `hardtanh`**: 硬性激活函数,计算效率高。
- **用途**: 移动设备上的轻量模型。
- **`threshold`**: 阈值激活,小于阈值设为指定值。
- **用途**: 稀疏激活。
---
### 4. 归一化函数 (Normalization Functions)
归一化函数用于稳定训练,加速收敛。
- **`batch_norm`**: 批归一化,基于 mini-batch 统计归一化。
- **输入**: input, running_mean, running_var, weight, bias
- **输出**: 归一化后的张量
- **用途**: 加速训练,减少内部协变量偏移。
- **`instance_norm`**: 实例归一化,基于单个样本归一化。
- **用途**: 风格迁移、图像生成。
- **`layer_norm`**: 层归一化,基于特征维度归一化。
- **用途**: Transformer 和 RNN。
- **`group_norm`**: 组归一化,将通道分组归一化。
- **用途**: 小批量场景。
- **`local_response_norm`**: 局部响应归一化,基于邻域归一化。
- **用途**: 早期卷积网络(如 AlexNet
---
### 5. 线性函数 (Linear Functions)
- **`linear`**: 线性变换y = xW^T + b。
- **输入**: input, weight, bias
- **输出**: 变换后的张量
- **用途**: 全连接层。
- **`bilinear`**: 双线性变换y = x1^T W x2 + b。
- **用途**: 多输入特征交互。
---
### 6. 损失函数 (Loss Functions)
损失函数用于衡量模型预测与真实值之间的差异。
- **`mse_loss`**: 均方误差,适用于回归任务。
- **`l1_loss`**: L1 损失,绝对值误差。
- **`smooth_l1_loss`**: 平滑 L1 损失,结合 MSE 和 L1 的优点。
- **`kl_div`**: KL 散度,衡量分布差异。
- **`cross_entropy`**: 交叉熵损失,适用于分类任务。
- **`nll_loss`**: 负对数似然损失,常与 log_softmax 配合。
- **`hinge_loss`**: Hinge 损失,用于 SVM 式分类。
- **`margin_ranking_loss`**: 排序损失,用于排序任务。
- **`triplet_margin_loss`**: 三元组损失,用于嵌入学习。
- **`ctc_loss`**: 连接主义时序分类损失,用于序列任务。
- **`bce_loss` / `bce_with_logits_loss`**: 二元交叉熵,适用于二分类。
---
### 7. 注意力机制 (Attention Mechanisms)
- **`scaled_dot_product_attention`**: 缩放点积注意力Transformer 的核心组件。
- **输入**: query, key, value, attn_mask可选
- **输出**: 注意力加权后的值
- **用途**: Transformer 模型()。[](https://docs.pytorch.org/tutorials/deep-dive.html)
---
### 8. 其他函数
- **`dropout` / `dropout2d` / `dropout3d`**: Dropout 正则化,随机置零。
- **`embedding`**: 嵌入层,将索引映射为向量。
- **`one_hot`**: 独热编码,整数转独热向量。
- **`pad`**: 张量填充,用于调整尺寸。
- **`interpolate`**: 插值上采样或下采样。
- **`grid_sample`**: 网格采样,用于空间变换。
- **`affine_grid`**: 生成仿射变换的网格。
---
### 注意事项
1. **函数式 vs 模块化**:
- `torch.nn.functional` 提供无状态的函数式接口,适合在 `forward` 方法中调用。相比之下,`torch.nn` 的模块(如 `nn.Conv2d`)会管理参数(如权重和偏置),适合在 `__init__` 中定义()。[](https://discuss.pytorch.org/t/difference-of-methods-between-torch-nn-and-functional/1076)
- 例如,`F.relu` 不保存参数,而 `nn.ReLU` 是模块化的,性能上无明显差异()。[](https://discuss.pytorch.org/t/difference-of-methods-between-torch-nn-and-functional/1076)
2. **使用场景**:
- 卷积、池化、线性变换等常用于卷积神经网络CNN
- 注意力机制和归一化常用于 Transformer 模型。
- 损失函数根据任务类型(如分类、回归、排序)选择。
3. **文档更新**:
- 以上基于 PyTorch 2.8 文档(截至 2024 年 11 月。建议查看最新文档https://pytorch.org/docs/stable/nn.functional.html以获取新增功能或更新。[](https://docs.pytorch.org/docs/2.8/nn.functional.html)
---
### 示例代码
以下是一个简单的 CNN 前向传播示例,展示如何使用 `torch.nn.functional`
```python
import torch
import torch.nn.functional as F
# 输入张量 (batch_size=1, channels=1, height=28, width=28)
x = torch.randn(1, 1, 28, 28)
# 卷积核
weight = torch.randn(6, 1, 5, 5)
bias = torch.randn(6)
# 卷积 + ReLU + 最大池化
x = F.conv2d(x, weight, bias, stride=1, padding=2)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2, stride=2)
print(x.shape) # 输出: torch.Size([1, 6, 14, 14])
```
---
如果您需要某类函数的详细讲解、特定方法的代码示例,或想深入某个功能(如注意力机制或损失函数),请告诉我,我可以进一步展开!
是的你的理解非常正确torch.nn.functional通常简写为 F中的方法是纯函数式接口它们只进行计算不管理或记录任何参数如权重或偏置。下面详细解答你的问题并澄清其与神经网络的关系。
1. torch.nn.functional 的本质
纯函数计算torch.nn.functional 提供的是无状态的函数,输入张量后直接执行特定操作(如卷积、激活、池化等),并返回结果。这些函数不会保存任何内部状态或参数。
例如F.conv2d(input, weight, bias) 需要显式传入 input、卷积核 weight 和偏置 bias如果有函数只负责计算卷积操作并返回结果。
不记录参数:与 torch.nn 模块(如 nn.Conv2d不同F 中的函数不会自动管理权重或偏置参数。这些参数需要由用户手动提供或通过其他机制(如 nn.Parameter管理。
与神经网络的关系:虽然 torch.nn.functional 本身不直接构成神经网络,但它提供了构建神经网络所需的核心计算操作。神经网络的实现通常结合 torch.nn.Module管理参数和层结构与 F 的函数(执行具体计算)。
2. torch.nn.functional vs torch.nn
为了更好理解,我们对比一下 torch.nn.functional 和 torch.nn 模块:
特性torch.nn.functionaltorch.nn (如 nn.Conv2d, nn.ReLU)类型函数式接口,无状态模块化接口,有状态参数管理不管理参数,需手动传入权重、偏置等自动管理参数(如权重和偏置)典型用途在 forward 方法中实现具体计算在 __init__ 中定义网络层灵活性更灵活适合自定义操作更封装适合标准网络结构示例F.conv2d(x, weight, bias)self.conv = nn.Conv2d(1, 6, 3)
例子:
使用 torch.nn.Conv2d
pythonimport torch
import torch.nn as nn
conv = nn.Conv2d(1, 6, 3) # 定义卷积层,自动管理 weight 和 bias
x = torch.randn(1, 1, 28, 28)
y = conv(x) # 直接调用,参数由模块管理
使用 torch.nn.functional
pythonimport torch.nn.functional as F
x = torch.randn(1, 1, 28, 28)
weight = torch.randn(6, 1, 3, 3) # 手动定义权重
bias = torch.randn(6) # 手动定义偏置
y = F.conv2d(x, weight, bias) # 显式传入参数
1. 为什么使用 torch.nn.functional
虽然 torch.nn 模块更方便(因为它封装了参数管理),但 F 的函数式接口在以下场景中更有优势:
自定义操作当需要自定义权重计算或动态调整参数时F 提供更大灵活性。例如,可以在每次前向传播中修改权重。
轻量级模型:在不需要持久化参数的场景(如某些实验性模型或推理阶段),直接使用 F 可以减少内存开销。
函数式编程:适合函数式编程风格,便于将操作组合成复杂的计算流程。
特殊场景:如实现自定义激活函数、池化操作,或在 Transformer 中实现注意力