Подтвердить что ты не робот

Индексация тензорного потока с булевым тензором

В numpy, с двумя массивами одинаковой формы, x и y, можно сделать такие фрагменты как y[x > 1]. Как вы достигаете такого же результата в тензорном потоке? y[tf.greater(x, 1)] не работает, а tf.slice не поддерживает ничего подобного. Есть ли способ индексирования с булевым тензором прямо сейчас или в настоящее время не поддерживается?

4b9b3361

Ответ 1

Try:

ones = tf.ones_like(x) # create a tensor all ones
mask = tf.greater(x, ones) # boolean tensor, mask[i] = True iff x[i] > 1
slice_y_greater_than_one = tf.boolean_mask(y, mask)

См. tf.boolean_mask

EDIT: еще один способ (лучше?):

import tensorflow as tf

x = tf.constant([1, 2, 0, 4])
y = tf.Variable([1, 2, 0, 4])
mask = x > 1
slice_y_greater_than_one = tf.boolean_mask(y, mask)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print (sess.run(slice_y_greater_than_one)) # [2 4]

Ответ 2

Я бы не сказал, что он полностью не реализован. Как это для двойного отрицательного?

Tensorflow фактически поддерживает довольно много нарезки и нарезки, хотя синтаксис может быть немного менее симпатичным. Например, если вы хотите создать новый массив, равный y, когда x>1, но равный 0 в противном случае, вы можете это сделать. Проверьте операторы сравнения, например

masked = tf.greater(x,1)
zeros = tf.zeros_like(x)
new_tensor = tf.where(masked, y, zeros)

Если, с другой стороны, вы хотите создать новый массив, содержащий только парней, где x>1 вы можете сделать это, объединив where с функцией gather. Подробности для gather можно найти на

https://www.tensorflow.org/versions/master/api_docs/python/array_ops/slicing_and_joining

PS. Конечно, x>1 не дифференцируема относительно x... tf может быть большой, но она не работает магия:).

Ответ 4

tf.boolean_mask выполняет эту работу, но на некоторых платформах, таких как Raspberry Pi или OSX, операция не поддерживается в распределении колес Tensorflow (проверьте это tf. boolean_mask не поддерживается в OSX. Альтернативой является использование where и gather, как предложил @Jackson Loper. Например:

x = tf.Variable([1, 2, 0, 4])
ix = tf.where(x > 1)
y = tf.gather(x, ix)

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 print(sess.run(y))