Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions flash_mla/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,50 @@
# Adapted from deepseek-ai/FlashMLA(https://github.com/deepseek-ai/FlashMLA)
__version__ = "1.0.1"

from flash_mla.flash_mla_interface import(
get_mla_metadata,
flash_mla_with_kvcache
)
__all__ = ["__version__", "get_mla_metadata", "flash_mla_with_kvcache"]

_INTERFACE_FUNCS = None


def _load_interface():
global _INTERFACE_FUNCS
if _INTERFACE_FUNCS is None:
try:
from flash_mla.flash_mla_interface import flash_mla_with_kvcache, get_mla_metadata
except ImportError as exc:
raise ImportError(
"flash_mla_cuda is not available. Build and install FlashMLA from source "
"in a configured MACA environment before calling FlashMLA kernels."
) from exc
_INTERFACE_FUNCS = (get_mla_metadata, flash_mla_with_kvcache)
return _INTERFACE_FUNCS


def get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k):
metadata_func, _ = _load_interface()
return metadata_func(cache_seqlens, num_heads_per_head_k, num_heads_k)


def flash_mla_with_kvcache(
q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
softmax_scale=None,
causal=False,
):
_, flash_func = _load_interface()
return flash_func(
q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
softmax_scale=softmax_scale,
causal=causal,
)
72 changes: 72 additions & 0 deletions tests/test_lazy_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import importlib
import inspect
import sys
import types
import unittest
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))


class LazyImportTest(unittest.TestCase):
def tearDown(self):
sys.modules.pop("flash_mla.flash_mla_interface", None)
sys.modules.pop("flash_mla", None)

def test_import_package_without_compiled_extension(self):
module = importlib.import_module("flash_mla")

self.assertEqual(module.__version__, "1.0.1")
self.assertEqual(
list(inspect.signature(module.get_mla_metadata).parameters),
["cache_seqlens", "num_heads_per_head_k", "num_heads_k"],
)
self.assertEqual(
list(inspect.signature(module.flash_mla_with_kvcache).parameters),
[
"q",
"k_cache",
"block_table",
"cache_seqlens",
"head_dim_v",
"tile_scheduler_metadata",
"num_splits",
"softmax_scale",
"causal",
],
)

def test_interface_functions_are_cached_after_first_load(self):
calls = []
fake_interface = types.ModuleType("flash_mla.flash_mla_interface")

def fake_get_mla_metadata(*args):
calls.append(("metadata", args))
return "metadata"

def fake_flash_mla_with_kvcache(*args, **kwargs):
calls.append(("flash", args, kwargs))
return "flash"

fake_interface.get_mla_metadata = fake_get_mla_metadata
fake_interface.flash_mla_with_kvcache = fake_flash_mla_with_kvcache
sys.modules["flash_mla.flash_mla_interface"] = fake_interface

module = importlib.import_module("flash_mla")
module._INTERFACE_FUNCS = None

self.assertEqual(module.get_mla_metadata("seq", 2, 3), "metadata")
sys.modules.pop("flash_mla.flash_mla_interface", None)
self.assertEqual(module.get_mla_metadata("cached", 4, 5), "metadata")

self.assertEqual(
calls,
[
("metadata", ("seq", 2, 3)),
("metadata", ("cached", 4, 5)),
],
)


if __name__ == "__main__":
unittest.main()