diff --git a/conversion/base.py b/conversion/base.py index d8f050ed32d..e3fa1b45d1e 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -909,7 +909,13 @@ def prepare_metadata(self, vocab_only: bool): total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count() - self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params) + self.metadata = gguf.Metadata.load( + self.metadata_override, + self.dir_model_card, + self.model_name, + total_params, + preserve_model_size_label=self.model_arch == gguf.MODEL_ARCH.MMPROJ, + ) # If we are using HF model id, set the metadata name to the model id if self.remote_hf_model_id: diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index e954644e28f..5f045630367 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -58,7 +58,7 @@ class Metadata: datasets: Optional[list[dict]] = None @staticmethod - def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata: + def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0, preserve_model_size_label: bool = False) -> Metadata: # This grabs as many contextual authorship metadata as possible from the model repository # making any conversion as required to match the gguf kv store metadata format # as well as giving users the ability to override any authorship metadata that may be incorrect @@ -72,7 +72,7 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # heuristics - metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params, preserve_model_size_label) if gen_config: metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence) @@ -237,7 +237,7 @@ def id_to_title(string): return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()]) @staticmethod - def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]: + def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0, preserve_model_size_label: bool = False) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]: # Huggingface often store model id as '/' # so let's parse it and apply some heuristics if possible for model name components @@ -273,6 +273,23 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = set[Literal["basename", "size_label", "finetune", "version", "type"]] ] = [set() for _ in name_parts] + def get_size_label_params(label: str) -> float: + return float(label[:-1]) * pow(1000, " KMBT".find(label[-1])) + + def is_active_params_label(label: str) -> bool: + return re.fullmatch(r'A\d+([._]\d+)?[KMBT][\d]?', label, re.IGNORECASE) is not None + + def normalize_size_label(label: str) -> str: + label = label.replace("_", ".") + # Handle weird bloom-7b1 notation + if label[-1].isdecimal(): + label = label[:-2] + "." + label[-1] + label[-2] + # Normalize the size suffixes + if len(label) > 1 and label[-2].isdecimal(): + if label[-1] in "kmbt": + label = label[:-1] + label[-1].upper() + return label + # Annotate the name for i, part in enumerate(name_parts): # Version @@ -284,24 +301,36 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = name_parts[i] = part.upper() # Model size elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE): - part = part.replace("_", ".") - # Handle weird bloom-7b1 notation - if part[-1].isdecimal(): - part = part[:-2] + "." + part[-1] + part[-2] - # Normalize the size suffixes - if len(part) > 1 and part[-2].isdecimal(): - if part[-1] in "kmbt": - part = part[:-1] + part[-1].upper() + part = normalize_size_label(part) if total_params != 0: try: - label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1])) + label_params = get_size_label_params(part) + next_part = normalize_size_label(name_parts[i + 1]) if i + 1 < len(name_parts) else None + # MoE models may encode both total and active parameters, e.g. 26B-A4B. + followed_by_matching_active_params = ( + next_part is not None + and is_active_params_label(next_part) + and abs(get_size_label_params(next_part[1:]) - abs(total_params)) <= 7 * abs(total_params) // 8 + ) + # When converting an mmproj, total_params is the projector size, not the + # parent model size. Keep the first model-size-looking label from the HF id + # and active parameter labels instead of demoting them to finetune text. + preserve_current_size_label = preserve_model_size_label and ( + is_active_params_label(part) + or ( + part[-1].upper() != "K" + and not any("size_label" in t and any(c.isdecimal() for c in n) for n, t in zip(name_parts[:i], name_types[:i])) + ) + ) # Only use it as a size label if it's close or bigger than the model size # Note that LoRA adapters don't necessarily include all layers, # so this is why bigger label sizes are accepted. # Do not use the size label when it's smaller than 1/8 of the model size - if (total_params < 0 and label_params < abs(total_params) // 8) or ( - # Check both directions when the current model isn't a LoRA adapter - total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8 + if not preserve_current_size_label and not followed_by_matching_active_params and ( + (total_params < 0 and label_params < abs(total_params) // 8) or ( + # Check both directions when the current model isn't a LoRA adapter + total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8 + ) ): # Likely a context length name_types[i].add("finetune") @@ -362,7 +391,7 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = return model_full_name_component, org_component, basename, finetune, version, size_label @staticmethod - def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata: + def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0, preserve_model_size_label: bool = False) -> Metadata: # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Model Card Heuristics @@ -459,7 +488,7 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id) if match: model_id_component = match.group(1) - model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params) + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params, preserve_model_size_label) # Populate model dictionary with extracted components if model_full_name_component is not None: @@ -471,7 +500,7 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): else: # Likely a Hugging Face ID - model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params, preserve_model_size_label) # Populate model dictionary with extracted components if model_full_name_component is not None: @@ -517,7 +546,7 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id) if match: dataset_id_component = match.group(1) - dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params) + dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params, preserve_model_size_label) # Populate dataset dictionary with extracted components if dataset_name_component is not None: @@ -529,7 +558,7 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): else: # Likely a Hugging Face ID - dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params) + dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params, preserve_model_size_label) # Populate dataset dictionary with extracted components if dataset_name_component is not None: @@ -569,7 +598,7 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): # Use _name_or_path only if its actually a model name and not some computer path # e.g. 'meta-llama/Llama-2-7b-hf' model_id = hf_name_or_path - model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params, preserve_model_size_label) if metadata.name is None and model_full_name_component is not None: metadata.name = Metadata.id_to_title(model_full_name_component) if metadata.organization is None and org_component is not None: @@ -587,7 +616,7 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): ############################################ if model_path is not None: model_id = model_path.name - model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params, preserve_model_size_label) if metadata.name is None and model_full_name_component is not None: metadata.name = Metadata.id_to_title(model_full_name_component) if metadata.organization is None and org_component is not None: diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py index b77c563ff25..4b298f762cb 100755 --- a/gguf-py/tests/test_metadata.py +++ b/gguf-py/tests/test_metadata.py @@ -62,6 +62,14 @@ def test_get_model_id_components(self): self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"), ('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B')) + # mmproj exports use the parent model id, but total_params is the projector size. + # Preserve the parent model size label instead of treating it as finetune text. + self.assertEqual(gguf.Metadata.get_model_id_components("google/gemma-4-26B-A4B-it", 576 * 10**6, preserve_model_size_label=True), + ('gemma-4-26B-A4B-it', 'google', 'gemma-4', 'it', None, '26B-A4B')) + + self.assertEqual(gguf.Metadata.get_model_id_components("google/gemma-4-31B-it", 576 * 10**6, preserve_model_size_label=True), + ('gemma-4-31B-it', 'google', 'gemma-4', 'it', None, '31B')) + # Check that it can handle a real model id with no version code # Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9),