From 49f19227d228580ce15fe131347087eb6f758d2e Mon Sep 17 00:00:00 2001 From: chungongyu Date: Wed, 1 Jul 2026 10:18:45 +0800 Subject: [PATCH] fix: uncorrected gating shape. --- profold2/model/head.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/profold2/model/head.py b/profold2/model/head.py index 562614f1..2aae7958 100644 --- a/profold2/model/head.py +++ b/profold2/model/head.py @@ -1435,7 +1435,9 @@ def _hamiton_cat(logits): variant_mask = variant_mask[..., None] variant_mask = variant_mask * variant_mask[:, :1, ...] variant_logit = logits - logits[:, :1, ...] - variant_logit = self.predict(variant_logit, variant_mask, gating=gating) + variant_logit = self.predict( + variant_logit, variant_mask, gating=gating[..., None, :, :] + ) r.update(variant_logit=variant_logit) if exists(motifs): @@ -1544,7 +1546,9 @@ def loss(self, value, batch): variant_label_mask[..., :, None, :] * label_mask_ref[..., None, :, :] ) # variant_logit = torch.sum(variant_logit * variant_mask, dim=-1) - variant_logit = self.predict(variant_logit, variant_mask, gating=gating) + variant_logit = self.predict( + variant_logit, variant_mask, gating=gating[..., None, None, :, :] + ) logger.debug('FitnessHead.logit: %s', str(variant_logit)) logger.debug('FitnessHead.label: %s', str(variant_label)) with accelerator.autocast(enabled=False):