🔥码云GVP开源项目 12k star Uniapp+ElementUI 功能强大 支持多语言、二开方便! 广告
# 重新训练现有的 CNN 模型 从头开始训练新的图像识别需要大量的时间和计算能力。如果我们可以采用先前训练的网络并使用我们的图像重新训练它,它可以节省我们的计算时间。对于此秘籍,我们将展示如何使用预先训练的 TensorFlow 图像识别模型并对其进行微调以处理不同的图像集。 ## 做好准备 其思想是从卷积层重用先前模型的权重和结构,并重新训练网络顶部的完全连接层。 TensorFlow 在现有 CNN 模型的基础上创建了一个关于训练的教程(请参阅下一节中的第一个要点)。在本文中,我们将说明如何对 CIFAR-10 使用相同的方法。我们将采用的 CNN 网络使用一种非常流行的架构,称为 Inception。 Inception CNN 模型由 Google 创建,在许多图像识别基准测试中表现非常出色。有关详细信息,请参阅“另请参阅”部分的第二个要点中的纸张参考。 我们将介绍的主要 Python 脚本显示如何下载 CIFAR-10 图像数据并自动分离,标记和保存图像到每个训练和测试文件夹中的十个类。之后,我们将重申如何在我们的图像上训练网络。 ## 操作步骤 执行以下步骤: 1. 我们首先加载必要的库来下载,解压缩和保存 CIFAR-10 图像: ```py import os import tarfile import _pickle as cPickle import numpy as np import urllib.request import scipy.misc from imageio import imwrite ``` 1. 我们现在声明 CIFAR-10 数据链接并创建我们将存储数据的临时目录。我们还将在以后保存图像时声明要引用的十个类别: ```py cifar_link = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' data_dir = 'temp' if not os.path.isdir(data_dir): os.makedirs(data_dir) objects = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] ``` 1. 现在我们将下载 CIFAR-10 `.tar`数据文件,并解压该文件: ```py target_file = os.path.join(data_dir, 'cifar-10-python.tar.gz') if not os.path.isfile(target_file): print('CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)') print('This may take a few minutes, please wait.') filename, headers = urllib.request.urlretrieve(cifar_link, target_file) # Extract into memory tar = tarfile.open(target_file) tar.extractall(path=data_dir) tar.close() ``` 1. 我们现在为训练创建必要的文件夹结构。临时目录将有两个文件夹,`train_dir`和`validation_dir`。在每个文件夹中,我们将为每个类别创建 10 个子文件夹: ```py # Create train image folders train_folder = 'train_dir' if not os.path.isdir(os.path.join(data_dir, train_folder)): for i in range(10): folder = os.path.join(data_dir, train_folder, objects[i]) os.makedirs(folder) # Create test image folders test_folder = 'validation_dir' if not os.path.isdir(os.path.join(data_dir, test_folder)): for i in range(10): folder = os.path.join(data_dir, test_folder, objects[i]) os.makedirs(folder) ``` 1. 为了保存图像,我们将创建一个从内存加载它们并将它们存储在图像字典中的函数: ```py def load_batch_from_file(file): file_conn = open(file, 'rb') image_dictionary = cPickle.load(file_conn, encoding='latin1') file_conn.close() return(image_dictionary) ``` 1. 使用前面的字典,我们将使用以下函数将每个文件保存在正确的位置: ```py def save_images_from_dict(image_dict, folder='data_dir'): for ix, label in enumerate(image_dict['labels']): folder_path = os.path.join(data_dir, folder, objects[label]) filename = image_dict['filenames'][ix] #Transform image data image_array = image_dict['data'][ix] image_array.resize([3, 32, 32]) # Save image output_location = os.path.join(folder_path, filename) imwrite(output_location,image_array.transpose()) ``` 1. 使用上述函数,我们可以遍历下载的数据文件并将每个图像保存到正确的位置: ```py data_location = os.path.join(data_dir, 'cifar-10-batches-py') train_names = ['data_batch_' + str(x) for x in range(1,6)] test_names = ['test_batch'] # Sort train images for file in train_names: print('Saving images from file: {}'.format(file)) file_location = os.path.join(data_dir, 'cifar-10-batches-py', file) image_dict = load_batch_from_file(file_location) save_images_from_dict(image_dict, folder=train_folder) # Sort test images for file in test_names: print('Saving images from file: {}'.format(file)) file_location = os.path.join(data_dir, 'cifar-10-batches-py', file) image_dict = load_batch_from_file(file_location) save_images_from_dict(image_dict, folder=test_folder) ``` 1. 我们脚本的最后一部分创建了图像标签文件,这是我们需要的最后一条信息。这个文件让我们将输出解释为标签而不是数字索引: ```py cifar_labels_file = os.path.join(data_dir,'cifar10_labels.txt') print('Writing labels file, {}'.format(cifar_labels_file)) with open(cifar_labels_file, 'w') as labels_file: for item in objects: labels_file.write("{}n".format(item)) ``` 1. 当前面的脚本运行时,它将下载图像并将它们分类到 TensorFlow 再训练教程所期望的正确文件夹结构中。完成后,我们只需按照教程进行操作即可。首先,我们应该克隆教程仓库: ```py git clone https://github.com/tensorflow/models/tree/master/research/inception ``` 1. 为了使用先前训练的模型,我们必须下载网络权重并将其应用于我们的模型。为此,您必须访问该站点: [https://github.com/tensorflow/models/tree/master/research/slim](https://github.com/tensorflow/models/tree/master/research/slim) ,并按照说明下载并安装 cifar10 模型架构和权重。您还将最终下载包含下面描述的构建,训练和测试脚本的数据目录。 > 对于此步骤,我们导航到 research / inception / inception 目录,然后执行以下命令,`--train_directory`,`--validation_directory`,`--output_directory`和`--labels_file`的路径指向相对路径或完整路径创建的目录结构。 1. 现在我们将图像放在正确的文件夹结构中,我们必须将它们变成`TFRecords`对象。我们通过运行以下命令来完成此操作: ```py me@computer:~$ python3 data/build_image_data.py --train_directory="temp/train_dir/" --validation_directory="temp/validation_dir" --output_directory="temp/" --labels_file="temp/cifar10_labels.txt" ``` 1. 现在我们将使用`bazel`训练模型,将参数设置为`true`。该脚本每 10 代输出一次损失。我们可以随时终止此过程,模型输出将在`temp/training_results`文件夹中。我们可以从此文件夹加载模型以进行评估: ```py me@computer:~$ bazel-bin/inception/flowers_train --train_dir="temp/training_results" --data_dir="temp/data_dir" --pretrained_model_checkpoint_path="model.ckpt-157585" --fine_tune=True --initial_learning_rate=0.001 --input_queue_memory_factor=1 ``` 1. 这应该导致输出类似于以下内容: ```py 2018-06-02 11:10:10.557012: step 1290, loss = 2.02 (1.2 examples/sec; 23.771 sec/batch) ... ``` ## 工作原理 关于预训练 CNN 上的训练的官方 TensorFlow 教程需要设置一个文件夹;我们从 CIFAR-10 数据创建的设置。然后我们将数据转换为所需的`TFRecords`格式并开始训练模型。请记住,我们正在微调模型并重新训练顶部的完全连接的层以适合我们的 10 类数据。 ## 另见 * 官方 Tensorflow Inception-v3 教程: [https://www.tensorflow.org/tutoriaimg/image_recognition](https://www.tensorflow.org/tutoriaimg/image_recognition) * Googlenet Inception-v3 文件: [https://arxiv.org/abs/1512.00567](https://arxiv.org/abs/1512.00567)