Подтвердить что ты не робот

Получение общего количества записей из файла .tfrecords в Tensorflow

Можно ли получить общее количество записей из файла .tfrecords? В связи с этим, как обычно отслеживается количество эпох, прошедших во время учебных моделей? Хотя мы можем указать batch_size и num_of_epochs, я не уверен, что получить такие значения, как current epoch, количество партий в эпоху и т.д., Просто, чтобы я мог больше контролировать как проходит обучение. В настоящее время я просто использую грязный хак, чтобы вычислить это, поскольку я знаю перед собой, сколько записей есть в моем файле .tfrecords и размере моих мини-абзацев. Цените любую помощь..

4b9b3361

Ответ 1

Чтобы подсчитать количество записей, вы можете использовать tf.python_io.tf_record_iterator.

c = 0
for fn in tf_records_filenames:
  for record in tf.python_io.tf_record_iterator(fn):
     c += 1

Чтобы просто отслеживать обучение модели, tensorboard пригодится.

Ответ 2

Нет, это невозможно. TFRecord не хранит метаданные о хранящихся внутри. Этот файл

представляет последовательность (двоичных) строк. Формат не случайный доступ, поэтому он подходит для потоковой передачи больших объемов данных, но не если требуется быстрый осколок или другой не последовательный доступ.

Если вы хотите, вы можете сохранить эти метаданные вручную или использовать record_iterator, чтобы получить номер (вам нужно будет перебирать все записи, которые у вас есть:

sum(1 for _ in tf.python_io.tf_record_iterator(file_name))

Если вы хотите узнать текущую эпоху, вы можете сделать это либо с тензора или путем печати номера из цикла.

Ответ 3

В соответствии с предупреждением об устаревании tf_record_iterator мы также можем использовать активное выполнение для подсчета записей.

#!/usr/bin/env python
from __future__ import print_function

import tensorflow as tf
import sys

assert len(sys.argv) == 2, \
    "USAGE: {} <file_glob>".format(sys.argv[0])

tf.enable_eager_execution()

input_pattern = sys.argv[1]

# Expand glob if there is one
input_files = tf.io.gfile.glob(input_pattern)

# Create the dataset
data_set = tf.data.TFRecordDataset(input_files)

# Count the records
records_n = sum(1 for record in data_set)

print("records_n = {}".format(records_n))

Ответ 4

Поскольку tf.io.tf_record_iterator устарела, великий ответ Сальвадора Дали должен теперь читать

tf.enable_eager_execution()
sum(1 for _ in tf.data.TFRecordDataset(file_name))