栏目分类:
子分类:
返回
文库吧用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
文库吧 > IT > 软件开发 > 后端开发 > Python

[Keras] 绘制训练过程中Acc和Loss曲线

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

[Keras] 绘制训练过程中Acc和Loss曲线

以Fashion MNIST数据集为例。

# coding = utf-8
from tensorflow import keras
import matplotlib.pyplot as plt

# Prepare data X_train: ndarray,(60000, 28, 28) y_train: ndarray, (60000,)
(X_train, y_train), (X_valid, y_valid) = keras.datasets.fashion_mnist.load_data()
X_train = X_train / 255.0
X_valid = X_valid / 255.0

# Build model
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),  # input layer
    keras.layers.Dense(128, activation='relu'),  # hidden layer
    keras.layers.Dense(10, activation='softmax')  # output layer
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train model
h = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_split=0.2)

# Plot accuracy and loss curve
history = h.history
print(history.keys())
epochs = range(len(history['loss']))

plt.figure(figsize=(20, 5))
ax = plt.subplot(1, 2, 1)
ax.set_title('Train and Valid Accuracy')
plt.plot(epochs, history['accuracy'], 'b', label='Train accuracy')
plt.plot(epochs, history['val_accuracy'], 'r', label='Valid accuracy')
plt.legend()

ax = plt.subplot(1, 2, 2)
ax.set_title('Train and Valid Loss')
plt.plot(epochs, history['loss'], 'b', label='Train loss')
plt.plot(epochs, history['val_loss'], 'r', label='Valid loss')
plt.legend()

# plt.savefig(f'acc_and_loss.png')

plt.show()

Output:

...
Epoch 98/100
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0504 - accuracy: 0.9818 - val_loss: 0.7474 - val_accuracy: 0.8877
Epoch 99/100
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0455 - accuracy: 0.9831 - val_loss: 0.7986 - val_accuracy: 0.8813
Epoch 100/100
1500/1500 [==============================] - 2s 1ms/step - loss: 0.0450 - accuracy: 0.9835 - val_loss: 0.7605 - val_accuracy: 0.8895

dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])


model.fit()返回一个TensorFlow的History对象,其中包含一个history的字典,字典里的键根据设定的metrics=['accuracy']有关。

转载请注明:文章转载自 www.wk8.com.cn
本文地址:https://www.wk8.com.cn/it/1037521.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 wk8.com.cn

ICP备案号:晋ICP备2021003244-6号