Подтвердить что ты не робот

Как найти имена переменных, которые сохраняются в контрольной точке тензорного потока?

Я хочу видеть переменные, которые сохраняются в контрольной точке tensorflow вместе с их значениями. Как найти имена переменных, которые сохраняются в контрольной точке тензорного потока?

EDIT:

Я использовал tf.train.NewCheckpointReader, который объясняется здесь. Но это не дано в документации тензорного потока. Есть ли другой способ?

`

    import tensorflow as tf
    v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name="v0")
    v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32,
                     name="v1")
    init_all_op = tf.initialize_all_variables()
    save = tf.train.Saver({"v0": v0, "v1": v1})
    checkpoint_path = os.path.join(model_dir, "model.ckpt")    

    with tf.Session() as sess:
      sess.run(init_all_op)
      # Saves a checkpoint.      
      save.save(sess, checkpoint_path)

      # Creates a reader.
      reader = tf.train.NewCheckpointReader(checkpoint_path)
      print('reder:\n', reader)

      # Verifies that the tensors exist.
      print('is exist v0?', reader.has_tensor("v0"))
      print('is exist v1?', reader.has_tensor("v1"))

      # Verifies that debug string contains the right strings.
      debug_string = reader.debug_string()
      print('\n All Variables: \n', debug_string)

      # Verifies get_variable_to_shape_map() returns the correct information.
      var_map = reader.get_variable_to_shape_map()
      print('\n All Variables information :\n', var_map)

      # Verifies get_tensor() returns the tensor value.
      v0_tensor = reader.get_tensor("v0")
      v1_tensor = reader.get_tensor("v1")
      print('\n   returns the v0 tensor value:\n', v0_tensor)
      print('\n   returns the v1 tensor value:\n', v1_tensor)

`

4b9b3361

Ответ 2

Использование примера:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
checkpoint_path = os.path.join(model_dir, "model.ckpt")

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')

# List contents of v0 tensor.
# Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')

# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')

Обновление: all_tensors аргумент был добавлен в print_tensors_in_checkpoint_file, поскольку Tensorflow 0.12.0-rc0, чтобы вы может потребоваться добавить all_tensors=False или all_tensors=True, если это необходимо.

Альтернативный метод:

from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names

Надеюсь, что это поможет.

Ответ 3

Добавление к предыдущему ответу:

Если модель сохраняется в формате V2

model-10000.data-00000-of-00001
model-10000.index
model-10000.meta

Вводимое имя контрольной точки должно быть только префиксом

print_tensors_in_checkpoint_file(file_name='/home/RNN/models/model_10000', tensor_name='',all_tensors=True)

источник: by @LingjiaDeng в https://github.com/tensorflow/tensorflow/issues/7696