更新时间:2023-12-03 23:45:04
I use these 2 functions to calc confusion matrix (as it defined in sklearn):
# rewrite sklearn method to torch
def confusion_matrix_1(y_true, y_pred):
N = max(max(y_true), max(y_pred)) + 1
y_true = torch.tensor(y_true, dtype=torch.long)
y_pred = torch.tensor(y_pred, dtype=torch.long)
return torch.sparse.LongTensor(
torch.stack([y_true, y_pred]),
torch.ones_like(y_true, dtype=torch.long),
torch.Size([N, N])).to_dense()
# weird trick with bincount
def confusion_matrix_2(y_true, y_pred):
N = max(max(y_true), max(y_pred)) + 1
y_true = torch.tensor(y_true, dtype=torch.long)
y_pred = torch.tensor(y_pred, dtype=torch.long)
y = N * y_true + y_pred
y = torch.bincount(y)
if len(y) < N * N:
y = torch.cat(y, torch.zeros(N * N - len(y), dtype=torch.long))
y = y.reshape(N, N)
return T
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
confusion_matrix_1(y_true, y_pred)
# tensor([[2, 0, 0],
# [0, 0, 1],
# [1, 0, 2]])
在类数量较少的情况下,第二个功能更快.
Second function is faster in case of small number of classes.
%%timeit
confusion_matrix_1(y_true, y_pred)
# 102 µs ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
confusion_matrix_2(y_true, y_pred)
# 25 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)