當前位置:
首頁 > 知識 > 簡單能看懂的感知機演算法PLA

簡單能看懂的感知機演算法PLA

什麼是感知機「Perceptron」

PLA全稱是Perceptron Linear Algorithm,即線性感知機演算法,屬於一種最簡單的感知機(Perceptron)模型。

感知機模型是機器學習二分類問題中的一個非常簡單的模型。它的基本結構如下圖所示:

簡單能看懂的感知機演算法PLA

其中,x

i

xi是輸入,w

i

wi表示權重係數,b

b表示偏移常數。感知機的線性輸出為:

scores=∑

i

N

w

i

x

i

+b

scores=∑iNwixi+b

為了簡化計算,通常我們將b

b作為權重係數的一個維度,即w

0

w0。同時,將輸入x

x擴展一個維度,為1。這樣,上式簡化為:

scores=∑

i

N+1

w

i

x

i

scores=∑iN+1wixi

scores

scores是感知機的輸出,接下來就要對scores

scores進行判斷:

  • 若scores≥0
  • scores≥0,則y
  • ^
  • =1
  • y^=1(正類)
  • 若scores<0
  • scores<0,則y
  • ^
  • =?1
  • y^=?1(負類)

以上就是線性感知機模型的基本概念,簡單來說,它由線性得分計算閾值比較兩個過程組成,最後根據比較結果判斷樣本屬於正類還是負類。

PLA理論解釋

對於二分類問題,可以使用感知機模型來解決。PLA的基本原理就是逐點修正,首先在超平面上隨意取一條分類面,統計分類錯誤的點;然後隨機對某個錯誤點就行修正,即變換直線的位置,使該錯誤點得以修正;接著再隨機選擇一個錯誤點進行糾正,分類面不斷變化,直到所有的點都完全分類正確了,就得到了最佳的分類面。

利用二維平面例子來進行解釋,第一種情況是錯誤地將正樣本(y=1)分類為負樣本(y=-1)。此時,wx<0

wx<0,即w

w與x

x的夾角大於90度,分類線l

l的兩側。修正的方法是讓夾角變小,修正w

w值,使二者位於直線同側:

w:=w+x=w+yx

w:=w+x=w+yx

修正過程示意圖如下所示:

簡單能看懂的感知機演算法PLA

第二種情況是錯誤地將負樣本(y=-1)分類為正樣本(y=1)。此時,wx>0

wx>0,即w

w與x

x的夾角小於90度,分類線l

l的同一側。修正的方法是讓夾角變大,修正w

w值,使二者位於直線兩側:

w:=w?x=w+yx

w:=w?x=w+yx

修正過程示意圖如下所示:

簡單能看懂的感知機演算法PLA

經過兩種情況分析,我們發現PLA每次w

w的更新表達式都是一樣的:w:=w+yx

w:=w+yx。掌握了每次w

w的優化表達式,那麼PLA就能不斷地將所有錯誤的分類樣本糾正並分類正確。

數據準備

導入數據

數據集存放在』../data/』目錄下,該數據集包含了100個樣本,正負樣本各50,特徵維度為2。

import numpy as np
import pandas as pd
data = pd.read_csv("./data/data1.csv", header=None)
# 樣本輸入,維度(100,2)
X = data.iloc[:,:2].values
# 樣本輸出,維度(100,)
y = data.iloc[:,2].values
1
2
3
4
5
6
7
8

數據分類與可視化

下面我們在二維平面上繪出正負樣本的分布情況。

import matplotlib.pyplot as plt
plt.scatter(X[:50, 0], X[:50, 1], color="blue", marker="o", label="Positive")
plt.scatter(X[50:, 0], X[50:, 1], color="red", marker="x", label="Negative")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend(loc = "upper left")
plt.title("Original Data")
plt.show()
1
2
3
4
5
6
7
8
9

簡單能看懂的感知機演算法PLA

PLA演算法

特徵歸一化

首先分別對兩個特徵進行歸一化處理,即:

X=X?μ

σ

X=X?μσ

其中,μ

μ是特徵均值,σ

σ是特徵標準差。

