import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class PercievePattern(): def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]): self.receptive_field_idxes = np.array(receptive_field_idxes) self.window_size = np.max(self.receptive_field_idxes) + 1 self.receptive_field_idxes = [ self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1], self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1], self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1], self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1], ] def __call__(self, x): b,c,h,w = x.shape x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate') x = F.unfold(input=x, kernel_size=self.window_size) x = torch.stack([ x[:,self.receptive_field_idxes[0],:], x[:,self.receptive_field_idxes[1],:], x[:,self.receptive_field_idxes[2],:], x[:,self.receptive_field_idxes[3],:] ], 2) x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2) return x