Browse Source

Merge pull request #39 from PommeDroid/checkpoints_download

Download checkpoints directly from dreampower
master
deeppppp 3 years ago committed by GitHub
parent
commit
b2d81a4e3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      README.md
  2. 27
      argv.py
  3. 46
      checkpoints.py
  4. 6
      config.py
  5. 37
      gpu_info.py
  6. 7
      main.py
  7. 4
      requirements.txt
  8. 1
      transform/gan/__init__.py
  9. 59
      utils.py

8
README.md

@ -43,24 +43,20 @@ See **LICENSE.md** for more details. @@ -43,24 +43,20 @@ See **LICENSE.md** for more details.
### Download
Download DreamPower is very easy! 2 files and you are ready. _(Get ready to download ~3GB)_
Download DreamPower is very easy!
- [DreamPower](https://bit.ly/2KdqlYH): The command line interface (CLI), here you will find everything you need, just download the .zip file that fits your operating system.
- [Checkpoints](http://bit.ly/2JBP88o): This is the information that the transformation algorithm **requires**, if you do not have this file the application will not work. You only need to download it once, if you update DreamPower use this same file for checkpoints. (unless we tell you otherwise)
### Download Mirrors
- [DreamPower (MEGA)](https://bit.ly/2GD6aST)
- [DreamPower (MediaFire)](https://bit.ly/2LNjAQk)
- [Checkpoints (MEGA)](http://bit.ly/30GiSbh)
- [Checkpoints (MediaFire)](http://bit.ly/2Y0V6sO)
### Setup
- Create a folder on your computer, it can be anywhere you want it, call it `DreamPower` and inside it place the 2 zip files you have downloaded.
- Extract the file that contains the CLI, this should generate a folder called `dreampower`
- Extract the other file `checkpoints.zip` and move the extracted folder `checkpoints` inside `dreampower`.
- Ready! Now you can use the command line interface run the `dreampower/dreampower.exe` file from a console.
- Run `dreampower checkpoints download` to download the checkpoints.
> When you update DreamPower it will only be necessary to download the file that contains the `DreamPower`, you can reuse the checkpoints (unless we tell you otherwise)

27
argv.py

@ -6,6 +6,7 @@ import re @@ -6,6 +6,7 @@ import re
import sys
from json import JSONDecodeError
import checkpoints
import gpu_info
from main import main
from config import Config as conf
@ -28,6 +29,7 @@ class ArgvParser: @@ -28,6 +29,7 @@ class ArgvParser:
"""
def config_checkpoints(a):
checkpoints_dir = a.checkpoints
a.checkpoints = {
'correct_to_mask': os.path.join(str(a.checkpoints), "cm.lib"),
'maskref_to_maskdet': os.path.join(str(a.checkpoints), "mm.lib"),
@ -35,7 +37,10 @@ class ArgvParser: @@ -35,7 +37,10 @@ class ArgvParser:
}
for _, v in a.checkpoints.items():
if not os.path.isfile(v):
ArgvParser.parser.error("Checkpoints file not found in directory {}".format(a.checkpoints))
ArgvParser.parser.error(
"Checkpoints file not found in directory {}. "
"You can download them using : {} checkpoints download".format(checkpoints_dir, sys.argv[0])
)
def config_body_parts_prefs(a):
a.prefs = {
@ -332,17 +337,25 @@ class ArgvParser: @@ -332,17 +337,25 @@ class ArgvParser:
)
gpu_info_parser = subparsers.add_parser('gpu-info')
gpu_info_parser.add_argument(
"-j",
"--json",
default=False,
action="store_true",
help="Print GPU info as JSON"
gpu_info_subparser = gpu_info_parser.add_subparsers()
gpu_info_json_parser = gpu_info_subparser.add_parser('json')
checkpoints_parser = subparsers.add_parser('checkpoints')
checkpoints_parser_subparser = checkpoints_parser.add_subparsers()
checkpoints_parser_info_parser = checkpoints_parser_subparser.add_parser('download')
checkpoints_parser.add_argument(
"-v",
"--version",
action='version', version='checkpoints {}'.format(conf.checkpoints_version)
)
# Register Command Handlers
ArgvParser.parser.set_defaults(func=main)
gpu_info_parser.set_defaults(func=gpu_info.main)
gpu_info_json_parser.set_defaults(func=gpu_info.json)
checkpoints_parser.set_defaults(func=checkpoints.main)
checkpoints_parser_info_parser.set_defaults(func=checkpoints.download)
# Show usage is no args is provided
if len(sys.argv) == 1:

46
checkpoints.py

@ -0,0 +1,46 @@ @@ -0,0 +1,46 @@
import logging
import os
import shutil
import sys
import tempfile
from config import Config as conf
from utils import setup_log, dll_file, unzip
def main(_):
conf.log = setup_log(logging.DEBUG) if conf.args['debug'] else setup_log()
if sum([1 for x in ["cm.lib", "mm.lib", "mn.lib"] if os.path.isfile(os.path.join(conf.args['checkpoints'], x))]):
conf.log.info("Checkpoints Found In {}".format(conf.args['checkpoints']))
else:
conf.log.warn("Checkpoints Not Found In {}".format(conf.args['checkpoints']))
conf.log.info("You Can Download Them Using : {} checkpoints download".format(sys.argv[0]))
def download(_):
conf.log = setup_log(logging.DEBUG) if conf.args['debug'] else setup_log()
tempdir = tempfile.mkdtemp()
cdn_url = conf.checkpoints_cdn.format(conf.checkpoints_version)
temp_zip = os.path.join(tempdir, "{}.zip".format(conf.checkpoints_version))
try:
conf.log.info("Downloading {}".format(cdn_url))
dll_file(conf.checkpoints_cdn.format(conf.checkpoints_version), temp_zip)
conf.log.info("Extracting {}".format(temp_zip))
unzip(temp_zip, conf.args['checkpoints'])
conf.log.info("Moving Checkpoints To Final location")
[(lambda a: os.remove(a) and shutil.move(a, os.path.abspath(conf.args['checkpoints'])))(x)
for x in (os.path.join(conf.args['checkpoints'], 'checkpoints', y) for y in ("cm.lib", "mm.lib", "mn.lib"))]
shutil.rmtree(os.path.join(conf.args['checkpoints'], 'checkpoints'))
except Exception as e:
conf.log.error(e)
conf.log.error("Something Gone Bad Download Downloading The Checkpoints")
shutil.rmtree(tempdir)
sys.exit(1)
shutil.rmtree(tempdir)
conf.log.info("Checkpoints Downloaded Successfully")

6
config.py

@ -3,6 +3,8 @@ class Config: @@ -3,6 +3,8 @@ class Config:
Variables Configuration Class
"""
version = "v1.0.0"
checkpoints_version = "v0.0.1"
checkpoints_cdn = "https://cdn.dreamnet.tech/releases/checkpoints/{}.zip"
# experiment specifics
norm = "batch" # instance normalization or batch normalization
@ -38,10 +40,6 @@ class Config: @@ -38,10 +40,6 @@ class Config:
# number of epochs that we only train the outmost local enhancer
niter_fix_global = 0
# Phase specific options
checkpoints_dir = ""
dataroot = ""
# Image requirement
desired_size = 512
desired_shape = 512, 512, 3

37
gpu_info.py

@ -1,28 +1,25 @@ @@ -1,28 +1,25 @@
import torch
import json
import logging
from torch import cuda
import json as j
from config import Config as conf
from utils import setup_log
def get_info():
has_cuda = torch.cuda.is_available()
devices = []
if has_cuda:
count = torch.cuda.device_count()
for i in range(count):
devices.append(torch.cuda.get_device_name(i))
def get_info():
return {
"has_cuda": has_cuda,
"devices": devices,
"has_cuda": cuda.is_available(),
"devices": [] if not cuda.is_available() else [cuda.get_device_name(i) for i in range(cuda.device_count())],
}
def main(args):
def main(_):
conf.log = setup_log(logging.DEBUG) if conf.args['debug'] else setup_log()
info = get_info()
if args.json:
data = json.dumps(info)
print(data)
else:
print("Has Cuda: {}".format(info["has_cuda"]))
for (i, device) in enumerate(info["devices"]):
print("GPU {}: {}".format(i, device))
conf.log.info("Has Cuda: {}".format(info["has_cuda"]))
for (i, device) in enumerate(info["devices"]):
conf.log.info("GPU {}: {}".format(i, device))
def json(_):
print(j.dumps(get_info()))

7
main.py

@ -6,7 +6,7 @@ from multiprocessing import freeze_support @@ -6,7 +6,7 @@ from multiprocessing import freeze_support
import argv
from config import Config as conf
from utils import setup_log, read_image, check_shape
from utils import setup_log, check_shape
from processing import SimpleTransform, FolderImageTransform, MultipleImageTransform
from transform.gan.mask import CorrectToMask, MaskrefToMaskdet, MaskfinToNude
@ -55,11 +55,13 @@ def select_phases(): @@ -55,11 +55,13 @@ def select_phases():
phases = [phase] + phases
if conf.args['steps'] and conf.args['steps'][0] != 0:
shift_step(shift_starting=1)
if conf.args['steps'] and conf.args['steps'][1] == len(phases) - 1:
shift_step(shift_ending=1)
return phases
def add_head(phases, phase):
phases = phases + [phase]
if conf.args['steps'] and conf.args['steps'][0] == len(phases) - 1:
if conf.args['steps'] and conf.args['steps'][1] == len(phases) - 1:
shift_step(shift_ending=1)
return phases
@ -135,5 +137,4 @@ def processing_image_folder(phases): @@ -135,5 +137,4 @@ def processing_image_folder(phases):
if __name__ == "__main__":
freeze_support()
# start_rook()
argv.ArgvParser.run()

4
requirements.txt

@ -5,5 +5,5 @@ rsa==4.0 @@ -5,5 +5,5 @@ rsa==4.0
torchvision==0.2.2.post3
torch==1.1.0
imageio==2.5.0
python-dotenv==0.10.3
coloredlogs==10.0
coloredlogs==10.0
requests==2.22.0

1
transform/gan/__init__.py

@ -91,7 +91,6 @@ class Dataset(torch.utils.data.Dataset): @@ -91,7 +91,6 @@ class Dataset(torch.utils.data.Dataset):
def initialize(self, opt, cv_img):
self.opt = opt
self.root = opt.dataroot
self.A = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
self.dataset_size = 1

59
utils.py

@ -2,12 +2,14 @@ import json @@ -2,12 +2,14 @@ import json
import logging
import os
import sys
import zipfile
from re import finditer
import coloredlogs
import cv2
import imageio
import numpy as np
import requests
from PIL import Image
from config import Config as conf
@ -127,4 +129,59 @@ def json_to_argv(data): @@ -127,4 +129,59 @@ def json_to_argv(data):
l.extend(["--{}".format(k), str(v)])
elif v:
l.append("--{}".format(k))
return l
return l
def dll_file(url, file_path):
"""
Download a file
:param url: <string> url of the file to download
:param file_path: <string> file path where save the file
:return: <string> full path of downloaded file
"""
conf.log.debug("Download url : {} to path: {}".format(url, file_path))
response = requests.get(url, stream=True)
dir = os.path.dirname(file_path)
if dir != '':
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
total_length = response.headers.get('content-length')
if total_length is None: # no content length header
f.write(response.content)
else:
dl = 0
total_length = int(total_length)
for data in response.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
done = int(50 * dl / total_length)
print("[{}{}]".format('=' * done, ' ' * (50 - done)), end="\r")
print(" "*80, end="\r")
conf.log.info("{} Downloaded".format(url,))
return file_path
def unzip(zip_path, extract_path):
"""
Extract a zip
:param zip_path: <string> path to zip to extract
:param extract_path: <string> path to dir where to extract
:return: None
"""
conf.log.debug("Extracting zip : {} to path: {}".format(zip_path, extract_path))
if not os.path.exists(extract_path):
os.makedirs(extract_path, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zf:
uncompress_size = sum((file.file_size for file in zf.infolist()))
extracted_size = 0
for file in zf.infolist():
done = int(50 * extracted_size / uncompress_size)
print("[{}{}]".format('=' * done, ' ' * (50 - done)), end="\r")
zf.extract(file, path=extract_path)
extracted_size += file.file_size
print(" "*80, end="\r")

Loading…
Cancel
Save