Подтвердить что ты не робот

Как работают обновления anano.scan?

theano.scan вернуть две переменные: переменную значений и обновить переменную. Например,

a = theano.shared(1)

values, updates = theano.scan(fn=lambda a:a+1, outputs_info=a,  n_steps=10)

Однако я замечаю, что в большинстве примеров, с которыми я работаю, переменная обновлений пуста. Кажется, только когда мы пишем функцию в theano.scan, это определенно, мы получаем обновления. Например,

a = theano.shared(1)

values, updates = theano.scan(lambda: {a: a+1}, n_steps=10)

Может кто-нибудь объяснить мне, почему в первом примере обновления пусты, но во втором примере переменная обновлений не пуста? и в целом, как работает переменная обновлений в theano.scan? Спасибо.

4b9b3361

Ответ 1

Рассмотрим следующие четыре варианта (этот код может быть выполнен для наблюдения различий) и анализ ниже.

import theano


def v1a():
    a = theano.shared(1)
    outputs, updates = theano.scan(lambda x: x + 1, outputs_info=a, n_steps=10)
    f = theano.function([], outputs=outputs)
    print f(), a.get_value()


def v1b():
    a = theano.shared(1)
    outputs, updates = theano.scan(lambda x: x + 1, outputs_info=a, n_steps=10)
    f = theano.function([], outputs=outputs, updates=updates)
    print f(), a.get_value()


def v2a():
    a = theano.shared(1)
    outputs, updates = theano.scan(lambda: {a: a + 1}, n_steps=10)
    f = theano.function([], outputs=outputs)
    print f(), a.get_value()


def v2b():
    a = theano.shared(1)
    outputs, updates = theano.scan(lambda: {a: a + 1}, n_steps=10)
    f = theano.function([], outputs=outputs, updates=updates)
    print f(), a.get_value()


def main():
    v1a()
    v1b()
    v2a()
    v2b()


main()

Выход этого кода

[ 2  3  4  5  6  7  8  9 10 11] 1
[ 2  3  4  5  6  7  8  9 10 11] 1
[] 1
[] 11

В вариантах v1x используется lambda x: x + 1. результат лямбда-функции - это символическая переменная, значение которой больше 1. Имя параметра функции лямбда было изменено, чтобы избежать затенения имени общей переменной. В этих вариантах общая переменная не используется или каким-либо образом не обрабатывается сканированием, кроме использования ее в качестве начального значения повторяющейся символической переменной, увеличиваемой с помощью функции шага сканирования.

В вариациях v2x используется lambda {a: a + 1}. Результатом лямбда-функции является словарь, в котором объясняется, как обновить общую переменную a.

updates из вариаций v1x пуст, потому что мы не вернули словарь из функции шага, определяющей любые обновления общих переменных. Вариант outputs из v2x пуст, потому что мы не предоставили никакого символьного вывода функции шага. updates используется только в том случае, если функция шага возвращает словарь выражения с расширением общей переменной (как в v2x), а outputs используется только в том случае, если функция шага возвращает выход символьной переменной (как в v1x).

Когда словарь возвращается, он не будет иметь никакого эффекта, если не указан theano.function. Обратите внимание, что общая переменная не обновлена ​​в v2a, но она обновлена ​​в v2b.

Ответ 2

Чтобы дополнить ответ Daniel, если вы хотите одновременно вычислить выходы и обновления в сканировании anano, посмотрите на этот пример.

Этот код перебирает последовательность, вычисляя сумму ее элементов и обновляет общую переменную t (длина предложения)

import theano
import numpy as np

t = theano.shared(0)
s = theano.tensor.vector('v')

def rec(s, first, t):
    first = s + first
    second = s
    return (first, second), {t: t+1}

first = np.float32(0)

(firsts, seconds), updates = theano.scan(
    fn=rec,
    sequences=s,
    outputs_info=[first, None],
    non_sequences=t)

f = theano.function([s], [firsts, seconds], updates=updates, allow_input_downcast=True)

v = np.arange(10)

print f(v)
print t.get_value()

Выход этого кода

[array([  0.,   1.,   3.,   6.,  10.,  15.,  21.,  28.,  36.,  45.], dtype=float32), 
array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.], dtype=float32)]
10
Функция

rec выводит кортеж и словарь. Сканирование по последовательности будет как вычислять выходы, так и добавлять словарь к обновлениям, позволяя одновременно создавать функцию обновления t и вычислять firsts и seconds.