Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/test_memory_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sys
from argparse import Namespace
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "tools"))

from estimate_flash_mla_memory import estimate_bytes # noqa: E402


def test_memory_estimator_counts_k_cache_blocks():
args = Namespace(
dtype="bf16",
batch_size=2,
s_q=1,
mean_sk=17,
h_q=4,
h_kv=1,
d=8,
dv=4,
block_size=16,
)

estimates = estimate_bytes(args)

assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2
assert estimates["out"] == 2 * 1 * 4 * 4 * 2
assert estimates["total"] >= estimates["k_cache"]
Comment on lines +25 to +27

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

建议在测试中增加对 out 估算值的断言,以确保其正确使用了 dtype_bytes(在 bf16 下为 2 字节),从而避免后续引入类似的计算错误。

Suggested change
assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2
assert estimates["total"] >= estimates["k_cache"]
assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2
assert estimates["out"] == 2 * 1 * 4 * 4 * 2
assert estimates["total"] >= estimates["k_cache"]



if __name__ == "__main__":
test_memory_estimator_counts_k_cache_blocks()
69 changes: 69 additions & 0 deletions tools/estimate_flash_mla_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
import argparse
import json
import math


DTYPE_BYTES = {
"bf16": 2,
"fp16": 2,
"fp32": 4,
}


def estimate_bytes(args: argparse.Namespace) -> dict[str, int]:
dtype_bytes = DTYPE_BYTES[args.dtype]
max_seqlen_pad = math.ceil(args.mean_sk / 256) * 256
num_blocks = args.batch_size * math.ceil(max_seqlen_pad / args.block_size)

q = args.batch_size * args.s_q * args.h_q * args.d * dtype_bytes
k_cache = num_blocks * args.block_size * args.h_kv * args.d * dtype_bytes
out = args.batch_size * args.s_q * args.h_q * args.dv * dtype_bytes
lse = args.batch_size * args.h_q * args.s_q * 4
block_table = args.batch_size * math.ceil(max_seqlen_pad / args.block_size) * 4
cache_seqlens = args.batch_size * 4

total = q + k_cache + out + lse + block_table + cache_seqlens
return {
"q": q,
"k_cache": k_cache,
"out": out,
"lse": lse,
"block_table": block_table,
"cache_seqlens": cache_seqlens,
"total": total,
}


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Estimate FlashMLA test tensor memory.")
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--s-q", type=int, default=1)
parser.add_argument("--mean-sk", type=int, default=4096)
parser.add_argument("--h-q", type=int, default=16)
parser.add_argument("--h-kv", type=int, default=1)
parser.add_argument("--d", type=int, default=576)
parser.add_argument("--dv", type=int, default=512)
parser.add_argument("--block-size", type=int, default=16)
parser.add_argument("--dtype", choices=sorted(DTYPE_BYTES), default="bf16")
parser.add_argument("--json", action="store_true", help="Print JSON instead of text.")
return parser.parse_args()


def main() -> int:
args = parse_args()
estimates = estimate_bytes(args)
gib = estimates["total"] / 1024**3
if args.json:
payload = dict(estimates)
payload["total_gib"] = gib
print(json.dumps(payload, indent=2, sort_keys=True))
else:
for name, value in estimates.items():
print(f"{name}: {value / 1024**2:.2f} MiB")
print(f"total_gib: {gib:.3f}")
return 0


if __name__ == "__main__":
raise SystemExit(main())