You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 5.5KB


  1. """GAN Model."""
  2. import functools
  3. import os
  4. from collections import OrderedDict
  5. import cv2
  6. import torch
  7. from PIL import Image
  8. from transform.gan.generator import GlobalGenerator, get_transform
  9. class DataLoader:
  10. """Dataset loader class."""
  11. def __init__(self, opt, cv_img):
  12. """
  13. Construct Data loader.
  14. :param opt: <Config> configuration to use
  15. :param cv_img: <RGB> image
  16. """
  17. super(DataLoader, self).__init__()
  18. self.dataset = Dataset()
  19. self.dataset.initialize(opt, cv_img)
  20. self.dataloader = torch.utils.data.DataLoader(
  21. self.dataset,
  22. batch_size=opt.batch_size,
  23. shuffle=not opt.serial_batches,
  24. num_workers=int(opt.n_threads),
  25. )
  26. def load_data(self):
  27. """
  28. Return loaded data.
  29. :return: <> loaded data
  30. """
  31. return self.dataloader
  32. def __len__(self):
  33. """
  34. Redefine __len___ for DataLoader.
  35. :return: <int> 1
  36. """
  37. return 1
  38. class Dataset(torch.utils.data.Dataset):
  39. """Dataset class."""
  40. def __init__(self):
  41. """Dataset Constructor."""
  42. super(Dataset, self).__init__()
  43. def initialize(self, opt, cv_img):
  44. """
  45. Initialize the Dataset.
  46. :param opt:
  47. :param cv_img:
  48. :return:
  49. """
  50. self.opt = opt
  51. self.A = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
  52. self.dataset_size = 1
  53. def __getitem__(self, index):
  54. """
  55. Redefine Dataset __getitem__.
  56. :param index:
  57. :return:
  58. """
  59. transform_a = get_transform(self.opt)
  60. a_tensor = transform_a(self.A.convert("RGB"))
  61. b_tensor = inst_tensor = feat_tensor = 0
  62. input_dict = {
  63. "label": a_tensor,
  64. "inst": inst_tensor,
  65. "image": b_tensor,
  66. "feat": feat_tensor,
  67. "path": "",
  68. }
  69. return input_dict
  70. def __len__(self):
  71. """
  72. Redefine __len___ for Dataset.
  73. :return: <int> 1
  74. """
  75. return 1
  76. class DeepModel(torch.nn.Module):
  77. """Deep Model."""
  78. def initialize(self, opt, gpu_ids, checkpoints_dir):
  79. """
  80. Deep Model initialize.
  81. :param opt: <Config> configuration to use
  82. :param gpu_ids: <int[]|None> gpu id to use (None = cpu)
  83. :param checkpoints_dir: <string> path to the directoy where models are
  84. :return:
  85. """
  86. self.opt = opt
  87. self.checkpoints_dir = checkpoints_dir
  88. if gpu_ids is None:
  89. self.gpu_ids = []
  90. else:
  91. self.gpu_ids = gpu_ids
  92. self.net_g = self.__define_g(
  93. opt.input_nc,
  94. opt.output_nc,
  95. opt.ngf,
  96. opt.net_g,
  97. opt.n_downsample_global,
  98. opt.n_blocks_global,
  99. opt.n_local_enhancers,
  100. opt.n_blocks_local,
  101. opt.norm,
  102. self.gpu_ids,
  103. )
  104. # load networks
  105. self.__load_network(self.net_g)
  106. def inference(self, label, inst):
  107. """
  108. Infere an image.
  109. :param label: <> label
  110. :param inst: <> isnt
  111. :return: <RGB> image
  112. """
  113. # Encode Inputs
  114. input_label, _, _, _ = self.__encode_input(label, inst, infer=True)
  115. # Fake Generation
  116. input_concat = input_label
  117. with torch.no_grad():
  118. fake_image = self.net_g.forward(input_concat)
  119. return fake_image
  120. # helper loading function that can be used by subclasses
  121. def __load_network(self, network):
  122. save_path = os.path.join(self.checkpoints_dir)
  123. state_dict = torch.load(save_path)
  124. if len(self.gpu_ids) > 1:
  125. new_state_dict = OrderedDict()
  126. for k, v in state_dict.items():
  127. name = "module." + k # add `module.`
  128. new_state_dict[name] = v
  129. else:
  130. new_state_dict = state_dict
  131. network.load_state_dict(new_state_dict)
  132. def __encode_input(
  133. self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False
  134. ):
  135. if len(self.gpu_ids) > 0:
  136. input_label = label_map.data.cuda() # GPU
  137. else:
  138. input_label = label_map.data # CPU
  139. return input_label, inst_map, real_image, feat_map
  140. @staticmethod
  141. def __weights_init(m):
  142. classname = m.__class__.__name__
  143. if classname.find("Conv") != -1:
  144. m.weight.data.normal_(0.0, 0.02)
  145. elif classname.find("BatchNorm2d") != -1:
  146. m.weight.data.normal_(1.0, 0.02)
  147. m.bias.data.fill_(0)
  148. def __define_g(
  149. self,
  150. input_nc,
  151. output_nc,
  152. ngf,
  153. net_g,
  154. n_downsample_global=3,
  155. n_blocks_global=9,
  156. n_local_enhancers=1,
  157. n_blocks_local=3,
  158. norm="instance",
  159. gpu_ids=None,
  160. ):
  161. norm_layer = self.__get_norm_layer(norm_type=norm)
  162. # model
  163. net_g = GlobalGenerator(
  164. input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer
  165. )
  166. if len(gpu_ids) > 1:
  167. net_g = torch.nn.DataParallel(net_g, gpu_ids)
  168. if len(gpu_ids) > 0:
  169. net_g.cuda(gpu_ids[0])
  170. net_g.apply(self.__weights_init)
  171. return net_g
  172. @staticmethod
  173. def __get_norm_layer(norm_type="instance"):
  174. norm_layer = functools.partial(torch.nn.InstanceNorm2d, affine=False)
  175. return norm_layer