diff --git a/model/waveunet.py b/model/waveunet.py index a14aa55..0c3aadf 100644 --- a/model/waveunet.py +++ b/model/waveunet.py @@ -34,8 +34,9 @@ def forward(self, x, shortcut): combined = centre_crop(shortcut, upsampled) # Combine high- and low-level features + combined = torch.cat([combined, centre_crop(upsampled, combined)], dim=1) for conv in self.post_shortcut_convs: - combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1)) + combined = conv(combined) return combined def get_output_size(self, input_size): @@ -230,4 +231,4 @@ def forward(self, x, inst=None): out_dict = {} for idx, inst in enumerate(self.instruments): out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs] - return out_dict \ No newline at end of file + return out_dict