if (N / C) < 1: x = ((q @ k.transpose(-2, -1)) @ v).transpose(1, 2).reshape(B, N, C) else: x = (q @ (k.transpose(-2, -1) @ v)).transpose(1, 2).reshape(B, N, C)
should be
if (N / (C/H)) < 1: x = ((q @ k.transpose(-2, -1)) @ v).transpose(1, 2).reshape(B, N, C) else: x = (q @ (k.transpose(-2, -1) @ v)).transpose(1, 2).reshape(B, N, C)
?
if (N / C) < 1: x = ((q @ k.transpose(-2, -1)) @ v).transpose(1, 2).reshape(B, N, C) else: x = (q @ (k.transpose(-2, -1) @ v)).transpose(1, 2).reshape(B, N, C)should be
if (N / (C/H)) < 1: x = ((q @ k.transpose(-2, -1)) @ v).transpose(1, 2).reshape(B, N, C) else: x = (q @ (k.transpose(-2, -1) @ v)).transpose(1, 2).reshape(B, N, C)?