Skip to content

Instantly share code, notes, and snippets.

@creotiv
Created June 18, 2019 17:39
Show Gist options
  • Select an option

  • Save creotiv/e70c1781097ea5068f3be8da5e37d117 to your computer and use it in GitHub Desktop.

Select an option

Save creotiv/e70c1781097ea5068f3be8da5e37d117 to your computer and use it in GitHub Desktop.
static style
style = load_image(args.style_image, size=args.style_size)
style = style_transform(style)
style = style.repeat(args.batch_size, 1, 1, 1).to(device)
features_style = vgg(normalize_batch(style.div_(255.0))).style
gram_style = [gram_matrix(y) for y in features_style]
for e in range(args.epochs):
transformer.train()
agg_content_loss = 0.
agg_style_loss = 0.
count = 0
for batch_id, (x, _) in enumerate(train_loader):
n_batch = len(x)
count += n_batch
optimizer.zero_grad()
_x = x.to(device)
y = transformer(_x)
y = normalize_batch(y.div_(255.0))
_x = normalize_batch(_x.div_(255.0))
features_y = vgg(y)
features_x = vgg(_x)
content_loss = args.content_weight * F.mse_loss(features_y.content[0], features_x.content[0])
style_loss = 0.
for ft_y, gm_s in zip(features_y.style, gram_style):
gm_y = gram_matrix(ft_y)
style_loss += F.mse_loss(gm_y, gm_s)
style_loss *= args.style_weight
total_loss = content_loss + style_loss
total_loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment