gan.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. from PIL import Image
  2. import numpy as np
  3. import cv2
  4. import torchvision.transforms as transforms
  5. import torch
  6. import io
  7. import os
  8. import functools
  9. class DataLoader():
  10. def __init__(self, opt, cv_img):
  11. super(DataLoader, self).__init__()
  12. self.dataset = Dataset()
  13. self.dataset.initialize(opt, cv_img)
  14. self.dataloader = torch.utils.data.DataLoader(
  15. self.dataset,
  16. batch_size=opt.batchSize,
  17. shuffle=not opt.serial_batches,
  18. num_workers=int(opt.nThreads))
  19. def load_data(self):
  20. return self.dataloader
  21. def __len__(self):
  22. return 1
  23. class Dataset(torch.utils.data.Dataset):
  24. def __init__(self):
  25. super(Dataset, self).__init__()
  26. def initialize(self, opt, cv_img):
  27. self.opt = opt
  28. self.root = opt.dataroot
  29. self.A = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
  30. self.dataset_size = 1
  31. def __getitem__(self, index):
  32. transform_A = get_transform(self.opt)
  33. A_tensor = transform_A(self.A.convert('RGB'))
  34. B_tensor = inst_tensor = feat_tensor = 0
  35. input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
  36. 'feat': feat_tensor, 'path': ""}
  37. return input_dict
  38. def __len__(self):
  39. return 1
  40. class DeepModel(torch.nn.Module):
  41. def initialize(self, opt):
  42. self.opt = opt
  43. self.gpu_ids = [] #FIX CPU
  44. self.netG = self.__define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
  45. opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
  46. opt.n_blocks_local, opt.norm, self.gpu_ids)
  47. # load networks
  48. self.__load_network(self.netG)
  49. def inference(self, label, inst):
  50. # Encode Inputs
  51. input_label, inst_map, _, _ = self.__encode_input(label, inst, infer=True)
  52. # Fake Generation
  53. input_concat = input_label
  54. with torch.no_grad():
  55. fake_image = self.netG.forward(input_concat)
  56. return fake_image
  57. # helper loading function that can be used by subclasses
  58. def __load_network(self, network):
  59. save_path = os.path.join(self.opt.checkpoints_dir)
  60. network.load_state_dict(torch.load(save_path))
  61. def __encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
  62. if (len(self.gpu_ids) > 0):
  63. input_label = label_map.data.cuda() #GPU
  64. else:
  65. input_label = label_map.data #CPU
  66. return input_label, inst_map, real_image, feat_map
  67. def __weights_init(self, m):
  68. classname = m.__class__.__name__
  69. if classname.find('Conv') != -1:
  70. m.weight.data.normal_(0.0, 0.02)
  71. elif classname.find('BatchNorm2d') != -1:
  72. m.weight.data.normal_(1.0, 0.02)
  73. m.bias.data.fill_(0)
  74. def __define_G(self, input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
  75. n_blocks_local=3, norm='instance', gpu_ids=[]):
  76. norm_layer = self.__get_norm_layer(norm_type=norm)
  77. netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
  78. if len(gpu_ids) > 0:
  79. netG.cuda(gpu_ids[0])
  80. netG.apply(self.__weights_init)
  81. return netG
  82. def __get_norm_layer(self, norm_type='instance'):
  83. norm_layer = functools.partial(torch.nn.InstanceNorm2d, affine=False)
  84. return norm_layer
  85. ##############################################################################
  86. # Generator
  87. ##############################################################################
  88. class GlobalGenerator(torch.nn.Module):
  89. def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=torch.nn.BatchNorm2d,
  90. padding_type='reflect'):
  91. assert(n_blocks >= 0)
  92. super(GlobalGenerator, self).__init__()
  93. activation = torch.nn.ReLU(True)
  94. model = [torch.nn.ReflectionPad2d(3), torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
  95. ### downsample
  96. for i in range(n_downsampling):
  97. mult = 2**i
  98. model += [torch.nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
  99. norm_layer(ngf * mult * 2), activation]
  100. ### resnet blocks
  101. mult = 2**n_downsampling
  102. for i in range(n_blocks):
  103. model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
  104. ### upsample
  105. for i in range(n_downsampling):
  106. mult = 2**(n_downsampling - i)
  107. model += [torch.nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
  108. norm_layer(int(ngf * mult / 2)), activation]
  109. model += [torch.nn.ReflectionPad2d(3), torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), torch.nn.Tanh()]
  110. self.model = torch.nn.Sequential(*model)
  111. def forward(self, input):
  112. return self.model(input)
  113. # Define a resnet block
  114. class ResnetBlock(torch.nn.Module):
  115. def __init__(self, dim, padding_type, norm_layer, activation=torch.nn.ReLU(True), use_dropout=False):
  116. super(ResnetBlock, self).__init__()
  117. self.conv_block = self.__build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
  118. def __build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
  119. conv_block = []
  120. p = 0
  121. if padding_type == 'reflect':
  122. conv_block += [torch.nn.ReflectionPad2d(1)]
  123. elif padding_type == 'replicate':
  124. conv_block += [torch.nn.ReplicationPad2d(1)]
  125. elif padding_type == 'zero':
  126. p = 1
  127. else:
  128. raise NotImplementedError('padding [%s] is not implemented' % padding_type)
  129. conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
  130. norm_layer(dim),
  131. activation]
  132. if use_dropout:
  133. conv_block += [torch.nn.Dropout(0.5)]
  134. p = 0
  135. if padding_type == 'reflect':
  136. conv_block += [torch.nn.ReflectionPad2d(1)]
  137. elif padding_type == 'replicate':
  138. conv_block += [torch.nn.ReplicationPad2d(1)]
  139. elif padding_type == 'zero':
  140. p = 1
  141. else:
  142. raise NotImplementedError('padding [%s] is not implemented' % padding_type)
  143. conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
  144. norm_layer(dim)]
  145. return torch.nn.Sequential(*conv_block)
  146. def forward(self, x):
  147. out = x + self.conv_block(x)
  148. return out
  149. # Data utils:
  150. def get_transform(opt, method=Image.BICUBIC, normalize=True):
  151. transform_list = []
  152. base = float(2 ** opt.n_downsample_global)
  153. if opt.netG == 'local':
  154. base *= (2 ** opt.n_local_enhancers)
  155. transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
  156. transform_list += [transforms.ToTensor()]
  157. if normalize:
  158. transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
  159. (0.5, 0.5, 0.5))]
  160. return transforms.Compose(transform_list)
  161. def __make_power_2(img, base, method=Image.BICUBIC):
  162. ow, oh = img.size
  163. h = int(round(oh / base) * base)
  164. w = int(round(ow / base) * base)
  165. if (h == oh) and (w == ow):
  166. return img
  167. return img.resize((w, h), method)
  168. # Converts a Tensor into a Numpy array
  169. # |imtype|: the desired type of the converted numpy array
  170. def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
  171. if isinstance(image_tensor, list):
  172. image_numpy = []
  173. for i in range(len(image_tensor)):
  174. image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
  175. return image_numpy
  176. image_numpy = image_tensor.cpu().float().numpy()
  177. if normalize:
  178. image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
  179. else:
  180. image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
  181. image_numpy = np.clip(image_numpy, 0, 255)
  182. if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
  183. image_numpy = image_numpy[:,:,0]
  184. return image_numpy.astype(imtype)