Подтвердить что ты не робот

Tensorflow: как получить все переменные из rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell

У меня есть настройка, где мне нужно инициализировать LSTM после основной инициализации, которая использует tf.initialize_all_variables(). То есть Я хочу позвонить tf.initialize_variables([var_list])

Есть ли способ собрать все внутренние обучаемые переменные для обоих:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

чтобы я мог инициализировать JUST эти параметры?

Основная причина, по которой я хочу это, состоит в том, что я не хочу повторно инициализировать некоторые обучаемые значения из более ранних версий.

4b9b3361

Ответ 1

Самый простой способ решить вашу проблему - использовать область переменной. Имена переменных в пределах области будут иметь префикс с именем. Вот короткий фрагмент:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  # Retrieve just the LSTM variables.
  lstm_variables = [v for v in tf.all_variables()
                    if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)

Он будет работать аналогично с MultiRNNCell.

EDIT: изменено tf.trainable_variables на tf.all_variables()

Ответ 2

Вы также можете использовать tf.get_collection():

cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)

(частично скопирован из ответа Рафаля)

Обратите внимание, что последняя строка эквивалентна пониманию списка в коде Rafal.

В принципе, tensorflow хранит глобальный набор переменных, который может быть выбран с помощью tf.all_variables() или tf.get_collection(tf.GraphKeys.VARIABLES). Если вы укажете scope (имя области) в tf.get_collection(), то вы получите только тензоры (переменные в этом случае) в коллекции чьи области находятся под указанной областью.

EDIT: Вы можете также использовать tf.GraphKeys.TRAINABLE_VARIABLES для получения только обучаемых переменных. Но так как vanilla BasicLSTMCell не инициализирует какую-либо не обучаемую переменную, обе будут функционально эквивалентными. Для получения полного списка коллекций графов по умолчанию проверьте этот вне.