You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

424 lines
12 KiB

import functools
import os
from collections import OrderedDict
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms as transforms
from config import Config as conf
from transform import ImageTransform
class ImageTransformGAN(ImageTransform):
"""
Abstract GAN Image Transformation Class
"""
def __init__(self, checkpoint, phase, input_index=(-1,), args=None):
"""
Abstract GAN Image Transformation Class Constructor
:param checkpoint: <string> path to the checkpoint
:param phase: <string> phase name
:param input_index: <tuple> index where to take the inputs (default is (-1) for previous transformation)
:param args: <dict> args parameter to run the image transformation (default use conf.args)
"""
super().__init__(input_index=input_index, args=args)
self.__checkpoint = checkpoint
self.__phase = phase
self.__gpu_ids = self._args["gpu_ids"]
def setup(self, image):
"""
Load Dataset and Model fot the image
:param image: <RGB> image to be transform
:return: None
"""
if self.__gpu_ids:
conf.log.debug("GAN Processing Using GPU IDs: {}".format(self.__gpu_ids))
else:
conf.log.debug("GAN Processing Using CPU")
c = conf()
# Load custom phase options:
data_loader = DataLoader(c, image)
self.__dataset = data_loader.load_data()
# Create Model
self.__model = DeepModel()
self.__model.initialize(c, self.__gpu_ids, self.__checkpoint)
def execute(self, image):
"""
Excute the GAN Transformation the image
:param image: <RGB> image to transform
:return: <RGB> image transformed
"""
for i, data in enumerate(self.__dataset):
generated = self.__model.inference(data["label"], data["inst"])
im = tensor2im(generated.data[0])
mask = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
return mask
class DataLoader:
def __init__(self, opt, cv_img):
super(DataLoader, self).__init__()
self.dataset = Dataset()
self.dataset.initialize(opt, cv_img)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
)
def load_data(self):
return self.dataloader
def __len__(self):
return 1
class Dataset(torch.utils.data.Dataset):
def __init__(self):
super(Dataset, self).__init__()
def initialize(self, opt, cv_img):
self.opt = opt
self.root = opt.dataroot
self.A = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
self.dataset_size = 1
def __getitem__(self, index):
transform_A = get_transform(self.opt)
A_tensor = transform_A(self.A.convert("RGB"))
B_tensor = inst_tensor = feat_tensor = 0
input_dict = {
"label": A_tensor,
"inst": inst_tensor,
"image": B_tensor,
"feat": feat_tensor,
"path": "",
}
return input_dict
def __len__(self):
return 1
class DeepModel(torch.nn.Module):
def initialize(self, opt, gpu_ids, checkpoints_dir):
self.opt = opt
self.checkpoints_dir = checkpoints_dir
if gpu_ids is None:
self.gpu_ids = []
else:
self.gpu_ids = gpu_ids
self.netG = self.__define_G(
opt.input_nc,
opt.output_nc,
opt.ngf,
opt.netG,
opt.n_downsample_global,
opt.n_blocks_global,
opt.n_local_enhancers,
opt.n_blocks_local,
opt.norm,
self.gpu_ids,
)
# load networks
self.__load_network(self.netG)
def inference(self, label, inst):
# Encode Inputs
input_label, inst_map, _, _ = self.__encode_input(label, inst, infer=True)
# Fake Generation
input_concat = input_label
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
return fake_image
# helper loading function that can be used by subclasses
def __load_network(self, network):
save_path = os.path.join(self.checkpoints_dir)
state_dict = torch.load(save_path)
if len(self.gpu_ids) > 1:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = "module." + k # add `module.`
new_state_dict[name] = v
else:
new_state_dict = state_dict
network.load_state_dict(new_state_dict)
def __encode_input(
self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False
):
if len(self.gpu_ids) > 0:
input_label = label_map.data.cuda() # GPU
else:
input_label = label_map.data # CPU
return input_label, inst_map, real_image, feat_map
def __weights_init(self, m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def __define_G(
self,
input_nc,
output_nc,
ngf,
netG,
n_downsample_global=3,
n_blocks_global=9,
n_local_enhancers=1,
n_blocks_local=3,
norm="instance",
gpu_ids=[],
):
norm_layer = self.__get_norm_layer(norm_type=norm)
# model
netG = GlobalGenerator(
input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer
)
if len(gpu_ids) > 1:
# print(
# "Using",
# len(gpu_ids),
# "of",
# torch.cuda.device_count(),
# "GPUs available.",
# )
netG = torch.nn.DataParallel(netG, gpu_ids)
if len(gpu_ids) > 0:
netG.cuda(gpu_ids[0])
netG.apply(self.__weights_init)
return netG
def __get_norm_layer(self, norm_type="instance"):
norm_layer = functools.partial(torch.nn.InstanceNorm2d, affine=False)
return norm_layer
##############################################################################
# Generator
##############################################################################
class GlobalGenerator(torch.nn.Module):
def __init__(
self,
input_nc,
output_nc,
ngf=64,
n_downsampling=3,
n_blocks=9,
norm_layer=torch.nn.BatchNorm2d,
padding_type="reflect",
):
assert n_blocks >= 0
super(GlobalGenerator, self).__init__()
activation = torch.nn.ReLU(True)
# activation = torch.nn.DataParallel(activation)
model = [
torch.nn.ReflectionPad2d(3),
torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
activation,
]
# downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [
torch.nn.Conv2d(
ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1
),
norm_layer(ngf * mult * 2),
activation,
]
# resnet blocks
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [
ResnetBlock(
ngf * mult,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
)
]
# upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [
torch.nn.ConvTranspose2d(
ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
norm_layer(int(ngf * mult / 2)),
activation,
]
model += [
torch.nn.ReflectionPad2d(3),
torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
torch.nn.Tanh(),
]
self.model = torch.nn.Sequential(*model)
# self.model = torch.nn.DataParallel(self.model)
def forward(self, input):
return self.model(input)
# Define a resnet block
class ResnetBlock(torch.nn.Module):
def __init__(
self,
dim,
padding_type,
norm_layer,
activation=torch.nn.ReLU(True),
use_dropout=False,
):
super(ResnetBlock, self).__init__()
self.conv_block = self.__build_conv_block(
dim, padding_type, norm_layer, activation, use_dropout
)
def __build_conv_block(
self, dim, padding_type, norm_layer, activation, use_dropout
):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [torch.nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [torch.nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block += [
torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
activation,
]
if use_dropout:
conv_block += [torch.nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [torch.nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [torch.nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block += [
torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
]
return torch.nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
# Data utils:
def get_transform(opt, method=Image.BICUBIC, normalize=True):
transform_list = []
base = float(2 ** opt.n_downsample_global)
if opt.netG == "local":
base *= 2 ** opt.n_local_enhancers
transform_list.append(
transforms.Lambda(lambda img: __make_power_2(img, base, method))
)
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
if isinstance(image_tensor, list):
image_numpy = []
for i in range(len(image_tensor)):
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
return image_numpy
image_numpy = image_tensor.cpu().float().numpy()
if normalize:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
else:
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
image_numpy = image_numpy[:, :, 0]
return image_numpy.astype(imtype)