Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Internal change #160

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,7 +1903,8 @@ def auto_train_steps(batch_size,


@gin.configurable
def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0):
def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0,
stop_after=None):
"""Get an iterable of checkpoint paths from a provided checkpoint step(s).

Args:
Expand All @@ -1917,6 +1918,9 @@ def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0):
model_dir: str, directory to look for checkpoints in.
skip_until: an integer - for "all" or "None" behavior, filter out
checkpoint numbers that are <= skip_until.
stop_after: an optional integer - for "None behavior, if specified
stop after finding a checkpoint number that is >= stop_at. When a
checkpoint number == stop_at is found, it is yielded before exiting.

Returns:
An iterable which yields checkpoint paths.
Expand Down Expand Up @@ -1957,7 +1961,18 @@ def _filter_fn(p):
return filter(_filter_fn,
[_get_checkpoint_path(s) for s in sorted(list(ckpt_steps))])
elif checkpoint_step is None:
return filter(_filter_fn, tf.train.checkpoints_iterator(model_dir))
checkpoints_iterator = filter(
_filter_fn, tf.train.checkpoints_iterator(model_dir))
if stop_after is not None:
def _generate_checkpoints():
for p in checkpoints_iterator:
step = get_step_from_checkpoint_path(p)
if step <= stop_after:
yield p
if step >= stop_after:
break
return _generate_checkpoints()
return checkpoints_iterator
elif isinstance(checkpoint_step, int):
return [_get_checkpoint_path(_get_closest_checkpoint(checkpoint_step))]
else:
Expand Down