You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
3.2 KiB
Python

import logging
import cv2
import numpy as np
from scipy import signal
import torch
import os
def round_func(input):
# Backward Pass Differentiable Approximation (BPDA)
# This is equivalent to replacing round function (non-differentiable)
# with an identity function (differentiable) only when backward,
forward_value = torch.round(input)
out = input.clone()
out.data = forward_value.data
return out
def logger_info(logger_name, log_path='default_logger.log'):
log = logging.getLogger(logger_name)
if log.hasHandlers():
print('LogHandlers exist!')
else:
print('LogHandlers setup!')
level = logging.INFO
formatter = logging.Formatter(
'%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
fh = logging.FileHandler(log_path, mode='a')
fh.setFormatter(formatter)
log.setLevel(level)
log.addHandler(fh)
# print(len(log.handlers))
sh = logging.StreamHandler()
sh.setFormatter(formatter)
log.addHandler(sh)
def modcrop(image, modulo):
if len(image.shape) == 2:
sz = image.shape
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1]]
elif image.shape[2] == 3:
sz = image.shape[0:2]
sz = sz - np.mod(sz, modulo)
image = image[0:sz[0], 0:sz[1], :]
else:
raise NotImplementedError
return image
def _rgb2ycbcr(img, maxVal=255):
O = np.array([[16],
[128],
[128]])
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
if maxVal == 1:
O = O / 255.0
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
t = np.dot(t, np.transpose(T))
t[:, 0] += O[0]
t[:, 1] += O[1]
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr
def PSNR(y_true, y_pred, shave_border=4):
target_data = np.array(y_true, dtype=np.float32)
ref_data = np.array(y_pred, dtype=np.float32)
diff = ref_data - target_data
if shave_border > 0:
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
rmse = np.sqrt(np.mean(np.power(diff, 2)))
return 20 * np.log10(255. / rmse)
def cal_ssim(img1, img2):
K = [0.01, 0.03]
L = 255
kernelX = cv2.getGaussianKernel(11, 1.5)
window = kernelX * kernelX.T
M, N = np.shape(img1)
C1 = (K[0] * L) ** 2
C2 = (K[1] * L) ** 2
img1 = np.float64(img1)
img2 = np.float64(img2)
mu1 = signal.convolve2d(img1, window, 'valid')
mu2 = signal.convolve2d(img2, window, 'valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
mssim = np.mean(ssim_map)
return mssim