PyTorch快速开始

创建于 2024年7月27日修改于 2024年7月27日
PyTorchPython

本文带你速览一遍 PyTorch 机器学习任务的常用API。想要深入了解,可点击各节提供的链接。

Contents

处理数据

PyTorch 提供操作数据的两个原语操作torch.utils.data.DataLoadertorch.utils.data.Dataset。 Dataset 存储样本及其对应的标签,DataLoader 为数据集 Dataset 包装了一个可迭代对象。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

PyTorch 提供了特定领域的库,例如 TorchTextTorchVisionTorchAudio,这些库都自带数据集。本教程将使用 TorchVision 数据集。

torchvision.datasets 模块包含许多业界视觉数据的 Dataset 对象,例如 CIFAR、COCO(完整列表在此)。在本教程中,我们使用 FashionMNIST 数据集。每个 TorchVision 数据集都包括两个参数:transformtarget_transform,分别用于修改样本(samples)和标签(labels)。

# 从开放数据集中下载训练数据
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# 从开放数据集中下载验证测试数据
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 32768/26421880 [00:00<02:24, 182945.57it/s]
  0%|          | 65536/26421880 [00:00<02:25, 181484.75it/s]
  0%|          | 98304/26421880 [00:00<02:25, 180977.88it/s]
  0%|          | 131072/26421880 [00:00<02:25, 180861.96it/s]
  1%|          | 196608/26421880 [00:00<02:05, 208200.89it/s]
  1%|          | 262144/26421880 [00:01<01:42, 254132.99it/s]
  1%|1         | 294912/26421880 [00:01<01:52, 232313.95it/s]
  1%|1         | 360448/26421880 [00:01<01:36, 270874.17it/s]
  2%|1         | 425984/26421880 [00:01<01:27, 297669.11it/s]
  2%|1         | 491520/26421880 [00:01<01:21, 317096.10it/s]
  2%|2         | 589824/26421880 [00:02<01:07, 383592.80it/s]
  2%|2         | 655360/26421880 [00:02<01:08, 377776.82it/s]
  3%|2         | 753664/26421880 [00:02<01:00, 426529.09it/s]
  3%|3         | 851968/26421880 [00:02<00:55, 461764.06it/s]
  4%|3         | 983040/26421880 [00:02<00:47, 538213.56it/s]
  4%|4         | 1114112/26421880 [00:02<00:42, 592530.18it/s]
  5%|4         | 1245184/26421880 [00:03<00:39, 632474.48it/s]
  5%|5         | 1376256/26421880 [00:03<00:35, 714894.16it/s]
  6%|5         | 1474560/26421880 [00:03<00:34, 716158.63it/s]
  6%|6         | 1638400/26421880 [00:03<00:29, 847754.23it/s]
  7%|6         | 1736704/26421880 [00:03<00:30, 818063.83it/s]
  7%|7         | 1966080/26421880 [00:03<00:25, 965142.80it/s]
  8%|8         | 2162688/26421880 [00:04<00:23, 1011318.88it/s]
  9%|9         | 2424832/26421880 [00:04<00:20, 1145733.67it/s]
 10%|#         | 2686976/26421880 [00:04<00:19, 1244417.66it/s]
 11%|#1        | 2981888/26421880 [00:04<00:17, 1362155.80it/s]
 13%|#2        | 3309568/26421880 [00:04<00:14, 1614458.43it/s]
 13%|#3        | 3506176/26421880 [00:04<00:14, 1572814.00it/s]
 15%|#4        | 3899392/26421880 [00:05<00:12, 1771000.86it/s]
 16%|#6        | 4292608/26421880 [00:05<00:10, 2059984.16it/s]
 17%|#7        | 4554752/26421880 [00:05<00:10, 2024741.07it/s]
 19%|#8        | 5013504/26421880 [00:05<00:08, 2398785.61it/s]
 20%|##        | 5308416/26421880 [00:05<00:09, 2340504.89it/s]
 22%|##2       | 5865472/26421880 [00:05<00:07, 2601514.06it/s]
 25%|##4       | 6488064/26421880 [00:05<00:06, 2879025.08it/s]
 27%|##7       | 7176192/26421880 [00:06<00:06, 3173470.38it/s]
 30%|###       | 7929856/26421880 [00:06<00:05, 3481690.43it/s]
 33%|###3      | 8749056/26421880 [00:06<00:04, 4100645.60it/s]
 35%|###4      | 9207808/26421880 [00:06<00:04, 3928564.95it/s]
 39%|###8      | 10190848/26421880 [00:06<00:03, 4427497.65it/s]
 43%|####2     | 11239424/26421880 [00:06<00:03, 4864381.85it/s]
 47%|####7     | 12419072/26421880 [00:07<00:02, 5380687.53it/s]
 52%|#####1    | 13729792/26421880 [00:07<00:02, 5945600.15it/s]
 57%|#####7    | 15138816/26421880 [00:07<00:01, 7022425.02it/s]
 60%|######    | 15925248/26421880 [00:07<00:01, 6725765.46it/s]
 66%|######6   | 17563648/26421880 [00:07<00:01, 8138049.86it/s]
 70%|######9   | 18481152/26421880 [00:07<00:01, 7763367.33it/s]
 77%|#######7  | 20348928/26421880 [00:08<00:00, 9622700.99it/s]
 81%|########  | 21397504/26421880 [00:08<00:00, 8931545.30it/s]
 89%|########9 | 23592960/26421880 [00:08<00:00, 10988473.31it/s]
 94%|#########3| 24805376/26421880 [00:08<00:00, 10422648.85it/s]
