訓練CycleGAN

``````from tqdm import tqdm
import torchvision.utils as vutils

for epoch in range(epochs):
for idx, data in progress_bar:
############ define training data & label ############
real_A = data[0][0].to(device)    # vangogh image
real_B = data[1][0].to(device)    # real picture
``````

``````        ############ Train G ############

############  Train G - Adversial Loss  ############
fake_A = G_B2A(real_B)
fake_out_A = D_A(fake_A)
fake_B = G_A2B(real_A)
fake_out_B = D_B(fake_B)

real_label = torch.ones( (fake_out_A.size()) , dtype=torch.float32).to(device)
fake_label = torch.zeros( (fake_out_A.size()) , dtype=torch.float32).to(device)
``````

2.是否能重新建構 (Consistency Loss)：

``````        ############  G - Consistency Loss (Reconstruction)  ############
rec_A = G_B2A(fake_B)
rec_B = G_A2B(fake_A)
consistency_loss_B2A = L1(rec_A, real_A)
consistency_loss_A2B = L1(rec_B, real_B)
rec_loss = consistency_loss_B2A + consistency_loss_A2B
``````

3.是否能保持一致 (Identity Loss)：

``````        ############  G - Identity  Loss ############
idt_A = G_B2A(real_A)
idt_B = G_A2B(real_B)
identity_loss_A = L1(idt_A, real_A)
identity_loss_B = L1(idt_B, real_B)
idt_loss = identity_loss_A + identity_loss_B
``````

``````        ############ Train D ############

############ D - Adversial D_A Loss ############
real_out_A = D_A(real_A)
real_out_A_loss = MSE(real_out_A, real_label)
fake_out_A = D_A(fake_A_sample.push_and_pop(fake_A))
fake_out_A_loss = MSE(real_out_A, fake_label)

loss_DA = real_out_A_loss + fake_out_A_loss

############  D - Adversial D_B Loss  ############

real_out_B = D_B(real_B)
real_out_B_loss = MSE(real_out_B, real_label)
fake_out_B = D_B(fake_B_sample.push_and_pop(fake_B))
fake_out_B_loss = MSE(fake_out_B, fake_label)

loss_DB = ( real_out_B_loss + fake_out_B_loss )

############  D - Total Loss ############

loss_D = ( loss_DA + loss_DB ) * 0.5

############  Backward & Update ############

loss_D.backward()
optim_D.step()
``````

``````        ############ progress info ############
progress_bar.set_description(
f"[{epoch}/{epochs - 1}][{idx}/{len(dataloader) - 1}] "
f"Loss_D: {(loss_DA + loss_DB).item():.4f} "
f"Loss_G: {loss_G.item():.4f} "
f"Loss_G_identity: {(idt_loss).item():.4f} "
f"loss_G_cycle: {(rec_loss).item():.4f}")
``````

1.儲存模型結構以及權重

``torch.save( model )``

2.只儲存權重

``torch.save( model.static_dict() )``

``````
if i % log_freq == 0:

vutils.save_image(real_A, f"{output_path}/real_A_{epoch}.jpg", normalize=True)
vutils.save_image(real_B, f"{output_path}/real_B_{epoch}.jpg", normalize=True)

fake_A = ( G_B2A( real_B ).data + 1.0 ) * 0.5
fake_B = ( G_A2B( real_A ).data + 1.0 ) * 0.5

vutils.save_image(fake_A, f"{output_path}/fake_A_{epoch}.jpg", normalize=True)
vutils.save_image(fake_B, f"{output_path}/fake_A_{epoch}.jpg", normalize=True)

torch.save(G_A2B.state_dict(), f"weights/netG_A2B_epoch_{epoch}.pth")
torch.save(G_B2A.state_dict(), f"weights/netG_B2A_epoch_{epoch}.pth")
torch.save(D_A.state_dict(), f"weights/netD_A_epoch_{epoch}.pth")
torch.save(D_B.state_dict(), f"weights/netD_B_epoch_{epoch}.pth")

############ Update learning rates ############
lr_scheduler_G.step()
lr_scheduler_D.step()

############ save last check pointing ############
torch.save(netG_A2B.state_dict(), f"weights/netG_A2B.pth")
torch.save(netG_B2A.state_dict(), f"weights/netG_B2A.pth")
torch.save(netD_A.state_dict(), f"weights/netD_A.pth")
torch.save(netD_B.state_dict(), f"weights/netD_B.pth")
``````

測試

1.導入函式庫

``````import os
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from tqdm import tqdm
import torchvision.utils as vutils
``````

``````  batch_size = 12
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose( [transforms.Resize((256,256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

root = r'vangogh2photo'
targetC_path = os.path.join(root, 'custom')
output_path = os.path.join('./', r'output')

if os.path.exists(output_path) == False:
os.mkdir(output_path)
print('Create dir : ', output_path)

``````

3.實例化生成器、載入權重 (load_static_dict)、選擇模式 ( train or eval )，如果選擇 eval，PyTorch會將Drop給自動關掉；因為我只要真實照片轉成梵谷所以只宣告了G_B2A：

``````  # get generator
G_B2A = Generator().to(device)

# Set model mode
G_B2A.eval()
``````

4.開始進行預測：

``````progress_bar = tqdm(enumerate(dataC_loader), total=len(dataC_loader))

for i, data in progress_bar:
# get data
real_images_B = data[0].to(device)

# Generate output
fake_image_A = 0.5 * (G_B2A(real_images_B).data + 1.0)

# Save image files
vutils.save_image(fake_image_A.detach(), f"{output_path}/FakeA_{i + 1:04d}.jpg", normalize=True)

progress_bar.set_description(f"Process images {i + 1} of {len(dataC_loader)}")
``````

5.去output察看結果：

 ORIGINAL TRANSFORM

在JetsonNano中進行風格轉換

1.首先要將權重放到Jetson Nano中

2.重建生成器並導入權重值

``````import torch
from torch import nn
from torchsummary import summary

def conv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0):

layer = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding),
nn.InstanceNorm2d(out_dim),
nn.ReLU(True))
return layer

nn.InstanceNorm2d(out_dim),
nn.ReLU(True))
return layer

