Browse Source

Add option location checkpoints

* add -c, --checkpoints to set a custom location for the checkpoints
master
PommeDroid 3 years ago
parent
commit
7acb25e9c0
  1. 19
      argv.py
  2. 7
      config.py
  3. 8
      transform/gan/mask.py

19
argv.py

@ -12,6 +12,17 @@ from gpu_info import get_info @@ -12,6 +12,17 @@ from gpu_info import get_info
def config_args(parser, args):
def config_checkpoints():
checkpoints = {
'correct_to_mask': os.path.join(args.checkpoints, "cm.lib"),
'maskref_to_maskdet': os.path.join(args.checkpoints, "mm.lib"),
'maskfin_to_nude': os.path.join(args.checkpoints, "mn.lib"),
}
for _, v in checkpoints.items():
if not os.path.isfile(v):
parser.error("Checkpoints file not found in directory {}".format(args.checkpoints))
return checkpoints
def config_body_parts_prefs():
prefs = {
"titsize": args.bsize,
@ -51,6 +62,7 @@ def config_args(parser, args): @@ -51,6 +62,7 @@ def config_args(parser, args):
if args.func == main:
conf.args = vars(args)
conf.args['checkpoints'] = config_checkpoints()
conf.args['gpu_ids'] = config_gpu_ids()
conf.args['prefs'] = config_body_parts_prefs()
config_args_in()
@ -253,6 +265,13 @@ def run(): @@ -253,6 +265,13 @@ def run():
help="path of the directory where steps images transformation are write."
)
parser.add_argument(
"-c",
"--checkpoints",
default=os.path.join(os.path.dirname(os.path.realpath(__file__)), "checkpoints" ),
help="path of the directory containing the checkpoints."
)
# Register Command Handlers
parser.set_defaults(func=main)
gpu_info_parser.set_defaults(func=gpu_info.main)

7
config.py

@ -46,13 +46,6 @@ class Config: @@ -46,13 +46,6 @@ class Config:
# Image requirement
desired_size = 512
# GAN checkpoints location
checkpoints = dict({
'correct_to_mask': os.path.join(os.path.dirname(os.path.realpath(__file__)), "checkpoints", "cm.lib"),
'maskref_to_maskdet': os.path.join(os.path.dirname(os.path.realpath(__file__)), "checkpoints", "mm.lib"),
'maskfin_to_nude': os.path.join(os.path.dirname(os.path.realpath(__file__)), "checkpoints", "mn.lib"),
})
# Argparser dict
args = {}

8
transform/gan/mask.py

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
from transform.gan import ImageTransformGAN
from config import Config as opt
from config import Config as conf
class CorrectToMask(ImageTransformGAN):
@ -11,7 +11,7 @@ class CorrectToMask(ImageTransformGAN): @@ -11,7 +11,7 @@ class CorrectToMask(ImageTransformGAN):
"""
CorrectToMask Constructor
"""
super().__init__(opt.checkpoints["correct_to_mask"], "correct_to_mask")
super().__init__(conf.args['checkpoints']["correct_to_mask"], "correct_to_mask")
class MaskrefToMaskdet(ImageTransformGAN):
@ -23,7 +23,7 @@ class MaskrefToMaskdet(ImageTransformGAN): @@ -23,7 +23,7 @@ class MaskrefToMaskdet(ImageTransformGAN):
"""
MaskrefToMaskdet Constructor
"""
super().__init__(opt.checkpoints["maskref_to_maskdet"], "maskref_to_maskdet")
super().__init__(conf.args['checkpoints']["maskref_to_maskdet"], "maskref_to_maskdet")
class MaskfinToNude(ImageTransformGAN):
@ -35,4 +35,4 @@ class MaskfinToNude(ImageTransformGAN): @@ -35,4 +35,4 @@ class MaskfinToNude(ImageTransformGAN):
"""
MaskfinToNude Constructor
"""
super().__init__(opt.checkpoints["maskfin_to_nude"], "maskfin_to_nude")
super().__init__(conf.args['checkpoints']["maskfin_to_nude"], "maskfin_to_nude")

Loading…
Cancel
Save