100%|##########| 26421880/26421880 [00:08<00:00, 3095816.12it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 324595.50it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|1         | 65536/4422102 [00:00<00:11, 364151.94it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 683437.57it/s]
 16%|#6        | 720896/4422102 [00:00<00:01, 2024879.07it/s]
 34%|###4      | 1507328/4422102 [00:00<00:00, 3149519.47it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 6090071.00it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 28636972.14it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

我们将 Dataset 作为 DataLoader 对象传递,包装数据集,并提供了批处理、采样、混洗和多进程数据加载的功能。这里定义了批大小为 64。这样,每次迭代时,数据加载器返回 64 个特征和标签。

batch_size = 64

# 创建数据 loader
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64

torch.Size 是一个表示张量(Tensor)维度的元组。

了解更多关于数据加载的内容。

创建模型

定义神经网络需要继承 nn.Module 并实现 forward 方法。我们可以在 __init__ 方法中初始化网络层。 为了加速运算操作,在硬件设备允许的情况系啊,我们可以将其移到 GPU 或者 MPS。

# 获取 cpu, gpu 或 mps 设备以进行训练。
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# 定义模型类
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)
Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

了解更多关于创建神经网络的内容。

优化模型参数

要训练模型,我们需要一个损失函数(loss function)和一个优化器(optimizer)。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

在一个训练循环中,模型对训练数据集(以批次形式提供)进行预测,并通过反向传播预测误差来调整模型的参数。

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播 Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

我们还需检查模型在测试数据集上的表现,以确保它有成功学习。

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程在多个迭代(epoch)中进行。在每个迭代中,模型学习参数以进行更好的预测。我们打印每个迭代的模型准确度(accuracy)和损失(loss);我们希望看到准确度增加和损失减少。

The training process is conducted over several iterations (epochs). During each epoch, the model learns parameters to make better predictions. We print the model’s accuracy and loss at each epoch; we’d like to see the accuracy increase and the loss decrease with every epoch.

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.303494  [   64/60000]
loss: 2.294637  [ 6464/60000]
loss: 2.277102  [12864/60000]
loss: 2.269977  [19264/60000]
loss: 2.254235  [25664/60000]
loss: 2.237146  [32064/60000]
loss: 2.231055  [38464/60000]
loss: 2.205037  [44864/60000]
loss: 2.203240  [51264/60000]
loss: 2.170889  [57664/60000]
Test Error:
 Accuracy: 53.9%, Avg loss: 2.168588

Epoch 2
-------------------------------
loss: 2.177787  [   64/60000]
loss: 2.168083  [ 6464/60000]
loss: 2.114910  [12864/60000]
loss: 2.130412  [19264/60000]
loss: 2.087473  [25664/60000]
loss: 2.039670  [32064/60000]
loss: 2.054274  [38464/60000]
loss: 1.985457  [44864/60000]
loss: 1.996023  [51264/60000]
loss: 1.917241  [57664/60000]
Test Error:
 Accuracy: 60.2%, Avg loss: 1.920374

Epoch 3
-------------------------------
loss: 1.951705  [   64/60000]
loss: 1.919516  [ 6464/60000]
loss: 1.808730  [12864/60000]
loss: 1.846550  [19264/60000]
loss: 1.740618  [25664/60000]
loss: 1.698733  [32064/60000]
loss: 1.708889  [38464/60000]
loss: 1.614436  [44864/60000]
loss: 1.646475  [51264/60000]
loss: 1.524308  [57664/60000]
Test Error:
 Accuracy: 61.4%, Avg loss: 1.547092

Epoch 4
-------------------------------
loss: 1.612695  [   64/60000]
loss: 1.570870  [ 6464/60000]
loss: 1.424730  [12864/60000]
loss: 1.489542  [19264/60000]
loss: 1.367256  [25664/60000]
loss: 1.373464  [32064/60000]
loss: 1.376744  [38464/60000]
loss: 1.304962  [44864/60000]
loss: 1.347154  [51264/60000]
loss: 1.230661  [57664/60000]
Test Error:
 Accuracy: 62.7%, Avg loss: 1.260891

Epoch 5
-------------------------------
loss: 1.337803  [   64/60000]
loss: 1.313278  [ 6464/60000]
loss: 1.151837  [12864/60000]
loss: 1.252142  [19264/60000]
loss: 1.123048  [25664/60000]
loss: 1.159531  [32064/60000]
loss: 1.175011  [38464/60000]
loss: 1.115554  [44864/60000]
loss: 1.160974  [51264/60000]
loss: 1.062730  [57664/60000]
Test Error:
 Accuracy: 64.6%, Avg loss: 1.087374

Done!

阅读更多关于训练模型的信息。

保存模型

保存模型的一种常见方法是序列化内部状态字典(包含模型参数)。

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
Saved PyTorch Model State to model.pth

加载模型

加载模型的过程包括重新创建模型结构,并将状态字典加载到其中。

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))
<All keys matched successfully>

这个模型现在可以用来进行预测。

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "Ankle boot", Actual: "Ankle boot"

阅读更多关于保存和加载模型的信息。

原文链接