Подтвердить что ты не робот

Как я могу реализовать пользовательский RNN (в частности, ESN) в Tensorflow?

Я пытаюсь определить свой собственный RNNCell (Echo State Network) в Tensorflow, согласно ниже определению.

x (t + 1) = tanh (Win * u (t) + W * x (t) + Wfb * y (t))

y (t) = Wout * z (t)

z (t) = [x (t), u (t)]

x - состояние, u вводится, y выводится. Win, W и Wfb не являются обучаемыми. Все веса случайным образом инициализируются, но W модифицируется следующим образом: "Установите определенный процент элементов W в 0, масштабируйте W, чтобы сохранить его спектральный радиус ниже 1,0

У меня есть этот код для генерации уравнения.

x = tf.Variable(tf.reshape(tf.zeros([N]), [-1, N]), trainable=False, name="state_vector")
W = tf.Variable(tf.random_normal([N, N], 0.0, 0.05), trainable=False)
# TODO: setup W according to the ESN paper
W_x = tf.matmul(x, W)

u = tf.placeholder("float", [None, K], name="input_vector")
W_in = tf.Variable(tf.random_normal([K, N], 0.0, 0.05), trainable=False)
W_in_u = tf.matmul(u, W_in)

z = tf.concat(1, [x, u])
W_out = tf.Variable(tf.random_normal([K + N, L], 0.0, 0.05))
y = tf.matmul(z, W_out)
W_fb = tf.Variable(tf.random_normal([L, N], 0.0, 0.05), trainable=False)
W_fb_y = tf.matmul(y, W_fb)

x_next = tf.tanh(W_in_u + W_x + W_fb_y)

y_ = tf.placeholder("float", [None, L], name="train_output")

Моя проблема в два раза. Сначала я не знаю, как реализовать это как суперкласс RNNCell. Во-вторых, я не знаю, как создать тензор W в соответствии с вышеприведенной спецификацией.

Любая помощь по любому из этих вопросов очень ценится. Может быть, я могу выяснить способ подготовки W, но я уверен, что, черт возьми, я не понимаю, как реализовать свой собственный RNN как суперкласс RNNCell.

4b9b3361

Ответ 1

Чтобы дать краткое описание:

Посмотрите в исходном коде TensorFlow в разделе python/ops/rnn_cell.py, также посмотрите, как подкласс RNNCell. Это обычно так:

class MyRNNCell(RNNCell):
  def __init__(...):

  @property
  def output_size(self):
  ...

  @property
  def state_size(self):
  ...

  def __call__(self, input_, state, name=None):
     ... your per-step iteration here ...