Подтвердить что ты не робот

TensorFlow сохранение/загрузка графика из файла

Из того, что я собрал до сих пор, существует несколько разных способов сброса графика TensorFlow в файл, а затем загрузка его в другую программу, но я не смог найти четкие примеры/информацию о том, как они работают, Я уже знаю это:

  • Сохраните переменные модели в файл контрольной точки (.ckpt) с помощью tf.train.Saver() и восстановите их позже (source)
  • Сохраните модель в файле .pb и загрузите ее обратно с помощью tf.train.write_graph() и tf.import_graph_def() (source)
  • Загрузите модель из файла .pb, переустановите ее и выгрузите в новый .pb файл с помощью Bazel (source)
  • Зафиксируйте график, чтобы сохранить график и вес вместе (источник)
  • Используйте as_graph_def() для сохранения модели и для весов/переменных, сопоставьте их с константами (источник)

Однако я не смог прояснить несколько вопросов относительно этих разных методов:

  • Что касается файлов контрольных точек, они сохраняют только подготовленные веса модели? Могут ли файлы контрольных точек загружаться в новую программу и использоваться для запуска модели, или они просто служат в качестве способов сохранения весов в модели в определенное время/этап?
  • Что касается tf.train.write_graph(), также сохраняются ли весы/переменные?
  • Что касается Bazel, может ли он только сохранить/загрузить из .pb файлов для переподготовки? Есть ли простая команда Bazel, чтобы сбрасывать граф в .pb?
  • Что касается замораживания, может ли загруженный замороженный граф использовать tf.import_graph_def()?
  • Демонстрация Android для загрузки TensorFlow в модели Google Inception из файла .pb. Если бы я хотел подставить свой собственный .pb файл, как бы я это сделал? Должен ли я изменить любой собственный код/​​методы?
  • В общем, какая именно разница между всеми этими методами? Или более широко, в чем разница между as_graph_def()/. Ckpt/.pb?

Короче говоря, то, что я ищу, - это метод, который позволяет сохранить как график (как в, различные операции и т.д.), так и его вес/переменные в файл, который затем можно использовать для загрузки графика и весов в другую программу для использования (не обязательно продолжение/переподготовка).

Документация по этой теме не очень проста, поэтому любые ответы/информация были бы оценены.

4b9b3361

Ответ 1

Есть много способов подойти к проблеме сохранения модели в TensorFlow, что может сделать ее несколько запутанной. Принимая каждый из ваших вопросов:

  1. Файлы контрольных точек (создаваемые, например, путем вызова saver.save() объекта tf.train.Saver) содержат только веса и любые другие переменные, определенные в одной и той же программе. Чтобы использовать их в другой программе, вы должны повторно создать связанную структуру графа (например, запустив код для его сборки снова или вызвав tf.import_graph_def()), который сообщает TensorFlow, что делать с этими весами. Обратите внимание, что вызов saver.save() также создает файл, содержащий MetaGraphDef, который содержит график и сведения о том, как связать веса с контрольной точки с этим графом. Дополнительную информацию см. В руководстве.

  2. tf.train.write_graph() записывает только структуру графа; а не веса.

  3. Bazel не связан с чтением или написанием графиков TensorFlow. (Возможно, я неправильно понимаю ваш вопрос: не стесняйтесь прояснить это в комментарии.)

  4. Замороженный график можно загрузить с помощью tf.import_graph_def(). В этом случае весы (обычно) встроены в график, поэтому вам не нужно загружать отдельную контрольную точку.

  5. Основное изменение заключалось бы в обновлении имен тензора (ов), которые подаются в модель, и имен тензора (ов), которые извлекаются из модели. В демоверсии TensorFlow Android это будет соответствовать inputName и outputName, которые передаются TensorFlowClassifier.initializeTensorFlow().

  6. GraphDef - это структура программы, которая обычно не изменяется в процессе обучения. Контрольная точка представляет собой моментальный снимок состояния процесса обучения, который обычно изменяется на каждом этапе учебного процесса. В результате TensorFlow использует разные форматы хранения данных этих типов, а низкоуровневый API предоставляет различные способы их сохранения и загрузки. Библиотеки более высокого уровня, такие как MetaGraphDef библиотеки, Keras и skflow на основе этих механизмов, чтобы обеспечить более удобные способы сохранения и восстановления целой модели.

Ответ 2

Вы можете попробовать следующий код:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)