diff --git a/ffn/jax/train.py b/ffn/jax/train.py index e22eedb..e84276c 100644 --- a/ffn/jax/train.py +++ b/ffn/jax/train.py @@ -359,17 +359,22 @@ def _get_ocp_args( return DatasetArgs(train_iter) -def _make_ckpt_args(state, train_iter: DataIterator) -> ocp.args.CheckpointArgs: - return ocp.args.Composite( - train_state=ocp.args.StandardSave(state), - train_iter=_get_ocp_args(train_iter, restore=False), - ) +def _make_ckpt_args( + state, + train_iter: DataIterator, + checkpoint_items: Sequence[str] = ('train_state', 'train_iter'), +) -> ocp.args.CheckpointArgs: + args = {'train_state': ocp.args.StandardSave(state)} + if 'train_iter' in checkpoint_items: + args['train_iter'] = _get_ocp_args(train_iter, restore=False) + return ocp.args.Composite(**args) def train_and_evaluate( config: ml_collections.ConfigDict, workdir: str, data_service_address: str | None = None, + checkpoint_items: Sequence[str] = ('train_state', 'train_iter'), ): """Main training loop.""" workdir = epath.Path(workdir) @@ -416,20 +421,54 @@ def train_and_evaluate( rng, dropout_rng = jax.random.split(rng) item_handlers = {} - if isinstance(train_iter, tf.data.Iterator): + if 'train_iter' in checkpoint_items and isinstance( + train_iter, tf.data.Iterator + ): item_handlers = {'train_iter': DatasetCheckpointHandler('ckpt', True)} # Checkpointing init. checkpoint_dir = epath.Path(workdir) / 'checkpoints' + + options_kwargs = {} + if config.get('checkpoint_every_minutes'): + options_kwargs['save_decision_policy'] = ( + ocp.checkpoint_managers.AnySavePolicy([ + ocp.checkpoint_managers.ContinuousCheckpointingPolicy( + minimum_interval_secs=int(config.checkpoint_every_minutes * 60) + ), + ocp.checkpoint_managers.PreemptionCheckpointingPolicy(), + ]) + ) + else: + options_kwargs['save_interval_steps'] = config.checkpoint_every_steps + + policies = [] + if config.get('keep_checkpoint_every_minutes'): + policies.append( + ocp.checkpoint_managers.EveryNSeconds( + interval_secs=int(config.keep_checkpoint_every_minutes * 60) + ) + ) + if config.get('max_checkpoints_to_keep') is not None: + policies.append( + ocp.checkpoint_managers.LatestN(n=config.max_checkpoints_to_keep) + ) + if policies: + options_kwargs['preservation_policy'] = ( + ocp.checkpoint_managers.AnyPreservationPolicy(policies) + ) + + checkpoint_options = ocp.CheckpointManagerOptions(**options_kwargs) + checkpoint_manager = ocp.CheckpointManager( checkpoint_dir, - item_names=('train_state', 'train_iter'), + item_names=tuple(checkpoint_items), item_handlers=item_handlers, - options=ocp.CheckpointManagerOptions( - save_interval_steps=config.checkpoint_every_steps - ), + options=checkpoint_options, ) - checkpointed_state = {'train_state': state, 'train_iter': train_iter} + checkpointed_state = {'train_state': state} + if 'train_iter' in checkpoint_items: + checkpointed_state['train_iter'] = train_iter latest_step = checkpoint_manager.latest_step() # If an initial checkpoint is provided and the checkpointing library does not # report a 'latest' checkpoint, then we are starting a new experiment. @@ -438,26 +477,30 @@ def train_and_evaluate( if config.init_from_cpoint and latest_step is None: handler = ocp.StandardCheckpointHandler() train_state_path = epath.Path(config.init_from_cpoint) / 'train_state' - train_iter_path = epath.Path(config.init_from_cpoint) / 'train_iter' - - if isinstance(train_iter, tf.data.Iterator): - iter_handler = item_handlers['train_iter'] - args = DatasetArgs(train_iter) checkpointed_state['train_state'] = handler.restore( train_state_path, args=ocp.args.StandardRestore(state) ) - checkpointed_state['train_iter'] = iter_handler.restore( - train_iter_path, args=args - ) + + if 'train_iter' in checkpoint_items: + train_iter_path = epath.Path(config.init_from_cpoint) / 'train_iter' + if isinstance(train_iter, tf.data.Iterator): + iter_handler = item_handlers['train_iter'] + args = DatasetArgs(train_iter) + checkpointed_state['train_iter'] = iter_handler.restore( + train_iter_path, args=args + ) + logging.info('Initializing training from %r', config.init_from_cpoint) elif latest_step is not None: + restore_args = { + 'train_state': ocp.args.StandardRestore(state), + } + if 'train_iter' in checkpoint_items: + restore_args['train_iter'] = _get_ocp_args(train_iter) checkpointed_state = checkpoint_manager.restore( latest_step, - args=ocp.args.Composite( - train_state=ocp.args.StandardRestore(state), - train_iter=_get_ocp_args(train_iter), - ), + args=ocp.args.Composite(**restore_args), ) logging.info('Restored checkpoint for step %d', latest_step) @@ -474,7 +517,8 @@ def train_and_evaluate( # with the current setup. Avoid the problem by moving the state to the # host. state = jax.tree.map(np.array, checkpointed_state['train_state']) - train_iter = checkpointed_state['train_iter'] + if 'train_iter' in checkpoint_items: + train_iter = checkpointed_state['train_iter'] initial_step = int(state.step) + 1 global_batch_size = config.per_device_batch_size * jax.device_count() @@ -531,8 +575,9 @@ def train_fn(state, batch, loss_scale): batch_sharding, # logits replicate_sharding, # loss scale ) - p_train_step = jax.jit(train_fn, in_shardings=shard_in, - out_shardings=shard_out) + p_train_step = jax.jit( + train_fn, in_shardings=shard_in, out_shardings=shard_out + ) # Initialize summary writer. writer = metric_writers.create_default_writer( @@ -645,7 +690,8 @@ def _reshape(x): logging.info('Saving checkpoint at %d.', step) train_state = jax.tree.map(np.array, state) checkpoint_manager.save( - step, args=_make_ckpt_args(train_state, train_iter) + step, + args=_make_ckpt_args(train_state, train_iter, checkpoint_items), ) if checkpoint_manager.reached_preemption(step):