神经网络手写数字识别

⚠申明: 未经许可,禁止以任何形式转载,若要引用,请标注链接地址。 全文共计4077字,阅读大概需要3分钟
🌈更多学习内容, 欢迎👏关注👀【文末】我的个人微信公众号:不懂开发的程序猿
个人网站:https://jerry-jy.co/

神经网络手写数字识别

  • 神经网络手写数字识别
    • 一、任务需求
    • 二、任务目标
          • 1、掌握神经网络的构建
          • 2、掌握神经网络的编译
          • 3、掌握神经网络的训练
          • 4、掌握神经网络的概要输出
          • 5、掌握神经网络的模型封装
    • 三、任务环境
          • 1、jupyter开发环境
          • 2、python3.6
          • 3、tensorflow2.4
    • 四、任务实施过程
    • 五、任务小结

神经网络手写数字识别


一、任务需求

MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素)。每个样本都是一张28 * 28像素的灰度手写数字图片,其中每一张图片都代表0~9中的一个数字。

在这里插入图片描述

要求:利用Sequential模型神经网络使用Tensorflow手写数字识别模型构建与输出

二、任务目标

1、掌握神经网络的构建
2、掌握神经网络的编译
3、掌握神经网络的训练
4、掌握神经网络的概要输出
5、掌握神经网络的模型封装

三、任务环境

1、jupyter开发环境
2、python3.6
3、tensorflow2.4

四、任务实施过程

1、导入灰度图数字识别所用到的模块

# 导入所需模块
import tensorflow as tf
from matplotlib import pyplot as plt

2、加载图片识别用到的训练集和测试集

# 导入数据,分别为训练集和测试集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

3、可视化训练集输入特征的第一个元素

plt.imshow(x_train[0], cmap='gray')  # 绘制灰度图
plt.show()

在这里插入图片描述

4、打印出训练集输入特征的第一个元素

# 打印出训练集输入特征的第一个元素
print("x_train[0]:\n", x_train[0])
x_train[0]:
 [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3  18  18  18 126 136
  175  26 166 255 247 127   0   0   0   0]
 [  0   0   0   0   0   0   0   0  30  36  94 154 170 253 253 253 253 253
  225 172 253 242 195  64   0   0   0   0]
 [  0   0   0   0   0   0   0  49 238 253 253 253 253 253 253 253 253 251
   93  82  82  56  39   0   0   0   0   0]
 [  0   0   0   0   0   0   0  18 219 253 253 253 253 253 198 182 247 241
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  80 156 107 253 253 205  11   0  43 154
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  14   1 154 253  90   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0 139 253 190   2   0   0   0
    0   0   0   0   0   0   0   0   0   0]
...
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]]
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

5、打印出训练集标签的第一个元素

# 打印出训练集标签的第一个元素
print("y_train[0]:\n", y_train[0])

y_train[0]:
5

6、打印出整个训练集输入特征形状

# 打印出整个训练集输入特征形状
print("x_train.shape:\n", x_train.shape)

x_train.shape:
(60000, 28, 28)

7、打印出整个训练集标签的形状

print("y_train.shape:\n", y_train.shape)

y_train.shape:
(60000,)

8、打印出整个测试集输入特征的形状

print("x_test.shape:\n", x_test.shape)

x_test.shape:
(10000, 28, 28)

9、打印出整个测试集标签的形状

print("y_test.shape:\n", y_test.shape)

y_test.shape:
(10000,)

10、数据集的归一化处理

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

11、Sequential神经网络模型构建,分别为拉直层、全连接层128个神经元,激活函数为relu。10个全连接输出,激活函数为softmax。

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

12、模型编译,优化器为adm,损失函数为带softmax的交叉熵,评价函数为准确率。

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

13、模型训练,喂入的数据为5个批次,每个批次32条数据。校验的数据为训练集。检验的频率为1个批次

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
Epoch 1/5
1875/1875 [==============================] - 9s 5ms/step - loss: 0.4221 - sparse_categorical_accuracy: 0.8801 - val_loss: 0.1338 - val_sparse_categorical_accuracy: 0.9594
Epoch 2/5
1875/1875 [==============================] - 9s 5ms/step - loss: 0.1173 - sparse_categorical_accuracy: 0.9652 - val_loss: 0.1107 - val_sparse_categorical_accuracy: 0.9669
Epoch 3/5
1875/1875 [==============================] - 9s 5ms/step - loss: 0.0789 - sparse_categorical_accuracy: 0.9762 - val_loss: 0.0804 - val_sparse_categorical_accuracy: 0.9759
Epoch 4/5
1875/1875 [==============================] - 9s 5ms/step - loss: 0.0567 - sparse_categorical_accuracy: 0.9827 - val_loss: 0.0761 - val_sparse_categorical_accuracy: 0.9756
Epoch 5/5
1875/1875 [==============================] - 8s 5ms/step - loss: 0.0441 - sparse_categorical_accuracy: 0.9864 - val_loss: 0.0710 - val_sparse_categorical_accuracy: 0.9777

