且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

过滤N-D numpy数组并仅保留特定元素

更新时间:2022-06-21 22:58:03

我通过将 reduce np.logical_or 组合起来并设置掩码,然后遍历应该保持:

I set up a mask by combining reduce with np.logical_or and iterated over the values that should remain:

import functools
import numpy as np

arr = np.array([[[36,  1, 72],
        [76, 50, 23],
        [28, 68, 17],
        [84, 75, 69]],
       [[ 5, 15, 93],
        [92, 92, 88],
        [11, 54, 21],
        [87, 76, 81]]])

# Set the values that should not
# be set to zero
vals = [11, 50, 72]

# Create a mask by looping over the above values
mask = functools.reduce(np.logical_or, (arr==val for val in vals))

masked = np.where(mask, arr, 0.)

print(masked)
> array([[[ 0.,  0., 72.],
        [ 0., 50.,  0.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.]],

       [[ 0.,  0.,  0.],
        [ 0.,  0.,  0.],
        [11.,  0.,  0.],
        [ 0.,  0.,  0.]]])