Skip to content

Global_scale is passed but not applied in MXFP quantization, yet used in GEMM, could this cause numerical inequivalence? #21

@manyizhang

Description

@manyizhang

In FP-Quant/inference_lib/src/fp_quant/module/linear_fns.py, the forward_quantize function receives a global_scale argument. However, fused_quantize_mx_op does not actually apply this global_scale during quantization.

def forward_quantize(
    x: torch.Tensor,
    hadamard_matrix: torch.Tensor,
    global_scale: torch.Tensor,
    dtype: FPQuantDtype,
    forward_method: str,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if dtype == FPQuantDtype.MXFP4:
        qweight, scales, mask = fused_quantize_mx_op(
            x.to(torch.bfloat16),
            hadamard_matrix.to(torch.bfloat16),
            forward_method,
            forward_method == "quest" and x.requires_grad,
        )
        return qweight, scales, mask

Later, in forward_gemm function, the same global_scale is used as the alpha scaling factor for the GEMM operation.

def forward_gemm(x_q, w_q, x_scales, w_scales, alpha, dtype: FPQuantDtype):
    if dtype == FPQuantDtype.MXFP4:
        if False and x_q.shape[0] <= 64:  # TODO: remove when ada alpha is fixed
            return matmul_ada_mxf4_bf16_tn_op(
                x_q, w_q, x_scales, w_scales, alpha.float()
            )
        else:
            return matmul_mxf4_bf16_tn_op(x_q, w_q, x_scales, w_scales, alpha.float())

Shouldn't the global_scale be applied during quantization to ensure equivalence with the GEMM's alpha usage? Or is this intentional behavior?
Looking forward to your clarification :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions