Подтвердить что ты не робот

Как получить global_step при восстановлении контрольных точек в Tensorflow?

Я сохраняю состояние сеанса следующим образом:

self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)

Когда я потом восстанавливаю, я хочу получить значение global_step для контрольной точки, из которой я восстанавливаю. Это делается для того, чтобы установить из него некоторые гиперпараметры.

Хакерный способ сделать это - пропустить и проанализировать имена файлов в каталоге контрольной точки. Но угрюмый должен быть лучше, встроенный способ сделать это?

4b9b3361

Ответ 1

Общий шаблон должен иметь переменную global_step для отслеживания шагов

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

Затем вы можете сохранить с помощью

saver.save(sess, save_path, global_step=global_step)

При восстановлении восстанавливается значение global_step

Ответ 2

Это немного взломать, но другие ответы не срабатывали для меня вообще

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])

Обновление 9/2017

Я не уверен, что это начало работать из-за обновлений, но следующий метод кажется эффективным для того, чтобы global_step обновлялся и загружался должным образом:

Создайте два ops. Один для хранения global_step и другого для его увеличения:

    global_step = tf.Variable(0, trainable=False, name='global_step')
    increment_global_step = tf.assign_add(global_step,1,
                                            name = 'increment_global_step')

Теперь в вашем цикле обучения запустите инкремент op каждый раз, когда вы запустите свой тренировочный op.

sess.run([train_op,increment_global_step],feed_dict=feed_dict)

Если вы хотите получить глобальное значение шага как целое число в любой точке, просто используйте следующую команду после загрузки модели:

sess.run(global_step)

Это может быть полезно для создания имен файлов или расчета того, какова ваша текущая эпоха, без второй переменной tensorflow Variable для хранения этого значения. Например, вычисление текущей эпохи при загрузке будет выглядеть примерно так:

loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)

Ответ 3

У меня была та же проблема, что и Lawrence Du, я не смог найти способ получить global_step, восстановив модель. Поэтому я применил его взлома к начальный код обучения v3 в Tensorflow/models github repo Я использую. В приведенном ниже коде также содержится исправление, связанное с pretrained_model_checkpoint_path.

Если у вас есть лучшее решение или знаете, что мне не хватает, оставьте комментарий!

В любом случае этот код работает для меня:

...

# When not restoring start at 0
last_step = 0
if FLAGS.pretrained_model_checkpoint_path:
    # A model consists of three files, use the base name of the model in
    # the checkpoint path. E.g. my-model-path/model.ckpt-291500
    #
    # Because we need to give the base name you can't assert (will always fail)
    # assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)

    variables_to_restore = tf.get_collection(
        slim.variables.VARIABLES_TO_RESTORE)
    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
    print('%s: Pre-trained model restored from %s' %
          (datetime.now(), FLAGS.pretrained_model_checkpoint_path))

    # HACK : global step is not restored for some unknown reason
    last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])

    # assign to global step
    sess.run(global_step.assign(last_step))

...

for step in range(last_step + 1, FLAGS.max_steps):

  ...

Ответ 4

TL; DR

Как переменная tenorflow (будет оцениваться в сеансе)

global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter 
# this variable will be evaluated later in the session
saver = tf.train.Saver()
with tf.Session() as sess:
    # restore all variables from checkpoint
    saver.restore(sess, checkpoint_path)
    # than init table and local variables and start training/evaluation ...

Или: как numpy integer (без какой-либо сессии):

reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor('global_step')


Длинный ответ

Существует как минимум два способа получения глобальных данных с контрольной точки. Как переменная tenorflow или как целое число numpy. Синтаксическое имя файла не будет работать, если global_step не был представлен в качестве параметра в save способе Saver. Для предварительно обученных моделей см. Примечание в конце ответа.

Как переменная Tensorflow

Если вам нужна переменная global_step для вычисления некоторых гиперпараметров, вы можете просто использовать tf.train.get_or_create_global_step(). Это вернет переменную tenorflow. Поскольку переменная будет оценена позже в сеансе, вы можете использовать только тензорные операции для вычисления ваших гиперпараметров. Так, например: max(global_step, 100) не будет работать. Вы должны использовать тензор потока, эквивалентный tf.maximum(global_step, 100) который может быть оценен позже в сеансе.

В течение сеанса вы можете инициализировать глобальную переменную шага с помощью контрольной точки, используя saver.restore(sess, checkpoint_path)

global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter 
# this variable will be evaluated later in the session
hyper_parameter = tf.maximum(global_step, 100) 
saver = tf.train.Saver()
with tf.Session() as sess:
    # restore all variables from checkpoint
    saver.restore(sess, checkpoint_path)
    # than init table and local variables and start training/evaluation ...

    # for verification you can print the global step and your hyper parameter
    print(sess.run([global_step, hyper_parameter]))

Или: как целое число numpy (без сессии)

Если вам нужна глобальная переменная шага в виде скаляра без запуска сеанса, вы также можете прочитать эту переменную непосредственно из файла (ов) контрольных точек. Вам просто нужен NewCheckpointReader. Из-за ошибки в более старых версиях тензорного потока вы должны преобразовать путь файла контрольных точек в абсолютный путь. С помощью ридера вы можете получить все тензоры модели в виде числовых переменных. Имя глобальной переменной шага является константной строкой tf.GraphKeys.GLOBAL_STEP определенной как 'global_step'.

