# 重新训练现有的 CNN 模型
从头开始训练新的图像识别需要大量的时间和计算能力。如果我们可以采用先前训练的网络并使用我们的图像重新训练它,它可以节省我们的计算时间。对于此秘籍,我们将展示如何使用预先训练的 TensorFlow 图像识别模型并对其进行微调以处理不同的图像集。
## 做好准备
其思想是从卷积层重用先前模型的权重和结构,并重新训练网络顶部的完全连接层。
TensorFlow 在现有 CNN 模型的基础上创建了一个关于训练的教程(请参阅下一节中的第一个要点)。在本文中,我们将说明如何对 CIFAR-10 使用相同的方法。我们将采用的 CNN 网络使用一种非常流行的架构,称为 Inception。 Inception CNN 模型由 Google 创建,在许多图像识别基准测试中表现非常出色。有关详细信息,请参阅“另请参阅”部分的第二个要点中的纸张参考。
我们将介绍的主要 Python 脚本显示如何下载 CIFAR-10 图像数据并自动分离,标记和保存图像到每个训练和测试文件夹中的十个类。之后,我们将重申如何在我们的图像上训练网络。
## 操作步骤
执行以下步骤:
1. 我们首先加载必要的库来下载,解压缩和保存 CIFAR-10 图像:
```py
import os
import tarfile
import _pickle as cPickle
import numpy as np
import urllib.request
import scipy.misc
from imageio import imwrite
```
1. 我们现在声明 CIFAR-10 数据链接并创建我们将存储数据的临时目录。我们还将在以后保存图像时声明要引用的十个类别:
```py
cifar_link = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
data_dir = 'temp'
if not os.path.isdir(data_dir):
os.makedirs(data_dir)
objects = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
```
1. 现在我们将下载 CIFAR-10 `.tar`数据文件,并解压该文件:
```py
target_file = os.path.join(data_dir, 'cifar-10-python.tar.gz')
if not os.path.isfile(target_file):
print('CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)')
print('This may take a few minutes, please wait.')
filename, headers = urllib.request.urlretrieve(cifar_link, target_file)
# Extract into memory
tar = tarfile.open(target_file)
tar.extractall(path=data_dir)
tar.close()
```
1. 我们现在为训练创建必要的文件夹结构。临时目录将有两个文件夹,`train_dir`和`validation_dir`。在每个文件夹中,我们将为每个类别创建 10 个子文件夹:
```py
# Create train image folders
train_folder = 'train_dir'
if not os.path.isdir(os.path.join(data_dir, train_folder)):
for i in range(10):
folder = os.path.join(data_dir, train_folder, objects[i])
os.makedirs(folder)
# Create test image folders
test_folder = 'validation_dir'
if not os.path.isdir(os.path.join(data_dir, test_folder)):
for i in range(10):
folder = os.path.join(data_dir, test_folder, objects[i])
os.makedirs(folder)
```
1. 为了保存图像,我们将创建一个从内存加载它们并将它们存储在图像字典中的函数:
```py
def load_batch_from_file(file):
file_conn = open(file, 'rb')
image_dictionary = cPickle.load(file_conn, encoding='latin1')
file_conn.close()
return(image_dictionary)
```
1. 使用前面的字典,我们将使用以下函数将每个文件保存在正确的位置:
```py
def save_images_from_dict(image_dict, folder='data_dir'):
for ix, label in enumerate(image_dict['labels']):
folder_path = os.path.join(data_dir, folder, objects[label])
filename = image_dict['filenames'][ix]
#Transform image data
image_array = image_dict['data'][ix]
image_array.resize([3, 32, 32])
# Save image
output_location = os.path.join(folder_path, filename)
imwrite(output_location,image_array.transpose())
```
1. 使用上述函数,我们可以遍历下载的数据文件并将每个图像保存到正确的位置:
```py
data_location = os.path.join(data_dir, 'cifar-10-batches-py')
train_names = ['data_batch_' + str(x) for x in range(1,6)]
test_names = ['test_batch']
# Sort train images
for file in train_names:
print('Saving images from file: {}'.format(file))
file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)
image_dict = load_batch_from_file(file_location)
save_images_from_dict(image_dict, folder=train_folder)
# Sort test images
for file in test_names:
print('Saving images from file: {}'.format(file))
file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)
image_dict = load_batch_from_file(file_location)
save_images_from_dict(image_dict, folder=test_folder)
```
1. 我们脚本的最后一部分创建了图像标签文件,这是我们需要的最后一条信息。这个文件让我们将输出解释为标签而不是数字索引:
```py
cifar_labels_file = os.path.join(data_dir,'cifar10_labels.txt')
print('Writing labels file, {}'.format(cifar_labels_file))
with open(cifar_labels_file, 'w') as labels_file:
for item in objects:
labels_file.write("{}n".format(item))
```
1. 当前面的脚本运行时,它将下载图像并将它们分类到 TensorFlow 再训练教程所期望的正确文件夹结构中。完成后,我们只需按照教程进行操作即可。首先,我们应该克隆教程仓库:
```py
git clone https://github.com/tensorflow/models/tree/master/research/inception
```
1. 为了使用先前训练的模型,我们必须下载网络权重并将其应用于我们的模型。为此,您必须访问该站点: [https://github.com/tensorflow/models/tree/master/research/slim](https://github.com/tensorflow/models/tree/master/research/slim) ,并按照说明下载并安装 cifar10 模型架构和权重。您还将最终下载包含下面描述的构建,训练和测试脚本的数据目录。
> 对于此步骤,我们导航到 research / inception / inception 目录,然后执行以下命令,`--train_directory`,`--validation_directory`,`--output_directory`和`--labels_file`的路径指向相对路径或完整路径创建的目录结构。
1. 现在我们将图像放在正确的文件夹结构中,我们必须将它们变成`TFRecords`对象。我们通过运行以下命令来完成此操作:
```py
me@computer:~$ python3 data/build_image_data.py
--train_directory="temp/train_dir/"
--validation_directory="temp/validation_dir"
--output_directory="temp/" --labels_file="temp/cifar10_labels.txt"
```
1. 现在我们将使用`bazel`训练模型,将参数设置为`true`。该脚本每 10 代输出一次损失。我们可以随时终止此过程,模型输出将在`temp/training_results`文件夹中。我们可以从此文件夹加载模型以进行评估:
```py
me@computer:~$ bazel-bin/inception/flowers_train
--train_dir="temp/training_results" --data_dir="temp/data_dir"
--pretrained_model_checkpoint_path="model.ckpt-157585"
--fine_tune=True --initial_learning_rate=0.001
--input_queue_memory_factor=1
```
1. 这应该导致输出类似于以下内容:
```py
2018-06-02 11:10:10.557012: step 1290, loss = 2.02 (1.2 examples/sec; 23.771 sec/batch)
...
```
## 工作原理
关于预训练 CNN 上的训练的官方 TensorFlow 教程需要设置一个文件夹;我们从 CIFAR-10 数据创建的设置。然后我们将数据转换为所需的`TFRecords`格式并开始训练模型。请记住,我们正在微调模型并重新训练顶部的完全连接的层以适合我们的 10 类数据。
## 另见
* 官方 Tensorflow Inception-v3 教程: [https://www.tensorflow.org/tutoriaimg/image_recognition](https://www.tensorflow.org/tutoriaimg/image_recognition)
* Googlenet Inception-v3 文件: [https://arxiv.org/abs/1512.00567](https://arxiv.org/abs/1512.00567)
- TensorFlow 入门
- 介绍
- TensorFlow 如何工作
- 声明变量和张量
- 使用占位符和变量
- 使用矩阵
- 声明操作符
- 实现激活函数
- 使用数据源
- 其他资源
- TensorFlow 的方式
- 介绍
- 计算图中的操作
- 对嵌套操作分层
- 使用多个层
- 实现损失函数
- 实现反向传播
- 使用批量和随机训练
- 把所有东西结合在一起
- 评估模型
- 线性回归
- 介绍
- 使用矩阵逆方法
- 实现分解方法
- 学习 TensorFlow 线性回归方法
- 理解线性回归中的损失函数
- 实现 deming 回归
- 实现套索和岭回归
- 实现弹性网络回归
- 实现逻辑回归
- 支持向量机
- 介绍
- 使用线性 SVM
- 简化为线性回归
- 在 TensorFlow 中使用内核
- 实现非线性 SVM
- 实现多类 SVM
- 最近邻方法
- 介绍
- 使用最近邻
- 使用基于文本的距离
- 使用混合距离函数的计算
- 使用地址匹配的示例
- 使用最近邻进行图像识别
- 神经网络
- 介绍
- 实现操作门
- 使用门和激活函数
- 实现单层神经网络
- 实现不同的层
- 使用多层神经网络
- 改进线性模型的预测
- 学习玩井字棋
- 自然语言处理
- 介绍
- 使用词袋嵌入
- 实现 TF-IDF
- 使用 Skip-Gram 嵌入
- 使用 CBOW 嵌入
- 使用 word2vec 进行预测
- 使用 doc2vec 进行情绪分析
- 卷积神经网络
- 介绍
- 实现简单的 CNN
- 实现先进的 CNN
- 重新训练现有的 CNN 模型
- 应用 StyleNet 和 NeuralStyle 项目
- 实现 DeepDream
- 循环神经网络
- 介绍
- 为垃圾邮件预测实现 RNN
- 实现 LSTM 模型
- 堆叠多个 LSTM 层
- 创建序列到序列模型
- 训练 Siamese RNN 相似性度量
- 将 TensorFlow 投入生产
- 介绍
- 实现单元测试
- 使用多个执行程序
- 并行化 TensorFlow
- 将 TensorFlow 投入生产
- 生产环境 TensorFlow 的一个例子
- 使用 TensorFlow 服务
- 更多 TensorFlow
- 介绍
- 可视化 TensorBoard 中的图
- 使用遗传算法
- 使用 k 均值聚类
- 求解常微分方程组
- 使用随机森林
- 使用 TensorFlow 和 Keras