diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index c3606a19..505e4d24 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -1903,7 +1903,8 @@ def auto_train_steps(batch_size, @gin.configurable -def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0): +def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0, + stop_after=None): """Get an iterable of checkpoint paths from a provided checkpoint step(s). Args: @@ -1917,6 +1918,9 @@ def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0): model_dir: str, directory to look for checkpoints in. skip_until: an integer - for "all" or "None" behavior, filter out checkpoint numbers that are <= skip_until. + stop_after: an optional integer - for "None behavior, if specified + stop after finding a checkpoint number that is >= stop_at. When a + checkpoint number == stop_at is found, it is yielded before exiting. Returns: An iterable which yields checkpoint paths. @@ -1957,7 +1961,18 @@ def _filter_fn(p): return filter(_filter_fn, [_get_checkpoint_path(s) for s in sorted(list(ckpt_steps))]) elif checkpoint_step is None: - return filter(_filter_fn, tf.train.checkpoints_iterator(model_dir)) + checkpoints_iterator = filter( + _filter_fn, tf.train.checkpoints_iterator(model_dir)) + if stop_after is not None: + def _generate_checkpoints(): + for p in checkpoints_iterator: + step = get_step_from_checkpoint_path(p) + if step <= stop_after: + yield p + if step >= stop_after: + break + return _generate_checkpoints() + return checkpoints_iterator elif isinstance(checkpoint_step, int): return [_get_checkpoint_path(_get_closest_checkpoint(checkpoint_step))] else: