ThinkChat2.0新版上线,更智能更精彩,支持会话、画图、阅读、搜索等,送10W Token,即刻开启你的AI之旅 广告
# 使用 TensorFlow 中的再训练的 Inception v3 进行图像分类 Inception v3 的再训练与 VGG16 不同,因为我们使用 softmax 激活层作为输出,`tf.losses.softmax_cross_entropy()`作为损耗函数。 1. 首先定义占位符: ```py is_training = tf.placeholder(tf.bool,name='is_training') x_p = tf.placeholder(shape=(None, image_height, image_width, 3 ), dtype=tf.float32, name='x_p') y_p = tf.placeholder(shape=(None,coco.n_classes), dtype=tf.int32, name='y_p') ``` 1. 接下来,加载模型: ```py with slim.arg_scope(inception.inception_v3_arg_scope()): logits,_ = inception.inception_v3(x_p, num_classes=coco.n_classes, is_training=True ) probabilities = tf.nn.softmax(logits) ``` 1. 接下来,定义函数以恢复除最后一层之外的变量: ```py with slim.arg_scope(inception.inception_v3_arg_scope()): logits,_ = inception.inception_v3(x_p, num_classes=coco.n_classes, is_training=True ) probabilities = tf.nn.softmax(logits) # restore except last layer checkpoint_exclude_scopes=["InceptionV3/Logits", "InceptionV3/AuxLogits"] exclusions = [scope.strip() for scope in checkpoint_exclude_scopes] variables_to_restore = [] for var in slim.get_model_variables(): excluded = False for exclusion in exclusions: if var.op.name.startswith(exclusion): excluded = True break if not excluded: variables_to_restore.append(var) init_fn = slim.assign_from_checkpoint_fn( os.path.join(model_home, '{}.ckpt'.format(model_name)), variables_to_restore) ``` 1. 定义损失,优化程序和训练操作: ```py tf.losses.softmax_cross_entropy(onehot_labels=y_p, logits=logits) loss = tf.losses.get_total_loss() learning_rate = 0.001 optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = optimizer.minimize(loss) ``` 1. 训练模型并在同一会话中完成训练后运行预测: ```py n_epochs=10 coco.y_onehot = True coco.batch_size = 32 coco.batch_shuffle = True total_images = len(x_train_files) n_batches = total_images // coco.batch_size with tf.Session() as tfs: tfs.run(tf.global_variables_initializer()) init_fn(tfs) for epoch in range(n_epochs): print('Starting epoch ',epoch) coco.reset_index() epoch_accuracy=0 epoch_loss=0 for batch in range(n_batches): x_batch, y_batch = coco.next_batch() images=np.array([coco.preprocess_for_inception(x) \ for x in x_batch]) feed_dict={x_p:images,y_p:y_batch,is_training:True} batch_loss,_ = tfs.run([loss,train_op], feed_dict = feed_dict) epoch_loss += batch_loss epoch_loss /= n_batches print('Train loss in epoch {}:{}' .format(epoch,epoch_loss)) # now run the predictions feed_dict={x_p:images_test,is_training: False} probs = tfs.run([probabilities],feed_dict=feed_dict) probs=probs[0] ``` 我们看到每个周期的损失都在减少: ```py INFO:tensorflow:Restoring parameters from /home/armando/models/inception_v3/inception_v3.ckpt Starting epoch 0 Train loss in epoch 0:2.7896385192871094 Starting epoch 1 Train loss in epoch 1:1.6651896286010741 Starting epoch 2 Train loss in epoch 2:1.2332031989097596 Starting epoch 3 Train loss in epoch 3:0.9912329530715942 Starting epoch 4 Train loss in epoch 4:0.8110128355026245 Starting epoch 5 Train loss in epoch 5:0.7177265572547913 Starting epoch 6 Train loss in epoch 6:0.6175705575942994 Starting epoch 7 Train loss in epoch 7:0.5542363750934601 Starting epoch 8 Train loss in epoch 8:0.523461252450943 Starting epoch 9 Train loss in epoch 9:0.4923107647895813 ``` 这次结果正确识别了绵羊,但错误地将猫图片识别为狗: ![](https://img.kancloud.cn/d5/a9/d5a99434c27c21542f94d7f5aafd7fc0_315x306.png) ```py Probability 98.84% of [zebra] Probability 0.84% of [giraffe] Probability 0.11% of [sheep] Probability 0.07% of [cat] Probability 0.06% of [dog] ``` --- ![](https://img.kancloud.cn/49/a6/49a68966aaa0ee71305961e2c5cada13_315x306.png) ```py Probability 95.77% of [horse] Probability 1.34% of [dog] Probability 0.89% of [zebra] Probability 0.68% of [bird] Probability 0.61% of [sheep] ``` --- ![](https://img.kancloud.cn/a8/ff/a8ff8a087a8cb72538fce00f199d8497_315x306.png) ```py Probability 94.83% of [dog] Probability 4.53% of [cat] Probability 0.56% of [sheep] Probability 0.04% of [bear] Probability 0.02% of [zebra] ``` --- ![](https://img.kancloud.cn/63/19/6319209b3678f238237547e18f9c9e65_315x306.png) ```py Probability 42.80% of [bird] Probability 25.64% of [cat] Probability 15.56% of [bear] Probability 8.77% of [giraffe] Probability 3.39% of [sheep] ``` --- ![](https://img.kancloud.cn/d5/38/d5388bb62b6dff6e317c441799363147_315x306.png) ```py Probability 72.58% of [sheep] Probability 8.40% of [bear] Probability 7.64% of [giraffe] Probability 4.02% of [horse] Probability 3.65% of [bird] ``` --- ![](https://img.kancloud.cn/0a/18/0a18ac3f3565f5993a6a2738935e8b20_315x306.png) ```py Probability 98.03% of [bear] Probability 0.74% of [cat] Probability 0.54% of [sheep] Probability 0.28% of [bird] Probability 0.17% of [horse] ``` --- ![](https://img.kancloud.cn/95/9a/959ab88e20b5c821831cb2ec8a433883_315x306.png) ```py Probability 96.43% of [giraffe] Probability 1.78% of [bird] Probability 1.10% of [sheep] Probability 0.32% of [zebra] Probability 0.14% of [bear] ``` --- ![](https://img.kancloud.cn/62/ff/62fffd6d8c14b02a0b8d7a6761bc4f6a_315x306.png) ```py Probability 34.43% of [horse] Probability 23.53% of [dog] Probability 16.03% of [zebra] Probability 9.76% of [cat] Probability 9.02% of [giraffe] ```