Files
python/Pytorch/nn/Module/nn.Module.md
2025-09-09 15:56:55 +08:00

464 lines
15 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

`torch.nn.Module` 是 PyTorch 中构建神经网络的核心基类,所有的神经网络模型都继承自它。它提供了许多方法来管理模块、参数、子模块和前向传播等功能。以下基于 PyTorch 官方文档https://pytorch.org/docs/stable/generated/torch.nn.Module.html`torch.nn.Module` 的所有方法进行详细讲解,力求清晰、简洁且实用。
---
### 1. `__init__` 方法
- **描述**: 初始化方法,用于定义模块的结构(如层、子模块等)。
- **用法**: 在自定义模块时,重写 `__init__` 来定义子模块或参数。
- **示例**:
```python
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
```
- **说明**: 调用 `super().__init__()` 确保基类的初始化,子模块需注册到模块中(通过赋值给 `self`)。
---
### 2. `forward(*args, **kwargs)`
- **描述**: 定义模块的前向传播逻辑,子类必须实现此方法。
- **用法**: 输入数据通过此方法计算输出。
- **示例**:
```python
def forward(self, x):
x = self.fc1(x)
return x
```
- **说明**: 这是核心方法,所有输入数据的计算流程在此定义。调用模型实例(如 `model(x)`)时会自动调用 `forward`。
---
### 3. `__call__` 方法
- **描述**: 使模块实例可调用,内部调用 `forward` 方法并添加钩子hook功能。
- **用法**: 不需要显式重写,用户通过 `model(input)` 间接调用。
- **说明**: 通常不需要直接操作,但了解它是 `model(input)` 的实现基础。
---
### 4. `add_module(name: str, module: Module) -> None`
- **描述**: 向模块添加一个子模块,并以指定名称注册。
- **用法**: 动态添加子模块。
- **示例**:
```python
model = nn.Module()
model.add_module("fc1", nn.Linear(10, 5))
```
- **说明**: 等价于 `self.name = module`,但更灵活,适合动态构建模型。
---
### 5. `apply(fn: Callable[['Module'], None]) -> T`
- **描述**: 递归地将函数 `fn` 应用到模块及其所有子模块。
- **用法**: 用于初始化参数或修改模块属性。
- **示例**:
```python
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
model.apply(init_weights)
```
- **说明**: 常用于自定义权重初始化。
---
### 6. `buffers(recurse: bool = True) -> Iterator[Tensor]`
- **描述**: 返回模块中所有缓冲区(如 `running_mean` 或 `running_var`)的迭代器。
- **用法**: 获取模块的非可训练参数(如 BatchNorm 的统计数据)。
- **示例**:
```python
for buf in model.buffers():
print(buf)
```
- **说明**: 如果 `recurse=True`,包括子模块的缓冲区。
---
### 7. `children() -> Iterator[Module]`
- **描述**: 返回直接子模块的迭代器。
- **用法**: 遍历模型的直接子模块。
- **示例**:
```python
for child in model.children():
print(child)
```
- **说明**: 仅返回直接子模块,不递归到更深层。
---
### 8. `cpu() -> T`
- **描述**: 将模块的所有参数和缓冲区移动到 CPU。
- **用法**: 用于设备切换。
- **示例**:
```python
model.cpu()
```
- **说明**: 确保模型和数据在同一设备上运行。
---
### 9. `cuda(device: Optional[Union[int, torch.device]] = None) -> T`
- **描述**: 将模块的所有参数和缓冲区移动到指定的 GPU 设备。
- **用法**: 指定 GPU 设备(如 `cuda:0`)。
- **示例**:
```python
model.cuda(0) # 移动到 GPU 0
```
- **说明**: 如果不指定 `device`,使用当前默认 GPU。
---
### 10. `double() -> T`
- **描述**: 将模块的参数和缓冲区转换为 `torch.float64`(双精度浮点)。
- **用法**: 提高数值精度。
- **示例**:
```python
model.double()
```
- **说明**: 通常在需要高精度计算时使用,但会增加内存占用。
---
### 11. `eval() -> T`
- **描述**: 将模块设置为评估模式。
- **用法**: 关闭训练特有的行为(如 Dropout 和 BatchNorm 的更新)。
- **示例**:
```python
model.eval()
```
- **说明**: 与 `train(False)` 等效,用于推理阶段。
---
### 12. `float() -> T`
- **描述**: 将模块的参数和缓冲区转换为 `torch.float32`(单精度浮点)。
- **用法**: 默认浮点类型,适合大多数深度学习任务。
- **示例**:
```python
model.float()
```
- **说明**: 比 `double()` 更节省内存。
---
### 13. `get_buffer(target: str) -> Tensor`
- **描述**: 返回指定名称的缓冲区。
- **用法**: 访问特定缓冲区(如 BatchNorm 的 `running_mean`)。
- **示例**:
```python
running_mean = model.get_buffer("bn.running_mean")
```
- **说明**: 如果缓冲区不存在,会抛出 `KeyError`。
---
### 14. `get_parameter(target: str) -> Parameter`
- **描述**: 返回指定名称的参数。
- **用法**: 访问特定参数(如 `weight` 或 `bias`)。
- **示例**:
```python
weight = model.get_parameter("fc1.weight")
```
- **说明**: 参数必须是注册的参数,否则抛出 `KeyError`。
---
### 15. `get_submodule(target: str) -> Module`
- **描述**: 返回指定路径的子模块。
- **用法**: 访问嵌套子模块。
- **示例**:
```python
submodule = model.get_submodule("block1.conv1")
```
- **说明**: 使用点号(如 `block1.conv1`)访问嵌套结构。
---
### 16. `half() -> T`
- **描述**: 将模块的参数和缓冲区转换为 `torch.float16`(半精度浮点)。
- **用法**: 用于加速计算和节省 GPU 内存。
- **示例**:
```python
model.half()
```
- **说明**: 需确保硬件支持半精度运算。
---
### 17. `load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> _IncompatibleKeys`
- **描述**: 从 `state_dict` 加载参数和缓冲区。
- **用法**: 用于加载预训练模型或检查点。
- **示例**:
```python
model.load_state_dict(torch.load("model.pth"))
```
- **说明**:
- `strict=True`:要求 `state_dict` 的键完全匹配。
- `assign=True`:直接赋值而非复制(实验性功能)。
- 返回 `_IncompatibleKeys`,指示缺失或多余的键。
---
### 18. `modules() -> Iterator[Module]`
- **描述**: 返回模块及其所有子模块(递归)的迭代器。
- **用法**: 遍历整个模块树。
- **示例**:
```python
for module in model.modules():
print(module)
```
- **说明**: 比 `children()` 更深入,包含所有层级子模块。
---
### 19. `named_buffers(recurse: bool = True, prefix: str = '', remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]`
- **描述**: 返回模块中所有缓冲区的名称和值的迭代器。
- **用法**: 获取缓冲区的名称和值。
- **示例**:
```python
for name, buf in model.named_buffers():
print(name, buf)
```
- **说明**: 如果 `recurse=True`,包括子模块的缓冲区。
---
### 20. `named_children() -> Iterator[Tuple[str, Module]]`
- **描述**: 返回直接子模块的名称和模块的迭代器。
- **用法**: 遍历直接子模块及其名称。
- **示例**:
```python
for name, child in model.named_children():
print(name, child)
```
- **说明**: 仅返回直接子模块,不递归。
---
### 21. `named_modules(memo: Optional[Set[Module]] = None, prefix: str = '', remove_duplicate: bool = True) -> Iterator[Tuple[str, Module]]`
- **描述**: 返回模块及其所有子模块(递归)的名称和模块的迭代器。
- **用法**: 遍历整个模块树及其名称。
- **示例**:
```python
for name, module in model.named_modules():
print(name, module)
```
- **说明**: 比 `named_children()` 更深入,包含所有层级。
---
### 22. `named_parameters(recurse: bool = True, prefix: str = '', remove_duplicate: bool = True) -> Iterator[Tuple[str, Parameter]]`
- **描述**: 返回模块中所有参数的名称和值的迭代器。
- **用法**: 访问模型的参数(如权重和偏置)。
- **示例**:
```python
for name, param in model.named_parameters():
print(name, param.shape)
```
- **说明**: 如果 `recurse=True`,包括子模块的参数。
---
### 23. `parameters(recurse: bool = True) -> Iterator[Parameter]`
- **描述**: 返回模块中所有参数的迭代器。
- **用法**: 用于优化器配置。
- **示例**:
```python
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
```
- **说明**: 如果 `recurse=True`,包括子模块的参数。
---
### 24. `register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) -> None`
- **描述**: 注册一个缓冲区(如非可训练的张量)。
- **用法**: 用于存储不需要梯度的张量(如 BatchNorm 的 `running_mean`)。
- **示例**:
```python
self.register_buffer("running_mean", torch.zeros(10))
```
- **说明**: `persistent=True` 表示缓冲区会保存到 `state_dict`。
---
### 25. `register_forward_hook(hook: Callable[..., None], *, prepend: bool = False, with_kwargs: bool = False) -> RemovableHandle`
- **描述**: 注册一个前向传播钩子函数,在模块前向传播时调用。
- **用法**: 用于监控或修改前向传播的输入/输出。
- **示例**:
```python
def hook(module, input, output):
print(output)
handle = model.register_forward_hook(hook)
```
- **说明**: 返回 `RemovableHandle`,可通过 `handle.remove()` 移除钩子。
---
### 26. `register_forward_pre_hook(hook: Callable[..., None], *, prepend: bool = False, with_kwargs: bool = False) -> RemovableHandle`
- **描述**: 注册一个前向传播前的钩子函数,在 `forward` 调用前触发。
- **用法**: 用于修改输入或调试。
- **示例**:
```python
def pre_hook(module, input):
print(input)
handle = model.register_forward_pre_hook(pre_hook)
```
- **说明**: 类似 `register_forward_hook`,但在 `forward` 之前运行。
---
### 27. `register_full_backward_hook(hook: Callable[..., None], prepend: bool = False) -> RemovableHandle`
- **描述**: 注册一个反向传播钩子函数,在梯度计算时调用。
- **用法**: 用于监控或修改梯度。
- **示例**:
```python
def backward_hook(module, grad_input, grad_output):
print(grad_output)
handle = model.register_full_backward_hook(backward_hook)
```
- **说明**: 在 PyTorch 2.0+ 中推荐使用,替代旧的 `register_backward_hook`。
---
### 28. `register_parameter(name: str, param: Optional[Parameter]) -> None`
- **描述**: 注册一个参数到模块。
- **用法**: 动态添加可训练参数。
- **示例**:
```python
self.register_parameter("weight", nn.Parameter(torch.randn(10, 5)))
```
- **说明**: 参数会自动加入 `parameters()` 和 `state_dict`。
---
### 29. `requires_grad_(requires_grad: bool = True) -> T`
- **描述**: 设置模块所有参数的 `requires_grad` 属性。
- **用法**: 冻结或解冻参数。
- **示例**:
```python
model.requires_grad_(False) # 冻结参数
```
- **说明**: 常用于冻结预训练模型的部分层。
---
### 30. `share_memory() -> T`
- **描述**: 将模块的参数和缓冲区移动到共享内存。
- **用法**: 用于多进程数据共享。
- **示例**:
```python
model.share_memory()
```
- **说明**: 主要用于 `torch.multiprocessing` 场景。
---
### 31. `state_dict(*args, destination: Optional[Dict[str, Tensor]] = None, prefix: str = '', keep_vars: bool = False) -> Dict[str, Tensor]`
- **描述**: 返回模块的参数和缓冲区的状态字典。
- **用法**: 保存模型状态。
- **示例**:
```python
torch.save(model.state_dict(), "model.pth")
```
- **说明**: `keep_vars=True` 时保留 `Tensor` 的 `Variable` 特性(较少使用)。
---
### 32. `to(*args, **kwargs) -> T`
- **描述**: 将模块的参数和缓冲区移动到指定设备或数据类型。
- **用法**: 灵活的设备/类型转换。
- **示例**:
```python
model.to(device="cuda", dtype=torch.float16)
```
- **说明**: 支持多种参数形式(如 `to(device)`, `to(dtype)`)。
---
### 33. `to_empty(*, device: Union[str, torch.device], recurse: bool = True) -> T`
- **描述**: 将模块的参数和缓冲区移动到指定设备,但不初始化内容。
- **用法**: 用于初始化空模型。
- **示例**:
```python
model.to_empty(device="cuda")
```
- **说明**: 较少使用,适合特殊场景。
---
### 34. `train(mode: bool = True) -> T`
- **描述**: 设置模块的训练模式。
- **用法**: 启用/禁用训练特有的行为(如 Dropout、BatchNorm
- **示例**:
```python
model.train() # 训练模式
model.train(False) # 等同于 eval()
```
- **说明**: 与 `eval()` 相对,用于切换训练/评估模式。
---
### 35. `type(dst_type: Union[str, torch.dtype]) -> T`
- **描述**: 将模块的参数和缓冲区转换为指定数据类型。
- **用法**: 类似 `float()`, `double()` 等,但更通用。
- **示例**:
```python
model.type(torch.float16)
```
- **说明**: 支持字符串(如 `"torch.float32"`)或 `torch.dtype`。
---
### 36. `xpu(device: Optional[Union[int, torch.device]] = None) -> T`
- **描述**: 将模块移动到 XPU 设备Intel GPU
- **用法**: 用于支持 XPU 的设备。
- **示例**:
```python
model.xpu(0)
```
- **说明**: 仅在支持 XPU 的环境中有效。
---
### 37. `zero_grad(set_to_none: bool = False) -> None`
- **描述**: 将模块所有参数的梯度清零。
- **用法**: 在优化步骤前调用。
- **示例**:
```python
model.zero_grad()
```
- **说明**: 如果 `set_to_none=True`,梯度设为 `None` 而非 0更节省内存
---
### 注意事项
- **模块管理**: 使用 `children()`, `modules()`, `named_parameters()` 等方法可以方便地管理复杂模型的结构和参数。
- **设备与类型**: 确保模型和输入数据在同一设备和数据类型(`to()`, `cuda()`, `float()` 等)。
- **钩子函数**: 钩子(如 `register_forward_hook`)是调试和动态修改模型的强大工具。
- **状态保存**: `state_dict()` 和 `load_state_dict()` 是保存和加载模型的关键方法。
### 实践建议
你可以尝试以下代码来探索 `torch.nn.Module` 的功能:
```python
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.register_buffer("running_mean", torch.zeros(5))
def forward(self, x):
return self.fc1(x)
model = MyModel()
print(list(model.named_parameters())) # 查看参数
print(list(model.named_buffers())) # 查看缓冲区
model.to("cuda") # 移动到 GPU
model.eval() # 设置评估模式
torch.save(model.state_dict(), "model.pth") # 保存模型
```
如果你想深入某个方法(如钩子函数或 `state_dict`)或需要具体代码示例,请告诉我,我可以进一步展开!