You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
1.2 KiB
Python
28 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
class PercievePattern():
|
|
def __init__(self, receptive_field_idxes=[[0,0],[0,1],[1,0],[1,1]]):
|
|
self.receptive_field_idxes = np.array(receptive_field_idxes)
|
|
self.window_size = np.max(self.receptive_field_idxes) + 1
|
|
self.receptive_field_idxes = [
|
|
self.receptive_field_idxes[0,0]*self.window_size + self.receptive_field_idxes[0,1],
|
|
self.receptive_field_idxes[1,0]*self.window_size + self.receptive_field_idxes[1,1],
|
|
self.receptive_field_idxes[2,0]*self.window_size + self.receptive_field_idxes[2,1],
|
|
self.receptive_field_idxes[3,0]*self.window_size + self.receptive_field_idxes[3,1],
|
|
]
|
|
|
|
def __call__(self, x):
|
|
b,c,h,w = x.shape
|
|
x = F.pad(x, pad=[0,self.window_size-1,0,self.window_size-1], mode='replicate')
|
|
x = F.unfold(input=x, kernel_size=self.window_size)
|
|
x = torch.stack([
|
|
x[:,self.receptive_field_idxes[0],:],
|
|
x[:,self.receptive_field_idxes[1],:],
|
|
x[:,self.receptive_field_idxes[2],:],
|
|
x[:,self.receptive_field_idxes[3],:]
|
|
], 2)
|
|
x = x.reshape(x.shape[0]*x.shape[1], 1, 2, 2)
|
|
return x |