大家好,欢迎来到IT知识分享网。
前言
为了区分深度学习中这两者的定义,详细讲解其关系以及代码
1. 定义
在 PyTorch 中,“epoch”(周期)和 “batch size”(批大小)是训练神经网络时的两个重要概念
它们用于控制训练的迭代和数据处理方式。
一、Epoch(周期):
- Epoch 是指整个训练数据集被神经网络完整地遍历一次的次数。
- 在每个 epoch 中,模型会一次又一次地使用数据集中的不同样本进行训练,以更新模型的权重。
- 通常,一个 epoch 包含多个迭代(iterations),每个迭代是一次权重更新的过程。
- 训练多个 epoch 的目的是让模型不断地学习,提高性能,直到收敛到最佳性能或达到停止条件。
二、Batch Size(批大小):
- Batch size 指的是每次模型权重更新时所使用的样本数。
- 通过将训练数据分成小批次,可以实现并行计算,提高训练效率。
- 较大的 batch size 可能会加速训练,但可能需要更多内存和计算资源。较小的 batch size 可能更适合小型数据集或资源受限的情况。
- 常见的 batch size 值通常是 32、64、128 等。
三、如何理解它们的关系:
- 在训练过程中,每个 epoch 包含多个 batch,而 batch size 决定了每个 batch 中包含多少样本。
- 在每个 epoch 开始时,数据集会被随机划分为多个 batch,然后模型使用这些 batch 逐一进行前向传播和反向传播,从而更新权重。
- 一次 epoch 完成后,数据集会被重新随机划分为新的 batch,这个过程会重复多次,直到完成指定数量的 epoch 或达到停止条件。
200个样本(数据行)的数据集
选择batch为5,epoch为1000
代表
划分为40个batch了,也就是每个batch有5个样本
每个batch训练好之后就会更新下
一个epoch 也就是40个batch,刚好200个样本训练一轮
总之,epoch 控制了整个训练的迭代次数,而 batch size 决定了每次迭代中处理的样本数量。这两个参数的选择取决于你的任务和资源,通常需要进行调优以获得最佳性能。
2. 代码
大致深度学习的代码中如下:
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset # 创建一个包含数字1到10的数据集 X_train = torch.arange(1, 11, dtype=torch.float32) y_train = X_train * 2 # 假设我们的任务是学习一个简单的线性关系,y = 2x # 转换数据为 PyTorch 张量 X_train = X_train.view(-1, 1) # 将数据转换为列向量 y_train = y_train.view(-1, 1) # 定义神经网络模型 model = nn.Sequential( nn.Linear(1, 1) ) # 定义损失函数和优化器 criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # 创建 DataLoader 并指定 batch size batch_size = 3 train_dataset = TensorDataset(X_train, y_train) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 训练循环 num_epochs = 10 for epoch in range(num_epochs): total_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() print("inputs:",inputs.numpy()) average_loss = total_loss / len(train_loader) print(f"Epoch {
epoch + 1}/{
num_epochs}, Loss: {
average_loss:.4f}")
执行完的结果截图:
大致结果详细如下:
inputs: [[1.] [8.] [7.]] inputs: [[4.] [3.] [6.]] inputs: [[ 5.] [ 9.] [10.]] inputs: [[2.]] Epoch 1/10, Loss: 39.6693 inputs: [[ 1.] [ 2.] [10.]] inputs: [[9.] [8.] [6.]] inputs: [[5.] [3.] [7.]] inputs: [[4.]] Epoch 2/10, Loss: 0.1154 inputs: [[2.] [1.] [9.]] inputs: [[10.] [ 5.] [ 4.]] inputs: [[6.] [8.] [7.]] inputs: [[3.]] Epoch 3/10, Loss: 0.0317 inputs: [[7.] [9.] [1.]] inputs: [[6.] [3.] [4.]] inputs: [[10.] [ 8.] [ 5.]] inputs: [[2.]] Epoch 4/10, Loss: 0.0414 inputs: [[9.] [6.] [4.]] inputs: [[2.] [3.] [1.]] inputs: [[ 8.] [10.] [ 5.]] inputs: [[7.]] Epoch 5/10, Loss: 0.0260 inputs: [[6.] [3.] [4.]] inputs: [[ 5.] [10.] [ 8.]] inputs: [[2.] [7.] [9.]] inputs: [[1.]] Epoch 6/10, Loss: 0.0386 inputs: [[ 6.] [10.] [ 4.]] inputs: [[5.] [7.] [8.]] inputs: [[1.] [9.] [2.]] inputs: [[3.]] Epoch 7/10, Loss: 0.0254 inputs: [[6.] [8.] [2.]] inputs: [[ 3.] [10.] [ 1.]] inputs: [[9.] [4.] [5.]] inputs: [[7.]] Epoch 8/10, Loss: 0.0197 inputs: [[ 2.] [ 3.] [10.]] inputs: [[9.] [4.] [5.]] inputs: [[8.] [1.] [6.]] inputs: [[7.]] Epoch 9/10, Loss: 0.0179 inputs: [[ 7.] [ 9.] [10.]] inputs: [[3.] [2.] [5.]] inputs: [[4.] [1.] [8.]] inputs: [[6.]] Epoch 10/10, Loss: 0.0216
这说明一个epoch会把整个数据都训练完
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/143763.html