Added nn
This commit is contained in:
0
Pytorch/nn/Parameter/List(Dict)AreUsedToStorePms
Normal file
0
Pytorch/nn/Parameter/List(Dict)AreUsedToStorePms
Normal file
225
Pytorch/nn/Parameter/nn.Parameter.md
Normal file
225
Pytorch/nn/Parameter/nn.Parameter.md
Normal file
@@ -0,0 +1,225 @@
|
||||
根据 PyTorch 官方文档,`torch.nn.Parameter` 是一个类,继承自 `torch.Tensor`,主要用于表示神经网络模型中的可训练参数。它本身并没有定义许多独立的方法,而是继承了 `torch.Tensor` 的大部分方法,同时具备一些特殊属性,用于与 `torch.nn.Module` 配合管理模型参数。以下是对 `torch.nn.Parameter` 的详细讲解,包括其作用、特性以及与参数管理相关的方法,基于官方文档和其他可靠来源。
|
||||
|
||||
---
|
||||
|
||||
### 1. `torch.nn.Parameter` 的定义与作用
|
||||
|
||||
**官方定义**:
|
||||
```python
|
||||
class torch.nn.parameter.Parameter(data=None, requires_grad=True)
|
||||
```
|
||||
- **作用**:`torch.nn.Parameter` 是一种特殊的 `torch.Tensor`,用于表示神经网络模型的可训练参数。当它被赋值给 `torch.nn.Module` 的属性时,会自动注册到模块的参数列表中(通过 `parameters()` 或 `named_parameters()` 方法访问),并参与梯度计算和优化。
|
||||
- **特性**:
|
||||
- **自动注册**:当 `nn.Parameter` 实例被赋值给 `nn.Module` 的属性时,它会自动添加到模块的 `parameters()` 迭代器中,而普通 `torch.Tensor` 不会。
|
||||
- **默认梯度**:`requires_grad` 默认值为 `True`,表示参数需要计算梯度,即使在 `torch.no_grad()` 上下文中也是如此。
|
||||
- **用途**:常用于定义模型的可训练权重、偏置,或其他需要优化的参数(如 Vision Transformer 中的 positional embedding 或 class token)。
|
||||
|
||||
**参数说明**:
|
||||
- `data`:一个 `torch.Tensor`,表示参数的初始值。
|
||||
- `requires_grad`:布尔值,指示是否需要计算梯度,默认 `True`。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 创建一个 Parameter
|
||||
param = nn.Parameter(torch.randn(3, 3))
|
||||
print(param) # 输出 Parameter 类型的张量,requires_grad=True
|
||||
|
||||
# 定义一个简单的模型
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(3, 3))
|
||||
|
||||
def forward(self, x):
|
||||
return torch.matmul(x, self.weight)
|
||||
|
||||
model = MyModel()
|
||||
print(list(model.parameters())) # 包含 self.weight
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. `torch.nn.Parameter` 的方法
|
||||
|
||||
`torch.nn.Parameter` 本身没有定义额外的方法,它继承了 `torch.Tensor` 的所有方法,并通过与 `nn.Module` 的交互提供参数管理的功能。以下是与 `nn.Parameter` 相关的核心方法(主要通过 `nn.Module` 访问)以及 `torch.Tensor` 的常用方法在 `nn.Parameter` 上的应用:
|
||||
|
||||
#### 2.1 通过 `nn.Module` 访问 `nn.Parameter` 的方法
|
||||
|
||||
这些方法是 `torch.nn.Module` 提供的,用于管理 `nn.Parameter` 实例:
|
||||
|
||||
1. **`parameters()`**:
|
||||
- **作用**:返回模型中所有 `nn.Parameter` 实例的迭代器。
|
||||
- **返回值**:`Iterator[Parameter]`,包含所有参数的张量。
|
||||
- **示例**:
|
||||
```python
|
||||
for param in model.parameters():
|
||||
print(param.shape) # 打印每个参数的形状
|
||||
```
|
||||
|
||||
2. **`named_parameters()`**:
|
||||
- **作用**:返回一个迭代器,包含模型中所有 `nn.Parameter` 的名称和对应的参数张量。
|
||||
- **返回值**:`Iterator[Tuple[str, Parameter]]`,每个元素是参数名称和参数的元组。
|
||||
- **示例**:
|
||||
```python
|
||||
for name, param in model.named_parameters():
|
||||
print(f"Parameter name: {name}, Shape: {param.shape}")
|
||||
```
|
||||
|
||||
3. **`_parameters`**:
|
||||
- **作用**:`nn.Module` 的属性,是一个 `OrderedDict`,存储模块中直接定义的 `nn.Parameter` 实例。
|
||||
- **示例**:
|
||||
```python
|
||||
print(model._parameters) # 输出 OrderedDict,包含 weight 参数
|
||||
```
|
||||
|
||||
4. **`apply(fn)`**:
|
||||
- **作用**:递归地将函数 `fn` 应用于模块及其子模块的所有参数,常用于参数初始化。
|
||||
- **示例**:
|
||||
```python
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
m.weight.data.fill_(1.0)
|
||||
model.apply(init_weights) # 初始化所有参数
|
||||
```
|
||||
|
||||
5. **`cpu()` / `cuda(device_id=None)`**:
|
||||
- **作用**:将所有参数(包括 `nn.Parameter`)移动到 CPU 或指定的 GPU 设备。
|
||||
- **示例**:
|
||||
```python
|
||||
model.cuda() # 将模型参数移动到 GPU
|
||||
```
|
||||
|
||||
6. **`to(device)`**:
|
||||
- **作用**:将参数移动到指定设备(CPU 或 GPU),支持更灵活的设备管理。
|
||||
- **示例**:
|
||||
```python
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
```
|
||||
|
||||
7. **`double()` / `float()` / `half()`**:
|
||||
- **作用**:将所有参数转换为指定的数据类型(如双精度、单精度或半精度)。
|
||||
- **示例**:
|
||||
```python
|
||||
model.double() # 转换为双精度
|
||||
```
|
||||
|
||||
#### 2.2 继承自 `torch.Tensor` 的方法
|
||||
|
||||
`nn.Parameter` 是 `torch.Tensor` 的子类,因此可以使用 `torch.Tensor` 的所有方法。以下是一些常用的方法,特别适用于参数操作:
|
||||
|
||||
1. **张量操作**:
|
||||
- `add_()`, `mul_()`, `div_()` 等:原地修改参数值。
|
||||
- `zero_()`:将参数值置零,常用于重置梯度或参数。
|
||||
- **示例**:
|
||||
```python
|
||||
param = nn.Parameter(torch.randn(3, 3))
|
||||
param.zero_() # 将参数置零
|
||||
```
|
||||
|
||||
2. **梯度相关**:
|
||||
- `grad`:访问参数的梯度(一个 `torch.Tensor`)。
|
||||
- `zero_grad()`:通过优化器调用,清除参数的梯度。
|
||||
- **示例**:
|
||||
```python
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
optimizer.zero_grad() # 清除所有参数的梯度
|
||||
```
|
||||
|
||||
3. **形状操作**:
|
||||
- `view()`, `reshape()`:改变参数的形状。
|
||||
- **示例**:
|
||||
```python
|
||||
param = nn.Parameter(torch.randn(3, 3))
|
||||
reshaped_param = param.view(9) # 展平为 1D 张量
|
||||
```
|
||||
|
||||
4. **数学运算**:
|
||||
- `sum()`, `mean()`, `std()` 等:对参数值进行统计计算。
|
||||
- **示例**:
|
||||
```python
|
||||
print(param.mean()) # 计算参数的均值
|
||||
```
|
||||
|
||||
5. **克隆与分离**:
|
||||
- `clone()`:创建参数的副本。
|
||||
- `detach()`:分离参数,创建一个不需要梯度的新张量。
|
||||
- **示例**:
|
||||
```python
|
||||
param_clone = param.clone() # 复制参数
|
||||
param_detached = param.detach() # 分离,requires_grad=False
|
||||
```
|
||||
|
||||
#### 2.3 与优化器交互
|
||||
|
||||
`nn.Parameter` 的主要用途是与优化器(如 `torch.optim.SGD` 或 `Adam`)一起使用。优化器通过 `model.parameters()` 获取所有 `nn.Parameter` 实例,并更新它们的值。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
import torch.optim as optim
|
||||
|
||||
model = MyModel()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
||||
loss_fn = nn.MSELoss()
|
||||
|
||||
# 前向传播
|
||||
x = torch.randn(1, 3)
|
||||
y = torch.randn(1, 3)
|
||||
out = model(x)
|
||||
loss = loss_fn(out, y)
|
||||
|
||||
# 反向传播与优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step() # 更新所有 nn.Parameter
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. 注意事项与常见问题
|
||||
|
||||
1. **与普通 Tensor 的区别**:
|
||||
- 普通 `torch.Tensor` 即使设置 `requires_grad=True`,也不会自动添加到 `nn.Module` 的参数列表中。
|
||||
- `nn.Parameter` 默认 `requires_grad=True`,且会自动注册为模型参数。
|
||||
|
||||
2. **初始化参数**:
|
||||
- 可以使用 `torch.nn.init` 模块初始化 `nn.Parameter`。
|
||||
- **示例**:
|
||||
```python
|
||||
import torch.nn.init as init
|
||||
param = nn.Parameter(torch.randn(3, 3))
|
||||
init.xavier_uniform_(param) # 使用 Xavier 初始化
|
||||
```
|
||||
|
||||
3. **临时状态 vs 参数**:
|
||||
- 如果需要在模型中存储临时状态(如 RNN 的隐藏状态),应使用普通 `torch.Tensor` 或 `nn.Module.register_buffer()`,避免注册为可训练参数。
|
||||
|
||||
4. **Vision Transformer 示例**:
|
||||
- 在 Vision Transformer 中,`nn.Parameter` 常用于定义可学习的 `cls_token` 和 `pos_embedding`:
|
||||
```python
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, num_patches, dim):
|
||||
super(ViT, self).__init__()
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. 总结
|
||||
|
||||
`torch.nn.Parameter` 本身没有定义独特的方法,但通过继承 `torch.Tensor` 和与 `nn.Module` 的交互,提供了强大的参数管理功能。核心方法(如 `parameters()`、`named_parameters()`)通过 `nn.Module` 访问,而 `torch.Tensor` 的方法(如 `zero_()`、`view()`)直接应用于 `nn.Parameter` 实例。以下是关键点:
|
||||
|
||||
- **自动注册**:赋值给 `nn.Module` 属性时,自动加入参数列表。
|
||||
- **梯度计算**:默认 `requires_grad=True`,支持优化。
|
||||
- **灵活操作**:继承 `torch.Tensor` 的所有方法,适用于张量操作。
|
||||
|
||||
**参考文献**:
|
||||
- PyTorch 官方文档:`torch.nn.Parameter`[](https://docs.pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html)[](https://docs.pytorch.ac.cn/docs/stable/generated/torch.nn.parameter.Parameter.html)
|
||||
- 极客教程:理解 `torch.nn.Parameter`[](https://geek-docs.com/pytorch/pytorch-questions/21_pytorch_understanding_torchnnparameter.html)
|
||||
- CSDN 博客:`torch.nn.Parameter` 讲解[](https://blog.csdn.net/weixin_44878336/article/details/124733598)[](https://blog.csdn.net/weixin_44966641/article/details/118730730)
|
||||
|
||||
如果需要更详细的代码示例或特定方法的应用场景,请告诉我!
|
359
Pytorch/nn/Parameter/nn.ParameterDict.md
Normal file
359
Pytorch/nn/Parameter/nn.ParameterDict.md
Normal file
@@ -0,0 +1,359 @@
|
||||
根据 PyTorch 官方文档(`torch.nn.ParameterList` 和 `torch.nn.ParameterDict` 的相关内容),以下是对这两个类的详细讲解,包括它们的定义、作用以及所有方法的全面说明。由于 `torch.nn.ParameterList` 和 `torch.nn.ParameterDict` 是专门用于管理 `torch.nn.Parameter` 的容器类,它们的方法相对较少,主要继承自 Python 的列表和字典操作,并与 PyTorch 的模块机制结合使用。以下内容基于官方文档(`PyTorch 2.8`)和其他可靠来源(如 PyTorch Forums 和 Stack Overflow),确保准确且全面。
|
||||
|
||||
---
|
||||
|
||||
## 1. `torch.nn.ParameterList`
|
||||
|
||||
### 1.1 定义与作用
|
||||
|
||||
**官方定义**:
|
||||
```python
|
||||
class torch.nn.ParameterList(values=None)
|
||||
```
|
||||
- **作用**:`torch.nn.ParameterList` 是一个容器类,用于以列表的形式存储 `torch.nn.Parameter` 实例。它类似于 Python 的内置 `list`,但专为 PyTorch 的参数管理设计,存储的 `nn.Parameter` 会被自动注册到 `nn.Module` 的参数列表中,参与梯度计算和优化。
|
||||
- **特性**:
|
||||
- **自动注册**:当 `nn.ParameterList` 作为 `nn.Module` 的属性时,其中的所有 `nn.Parameter` 会自动被 `model.parameters()` 识别。
|
||||
- **Tensor 自动转换**:在构造、赋值、`append()` 或 `extend()` 时,传入的普通 `torch.Tensor` 会自动转换为 `nn.Parameter`。
|
||||
- **用途**:适用于需要动态管理一组参数的场景,例如在循环网络或变长模型中存储多个权重矩阵。
|
||||
- **参数**:
|
||||
- `values`:一个可选的可迭代对象,包含初始的 `nn.Parameter` 或 `torch.Tensor`。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.params = nn.ParameterList([nn.Parameter(torch.randn(2, 2)) for _ in range(3)])
|
||||
|
||||
def forward(self, x):
|
||||
for param in self.params:
|
||||
x = x @ param # 矩阵乘法
|
||||
return x
|
||||
|
||||
model = MyModel()
|
||||
print(list(model.parameters())) # 输出 3 个 (2, 2) 的参数张量
|
||||
```
|
||||
|
||||
### 1.2 方法讲解
|
||||
|
||||
`torch.nn.ParameterList` 支持以下方法,主要继承自 Python 的 `list`,并添加了 PyTorch 的参数管理特性。以下是所有方法的详细说明(基于官方文档和实际用法):
|
||||
|
||||
1. **`append(value)`**:
|
||||
- **作用**:向 `ParameterList` 末尾添加一个值。如果 `value` 是 `torch.Tensor`,会自动转换为 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- `value`:要添加的元素(可以是 `nn.Parameter` 或 `torch.Tensor`)。
|
||||
- **返回值**:`self`(`ParameterList` 本身,支持链式调用)。
|
||||
- **示例**:
|
||||
```python
|
||||
param_list = nn.ParameterList()
|
||||
param_list.append(torch.randn(2, 2)) # 自动转换为 nn.Parameter
|
||||
print(len(param_list)) # 输出 1
|
||||
print(isinstance(param_list[0], nn.Parameter)) # 输出 True
|
||||
```
|
||||
- **注意**:添加的 `nn.Parameter` 会自动注册到模块的参数列表中。
|
||||
|
||||
2. **`extend(values)`**:
|
||||
- **作用**:将一个可迭代对象中的值追加到 `ParameterList` 末尾。所有 `torch.Tensor` 会被转换为 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- `values`:一个可迭代对象,包含 `nn.Parameter` 或 `torch.Tensor`。
|
||||
- **返回值**:`self`。
|
||||
- **示例**:
|
||||
```python
|
||||
param_list = nn.ParameterList()
|
||||
param_list.extend([torch.randn(2, 2), torch.randn(2, 2)])
|
||||
print(len(param_list)) # 输出 2
|
||||
```
|
||||
- **注意**:`extend` 比逐个 `append` 更高效,适合批量添加参数。
|
||||
- **Stack Overflow 示例**:可以使用 `extend` 合并两个 `ParameterList`:
|
||||
```python
|
||||
plist = nn.ParameterList()
|
||||
plist.extend(sub_list_1)
|
||||
plist.extend(sub_list_2)
|
||||
```
|
||||
或者使用解包方式:
|
||||
```python
|
||||
param_list = nn.ParameterList([*sub_list_1, *sub_list_2])
|
||||
```
|
||||
|
||||
3. **索引操作(如 `__getitem__`, `__setitem__`)**:
|
||||
- **作用**:支持像 Python 列表一样的索引访问和赋值操作。
|
||||
- **参数**:
|
||||
- 索引(整数或切片)。
|
||||
- **返回值**:指定索引处的 `nn.Parameter`。
|
||||
- **示例**:
|
||||
```python
|
||||
param_list = nn.ParameterList([nn.Parameter(torch.randn(2, 2)) for _ in range(3)])
|
||||
print(param_list[0]) # 访问第一个参数
|
||||
param_list[0] = nn.Parameter(torch.ones(2, 2)) # 替换第一个参数
|
||||
```
|
||||
- **注意**:赋值时,新值必须是 `nn.Parameter` 或 `torch.Tensor`(后者会自动转换为 `nn.Parameter`)。
|
||||
|
||||
4. **迭代操作(如 `__iter__`)**:
|
||||
- **作用**:支持迭代,允许遍历 `ParameterList` 中的所有参数。
|
||||
- **返回值**:迭代器,逐个返回 `nn.Parameter`。
|
||||
- **示例**:
|
||||
```python
|
||||
for param in param_list:
|
||||
print(param.shape) # 打印每个参数的形状
|
||||
```
|
||||
|
||||
5. **长度查询(如 `__len__`)**:
|
||||
- **作用**:返回 `ParameterList` 中参数的数量。
|
||||
- **返回值**:整数。
|
||||
- **示例**:
|
||||
```python
|
||||
print(len(param_list)) # 输出参数数量
|
||||
```
|
||||
|
||||
6. **其他列表操作**:
|
||||
- `ParameterList` 支持 Python 列表的常见方法,如 `pop()`, `clear()`, `insert()`, `remove()` 等,但这些方法在官方文档中未明确列出(基于 Python 的 `list` 实现)。
|
||||
- **示例**(`pop`):
|
||||
```python
|
||||
param_list = nn.ParameterList([nn.Parameter(torch.randn(2, 2)) for _ in range(3)])
|
||||
removed_param = param_list.pop(0) # 移除并返回第一个参数
|
||||
print(len(param_list)) # 输出 2
|
||||
```
|
||||
- **PyTorch Forums 讨论**:有用户询问如何在运行时移除 `ParameterList` 中的元素,可以使用 `pop(index)` 或重新构造列表:
|
||||
```python
|
||||
param_list = nn.ParameterList(param_list[:index] + param_list[index+1:])
|
||||
```
|
||||
|
||||
### 1.3 注意事项
|
||||
|
||||
- **与 `nn.ModuleList` 的区别**:`ParameterList` 存储 `nn.Parameter`,用于参数管理;`nn.ModuleList` 存储 `nn.Module`,用于子模块管理。
|
||||
- **自动转换**:任何添加到 `ParameterList` 的 `torch.Tensor` 都会被转换为 `nn.Parameter`,确保参数可被优化器识别。
|
||||
- **动态性**:`ParameterList` 适合动态模型(如变长序列模型),但操作频繁可能影响性能。
|
||||
- **优化器集成**:`ParameterList` 中的参数会自动被 `model.parameters()` 包含,优化器会更新这些参数。
|
||||
|
||||
---
|
||||
|
||||
## 2. `torch.nn.ParameterDict`
|
||||
|
||||
### 2.1 定义与作用
|
||||
|
||||
**官方定义**:
|
||||
```python
|
||||
class torch.nn.ParameterDict(parameters=None)
|
||||
```
|
||||
- **作用**:`torch.nn.ParameterDict` 是一个容器类,用于以字典的形式存储 `torch.nn.Parameter` 实例。它类似于 Python 的内置 `dict`,但专为 PyTorch 参数管理设计,存储的 `nn.Parameter` 会自动注册到 `nn.Module` 的参数列表中。
|
||||
- **特性**:
|
||||
- **键值存储**:使用字符串键来索引参数,便于按名称访问。
|
||||
- **自动注册**:与 `ParameterList` 类似,`ParameterDict` 中的参数会被 `model.parameters()` 识别。
|
||||
- **Tensor 自动转换**:赋值时,`torch.Tensor` 会被转换为 `nn.Parameter`。
|
||||
- **用途**:适用于需要按名称管理参数的场景,例如在多任务学习或复杂模型中。
|
||||
- **参数**:
|
||||
- `parameters`:一个可选的可迭代对象,包含键值对(键为字符串,值为 `nn.Parameter` 或 `torch.Tensor`)。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.params = nn.ParameterDict({
|
||||
'weight1': nn.Parameter(torch.randn(2, 2)),
|
||||
'weight2': nn.Parameter(torch.randn(2, 2))
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
x = x @ self.params['weight1']
|
||||
x = x @ self.params['weight2']
|
||||
return x
|
||||
|
||||
model = MyModel()
|
||||
print(list(model.parameters())) # 输出 2 个 (2, 2) 的参数张量
|
||||
```
|
||||
|
||||
### 2.2 方法讲解
|
||||
|
||||
`torch.nn.ParameterDict` 支持以下方法,主要继承自 Python 的 `dict`,并添加了 PyTorch 的参数管理特性。以下是所有方法的详细说明:
|
||||
|
||||
1. **`__setitem__(key, value)`**:
|
||||
- **作用**:向 `ParameterDict` 添加或更新一个键值对。如果 `value` 是 `torch.Tensor`,会自动转换为 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- `key`:字符串,参数的名称。
|
||||
- `value`:`nn.Parameter` 或 `torch.Tensor`。
|
||||
- **示例**:
|
||||
```python
|
||||
param_dict = nn.ParameterDict()
|
||||
param_dict['weight'] = torch.randn(2, 2) # 自动转换为 nn.Parameter
|
||||
print(param_dict['weight']) # 输出 nn.Parameter
|
||||
```
|
||||
|
||||
2. **`update([other])`**:
|
||||
- **作用**:使用另一个字典或键值对更新 `ParameterDict`。输入的 `torch.Tensor` 会被转换为 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- `other`:一个字典或键值对的可迭代对象。
|
||||
- **示例**:
|
||||
```python
|
||||
param_dict = nn.ParameterDict()
|
||||
param_dict.update({'weight1': torch.randn(2, 2), 'weight2': torch.randn(2, 2)})
|
||||
print(len(param_dict)) # 输出 2
|
||||
```
|
||||
|
||||
3. **索引操作(如 `__getitem__`)**:
|
||||
- **作用**:通过键访问 `ParameterDict` 中的 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- 键(字符串)。
|
||||
- **返回值**:对应的 `nn.Parameter`。
|
||||
- **示例**:
|
||||
```python
|
||||
print(param_dict['weight1']) # 访问 weight1 参数
|
||||
```
|
||||
|
||||
4. **迭代操作(如 `__iter__`, `keys()`, `values()`, `items()`)**:
|
||||
- **作用**:支持迭代键、值或键值对。
|
||||
- **返回值**:
|
||||
- `keys()`:返回所有键的迭代器。
|
||||
- `values()`:返回所有 `nn.Parameter` 的迭代器。
|
||||
- `items()`:返回键值对的迭代器。
|
||||
- **示例**:
|
||||
```python
|
||||
for key, param in param_dict.items():
|
||||
print(f"Key: {key}, Shape: {param.shape}")
|
||||
```
|
||||
|
||||
5. **长度查询(如 `__len__`)**:
|
||||
- **作用**:返回 `ParameterDict` 中参数的数量。
|
||||
- **返回值**:整数。
|
||||
- **示例**:
|
||||
```python
|
||||
print(len(param_dict)) # 输出参数数量
|
||||
```
|
||||
|
||||
6. **其他字典操作**:
|
||||
- `ParameterDict` 支持 Python 字典的常见方法,如 `pop(key)`, `clear()`, `popitem()`, `get(key, default=None)` 等。
|
||||
- **示例**(`pop`):
|
||||
```python
|
||||
param_dict = nn.ParameterDict({'weight1': torch.randn(2, 2), 'weight2': torch.randn(2, 2)})
|
||||
removed_param = param_dict.pop('weight1') # 移除并返回 weight1
|
||||
print(len(param_dict)) # 输出 1
|
||||
```
|
||||
|
||||
### 2.3 注意事项
|
||||
|
||||
- **与 `nn.ModuleDict` 的区别**:`ParameterDict` 存储 `nn.Parameter`,用于参数管理;`nn.ModuleDict` 存储 `nn.Module`,用于子模块管理。
|
||||
- **键的唯一性**:`ParameterDict` 的键必须是字符串,且不能重复。
|
||||
- **动态管理**:适合需要按名称访问参数的场景,但操作复杂模型时需注意性能开销。
|
||||
- **优化器集成**:`ParameterDict` 中的参数也会被 `model.parameters()` 包含,优化器会自动更新。
|
||||
|
||||
---
|
||||
|
||||
## 3. 比较与使用场景
|
||||
|
||||
| 特性 | `ParameterList` | `ParameterDict` |
|
||||
|---------------------|--------------------------------------------|--------------------------------------------|
|
||||
| **存储方式** | 列表(按索引访问) | 字典(按键访问) |
|
||||
| **访问方式** | 索引(`param_list[0]`) | 键(`param_dict['key']`) |
|
||||
| **主要方法** | `append`, `extend`, `pop`, 索引操作 | `update`, `pop`, `keys`, `values`, `items` |
|
||||
| **适用场景** | 顺序参数管理(如循环网络中的多层权重) | 命名参数管理(如多任务模型中的权重) |
|
||||
| **动态性** | 适合动态添加/移除参数 | 适合按名称管理参数 |
|
||||
|
||||
**选择建议**:
|
||||
- 如果参数需要按顺序访问或动态增减,使用 `ParameterList`。
|
||||
- 如果需要按名称管理参数或参数具有明确语义,使用 `ParameterDict`。
|
||||
|
||||
---
|
||||
|
||||
## 4. 综合示例
|
||||
|
||||
以下是一个结合 `ParameterList` 和 `ParameterDict` 的示例,展示它们在模型中的使用:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ComplexModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 使用 ParameterList 存储一组权重
|
||||
self.list_params = nn.ParameterList([nn.Parameter(torch.randn(2, 2)) for _ in range(2)])
|
||||
# 使用 ParameterDict 存储命名参数
|
||||
self.dict_params = nn.ParameterDict({
|
||||
'conv_weight': nn.Parameter(torch.randn(3, 3)),
|
||||
'fc_weight': nn.Parameter(torch.randn(4, 4))
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
# 使用 ParameterList
|
||||
for param in self.list_params:
|
||||
x = x @ param
|
||||
# 使用 ParameterDict
|
||||
x = x @ self.dict_params['conv_weight']
|
||||
x = x @ self.dict_params['fc_weight']
|
||||
return x
|
||||
|
||||
model = ComplexModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
# 打印所有参数
|
||||
for name, param in model.named_parameters():
|
||||
print(f"Parameter name: {name}, Shape: {param.shape}")
|
||||
|
||||
# 动态修改 ParameterList
|
||||
model.list_params.append(nn.Parameter(torch.ones(2, 2)))
|
||||
|
||||
# 动态修改 ParameterDict
|
||||
model.dict_params['new_weight'] = nn.Parameter(torch.zeros(4, 4))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 常见问题与解答
|
||||
|
||||
1. **如何高效合并两个 `ParameterList`?**
|
||||
- 使用 `extend` 方法或解包方式:
|
||||
```python
|
||||
param_list = nn.ParameterList()
|
||||
param_list.extend(sub_list_1)
|
||||
param_list.extend(sub_list_2)
|
||||
# 或者
|
||||
param_list = nn.ParameterList([*sub_list_1, *sub_list_2])
|
||||
```
|
||||
|
||||
2. **如何从 `ParameterList` 删除元素?**
|
||||
- 使用 `pop(index)` 或重新构造列表:
|
||||
```python
|
||||
param_list.pop(0) # 移除第一个参数
|
||||
# 或者
|
||||
param_list = nn.ParameterList([param for i, param in enumerate(param_list) if i != 0])
|
||||
```
|
||||
|
||||
3. **如何检查 `ParameterDict` 中的参数?**
|
||||
- 使用 `keys()`, `values()`, 或 `items()` 遍历:
|
||||
```python
|
||||
for key, param in param_dict.items():
|
||||
print(f"Key: {key}, Parameter: {param}")
|
||||
```
|
||||
|
||||
4. **如何初始化参数?**
|
||||
- 使用 `torch.nn.init` 或通过 `nn.Module.apply`:
|
||||
```python
|
||||
import torch.nn.init as init
|
||||
for param in param_list:
|
||||
init.xavier_uniform_(param)
|
||||
for param in param_dict.values():
|
||||
init.xavier_uniform_(param)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 总结
|
||||
|
||||
- **`torch.nn.ParameterList`**:
|
||||
- 方法:`append`, `extend`, `pop`, 索引/迭代操作。
|
||||
- 特点:类似列表,适合顺序管理参数,自动将 `torch.Tensor` 转换为 `nn.Parameter`。
|
||||
- 场景:动态模型、变长参数列表。
|
||||
- **`torch.nn.ParameterDict`**:
|
||||
- 方法:`update`, `pop`, `keys`, `values`, `items`, 索引/迭代操作。
|
||||
- 特点:类似字典,适合按名称管理参数,自动转换 `torch.Tensor`。
|
||||
- 场景:多任务学习、需要语义化命名的模型。
|
||||
|
||||
**参考文献**:
|
||||
- PyTorch 官方文档:`torch.nn.ParameterList` 和 `torch.nn.ParameterDict`(`docs.pytorch.org`)[](https://docs.pytorch.org/docs/stable/generated/torch.nn.ParameterList.html)[](https://docs.pytorch.org/docs/stable/nn.html)
|
||||
- Stack Overflow:合并 `ParameterList` 的讨论[](https://stackoverflow.com/questions/70779631/combining-parameterlist-in-pytorch)
|
||||
- PyTorch Forums:`ParameterList` 的使用和动态操作[](https://discuss.pytorch.org/t/using-nn-parameterlist/86742)
|
||||
|
||||
如果需要更详细的代码示例、特定方法的实现,或其他相关问题(如与优化器的集成),请告诉我!
|
359
Pytorch/nn/Parameter/nn.ParameterList.md
Normal file
359
Pytorch/nn/Parameter/nn.ParameterList.md
Normal file
@@ -0,0 +1,359 @@
|
||||
根据 PyTorch 官方文档(`torch.nn.ParameterList` 和 `torch.nn.ParameterDict` 的相关内容),以下是对这两个类的详细讲解,包括它们的定义、作用以及所有方法的全面说明。由于 `torch.nn.ParameterList` 和 `torch.nn.ParameterDict` 是专门用于管理 `torch.nn.Parameter` 的容器类,它们的方法相对较少,主要继承自 Python 的列表和字典操作,并与 PyTorch 的模块机制结合使用。以下内容基于官方文档(`PyTorch 2.8`)和其他可靠来源(如 PyTorch Forums 和 Stack Overflow),确保准确且全面。
|
||||
|
||||
---
|
||||
|
||||
## 1. `torch.nn.ParameterList`
|
||||
|
||||
### 1.1 定义与作用
|
||||
|
||||
**官方定义**:
|
||||
```python
|
||||
class torch.nn.ParameterList(values=None)
|
||||
```
|
||||
- **作用**:`torch.nn.ParameterList` 是一个容器类,用于以列表的形式存储 `torch.nn.Parameter` 实例。它类似于 Python 的内置 `list`,但专为 PyTorch 的参数管理设计,存储的 `nn.Parameter` 会被自动注册到 `nn.Module` 的参数列表中,参与梯度计算和优化。
|
||||
- **特性**:
|
||||
- **自动注册**:当 `nn.ParameterList` 作为 `nn.Module` 的属性时,其中的所有 `nn.Parameter` 会自动被 `model.parameters()` 识别。
|
||||
- **Tensor 自动转换**:在构造、赋值、`append()` 或 `extend()` 时,传入的普通 `torch.Tensor` 会自动转换为 `nn.Parameter`。
|
||||
- **用途**:适用于需要动态管理一组参数的场景,例如在循环网络或变长模型中存储多个权重矩阵。
|
||||
- **参数**:
|
||||
- `values`:一个可选的可迭代对象,包含初始的 `nn.Parameter` 或 `torch.Tensor`。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.params = nn.ParameterList([nn.Parameter(torch.randn(2, 2)) for _ in range(3)])
|
||||
|
||||
def forward(self, x):
|
||||
for param in self.params:
|
||||
x = x @ param # 矩阵乘法
|
||||
return x
|
||||
|
||||
model = MyModel()
|
||||
print(list(model.parameters())) # 输出 3 个 (2, 2) 的参数张量
|
||||
```
|
||||
|
||||
### 1.2 方法讲解
|
||||
|
||||
`torch.nn.ParameterList` 支持以下方法,主要继承自 Python 的 `list`,并添加了 PyTorch 的参数管理特性。以下是所有方法的详细说明(基于官方文档和实际用法):
|
||||
|
||||
1. **`append(value)`**:
|
||||
- **作用**:向 `ParameterList` 末尾添加一个值。如果 `value` 是 `torch.Tensor`,会自动转换为 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- `value`:要添加的元素(可以是 `nn.Parameter` 或 `torch.Tensor`)。
|
||||
- **返回值**:`self`(`ParameterList` 本身,支持链式调用)。
|
||||
- **示例**:
|
||||
```python
|
||||
param_list = nn.ParameterList()
|
||||
param_list.append(torch.randn(2, 2)) # 自动转换为 nn.Parameter
|
||||
print(len(param_list)) # 输出 1
|
||||
print(isinstance(param_list[0], nn.Parameter)) # 输出 True
|
||||
```
|
||||
- **注意**:添加的 `nn.Parameter` 会自动注册到模块的参数列表中。
|
||||
|
||||
2. **`extend(values)`**:
|
||||
- **作用**:将一个可迭代对象中的值追加到 `ParameterList` 末尾。所有 `torch.Tensor` 会被转换为 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- `values`:一个可迭代对象,包含 `nn.Parameter` 或 `torch.Tensor`。
|
||||
- **返回值**:`self`。
|
||||
- **示例**:
|
||||
```python
|
||||
param_list = nn.ParameterList()
|
||||
param_list.extend([torch.randn(2, 2), torch.randn(2, 2)])
|
||||
print(len(param_list)) # 输出 2
|
||||
```
|
||||
- **注意**:`extend` 比逐个 `append` 更高效,适合批量添加参数。
|
||||
- **Stack Overflow 示例**:可以使用 `extend` 合并两个 `ParameterList`:
|
||||
```python
|
||||
plist = nn.ParameterList()
|
||||
plist.extend(sub_list_1)
|
||||
plist.extend(sub_list_2)
|
||||
```
|
||||
或者使用解包方式:
|
||||
```python
|
||||
param_list = nn.ParameterList([*sub_list_1, *sub_list_2])
|
||||
```
|
||||
|
||||
3. **索引操作(如 `__getitem__`, `__setitem__`)**:
|
||||
- **作用**:支持像 Python 列表一样的索引访问和赋值操作。
|
||||
- **参数**:
|
||||
- 索引(整数或切片)。
|
||||
- **返回值**:指定索引处的 `nn.Parameter`。
|
||||
- **示例**:
|
||||
```python
|
||||
param_list = nn.ParameterList([nn.Parameter(torch.randn(2, 2)) for _ in range(3)])
|
||||
print(param_list[0]) # 访问第一个参数
|
||||
param_list[0] = nn.Parameter(torch.ones(2, 2)) # 替换第一个参数
|
||||
```
|
||||
- **注意**:赋值时,新值必须是 `nn.Parameter` 或 `torch.Tensor`(后者会自动转换为 `nn.Parameter`)。
|
||||
|
||||
4. **迭代操作(如 `__iter__`)**:
|
||||
- **作用**:支持迭代,允许遍历 `ParameterList` 中的所有参数。
|
||||
- **返回值**:迭代器,逐个返回 `nn.Parameter`。
|
||||
- **示例**:
|
||||
```python
|
||||
for param in param_list:
|
||||
print(param.shape) # 打印每个参数的形状
|
||||
```
|
||||
|
||||
5. **长度查询(如 `__len__`)**:
|
||||
- **作用**:返回 `ParameterList` 中参数的数量。
|
||||
- **返回值**:整数。
|
||||
- **示例**:
|
||||
```python
|
||||
print(len(param_list)) # 输出参数数量
|
||||
```
|
||||
|
||||
6. **其他列表操作**:
|
||||
- `ParameterList` 支持 Python 列表的常见方法,如 `pop()`, `clear()`, `insert()`, `remove()` 等,但这些方法在官方文档中未明确列出(基于 Python 的 `list` 实现)。
|
||||
- **示例**(`pop`):
|
||||
```python
|
||||
param_list = nn.ParameterList([nn.Parameter(torch.randn(2, 2)) for _ in range(3)])
|
||||
removed_param = param_list.pop(0) # 移除并返回第一个参数
|
||||
print(len(param_list)) # 输出 2
|
||||
```
|
||||
- **PyTorch Forums 讨论**:有用户询问如何在运行时移除 `ParameterList` 中的元素,可以使用 `pop(index)` 或重新构造列表:
|
||||
```python
|
||||
param_list = nn.ParameterList(param_list[:index] + param_list[index+1:])
|
||||
```
|
||||
|
||||
### 1.3 注意事项
|
||||
|
||||
- **与 `nn.ModuleList` 的区别**:`ParameterList` 存储 `nn.Parameter`,用于参数管理;`nn.ModuleList` 存储 `nn.Module`,用于子模块管理。
|
||||
- **自动转换**:任何添加到 `ParameterList` 的 `torch.Tensor` 都会被转换为 `nn.Parameter`,确保参数可被优化器识别。
|
||||
- **动态性**:`ParameterList` 适合动态模型(如变长序列模型),但操作频繁可能影响性能。
|
||||
- **优化器集成**:`ParameterList` 中的参数会自动被 `model.parameters()` 包含,优化器会更新这些参数。
|
||||
|
||||
---
|
||||
|
||||
## 2. `torch.nn.ParameterDict`
|
||||
|
||||
### 2.1 定义与作用
|
||||
|
||||
**官方定义**:
|
||||
```python
|
||||
class torch.nn.ParameterDict(parameters=None)
|
||||
```
|
||||
- **作用**:`torch.nn.ParameterDict` 是一个容器类,用于以字典的形式存储 `torch.nn.Parameter` 实例。它类似于 Python 的内置 `dict`,但专为 PyTorch 参数管理设计,存储的 `nn.Parameter` 会自动注册到 `nn.Module` 的参数列表中。
|
||||
- **特性**:
|
||||
- **键值存储**:使用字符串键来索引参数,便于按名称访问。
|
||||
- **自动注册**:与 `ParameterList` 类似,`ParameterDict` 中的参数会被 `model.parameters()` 识别。
|
||||
- **Tensor 自动转换**:赋值时,`torch.Tensor` 会被转换为 `nn.Parameter`。
|
||||
- **用途**:适用于需要按名称管理参数的场景,例如在多任务学习或复杂模型中。
|
||||
- **参数**:
|
||||
- `parameters`:一个可选的可迭代对象,包含键值对(键为字符串,值为 `nn.Parameter` 或 `torch.Tensor`)。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.params = nn.ParameterDict({
|
||||
'weight1': nn.Parameter(torch.randn(2, 2)),
|
||||
'weight2': nn.Parameter(torch.randn(2, 2))
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
x = x @ self.params['weight1']
|
||||
x = x @ self.params['weight2']
|
||||
return x
|
||||
|
||||
model = MyModel()
|
||||
print(list(model.parameters())) # 输出 2 个 (2, 2) 的参数张量
|
||||
```
|
||||
|
||||
### 2.2 方法讲解
|
||||
|
||||
`torch.nn.ParameterDict` 支持以下方法,主要继承自 Python 的 `dict`,并添加了 PyTorch 的参数管理特性。以下是所有方法的详细说明:
|
||||
|
||||
1. **`__setitem__(key, value)`**:
|
||||
- **作用**:向 `ParameterDict` 添加或更新一个键值对。如果 `value` 是 `torch.Tensor`,会自动转换为 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- `key`:字符串,参数的名称。
|
||||
- `value`:`nn.Parameter` 或 `torch.Tensor`。
|
||||
- **示例**:
|
||||
```python
|
||||
param_dict = nn.ParameterDict()
|
||||
param_dict['weight'] = torch.randn(2, 2) # 自动转换为 nn.Parameter
|
||||
print(param_dict['weight']) # 输出 nn.Parameter
|
||||
```
|
||||
|
||||
2. **`update([other])`**:
|
||||
- **作用**:使用另一个字典或键值对更新 `ParameterDict`。输入的 `torch.Tensor` 会被转换为 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- `other`:一个字典或键值对的可迭代对象。
|
||||
- **示例**:
|
||||
```python
|
||||
param_dict = nn.ParameterDict()
|
||||
param_dict.update({'weight1': torch.randn(2, 2), 'weight2': torch.randn(2, 2)})
|
||||
print(len(param_dict)) # 输出 2
|
||||
```
|
||||
|
||||
3. **索引操作(如 `__getitem__`)**:
|
||||
- **作用**:通过键访问 `ParameterDict` 中的 `nn.Parameter`。
|
||||
- **参数**:
|
||||
- 键(字符串)。
|
||||
- **返回值**:对应的 `nn.Parameter`。
|
||||
- **示例**:
|
||||
```python
|
||||
print(param_dict['weight1']) # 访问 weight1 参数
|
||||
```
|
||||
|
||||
4. **迭代操作(如 `__iter__`, `keys()`, `values()`, `items()`)**:
|
||||
- **作用**:支持迭代键、值或键值对。
|
||||
- **返回值**:
|
||||
- `keys()`:返回所有键的迭代器。
|
||||
- `values()`:返回所有 `nn.Parameter` 的迭代器。
|
||||
- `items()`:返回键值对的迭代器。
|
||||
- **示例**:
|
||||
```python
|
||||
for key, param in param_dict.items():
|
||||
print(f"Key: {key}, Shape: {param.shape}")
|
||||
```
|
||||
|
||||
5. **长度查询(如 `__len__`)**:
|
||||
- **作用**:返回 `ParameterDict` 中参数的数量。
|
||||
- **返回值**:整数。
|
||||
- **示例**:
|
||||
```python
|
||||
print(len(param_dict)) # 输出参数数量
|
||||
```
|
||||
|
||||
6. **其他字典操作**:
|
||||
- `ParameterDict` 支持 Python 字典的常见方法,如 `pop(key)`, `clear()`, `popitem()`, `get(key, default=None)` 等。
|
||||
- **示例**(`pop`):
|
||||
```python
|
||||
param_dict = nn.ParameterDict({'weight1': torch.randn(2, 2), 'weight2': torch.randn(2, 2)})
|
||||
removed_param = param_dict.pop('weight1') # 移除并返回 weight1
|
||||
print(len(param_dict)) # 输出 1
|
||||
```
|
||||
|
||||
### 2.3 注意事项
|
||||
|
||||
- **与 `nn.ModuleDict` 的区别**:`ParameterDict` 存储 `nn.Parameter`,用于参数管理;`nn.ModuleDict` 存储 `nn.Module`,用于子模块管理。
|
||||
- **键的唯一性**:`ParameterDict` 的键必须是字符串,且不能重复。
|
||||
- **动态管理**:适合需要按名称访问参数的场景,但操作复杂模型时需注意性能开销。
|
||||
- **优化器集成**:`ParameterDict` 中的参数也会被 `model.parameters()` 包含,优化器会自动更新。
|
||||
|
||||
---
|
||||
|
||||
## 3. 比较与使用场景
|
||||
|
||||
| 特性 | `ParameterList` | `ParameterDict` |
|
||||
|---------------------|--------------------------------------------|--------------------------------------------|
|
||||
| **存储方式** | 列表(按索引访问) | 字典(按键访问) |
|
||||
| **访问方式** | 索引(`param_list[0]`) | 键(`param_dict['key']`) |
|
||||
| **主要方法** | `append`, `extend`, `pop`, 索引操作 | `update`, `pop`, `keys`, `values`, `items` |
|
||||
| **适用场景** | 顺序参数管理(如循环网络中的多层权重) | 命名参数管理(如多任务模型中的权重) |
|
||||
| **动态性** | 适合动态添加/移除参数 | 适合按名称管理参数 |
|
||||
|
||||
**选择建议**:
|
||||
- 如果参数需要按顺序访问或动态增减,使用 `ParameterList`。
|
||||
- 如果需要按名称管理参数或参数具有明确语义,使用 `ParameterDict`。
|
||||
|
||||
---
|
||||
|
||||
## 4. 综合示例
|
||||
|
||||
以下是一个结合 `ParameterList` 和 `ParameterDict` 的示例,展示它们在模型中的使用:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ComplexModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 使用 ParameterList 存储一组权重
|
||||
self.list_params = nn.ParameterList([nn.Parameter(torch.randn(2, 2)) for _ in range(2)])
|
||||
# 使用 ParameterDict 存储命名参数
|
||||
self.dict_params = nn.ParameterDict({
|
||||
'conv_weight': nn.Parameter(torch.randn(3, 3)),
|
||||
'fc_weight': nn.Parameter(torch.randn(4, 4))
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
# 使用 ParameterList
|
||||
for param in self.list_params:
|
||||
x = x @ param
|
||||
# 使用 ParameterDict
|
||||
x = x @ self.dict_params['conv_weight']
|
||||
x = x @ self.dict_params['fc_weight']
|
||||
return x
|
||||
|
||||
model = ComplexModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
# 打印所有参数
|
||||
for name, param in model.named_parameters():
|
||||
print(f"Parameter name: {name}, Shape: {param.shape}")
|
||||
|
||||
# 动态修改 ParameterList
|
||||
model.list_params.append(nn.Parameter(torch.ones(2, 2)))
|
||||
|
||||
# 动态修改 ParameterDict
|
||||
model.dict_params['new_weight'] = nn.Parameter(torch.zeros(4, 4))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 常见问题与解答
|
||||
|
||||
1. **如何高效合并两个 `ParameterList`?**
|
||||
- 使用 `extend` 方法或解包方式:
|
||||
```python
|
||||
param_list = nn.ParameterList()
|
||||
param_list.extend(sub_list_1)
|
||||
param_list.extend(sub_list_2)
|
||||
# 或者
|
||||
param_list = nn.ParameterList([*sub_list_1, *sub_list_2])
|
||||
```
|
||||
|
||||
2. **如何从 `ParameterList` 删除元素?**
|
||||
- 使用 `pop(index)` 或重新构造列表:
|
||||
```python
|
||||
param_list.pop(0) # 移除第一个参数
|
||||
# 或者
|
||||
param_list = nn.ParameterList([param for i, param in enumerate(param_list) if i != 0])
|
||||
```
|
||||
|
||||
3. **如何检查 `ParameterDict` 中的参数?**
|
||||
- 使用 `keys()`, `values()`, 或 `items()` 遍历:
|
||||
```python
|
||||
for key, param in param_dict.items():
|
||||
print(f"Key: {key}, Parameter: {param}")
|
||||
```
|
||||
|
||||
4. **如何初始化参数?**
|
||||
- 使用 `torch.nn.init` 或通过 `nn.Module.apply`:
|
||||
```python
|
||||
import torch.nn.init as init
|
||||
for param in param_list:
|
||||
init.xavier_uniform_(param)
|
||||
for param in param_dict.values():
|
||||
init.xavier_uniform_(param)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 总结
|
||||
|
||||
- **`torch.nn.ParameterList`**:
|
||||
- 方法:`append`, `extend`, `pop`, 索引/迭代操作。
|
||||
- 特点:类似列表,适合顺序管理参数,自动将 `torch.Tensor` 转换为 `nn.Parameter`。
|
||||
- 场景:动态模型、变长参数列表。
|
||||
- **`torch.nn.ParameterDict`**:
|
||||
- 方法:`update`, `pop`, `keys`, `values`, `items`, 索引/迭代操作。
|
||||
- 特点:类似字典,适合按名称管理参数,自动转换 `torch.Tensor`。
|
||||
- 场景:多任务学习、需要语义化命名的模型。
|
||||
|
||||
**参考文献**:
|
||||
- PyTorch 官方文档:`torch.nn.ParameterList` 和 `torch.nn.ParameterDict`(`docs.pytorch.org`)[](https://docs.pytorch.org/docs/stable/generated/torch.nn.ParameterList.html)[](https://docs.pytorch.org/docs/stable/nn.html)
|
||||
- Stack Overflow:合并 `ParameterList` 的讨论[](https://stackoverflow.com/questions/70779631/combining-parameterlist-in-pytorch)
|
||||
- PyTorch Forums:`ParameterList` 的使用和动态操作[](https://discuss.pytorch.org/t/using-nn-parameterlist/86742)
|
||||
|
||||
如果需要更详细的代码示例、特定方法的实现,或其他相关问题(如与优化器的集成),请告诉我!
|
Reference in New Issue
Block a user