You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
1.5 KiB
Python
55 lines
1.5 KiB
Python
import logging
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from scipy import signal
|
|
import torch
|
|
import os
|
|
|
|
|
|
def pil2numpy(image):
|
|
np_image = np.array(image)
|
|
if len(np_image.shape) == 2:
|
|
np_image = np_image[:,:,None]
|
|
return np_image
|
|
|
|
def round_func(input):
|
|
# Backward Pass Differentiable Approximation (BPDA)
|
|
# This is equivalent to replacing round function (non-differentiable)
|
|
# with an identity function (differentiable) only when backward,
|
|
forward_value = torch.round(input)
|
|
out = input.clone()
|
|
out.data = forward_value.data
|
|
return out
|
|
|
|
def logger_info(logger_name, log_path='default_logger.log'):
|
|
log = logging.getLogger(logger_name)
|
|
if log.hasHandlers():
|
|
print('LogHandlers exist!')
|
|
else:
|
|
print('LogHandlers setup!')
|
|
level = logging.INFO
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
|
|
fh = logging.FileHandler(log_path, mode='a')
|
|
fh.setFormatter(formatter)
|
|
log.setLevel(level)
|
|
log.addHandler(fh)
|
|
|
|
sh = logging.StreamHandler()
|
|
sh.setFormatter(formatter)
|
|
log.addHandler(sh)
|
|
|
|
|
|
def modcrop(image, modulo):
|
|
if len(image.shape) == 2:
|
|
sz = image.shape
|
|
sz = sz - np.mod(sz, modulo)
|
|
image = image[0:sz[0], 0:sz[1]]
|
|
elif len(image.shape) == 3:
|
|
sz = image.shape[0:2]
|
|
sz = sz - np.mod(sz, modulo)
|
|
image = image[0:sz[0], 0:sz[1], :]
|
|
else:
|
|
raise NotImplementedError
|
|
return image |