<tensorflow.python.keras.callbacks.History at 0x7f5445b697b8>

14、模型概要,输出模型的参数信息。

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 128)               100480    
_________________________________________________________________
dense_5 (Dense)              (None, 10)                1290      
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________

可以看到,模型的参数一共101770个。计算方式为:1282828 +128 + 10*128+10

15、使用类来封装模型

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Flatten
class MnistModel(Model):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        y = self.d2(x)
        return y

16、创建模型

model = MnistModel()

17、构型编译、训练和模型概要输出方式和上面的方式一样

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

在这里插入图片描述

在这里插入图片描述

五、任务小结

本节我们通过Sequential模型,构建神经网络来实现手写数字的识别任务。首先我们查看了手写数字的训练集和测试集的数量和维度,查看了训练集的第一张图片的样式。同时打印输出了训练集和测试集的第一个元素。其次我们按照神经网络构建的步骤,进行了模型的构建、模型训练、模型概要输出。最后我们对模型进行了封装,也进行了模型构建、训练和概要输出。可以看到使用封装和不使用封装的方式模型概要输出的内容是一样的。通过本任务我们需要掌握使用神经网络的步骤。掌握全连接网络的构建方式。

–end–

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/572472.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python安装pytorch@FreeBSD

先上结论&#xff0c;最后在conda下安装成功了&#xff01; PyTorch是一个开源的人工智能深度学习框架&#xff0c;由Facebook人工智能研究院&#xff08;FAIR&#xff09;基于Torch库开发并维护。PyTorch提供了一个高效、灵活且易于使用的工具集&#xff0c;用于构建和训练深…

Python-VBA函数之旅-iter函数

目录 一、iter函数的常见应用场景&#xff1a; 二、iter函数使用注意事项&#xff1a; 三、如何用好iter函数&#xff1f; 1、iter函数&#xff1a; 1-1、Python&#xff1a; 1-2、VBA&#xff1a; 2、推荐阅读&#xff1a; 个人主页&#xff1a;神奇夜光杯-CSDN博客 …

AndroidStudio 新建工程的基本修改及事件添加

注&#xff1a;2022.3.1&#xff0c;新建Empty Activity默认是Kotlin&#xff0c;可以选择新建Empty View Activity&#xff0c;修改语言为JAVA 应用名称 修改应用名称 路径&#xff1a;res-values-strings.xml 是否显示应用名称 路径&#xff1a;res-values-themes.xml …

SpringMVC基础篇(一)

文章目录 1.基本介绍1.特点2.SpringMVC跟SpringBoot的关系 2.快速入门1.需求分析2.图解3.环境搭建1.创建普通java工程2.添加web框架支持3.配置lib文件夹1.导入jar包2.Add as Library3.以后自动添加 4.配置tomcat1.配置上下文路径2.配置热加载 5.src下创建Spring配置文件applica…

React.js 3D开发快速入门

如果你对 3D 图形的可能性着迷&#xff0c;但发现从头开始创建 3D 模型的想法是不可能的 - 不用担心&#xff01; Three.js 是一个强大的 JavaScript 库&#xff0c;它可以帮助我们轻松地将现有的 3D 模型集成到 React 应用程序中。因此&#xff0c;在本文中&#xff0c;我将深…

Educational Codeforces Round 164 (Rated for Div. 2) A-E

A. Painting the Ribbon 暴力模拟即可 #include <bits/stdc.h>using namespace std; const int N 2e5 5; typedef long long ll; typedef pair<ll, ll> pll; typedef array<ll, 3> p3; // int mod 998244353; const int maxv 4e6 5; // #define endl &…

ICCV2023人脸识别TransFace论文及代码学习笔记

论文链接&#xff1a;https://arxiv.org/pdf/2308.10133.pdf 代码链接&#xff1a;GitHub - DanJun6737/TransFace: Code of TransFace 背景 尽管ViTs在多种视觉任务中展示了强大的表示能力&#xff0c;但作者发现&#xff0c;当应用于具有极大数据集的人脸识别场景时&#…

Leaflet实现离线地图展示,同时显示地图上的坐标点和热力图

在实际工作中,因为部署环境的要求,必须使用离线地图,而不是调用地图接口。我们应该怎么解决这种项目呢? 下面介绍一种解决该问题的方案:Leaflet+瓦片地图 一、Leaflet Leaflet 是一个开源并且对移动端友好的交互式地图 JavaScript 库。 它大小仅仅只有 42 KB of JS, 并且拥…

