PyTorch和torchvision为例,怎样利用预训练的ResNet模子来训练水稻虫害分类数据集 14类 从数据准备到模子训练、评估全流程

[复制链接]
发表于 2025-5-7 18:07:40 | 显示全部楼层 |阅读模式
PyTorch和torchvision为例,怎样利用预训练的ResNet模子来训练水稻虫害分类数据集 14类 从数据准备到模子训练、评估全流程

  
水稻虫害分类数据集
包含14个类别共8417张图像:稻纵卷叶螟(rice leaf roller)、稻叶毛虫(rice leaf caterpillar)、稻潜叶蝇(paddy stem maggot)、水稻二化螟(asiatic rice borer)、水稻三化螟(yellow rice borer)、稻瘿蚊(rice gall midge)、稻秆蝇(Rice Stemfly)
褐稻虱(brown plant hopper)、白背飞虱(white backed plant)、灰飞虱(small brown plant)、稻水象甲(rice water weevil)、稻叶蝉(rice leafhopper)、谷物撒布机蓟马(grain spreader thrips)、稻苞虫(rice shell pest),训练集、验证集、测试集分别有5043、843、2531张

利用得当图像分类使命的模子举行处理。对于图像分类使命,可以利用ResNet、EfficientNet等深度学习模子。这里以PyTorch和torchvision为例,展示怎样利用预训练的ResNet模子来训练这个水稻虫害分类数据集。

1. 环境准备

确保安装了须要的依赖项:
  1. pip install torch torchvision torchaudio matplotlib
复制代码
2. 数据准备

首先,确保您的数据集按照以下结构组织:
  1. path/to/dataset/
  2.     train/
  3.         rice_leaf_roller/
  4.             img1.jpg
  5.             img2.jpg
  6.             ...
  7.         rice_leaf_caterpillar/
  8.             img1.jpg
  9.             img2.jpg
  10.             ...
  11.         ...
  12.     val/
  13.         rice_leaf_roller/
  14.             img1.jpg
  15.             img2.jpg
  16.             ...
  17.         rice_leaf_caterpillar/
  18.             img1.jpg
  19.             img2.jpg
  20.             ...
  21.         ...
  22.     test/  # 如果有测试集的话
  23.         rice_leaf_roller/
  24.             img1.jpg
  25.             img2.jpg
  26.             ...
  27.         rice_leaf_caterpillar/
  28.             img1.jpg
  29.             img2.jpg
  30.             ...
  31.         ...
复制代码

3. 数据加载与加强

利用torchvision.datasets.ImageFolder来加载数据,并应用一些数据加强技术。
  1. from torchvision import datasets, transforms
  2. from torch.utils.data import DataLoader
  3. # 定义数据预处理流程
  4. data_transforms = {
  5.     'train': transforms.Compose([
  6.         transforms.RandomResizedCrop(224),
  7.         transforms.RandomHorizontalFlip(),
  8.         transforms.ToTensor(),
  9.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  10.     ]),
  11.     'val': transforms.Compose([
  12.         transforms.Resize(256),
  13.         transforms.CenterCrop(224),
  14.         transforms.ToTensor(),
  15.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  16.     ]),
  17. }
  18. data_dir = './path/to/dataset'
  19. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
  20. dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']}
  21. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
  22. class_names = image_datasets['train'].classes
复制代码
4. 模子界说与训练

利用预训练的ResNet模子并举行微调。
  1. import torch.nn as nn
  2. import torch.optim as optim
  3. from torchvision import models
  4. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  5. # 加载预训练的ResNet模型
  6. model_ft = models.resnet18(pretrained=True)
  7. num_ftrs = model_ft.fc.in_features
  8. # 更改最后全连接层的输出为类别数
  9. model_ft.fc = nn.Linear(num_ftrs, len(class_names))
  10. model_ft = model_ft.to(device)
  11. criterion = nn.CrossEntropyLoss()
  12. # 观察所有参数都更新
  13. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  14. # 每7个epoch后降低学习率
  15. exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
复制代码
训练模子

  1. def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
  2.     best_model_wts = model.state_dict()
  3.     best_acc = 0.0
  4.     for epoch in range(num_epochs):
  5.         print(f'Epoch {epoch}/{num_epochs - 1}')
  6.         print('-' * 10)
  7.         # 每个epoch都有训练和验证阶段
  8.         for phase in ['train', 'val']:
  9.             if phase == 'train':
  10.                 model.train()  # 设置模型为训练模式
  11.             else:
  12.                 model.eval()   # 设置模型为评估模式
  13.             running_loss = 0.0
  14.             running_corrects = 0
  15.             # 迭代数据
  16.             for inputs, labels in dataloaders[phase]:
  17.                 inputs = inputs.to(device)
  18.                 labels = labels.to(device)
  19.                 # 零参数梯度
  20.                 optimizer.zero_grad()
  21.                 # 前向传播
  22.                 with torch.set_grad_enabled(phase == 'train'):
  23.                     outputs = model(inputs)
  24.                     _, preds = torch.max(outputs, 1)
  25.                     loss = criterion(outputs, labels)
  26.                     # 只在训练阶段反向传播和优化
  27.                     if phase == 'train':
  28.                         loss.backward()
  29.                         optimizer.step()
  30.                 # 统计损失
  31.                 running_loss += loss.item() * inputs.size(0)
  32.                 running_corrects += torch.sum(preds == labels.data)
  33.             if phase == 'train':
  34.                 scheduler.step()
  35.             epoch_loss = running_loss / dataset_sizes[phase]
  36.             epoch_acc = running_corrects.double() / dataset_sizes[phase]
  37.             print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
  38.             # 深拷贝模型
  39.             if phase == 'val' and epoch_acc > best_acc:
  40.                 best_acc = epoch_acc
  41.                 best_model_wts = model.state_dict()
  42.     print(f'Best val Acc: {best_acc:4f}')
  43.     # 加载最佳模型权重
  44.     model.load_state_dict(best_model_wts)
  45.     return model
  46. # 训练并保存最佳模型
  47. model = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)
  48. torch.save(model.state_dict(), './rice_pest_classification_resnet.pth')
复制代码
5. 测试模子(可选)

假如有一个单独的测试集,可以利用相似的方法来评估模子性能
  1. test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), data_transforms['val'])
  2. test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
  3. model.eval()
  4. running_corrects = 0
  5. for inputs, labels in test_dataloader:
  6.     inputs = inputs.to(device)
  7.     labels = labels.to(device)
  8.     with torch.no_grad():
  9.         outputs = model(inputs)
  10.         _, preds = torch.max(outputs, 1)
  11.    
  12.     running_corrects += torch.sum(preds == labels.data)
  13. accuracy = running_corrects.double() / len(test_dataset)
  14. print(f'Test Accuracy: {accuracy:.4f}')
复制代码
以上步骤提供了一个完备的流程,从环境配置到数据准备、模子训练及评估的具体实现。确保你根据实际环境调整路径和其他设置。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
回复

使用道具 举报

×
登录参与点评抽奖,加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表