Skip to content

Questions about the wikiart training sets #12

@Rancherzhang

Description

@Rancherzhang

Hello, I want to reproduce your great job, but to my limited knowledge, I have two questions right now.
Firstly, I'm trying to rewrite the training phrase and beginning to train on the wikiart with content-dir of 'wikiart/Rococo' while style-dir of 'wikiart/Symbolism', but the intermediate result is not good as you, so I want to know what content-dir and style-dir you choose on the wikiart datasets?
Secondly, my loss on style distribution could not converge, it is always around between 4.2-4.4. My code is as below:

class StyleDistLoss(nn.Module):
    '''
    style distribition loss of s and s'
    '''
    def __init__(self, pool_size):
        super(StyleDistLoss, self).__init__()
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_style_batch = 0
            self.style_batches = []
        self.loss = nn.L1Loss()

    def __call__(self, sc, st):
        '''
            return the standart Gaussian distribution loss of input 
            style source {sc} and style traget {st} which are respective to s and s' in the paper
        '''
        styles = []
        if self.pool_size == 0:
            styles.extend([sc, st])
        else:
            styles += self.style_batches
            styles.extend([sc, st])

            detach_sc = sc.clone().detach()
            detach_st = st.clone().detach()

            if self.num_style_batch + 2 < self.pool_size:
                self.style_batches.extend([detach_sc, detach_st])
                self.num_style_batch += 2
            else:
                random_idx = [x for x in range(self.num_style_batch)]
                random.shuffle(random_idx)
                self.style_batches[random_idx[0]] = detach_sc
                self.style_batches[random_idx[1]] = detach_st
        tensor_styles = torch.squeeze(torch.cat(styles, 0))
        styles_mean = torch.mean(tensor_styles, dim=0)
        tminuss = tensor_styles - styles_mean
        cov = torch.mm(tminuss.t(), tminuss) / tensor_styles.shape[0]
        std_cov = cov.diag(diagonal=0)
        total_loss = self.loss(styles_mean, torch.zeros_like(styles_mean))
        total_loss += self.loss(cov, torch.ones_like(cov))
        total_loss += self.loss(std_cov, torch.ones_like(std_cov))
        return total_loss

Could you please give me some advice? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions