且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

TensorBoard中的Tensorflow混淆矩阵

更新时间:2023-12-03 23:27:58

Here is something I have put together That works reasonably well. Still need to adjust a few things like the tick placements etc.

Here is the function that will pretty much do everything for you.

from textwrap import wrap
import re
import itertools
import tfplot
import matplotlib
import numpy as np
from sklearn.metrics import confusion_matrix



def plot_confusion_matrix(correct_labels, predict_labels, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False):
''' 
Parameters:
    correct_labels                  : These are your true classification categories.
    predict_labels                  : These are you predicted classification categories
    labels                          : This is a lit of labels which will be used to display the axix labels
    title='Confusion matrix'        : Title for your matrix
    tensor_name = 'MyFigure/image'  : Name for the output summay tensor

Returns:
    summary: TensorFlow summary 

Other itema to note:
    - Depending on the number of category and the data , you may have to modify the figzie, font sizes etc. 
    - Currently, some of the ticks dont line up due to rotations.
'''
cm = confusion_matrix(correct_labels, predict_labels, labels=labels)
if normalize:
    cm = cm.astype('float')*10 / cm.sum(axis=1)[:, np.newaxis]
    cm = np.nan_to_num(cm, copy=True)
    cm = cm.astype('int')

np.set_printoptions(precision=2)
###fig, ax = matplotlib.figure.Figure()

fig = matplotlib.figure.Figure(figsize=(7, 7), dpi=320, facecolor='w', edgecolor='k')
ax = fig.add_subplot(1, 1, 1)
im = ax.imshow(cm, cmap='Oranges')

classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels]
classes = ['\n'.join(wrap(l, 40)) for l in classes]

tick_marks = np.arange(len(classes))

ax.set_xlabel('Predicted', fontsize=7)
ax.set_xticks(tick_marks)
c = ax.set_xticklabels(classes, fontsize=4, rotation=-90,  ha='center')
ax.xaxis.set_label_position('bottom')
ax.xaxis.tick_bottom()

ax.set_ylabel('True Label', fontsize=7)
ax.set_yticks(tick_marks)
ax.set_yticklabels(classes, fontsize=4, va ='center')
ax.yaxis.set_label_position('left')
ax.yaxis.tick_left()

for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black")
fig.set_tight_layout(True)
summary = tfplot.figure.to_summary(fig, tag=tensor_name)
return summary

#

And here is the rest of the code that you will need to call this functions.

''' confusion matrix summaries '''
img_d_summary_dir = os.path.join(checkpoint_dir, "summaries", "img")
img_d_summary_writer = tf.summary.FileWriter(img_d_summary_dir, sess.graph)
img_d_summary = plot_confusion_matrix(correct_labels, predict_labels, labels, tensor_name='dev/cm')
img_d_summary_writer.add_summary(img_d_summary, current_step)

Confuse away!!!