Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/sft/checkpoint_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,34 @@ def setUp(self):
axis_names=('fsdp', 'tp'),
)

def test_handlers_options(self):
"""Verifies OCDBT/Zarr3 options match active platform configuration."""
cp_path = f'{self.temp_path}/{self.id()}'
cp_manager = checkpoint_manager.CheckpointManager(cp_path)

platforms = jax.config.jax_platforms or ''
is_pathways_or_proxy = 'proxy' in platforms or 'pathways' in platforms

handler = cp_manager._checkpoint_manager._checkpointer._handler # pytype: disable=attribute-error
registry_entries = handler._handler_registry.get_all_entries() # pytype: disable=attribute-error

handlers = {}
for (item_name, _), h in registry_entries.items():
handlers[item_name] = h

self.assertIn('model_params', handlers)
self.assertIn('optimizer_state', handlers)

if is_pathways_or_proxy:
self.assertFalse(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error
self.assertFalse(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error
else:
self.assertTrue(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error
self.assertTrue(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error

self.assertFalse(handlers['model_params']._use_zarr3) # pytype: disable=attribute-error
self.assertFalse(handlers['optimizer_state']._use_zarr3) # pytype: disable=attribute-error

def test_empty_root_directory(self):
cp_manager = checkpoint_manager.CheckpointManager(root_directory=None)
self.assertIsNone(cp_manager.latest_step())
Expand Down Expand Up @@ -299,6 +327,12 @@ def test_restore_with_backward_compatibility(self, ckpt_path):
# The checkpoints in test_data is saved with StandardSave. The test is to
# verify the checkpoint manager with PyTreeRestore can still restore the
# checkpoints saved with StandardSave.
if os.getenv('ENABLE_PATHWAYS_PERSISTENCE', '') == '1':
self.skipTest(
'Pathways persistence cannot read standard backwards-compatible'
' checkpoints.'
)

ckpt_manager = checkpoint_manager.CheckpointManager(
os.path.join(os.path.dirname(__file__), ckpt_path)
)
Expand Down
11 changes: 10 additions & 1 deletion tunix/sft/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def __init__(
if root_directory is not None:
# When using Pathways, the checkpoint manager only supports persistence
# APIs now.
if 'proxy' in os.getenv('JAX_PLATFORMS', ''):
platforms = jax.config.jax_platforms or ''
if (
'proxy' in platforms
or 'pathways' in platforms
):
item_handlers = {
'model_params': ocp.PyTreeCheckpointHandler(
use_ocdbt=False,
Expand All @@ -65,6 +69,11 @@ def __init__(
logging.info(
'Using persistence API for checkpointing with Pathways.'
)
ocp.pathways.register_type_handlers(
checkpointing_impl=ocp.pathways.CheckpointingImpl.from_options(
use_remote_python=True,
)
)
else:
logging.warning(
'Checkpointing without the persistence API, be aware of potential'
Expand Down
Loading