From 3b0825cd8d2d8a8ae41dd3f3621077a18717f9ba Mon Sep 17 00:00:00 2001 From: Angel Mau Date: Wed, 13 May 2026 09:08:15 -0700 Subject: [PATCH] Add Pathways support and proper testing. PiperOrigin-RevId: 914897467 --- tests/sft/checkpoint_manager_test.py | 34 ++++++++++++++++++++++++++++ tunix/sft/checkpoint_manager.py | 11 ++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/sft/checkpoint_manager_test.py b/tests/sft/checkpoint_manager_test.py index 63d120744..2fb209850 100644 --- a/tests/sft/checkpoint_manager_test.py +++ b/tests/sft/checkpoint_manager_test.py @@ -101,6 +101,34 @@ def setUp(self): axis_names=('fsdp', 'tp'), ) + def test_handlers_options(self): + """Verifies OCDBT/Zarr3 options match active platform configuration.""" + cp_path = f'{self.temp_path}/{self.id()}' + cp_manager = checkpoint_manager.CheckpointManager(cp_path) + + platforms = jax.config.jax_platforms or '' + is_pathways_or_proxy = 'proxy' in platforms or 'pathways' in platforms + + handler = cp_manager._checkpoint_manager._checkpointer._handler # pytype: disable=attribute-error + registry_entries = handler._handler_registry.get_all_entries() # pytype: disable=attribute-error + + handlers = {} + for (item_name, _), h in registry_entries.items(): + handlers[item_name] = h + + self.assertIn('model_params', handlers) + self.assertIn('optimizer_state', handlers) + + if is_pathways_or_proxy: + self.assertFalse(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error + self.assertFalse(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error + else: + self.assertTrue(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error + self.assertTrue(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error + + self.assertFalse(handlers['model_params']._use_zarr3) # pytype: disable=attribute-error + self.assertFalse(handlers['optimizer_state']._use_zarr3) # pytype: disable=attribute-error + def test_empty_root_directory(self): cp_manager = checkpoint_manager.CheckpointManager(root_directory=None) self.assertIsNone(cp_manager.latest_step()) @@ -299,6 +327,12 @@ def test_restore_with_backward_compatibility(self, ckpt_path): # The checkpoints in test_data is saved with StandardSave. The test is to # verify the checkpoint manager with PyTreeRestore can still restore the # checkpoints saved with StandardSave. + if os.getenv('ENABLE_PATHWAYS_PERSISTENCE', '') == '1': + self.skipTest( + 'Pathways persistence cannot read standard backwards-compatible' + ' checkpoints.' + ) + ckpt_manager = checkpoint_manager.CheckpointManager( os.path.join(os.path.dirname(__file__), ckpt_path) ) diff --git a/tunix/sft/checkpoint_manager.py b/tunix/sft/checkpoint_manager.py index d069a7706..c4e873390 100644 --- a/tunix/sft/checkpoint_manager.py +++ b/tunix/sft/checkpoint_manager.py @@ -50,7 +50,11 @@ def __init__( if root_directory is not None: # When using Pathways, the checkpoint manager only supports persistence # APIs now. - if 'proxy' in os.getenv('JAX_PLATFORMS', ''): + platforms = jax.config.jax_platforms or '' + if ( + 'proxy' in platforms + or 'pathways' in platforms + ): item_handlers = { 'model_params': ocp.PyTreeCheckpointHandler( use_ocdbt=False, @@ -65,6 +69,11 @@ def __init__( logging.info( 'Using persistence API for checkpointing with Pathways.' ) + ocp.pathways.register_type_handlers( + checkpointing_impl=ocp.pathways.CheckpointingImpl.from_options( + use_remote_python=True, + ) + ) else: logging.warning( 'Checkpointing without the persistence API, be aware of potential'