Tensorflow入门实战 T06-Vgg16 明星识别

目录

1、前言

2、 完整代码

3、运行过程+结果

4、遇到的问题

5、小结


  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

1、前言

这周主要是使用VGG16模型,完成明星照片识别。

2、 完整代码

from keras.utils import losses_utils
from tensorflow import keras
from keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from keras.callbacks import ModelCheckpoint, EarlyStopping

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0], "GPU")

# 导入数据
data_dir = "/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p06_vgg16/data"
data_dir = pathlib.Path(data_dir)

# 查看数据
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:",image_count)  # 1800

roses = list(data_dir.glob('Jennifer Lawrence/*.jpg'))
img = PIL.Image.open(str(roses[0]))
# img.show()  # 查看图片

# 数据预处理
# 1、加载数据
batch_size = 32
img_height = 224
img_width = 224

print('data_dir======>',data_dir)
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="training",
    label_mode="categorical",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="validation",
    label_mode="categorical",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

# 可视化数据
plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[np.argmax(labels[i])])
        plt.axis("off")
plt.show()

# 再次检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)   # (32, 224, 224, 3)
    print(labels_batch.shape)   # (32, 17)
    break

# 配置数据集
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# 构建CNN网络
"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995

layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""

model = models.Sequential([
    keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),

    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样
    layers.Dropout(0.5),
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.AveragePooling2D((2, 2)),
    layers.Dropout(0.5),
    layers.Conv2D(128, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.5),

    layers.Flatten(),  # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取
    layers.Dense(len(class_names))  # 输出层,输出预期结果
])

# model.summary()  # 打印网络结构


# 训练模型
# 1、设置动态学习率
# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=60,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])


# 损失函数
# 调用方式1:
model.compile(optimizer="adam",
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 调用方式2:
# model.compile(optimizer="adam",
#               loss=tf.keras.losses.CategoricalCrossentropy(),
#               metrics=['accuracy'])

# sparse_categorical_crossentropy(稀疏性多分类的对数损失函数)
# 调用方式1:
model.compile(optimizer="adam",
              loss='categorical_crossentropy',
              metrics=['accuracy'])
# ↑↑↑↑这里出现报错,需要将 sparse_categorical_crossentropy  改成→  categorical_crossentropy↑↑
# 调用方式2:
# model.compile(optimizer="adam",
#               loss=tf.keras.losses.SparseCategoricalCrossentropy(),
#               metrics=['accuracy'])

# 函数原型
tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=False,
    reduction=losses_utils.ReductionV2.AUTO,
    name='sparse_categorical_crossentropy'
)



epochs = 100

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
                             min_delta=0.001,
                             patience=20,
                             verbose=1)

# 网络模型训练
history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])

# 模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(loss))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()


# 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')

from PIL import Image
import numpy as np

img = Image.open("/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p06_vgg16/data/Jennifer Lawrence/003_963a3627.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])

img_array = tf.expand_dims(image, 0)

predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])




3、运行过程+结果

【查看图片】

【模型运行过程---第21epoch就早停了】

【训练精度、损失-----显然结果很很差】

4、遇到的问题

① 在运行代码的时候遇到报错:

错误:Graph execution error: Detected at node 'sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits' defined at (most recent call last):

出现这个问题来自我们使用的损失函数。

