实战:掌握PyTorch图片分类的简明教程 | 附完整代码

作者 |?小宋是呢
转载自CSDN博客
1.引文
深度学习的比赛中,图片分类是很常见的比赛,同时也是很难取得特别高名次的比赛,因为图片分类已经被大家研究的很透彻,一些开源的网络很容易取得高分。如果大家还掌握不了使用开源的网络进行训练,再慢慢去模型调优,很难取得较好的成绩。
我们在[PyTorch小试牛刀]实战六·准备自己的数据集用于训练讲解了如何制作自己的数据集用于训练,这个教程在此基础上,进行训练与应用。
(实战六链接:
https://blog.csdn.net/xiaosongshine/article/details/85225873)
2.数据介绍
数据下载地址:
https://download.csdn.net/download/xiaosongshine/11128410
这次的实战使用的数据是交通标志数据集,共有62类交通标志。其中训练集数据有4572张照片(每个类别大概七十个),测试数据集有2520张照片(每个类别大概40个)。数据包含两个子目录分别train与test:
为什么还需要测试数据集呢?这个测试数据集不会拿来训练,是用来进行模型的评估与调优。

train与test每个文件夹里又有62个子文件夹,每个类别在同一个文件夹内:

我从中打开一个文件间,把里面图片展示出来:

其中每张照片都类似下面的例子,100*100*3的大小。100是照片的照片的长和宽,3是什么呢?这其实是照片的色彩通道数目,RGB。彩色照片存储在计算机里就是以三维数组的形式。我们送入网络的也是这些数组。
3.网络构建
1.导入Python包,定义一些参数
1import?torch?as?t
2import?torchvision?as?tv
3import?os
4import?time
5import?numpy?as?np
6from?tqdm?import?tqdm
7
8
9class?DefaultConfigs(object):
10
11????data_dir?=?"./traffic-sign/"
12????data_list?=?["train","test"]
13
14????lr?=?0.001
15????epochs?=?10
16????num_classes?=?62
17????image_size?=?224
18????batch_size?=?40
19????channels?=?3
20????gpu?=?"0"
21????train_len?=?4572
22????test_len?=?2520
23????use_gpu?=?t.cuda.is_available()
24
25config?=?DefaultConfigs()
2.数据准备,采用PyTorch提供的读取方式
注意一点Train数据需要进行随机裁剪,Test数据不要进行裁剪了
1normalize?=?tv.transforms.Normalize(mean?=?[0.485,?0.456,?0.406],
2????????????????????????????????????std?=?[0.229,?0.224,?0.225]
3????????????????????????????????????)
4
5transform?=?{
6????config.data_list[0]:tv.transforms.Compose(
7????????[tv.transforms.Resize([224,224]),tv.transforms.CenterCrop([224,224]),
8????????tv.transforms.ToTensor(),normalize]#tv.transforms.Resize?用于重设图片大小
9????)?,
10????config.data_list[1]:tv.transforms.Compose(
11????????[tv.transforms.Resize([224,224]),tv.transforms.ToTensor(),normalize]
12????)?
13}
14
15datasets?=?{
16????x:tv.datasets.ImageFolder(root?=?os.path.join(config.data_dir,x),transform=transform[x])
17????for?x?in?config.data_list
18}
19
20dataloader?=?{
21????x:t.utils.data.DataLoader(dataset=?datasets[x],
22????????batch_size=config.batch_size,
23????????shuffle=True
24????)?
25????for?x?in?config.data_list
26}
3.构建网络模型(使用resnet18进行迁移学习,训练参数为最后一个全连接层 t.nn.Linear(512,num_classes))?
1def?get_model(num_classes):
2
3????model?=?tv.models.resnet18(pretrained=True)
4????for?parma?in?model.parameters():
5????????parma.requires_grad?=?False
6????model.fc?=?t.nn.Sequential(
7????????t.nn.Dropout(p=0.3),
8????????t.nn.Linear(512,num_classes)
9????)
10????return(model)
如果电脑硬件支持,可以把下述代码屏蔽,则训练整个网络,最终准确率会上升,训练数据会变慢。
1for?parma?in?model.parameters():
2????parma.requires_grad?=?False
模型输出
1ResNet(
2??(conv1):?Conv2d(3,?64,?kernel_size=(7,?7),?stride=(2,?2),?padding=(3,?3),?bias=False)
3??(bn1):?BatchNorm2d(64,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
4??(relu):?ReLU(inplace)
5??(maxpool):?MaxPool2d(kernel_size=3,?stride=2,?padding=1,?dilation=1,?ceil_mode=False)
6??(layer1):?Sequential(
7????(0):?BasicBlock(
8??????(conv1):?Conv2d(64,?64,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
9??????(bn1):?BatchNorm2d(64,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
10??????(relu):?ReLU(inplace)
11??????(conv2):?Conv2d(64,?64,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
12??????(bn2):?BatchNorm2d(64,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
13????)
14????(1):?BasicBlock(
15??????(conv1):?Conv2d(64,?64,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
16??????(bn1):?BatchNorm2d(64,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
17??????(relu):?ReLU(inplace)
18??????(conv2):?Conv2d(64,?64,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
19??????(bn2):?BatchNorm2d(64,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
20????)
21??)
22??(layer2):?Sequential(
23????(0):?BasicBlock(
24??????(conv1):?Conv2d(64,?128,?kernel_size=(3,?3),?stride=(2,?2),?padding=(1,?1),?bias=False)
25??????(bn1):?BatchNorm2d(128,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
26??????(relu):?ReLU(inplace)
27??????(conv2):?Conv2d(128,?128,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
28??????(bn2):?BatchNorm2d(128,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
29??????(downsample):?Sequential(
30????????(0):?Conv2d(64,?128,?kernel_size=(1,?1),?stride=(2,?2),?bias=False)
31????????(1):?BatchNorm2d(128,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
32??????)
33????)
34????(1):?BasicBlock(
35??????(conv1):?Conv2d(128,?128,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
36??????(bn1):?BatchNorm2d(128,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
37??????(relu):?ReLU(inplace)
38??????(conv2):?Conv2d(128,?128,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
39??????(bn2):?BatchNorm2d(128,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
40????)
41??)
42??(layer3):?Sequential(
43????(0):?BasicBlock(
44??????(conv1):?Conv2d(128,?256,?kernel_size=(3,?3),?stride=(2,?2),?padding=(1,?1),?bias=False)
45??????(bn1):?BatchNorm2d(256,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
46??????(relu):?ReLU(inplace)
47??????(conv2):?Conv2d(256,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
48??????(bn2):?BatchNorm2d(256,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
49??????(downsample):?Sequential(
50????????(0):?Conv2d(128,?256,?kernel_size=(1,?1),?stride=(2,?2),?bias=False)
51????????(1):?BatchNorm2d(256,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
52??????)
53????)
54????(1):?BasicBlock(
55??????(conv1):?Conv2d(256,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
56??????(bn1):?BatchNorm2d(256,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
57??????(relu):?ReLU(inplace)
58??????(conv2):?Conv2d(256,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
59??????(bn2):?BatchNorm2d(256,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
60????)
61??)
62??(layer4):?Sequential(
63????(0):?BasicBlock(
64??????(conv1):?Conv2d(256,?512,?kernel_size=(3,?3),?stride=(2,?2),?padding=(1,?1),?bias=False)
65??????(bn1):?BatchNorm2d(512,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
66??????(relu):?ReLU(inplace)
67??????(conv2):?Conv2d(512,?512,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
68??????(bn2):?BatchNorm2d(512,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
69??????(downsample):?Sequential(
70????????(0):?Conv2d(256,?512,?kernel_size=(1,?1),?stride=(2,?2),?bias=False)
71????????(1):?BatchNorm2d(512,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
72??????)
73????)
74????(1):?BasicBlock(
75??????(conv1):?Conv2d(512,?512,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
76??????(bn1):?BatchNorm2d(512,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
77??????(relu):?ReLU(inplace)
78??????(conv2):?Conv2d(512,?512,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1),?bias=False)
79??????(bn2):?BatchNorm2d(512,?eps=1e-05,?momentum=0.1,?affine=True,?track_running_stats=True)
80????)
81??)
82??(avgpool):?AvgPool2d(kernel_size=7,?stride=1,?padding=0)
83??(fc):?Sequential(
84????(0):?Dropout(p=0.3)
85????(1):?Linear(in_features=512,?out_features=62,?bias=True)
86??)
87)
4.训练模型(支持自动GPU加速)
1def?train(epochs):
2
3????model?=?get_model(config.num_classes)
4????print(model)
5????loss_f?=?t.nn.CrossEntropyLoss()
6????if(config.use_gpu):
7????????model?=?model.cuda()
8????????loss_f?=?loss_f.cuda()
9
10????opt?=?t.optim.Adam(model.fc.parameters(),lr?=?config.lr)
11????time_start?=?time.time()
12
13????for?epoch?in?range(epochs):
14????????train_loss?=?[]
15????????train_acc?=?[]
16????????test_loss?=?[]
17????????test_acc?=?[]
18????????model.train(True)
19????????print("Epoch?{}/{}".format(epoch+1,epochs))
20????????for?batch,?datas?in?tqdm(enumerate(iter(dataloader["train"]))):
21????????????x,y?=?datas
22????????????if?(config.use_gpu):
23????????????????x,y?=?x.cuda(),y.cuda()
24????????????y_?=?model(x)
25????????????#print(x.shape,y.shape,y_.shape)
26????????????_,?pre_y_?=?t.max(y_,1)
27????????????pre_y?=?y
28????????????#print(y_.shape)
29????????????loss?=?loss_f(y_,pre_y)
30????????????#print(y_.shape)
31????????????acc?=?t.sum(pre_y_?==?pre_y)
32
33????????????loss.backward()
34????????????opt.step()
35????????????opt.zero_grad()
36????????????if(config.use_gpu):
37????????????????loss?=?loss.cpu()
38????????????????acc?=?acc.cpu()
39????????????train_loss.append(loss.data)
40????????????train_acc.append(acc)
41????????????#if((batch+1)%5?==0):
42????????time_end?=?time.time()
43????????print("Batch?{},?Train?loss:{:.4f},?Train?acc:{:.4f},?Time:?{}"
44????????????.format(batch+1,np.mean(train_loss)/config.batch_size,np.mean(train_acc)/config.batch_size,(time_end-time_start)))
45????????time_start?=?time.time()
46
47????????model.train(False)
48????????for?batch,?datas?in?tqdm(enumerate(iter(dataloader["test"]))):
49????????????x,y?=?datas
50????????????if?(config.use_gpu):
51????????????????x,y?=?x.cuda(),y.cuda()
52????????????y_?=?model(x)
53????????????#print(x.shape,y.shape,y_.shape)
54????????????_,?pre_y_?=?t.max(y_,1)
55????????????pre_y?=?y
56????????????#print(y_.shape)
57????????????loss?=?loss_f(y_,pre_y)
58????????????acc?=?t.sum(pre_y_?==?pre_y)
59
60????????????if(config.use_gpu):
61????????????????loss?=?loss.cpu()
62????????????????acc?=?acc.cpu()
63
64????????????test_loss.append(loss.data)
65????????????test_acc.append(acc)
66????????print("Batch?{},?Test?loss:{:.4f},?Test?acc:{:.4f}".format(batch+1,np.mean(test_loss)/config.batch_size,np.mean(test_acc)/config.batch_size))
67
68????????t.save(model,str(epoch+1)+"ttmodel.pkl")
69
70
71
72if?__name__?==?"__main__":
73????train(config.epochs)
训练结果如下:
1Epoch?1/10
2115it?[00:48,??2.63it/s]
3Batch?115,?Train?loss:0.0590,?Train?acc:0.4635,?Time:?48.985504150390625
463it?[00:24,??2.62it/s]
5Batch?63,?Test?loss:0.0374,?Test?acc:0.6790,?Time?:24.648272275924683
6Epoch?2/10
7115it?[00:45,??3.22it/s]
8Batch?115,?Train?loss:0.0271,?Train?acc:0.7576,?Time:?45.68823838233948
963it?[00:23,??2.62it/s]
10Batch?63,?Test?loss:0.0255,?Test?acc:0.7524,?Time?:23.271782875061035
11Epoch?3/10
12115it?[00:45,??3.19it/s]
13Batch?115,?Train?loss:0.0181,?Train?acc:0.8300,?Time:?45.92648506164551
1463it?[00:23,??2.60it/s]
15Batch?63,?Test?loss:0.0212,?Test?acc:0.7861,?Time?:23.80789279937744
16Epoch?4/10
17115it?[00:45,??3.28it/s]
18Batch?115,?Train?loss:0.0138,?Train?acc:0.8767,?Time:?45.27525019645691
1963it?[00:23,??2.57it/s]
20Batch?63,?Test?loss:0.0173,?Test?acc:0.8385,?Time?:23.736321449279785
21Epoch?5/10
22115it?[00:44,??3.22it/s]
23Batch?115,?Train?loss:0.0112,?Train?acc:0.8950,?Time:?44.983638286590576
2463it?[00:22,??2.69it/s]
25Batch?63,?Test?loss:0.0156,?Test?acc:0.8520,?Time?:22.790074348449707
26Epoch?6/10
27115it?[00:44,??3.19it/s]
28Batch?115,?Train?loss:0.0095,?Train?acc:0.9159,?Time:?45.10426950454712
2963it?[00:22,??2.77it/s]
30Batch?63,?Test?loss:0.0158,?Test?acc:0.8214,?Time?:22.80412459373474
31Epoch?7/10
32115it?[00:45,??2.95it/s]
33Batch?115,?Train?loss:0.0081,?Train?acc:0.9280,?Time:?45.30439043045044
3463it?[00:23,??2.66it/s]
35Batch?63,?Test?loss:0.0139,?Test?acc:0.8528,?Time?:23.122379541397095
36Epoch?8/10
37115it?[00:44,??3.23it/s]
38Batch?115,?Train?loss:0.0073,?Train?acc:0.9300,?Time:?44.304762840270996
3963it?[00:22,??2.74it/s]
40Batch?63,?Test?loss:0.0142,?Test?acc:0.8496,?Time?:22.801835536956787
41Epoch?9/10
42115it?[00:43,??3.19it/s]
43Batch?115,?Train?loss:0.0068,?Train?acc:0.9361,?Time:?44.08414030075073
4463it?[00:23,??2.44it/s]
45Batch?63,?Test?loss:0.0142,?Test?acc:0.8437,?Time?:23.604419231414795
46Epoch?10/10
47115it?[00:46,??3.12it/s]
48Batch?115,?Train?loss:0.0063,?Train?acc:0.9337,?Time:?46.76597046852112
4963it?[00:24,??2.65it/s]
50Batch?63,?Test?loss:0.0130,?Test?acc:0.8591,?Time?:24.64351773262024
训练10个Epoch,测试集准确率可以到达0.86,已经达到不错效果。通过修改参数,增加训练,可以达到更高的准确率。
原文链接:
https://blog.csdn.net/xiaosongshine/article/details/89409223?
(*本文为 AI科技大本营转载文章,转载请联系原作者)
◆
精彩推荐
◆
参与投稿加入作者群,成为全宇宙最优秀的技术人~


推荐阅读

关注公众号:拾黑(shiheibook)了解更多
[广告]赞助链接:
四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/
关注网络尖刀微信公众号随时掌握互联网精彩
- 1 总书记引领中国经济巨轮行稳致远 7904888
- 2 日本强震 高市早苗神色慌张一路小跑 7808038
- 3 受贿超11亿!白天辉被执行死刑 7713261
- 4 明年经济工作怎么干?关注这些重点 7617035
- 5 中方回应没接听日方“热线电话” 7522332
- 6 “丧葬风”头巾实为日本品牌设计 7428597
- 7 一定要在这个年龄前就开始控糖 7332845
- 8 苹果原装取卡针回收超300元?官方回应 7236631
- 9 日本发生7.5级强震后 高市早苗发声 7138421
- 10 “中国游”“中国购”体验感拉满 7047335







AI100
