Подтвердить что ты не робот

Как добавить, если условие в графе TensorFlow?

Скажем, у меня есть следующий код:

x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input")
condition = tf.placeholder("int32", shape=[1, 1], name = "condition")
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights")
b = tf.Variable(tf.zeros([label_option]), name = "bias")

if condition > 0:
    y = tf.nn.softmax(tf.matmul(x, W) + b)
else:
    y = tf.nn.softmax(tf.matmul(x, W) - b)  

Будет ли оператор if работать в вычислении (я так не думаю)? Если нет, как я могу добавить оператор if в график расчета TensorFlow?

4b9b3361

Ответ 1

Вы правы, что инструкция if здесь не работает, потому что условие оценивается во время построения графика, тогда как, предположительно, вы хотите, чтобы условие зависело от значения, введенного в заполнитель во время выполнения. (Фактически, он всегда будет принимать первую ветвь, потому что condition > 0 оценивается как Tensor, который "правдивый" в Python. )

Чтобы поддерживать поток условного управления, TensorFlow предоставляет оператор tf.cond(), который оценивает одну из двух ветвей, в зависимости от логического условия. Чтобы показать вам, как его использовать, я переписал вашу программу, чтобы condition было скалярным значением tf.int32 для простоты:

x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
condition = tf.placeholder(tf.int32, shape=[], name="condition")
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
b = tf.Variable(tf.zeros([label_option]), name="bias")

y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)

Ответ 2

TensorFlow 2.0

В TF 2.0 появилась функция AutoGraph, которая позволяет JIT компилировать код Python в графические исполнения. Это означает, что вы можете использовать операторы потока управления python (да, это включает операторы if). Из документов,

Автограф поддерживает общие заявления Python, как while, for, if, break, continue и return с поддержкой вложенности. Это означает, что вы можете использовать выражения Tensor в условии операторов while и if или выполнять итерацию по Tensor в цикле for.

Вам нужно будет определить функцию, реализующую вашу логику, и аннотировать ее с помощью функции tf.function. Вот модифицированный пример из документации:

import tensorflow as tf

@tf.function
def sum_even(items):
  s = 0
  for c in items:
    if tf.equal(c % 2, 0): 
        s += c
  return s

sum_even(tf.constant([10, 12, 15, 20]))
#  <tf.Tensor: id=1146, shape=(), dtype=int32, numpy=42>