回顾神经网络分类任务的整体流程

更新时间:2023-05-22 02:39:51 阅读: 评论:0

结婚男方父亲讲话致辞-格造句

回顾神经网络分类任务的整体流程
2023年5月22日发(作者:变的成语)

回顾神经⽹络分类任务的整体流程

下⾯构建了三个线性叠加的模型,⽆论它怎样嵌套,仍然是⼀个线性模型,对于⼀个⼿写数字来

pred=w[w[wx+b]+b]+b

321123

讲,⼈脑之所以能在复杂噪声的条件下把⼿写数字识别出来,是因为⼈脑具有很强的⾮线性表达能⼒,⽽对于线性模型来说,它很难完成像

⼿写数字识别任务这种现实⽣活中遇到的简单问题。如何解决这个问题?就是在每⼀个线性函数之后添加⼀个⾮线性的部分——

Relu,Relu有⼀个更好特性就是避免梯度离散。除了Relu之外,还有sigmoid等激活函数。早期模型都是使⽤sigmoid把加权和映射[0,1]

区间内,但现在的神经⽹络基本上都不使⽤sigmoid,实际上sigmoid并没有让训练结果变得更好,或者在某种程度上使模型很难训练,⽽

Relu在特别深层次的神经⽹络上效果特别好和更好训练。

# -*- coding: utf-8 -*-

import torch

from torch import nn #

神经⽹络库

from torch.nn import functional as F #

常⽤函数

from torch import optim #

优化⼯具包

import torchvision #

视觉⼯具包

from matplotlib import pyplot as plt #

数据可⽰化⼯具包

from utils import plot_image, plot_curve, one_hot

batch_size = 512

step1. load datat

Normalize 零—均值规范化也叫标准差标准化,mean:0.1307,std:0.3081,其转化公式s = (x - mean)/std,特征标准化不会改变

特征取值分布,只是为了保证参数变量的取值范围具有相似的尺度,以帮助梯度下降算法收敛更快。

shuffle 将数据集随机打乱

train_loader = torch.utils.data.DataLoader(

torchvision.datats.MNIST('mnist_data', train=True, download=True,

transform=torchvision.transforms.Compo([

torchvision.transforms.ToTensor(),

torchvision.transforms.Normalize(

(0.1307,), (0.3081,))

])),

batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(

torchvision.datats.MNIST('mnist_data/', train=Fal, download=True,

transform=torchvision.transforms.Compo([

torchvision.transforms.ToTensor(),

torchvision.transforms.Normalize(

(0.1307,), (0.3081,))

])),

batch_size=batch_size, shuffle=Fal)

x, y = next(iter(train_loader))

print(x.shape, y.shape, x.min(), x.max())

打印结果

([512, 1, 28, 28]) ([512]) tensor(-0.4242) tensor(2.8215)

plot_image(x, y, 'image sample')

打印结果

step2. Build Model

class Net(nn.Module):

def __init__(lf):

super(Net, lf).__init__()

# xw+b , 256,6428*28100-9

其中的数值都是由经验决定的,输⼊的维度,是⼀个分类值

lf.fc1 = nn.Linear(28*28, 256)

lf.fc2 = nn.Linear(256, 64)

lf.fc3 = nn.Linear(64, 10)

def forward(lf, x):

# x: [b, 1, 28, 28]

# h1 = relu(xw1+b1)

x = F.relu(lf.fc1(x))

# h2 = relu(h1w2+b2)

x = F.relu(lf.fc2(x))

# h3 = h2w3+b3

x = lf.fc3(x)

return x

net = Net()

# [w1, b1, w2, b2, w3, b3]optimizer

是⼀个优化器,更新参数值

optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

step3. Train

train_loss = []

for epoch in range(60):

for batch_idx, (x, y) in enumerate(train_loader):

# x: [b, 1, 28, 28], y: [512]

# [b, 1, 28, 28] => [b, 784]bbatchsize28*28 => 784 x_i

,其中可以看作是的样本数据

x = x.view(x.size(0), 28*28)

# => [b, 10]

out = net(x)

# [b, 10]

y_onehot = one_hot(y)

# loss = m(out, y_onehot)

loss = F.m_loss(out, y_onehot) #

获得代价函数的初始值

optimizer.zero_grad() # BP

之前⾸先将梯度清零,以保证每次更新的负梯度值是最新的。

loss.backward() #

计算出梯度信息

# w' = w - lr*grad

optimizer.step() #

更新参数信息

train_loss.append(loss.item()) #

保存当前参数信息

if batch_idx % 10 == 0:

print(epoch, batch_idx, loss.item()) # mini-batch

每训练完⼀个就显⽰当前训练模型的参数状态

plot_curve(train_loss) #

模型训练完毕,显⽰代价函数曲线收敛的⾛势

# we get optimal [w1, b1, w2, b2, w3, b3] # loss

模型训练完之后会得到这⼀组最优参数解,使得值全局最⼩。

打印结果

这⾥的loss值不是⽤来衡量模型的性能指标,只是⽤来辅助我们更好地训练模型,衡量模型的性能指标有很多种⽅法,最终

衡量模型的指标是它的准确度。

下⾯使⽤测试集对模型进⾏准确度测试。

step4. Test

total_correct = 0

for x,y in test_loader:

x = x.view(x.size(0), 28*28)

out = net(x) # x_i

输⼊测试样本数据,预测出概率模型

'''

out: [b, 10] => pred: [b] , ⽐如输出标签对应的预测概率为[0.1,0.9,0.01,......,0.08],∑P(y|x) = 1

argmax获得预测概率最⼤元素所在的索引号,max=0.9,argmax(out)=[0,1,0,......,0]

从⽽获得one-hot的预测编码

若预测概率是out = [0.01,0.02,0.03,0.705,...,0.09], argmax(out) = [0,0,0,3,0,0,0,0,0,0]

'''

pred = out.argmax(dim=1)

correct = pred.eq(y).sum().float().item()

total_correct += correct

total_num = len(test_t)

acc = total_correct / total_num

print('test acc:', acc)

打印结果

test acc: 0.9684

x, y = next(iter(test_loader))

out = net(((0), 28*28))

pred = (dim=1)

plot_image(x, pred, 'test')

打印预测结果

写于2020.03.10 01:13:14

焦点网络-张俪的腿

回顾神经网络分类任务的整体流程

本文发布于:2023-05-22 02:39:50,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/zhishi/a/1684694391174615.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

本文word下载地址:回顾神经网络分类任务的整体流程.doc

本文 PDF 下载地址:回顾神经网络分类任务的整体流程.pdf

标签:网络任务
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 实用文体写作网旗下知识大全大全栏目是一个全百科类宝库! 优秀范文|法律文书|专利查询|