错误分析

遵循步骤····

  • 探索数据准备的选项
  • 尝试多个模型
  • 列出最佳模型
  • 用GirdSearchCV对其超参数进行微调,等等
  • 改进:分析错误类型

查看混淆矩阵

y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
#### 查看混淆矩阵的图像表示
plt.matshow(conf_mx, cmap=plt.cm.gray)
save_fig("confusion_matrix_plot", tight_layout=False)
plt.show()
或者
def plot_confusion_matrix(matrix):
    """If you prefer color and a colorbar"""
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    cax = ax.matshow(matrix)
    fig.colorbar(cax)

重新绘制

  • 每个值除以相应类别中的图片数量,获得错误率
  • 用0填充对角线

    row_sums = conf_mx.sum(axis=1, keepdims=True)
    norm_conf_mx = conf_mx / row_sums

np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
save_fig(“confusion_matrix_errors_plot”, tight_layout=False)
plt.show()

  • 看到许多图片被错误分为8/9
  • 8和9经常和其他图混淆
  • 错误不完全对称

解决方案

  • 收集更多训练数据
  • 开发新特征
  • 对图片预处理(平移,旋转)

参见Sklearn03