absolute_checkpoint_path = os.path.abspath(checkpoint_path)
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)

Примечание для моделей с предварительной подготовкой: в большинстве моделей с предварительной подготовкой, доступных в Интернете, глобальный шаг сбрасывается на ноль. Таким образом, эти модели можно использовать для инициализации параметров модели для тонкой настройки без перезаписи глобального шага.

Ответ 5

Текущая версия 0.10rc0, по-видимому, отличается, и нет tf.saver(). Теперь это tf.train.Saver(). Кроме того, команда save добавляет информацию к имени файла save_path для global_step, поэтому мы не можем просто вызвать восстановление на том же пути save_path, поскольку это не фактический файл сохранения.

Самый простой способ, который я вижу сейчас, - использовать SessionManager вместе с заставкой вроде этого:

my_checkpoint_dir = "/tmp/checkpoint_dir"
# make a saver to use with SessionManager for restoring
saver = tf.train.Saver()
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
# use a SessionManager to help with automatic variable restoration
sm = tf.train.SessionManager()
# try to find the latest checkpoint in my_checkpoint_dir, then create a session with that restored
# if no such checkpoint, then call the init_op after creating a new session
sess = sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=my_checkpoint_dir))

Что это. Теперь у вас есть сеанс, который либо восстановлен из my_checkpoint_dir (убедитесь, что каталог существует до его вызова), либо там нет контрольной точки, он создает новый сеанс и вызывает init_op для инициализации ваших переменных.

Когда вы хотите сохранить, вы просто сохраняете любое имя, которое вы хотите в этом каталоге, и передаете global_step. Здесь пример, где я сохраняю переменную шага в цикле как global_step, поэтому он возвращается к этой точке, если вы убиваете программу и перезапускаете ее, чтобы восстановить контрольную точку:

checkpoint_path = os.path.join(my_checkpoint_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)

Это создает файлы в my_checkpoint_dir, такие как "model.ckpt-1000", где 1000 - это global_step. Он продолжает работать, а затем вы больше похожи на "model.ckpt-2000". При запуске программы SessionManager поднимает последнюю из них. Путь checkpoint_path может быть любым желаемым именем файла, если он находится в checkpoint_dir. Save() создаст этот файл с добавлением global_step (как показано выше). Он также создает индексный файл "контрольной точки", в результате которого SessionManager обнаруживает последнюю контрольную точку сохранения.

Ответ 6

просто обратите внимание на мое решение о глобальном сохранении и восстановлении шагов.

Сохранить:

global_step = tf.Variable(0, trainable=False, name='global_step')
saver.save(sess, model_path + model_name, global_step=_global_step)

Восстановить:

if os.path.exists(model_path):
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    print("Model restore finished, current globle step: %d" % global_step.eval())

Ответ 7

Вы можете использовать переменную global_step для отслеживания шагов, но если в вашем коде вы инициализируете или присваиваете это значение другой переменной step, оно может быть непоследовательным.

Например, вы определяете ваш global_step используя:

global_step = tf.Variable(0, name='global_step', trainable=False)

Назначьте вашей учебной операции:

train_op = optimizer.minimize(loss, global_step=global_step)

Сохранить в вашем контрольном пункте:

saver.save(sess, checkpoint_path, global_step=global_step)

И восстановить с вашего контрольного пункта:

saver.restore(sess, checkpoint_path) 

значение global_step восстанавливается, но если вы присваиваете его другой переменной, скажем, step, то вы должны сделать что-то вроде:

step = global_step.eval(session=sess)

Переменная step, содержит последний сохраненный global_step в контрольной точке.

Было бы неплохо также определить global_step из графа, а не как нулевую переменную (как определено ранее):

global_step = tf.train.get_or_create_global_step()

Это получит ваш последний global_step если он существует, или создаст его, если нет.

Ответ 8

Причина того, что переменная не была восстановлена должным образом, наиболее вероятна из-за того, что она была создана после создания вашего tf.Saver().

Место, где вы создаете объект tf.Saver() имеет значение, если вы явно не указали var_list или не указали None для var_list. Ожидаемое поведение для многих программистов состоит в том, что все переменные в графе сохраняются при вызове метода save(), но это не тот случай, и, возможно, его следует документировать как таковой. Снимок всех переменных на графике сохраняется во время создания объекта.

Если у вас нет каких-либо проблем с производительностью, безопаснее всего создать объект-заставку, когда вы решите сохранить свой прогресс. В противном случае обязательно создайте объект-заставку после создания всех ваших переменных.

Кроме того, global_step который передается в saver.save(sess, save_path, global_step=global_step) является просто счетчиком, используемым для создания имени файла, и не имеет никакого отношения к тому, будет ли оно восстановлено как переменная global_step. Это неправильный параметр IMO, поскольку, если вы сохраняете свой прогресс в конце каждой эпохи, вероятно, лучше передать номер эпохи для этого параметра.