MNIST手写字体识别实例
1.步骤概要
-
读取数据
-
创建网络
-
训练参数得到 [w1,b1,w2,b2,w3,b3](三层)
-
准确度测试
2.具体实现
2.1 工具包
对图片进行可视化展示以及onehot编码
1 | import torch |
2.2导入包
1 | from torch import nn #网络模型 |
2.3 下载并读取数据
1 | #每个批次多少张照片 |
对训练数据进行可视化展示,查看图片
1 | x,y=next(iter(train_loader)) |
1 | output: |
2.4 创建网络模型
1 | class Net(nn.Module): |
调用模型进行训练迭代
1 | net=Net() |
2.5 准确度测试
1 | #4.准确度测试 |
1 | output: |
2.6训练集与测试集数据展示
1 | #训练集 |
1 | output: |