diff --git a/simtorch/similarity/cka.py b/simtorch/similarity/cka.py index bb0ff27..157b1bf 100644 --- a/simtorch/similarity/cka.py +++ b/simtorch/similarity/cka.py @@ -56,11 +56,11 @@ def compute(self, dataloader: Collection): # iterate through layers for i, (_, activation1) in enumerate(self.sim_model1.model_activations.items()): - X = self._normalize(activation1.view(batch_size, -1).to(self.device)) + X = self._normalize(activation1.reshape(batch_size, -1).to(self.device)) L_X = torch.matmul(X, X.T) for j, (_, activation2) in enumerate(self.sim_model2.model_activations.items()): - Y = self._normalize(activation2.view(batch_size, -1).to(self.device)) + Y = self._normalize(activation2.reshape(batch_size, -1).to(self.device)) L_Y = torch.matmul(Y, Y.T) layer_cka = self.linear_CKA(L_X=L_X, L_Y=L_Y)