You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
3.2 KiB
Python
113 lines
3.2 KiB
Python
import logging
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from scipy import signal
|
|
import torch
|
|
import os
|
|
|
|
def round_func(input):
|
|
# Backward Pass Differentiable Approximation (BPDA)
|
|
# This is equivalent to replacing round function (non-differentiable)
|
|
# with an identity function (differentiable) only when backward,
|
|
forward_value = torch.round(input)
|
|
out = input.clone()
|
|
out.data = forward_value.data
|
|
return out
|
|
|
|
def logger_info(logger_name, log_path='default_logger.log'):
|
|
log = logging.getLogger(logger_name)
|
|
if log.hasHandlers():
|
|
print('LogHandlers exist!')
|
|
else:
|
|
print('LogHandlers setup!')
|
|
level = logging.INFO
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
|
|
fh = logging.FileHandler(log_path, mode='a')
|
|
fh.setFormatter(formatter)
|
|
log.setLevel(level)
|
|
log.addHandler(fh)
|
|
# print(len(log.handlers))
|
|
|
|
sh = logging.StreamHandler()
|
|
sh.setFormatter(formatter)
|
|
log.addHandler(sh)
|
|
|
|
|
|
def modcrop(image, modulo):
|
|
if len(image.shape) == 2:
|
|
sz = image.shape
|
|
sz = sz - np.mod(sz, modulo)
|
|
image = image[0:sz[0], 0:sz[1]]
|
|
elif image.shape[2] == 3:
|
|
sz = image.shape[0:2]
|
|
sz = sz - np.mod(sz, modulo)
|
|
image = image[0:sz[0], 0:sz[1], :]
|
|
else:
|
|
raise NotImplementedError
|
|
return image
|
|
|
|
|
|
def _rgb2ycbcr(img, maxVal=255):
|
|
O = np.array([[16],
|
|
[128],
|
|
[128]])
|
|
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
|
|
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
|
|
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
|
|
|
|
if maxVal == 1:
|
|
O = O / 255.0
|
|
|
|
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
|
|
t = np.dot(t, np.transpose(T))
|
|
t[:, 0] += O[0]
|
|
t[:, 1] += O[1]
|
|
t[:, 2] += O[2]
|
|
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
|
|
|
|
return ycbcr
|
|
|
|
|
|
def PSNR(y_true, y_pred, shave_border=4):
|
|
target_data = np.array(y_true, dtype=np.float32)
|
|
ref_data = np.array(y_pred, dtype=np.float32)
|
|
|
|
diff = ref_data - target_data
|
|
if shave_border > 0:
|
|
diff = diff[shave_border:-shave_border, shave_border:-shave_border]
|
|
rmse = np.sqrt(np.mean(np.power(diff, 2)))
|
|
|
|
return 20 * np.log10(255. / rmse)
|
|
|
|
|
|
def cal_ssim(img1, img2):
|
|
K = [0.01, 0.03]
|
|
L = 255
|
|
kernelX = cv2.getGaussianKernel(11, 1.5)
|
|
window = kernelX * kernelX.T
|
|
|
|
M, N = np.shape(img1)
|
|
|
|
C1 = (K[0] * L) ** 2
|
|
C2 = (K[1] * L) ** 2
|
|
img1 = np.float64(img1)
|
|
img2 = np.float64(img2)
|
|
|
|
mu1 = signal.convolve2d(img1, window, 'valid')
|
|
mu2 = signal.convolve2d(img2, window, 'valid')
|
|
|
|
mu1_sq = mu1 * mu1
|
|
mu2_sq = mu2 * mu2
|
|
mu1_mu2 = mu1 * mu2
|
|
|
|
sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
|
|
sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
|
|
sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2
|
|
|
|
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
|
mssim = np.mean(ssim_map)
|
|
return mssim
|
|
|