From 0d9fad8e7978c618d26fc37fe2155dccaf6b1864 Mon Sep 17 00:00:00 2001 From: B-lanc Date: Wed, 1 Jun 2022 11:56:53 +0700 Subject: [PATCH] fixed depth issue --- model/waveunet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model/waveunet.py b/model/waveunet.py index a14aa55..0c3aadf 100644 --- a/model/waveunet.py +++ b/model/waveunet.py @@ -34,8 +34,9 @@ def forward(self, x, shortcut): combined = centre_crop(shortcut, upsampled) # Combine high- and low-level features + combined = torch.cat([combined, centre_crop(upsampled, combined)], dim=1) for conv in self.post_shortcut_convs: - combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1)) + combined = conv(combined) return combined def get_output_size(self, input_size): @@ -230,4 +231,4 @@ def forward(self, x, inst=None): out_dict = {} for idx, inst in enumerate(self.instruments): out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs] - return out_dict \ No newline at end of file + return out_dict