Подтвердить что ты не робот

Сглаживание партии в тензорном потоке

У меня есть вход для тензорного потока формы [None, 9, 2] (где None является пакетным).

Для выполнения дальнейших действий (например, matmul) на нем мне нужно преобразовать его в форму [None, 18]. Как это сделать?

4b9b3361

Ответ 1

Вы можете сделать это легко с помощью tf.reshape(), не зная размер партии.

x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list()        # a list: [None, 9, 2]
dim = numpy.prod(shape[1:])            # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim])           # -1 means "all"

-1 в последней строке означает весь столбец, вне зависимости от того, какой размер пакета находится во время выполнения. Вы можете увидеть его в tf.reshape().


Обновление: форма = [Нет, 3, Нет]

Спасибо @kbrose. Для случаев, когда более 1 размерности undefined, мы можем использовать tf.shape() с tf.reduce_prod() в качестве альтернативы.

x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])

tf.shape() возвращает тензодамент формы, который можно оценить во время выполнения. Разницу между tf.get_shape() и tf.shape() можно увидеть в документе.

Я также попробовал tf.contrib.layers.flatten() в другом. Это проще всего для первого случая, но он не может обрабатывать второй.

Ответ 2

flat_inputs = tf.layers.flatten(inputs)

Ответ 3

Вы можете использовать динамическое преобразование, чтобы получить значение размера партии через tf.batch во время выполнения, вычислить весь набор новых измерений в tf.reshape. Здесь приведен пример преобразования плоского списка в квадратную матрицу без знания длины списка.

tf.reset_default_graph()
sess = tf.InteractiveSession("")
a = tf.placeholder(dtype=tf.int32)
# get [9]
ashape = tf.shape(a)
# slice the list from 0th to 1st position
ashape0 = tf.slice(ashape, [0], [1])
# reshape list to scalar, ie from [9] to 9
ashape0_flat = tf.reshape(ashape0, ())
# tf.sqrt doesn't support int, so cast to float
ashape0_flat_float = tf.to_float(ashape0_flat)
newshape0 = tf.sqrt(ashape0_flat_float)
# convert [3, 3] Python list into [3, 3] Tensor
newshape = tf.pack([newshape0, newshape0])
# tf.reshape doesn't accept float, so convert back to int
newshape_int = tf.to_int32(newshape)
a_reshaped = tf.reshape(a, newshape_int)
sess.run(a_reshaped, feed_dict={a: np.ones((9))})

Вы должны увидеть

array([[1, 1, 1],
       [1, 1, 1],
       [1, 1, 1]], dtype=int32)