model.compile(optimizer="adam",
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

解决办法:

将损失函数里面的loss='sparse_categorical_crossentropy' 改成 'categorical_crossentropy',即可解决报错问题。

关于sparse_categorical_crossentropy和categorical_crossentropy的更多细节,详细参考这篇博文:交叉熵损失_多分类交叉熵损失函数-CSDN博客

5、小结

原始模型,跑出来效果很差很差!!!

(1)将原来的Adam优化器换成SGD优化器,效果如下:

(2)后续再补充,最近在写结课论文,有些忙。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/755844.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[XYCTF新生赛2024]-PWN:EZ2.0?(arm架构,arm架构下的系统调用)

查看保护 查看ida 完整exp: from pwn import*pprocess(./arm) premote(gz.imxbt.cn,20082) svc0x0001c58c mov_r2_r4_blx_r30x00043224 pop_r70x00027d78 pop_r40x000104e0 pop_r30x00010160 pop_r10x0005f824 pop_r00x0005f73c sh0x0008A090payloadba*0x44 payloa…

探索高效开发神器:Blackbox AI(免费编程助手)

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 🤖 想要代码生成?👌 💬 需要和AI聊天解决难题?&#…

完全离线的本地问答模型LocalGPT如何实现无公网IP远程连接提问

文章目录 前言环境准备1. localGPT部署2. 启动和使用3. 安装cpolar 内网穿透4. 创建公网地址5. 公网地址访问6. 固定公网地址 前言 本文主要介绍如何本地部署LocalGPT并实现远程访问,由于localGPT只能通过本地局域网IP地址端口号的形式访问,实现远程访问…

[C语言]指针

一、指针简介 1、指针(Pointer)是C语言的一个重要知识点,其使用灵活、功能强大,是C语言的灵魂 2、指针与底层硬件联系紧密,使用指针可操作数据的地址,实现数据的间接访问 3、计算机存储机制 4、定义指针 (1&#x…

视频共享融合赋能平台LntonCVS安防监控平台现场方案实现和应用场景

LntonCVS国标视频融合云平台采用端-边-云一体化架构,部署简单灵活,功能多样化。支持多协议(GB28181/RTSP/Onvif/海康SDK/Ehome/大华SDK/RTMP推流等)和多类型设备接入(IPC/NVR/监控平台)。主要功能包括视频直…

技术驱动的音乐变革:AI带来的产业重塑

📑引言 近一个月来,随着几款音乐大模型的轮番上线,AI在音乐产业的角色迅速扩大。这些模型不仅将音乐创作的门槛降至前所未有的低点,还引发了一场关于AI是否会彻底颠覆音乐行业的激烈讨论。从初期的兴奋到现在的理性审视&#xff0…

衣服、帽子、鞋子相关深度学习数据集大合集(1)

最近收集了一大波关于衣物深度学习数据集,主要有衣服、帽子、鞋子、短裤、短袖、T恤等。 1、运动裤、短裤图片数据集 数据格式:图片 是否标注:已标注 标注格式:yolov8 图片数量:915张 查看地址:https…

virutalBox安装debian并配置docker环境

下载镜像 https://gemmei.ftp.acc.umu.se/debian-cd/current/amd64/iso-cd/debian-12.5.0-amd64-netinst.iso 虚拟机安装 如何在Virtual Box 上安装Debian系统_virtual box debian iso netinst-CSDN博客 启动命令行模式 如何设置Debian图形启动或命令行界面启动&#xff1…

S32K3 --- Wdg(内狗) Mcal配置

前言 看门狗的作用是用来检测程序是否跑飞,进入死循环。我们需要不停地喂狗,来确保程序是正常运行的,一旦停止喂狗,意味着程序跑飞,超时后就会reset复位程序。 一、Wdg 1.1 WdgGeneral Wdg Disable Allowed : 启用此参数后,允许在运行的时候禁用看门狗 Wdg Enable User…

微信小程序服务器从腾讯云迁移到阿里云出现的坑

微信小程序服务器从腾讯云迁移到阿里云出现的坑 背景 原先小程序后台服务器到期,因为之前买的是腾讯云新用户,便宜,到期后续费金额懂的都懂。就在阿里云用新用户买了个新的,遂把服务全转到了阿里云服务器上。 此时,域…

加油站可视化:打造智能化运营与管理新模式

智慧加油站可视化通过图扑 HT 构建仿真的三维模型,将加油站的布局、设备状态、人员活动等信息动态呈现。管理者可以通过直观的可视化界面实时监控和分析运营状况,快速做出决策,提高管理效率和安全水平,推动加油站向智能化管理转型…

云服务器的三大核心要素

云服务器作为云计算服务的重要组成部分,面向各类互联网用户提供综合业务能力的服务平台。该平台整合了传统意义上的互联网应用三大核心要素:计算、存储、网络,为用户提供公用化的互联网基础设施服务。 下面将围绕这三大核心要素展开详细的阐…

OverTheWire Bandit 靶场通关解析(下)

介绍 OverTheWire Bandit 是一个针对初学者设计的网络安全挑战平台,旨在帮助用户掌握基本的命令行操作和网络安全技能。Bandit 游戏包含一系列的关卡,每个关卡都需要解决特定的任务来获取进入下一关的凭证。通过逐步挑战更复杂的问题,用户可…

【Dison夏令营 Day 03】使用 Python 创建我们自己的 21 点游戏

21 点(英文:Blackjack)是一种在赌场玩的纸牌游戏。这种游戏的参与者不是互相竞争,而是与赌场指定的庄家竞争。在本文中,我们将从头开始创建可在终端上玩的玩家与庄家之间的二十一点游戏。 二十一点规则 我们将为从未玩过二十一点的读者提供…

【Python实战因果推断】6_元学习器1

目录 Metalearners for Discrete Treatments T-Learner 简单回顾一下,在之前的部分中,你们的重点是干预效果的异质性,也就是确定各单位对治疗的不同反应。在此框架下,您希望估算 或连续情况下的 。换句话说,您想知道…

“深入解析操作系统核心:进程管理与并发艺术“

操作系统中最核心的概念是进程:这是对正在运行程序的ー个抽象 并行 并发 “在任何多道程序设计系统中,CPU由一个进程快速切换至另ー个进程,使每个进程各运行几十或几百毫秒。严格地说,在某ー个瞬间,CPU只能运行ー个进程。但在1秒钟内,它可能运行多个进程,这样就产生并行的错觉…

【RAG】FoRAG:面向网络增强型长形式问答的事实性优化RAG

一、解决问题 在基于网络的长形式问答(Web-enhanced Long-form Question Answering, LFQA)任务中,现有RAG在生成答案时存在的问题: 事实性不足:研究表明,现有系统生成的答案中只有大约一半的陈述能够完全得…

SpringBoot学习04-[定制SpringMVC]

定制SpringMVC 定制SpringMvc的自动配置定制springmvc-configurePathMatch配置定制SpringMVC-拦截器Interceptor定制SpringMVC-CORS配置全局cors配置针对某个方法加跨域解决 WebMvcConfigurer原理定制SpringMVC-JSONJSON开发jackson的使用定制化json序列化和反序列化 JSON国际化…

Kafka面试必备:深度解析Replica副本的作用与机制

我是小米,一个喜欢分享技术的29岁程序员。如果你喜欢我的文章,欢迎关注我的微信公众号“软件求生”,获取更多技术干货! Hey大家好!我是小米,一个超级喜欢分享技术干货的大哥哥!今天咱们来聊聊阿里巴巴面试题中的一个热门话题:Kafka中的Replica副本作用。这可是个既基础…