Я думаю, что было бы очень полезно сообществу Tensorflow, если бы было хорошо документированное решение решающей задачи тестирования одного нового изображения против модели, созданной convnet в учебнике CIFAR-10.
Возможно, я ошибаюсь, но этот критический шаг, который делает пригодную для обучения на практике модель, кажется, отсутствует. В этом учебнике есть "недостающее звено" - script, которое будет напрямую загружать одно изображение (в виде массива или двоичного кода), сравнивать его с обученной моделью и возвращать классификацию.
Предыдущие ответы дают частичные решения, которые объясняют общий подход, но ни один из которых я не смог успешно реализовать. Другие кусочки можно найти здесь и там, но, к сожалению, не добавили рабочего решения. Просьба рассмотреть сделанное мной исследование, прежде чем пометить это как дублирующее или уже ответившее.
Tensorflow: как сохранить/восстановить модель?
Восстановление модели TensorFlow
Невозможно восстановить модели в tensorflow v0.8
https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
Самый популярный ответ - первый, в котором @RyanSepassi и @YaroslavBulatov описывают проблему и подход: нужно "вручную построить график с идентичными именами node и использовать Saver для загрузки в него весов", Хотя оба ответа полезны, неясно, как можно было бы подключить его к проекту CIFAR-10.
Полностью функциональное решение было бы очень желательно, поэтому мы могли бы перенести его на другие проблемы классификации изображений. В этом отношении есть несколько вопросов, касающихся SO, которые требуют этого, но до сих пор нет полного ответа (например Загрузить контрольную точку и оценить одиночное изображение с DNN тензорного потока).
Надеюсь, мы сможем сблизиться с рабочим script, который каждый мог бы использовать.
Ниже script еще не функционирует, и я был бы рад услышать от вас, как это можно улучшить, чтобы обеспечить решение для классификации одного изображения с использованием обучаемой модели учебника CIFAR-10 TF.
Предположим, что все переменные, имена файлов и т.д. нетронуты из исходного учебника.
Новый файл: cifar10_eval_single.py
import cv2
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('eval_dir', './input/eval',
"""Directory where to write event logs.""")
tf.app.flags.DEFINE_string('checkpoint_dir', './input/train',
"""Directory where to read model checkpoints.""")
def get_single_img():
file_path = './input/data/single/test_image.tif'
pixels = cv2.imread(file_path, 0)
return pixels
def eval_single_img():
# below code adapted from @RyanSepassi, however not functional
# among other errors, saver throws an error that there are no
# variables to save
with tf.Graph().as_default():
# Get image.
image = get_single_img()
# Build a Graph.
# TODO
# Create dummy variables.
x = tf.placeholder(tf.float32)
w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print('Checkpoint found')
else:
print('No checkpoint found')
# Run the model to get predictions
predictions = sess.run(y_hat, feed_dict={x: image})
print(predictions)
def main(argv=None):
if tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.DeleteRecursively(FLAGS.eval_dir)
tf.gfile.MakeDirs(FLAGS.eval_dir)
eval_single_img()
if __name__ == '__main__':
tf.app.run()