import os 
from pathlib import Path
import numpy as np
import cv2 
from scipy import signal
from skimage.metrics import structural_similarity
from PIL import Image
import argparse

import time
from datetime import datetime
import ray
ray.init(num_cpus=16, num_gpus=0, ignore_reinit_error=True, log_to_driver=False)

parser = argparse.ArgumentParser()
parser.add_argument("path_to_dataset", type=str)
parser.add_argument("--scale", type=int, default=4)
args = parser.parse_args()

def cal_ssim(img1, img2):
    K = [0.01, 0.03]
    L = 255
    kernelX = cv2.getGaussianKernel(11, 1.5)
    window = kernelX * kernelX.T

    M, N = np.shape(img1)

    C1 = (K[0] * L) ** 2
    C2 = (K[1] * L) ** 2
    img1 = np.float64(img1)
    img2 = np.float64(img2)

    mu1 = signal.convolve2d(img1, window, 'valid')
    mu2 = signal.convolve2d(img2, window, 'valid')

    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2

    sigma1_sq = signal.convolve2d(img1 * img1, window, 'valid') - mu1_sq
    sigma2_sq = signal.convolve2d(img2 * img2, window, 'valid') - mu2_sq
    sigma12 = signal.convolve2d(img1 * img2, window, 'valid') - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    mssim = np.mean(ssim_map)
    return mssim

def PSNR(y_true, y_pred, shave_border=4):
    target_data = np.array(y_true, dtype=np.float32)
    ref_data = np.array(y_pred, dtype=np.float32)

    diff = ref_data - target_data
    if shave_border > 0:
        diff = diff[shave_border:-shave_border, shave_border:-shave_border]
    rmse = np.sqrt(np.mean(np.power(diff, 2)))

    return 20 * np.log10(255. / rmse)

def _rgb2ycbcr(img, maxVal=255):
    O = np.array([[16],
                  [128],
                  [128]])
    T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
                  [-0.148223529411765, -0.290992156862745, 0.439215686274510],
                  [0.439215686274510, -0.367788235294118, -0.071427450980392]])

    if maxVal == 1:
        O = O / 255.0

    t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
    t = np.dot(t, np.transpose(T))
    t[:, 0] += O[0]
    t[:, 1] += O[1]
    t[:, 2] += O[2]
    ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])

    return ycbcr


def modcrop(image, modulo):
    if len(image.shape) == 2:
        sz = image.shape
        sz = sz - np.mod(sz, modulo)
        image = image[0:sz[0], 0:sz[1]]
    elif image.shape[2] == 3:
        sz = image.shape[0:2]
        sz = sz - np.mod(sz, modulo)
        image = image[0:sz[0], 0:sz[1], :]
    else:
        raise NotImplementedError
    return image

scale = args.scale

dataset_path = Path(args.path_to_dataset)
hr_path = dataset_path / "HR/"
lr_path = dataset_path / f"LR_bicubic/X{scale}/"


print(hr_path, lr_path)

hr_files = os.listdir(hr_path)
lr_files = os.listdir(lr_path)

@ray.remote(num_cpus=1)
def benchmark_image_pair(hr_image_path, lr_image_path, interpolation_function):
    hr_image = cv2.imread(hr_image_path)
    lr_image = cv2.imread(lr_image_path)
    
    hr_image = hr_image[:,:,::-1] # BGR -> RGB
    lr_image = lr_image[:,:,::-1] # BGR -> RGB
    
    start_time = datetime.now()
    upscaled_lr_image = interpolation_function(lr_image, scale)
    processing_time = datetime.now() - start_time

    hr_image = modcrop(hr_image, scale)
    upscaled_lr_image = upscaled_lr_image

    psnr = PSNR(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0])
    cpsnr = PSNR(hr_image, upscaled_lr_image)

    cv2_psnr = cv2.PSNR(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0])
    cv2_cpsnr = cv2.PSNR(hr_image, upscaled_lr_image)

    ssim = cal_ssim(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0])
    cv2_ssim = cal_ssim(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0])
    ssim_scikit, diff = structural_similarity(_rgb2ycbcr(hr_image)[:,:,0], _rgb2ycbcr(upscaled_lr_image)[:,:,0], full=True, data_range=255.0)
    cv2_scikit_ssim, diff = structural_similarity(cv2.cvtColor(hr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], cv2.cvtColor(upscaled_lr_image, cv2.COLOR_RGB2YCrCb)[:,:,0], full=True, data_range=255.0)

    return ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time.total_seconds()
    