# 均值
u = np.mean(X, axis=0)
# 方差
v = np.std(X, axis=0)
X = (X - u) / v
# 作圖
plt.scatter(X[:50, 0], X[:50, 1], color="blue", marker="o", label="Positive")
plt.scatter(X[50:, 0], X[50:, 1], color="red", marker="x", label="Negative")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend(loc = "upper left")
plt.title("Normalization data")
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

簡單能看懂的感知機演算法PLA

直線初始化

# X加上偏置項
X = np.hstack((np.ones((X.shape[0],1)), X))
# 權重初始化
w = np.random.randn(3,1)
1
2
3
4

顯示初始化直線位置:

# 直線第一個坐標(x1,y1)
x1 = -2
y1 = -1 / w[2] * (w[0] * 1 + w[1] * x1)
# 直線第二個坐標(x2,y2)
x2 = 2
y2 = -1 / w[2] * (w[0] * 1 + w[1] * x2)
# 作圖
plt.scatter(X[:50, 1], X[:50, 2], color="blue", marker="o", label="Positive")
plt.scatter(X[50:, 1], X[50:, 2], color="red", marker="x", label="Negative")
plt.plot([x1,x2], [y1,y2],"r")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend(loc = "upper left")
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14

簡單能看懂的感知機演算法PLA

由上圖可見,一般隨機生成的分類線,錯誤率很高。

計算scores,更新權重

接下來,計算scores,得分函數與閾值0做比較,大於零則y

^

=1

y^=1,小於零則y

^

=?1

y^=?1

s = np.dot(X, w)
y_pred = np.ones_like(y) # 預測輸出初始化
loc_n = np.where(s < 0)[0] # 大於零索引下標
y_pred[loc_n] = -1
1
2
3
4

接著,從分類錯誤的樣本中選擇一個,使用PLA更新權重係數w

w。

# 第一個分類錯誤的點
t = np.where(y != y_pred)[0][0]
# 更新權重w
w += y[t] * X[t, :].reshape((3,1))
1
2
3
4

迭代更新訓練

更新權重w

w是個迭代過程,只要存在分類錯誤的樣本,就不斷進行更新,直至所有的樣本都分類正確。(注意,前提是正負樣本完全可分)

for i in range(100):
s = np.dot(X, w)
y_pred = np.ones_like(y)
loc_n = np.where(s < 0)[0]
y_pred[loc_n] = -1
num_fault = len(np.where(y != y_pred)[0])
print("第%2d次更新,分類錯誤的點個數:%2d" % (i, num_fault))
if num_fault == 0:
break
else:
t = np.where(y != y_pred)[0][0]
w += y[t] * X[t, :].reshape((3,1))
1
2
3
4
5
6
7
8
9
10
11
12

迭代完畢後,得到更新後的權重係數w

w,繪製此時的分類直線是什麼樣子。

# 直線第一個坐標(x1,y1)
x1 = -2
y1 = -1 / w[2] * (w[0] * 1 + w[1] * x1)
# 直線第二個坐標(x2,y2)
x2 = 2
y2 = -1 / w[2] * (w[0] * 1 + w[1] * x2)
# 作圖
plt.scatter(X[:50, 1], X[:50, 2], color="blue", marker="o", label="Positive")
plt.scatter(X[50:, 1], X[50:, 2], color="red", marker="x", label="Negative")
plt.plot([x1,x2], [y1,y2],"r")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend(loc = "upper left")
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14

簡單能看懂的感知機演算法PLA

其實,PLA演算法的效率還算不錯,只需要數次更新就能找到一條能將所有樣本完全分類正確的分類線。所以得出結論,對於正負樣本線性可分的情況,PLA能夠在有限次迭代後得到正確的分類直線。

總結與疑問

這裡導入的數據本身就是線性可分的,可以使用PCA來得到分類直線。但是,如果數據不是線性可分,即找不到一條直線能夠將所有的正負樣本完全分類正確,這種情況下,似乎PCA會永遠更新迭代下去,卻找不到正確的分類線。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 程序員小新人學習 的精彩文章:

jQuery Mobile 安裝
XML CDATA

TAG:程序員小新人學習 |