Подтвердить что ты не робот

Как я могу определить только градиент для подграфа Tensorflow?

Во-первых: я всего несколько дней в Tensorflow, поэтому, пожалуйста, несите меня.

Я начал с кода cifar10 tutorial, и теперь я использую комбинацию сверток и разложения по собственным значениям, которые нарушают символическое дифференцирование. То есть график создается, а затем при вызове train() останавливается script с "Без градиента, определенного для операции [...] (тип op: SelfAdjointEig)". Не удивительно.

Входы на рассматриваемый подграф все еще являются только картами входных функций и используемыми фильтрами, и у меня есть формулы для градиентов под рукой, и они должны быть прямолинейными для реализации с учетом входных данных для подграфа и градиент относительно его выхода.

Из того, что я вижу в документах, я могу зарегистрировать метод градиента для пользовательских Ops с RegisterGradient или переопределить их с помощью экспериментального gradient_override_map. Оба из них должны дать мне доступ к тем вещам, которые мне нужны. Например, поиск в Github. Я нахожу много примеров, которые обращаются к операционным входам как op.input[0] или к такому.

Проблема заключается в том, что я хочу по существу "сократить" целый подграф, а не один операнд, поэтому у меня нет единого ор, чтобы украсить. Поскольку это происходит в одном из сверточных слоев примера cifar, я попытался использовать объект области видимости для этого слоя. Концептуально то, что входит и выходит из этого графика области, является именно тем, что я хочу, если бы я мог каким-то образом переопределить все градиенты области, которые "уже" это сделают.

Я видел tf.Graph.create_op, который (я думаю) мог бы использовать для регистрации нового типа операции, и я мог бы затем переопределить это вычисление градиента типа операции с вышеупомянутыми методами, Но я не вижу способа определить этот проход вперед, не записывая его в С++...

Может быть, я полностью подхожу к этому? Поскольку все мои операции вперед или назад могут быть реализованы с помощью интерфейса python, я, очевидно, хочу избежать реализации чего-либо на С++.

4b9b3361

Ответ 1

Вот трюк от Сергея Иоффе:

Предположим, что вы хотите, чтобы группа ops вела себя как f (x) в прямом режиме, но как g (x) в обратном режиме. Вы реализуете его как

t = g(x)
y = t + tf.stop_gradient(f(x) - t)

Итак, в вашем случае ваш g (x) может быть идентификатором op, с помощью специального градиента с использованием gradient_override_map

Ответ 2

Начиная с TensorFlow 1.7, tf.custom_gradient - это путь.

Ответ 3

Как насчет умножения и деления вместо добавления и вычитания t?

t = g(x)
y = tf.stop_gradient(f(x) / t) * t

Ответ 4

Вот подход, который работает для TensorFlow 2.0. Обратите внимание, что в версии 2.0 мы рады иметь 2 разных алгоритма автодифференцирования: GradientTape для режима ожидания и tf.gradient для режима без ожидания (здесь он называется "ленивый"). Мы демонстрируем, что tf.custom_gradient работает в обоих направлениях.

import tensorflow as tf
assert tf.version.VERSION.startswith('2.')
import numpy as np
from tensorflow.python.framework.ops import disable_eager_execution, enable_eager_execution
from tensorflow.python.client.session import Session

@tf.custom_gradient
def mysquare(x):
  res = x * x
  def _grad(dy):
    return dy * (2*x)
  return res, _grad

def run_eager():
  enable_eager_execution()

  x = tf.constant(np.array([[1,2,3],[4,5,6]]).astype('float32'))
  with tf.GradientTape() as tape:
    tape.watch(x)
    y = tf.reduce_sum(mysquare(x))

    dy_dx = tape.gradient(y,x)
    print('Eager mode')
    print('x:\n',x.numpy())
    print('y:\n',y.numpy())
    print('dy_dx:\n',dy_dx.numpy())


def run_lazy():
  disable_eager_execution()

  x = tf.constant(np.array([[1,2,3],[4,5,6]]).astype('float32'))
  y = tf.reduce_sum(mysquare(x))
  dy_dx = tf.gradients(y,x)

  with Session() as s:
    print('Lazy mode')
    print('x:\n',x.eval(session=s))
    print('y:\n',y.eval(session=s))
    assert len(dy_dx)==1
    print('dy_dx:\n',dy_dx[0].eval(session=s))

if __name__ == '__main__':
  run_eager()
  run_lazy()