Skip to content

Jax + tpu and AQT int8 train model loss is abnormal #71

@Lisennlp

Description

@Lisennlp

I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training to a certain number of steps (such as 200 steps), which was quite different from the loss curve trained by bfloat16. Am I missing something? For example, does backward need some additional processing?
ps: I train model on jax==0.4.23 and tpu v5p-8

In other words, is there a training example for AQT int8 in pax?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions