You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

generator.py 6.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. """GAN generator."""
  2. import numpy as np
  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms as transforms
  6. class GlobalGenerator(torch.nn.Module):
  7. """Global Generator."""
  8. def __init__(
  9. self,
  10. input_nc,
  11. output_nc,
  12. ngf=64,
  13. n_downsampling=3,
  14. n_blocks=9,
  15. norm_layer=torch.nn.BatchNorm2d,
  16. padding_type="reflect",
  17. ):
  18. """
  19. Global Generator Constructor.
  20. :param input_nc:
  21. :param output_nc:
  22. :param ngf:
  23. :param n_downsampling:
  24. :param n_blocks:
  25. :param norm_layer:
  26. :param padding_type:
  27. """
  28. if n_blocks < 0:
  29. raise AssertionError()
  30. super(GlobalGenerator, self).__init__()
  31. activation = torch.nn.ReLU(True)
  32. # activation = torch.nn.DataParallel(activation)
  33. model = [
  34. torch.nn.ReflectionPad2d(3),
  35. torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
  36. norm_layer(ngf),
  37. activation,
  38. ]
  39. # downsample
  40. for i in range(n_downsampling):
  41. mult = 2 ** i
  42. model += [
  43. torch.nn.Conv2d(
  44. ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1
  45. ),
  46. norm_layer(ngf * mult * 2),
  47. activation,
  48. ]
  49. # resnet blocks
  50. mult = 2 ** n_downsampling
  51. for _ in range(n_blocks):
  52. model += [
  53. ResnetBlock(
  54. ngf * mult,
  55. padding_type=padding_type,
  56. activation=activation,
  57. norm_layer=norm_layer,
  58. )
  59. ]
  60. # upsample
  61. for i in range(n_downsampling):
  62. mult = 2 ** (n_downsampling - i)
  63. model += [
  64. torch.nn.ConvTranspose2d(
  65. ngf * mult,
  66. int(ngf * mult / 2),
  67. kernel_size=3,
  68. stride=2,
  69. padding=1,
  70. output_padding=1,
  71. ),
  72. norm_layer(int(ngf * mult / 2)),
  73. activation,
  74. ]
  75. model += [
  76. torch.nn.ReflectionPad2d(3),
  77. torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
  78. torch.nn.Tanh(),
  79. ]
  80. self.model = torch.nn.Sequential(*model)
  81. # self.model = torch.nn.DataParallel(self.model)
  82. def forward(self, i):
  83. """
  84. Global Generator forward.
  85. :param i: <> input
  86. :return:
  87. """
  88. return self.model(i)
  89. class ResnetBlock(torch.nn.Module):
  90. """Define a resnet block."""
  91. def __init__(
  92. self,
  93. dim,
  94. padding_type,
  95. norm_layer,
  96. activation=None,
  97. use_dropout=False,
  98. ):
  99. """
  100. Resnet Block constuctor.
  101. :param dim: <> dim
  102. :param padding_type: <> padding_type
  103. :param norm_layer: <> norm_layer
  104. :param activation: <> activation
  105. :param use_dropout: <> use_dropout
  106. """
  107. super(ResnetBlock, self).__init__()
  108. if activation is None:
  109. activation = torch.nn.ReLU(True)
  110. self.conv_block = self.__build_conv_block(
  111. dim, padding_type, norm_layer, activation, use_dropout
  112. )
  113. @staticmethod
  114. def __build_conv_block(
  115. dim, padding_type, norm_layer, activation, use_dropout
  116. ):
  117. conv_block = []
  118. p = 0
  119. if padding_type == "reflect":
  120. conv_block += [torch.nn.ReflectionPad2d(1)]
  121. elif padding_type == "replicate":
  122. conv_block += [torch.nn.ReplicationPad2d(1)]
  123. elif padding_type == "zero":
  124. p = 1
  125. else:
  126. raise NotImplementedError("padding [%s] is not implemented" % padding_type)
  127. conv_block += [
  128. torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
  129. norm_layer(dim),
  130. activation,
  131. ]
  132. if use_dropout:
  133. conv_block += [torch.nn.Dropout(0.5)]
  134. p = 0
  135. if padding_type == "reflect":
  136. conv_block += [torch.nn.ReflectionPad2d(1)]
  137. elif padding_type == "replicate":
  138. conv_block += [torch.nn.ReplicationPad2d(1)]
  139. elif padding_type == "zero":
  140. p = 1
  141. else:
  142. raise NotImplementedError("padding [%s] is not implemented" % padding_type)
  143. conv_block += [
  144. torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
  145. norm_layer(dim),
  146. ]
  147. return torch.nn.Sequential(*conv_block)
  148. def forward(self, x):
  149. """
  150. Resnet Block forward.
  151. :param x: <> input
  152. :return: <> out
  153. """
  154. out = x + self.conv_block(x)
  155. return out
  156. def get_transform(opt, method=Image.BICUBIC, normalize=True):
  157. """
  158. Get transform list.
  159. :param opt: <Config> configuration
  160. :param method: <> transformation method used
  161. :param normalize: <boolean> if true normalization is enable
  162. :return:
  163. """
  164. transform_list = []
  165. base = float(2 ** opt.n_downsample_global)
  166. if opt.net_g == "local":
  167. base *= 2 ** opt.n_local_enhancers
  168. transform_list.append(
  169. transforms.Lambda(lambda img: make_power_2(img, base, method))
  170. )
  171. transform_list += [transforms.ToTensor()]
  172. if normalize:
  173. transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  174. return transforms.Compose(transform_list)
  175. def make_power_2(img, base, method=Image.BICUBIC):
  176. """
  177. Make power 2.
  178. :param img: <> image
  179. :param base: <> base
  180. :param method: <> method
  181. :return:
  182. """
  183. ow, oh = img.size
  184. h = int(round(oh / base) * base)
  185. w = int(round(ow / base) * base)
  186. if (h == oh) and (w == ow):
  187. return img
  188. return img.resize((w, h), method)
  189. def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
  190. """
  191. Convert a Tensor into a Numpy array.
  192. :param image_tensor: <> image tensor
  193. :param imtype: <imtype> the desired type of the converted numpy array
  194. :param normalize: <Boolean> if true normalization is enable
  195. :return:
  196. """
  197. if isinstance(image_tensor, list):
  198. image_numpy = []
  199. for i in image_tensor:
  200. image_numpy.append(tensor2im(i, imtype, normalize))
  201. return image_numpy
  202. image_numpy = image_tensor.cpu().float().numpy()
  203. if normalize:
  204. image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
  205. else:
  206. image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
  207. image_numpy = np.clip(image_numpy, 0, 255)
  208. if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
  209. image_numpy = image_numpy[:, :, 0]
  210. return image_numpy.astype(imtype)