Skip to content

Better support for multiclass/multilabel classifiers #360

@besaleli

Description

@besaleli

Our current SequenceClassificationResult proto and corresponding REST response shape assumes single-label classification — we return a single predicted_index and optional predicted_label, which is just wrong for multilabel models where multiple labels can be active simultaneously. Users with multilabel models are currently getting silently incorrect results.

Changes needed:

Refactor SequenceClassificationResult to drop predicted_index and predicted_label in favor of a repeated ClassificationLabel field, where each entry contains the index, label string, and score for that prediction. For single-label models this will have one entry, for multilabel models it will contain all labels that cleared the sigmoid threshold. Raw logits stay as-is.

message ClassificationLabel {
  uint32 index = 1;
  string label = 2;
}

message ClassificationResult {
  repeated float logits = 1;
  repeated ClassificationLabel labels = 2;
}

Same fix should be applied to the REST response shape for consistency.

Note this is a breaking change to the proto.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions