You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

correct.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. """OpenCV Correct Transforms."""
  2. import math
  3. import cv2
  4. import numpy as np
  5. from third.opencv.color_transfer import color_transfer
  6. from transform.opencv import ImageTransformOpenCV
  7. class DressToCorrect(ImageTransformOpenCV):
  8. """Dress -> Correct [OPENCV]."""
  9. def _execute(self, *args):
  10. """
  11. Execute dress to correct phase.
  12. :param args: <[RGB]> Image to correct
  13. :return: <RGB> image corrected
  14. """
  15. return self.correct_color(args[0], 5)
  16. @staticmethod
  17. def correct_color(img, percent):
  18. """
  19. Correct the color of an image.
  20. :param img: <RGB> Image to correct
  21. :param percent: <int> Percent of correction (1-100)
  22. :return <RGB>: image corrected
  23. """
  24. if img.shape[2] != 3:
  25. raise AssertionError()
  26. if not 0 < percent <= 100:
  27. raise AssertionError()
  28. half_percent = percent / 200.0
  29. channels = cv2.split(img)
  30. out_channels = []
  31. for channel in channels:
  32. if len(channel.shape) != 2:
  33. raise AssertionError()
  34. # find the low and high precentile values (based on the input percentile)
  35. height, width = channel.shape
  36. vec_size = width * height
  37. flat = channel.reshape(vec_size)
  38. if len(flat.shape) != 1:
  39. raise AssertionError()
  40. flat = np.sort(flat)
  41. n_cols = flat.shape[0]
  42. low_val = flat[math.floor(n_cols * half_percent)]
  43. high_val = flat[math.ceil(n_cols * (1.0 - half_percent))]
  44. # saturate below the low percentile and above the high percentile
  45. thresholded = DressToCorrect.apply_threshold(channel, low_val, high_val)
  46. # scale the channel
  47. normalized = cv2.normalize(thresholded, thresholded.copy(), 0, 255, cv2.NORM_MINMAX)
  48. out_channels.append(normalized)
  49. return cv2.merge(out_channels)
  50. @staticmethod
  51. def apply_threshold(matrix, low_value, high_value):
  52. """
  53. Apply a threshold on a matrix.
  54. :param matrix: <array> matrix
  55. :param low_value: <float> low value
  56. :param high_value: <float> high value
  57. :return: None
  58. """
  59. low_mask = matrix < low_value
  60. matrix = DressToCorrect.apply_mask(matrix, low_mask, low_value)
  61. high_mask = matrix > high_value
  62. matrix = DressToCorrect.apply_mask(matrix, high_mask, high_value)
  63. return matrix
  64. @staticmethod
  65. def apply_mask(matrix, mask, fill_value):
  66. """
  67. Apply a mask on a matrix.
  68. :param matrix: <array> matrix
  69. :param mask: <RGB> image mask
  70. :param fill_value: <> fill value
  71. :return: None
  72. """
  73. masked = np.ma.array(matrix, mask=mask, fill_value=fill_value)
  74. return masked.filled()
  75. class ColorTransfer(ImageTransformOpenCV):
  76. """ColorTransfer [OPENCV]."""
  77. def __init__(self, input_index=(0, -1)):
  78. """
  79. Color Transfer constructor.
  80. :param input_index: <tuple> index where to take the inputs (default is (0,-1)
  81. for first and previous transformation)
  82. :param args: <dict> args parameter to run the image transformation (default use Conf.args)
  83. """
  84. super().__init__(input_index=input_index)
  85. def _execute(self, *args):
  86. """
  87. Transfers the color distribution from the source to the target.
  88. :param args: <[RGB,RGB]> Image source, Image target
  89. :return: <RGB> Color transfer image
  90. """
  91. return color_transfer(args[0], args[1], clip=True, preserve_paper=False)