更新时间:2023-12-01 22:19:22
2D卷积的内核大小如下
The kernel size for 2D convolution is as follows
[ height, width, input_filters, output_filters ]
第三维尺寸与输入过滤器的尺寸相同.这很关键.
The third dimension is of the same size as the input filters. This is critical.
让我们考虑一下如何手动进行卷积.步骤如下:
Let's consider how convolution is done manually. Here are the steps:
输出是每个补丁的过滤器.
The output is the filters for each patch.
鉴于我们知道卷积中的权重是整形的像 [高度,宽度,input_filters,output_filters]
,我们想正确地应用 [高度,宽度]
的蒙版,可以像这样广播该蒙版
Given that we know the weights in the convolution are shaped
like [ height, width, input_filters, output_filters ]
and we want to properly apply a mask of [ height, width ]
, can can just broadcast that mask like so
masked_weight = weight * mask.reshape([height,width,1,1])
我们的Tensorflow keras层可以这样写
Our Tensorflow keras layer could be written like so
class MaskedConv2D(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
super(MaskedConv2D, self).__init__()
self.conv2d = tf.keras.layers.Conv2D(*args, **kwargs)
def build(self, input_shape):
self.conv2d.build(input_shape[0])
self._convolution_op = self.conv2d._convolution_op
def masked_convolution_op(self, filters, kernel, mask):
return self._convolution_op(filters, tf.math.multiply(kernel, tf.reshape(mask, mask.shape + [1,1] )))
def call(self, inputs):
x, mask = inputs
self.conv2d._convolution_op = functools.partial(self.masked_convolution_op, mask=mask)
return self.conv2d.call(x)
我们可以使用以下脚本对其进行测试
and we can test it with the following script
mcon = MaskedConv2D(filters=2,kernel_size=[3,3])
# hack: initialize it by running some data through it
mcon((np.ones([1,4,4,3], dtype=np.float32), tf.constant([[1,1,0],[1,1,1],[0,1,1]], dtype=tf.float32)))
# set all the weights to 1 for testing
mcon.set_weights([ np.ones([3,3,3,2]) , np.zeros([2]) ])
# pass in a matrix of 1s and mask out 2 elements for each input filter
mcon((np.ones([1,4,4,3], dtype=np.float32), tf.constant([[1,1,0],[1,1,1],[0,1,1]], dtype=tf.float32)))
具有可预测的输出
<tf.Tensor: shape=(1, 2, 2, 2), dtype=float32, numpy=
array([[[[21., 21.],
[21., 21.]],
[[21., 21.],
[21., 21.]]]], dtype=float32)>