Подтвердить что ты не робот

TensorFlow: получение переменной по имени

При использовании TensorFlow Python API я создал переменную (без указания ее name в конструкторе), а ее свойство name имеет значение "Variable_23:0". Когда я пытаюсь выбрать эту переменную с помощью tf.get_variable("Variable23"), вместо нее создается новая переменная с именем "Variable_23_1:0". Как правильно выбрать "Variable_23" вместо создания нового?

Что я хочу сделать, это выбрать переменную по имени и повторно инициализировать ее, чтобы я мог точно определить вес.

4b9b3361

Ответ 1

Самый простой способ получить переменную по имени - найти ее в коллекции tf.global_variables():

var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]

Это хорошо подходит для повторного использования существующих переменных. Более структурированный подход: когда вы хотите обмениваться переменными между несколькими частями модели, вы можете ознакомиться в Разделение переменных переменных.

Ответ 2

Функция get_variable() создает новую переменную или возвращает ранее созданную get_variable(). Он не будет возвращать переменную, созданную с помощью tf.Variable(). Вот краткий пример:

>>> with tf.variable_scope("foo"):
...   bar1 = tf.get_variable("bar", (2,3)) # create
... 
>>> with tf.variable_scope("foo", reuse=True):
...   bar2 = tf.get_variable("bar")  # reuse
... 

>>> with tf.variable_scope("", reuse=True): # root variable scope
...   bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
... 
>>> (bar1 is bar2) and (bar2 is bar3)
True

Если вы не создали переменную с помощью tf.get_variable(), у вас есть пара вариантов. Во-первых, вы можете использовать tf.global_variable() (как предлагает @mrry):

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True

Или вы можете использовать tf.get_collection() так:

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True

Edit

Вы также можете использовать get_tensor_by_name():

>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = tf.get_tensor_by_name("bar:0")
>>> bar1 is bar2
True

Напомним, что тензор является результатом операции. Он имеет то же имя, что и операция, плюс :0. Если операция имеет несколько выходов, они имеют то же имя, что и операция плюс :0, :1, :2 и т.д.