diff --git a/tests/sft/checkpoint_manager_test.py b/tests/sft/checkpoint_manager_test.py index 63d120744..2fb209850 100644 --- a/tests/sft/checkpoint_manager_test.py +++ b/tests/sft/checkpoint_manager_test.py @@ -101,6 +101,34 @@ def setUp(self): axis_names=('fsdp', 'tp'), ) + def test_handlers_options(self): + """Verifies OCDBT/Zarr3 options match active platform configuration.""" + cp_path = f'{self.temp_path}/{self.id()}' + cp_manager = checkpoint_manager.CheckpointManager(cp_path) + + platforms = jax.config.jax_platforms or '' + is_pathways_or_proxy = 'proxy' in platforms or 'pathways' in platforms + + handler = cp_manager._checkpoint_manager._checkpointer._handler # pytype: disable=attribute-error + registry_entries = handler._handler_registry.get_all_entries() # pytype: disable=attribute-error + + handlers = {} + for (item_name, _), h in registry_entries.items(): + handlers[item_name] = h + + self.assertIn('model_params', handlers) + self.assertIn('optimizer_state', handlers) + + if is_pathways_or_proxy: + self.assertFalse(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error + self.assertFalse(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error + else: + self.assertTrue(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error + self.assertTrue(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error + + self.assertFalse(handlers['model_params']._use_zarr3) # pytype: disable=attribute-error + self.assertFalse(handlers['optimizer_state']._use_zarr3) # pytype: disable=attribute-error + def test_empty_root_directory(self): cp_manager = checkpoint_manager.CheckpointManager(root_directory=None) self.assertIsNone(cp_manager.latest_step()) @@ -299,6 +327,12 @@ def test_restore_with_backward_compatibility(self, ckpt_path): # The checkpoints in test_data is saved with StandardSave. The test is to # verify the checkpoint manager with PyTreeRestore can still restore the # checkpoints saved with StandardSave. + if os.getenv('ENABLE_PATHWAYS_PERSISTENCE', '') == '1': + self.skipTest( + 'Pathways persistence cannot read standard backwards-compatible' + ' checkpoints.' + ) + ckpt_manager = checkpoint_manager.CheckpointManager( os.path.join(os.path.dirname(__file__), ckpt_path) ) diff --git a/tunix/sft/checkpoint_manager.py b/tunix/sft/checkpoint_manager.py index d069a7706..c4e873390 100644 --- a/tunix/sft/checkpoint_manager.py +++ b/tunix/sft/checkpoint_manager.py @@ -50,7 +50,11 @@ def __init__( if root_directory is not None: # When using Pathways, the checkpoint manager only supports persistence # APIs now. - if 'proxy' in os.getenv('JAX_PLATFORMS', ''): + platforms = jax.config.jax_platforms or '' + if ( + 'proxy' in platforms + or 'pathways' in platforms + ): item_handlers = { 'model_params': ocp.PyTreeCheckpointHandler( use_ocdbt=False, @@ -65,6 +69,11 @@ def __init__( logging.info( 'Using persistence API for checkpointing with Pathways.' ) + ocp.pathways.register_type_handlers( + checkpointing_impl=ocp.pathways.CheckpointingImpl.from_options( + use_remote_python=True, + ) + ) else: logging.warning( 'Checkpointing without the persistence API, be aware of potential'