You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

28 lines
1.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PercievePattern():
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]):
self.receptive_field_idxes = np.array(receptive_field_idxes)
self.window_size = np.max(self.receptive_field_idxes) + 1
self.receptive_field_idxes = [
self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1],
self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1],
self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1],
self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1],
]
def __call__(self, x):
b,c,h,w = x.shape
x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
x = F.unfold(input=x, kernel_size=self.window_size)
x = torch.stack([
x[:,self.receptive_field_idxes[0],:],
x[:,self.receptive_field_idxes[1],:],
x[:,self.receptive_field_idxes[2],:],
x[:,self.receptive_field_idxes[3],:]
], 2)
x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2)
return x