opencv图片绘制图形-------c++

绘制图形 #include <opencv2/opencv.hpp> #include <opencv2/core.hpp> #include <filesystem>bool opencvTool::drawPolygon(std::string image_p, std::vector<cv::Point> points) {cv::Mat ima cv::imread(image_p.c_str()); // 读取图像&#xf…

如何调节电脑屏幕亮度?让你的眼睛更舒适!

电脑屏幕亮度的调节对于我们的视力保护和使用舒适度至关重要。不同的环境和使用习惯可能需要不同的亮度设置。可是如何调节电脑屏幕亮度呢&#xff1f;本文将介绍三种不同的电脑屏幕亮度调节方法&#xff0c;帮助您轻松调节电脑屏幕亮度&#xff0c;以满足您的需求。 方法1&…

C++必修:从C到C++的过渡(下)

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ &#x1f388;&#x1f388;养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 所属专栏&#xff1a;C学习 贝蒂的主页&#xff1a;Betty’s blog 1. 缺省参数 1.1. 缺省参数的使用 缺省参数是声明或定义函数时为函数的参数指定…

直接插入排序与希尔排序的详解及对比

目录 1.直接插入排序&#xff08;至少有两个元素才可以使用&#xff09; 排序逻辑 B站动画演示&#xff1a;直接插入排序 逻辑转为代码&#xff1a; 稳定性&#xff1a;稳定 时间复杂度&#xff1a;O(N^2) 空间复杂度&#xff1a;O(1) 应用场景 2.希尔排序&#xff08;对…

VUE父组件向子组件传递值

创作灵感 最近在写一个项目时&#xff0c;遇到了这样的一个需求。我封装了一个组件&#xff0c;这个组件需要被以下两个地方使用&#xff0c;一个是搜索用户时用到&#xff0c;一个是修改用户信息时需要用到。其中&#xff0c;在搜索用户时&#xff0c;可以根据姓名或者账号进…

C++之STL-String

目录 一、STL简介 1.1 什么是STL 1.2 STL的版本 1.3 STL的六大组件 ​编辑 1.4 STL的重要性 二、String类 2.1 Sting类的简介 2.2 string之构造函数 2.3 string类对象的容量操作 2.3.1 size() 2.3.2 length() 2.3.3 capacity() 2.3.4 empty() 2.3.5 clear() 2.3.6…

【Unity】苹果(IOS)开发证书保姆级申请教程

前言 我们在使用xcode出包的时候&#xff0c;需要用到iOS证书(.p12)和描述文件(.mobileprovision) 开发证书及对应的描述文件用于开发阶段使用&#xff0c;可以直接将 App 安装到手机上&#xff0c;一个描述文件最多绑定100台测试设备 1.证书管理 进入网站Apple Developer &…

从虚拟化走向云原生,红帽OpenShift“一手托两家”

汽车行业已经迈入“软件定义汽车”的新时代。吉利汽车很清醒地意识到&#xff0c;只有通过云原生技术和数字化转型&#xff0c;才能巩固其作为中国领先汽车制造商的地位。 和很多传统企业一样&#xff0c;吉利汽车在走向云原生的过程中也经历了稳态业务与敏态业务并存带来的前所…

视频美颜SDK原理与实践:从算法到应用

当下&#xff0c;从社交媒体到视频通话&#xff0c;人们越来越依赖于视频美颜功能来提升自己的形象。而视频美颜SDK作为支撑这一技术的重要工具&#xff0c;其原理和实践至关重要。 一、什么是视频美颜SDK&#xff1f; 视频美颜SDK是一种软件开发工具包&#xff0c;用于集成到…

FloodFill算法---DFS

目录 floodfill算法概念&#xff1a; 算法模板套路&#xff1a; 例题1&#xff1a;图像渲染 例题2&#xff1a;岛屿数量 例题3&#xff1a;岛屿的最大面积 例题4&#xff1a;被围绕的区域 floodfill算法概念&#xff1a; floodfill算法是一种常用的图像处理算法&#xf…

【IDEA】在IntelliJ IDEA中导入Eclipse项目:详细指南

IntelliJ IDEA和Eclipse是两款常用的集成开发环境&#xff08;IDE&#xff09;&#xff0c;在软件开发中经常会遇到需要在它们之间迁移项目的情况。本文将重点介绍如何在IntelliJ IDEA中导入Eclipse项目&#xff0c;以帮助开发者顺利地迁移他们的项目&#xff0c;并在IntelliJ …
最新文章