且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

在矩阵的每一行中找到1的列索引

更新时间:2022-01-03 00:05:21

您实际上可以通过简单的矩阵乘法来解决这个问题.

You can actually solve this with simple matrix multiplication.

result = M * (1:size(M, 2)).';

     3
     1
     2
     1
     3

这可以通过将您的M x 3矩阵乘以3 x 1数组(其中3x1的元素只是[1; 2; 3])来工作.简而言之,对于M的每一行,使用3 x 1数组执行逐元素乘法. M行中只有1会在结果中产生任何结果.然后,将按元素相乘的结果相加.因为每行只有一个"1",所以结果将是该1所在的列索引.

This works by multiplying your M x 3 matrix with a 3 x 1 array where the elements of the 3x1 are simply [1; 2; 3]. Briefly, for each row of M, element-wise multiplication is performed with the 3 x 1 array. Only the 1's in the row of M will yield anything in the result. Then the result of this element-wise multiplication is summed. Because you only have one "1" per row, the result is going to be the column index where that 1 is located.

例如M的第一行.

element_wise_multiplication = [0 0 1] .* [1 2 3]

    [0, 0, 3]

sum(element_wise_multiplication)

    3

更新

基于 @reyryeng

Based on the solutions provided by @reyryeng and @Luis below, I decided to run a comparison to see how the performance of the various methods compared.

要设置测试矩阵(M),我创建了原始问题中指定形式的矩阵,并更改了行数.使用randi([1 nCols], size(M, 1))随机选择哪一列为1.使用timeit分析执行时间.

To setup the test matrix (M) I created a matrix of the form specified in the original question and varied the number of rows. Which column had the 1 was chosen randomly using randi([1 nCols], size(M, 1)). Execution times were analyzed using timeit.

使用类型为doubleM(MATLAB的默认值)运行时,将获得以下执行时间.

When run using M of type double (MATLAB's default) you get the following execution times.

如果Mlogical,则由于矩阵乘法在矩阵乘法之前必须转换为数字类型这一事实而大受打击,而其他两个在性能上都有所改善

If M is a logical, then the matrix multiplication takes a hit due to the fact that it has to be converted to a numerical type prior to matrix multiplication, whereas the other two have a bit of a performance improvement.

这是我使用的测试代码.

Here is the test code that I used.

sizes = round(linspace(100, 100000, 100));
times = zeros(numel(sizes), 3);

for k = 1:numel(sizes)
    M = generateM(sizes(k));
    times(k,1) = timeit(@()M * (1:size(M, 2)).');
    M = generateM(sizes(k));
    times(k,2) = timeit(@()max(M, [], 2), 2);
    M = generateM(sizes(k));
    times(k,3) = timeit(@()find(M.'), 2);
end

figure
plot(range, times / 1000);
legend({'Multiplication', 'Max', 'Find'})
xlabel('Number of rows in M')
ylabel('Execution Time (ms)')

function M = generateM(nRows)
    M = zeros(nRows, 3);
    col = randi([1 size(M, 2)], 1, size(M, 1));
    M(sub2ind(size(M), 1:numel(col), col)) = 1;
end