Browse Source

CLI improvements

Add n_cores, n_runs options
Add gif support to auto rescale options
Update readme
master
Ringo 3 years ago committed by PommeDroid
parent
commit
2ec6539528
  1. 85
      main.py
  2. 19
      run.py
  3. 3
      src/cli/utils.py

85
main.py

@ -12,6 +12,7 @@ import utils @@ -12,6 +12,7 @@ import utils
from run import process, process_gif
from multiprocessing import freeze_support
from multiprocessing.pool import ThreadPool
from dotenv import load_dotenv
#
@ -46,25 +47,31 @@ parser.add_argument( @@ -46,25 +47,31 @@ parser.add_argument(
help="generates pubic hair on output image",
)
parser.add_argument(
"--gif", action="store_true", default=False, help="Run the processing on a gif"
"--gif", action="store_true", default=False, help="run the processing on a gif"
)
parser.add_argument(
"--auto-resize",
action="store_true",
default=False,
help="Scale and pad image to 512x512. Maintains aspect ratio. Doesn't support gifs for now.",
help="Scale and pad image to 512x512 (maintains aspect ratio)",
)
parser.add_argument(
"--auto-resize-crop",
action="store_true",
default=False,
help="Scale and crop image to 512x512. Maintains aspect ratio. Doesn't support gifs for now.",
help="Scale and crop image to 512x512 (maintains aspect ratio)",
)
parser.add_argument(
"--auto-rescale",
action="store_true",
default=False,
help="Scale image to 512x512. Doesn't support gifs for now.",
help="Scale image to 512x512",
)
parser.add_argument(
"-n", "--n_runs", type=int, help="number of times to process input (default: 1)",
)
parser.add_argument(
"--n_cores", type=int, default=4, help="number of cpu cores to use (default: 4)",
)
args = parser.parse_args()
@ -88,9 +95,10 @@ def main(): @@ -88,9 +95,10 @@ def main():
gpu_ids = [0]
if not args.gif:
# Read input image
# Read image
image = cv2.imread(args.input)
# Preprocess
if args.auto_resize:
image = utils.resize_input(image)
elif args.auto_resize_crop:
@ -99,25 +107,44 @@ def main(): @@ -99,25 +107,44 @@ def main():
image = utils.rescale_input(image)
# Process
result = process(image, gpu_ids, args.enablepubes)
# Write output image
cv2.imwrite(args.output, result)
if args.n_runs is None or args.n_runs == 1:
result = process(image, gpu_ids, args.enablepubes)
cv2.imwrite(args.output, result)
else:
base_output_filename = utils.strip_file_extension(args.output, ".png")
def process_one_image(i):
result = process(image, gpu_ids, args.enablepubes)
cv2.imwrite(base_output_filename + "%03d.png" % i, result)
if args.cpu:
pool = ThreadPool(args.n_cores)
pool.map(process_one_image, range(args.n_runs))
pool.close()
pool.join()
else:
for i in range(args.n_runs):
process_one_image(i)
else:
# Read images
gif_imgs = imageio.mimread(args.input)
nums = len(gif_imgs)
print("Total {} frames in the gif!".format(nums))
tmp_dir = tempfile.mkdtemp()
process_gif(gif_imgs, gpu_ids, args.enablepubes, tmp_dir)
print("Creating gif")
imageio.mimsave(
args.output if args.output != "output.png" else "output.gif",
[
imageio.imread(os.path.join(tmp_dir, "output_{}.jpg".format(i)))
for i in range(nums)
],
)
shutil.rmtree(tmp_dir)
print("Total {} frames in the gif!".format(len(gif_imgs)))
# Preprocess
if args.auto_resize:
gif_imgs = [utils.resize_input(img) for img in gif_imgs]
elif args.auto_resize_crop:
gif_imgs = [utils.resize_crop_input(img) for img in gif_imgs]
elif args.auto_rescale:
gif_imgs = [utils.rescale_input(img) for img in gif_imgs]
# Process
if args.n_runs is None or args.n_runs == 1:
process_gif_wrapper(gif_imgs, args.output if args.output != "output.png" else "output.gif", gpu_ids, args.enablepubes, args.n_cores)
else:
base_output_filename = utils.strip_file_extension(args.output, ".gif") if args.output != "output.png" else "output"
for i in range(args.n_runs):
process_gif_wrapper(gif_imgs, base_output_filename + "%03d.gif" % i, gpu_ids, args.enablepubes, args.n_cores)
end = time.time()
duration = end - start
@ -129,6 +156,20 @@ def main(): @@ -129,6 +156,20 @@ def main():
sys.exit()
def process_gif_wrapper(gif_imgs, filename, gpu_ids, enablepubes, n_cores):
tmp_dir = tempfile.mkdtemp()
process_gif(gif_imgs, gpu_ids, enablepubes, tmp_dir, n_cores)
print("Creating gif")
imageio.mimsave(
filename,
[
imageio.imread(os.path.join(tmp_dir, "output_{}.jpg".format(i)))
for i in range(len(gif_imgs))
],
)
shutil.rmtree(tmp_dir)
def start_sentry():
dsn = os.getenv("SENTRY_DSN")

19
run.py

@ -191,25 +191,20 @@ def process(cv_img, gpu_ids, enable_pubes): @@ -191,25 +191,20 @@ def process(cv_img, gpu_ids, enable_pubes):
# return:
# gif
def process_gif(gif_imgs, gpu_ids, enable_pubes, tmp_dir):
def process_gif(gif_imgs, gpu_ids, enable_pubes, tmp_dir, n_cores):
def process_one_image(a):
print("Processing image : {}/{}".format(a[1] + 1, len(gif_imgs)))
img = cv2.resize(a[0], (512, 512))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(
os.path.join(tmp_dir, "output_{}.jpg".format(a[1])),
process(img, gpu_ids, enable_pubes),
)
print(gpu_ids)
if (
gpu_ids is None
): # Only multithreading with CPU because threads cause crashes with GPU
pool = ThreadPool(4)
cv2.imwrite(os.path.join(tmp_dir, "output_{}.jpg".format(a[1])), process(img, gpu_ids, enable_pubes))
print("GPU IDs: " + str(gpu_ids), flush=True)
if gpu_ids is None: # Only multithreading with CPU because threads cause crashes with GPU
pool = ThreadPool(n_cores)
pool.map(process_one_image, zip(gif_imgs, range(len(gif_imgs))))
pool.close()
pool.join()
else:
for x in zip(gif_imgs, range(len(gif_imgs))):
process_one_image(x)

3
src/cli/utils.py

@ -37,3 +37,6 @@ def resize_crop_input(img): @@ -37,3 +37,6 @@ def resize_crop_input(img):
def rescale_input(img):
return cv2.resize(img, (desired_size, desired_size))
def strip_file_extension(filename, extension):
return filename[::-1].replace(extension[::-1], "", 1)[::-1]

Loading…
Cancel
Save