Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, model, batcher, vocab):
self._sess = tf.Session(config=util.get_config())

# Load an initial checkpoint to use for decoding
ckpt_path = util.load_ckpt(self._saver, self._sess)
ckpt_path = util.load_ckpt(self._saver, self._sess, load_best=True)

if FLAGS.single_pass:
# Make a descriptive decode directory name
Expand Down
15 changes: 13 additions & 2 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,23 @@ def get_config():
config.gpu_options.allow_growth=True
return config

def load_ckpt(saver, sess):
def load_ckpt(saver, sess, load_best=False):
"""Load checkpoint from the train directory and restore it to saver and sess, waiting 10 secs in the case of failure. Also returns checkpoint name."""
while True:
try:
ckpt_state = None
if load_best:
eval_dir = os.path.join(FLAGS.log_root, "eval")
if os.path.exists(eval_dir):
try:
ckpt_state = tf.train.get_checkpoint_state(eval_dir, latest_filename="checkpoint_best")
except ValueError:
pass

train_dir = os.path.join(FLAGS.log_root, "train")
ckpt_state = tf.train.get_checkpoint_state(train_dir)
if ckpt_state is None:
ckpt_state = tf.train.get_checkpoint_state(train_dir)

tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
saver.restore(sess, ckpt_state.model_checkpoint_path)
return ckpt_state.model_checkpoint_path
Expand Down