Подтвердить что ты не робот

Как использовать tf.while_loop() в тензорном потоке

Это общий вопрос. Я обнаружил, что в тензорном потоке, после того как мы построим график, выберем данные в график, вывод из графика будет тензором. но во многих случаях нам нужно сделать некоторые вычисления на основе этого вывода (который является tensor), что недопустимо в тензорном потоке.

Например, я пытаюсь реализовать RNN, который циклически меняет время, основанное на собственном свойстве данных. То есть, мне нужно использовать tensor, чтобы судить о том, следует ли останавливаться (я не использую dynamic_rnn, поскольку в моем дизайне rnn очень настроен). Я считаю, что tf.while_loop(cond,body.....) может быть кандидатом на мою реализацию. Но официальный урок слишком прост. Я не знаю, как добавить больше функциональности в "тело". Может ли кто-нибудь дать мне несколько более сложных примеров?

Кроме того, в таком случае, если будущий расчет основан на тензорном выходе (например: останов RNN на основе критерия выхода), что является очень распространенным случаем. Есть ли элегантный способ или лучший способ вместо динамического графика?

4b9b3361

Ответ 1

Что мешает вам добавлять больше функциональности в тело? Вы можете построить любой сложный вычислительный граф, который вам нравится в теле, и взять любые входы, которые вам нравятся, из прилагаемого графика. Кроме того, вне цикла вы можете делать все, что хотите, с любыми выходами, которые вы возвращаете. Как видно из количества "битверов", примитивы управления потоком TensorFlow были построены с большой общности. Ниже приведен еще один "простой" пример, если он помогает.

import tensorflow as tf
import numpy as np

def body(x):
    a = tf.random_uniform(shape=[2, 2], dtype=tf.int32, maxval=100)
    b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32)
    c = a + b
    return tf.nn.relu(x + c)

def condition(x):
    return tf.reduce_sum(x) < 100

x = tf.Variable(tf.constant(0, shape=[2, 2]))

with tf.Session():
    tf.initialize_all_variables().run()
    result = tf.while_loop(condition, body, [x])
    print(result.eval())