如何使數字圖片分類任務的準確率達到99%以上?
最新
05-04
任務描述:訓練模型,識別圖片上的手繪數字。
數據集描述:數據集train.csv和test.csv包含從0-9的手繪數字的灰度圖像,每張圖片為28×28 = 784個像素值,像素值的範圍在0-255之間。
train.csv文件:一共785=(1+784)列,第一列稱為「標籤」,即手繪圖片上的數字,其餘列為與標籤對應的像素值。
test.csv文件:與train.csv文件相同,只是不含有標籤列。
提交文件應採用以下格式:對於測試集中的每個圖像,輸出一行,其中包含圖像標識ID和您預測的數字例如,如果您預測第一張圖像是3,第二張圖像是7,第三張圖像是8,那麼您的提交文件將如下所示:
ID,label
1,3
2,7
3,8
本文使用的深度學習框架
tensorflow-gpu版本:1.2.0
keras版本:2.1.6
建立模型
Step 1:導入要用的包
Step 2:讀取數據
Step 3:查看數據集
Step 4:查看數據集分布是否平衡及是否有缺失值
對每一類圖片進行數量統計
train_y.value_counts()
查看數據集是否有缺失值
train_x.isnull().any().describe()
Step 5:數據歸一化和重整數據集
Step 6:將類別標籤轉化成one-hot編碼
Step 7:劃分train_x數據集為訓練集和驗證集
Step 8:建立CNN模型
Step 9:運行模型
Step 10:畫出訓練集和驗證集的準確率和損失圖
Step 11:畫出混淆矩陣
Step 12:展示預測錯誤的結果
思考:如何進一步改進模型?
Step 13:對測試集做預測
TAG:計算材料學與大數據 |