Подтвердить что ты не робот

Смутно поведением `tf.cond`

Мне нужен поток условного управления в моем графике. Если pred - True, граф должен вызывать op, который обновляет переменную и затем возвращает ее, в противном случае она возвращает переменную без изменений. Упрощенная версия:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

Однако я обнаружил, что оба pred=True и pred=False приводят к одному результату y=[2], что означает, что присваивание op также вызывается, если update_x_2 не выбрано tf.cond. Как это объяснить? И как решить эту проблему?

4b9b3361

Ответ 1

TL; DR: Если вы хотите tf.cond() выполнить побочный эффект (например, присвоение) в одной из ветвей, вы должны создать op, который выполняет побочный эффект внутри, который вы передаете в tf.cond().

Поведение tf.cond() немного неинтуитивно. Поскольку выполнение в графе TensorFlow перемещается вперед по графику, все операции, которые вы указываете в ветке или, должны выполняться до вычисления условия. Это означает, что как истинная, так и ложная ветки получают зависимую зависимость от tf.assign() op, поэтому y всегда получает значение 2, даже если pred is False`.

Решение состоит в том, чтобы создать tf.assign() op внутри функции, которая определяет истинную ветвь. Например, вы можете структурировать свой код следующим образом:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]