Skip to content

Restore device-aware num_aie_columns in SwiGLU operators#104

Closed
albiol2004 wants to merge 1 commit intoamd:develfrom
albiol2004:swiglu-num-aie-columns
Closed

Restore device-aware num_aie_columns in SwiGLU operators#104
albiol2004 wants to merge 1 commit intoamd:develfrom
albiol2004:swiglu-num-aie-columns

Conversation

@albiol2004
Copy link
Copy Markdown
Contributor

Restores device-aware column selection in SwiGLUDecode and SwiGLUPrefill so both composite operators run on Phoenix (NPU1, 4 columns) as well as Strix (NPU2, 8 columns). The branching was originally introduced in #89 and inadvertently dropped during the simplifying refactor in #88, which hardcoded num_aie_columns=8 across both SwiGLU variants. This PR replaces the literal 8 with aie_utils.get_current_device().cols, matching the pattern already used by rms_norm, gemm, gemv, mem_copy, etc.

While re-threading the column count, num_aie_columns=n_cols is also now passed to the two GEMM calls in SwiGLUPrefill. Previously those defaulted to GEMM's fallback, which meant SwiGLU prefill was implicitly under-parallelized on NPU2 and misaligned with the column count used by the surrounding SiLU and ElementwiseMul sub-ops.

A new rectangular FFN shape (seq_len=256, embedding_dim=1024, hidden_dim=3584) is added to swiglu_prefill/test.py so real decoder-model FFN dimensions (e.g. Qwen3.5-0.8B) are exercised in CI alongside the existing square 2048² smoke test.

Added

  • (256, 1024, 3584, False) rectangular FFN shape in iron/operators/swiglu_prefill/test.py, reflecting real decoder-model dims so non-square paths are covered.

Changed

  • iron/operators/swiglu_decode/op.py: derive n_cols = aie_utils.get_current_device().cols and pass it to the gemv_1 / silu / eltwise_mul / gemv_2 sub-ops in place of the hardcoded 8.
  • iron/operators/swiglu_prefill/op.py: same device-aware derivation, applied to silu / eltwise_mul and (newly) threaded through the gemm_1 and gemm_2 calls, which previously omitted num_aie_columns entirely.

Removed

  • Hardcoded num_aie_columns=8 and associated // 8, // 16 literals in both SwiGLU op files.

Testing

Verified on NPU2 (Strix, aie2p) with ironenv + XRT sourced:

  • pytest iron/operators/ -m "not extensive" --iterations 1 : all previously-passing tests still pass (pre-existing LeakyReLU skips unchanged).
  • pytest iron/operators/swiglu_decode/test.py -v --iterations 1 : square 2048² passes.
  • pytest iron/operators/swiglu_prefill/test.py -v --iterations 1 : both square 2048² and the new rectangular 256 × 1024 × 3584 pass.

NPU1 (Phoenix, aie2) hardware was not available for local validation; the column-selection logic is structurally identical to the previously-shipped #89 code path, so Phoenix behavior is expected to match that baseline and should be re-confirmed by a reviewer with Phoenix access.

During this work a separate, pre-existing numerical issue was observed in swiglu_decode at rectangular decode shapes (e.g. 1024 × 3584) that is unrelated to this change, the decode failure reproduces with num_aie_columns=8 either before or after this PR, so the rectangular case was not added to swiglu_decode/test.py in this PR. That issue is being investigated separately.

PR Merge Checklist

  1. The PR is rebased on the latest devel commit and pointing to devel.
  2. Your PR has been reviewed and approved.
  3. All checks are passing.

Copy link
Copy Markdown
Collaborator

@andrej andrej left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for restoring that fix

Comment on lines +45 to +47
# (Phoenix, aie2) and 8 on NPU2 (Strix, aie2p). Restores NPU1 support
# that was originally added in #89 and inadvertently dropped during the
# simplifying refactor (#88).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's drop that second sentence from the comment. It's enough for the comments to just describe the functionality/reasoning of the code, not it's history in this case.

Comment on lines +52 to +54
# (Phoenix, aie2) and 8 on NPU2 (Strix, aie2p). Restores NPU1 support
# that was originally added in #89 and inadvertently dropped during the
# simplifying refactor (#88).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop second sentence

Comment on lines +21 to +23
# Square shapes cover the historical smoke-test config; rectangular
# shapes reflect real decoder-model FFN dims (e.g. Qwen3.5-0.8B
# embedding=1024, hidden=3584) that downstream runtimes actually hit.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that comment can be dropped as well

@andrej
Copy link
Copy Markdown
Collaborator

andrej commented Apr 17, 2026

Sorry for the back and forth. I just noticed this doesn't point to the latest devel. The fixes for SwiGLU were already included in #90 so I'm going to close this.

@andrej andrej closed this Apr 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants