遵循步骤····
- 探索数据准备的选项
- 尝试多个模型
- 列出最佳模型
- 用GirdSearchCV对其超参数进行微调,等等
- 改进:分析错误类型
查看混淆矩阵
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3) conf_mx = confusion_matrix(y_train, y_train_pred) conf_mx#### 查看混淆矩阵的图像表示
plt.matshow(conf_mx, cmap=plt.cm.gray) save_fig("confusion_matrix_plot", tight_layout=False) plt.show() 或者 def plot_confusion_matrix(matrix): """If you prefer color and a colorbar""" fig = plt.figure(figsize=(8,8)) ax = fig.add_subplot(111) cax = ax.matshow(matrix) fig.colorbar(cax)
重新绘制
- 每个值除以相应类别中的图片数量,获得错误率
- 用0填充对角线
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
save_fig(“confusion_matrix_errors_plot”, tight_layout=False)
plt.show()
- 看到许多图片被错误分为8/9
- 8和9经常和其他图混淆
- 错误不完全对称
解决方案
- 收集更多训练数据
- 开发新特征
- 对图片预处理(平移,旋转)