You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 2.0KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. """GAN Transforms."""
  2. import cv2
  3. from config import Config as Conf
  4. from transform import ImageTransform
  5. from transform.gan.generator import tensor2im
  6. from transform.gan.model import DeepModel, DataLoader
  7. class ImageTransformGAN(ImageTransform):
  8. """Abstract GAN Image Transformation Class."""
  9. def __init__(self, checkpoint, phase, input_index=(-1,), args=None):
  10. """
  11. Abstract GAN Image Transformation Class Constructor.
  12. :param checkpoint: <string> path to the checkpoint
  13. :param phase: <string> phase name
  14. :param input_index: <tuple> index where to take the inputs (default is (-1) for previous transformation)
  15. :param args: <dict> args parameter to run the image transformation (default use Conf.args)
  16. """
  17. super().__init__(input_index=input_index, args=args)
  18. self.__checkpoint = checkpoint
  19. self.__phase = phase
  20. self.__gpu_ids = self._args["gpu_ids"]
  21. def _setup(self, *args):
  22. """
  23. Load Dataset and Model fot the image.
  24. :param args: <[RGB]> image to be transform
  25. :return: None
  26. """
  27. if self.__gpu_ids:
  28. Conf.log.debug("GAN Processing Using GPU IDs: {}".format(self.__gpu_ids))
  29. else:
  30. Conf.log.debug("GAN Processing Using CPU")
  31. c = Conf()
  32. # Load custom phase options:
  33. data_loader = DataLoader(c, args[0])
  34. self.__dataset = data_loader.load_data()
  35. # Create Model
  36. self.__model = DeepModel()
  37. self.__model.initialize(c, self.__gpu_ids, self.__checkpoint)
  38. def _execute(self, *args):
  39. """
  40. Execute the GAN Transformation the image.
  41. :param *args: <[RGB]> image to transform
  42. :return: <RGB> image transformed
  43. """
  44. mask = None
  45. for data in self.__dataset:
  46. generated = self.__model.inference(data["label"], data["inst"])
  47. im = tensor2im(generated.data[0])
  48. mask = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
  49. return mask