class ResidualBlock(nn.Module):

def __init__(self, dim, use_dropout):
super(ResidualBlock, self).__init__()
conv_norm_relu(dim, dim, kernel_size=3)]

if use_dropout:
res_block += [nn.Dropout(0.5)]
nn.InstanceNorm2d(dim)]

self.res_block = nn.Sequential(*res_block)

def forward(self, x):
return x + self.res_block(x)

class Generator(nn.Module):

def __init__(self, input_nc=3, output_nc=3, filters=64, use_dropout=True, n_blocks=6):
super(Generator, self).__init__()

# 向下採樣
conv_norm_relu(input_nc   , filters * 1, 7),
conv_norm_relu(filters * 1, filters * 2, 3, 2, 1),
conv_norm_relu(filters * 2, filters * 4, 3, 2, 1)]

# 頸脖層
for i in range(n_blocks):
model += [ResidualBlock(filters * 4, use_dropout)]

# 向上採樣
model += [dconv_norm_relu(filters * 4, filters * 2, 3, 2, 1, 1),
dconv_norm_relu(filters * 2, filters * 1, 3, 2, 1, 1),
nn.Conv2d(filters, output_nc, 7),
nn.Tanh()]

self.model = nn.Sequential(*model)    # model 是 list 但是 sequential 需要將其透過 , 分割出來

def forward(self, x):
return self.model(x)
``````

``````def init_model():

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
G_B2A = Generator().to(device)
G_B2A.eval()

return G_B2A
``````

3.在Colab中拍照

``````def test(G, img):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
transform = transforms.Compose([transforms.Resize((256,256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

data = transform(img).to(device)
data = data.unsqueeze(0)
out = (0.5 * (G(data).data + 1.0)).squeeze(0)    return out
``````

``````if __name__=='__main__':

G = init_model()

trans_path = 'test_transform.jpg'
org_path = 'test_original.jpg'

cap = cv2.VideoCapture(0)

while(True):

cv2.imshow('webcam', frame)

key = cv2.waitKey(1)

if key==ord('q'):
cap.release()
cv2.destroyAllWindows()
break

elif key==ord('s'):

output = test(G, Image.fromarray(frame))
style_img = np.array(output.cpu()).transpose([1,2,0])
org_img = cv2.resize(frame, (256, 256))

cv2.imwrite(trans_path, style_img*255)
cv2.imwrite(org_path, org_img)
break

cap.release()
cv2.destroyWindow('webcam')

``````

``````    res = np.concatenate((style_img, org_img/255), axis=1)
cv2.imshow('res',res )

cv2.waitKey(0)
cv2.destroyAllWindows()
``````

在Jetson Nano中做即時影像轉換

``````if __name__=='__main__':

G = init_model()
cap = cv2.VideoCapture(0)
change_style = False
save_img_name = 'test.jpg'
cv2text = ''

while(True):
# Do Something Cool
############################
if change_style:
style_img = test(G, Image.fromarray(frame))
out = np.array(style_img.cpu()).transpose([1,2,0])
cv2text = 'Style Transfer'
else:
out = frame
cv2text = 'Original'

out = cv2.resize(out, (512, 512))
out = cv2.putText(out, f'{cv2text}', (20, 40), cv2.FONT_HERSHEY_SIMPLEX ,
1, (255, 255, 255), 2, cv2.LINE_AA)

###########################

cv2.imshow('webcam', out)
key = cv2.waitKey(1)
if key==ord('q'):
break
elif key==ord('s'):
if change_style==True:
cv2.imwrite(save_img_name,out*255)
else:
cv2.imwrite(save_img_name,out)
elif key==ord('t'):
change_style = False if change_style else True
cap.release()
cv2.destroyAllWindows()
``````

補充 – Nano 安裝Torch 1.6的方法

``````\$ wget https://nvidia.box.com/shared/static/yr6sjswn25z7oankw8zy1roow9cy5ur1.whl -O torch-1.6.0rc2-cp36-cp36m-linux_aarch64.whl
\$ sudo apt-get install python3-pip libopenblas-base libopenmpi-dev
\$ pip3 install Cython
\$ pip3 install torch-1.6.0rc2-cp36-cp36m-linux_aarch64.whl
``````

``````\$ sudo apt-get install libjpeg-dev zlib1g-dev
\$ git clone --branch v0.7.0 https://github.com/pytorch/vision torchvision
\$ cd torchvision
\$ export BUILD_VERSION=0.7.0  # where 0.x.0 is the torchvision version
\$ sudo python3 setup.py install     # use python3 if installing for Python 3.6
\$ cd ../  # attempting to load torchvision from build dir will result in import error
``````
