import logging import cv2 import numpy as np from scipy import signal import torch import os def pil2numpy(image): np_image = np.array(image) if len(np_image.shape) == 2: np_image = np_image[:,:,None] return np_image def round_func(input): # Backward Pass Differentiable Approximation (BPDA) # This is equivalent to replacing round function (non-differentiable) # with an identity function (differentiable) only when backward, forward_value = torch.round(input) out = input.clone() out.data = forward_value.data return out def logger_info(logger_name, log_path='default_logger.log'): log = logging.getLogger(logger_name) if log.hasHandlers(): print('LogHandlers exist!') else: print('LogHandlers setup!') level = logging.INFO formatter = logging.Formatter( '%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') fh = logging.FileHandler(log_path, mode='a') fh.setFormatter(formatter) log.setLevel(level) log.addHandler(fh) sh = logging.StreamHandler() sh.setFormatter(formatter) log.addHandler(sh) def modcrop(image, modulo): if len(image.shape) == 2: sz = image.shape sz = sz - np.mod(sz, modulo) image = image[0:sz[0], 0:sz[1]] elif len(image.shape) == 3: sz = image.shape[0:2] sz = sz - np.mod(sz, modulo) image = image[0:sz[0], 0:sz[1], :] else: raise NotImplementedError return image