Подтвердить что ты не робот

Как тренировать сеть TensorFlow с использованием генератора для создания входных данных?

TensorFlow docs описывает множество способов чтения данных с использованием TFRecordReader, TextLineReader, QueueRunner и т.д. и

То, что я хотел бы сделать, намного проще: у меня есть функция генератора питона, которая создает бесконечную последовательность данных обучения как (X, y) кортежей (оба представляют собой массивы numpy, а первое измерение - размер партии). Я просто хочу обучить сеть, используя эти данные в качестве входных данных.

Есть ли простой автономный пример обучения сети TensorFlow с использованием генератора, который производит данные? (по строкам примеров MNIST или CIFAR)

4b9b3361

Ответ 1

Предположим, что у вас есть функция, которая генерирует данные:

 def generator(data): 
    ...
    yield (X, y)

Теперь вам нужна другая функция, описывающая вашу модельную архитектуру. Это может быть любая функция, которая обрабатывает X и должна предсказать y как результат (например, нейронную сеть).

Предположим, что ваша функция принимает X и y в качестве входов, каким-то образом вычисляет предсказание для y из X и возвращает функцию потерь (например, кросс-энтропия или MSE в случае регрессии) между y и предсказанным y:

 def neural_network(X, y): 
    # computation of prediction for y using X
    ...
    return loss(y, y_pred)

Чтобы ваша модель работала, вам нужно определить заполнители для X и Y, а затем запустить сеанс:

 X = tf.placeholder(tf.float32, shape=(batch_size, x_dim))
 y = tf.placeholder(tf.float32, shape=(batch_size, y_dim))

Заполнители - это что-то вроде "свободных переменных", которые нужно указать при запуске сеанса feed_dict:

 with tf.Session() as sess:
     # variables need to be initialized before any sess.run() calls
     tf.global_variables_initializer().run()

     for X_batch, y_batch in generator(data):
         feed_dict = {X: X_batch, y: y_batch} 
         _, loss_value, ... = sess.run([train_op, loss, ...], feed_dict)
         # train_op here stands for optimization operation you have defined
         # and loss for loss function (return value of neural_network function)

Надеюсь, вам понравится. Однако имейте в виду, что это не полностью работающая реализация, а скорее псевдокод, поскольку вы не указали почти никаких подробностей.