You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

checkpoints.py 2.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. import sys
  3. import checkpoints
  4. from config import Config as Conf
  5. from argv.common import arg_help, arg_debug
  6. def init_checkpoints_sub_parser(subparsers):
  7. checkpoints_parser = subparsers.add_parser(
  8. 'checkpoints',
  9. description="Handle checkpoints for dreampower.",
  10. help="Handle checkpoints for dreampower.",
  11. add_help=False
  12. )
  13. # add checkpoints arguments
  14. arg_checkpoints(checkpoints_parser)
  15. arg_help(checkpoints_parser)
  16. arg_debug(checkpoints_parser)
  17. arg_version(checkpoints_parser)
  18. # add download subparser
  19. checkpoints_parser_subparser = checkpoints_parser.add_subparsers()
  20. checkpoints_parser_info_parser = checkpoints_parser_subparser.add_parser(
  21. 'download',
  22. description="Download checkpoints for dreampower.",
  23. help="Download checkpoints for dreampower."
  24. )
  25. checkpoints_parser.set_defaults(func=checkpoints.main)
  26. checkpoints_parser_info_parser.set_defaults(func=checkpoints.download)
  27. return checkpoints_parser
  28. def set_args_checkpoints_parser(args):
  29. set_arg_checkpoints(args)
  30. def check_args_checkpoints_parser(parser, args):
  31. check_arg_checkpoints(parser, args)
  32. def check_arg_checkpoints(parser, args):
  33. Conf.log.debug(args.checkpoints)
  34. for _, v in args.checkpoints.items():
  35. if not os.path.isfile(v):
  36. parser.error(
  37. "Checkpoints file not found. "
  38. "You can download them using : {} checkpoints download".format(sys.argv[0])
  39. )
  40. def set_arg_checkpoints(args):
  41. Conf.log.debug(args.checkpoints)
  42. args.checkpoints = {
  43. 'correct_to_mask': os.path.join(str(args.checkpoints), "cm.lib"),
  44. 'maskref_to_maskdet': os.path.join(str(args.checkpoints), "mm.lib"),
  45. 'maskfin_to_nude': os.path.join(str(args.checkpoints), "mn.lib"),
  46. }
  47. def arg_checkpoints(parser):
  48. parser.add_argument(
  49. "-c",
  50. "--checkpoints",
  51. default=os.path.join(os.getcwd(), "checkpoints"),
  52. help="Path of the directory containing the checkpoints. Default : ./checkpoints"
  53. )
  54. def arg_version(parser):
  55. parser.add_argument(
  56. "-v",
  57. "--version",
  58. action='version', version='checkpoints {}'.format(Conf.checkpoints_version)
  59. )