且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

在numpy中查找对角线总和(更快)

更新时间:2023-11-09 21:41:22

有可能使用stride_tricks的解决方案.这部分是基于此问题的答案中可用的大量信息,但是问题恰恰不同,我认为,不应将其视为重复项.这是应用于方阵的基本思想,有关实现更通用解决方案的函数,请参见下文.

There's a possible solution using stride_tricks. This is based in part on the plethora of information available in the answers to this question, but the problem is just different enough, I think, not to count as a duplicate. Here's the basic idea, applied to a square matrix -- see below for a function implementing the more general solution.

>>> cols = 8
>>> a = numpy.arange(cols * cols).reshape((cols, cols))
>>> fill = numpy.zeros((cols - 1) * cols, dtype='i8').reshape((cols - 1, cols))
>>> stacked = numpy.vstack((a, fill, a))
>>> major_stride, minor_stride = stacked.strides
>>> strides = major_stride, minor_stride * (cols + 1)
>>> shape = (cols * 2 - 1, cols)
>>> numpy.lib.stride_tricks.as_strided(stacked, shape, strides)
array([[ 0,  9, 18, 27, 36, 45, 54, 63],
       [ 8, 17, 26, 35, 44, 53, 62,  0],
       [16, 25, 34, 43, 52, 61,  0,  0],
       [24, 33, 42, 51, 60,  0,  0,  0],
       [32, 41, 50, 59,  0,  0,  0,  0],
       [40, 49, 58,  0,  0,  0,  0,  0],
       [48, 57,  0,  0,  0,  0,  0,  0],
       [56,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  7],
       [ 0,  0,  0,  0,  0,  0,  6, 15],
       [ 0,  0,  0,  0,  0,  5, 14, 23],
       [ 0,  0,  0,  0,  4, 13, 22, 31],
       [ 0,  0,  0,  3, 12, 21, 30, 39],
       [ 0,  0,  2, 11, 20, 29, 38, 47],
       [ 0,  1, 10, 19, 28, 37, 46, 55]])
>>> diags = numpy.lib.stride_tricks.as_strided(stacked, shape, strides)
>>> diags.sum(axis=1)
array([252, 245, 231, 210, 182, 147, 105,  56,   7,  21,  42,  70, 105,
       147, 196])

当然,我不知道这实际上有多快.但是我敢打赌它会比Python列表理解要快.

Of course, I have no idea how fast this will actually be. But I bet it will be faster than a Python list comprehension.

为方便起见,这是一个完全通用的diagonals函数.假设您想沿最长轴移动对角线:

For convenience, here's a fully general diagonals function. It assumes you want to move the diagonal along the longest axis:

def diagonals(a):
    rows, cols = a.shape
    if cols > rows:
        a = a.T
        rows, cols = a.shape
    fill = numpy.zeros(((cols - 1), cols), dtype=a.dtype)
    stacked = numpy.vstack((a, fill, a))
    major_stride, minor_stride = stacked.strides
    strides = major_stride, minor_stride * (cols + 1)
    shape = (rows + cols - 1, cols)
    return numpy.lib.stride_tricks.as_strided(stacked, shape, strides)