From 722f140f6c89be634c5aeb11a3c1bf1a5dd561e8 Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Fri, 5 Jun 2026 11:15:03 +0800 Subject: [PATCH 1/2] Lazy load FlashMLA extension --- flash_mla/__init__.py | 26 ++++++++++++++++++++++---- tests/test_lazy_import.py | 14 ++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 tests/test_lazy_import.py diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index e37cd19..352d51d 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -1,7 +1,25 @@ # Adapted from deepseek-ai/FlashMLA(https://github.com/deepseek-ai/FlashMLA) __version__ = "1.0.1" -from flash_mla.flash_mla_interface import( - get_mla_metadata, - flash_mla_with_kvcache -) +__all__ = ["__version__", "get_mla_metadata", "flash_mla_with_kvcache"] + + +def _load_interface(): + try: + from flash_mla.flash_mla_interface import flash_mla_with_kvcache, get_mla_metadata + except ImportError as exc: + raise ImportError( + "flash_mla_cuda is not available. Build and install FlashMLA from source " + "in a configured MACA environment before calling FlashMLA kernels." + ) from exc + return get_mla_metadata, flash_mla_with_kvcache + + +def get_mla_metadata(*args, **kwargs): + metadata_func, _ = _load_interface() + return metadata_func(*args, **kwargs) + + +def flash_mla_with_kvcache(*args, **kwargs): + _, flash_func = _load_interface() + return flash_func(*args, **kwargs) diff --git a/tests/test_lazy_import.py b/tests/test_lazy_import.py new file mode 100644 index 0000000..4153043 --- /dev/null +++ b/tests/test_lazy_import.py @@ -0,0 +1,14 @@ +import importlib +import unittest + + +class LazyImportTest(unittest.TestCase): + def test_import_package_without_compiled_extension(self): + module = importlib.import_module("flash_mla") + + self.assertEqual(module.__version__, "1.0.1") + self.assertTrue(callable(module.get_mla_metadata)) + + +if __name__ == "__main__": + unittest.main() From 1b5be744b3faffc237a4d5e457f286a8a5ff158a Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Mon, 22 Jun 2026 13:39:07 +0800 Subject: [PATCH 2/2] Refine lazy import wrappers --- flash_mla/__init__.py | 49 ++++++++++++++++++++++++-------- tests/test_lazy_import.py | 60 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 13 deletions(-) diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index 352d51d..9513b14 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -3,23 +3,48 @@ __all__ = ["__version__", "get_mla_metadata", "flash_mla_with_kvcache"] +_INTERFACE_FUNCS = None + def _load_interface(): - try: - from flash_mla.flash_mla_interface import flash_mla_with_kvcache, get_mla_metadata - except ImportError as exc: - raise ImportError( - "flash_mla_cuda is not available. Build and install FlashMLA from source " - "in a configured MACA environment before calling FlashMLA kernels." - ) from exc - return get_mla_metadata, flash_mla_with_kvcache + global _INTERFACE_FUNCS + if _INTERFACE_FUNCS is None: + try: + from flash_mla.flash_mla_interface import flash_mla_with_kvcache, get_mla_metadata + except ImportError as exc: + raise ImportError( + "flash_mla_cuda is not available. Build and install FlashMLA from source " + "in a configured MACA environment before calling FlashMLA kernels." + ) from exc + _INTERFACE_FUNCS = (get_mla_metadata, flash_mla_with_kvcache) + return _INTERFACE_FUNCS -def get_mla_metadata(*args, **kwargs): +def get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k): metadata_func, _ = _load_interface() - return metadata_func(*args, **kwargs) + return metadata_func(cache_seqlens, num_heads_per_head_k, num_heads_k) -def flash_mla_with_kvcache(*args, **kwargs): +def flash_mla_with_kvcache( + q, + k_cache, + block_table, + cache_seqlens, + head_dim_v, + tile_scheduler_metadata, + num_splits, + softmax_scale=None, + causal=False, +): _, flash_func = _load_interface() - return flash_func(*args, **kwargs) + return flash_func( + q, + k_cache, + block_table, + cache_seqlens, + head_dim_v, + tile_scheduler_metadata, + num_splits, + softmax_scale=softmax_scale, + causal=causal, + ) diff --git a/tests/test_lazy_import.py b/tests/test_lazy_import.py index 4153043..cad834b 100644 --- a/tests/test_lazy_import.py +++ b/tests/test_lazy_import.py @@ -1,13 +1,71 @@ import importlib +import inspect +import sys +import types import unittest +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) class LazyImportTest(unittest.TestCase): + def tearDown(self): + sys.modules.pop("flash_mla.flash_mla_interface", None) + sys.modules.pop("flash_mla", None) + def test_import_package_without_compiled_extension(self): module = importlib.import_module("flash_mla") self.assertEqual(module.__version__, "1.0.1") - self.assertTrue(callable(module.get_mla_metadata)) + self.assertEqual( + list(inspect.signature(module.get_mla_metadata).parameters), + ["cache_seqlens", "num_heads_per_head_k", "num_heads_k"], + ) + self.assertEqual( + list(inspect.signature(module.flash_mla_with_kvcache).parameters), + [ + "q", + "k_cache", + "block_table", + "cache_seqlens", + "head_dim_v", + "tile_scheduler_metadata", + "num_splits", + "softmax_scale", + "causal", + ], + ) + + def test_interface_functions_are_cached_after_first_load(self): + calls = [] + fake_interface = types.ModuleType("flash_mla.flash_mla_interface") + + def fake_get_mla_metadata(*args): + calls.append(("metadata", args)) + return "metadata" + + def fake_flash_mla_with_kvcache(*args, **kwargs): + calls.append(("flash", args, kwargs)) + return "flash" + + fake_interface.get_mla_metadata = fake_get_mla_metadata + fake_interface.flash_mla_with_kvcache = fake_flash_mla_with_kvcache + sys.modules["flash_mla.flash_mla_interface"] = fake_interface + + module = importlib.import_module("flash_mla") + module._INTERFACE_FUNCS = None + + self.assertEqual(module.get_mla_metadata("seq", 2, 3), "metadata") + sys.modules.pop("flash_mla.flash_mla_interface", None) + self.assertEqual(module.get_mla_metadata("cached", 4, 5), "metadata") + + self.assertEqual( + calls, + [ + ("metadata", ("seq", 2, 3)), + ("metadata", ("cached", 4, 5)), + ], + ) if __name__ == "__main__":