mirror of https://github.com/da0c/DL_Course_SamU
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
4.9 KiB
Cython
122 lines
4.9 KiB
Cython
4 years ago
|
import numpy as np
|
||
|
cimport numpy as np
|
||
|
cimport cython
|
||
|
|
||
|
# DTYPE = np.float64
|
||
|
# ctypedef np.float64_t DTYPE_t
|
||
|
|
||
|
ctypedef fused DTYPE_t:
|
||
|
np.float32_t
|
||
|
np.float64_t
|
||
|
|
||
|
def im2col_cython(np.ndarray[DTYPE_t, ndim=4] x, int field_height,
|
||
|
int field_width, int padding, int stride):
|
||
|
cdef int N = x.shape[0]
|
||
|
cdef int C = x.shape[1]
|
||
|
cdef int H = x.shape[2]
|
||
|
cdef int W = x.shape[3]
|
||
|
|
||
|
cdef int HH = (H + 2 * padding - field_height) / stride + 1
|
||
|
cdef int WW = (W + 2 * padding - field_width) / stride + 1
|
||
|
|
||
|
cdef int p = padding
|
||
|
cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.pad(x,
|
||
|
((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
|
||
|
|
||
|
cdef np.ndarray[DTYPE_t, ndim=2] cols = np.zeros(
|
||
|
(C * field_height * field_width, N * HH * WW),
|
||
|
dtype=x.dtype)
|
||
|
|
||
|
# Moving the inner loop to a C function with no bounds checking works, but does
|
||
|
# not seem to help performance in any measurable way.
|
||
|
|
||
|
im2col_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
|
||
|
field_height, field_width, padding, stride)
|
||
|
return cols
|
||
|
|
||
|
|
||
|
@cython.boundscheck(False)
|
||
|
cdef int im2col_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
|
||
|
np.ndarray[DTYPE_t, ndim=4] x_padded,
|
||
|
int N, int C, int H, int W, int HH, int WW,
|
||
|
int field_height, int field_width, int padding, int stride) except? -1:
|
||
|
cdef int c, ii, jj, row, yy, xx, i, col
|
||
|
|
||
|
for c in range(C):
|
||
|
for yy in range(HH):
|
||
|
for xx in range(WW):
|
||
|
for ii in range(field_height):
|
||
|
for jj in range(field_width):
|
||
|
row = c * field_width * field_height + ii * field_height + jj
|
||
|
for i in range(N):
|
||
|
col = yy * WW * N + xx * N + i
|
||
|
cols[row, col] = x_padded[i, c, stride * yy + ii, stride * xx + jj]
|
||
|
|
||
|
|
||
|
|
||
|
def col2im_cython(np.ndarray[DTYPE_t, ndim=2] cols, int N, int C, int H, int W,
|
||
|
int field_height, int field_width, int padding, int stride):
|
||
|
cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
|
||
|
cdef int HH = (H + 2 * padding - field_height) / stride + 1
|
||
|
cdef int WW = (W + 2 * padding - field_width) / stride + 1
|
||
|
cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * padding, W + 2 * padding),
|
||
|
dtype=cols.dtype)
|
||
|
|
||
|
# Moving the inner loop to a C-function with no bounds checking improves
|
||
|
# performance quite a bit for col2im.
|
||
|
col2im_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
|
||
|
field_height, field_width, padding, stride)
|
||
|
if padding > 0:
|
||
|
return x_padded[:, :, padding:-padding, padding:-padding]
|
||
|
return x_padded
|
||
|
|
||
|
|
||
|
@cython.boundscheck(False)
|
||
|
cdef int col2im_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
|
||
|
np.ndarray[DTYPE_t, ndim=4] x_padded,
|
||
|
int N, int C, int H, int W, int HH, int WW,
|
||
|
int field_height, int field_width, int padding, int stride) except? -1:
|
||
|
cdef int c, ii, jj, row, yy, xx, i, col
|
||
|
|
||
|
for c in range(C):
|
||
|
for ii in range(field_height):
|
||
|
for jj in range(field_width):
|
||
|
row = c * field_width * field_height + ii * field_height + jj
|
||
|
for yy in range(HH):
|
||
|
for xx in range(WW):
|
||
|
for i in range(N):
|
||
|
col = yy * WW * N + xx * N + i
|
||
|
x_padded[i, c, stride * yy + ii, stride * xx + jj] += cols[row, col]
|
||
|
|
||
|
|
||
|
@cython.boundscheck(False)
|
||
|
@cython.wraparound(False)
|
||
|
cdef col2im_6d_cython_inner(np.ndarray[DTYPE_t, ndim=6] cols,
|
||
|
np.ndarray[DTYPE_t, ndim=4] x_padded,
|
||
|
int N, int C, int H, int W, int HH, int WW,
|
||
|
int out_h, int out_w, int pad, int stride):
|
||
|
|
||
|
cdef int c, hh, ww, n, h, w
|
||
|
for n in range(N):
|
||
|
for c in range(C):
|
||
|
for hh in range(HH):
|
||
|
for ww in range(WW):
|
||
|
for h in range(out_h):
|
||
|
for w in range(out_w):
|
||
|
x_padded[n, c, stride * h + hh, stride * w + ww] += cols[c, hh, ww, n, h, w]
|
||
|
|
||
|
|
||
|
def col2im_6d_cython(np.ndarray[DTYPE_t, ndim=6] cols, int N, int C, int H, int W,
|
||
|
int HH, int WW, int pad, int stride):
|
||
|
cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
|
||
|
cdef int out_h = (H + 2 * pad - HH) / stride + 1
|
||
|
cdef int out_w = (W + 2 * pad - WW) / stride + 1
|
||
|
cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * pad, W + 2 * pad),
|
||
|
dtype=cols.dtype)
|
||
|
|
||
|
col2im_6d_cython_inner(cols, x_padded, N, C, H, W, HH, WW, out_h, out_w, pad, stride)
|
||
|
|
||
|
if pad > 0:
|
||
|
return x_padded[:, :, pad:-pad, pad:-pad]
|
||
|
return x_padded
|