Files
2025-09-09 15:10:57 +08:00
..
2025-09-09 15:10:57 +08:00
2025-09-09 15:10:57 +08:00
2025-09-09 15:10:57 +08:00
2025-09-09 15:10:57 +08:00
2025-09-09 15:10:57 +08:00

好的 😄,我给你整理一个 PyTorch tensor 使用大全,覆盖创建、操作、转换、计算等常用方法。为了系统化,我会按类别来整理,并给简单示例。


🟢 1. 创建 Tensor

import torch

# 直接从数据创建
a = torch.tensor([1, 2, 3], dtype=torch.float32)

# 全 0 / 全 1 / 全某数
b = torch.zeros(3, 4)
c = torch.ones(2, 5) * 7

# 随机数
d = torch.rand(3, 3)       # 0~1 均匀分布
e = torch.randn(3, 3)      # 标准正态分布

# 类似 numpy 的 arange / linspace
f = torch.arange(0, 10, 2) # 0,2,4,6,8
g = torch.linspace(0, 1, steps=5) # 0,0.25,0.5,0.75,1

# 空 tensor占位
h = torch.empty(2, 3)

🟢 2. Tensor 属性

x = torch.randn(2, 3)

x.shape       # 返回 (2,3)
x.size()      # 等同 shape
x.dtype       # 数据类型,如 torch.float32
x.device      # 运行设备,如 cpu 或 cuda:0
x.numel()     # 元素总数
x.requires_grad_(True)  # 设置是否需要梯度

🟢 3. Tensor 运算

3.1 基本算术

a = torch.tensor([1,2,3])
b = torch.tensor([4,5,6])

c = a + b
d = a - b
e = a * b
f = a / b
g = a ** 2

3.2 矩阵运算

A = torch.rand(2,3)
B = torch.rand(3,2)

C = torch.matmul(A, B)   # 矩阵乘法
D = A @ B                # 等价写法
E = A.T                  # 转置

3.3 统计函数

x = torch.tensor([[1,2],[3,4]], dtype=torch.float32)

x.sum()       # 所有元素求和
x.mean()      # 平均值
x.max()       # 最大值
x.min()       # 最小值
x.argmax()    # 最大值索引
x.argmin()    # 最小值索引
x.std()       # 标准差
x.var()       # 方差

🟢 4. Tensor 索引与切片

x = torch.tensor([[1,2,3],[4,5,6]])

x[0]       # 第一行
x[:,1]     # 第二列
x[0,2]     # 第一行第三列
x[1,:2]    # 第二行前两列
x[-1,-1]   # 最后一个元素

🟢 5. Tensor 改变形状

x = torch.arange(12)

y = x.view(3,4)       # 改变形状 (reshape)
z = x.reshape(3,4)    # reshape 等同 view但更安全
w = x.unsqueeze(0)    # 增加维度
v = x.squeeze()       # 删除维度为1的维
t = x.transpose(0,0)  # 交换维度

🟢 6. Tensor 拼接与拆分

a = torch.randn(2,3)
b = torch.randn(2,3)

# 拼接
c = torch.cat([a,b], dim=0)   # 纵向拼接
d = torch.cat([a,b], dim=1)   # 横向拼接

# 堆叠(增加新维度)
e = torch.stack([a,b], dim=0) 

# 拆分
f1, f2 = torch.chunk(c, 2, dim=0)  # 按行拆成两块

🟢 7. Tensor 类型转换

x = torch.tensor([1,2,3], dtype=torch.int32)

x.float()       # 转 float
x.double()      # 转 double
x.long()        # 转 int64
x.bool()        # 转 bool

🟢 8. Tensor 与 NumPy 互转

import numpy as np

a = torch.tensor([1,2,3])
b = a.numpy()              # Tensor → numpy array

c = np.array([4,5,6])
d = torch.from_numpy(c)    # numpy → Tensor

⚠️ 注意:共享内存,如果修改其中一个,另一个也会变。


🟢 9. GPU 操作

x = torch.randn(2,3)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
x = x.to(device)

y = torch.tensor([1,2,3], device=device)  # 直接在 GPU 上创建

🟢 10. 自动求导(梯度)

x = torch.tensor([2.0,3.0], requires_grad=True)
y = x ** 2 + 3*x
z = y.sum()
z.backward()  # 自动求梯度
print(x.grad) # dx

🟢 11. 常用函数速查

类别 常用函数
线性代数 matmul, mm, bmm, einsum
激活函数 relu, sigmoid, tanh, softmax
聚合函数 sum, mean, max, min, prod, cumsum
比较函数 eq, ne, gt, lt, ge, le
随机函数 rand, randn, randint, randperm
其他 clone, detach, item, expand, repeat

如果你需要,我可以帮你整理一个 可直接拷贝的 PyTorch Tensor 使用小抄(含示例和用法备注),适合放在屏幕边上快速查。

你希望我帮你整理吗?