def benchmark_interpolation(interpolation_function):
    psnrs, cpsnrs, ssims = [], [], []
    cv2_psnrs, cv2_cpsnrs, scikit_ssims = [], [], []
    cv2_scikit_ssims = []
    cv2_ssims = []
    tasks = []
    for hr_name, lr_name in zip(hr_files, lr_files):
        hr_image_path = str(hr_path / hr_name)
        lr_image_path = str(lr_path / lr_name)
        tasks.append(benchmark_image_pair.remote(hr_image_path, lr_image_path, interpolation_function))

    ready_refs, remaining_refs = ray.wait(tasks, num_returns=1, timeout=None)
    while len(remaining_refs) > 0:
        print(f"\rReady {len(ready_refs)}/{len(hr_files)}",  end="     ")
        ready_refs, remaining_refs = ray.wait(tasks, num_returns=len(ready_refs)+1, timeout=None)

    for task in tasks:
        ssim, cv2_ssim, ssim_scikit, cv2_scikit_ssim, psnr, cpsnr, cv2_psnr, cv2_cpsnr, processing_time = ray.get(task)
        ssims.append(ssim)
        cv2_ssims.append(cv2_ssim)
        scikit_ssims.append(ssim_scikit)
        cv2_scikit_ssims.append(cv2_scikit_ssim)        
        psnrs.append(psnr) 
        cpsnrs.append(cpsnr)
        cv2_psnrs.append(cv2_psnr)
        cv2_cpsnrs.append(cv2_cpsnr)
        processing_times.append(processing_time)

    print()
    print(f"AVG PSNR: {np.mean(psnrs):.2f} PSNR + _rgb2ycbcr") 
    print(f"AVG PSNR: {np.mean(cv2_psnrs):.2f} cv2.PSNR + cv2.cvtColor")
    print(f"AVG cPSNR: {np.mean(cpsnrs):.2f} PSNR")    
    print(f"AVG cPSNR: {np.mean(cv2_cpsnrs):.2f} cv2.PSNR ")
    print(f"AVG SSIM: {np.mean(ssims):.4f} cal_ssim + _rgb2ycbcr")
    print(f"AVG SSIM: {np.mean(cv2_ssims):.4f} cal_ssim + cv2.cvtColor")
    print(f"AVG SSIM: {np.mean(scikit_ssims):.4f} structural_similarity + _rgb2ycbcr")
    print(f"AVG SSIM: {np.mean(cv2_scikit_ssims):.4f} structural_similarity + cv2.cvtColor")
    print(f"AVG Time s: {np.percentile(processing_times, q=0.9)}")
    print(f"{np.mean(psnrs):.2f},{np.mean(cv2_psnrs):.2f},{np.mean(cpsnrs):.2f},{np.mean(cv2_cpsnrs):.2f},{np.mean(ssims):.4f},{np.mean(cv2_ssims):.4f},{np.mean(scikit_ssims):.4f},{np.mean(cv2_scikit_ssims):.4f},{np.percentile(processing_times, q=0.9)}")  

def cv2_interpolation(image, scale):
    scaled_image = cv2.resize(
        image, 
        None, None, 
        fx=scale, fy=scale, 
        interpolation=cv2.INTER_CUBIC
    )
    return scaled_image

def pillow_interpolation(image, scale):
    image = Image.fromarray(image[:,:,::-1])
    width, height = int(image.width * scale), int(image.height * scale)
    scaled_image = image.resize((width, height), resample=Image.Resampling.BICUBIC)
    return np.array(scaled_image)[:,:,::-1]

print("cv2 bicubic interpolation")
benchmark_interpolation(cv2_interpolation)
print()
print("pillow bicubic interpolation")
benchmark_interpolation(pillow_interpolation)