Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 73 additions & 27 deletions ffn/jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,17 +359,22 @@ def _get_ocp_args(
return DatasetArgs(train_iter)


def _make_ckpt_args(state, train_iter: DataIterator) -> ocp.args.CheckpointArgs:
return ocp.args.Composite(
train_state=ocp.args.StandardSave(state),
train_iter=_get_ocp_args(train_iter, restore=False),
)
def _make_ckpt_args(
state,
train_iter: DataIterator,
checkpoint_items: Sequence[str] = ('train_state', 'train_iter'),
) -> ocp.args.CheckpointArgs:
args = {'train_state': ocp.args.StandardSave(state)}
if 'train_iter' in checkpoint_items:
args['train_iter'] = _get_ocp_args(train_iter, restore=False)
return ocp.args.Composite(**args)


def train_and_evaluate(
config: ml_collections.ConfigDict,
workdir: str,
data_service_address: str | None = None,
checkpoint_items: Sequence[str] = ('train_state', 'train_iter'),
):
"""Main training loop."""
workdir = epath.Path(workdir)
Expand Down Expand Up @@ -416,20 +421,54 @@ def train_and_evaluate(
rng, dropout_rng = jax.random.split(rng)

item_handlers = {}
if isinstance(train_iter, tf.data.Iterator):
if 'train_iter' in checkpoint_items and isinstance(
train_iter, tf.data.Iterator
):
item_handlers = {'train_iter': DatasetCheckpointHandler('ckpt', True)}

# Checkpointing init.
checkpoint_dir = epath.Path(workdir) / 'checkpoints'

options_kwargs = {}
if config.get('checkpoint_every_minutes'):
options_kwargs['save_decision_policy'] = (
ocp.checkpoint_managers.AnySavePolicy([
ocp.checkpoint_managers.ContinuousCheckpointingPolicy(
minimum_interval_secs=int(config.checkpoint_every_minutes * 60)
),
ocp.checkpoint_managers.PreemptionCheckpointingPolicy(),
])
)
else:
options_kwargs['save_interval_steps'] = config.checkpoint_every_steps

policies = []
if config.get('keep_checkpoint_every_minutes'):
policies.append(
ocp.checkpoint_managers.EveryNSeconds(
interval_secs=int(config.keep_checkpoint_every_minutes * 60)
)
)
if config.get('max_checkpoints_to_keep') is not None:
policies.append(
ocp.checkpoint_managers.LatestN(n=config.max_checkpoints_to_keep)
)
if policies:
options_kwargs['preservation_policy'] = (
ocp.checkpoint_managers.AnyPreservationPolicy(policies)
)

checkpoint_options = ocp.CheckpointManagerOptions(**options_kwargs)

checkpoint_manager = ocp.CheckpointManager(
checkpoint_dir,
item_names=('train_state', 'train_iter'),
item_names=tuple(checkpoint_items),
item_handlers=item_handlers,
options=ocp.CheckpointManagerOptions(
save_interval_steps=config.checkpoint_every_steps
),
options=checkpoint_options,
)
checkpointed_state = {'train_state': state, 'train_iter': train_iter}
checkpointed_state = {'train_state': state}
if 'train_iter' in checkpoint_items:
checkpointed_state['train_iter'] = train_iter
latest_step = checkpoint_manager.latest_step()
# If an initial checkpoint is provided and the checkpointing library does not
# report a 'latest' checkpoint, then we are starting a new experiment.
Expand All @@ -438,26 +477,30 @@ def train_and_evaluate(
if config.init_from_cpoint and latest_step is None:
handler = ocp.StandardCheckpointHandler()
train_state_path = epath.Path(config.init_from_cpoint) / 'train_state'
train_iter_path = epath.Path(config.init_from_cpoint) / 'train_iter'

if isinstance(train_iter, tf.data.Iterator):
iter_handler = item_handlers['train_iter']
args = DatasetArgs(train_iter)

checkpointed_state['train_state'] = handler.restore(
train_state_path, args=ocp.args.StandardRestore(state)
)
checkpointed_state['train_iter'] = iter_handler.restore(
train_iter_path, args=args
)

if 'train_iter' in checkpoint_items:
train_iter_path = epath.Path(config.init_from_cpoint) / 'train_iter'
if isinstance(train_iter, tf.data.Iterator):
iter_handler = item_handlers['train_iter']
args = DatasetArgs(train_iter)
checkpointed_state['train_iter'] = iter_handler.restore(
train_iter_path, args=args
)

logging.info('Initializing training from %r', config.init_from_cpoint)
elif latest_step is not None:
restore_args = {
'train_state': ocp.args.StandardRestore(state),
}
if 'train_iter' in checkpoint_items:
restore_args['train_iter'] = _get_ocp_args(train_iter)
checkpointed_state = checkpoint_manager.restore(
latest_step,
args=ocp.args.Composite(
train_state=ocp.args.StandardRestore(state),
train_iter=_get_ocp_args(train_iter),
),
args=ocp.args.Composite(**restore_args),
)
logging.info('Restored checkpoint for step %d', latest_step)

Expand All @@ -474,7 +517,8 @@ def train_and_evaluate(
# with the current setup. Avoid the problem by moving the state to the
# host.
state = jax.tree.map(np.array, checkpointed_state['train_state'])
train_iter = checkpointed_state['train_iter']
if 'train_iter' in checkpoint_items:
train_iter = checkpointed_state['train_iter']
initial_step = int(state.step) + 1

global_batch_size = config.per_device_batch_size * jax.device_count()
Expand Down Expand Up @@ -531,8 +575,9 @@ def train_fn(state, batch, loss_scale):
batch_sharding, # logits
replicate_sharding, # loss scale
)
p_train_step = jax.jit(train_fn, in_shardings=shard_in,
out_shardings=shard_out)
p_train_step = jax.jit(
train_fn, in_shardings=shard_in, out_shardings=shard_out
)

# Initialize summary writer.
writer = metric_writers.create_default_writer(
Expand Down Expand Up @@ -645,7 +690,8 @@ def _reshape(x):
logging.info('Saving checkpoint at %d.', step)
train_state = jax.tree.map(np.array, state)
checkpoint_manager.save(
step, args=_make_ckpt_args(train_state, train_iter)
step,
args=_make_ckpt_args(train_state, train_iter, checkpoint_items),
)

if checkpoint_manager.reached_preemption(step):
Expand Down
Loading