Подтвердить что ты не робот

Есть ли пример того, как создавать файлы protobuf, содержащие обучаемые графики Tensorflow

Я рассматриваю пример Google о том, как развернуть и использовать предварительно подготовленный график (модель) Tensorflow на Android, по адресу:

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android

В этом примере используется файл .pb по адресу: [это ссылка на файл, который автоматически загружается] https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

В этом примере показано, как загрузить файл .pb в сеанс Tensorflow и использовать его для выполнения классификации, но не упоминает (?), как создать такой .pb файл после обучения графа (например, в Python).

Есть ли примеры того, как это сделать?

4b9b3361

Ответ 1

EDIT: freeze_graph.py script, который является частью репозитория TensorFlow, теперь служит инструмент, который генерирует буфер протокола, представляющий "замороженную" обучаемую модель, из существующего TensorFlow GraphDef и сохраненной контрольной точки. Он использует те же шаги, что описаны ниже, но гораздо проще в использовании.


В настоящее время процесс не очень хорошо документирован (и может быть уточнен), но приблизительные шаги заключаются в следующем:

  • Создайте и подготовьте свою модель как tf.Graph под названием g_1.
  • Получить окончательные значения каждой из переменных и сохранить их как массивы numpy (используя Session.run()).
  • В новом tf.Graph, называемом g_2, создайте тензоры tf.constant() для каждой из переменных, используя значение соответствующего массива numpy, выбранного на шаге 2.
  • Используйте tf.import_graph_def() для копирования узлов из g_1 в g_2 и используйте аргумент input_map для замены каждая переменная в g_1 с соответствующими тензорами tf.constant(), созданная на шаге 3. Вы также можете использовать input_map для указания нового входного тензора (например, заменяя введите < с tf.placeholder()). Используйте аргумент return_elements, чтобы указать имя прогнозируемого выходного тензора.

  • Вызвать g_2.as_graph_def(), чтобы получить представление буфера в протоколе графика.

( ПРИМЕЧАНИЕ: Сгенерированный граф будет иметь дополнительные узлы в графике для обучения. Хотя он не является частью общедоступного API, вы можете использовать внутренний graph_util.extract_sub_graph(), чтобы удалить эти узлы из графика.)

Ответ 2

В качестве альтернативы моему предыдущему ответу, используя freeze_graph(), который хорош только, если вы называете его script, есть очень приятная функция, которая сделает весь тяжелый подъем для вас и подходит для вызова из вашего нормальный код обучения модели.

convert_variables_to_constants() выполняет две вещи:

  • Он замораживает вес, заменяя переменные константами
  • Он удаляет узлы, которые не связаны с предсказанием вперед

Предполагая, что sess - ваш tf.Session() и "output" - это имя вашего прогноза node, следующий код сериализует ваш минимальный график как в текстовый, так и в двоичный protobuf.


from tensorflow.python.framework.graph_util import convert_variables_to_constants

minimal_graph = convert_variables_to_constants(sess, sess.graph_def, ["output"])

tf.train.write_graph(minimal_graph, '.', 'minimal_graph.proto', as_text=False)
tf.train.write_graph(minimal_graph, '.', 'minimal_graph.txt', as_text=True)

Ответ 3

Я не мог понять, как реализовать метод, описанный mrry. Но вот как я это решил. Я не уверен, что это лучший способ решить проблему, но, по крайней мере, она решает ее.

Поскольку write_graph также может хранить значения констант, я добавил следующий код в python непосредственно перед написанием графика с помощью функции write_graph:

for v in tf.trainable_variables():
    vc = tf.constant(v.eval())
    tf.assign(v, vc, name="assign_variables")

Это создает константы, которые сохраняют значения переменных после обучения, а затем создают тензоры " assign_variables", чтобы назначить их переменным. Теперь, когда вы вызываете write_graph, он будет хранить значения переменных в файле в виде констант.

Единственной оставшейся частью является вызов этих тензоров " assign_variables" в коде c, чтобы убедиться, что ваши переменные назначены значениями констант, которые хранятся в файле. Вот один из способов сделать это:

      Status status = NewSession(SessionOptions(), &session);
      std::vector<tensorflow::Tensor> outputs;
      char name[100];
      for(int i = 0;status.ok(); i++) {
        if (i==0)
            sprintf(name, "assign_variables");
        else
            sprintf(name, "assign_variables_%d", i);

        status = session->Run({}, {name}, {}, &outputs);
      }

Ответ 4

Вот еще один ответ на @Mostafa. Несколько более простой способ запуска tf.assign ops - сохранить их в tf.group. Здесь мой код Python:

  ops = []
  for v in tf.trainable_variables():
    vc = tf.constant(v.eval())
    ops.append(tf.assign(v, vc));
  tf.group(*ops, name="assign_trained_variables")

И в С++:

  std::vector<tensorflow::Tensor> tmp;
  status = session.Run({}, {}, { "assign_trained_variables" }, &tmp);
  if (!status.ok()) {
    // Handle error
  }

Таким образом, у вас есть только один именованный оператор op для запуска на стороне С++, поэтому вам не нужно путаться с итерацией по узлам.

Ответ 5

Просто нашел этот пост, и это было очень полезно! Я также использую метод @Mostafa, хотя мой код на С++ немного отличается:

    std::vector<string> names;
    int node_count = graph.node_size();
    cout << node_count << " nodes in graph" << endl;

    // iterate all nodes
    for(int i=0; i<node_count; i++) {
        auto n = graph.node(i);
        cout << i << ":" << n.name() << endl;

        // if name contains "var_hack", add to vector
        if(n.name().find("var_hack") != std::string::npos) {
            names.push_back(n.name());
            cout << "......bang" << endl;
        }
    }
    session.Run({}, names, {}, &outputs);

NB Я использую "var_hack" как имя моей переменной в python

Ответ 6

Я нашел функцию freeze_graph() в кодовой базе Tensorflow, которая может быть полезна при выполнении этого. Из того, что я понимаю, он меняет переменные с константами перед сериализацией GraphDef, поэтому, когда вы загружаете этот график из С++, у него нет никаких переменных, которые необходимо установить больше, и вы можете напрямую использовать их для предсказаний.

Существует также test для него и некоторое описание в Руководство по.

Это похоже на самый чистый вариант здесь.