2019/9/15

以 python 實作迴歸與分類程式

迴歸


學習資料


取得學習資料文字檔 click.csv


x,y
235,591
216,539
148,413
35,310
85,308
...

先利用 matplotlib 繪製到圖表上


import numpy as np
import matplotlib.pyplot as plt

train = np.loadtxt('click.csv', delimiter=',', skiprows=1)
train_x = train[:,0]
train_y = train[:,1]

plt.plot(train_x, train_y, 'o')
plt.show()


另外針對原始的學習資料,進行標準化(z-score正規化),也就是將資料平均轉換為 0,分散轉換為1。其中 𝜇 是所有資料的平均,𝜎 是所有資料的標準差。這樣處理後,會讓參數收斂更快。


\(z^{(i)} = \frac{x^{(i)} - 𝜇}{𝜎}\)


import numpy as np
import matplotlib.pyplot as plt

train = np.loadtxt('click.csv', delimiter=',', skiprows=1)
train_x = train[:,0]
train_y = train[:,1]

# 標準化
mu = train_x.mean()
sigma = train_x.std()
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)

plt.plot(train_z, train_y, 'o')
plt.show()


一次函數


先使用一次目標函數 \(f_𝜃(x)\)


\({f_𝜃(x)=𝜃_0+𝜃_1x}​\)


\({E(𝜃)= \frac{1}{2} \sum_{i=1}^{n}( y^{(i)} - f_𝜃(x^{(i)})^2 }​\)


\(𝜃_0, 𝜃_1​\) 可任意選擇初始值


\(𝜃_0, 𝜃_1\) 的參數更新式為


\(𝜃_0 := 𝜃_0 - 𝜂 \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )\)


\(𝜃_1 := 𝜃_1 - 𝜂 \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )x^{(i)}\)


用這個方法,就可以找出正確的 \(𝜃_0, 𝜃_1\)


其中 𝜂 是任意數值,先設定為 \(10^{-3}\) 試試看。一般來說,會指定要處理的次數,有時會比較參數更新前後,目標函數的值,如果差異不大,就直接結束。另外 \(𝜃_0, 𝜃_1\) 必須同時一起更新。


import numpy as np
import matplotlib.pyplot as plt

# 載入學習資料
train = np.loadtxt('click.csv', delimiter=',', dtype='int', skiprows=1)
train_x = train[:,0]
train_y = train[:,1]

# 標準化
mu = train_x.mean()
sigma = train_x.std()
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)

# 任意選擇初始值
theta0 = np.random.rand()
theta1 = np.random.rand()

# 預測函數
def f(x):
    return theta0 + theta1 * x

# 目標函數 E(𝜃)
def E(x, y):
    return 0.5 * np.sum((y - f(x)) ** 2)

# 學習率
ETA = 1e-3

# 誤差
diff = 1

# 更新次數
count = 0

# 重複學習
error = E(train_z, train_y)
while diff > 1e-2:
    # 暫存更新結果
    tmp_theta0 = theta0 - ETA * np.sum((f(train_z) - train_y))
    tmp_theta1 = theta1 - ETA * np.sum((f(train_z) - train_y) * train_z)

    # 更新參數
    theta0 = tmp_theta0
    theta1 = tmp_theta1

    # 計算誤差
    current_error = E(train_z, train_y)
    diff = error - current_error
    error = current_error

    # log
    count += 1
    log = '{}次數: theta0 = {:.3f}, theta1 = {:.3f}, 誤差 = {:.4f}'
    print(log.format(count, theta0, theta1, diff))

# 繪製學習資料與預測函數的直線
x = np.linspace(-3, 3, 100)
plt.plot(train_z, train_y, 'o')
plt.plot(x, f(x))
plt.show()

測試結果


391次數: theta0 = 428.991, theta1 = 93.444, 誤差 = 0.0109
392次數: theta0 = 428.994, theta1 = 93.445, 誤差 = 0.0105
393次數: theta0 = 428.997, theta1 = 93.446, 誤差 = 0.0101
394次數: theta0 = 429.000, theta1 = 93.446, 誤差 = 0.0097


驗證


可輸入 x 預測點擊數,但因為剛剛有將學習資料正規化,預測資料也必須正規化


>>> f(standardize(100))
370.96741051658194
>>> f(standardize(500))
928.9775823086377

二次多項式迴歸


\(f_𝜃(x) = 𝜃_0 + 𝜃_1x + 𝜃_2x^2\) 要增加 \( 𝜃_2\) 這個參數


目標的誤差函數 \({E(𝜃)= \frac{1}{2} \sum_{i=1}^{n}( y^{(i)} - f_𝜃(x^{(i)})^2 }​\)


因為有多筆學習資料,可將資料以矩陣方式處理


\( X = \begin{bmatrix}
(x^{(1)})^T\\
(x^{(2)})^T\\
\cdot \\
\cdot \\
(x^{(n)})^T \\
\end{bmatrix}
= \begin{bmatrix}
1 & x^{(1)} & (x^{(1)})^2 \\
1 & x^{(2)} & (x^{(2)})^2 \\
\cdot \\
\cdot \\
1 & x^{(n)} & (x^{(n)})^2 \\
\end{bmatrix} ​\)


\(f_𝜃(x) = \begin{bmatrix}
1 & x^{(1)} & (x^{(1)})^2 \\
1 & x^{(2)} & (x^{(2)})^2 \\
\cdot \\
\cdot \\
1 & x^{(n)} & (x^{(n)})^2 \\
\end{bmatrix} \begin{bmatrix}
𝜃_0 \\
𝜃_1 \\
𝜃_2 \\
\end{bmatrix}
= \begin{bmatrix}
𝜃_0 + 𝜃_1 x^{(1)} + 𝜃_2 (x^{(1)})^2\\
𝜃_0 + 𝜃_1 x^{(2)} + 𝜃_2 (x^{(2)})^2\\
\cdot \\
\cdot \\
𝜃_0 + 𝜃_1 x^{(n)} + 𝜃_2 (x^{(n)})^2\\
\end{bmatrix}\)


第j 項參數的更新式定義為


\(𝜃_j := 𝜃_j - 𝜂 \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )x_j^{(i)}​\)


可將 \( ( f_𝜃(x^{(i)} )-y^{(i)} ) ​\) 以及 \(x_j^{(i)}​\) 這兩部分各自以矩陣方式處理


\( f= \begin{bmatrix}
( f_𝜃(x^{(1)} )-y^{(1)} )\\
( f_𝜃(x^{(2)} )-y^{(2)} )\\
\cdot \\
\cdot \\
( f_𝜃(x^{(n)} )-y^{(n)} ) \\
\end{bmatrix} \)


\( x_0 = \begin{bmatrix}
x_0^{(1)} \\
x_0^{(2)}\\
\cdot \\
\cdot \\
x_0^{(n)} \\
\end{bmatrix} \)


\( \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )x_0^{(i)} = f^Tx_0 \)


分別考慮三個參數


\( x_0 = \begin{bmatrix}
x_0^{(1)} \\
x_0^{(2)}\\
\cdot \\
\cdot \\
x_0^{(n)} \\
\end{bmatrix} ,
x_1 = \begin{bmatrix}
x^{(1)} \\
x^{(2)}\\
\cdot \\
\cdot \\
x^{(n)} \\
\end{bmatrix} ,
x_2 = \begin{bmatrix}
(x^{(1)})^2 \\
(x^{(2)})^2\\
\cdot \\
\cdot \\
(x^{(n)})^2 \\
\end{bmatrix}\)


\( X = \begin{bmatrix}
x_0 & x_1 & x_2
\end{bmatrix}
= \begin{bmatrix}
1 & x^{(1)} & (x^{(1)})^2 \\
1 & x^{(2)} & (x^{(2)})^2\\
\cdot \\
\cdot \\
1 & x^{(n)} & (x^{(n)})^2 \\
\end{bmatrix} \)


使用 \( f^TX\) 就可以一次更新三個參數


import numpy as np
import matplotlib.pyplot as plt

# 讀取學習資料
train = np.loadtxt('click.csv', delimiter=',', dtype='int', skiprows=1)
train_x = train[:,0]
train_y = train[:,1]

# 標準化
mu = train_x.mean()
sigma = train_x.std()
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)

# 任意初始值
theta = np.random.rand(3)

# 學習資料轉換為矩陣
def to_matrix(x):
    return np.vstack([np.ones(x.size), x, x ** 2]).T

X = to_matrix(train_z)

# 預測函數
def f(x):
    return np.dot(x, theta)

# 目標函數
def E(x, y):
    return 0.5 * np.sum((y - f(x)) ** 2)

# 學習率
ETA = 1e-3

# 誤差
diff = 1

# 更新次數
count = 0

# 重複學習
error = E(X, train_y)
while diff > 1e-2:
    # 更新參數
    theta = theta - ETA * np.dot(f(X) - train_y, X)

    # 計算誤差
    current_error = E(X, train_y)
    diff = error - current_error
    error = current_error

    # log
    count += 1
    log = '{}次: theta = {}, 誤差 = {:.4f}'
    print(log.format(count, theta, diff))

# 繪製學習資料與預測函數
x = np.linspace(-3, 3, 100)
plt.plot(train_z, train_y, 'o')
plt.plot(x, f(to_matrix(x)))
plt.show()




也可以將重複停止的條件,改為均方誤差


目標的誤差函數 \({E(𝜃)= \frac{1}{n} \sum_{i=1}^{n}( y^{(i)} - f_𝜃(x^{(i)})^2 }\)


import numpy as np
import matplotlib.pyplot as plt

# 讀取學習資料
train = np.loadtxt('click.csv', delimiter=',', dtype='int', skiprows=1)
train_x = train[:,0]
train_y = train[:,1]

# 標準化
mu = train_x.mean()
sigma = train_x.std()
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)

# 任意初始值
theta = np.random.rand(3)

# 學習資料轉換為矩陣
def to_matrix(x):
    return np.vstack([np.ones(x.size), x, x ** 2]).T

X = to_matrix(train_z)

# 預測函數
def f(x):
    return np.dot(x, theta)

# 目標函數
def MSE(x, y):
    return ( 1 / x.shape[0] * np.sum( (y-f(x)))**2 )

# 學習率
ETA = 1e-3

# 誤差
diff = 1

# 更新次數
count = 0

# 均方誤差的歷史資料
errors = []

# 重複學習
errors.append( MSE(X, train_y) )
while diff > 1e-2:
    # 更新參數
    theta = theta - ETA * np.dot(f(X) - train_y, X)

    # 計算誤差
    errors.append( MSE(X, train_y) )
    diff = errors[-2] - errors[-1]

    # log
    count += 1
    log = '{}次: theta = {}, 誤差 = {:.4f}'
    print(log.format(count, theta, diff))

# 繪製重複次數 與誤差的關係
x = np.arange(len(errors))
plt.plot(x, errors)
plt.show()


隨機梯度下降法


隨機選擇一項學習資料,套用在參數的更新上,例如選擇第 k 項。


\(𝜃_j := 𝜃_j - 𝜂 ( f_𝜃(x^{(k)} )-y^{(k)} )x_j^{(k)}\)


import numpy as np
import matplotlib.pyplot as plt

# 載入學習資料
train = np.loadtxt('click.csv', delimiter=',', dtype='int', skiprows=1)
train_x = train[:,0]
train_y = train[:,1]

# 標準化
mu = train_x.mean()
sigma = train_x.std()
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)

# 任意選擇初始值
theta = np.random.rand(3)

# 學習資料轉換為矩陣
def to_matrix(x):
    return np.vstack([np.ones(x.size), x, x ** 2]).T

X = to_matrix(train_z)

# 預測函數
def f(x):
    return np.dot(x, theta)

# 均方差
def MSE(x, y):
    return (1 / x.shape[0]) * np.sum((y - f(x)) ** 2)

# 學習率
ETA = 1e-3

# 誤差
diff = 1

# 更新次數
count = 0

# 重複學習
error = MSE(X, train_y)
while diff > 1e-2:
    # 排列學習資料所需的隨機排列
    p = np.random.permutation(X.shape[0])
    # 將學習資料以隨機方式取出,並用隨機梯度下降法 更新參數
    for x, y in zip(X[p,:], train_y[p]):
        theta = theta - ETA * (f(x) - y) * x

    # 計算跟前一個誤差的差距
    current_error = MSE(X, train_y)
    diff = error - current_error
    error = current_error

    # log
    count += 1
    log = '{}回目: theta = {}, 差分 = {:.4f}'
    print(log.format(count, theta, diff))

# 列印結果
x = np.linspace(-3, 3, 100)
plt.plot(train_z, train_y, 'o')
plt.plot(x, f(to_matrix(x)))
plt.show()

多元迴歸


如果要處理多元迴歸,就跟多項式迴歸一樣改用矩陣,但在多元迴歸中要注意,要對所有變數 \(x_1, x_2, x_3\)都進行標準化。


\(z_1^{(i)} = \frac{x_1^{(i)} - 𝜇_1}{𝜎_1} \)


\(z_2^{(i)} = \frac{x_2^{(i)} - 𝜇_2}{𝜎_2} \)


\(z_3^{(i)} = \frac{x_3^{(i)} - 𝜇_3}{𝜎_3} \)


分類(感知器)


使用 images1.csv 資料


x1,x2,y
153,432,-1
220,262,-1
118,214,-1
474,384,1
485,411,1
233,430,-1
...

先將原始資料標記在圖表上,y=1 用圓圈,y=-1 用


import numpy as np
import matplotlib.pyplot as plt

# 載入學習資料
train = np.loadtxt('images1.csv', delimiter=',', skiprows=1)
train_x = train[:,0:2]
train_y = train[:,2]

# 繪圖
x1 = np.arange(0, 500)
plt.plot(train_x[train_y ==  1, 0], train_x[train_y ==  1, 1], 'o')
plt.plot(train_x[train_y == -1, 0], train_x[train_y == -1, 1], 'x')
plt.savefig('1.png')


  • 識別函數 \(f_w(x)\) 就是給定向量 \(x\) 後,回傳 1 或 -1 的函數,用來判斷橫向或縱向。

\(f_w(x) = \left\{\begin{matrix} 1 \quad (w \cdot x \geq 0) \\ -1 \quad (w \cdot x < 0) \end{matrix}\right.\)


  • 權重更新式

\(w := \left\{\begin{matrix} w + y^{(i)}x^{(i)} \quad (f_w(x) \neq y^{(i)}) \\ w \quad \quad \quad \quad (f_w(x) = y^{(i)}) \end{matrix}\right.\)


感知器使用精度作為停止的標準比較好,但目前先直接設定訓練次數


最後繪製以權重向量為法線的直線方程式


\(w \cdot x = w_1x_1 + w_2x_2 = 0​\)


\(x_2 = - \frac{w_1}{w2} x_1​\)


import numpy as np
import matplotlib.pyplot as plt

# 載入學習資料
train = np.loadtxt('images1.csv', delimiter=',', skiprows=1)
train_x = train[:,0:2]
train_y = train[:,2]

# 任意初始值
w = np.random.rand(2)

# 識別函數,判斷矩形是橫向或縱向
def f(x):
    if np.dot(w, x) >= 0:
        return 1
    else:
        return -1

# 重複次數
epoch = 10

# 更新次數
count = 0

# 學習權重
for _ in range(epoch):
    for x, y in zip(train_x, train_y):
        if f(x) != y:
            w = w + y * x

            # log
            count += 1
            print('{}次數: w = {}'.format(count, w))

# 繪圖
x1 = np.arange(0, 500)
plt.plot(train_x[train_y ==  1, 0], train_x[train_y ==  1, 1], 'o')
plt.plot(train_x[train_y == -1, 0], train_x[train_y == -1, 1], 'x')
plt.plot(x1, -w[0] / w[1] * x1, linestyle='dashed')
plt.savefig("1.png")


驗證


python -i classification1_perceptron.py
>>> f([200,100])
1
>>> f([100,200])
-1

分類(邏輯迴歸)


邏輯迴歸要先修改學習資料,橫向為 1 ,縱向為 0


x1,x2,y
153,432,0
220,262,0
118,214,0
474,384,1
485,411,1
...

預測函數就是 S 函數


\(f_𝜃(x) = \frac{1}{1 + exp(-𝜃^Tx)}\)


參數更新式為


\(𝜃_j := 𝜃_j - 𝜂 \sum_{i=1}^{n}( f_𝜃(x^{(i)}) - y^{(i)} )x_j^{(i)}\)


可用矩陣處理,轉換時要加上 \(x_0\),且設定為 1,如果當 \(f_𝜃(x) \geq 0.5\),也就是 \(𝜃^T x >0​\) ,就判定為橫向。


將 \(Q^Tx = 0 \) 整理後,就可得到一條直線


\(Q^Tx = 𝜃_0x_0 + 𝜃_1x_1 + 𝜃_2x_2 = 𝜃_0 +𝜃_1x_1+𝜃_2x_2 =0\)


\(x_2 = - \frac{𝜃_0 + 𝜃_1x_2}{𝜃_2}​\)


import numpy as np
import matplotlib.pyplot as plt

# 載入學習資料
train = np.loadtxt('images2.csv', delimiter=',', skiprows=1)
train_x = train[:,0:2]
train_y = train[:,2]

# 任意初始值
theta = np.random.rand(3)

# 以平均及標準差進行標準化
mu = train_x.mean(axis=0)
sigma = train_x.std(axis=0)
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)

# 轉換為矩陣,加上 x0
def to_matrix(x):
    x0 = np.ones([x.shape[0], 1])
    return np.hstack([x0, x])

X = to_matrix(train_z)

# 預測函數 S函數
def f(x):
    return 1 / (1 + np.exp(-np.dot(x, theta)))

# 識別函數
def classify(x):
    return (f(x) >= 0.5).astype(np.int)

# 學習率
ETA = 1e-3

# 重複次數
epoch = 5000

# 更新次數
count = 0

# 重複學習
for _ in range(epoch):
    theta = theta - ETA * np.dot(f(X) - train_y, X)

    # log
    count += 1
    print('{}次數: theta = {}'.format(count, theta))

# 繪製圖形
x0 = np.linspace(-2, 2, 100)
plt.plot(train_z[train_y == 1, 0], train_z[train_y == 1, 1], 'o')
plt.plot(train_z[train_y == 0, 0], train_z[train_y == 0, 1], 'x')
plt.plot(x0, -(theta[0] + theta[1] * x0) / theta[2], linestyle='dashed')
# plt.show()
plt.savefig("機器學習4_coding_.png")


驗證


這樣的意思是 200x100 的矩形有 91.6% 的機率會是橫向


>>> f(to_matrix(standardize([[200,100], [100,200]])))
array([0.91604483, 0.03009514])

可再轉化為 1 與 0


>>> classify(to_matrix(standardize([[200,100], [100,200]])))
array([1, 0])

線性不可分離的分類


學習資料為 data3.csv


x1,x2,y
0.54508775,2.34541183,0
0.32769134,13.43066561,0
4.42748117,14.74150395,0
2.98189041,-1.81818172,1
4.02286274,8.90695686,1
2.26722613,-6.61287392,1
-2.66447221,5.05453871,1
-1.03482441,-1.95643469,1
4.06331548,1.70892541,1
2.89053966,6.07174283,0
2.26929206,10.59789814,0
4.68096051,13.01153161,1
1.27884366,-9.83826738,1
-0.1485496,12.99605136,0
-0.65113893,10.59417745,0
3.69145079,3.25209182,1
-0.63429623,11.6135625,0
0.17589959,5.84139826,0
0.98204409,-9.41271559,1
-0.11094911,6.27900499,0

先將學習資料繪製到圖表上看起來無法用一條直線來分類,增加 \(x_1^2​\) 進行分類



參數變成四個,將 \(Q^Tx = 0 ​\) 整理後,就可得到一條曲線


\(Q^Tx = 𝜃_0x_0 + 𝜃_1x_1 + 𝜃_2x_2 +𝜃_3x_1^2 = 𝜃_0 +𝜃_1x_1+𝜃_2x_2 +𝜃_3x_1^2 =0​\)


\(x_2 = - \frac{𝜃_0 + 𝜃_1x_2 +𝜃_3x_1^2}{𝜃_2}\)


import numpy as np
import matplotlib.pyplot as plt

# 載入學習資料
train = np.loadtxt('data3.csv', delimiter=',', skiprows=1)
train_x = train[:,0:2]
train_y = train[:,2]

# 任意初始值
theta = np.random.rand(4)

# 標準化
mu = train_x.mean(axis=0)
sigma = train_x.std(axis=0)
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)

# 轉換為矩陣,加上 x0, x3
def to_matrix(x):
    x0 = np.ones([x.shape[0], 1])
    x3 = x[:,0,np.newaxis] ** 2
    return np.hstack([x0, x, x3])

X = to_matrix(train_z)

# 預測函數 S函數
def f(x):
    return 1 / (1 + np.exp(-np.dot(x, theta)))

# 識別函數
def classify(x):
    return (f(x) >= 0.5).astype(np.int)

# 學習率
ETA = 1e-3

# 重複次數
epoch = 5000

# 更新次數
count = 0

# 重複學習
for _ in range(epoch):
    theta = theta - ETA * np.dot(f(X) - train_y, X)

    # log
    count += 1
    print('{}次數: theta = {}'.format(count, theta))

# 繪製圖形
x1 = np.linspace(-2, 2, 100)
x2 = -(theta[0] + theta[1] * x1 + theta[3] * x1 ** 2) / theta[2]
plt.plot(train_z[train_y == 1, 0], train_z[train_y == 1, 1], 'o')
plt.plot(train_z[train_y == 0, 0], train_z[train_y == 0, 1], 'x')
plt.plot(x1, x2, linestyle='dashed')
# plt.show()
plt.savefig("機器學習4_coding_.png")


分類的精度,就是在全部的資料中,能夠被正確分類的 TP與 TN 佔的比例,可表示為


\( Accuracy = \frac{TP + TN}{TP+FP+FN+TN} ​\)


# 精度
accuracies = []

# 重複學習
for _ in range(epoch):
    theta = theta - ETA * np.dot(f(X) - train_y, X)

    # 計算精度
    result = classify(X) == train_y
    accuracy = len(result[result ==True]) / len(result)
    accuracies.append(accuracy)

# 繪製圖形
x = np.arange(len(accuracies))
plt.plot(x, accuracies)
plt.savefig("機器學習4_coding_.png")

計算精度,繪製圖表



隨機梯度下降法


import numpy as np
import matplotlib.pyplot as plt

# 載入學習資料
train = np.loadtxt('data3.csv', delimiter=',', skiprows=1)
train_x = train[:,0:2]
train_y = train[:,2]

# 任意初始值
theta = np.random.rand(4)

# 標準化
mu = train_x.mean(axis=0)
sigma = train_x.std(axis=0)
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)

# 轉換為矩陣,加上 x0, x3
def to_matrix(x):
    x0 = np.ones([x.shape[0], 1])
    x3 = x[:,0,np.newaxis] ** 2
    return np.hstack([x0, x, x3])

X = to_matrix(train_z)

# 預測函數 S函數
def f(x):
    return 1 / (1 + np.exp(-np.dot(x, theta)))

# 識別函數
def classify(x):
    return (f(x) >= 0.5).astype(np.int)

# 學習率
ETA = 1e-3

# 重複次數
epoch = 5000

# 更新次數
count = 0

# 重複學習
for _ in range(epoch):
    # 以隨機梯度下降法更新參數
    p = np.random.permutation(X.shape[0])
    for x, y in zip(X[p,:], train_y[p]):
        theta = theta - ETA * (f(x) - y) * x

    # log
    count += 1
    print('{}次數: theta = {}'.format(count, theta))

# 繪製圖形
x1 = np.linspace(-2, 2, 100)
x2 = -(theta[0] + theta[1] * x1 + theta[3] * x1 ** 2) / theta[2]
plt.plot(train_z[train_y == 1, 0], train_z[train_y == 1, 1], 'o')
plt.plot(train_z[train_y == 0, 0], train_z[train_y == 0, 1], 'x')
plt.plot(x1, x2, linestyle='dashed')
# plt.show()
plt.savefig("機器學習4_coding_.png")


正規化


首先考慮這樣的函數


\(g(x) = 0.1(x^3 + x^2 + x)\)


產生一些雜訊的學習資料,並繪製圖表



import numpy as np
import matplotlib.pyplot as plt

# 原始真正的函數
def g(x):
    return 0.1 * (x ** 3 + x ** 2 + x)

# 適當地利用原本的函數,加上一些雜訊,產生學習資料
train_x = np.linspace(-2, 2, 8)
train_y = g(train_x) + np.random.randn(train_x.size) * 0.05


plt.clf()
x=np.linspace(-2, 2, 100)
plt.plot(train_x, train_y, 'o')
plt.plot(x, g(x), linestyle='dashed')
plt.ylim(-1,2)
plt.savefig("機器學習4_coding_1.png")


# 標準化
mu = train_x.mean()
sigma = train_x.std()
def standardize(x):
    return (x - mu) / sigma

train_z = standardize(train_x)

# 產生學習資料的矩陣 (10次多項式)
def to_matrix(x):
    return np.vstack([
        np.ones(x.size),
        x,
        x ** 2,
        x ** 3,
        x ** 4,
        x ** 5,
        x ** 6,
        x ** 7,
        x ** 8,
        x ** 9,
        x ** 10
    ]).T

X = to_matrix(train_z)

# 參數使用任意初始值
theta = np.random.randn(X.shape[1])

# 預測函數
def f(x):
    return np.dot(x, theta)

# 目標函數
def E(x, y):
    return 0.5 * np.sum((y - f(x)) ** 2)

# 正規化常數
LAMBDA = 0.5

# 學習率
ETA = 1e-4

# 誤差
diff = 1

# 重複學習
error = E(X, train_y)
while diff > 1e-6:
    theta = theta - ETA * (np.dot(f(X) - train_y, X))

    current_error = E(X, train_y)
    diff = error - current_error
    error = current_error

theta1 = theta

# 加上正規化項
theta = np.random.randn(X.shape[1])
diff = 1
error = E(X, train_y)
while diff > 1e-6:
    # 正規化項,因為偏差項不適用於正規化,所以為 0,當 j>0,正規化項為 𝜆 * 𝜃
    reg_term = LAMBDA * np.hstack([0, theta[1:]])
    # 適用於正規化項,更新參數
    theta = theta - ETA * (np.dot(f(X) - train_y, X) + reg_term)

    current_error = E(X, train_y)
    diff = error - current_error
    error = current_error

theta2 = theta

# 繪製圖表
plt.clf()
plt.plot(train_z, train_y, 'o')
z = standardize(np.linspace(-2, 2, 100))
theta = theta1 # 無正規化的結果,虛線
plt.plot(z, f(to_matrix(z)), linestyle='dashed')
theta = theta2 # 有正規化的結果,實線
plt.plot(z, f(to_matrix(z)))
# plt.show()
plt.savefig("機器學習4_coding_2.png")


References


練好機器學習的基本功 範例下載

2019/9/8

機器學習_評估


確認模型的正確性,針對建立的模型,以評估的方法進行機器學習。


迴歸與分類都是定義預測時所需要的函數 \(f_𝜃(x)\),藉由學習資料來找到函數中的 𝜃。方法是將目標函數進行微分,求得參數更新式。但實際上需要的,是預測函數得到的預測值,例如花多少廣告費可得多少點擊率。


我們需要量測函數 \(f_𝜃(x)\) 的正確性(精確度),但像多元迴歸,無法用圖形表示,就需要將機器學習的模型的精度,以定量的方式表示,然後表現其精確度,這就是模型評估。因為參數是透過學習資料修正而來的,對於原本的學習資料來說,參數是正確的,但對於新的資料就不一定了。


交叉驗證 Cross Validation


將學習資料區分為學習以及測試使用,用測試用的資料評估模型,一般來說,學習用的資料會比較多。


在回歸問題中,函數 \(f_𝜃(x)\) 是透過習資料修正而來的,對於原本的學習資料來說,參數是正確的,但對於測試部分的資料就不一定正確了。


假設測試資料有 n 筆,將測試用的資料,透過模型求得結果,再跟原本的實際值比較得到誤差。以下為均方差 MSE (Mean Square Error)


\( \frac{1}{n} \sum_{i=1}^{n} ( y^{(i)} - f_𝜃(x^{(i)} ) )^2 ​\)


當均方差越小,表示模型的精度很高。


分類問題的驗證


因迴歸問題是連續值,可用誤差進行驗證,分類問題是邏輯迴歸,回到矩形是橫向或縱向的問題上,會有四種狀況。


分類結果 原本是橫向 原本是縱向
橫向 正確 錯誤
縱向 錯誤 正確

可將二元分類轉換為這樣的表格


分類結果 + -
+ Positive True Positive (TP) False Positive (FP)
- Negative False Negative (FN) True Negative (TN)

分類的精度,就是在全部的資料中,能夠被正確分類的 TP與 TN 佔的比例,可表示為


\( Accuracy = \frac{TP + TN}{TP+FP+FN+TN} ​\)


ex: 100 筆測試資料,有 80 筆正確


\( Accuracy = \frac{80}{100} = 0.8\)


精確率與回現率


有時候只用精確度評估分類結果會遇到問題。例如當原始資料有大量資料 95 筆為 Negative,只有一點點資料 5 筆為 Positive,如果將全部測試資料都分類為 False 的模型,Accuracy 為 0.95,精確度很高但實際上這是一個錯誤的預測模型。


因此要導入其他評估的指標。


  • 精確率 Precision: 分類為 Positive 的資料中,實際為 Positive 的資料數的比例。值越高,代表分類錯誤的越少。


    \( Precision = \frac{TP}{TP+FP}\)

  • 回現率 Recall: Positive 資料中,實際上分類為 Positive 的資料數。值越高,代表沒有被遺漏,且被正確分類的比例。


    \( Recall = \frac{TP}{TP+FN} \)


通常精確率與回現率,只要有ㄧ個是高的,另一個就會變低。


舉例來說,


資料 個數
Positive 5
Negative 95
評估結果
True Positive 1
False Positive 2
False Negative 4
True Negative 93
精確度 Accuracy 94%
精確率 Precision \(\frac{1}{1+2} = 0.333\)
回現率 Recall \(\frac{1}{1+4} = 0.2\)

F 值 (Fmeasure)


通常精確率與回現率,只要有ㄧ個是高的,另一個就會變低。但直接將兩個平均,也不是好的指標


模型 精確率 回現率 平均
A 0.6 0.39 0.495
B 0.02 1.0 0.51

B 模型是將全部的資料都分類為 Positive,但因為 Negative 也分類為 Positive,所以精確率很低,實際上,B 不是一個好的模型。


Fmeasure 定義如下,只要 Precision 或 Recall 其中一項變低,就會影響到 Fmeasure


\( Fmeasure = \frac{2}{ \frac{1}{Precision} + \frac{1}{Recall} } = \frac{2 \cdot Precision \cdot Recall}{Precision + Recall}\)


模型 精確率 回現率 平均 Fmeasure
A 0.6 0.39 0.495 0.472
B 0.02 1.0 0.51 0.039

F值有時被稱為 F1 值




F值可再加上權重


\(WeightedFmeasure = \frac{ (1+𝛽)^2 \cdot Precision \cdot Recall }{ 𝛽^2 \cdot Precision + Recall }\)


將權重設定為 1 就是原本的 F 值,也就是 F1值




剛剛都是以 TP 為主,考慮精確率與回現率。如果以 TN 為主


\( Precision = \frac{TN}{TN+FN} \)


\( Recall = \frac{TN}{TN+FP} ​\)


當測試資料 Positive 的部分較少,就用 Positive 的 Precision, Recall 來評估。




交叉驗證中,以 K 等分交叉驗證最常見


  • 將學習資料分為 K 筆
  • K-1 筆作為學習用的資料,1 筆作為測試資料
  • 將學習用資料與測試資料,一邊交換,一邊驗證,重複 K 次交叉驗證
  • 最後計算 K 筆精度平均值,視為最終的精度

例如 4 等分



正規化


過適 Overfitting


只跟學習資料吻合的狀態就是 overfitting。如果迴歸中 \(f_𝜃(x)\) 的次方數過度增加,就會造成 overfitting。分類有一樣的問題。


為了避免 overfitting,有以下對應方式


  • 增加學習資料的數量
  • 將模型簡化為較簡單的形式
  • 正規化

正規化


在迴歸分析中的誤差函數為


\({E(𝜃)= \frac{1}{2} \sum_{i=1}^{n}( y^{(i)} - f_𝜃(x^{(i)})^2 }​\)


對該目標函數,再增加正規化的項目 \(R(𝜃) = \frac{𝜆}{2} \sum_{j=1}^{m} 𝜃_j^2​\)


\({E(𝜃)= \frac{1}{2} \sum_{i=1}^{n}( y^{(i)} - f_𝜃(x^{(i)})^2 } + R(𝜃)​\)


對新的目標函數進行最小化,就是正規化


m 是參數的個數,通常對於 \(𝜃_0\) 來說,無法做正規化,只能從 j=1 開始。例如 \(f_𝜃(x) = 𝜃_0 + 𝜃_1x + 𝜃_2x^2\) 中,m 為 2,正規化的參數對象為 \(𝜃_1, 𝜃_2\)。\(𝜃_0\) 被稱為 bias 項


𝜆 是決定對於正規化項的影響為正的常數,要自己決定用什麼值。


正規化的效果


先將目標函數分為兩項


\( C(𝜃)= \frac{1}{2} \sum_{i=1}^{n}( y^{(i)} - f_𝜃(x^{(i)})^2 \)


\(R(𝜃) = \frac{𝜆}{2} \sum_{j=1}^{m} 𝜃_j^2​\)


因為 C(𝜃) 假設是任意一個曲線,R(𝜃) 是二次函數,任意假設它為 \( \frac{1}{2} 𝜃_1^2​\) ,是通過原點的二次函數



正規化後,\(𝜃_1\) 最小值會往原點靠近




\( f_𝜃(x) = 𝜃_0 + 𝜃_1x+ 𝜃_2x^2 \) 是二次曲線,但如果 \( 𝜃_2 \) 為 0,變成一次直線,就簡化了模型


𝜆 的大小,決定正規化的影響,如果 \( 𝜆=0 \) 就等於沒有用到正規化


分類的正規化


分類問題是用對數似然函數


\(\log L(𝜃) = \log \prod _{i=1}^{n} P( y^{(i)} = 1|x^{(i)} )^{y^{(i)}} P( y^{(i)} = 0|x^{(i)} )^{1-y^{(i)}}​\)


正規化就是再加上 R(𝜃) 的部分,另外因為對數似然指數,原本是要最大化,為了轉換為最小化,加上負號


\(\log L(𝜃) = - \log \prod _{i=1}^{n} P( y^{(i)} = 1|x^{(i)} )^{y^{(i)}} P( y^{(i)} = 0|x^{(i)} )^{1-y^{(i)}} + \frac{𝜆}{2} \sum_{j=1}^{m}𝜃_j^2\)


正規化後的微分


因為 \( E(𝜃) = C(𝜃) + R(𝜃) ​\) ,就分別對 C(𝜃), R(𝜃) 進行偏微分


\(\frac{𝜕C(𝜃)}{𝜕𝜃_j} = \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )x_j^{(i)} ​\)


\(R(𝜃) = \frac{𝜆}{2} \sum_{j=1}^{m} 𝜃_j^2 = \frac{𝜆}{2}𝜃_1^2 + \frac{𝜆}{2}𝜃_2^2 + \cdots + \frac{𝜆}{2}𝜃_m^2 \)


\( \frac{𝜕R(𝜃)}{𝜕𝜃_j} = 𝜆𝜃_j​\)


因此參數更新式就改為


\(𝜃_j := 𝜃_j - 𝜂 ( \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )x_j^{(i)} + 𝜆𝜃_j )\)


關於 \(𝜃_0​\) 的部分,無法處理正規化,因為 R(𝜃) 以 \(𝜃_0​\) 微分後變成 0




邏輯迴歸是類似的過程


E(𝜃) = C(𝜃) + R(𝜃)


\(\log L(𝜃) = - \log \prod _{i=1}^{n} P( y^{(i)} = 1|x^{(i)} )^{y^{(i)}} P( y^{(i)} = 0|x^{(i)} )^{1-y^{(i)}} + \frac{𝜆}{2} \sum_{j=1}^{m}𝜃_j^2\)


\(R(𝜃) = \frac{𝜆}{2} \sum_{j=1}^{m} 𝜃_j^2\)


微分後,因為 R(𝜃) 以 \(𝜃_0​\) 微分後變成 0


\(𝜃_0:= 𝜃_0 - 𝜂 ( \sum_{i=1}^{n}( f_𝜃(x^{(i)}) - y^{(i)} )x_j^{(i)} ) ​\)


\(𝜃_j := 𝜃_j - 𝜂 ( \sum_{i=1}^{n}( f_𝜃(x^{(i)}) - y^{(i)} )x_j^{(i)} + 𝜆𝜃_j ) \quad\quad\quad j > 0 \)




正規化不是只有一種,目前都是 L2 正規化。 L1正規化,用在判斷非必要的參數,會變成0,用來減少變數的數量。L1 正規化是要減少不必要的變數,L2 正規化是要抑制變數的影響。


學習曲線


乏適 Underfitting


跟 Overfitting 相反的狀況是 Underfitting,也就是找不到適合學習資料的模型。


光只有查看精度,沒辦法判斷是 overfitting 或是 underfitting


以這個圖形為例,看起來學習資料是二次曲線,很難找到一條一次函數的直線,適合這些學習資料,隨著學習資料變多,模型的精度會一直下降



使用學習資料數量較少的模型,去預測未知資料會比較困難,精度會比較低。如果學習資料越多,精度會增加。


將學習用資料與測試用資料數量,及精度的對照,繪製圖表。


如出現高偏差的狀況,對學習資料數量增加,但精度降低,測試資料增加,精度會增加,且兩個精度會越來越接近。這種狀況就是 underfitting



如果發現有高方差的狀況,就是對學習資料數量增加,但精度一直維持很高,測試資料增加,卻沒辦法增加精度。這種狀況就是 overfitting



這種將資料個數與精度繪製的圖表,就稱為學習曲線。


References


練好機器學習的基本功:用Python進行基礎數學理論的實作

2019/8/25

機器學習_分類

如果要判斷矩形是橫向或是縱向,通常可以直接看出結果。


可以將所有矩形左下角跟原點對齊,以右上角的座標位置,來判斷橫或縱向。分類就是要找出一條分類的線,將兩種矩形分類。



這條線,是將權重向量 (weight) 視為法線的直線,因為互相垂直,所以內積為 0。


\(w \cdot x = \sum_{i=1}^{n} w_ix_i = 0\)


如果有兩個維度,且 \(weight = (1,1)\)


\(w \cdot x = w_1x_1 + w_2x_2 = 1 \cdot x_1 + 1 \cdot x_2 = x_1 + x_2 = 0\)



內積也可以用另一個式子


\(w \cdot x = |w| \cdot |x| \cdot cos𝜃​\)


因為內積為 0,表示 \(cos𝜃 =0\),也就是 \(𝜃=90^o \) 或 \(𝜃=270^o\)


一般來說是要用機器學習,找出權重向量,得到跟該向量垂直的直線,再透過該直線進行分類。


感知器


感知器是一個可以接受多個輸入,並對每一個值,乘上權重再加總,輸出得到的結果的模型。



  • 準備機器學習的資料

矩形大小 形狀 \(x_1\) \(x_2\) \(y\)
80 x 150 縱向 80 150 -1
60 x 110 縱向 60 110 -1
160 x 50 橫向 160 50 1
125 x 30 橫向 125 30 1

識別函數 \(f_w(x)​\) 就是給定向量 \(x​\) 後,回傳 1 或 -1 的函數,用來判斷橫向或縱向。


\(f_w(x) = \left\{\begin{matrix} 1 \quad (w \cdot x \geq 0) \\ -1 \quad (w \cdot x < 0) \end{matrix}\right.​\)


回到 \(w \cdot x = |w| \cdot |x| \cdot cos𝜃\) 這個式子,如果內積為負數,那麼表示 \( 90^o < 𝜃 < 270^o\)


內積是用來表示向量之間相似程度,如果是正數,就是相似,0 是直角,負數代表不相似。


  • 權重更新式

\(w := \left\{\begin{matrix} w + y^{(i)}x^{(i)} \quad (f_w(x) \neq y^{(i)}) \\ w \quad \quad \quad \quad (f_w(x) = y^{(i)}) \end{matrix}\right.​\)


i 是學習資料的索引,這個權重更新是針對所有學習資料重複處理,用來更新權重向量。上面的部分,意思是藉由識別函數進行分類失敗時,才要去更新權重向量。


權重向量是用隨機的值初始化的



因為一開始識別函數的結果為 -1,而學習資料 \(y^{(1)} =1\),所以要更新權重向量


\( w + y^{(1)}x^{(1)} = w + x^{(1)}​\)



運用向量加法得到新的 w,而新的 w 的法線,會讓 (125, 30) 跟 w 向量在同一側。


線性可分離


感知器有個缺點,只能用來解決線性可分離的問題,以下這樣的問題,無法用一條線去分類。



所以如果要處理圖片分類,就沒辦法以線性的方式處理。


上面例子是單層的感知器,多層感知器,就是神經網路。


另外有一種方法,可用在線性不可分離的問題上:邏輯迴歸 Logistic Regression


邏輯迴歸 Logistic Regression


這種方法是將分類用機率來思考。以一開始矩形橫向或縱向的例子,這邊假設橫向為 1,縱向為 0。


S 型函數


前面的回歸有提到 \(f_𝜃(x) = 𝜃^T x​\) 這個函數,可用最速下降法或隨機梯度下降法學習 𝜃,然後用 𝜃 求得未知資料 x 的輸出值。


這邊需要的函數為,其中 \(exp(-𝜃^Tx) = e^{-𝜃^T x}\) , e 為自然對數的底數 Euler's number (2.71828)


\(f_𝜃(x) = \frac{1}{1 + exp(-𝜃^Tx)}​\)


會稱為 S 函數是因為如果將 \(𝜃^Tx\) 設為橫軸,\(f_𝜃(x)\) 設定為縱軸,會出現這樣的圖形



S函數的特徵是 當 \(𝜃^Tx = 0​\) 會得到 \(f_𝜃(x) = 0.5​\),且 \(0< f_𝜃(x) <1 ​\)


當作機率處理的原因是 \(0< f_𝜃(x) <1 \)


決策邊界


將未知資料 x 屬於橫向的機率設為 \(f_𝜃(x)\) ,用條件機率的方式描述為


\(P( y = 1 | x) = f_𝜃(x)​\)


當給予資料 x 時,y=1的機率為 \(f_𝜃(x)\) 。如果計算後的結果,機率為 0.7,就表示矩形為橫向的機率為 0.7。以 0.5 為閥值,判斷是不是橫向。


\(y = \left\{\begin{matrix} 1 \quad (f_𝜃(x) \geq 0.5) \\ 0 \quad (f_𝜃(x) < 0.5) \end{matrix}\right.​\)


回頭看 S 函數,當 \(f_𝜃(x) \geq 0.5​\),也就是 \(𝜃^T x >0​\) ,就判定為橫向。可將判斷式改為


\(y = \left\{\begin{matrix} 1 \quad (𝜃^T x \geq 0) \\ 0 \quad (𝜃^T x < 0) \end{matrix}\right.\)




任意選擇一個 𝜃 為例子,\(x_1\) 是橫長, \(x_2\) 是高


\(𝜃= \left[ \begin{matrix} 𝜃_0 \\ 𝜃_1 \\𝜃_2 \end{matrix} \right] = \left[ \begin{matrix} -100 \\ 2 \\ 1 \end{matrix} \right] , x= \left[ \begin{matrix} 1 \\ x_1 \\x_2 \end{matrix} \right] ​\)


\(𝜃^T x = -100 \cdot 1 +2x_1 +x_2 \geq 0 \)


\( x_2 \geq -2x_1 +100 \) 就表示分類為橫向


以圖形表示



以 \(𝜃^Tx =0\) 這條直線為邊界線,就能區分橫向或縱向,這條線就是決策邊界。但實際上,這個任意選擇的 𝜃 並不能正確的進行分類,因此為了求得正確的 𝜃 ,就要定義目標函數,進行微分,以求得正確的參數 𝜃 ,這個方法就稱為邏輯迴歸。


似然函數


現在要找到 𝜃 的更新式


一開始將 x 為橫向的機率 \(P( y = 1 | x) ​\) 定義為 \(f_𝜃(x)​\) ,根據這個定義,學習資料 \(y​\) 跟 \(f_𝜃(x)​\) 的關係,最佳的狀況是 \(y=1, f_𝜃(x)=1​\) , \(y=0, f_𝜃(x)=0​\) ,但還要改寫為


  • \( y=1 ​\) 時,要讓機率 \(P( y = 1 | x) ​\) 是最大,判定為橫向
  • \( y=0 ​\) 時,要讓機率 \(P( y = 0| x) ​\) 是最大,判定為縱向

矩形大小 形狀 \(y\) 機率
80 x 150 縱向 0 要讓機率 $P( y = 0
60 x 110 縱向 0 要讓機率 $P( y = 0
160 x 50 橫向 1 要讓機率 $P( y = 1
125 x 30 橫向 1 要讓機率 $P( y = 1

因為所有學習資料互相獨立沒有關聯,整體的機率就是全部的機率相乘


\(L(𝜃) = P( y^{(1)} = 0|x^{(1)} ) P( y^{(2)} = 0|x^{(2)} ) P( y^{(3)} = 1|x^{(3)} ) P( y^{(4)} = 1|x^{(4)} ) \)


這個式子可改寫為


\(L(𝜃) = \prod _{i=1}^{n} P( y^{(i)} = 1|x^{(i)} )^{y^{(i)}} P( y^{(i)} = 0|x^{(i)} )^{1-y^{(i)}}​\)




如果假設 \(y^{(i)} = 1​\)


\(P( y^{(i)} = 1|x^{(i)} )^1 P( y^{(i)} = 0|x^{(i)} )^0 = P( y^{(i)} = 1|x^{(i)} )​\)


如果假設 \(y^{(i)} =0 \)


\(P( y^{(i)} = 1|x^{(i)} )^0 P( y^{(i)} = 0|x^{(i)} )^1 = P( y^{(i)} = 0|x^{(i)} )​\)




目標函數 \(L(𝜃)\) 就稱為似然函數 Likelihood,就是要找到讓 \(L(𝜃)\) 最大的參數 𝜃


對數似然函數


因為機率都小於 1,機率的乘積會不斷變小,在程式設計會產生精確度的問題。所以加上 log


\(\log L(𝜃) = \log \prod _{i=1}^{n} P( y^{(i)} = 1|x^{(i)} )^{y^{(i)}} P( y^{(i)} = 0|x^{(i)} )^{1-y^{(i)}}​\)


因為 log 是單調遞增函數,因此不會影響到結果。換句話說,要讓 \(L(𝜃)​\) 最大化,跟要讓 \(logL(𝜃)​\) 最大化是一樣的。


\(\begin{equation}
\begin{split}
\log L(𝜃) &= \log \prod _{i=1}^{n} P( y^{(i)} = 1|x^{(i)} )^{y^{(i)}} P( y^{(i)} = 0|x^{(i)} )^{1-y^{(i)}}\\
&=\sum_{i=1}^{n}( \log P( y^{(i)} = 1|x^{(i)} )^{y^{(i)}} + log P( y^{(i)} = 0|x^{(i)} )^{1-y^{(i)}} ) \\
&=\sum_{i=1}^{n}( {y^{(i)}} \log P( y^{(i)} = 1|x^{(i)} ) + ({1-y^{(i)}}) \log P( y^{(i)} = 0|x^{(i)} ) ) \\
&=\sum_{i=1}^{n}( {y^{(i)}} \log P( y^{(i)} = 1|x^{(i)} ) + ({1-y^{(i)}}) \log (1-P( y^{(i)} = 1|x^{(i)} )) ) \\
&=\sum_{i=1}^{n}( {y^{(i)}} \log f_𝜃( x^{(i)} ) + ({1-y^{(i)}}) \log (1- f_𝜃(x^{(i)} )) ) \\
\end{split}
\end{equation}​\)


  • \(\log(ab) = \log a + \log b\)
  • \(\log a^b = b \log a​\)
  • 因為只考慮 \(y=1\) 或 \(y=0\),所以 \(P( y^{(i)} = 0|x^{(i)} ) = 1 - P( y^{(i)} = 1|x^{(i)} )\)

似然函數的微分


邏輯迴歸,就是將這個對數似然函數當作目標函數使用


\( \log L(𝜃) =\sum_{i=1}^{n}( {y^{(i)}} \log f_𝜃( x^{(i)} ) + ({1-y^{(i)}}) \log (1- f_𝜃(x^{(i)} )) ) ​\)


要將這個函數,個別針對參數 \(𝜃_j\) 進行偏微分


同樣利用合成函數的微分方法


\( u = \log L(𝜃)​\)


\(v = f_𝜃 (x) = \frac{1}{1 + exp(-𝜃^Tx)}​\)


然後


\(\frac{𝜕E}{𝜕𝜃_j} = \frac{𝜕u}{𝜕v} \frac{𝜕v}{𝜕𝜃_j}​\)




先計算第一項


因為 \(\log (v)\) 的微分是 \(\frac{1}{v}\)


而 \( \log (1-v)\) 的微分為


\( s =1-v ​\)


\( t = \log (s) \)


\( \frac{dt}{dv} = \frac{dt}{ds} \cdot \frac{ds}{dv} = \frac{1}{s} \cdot -1 = - \frac{1}{1-v} \)


所以


\( \frac{𝜕u}{𝜕v} = \frac{𝜕}{𝜕v} \sum_{i=1}^{n}( {y^{(i)}} \log (v) + ({1-y^{(i)}}) \log (1- v ) ) = \sum_{i=1}^{n} ( \frac{y^{(i)}}{v} - \frac{1- y^{(ii)} }{1-v} )\)




然後將 \(v\) 以 \(𝜃_j\) 微分


\(\frac{𝜕v}{𝜕𝜃_j} = \frac{𝜕}{𝜕𝜃_j} \frac{1}{ 1+ exp(-𝜃^Tx)} ​\)


因為 \(f_𝜃 (x)\) 是 S 型函數,且已知 S 型函數的微分為


\( \frac{d𝜎(x)}{dx} = 𝜎(x) (1-𝜎(x)) \)


利用合成函數的微分方法


\( z = 𝜃^T x\)


\(v = f_𝜃 (x) = \frac{1}{1 + exp(-z)}\)


\(\frac{𝜕v}{𝜕𝜃_j} = \frac{𝜕v}{𝜕z} \frac{𝜕z}{𝜕𝜃_j} ​\)


前面的部分


\( \frac{𝜕v}{𝜕z} = v(1-v) ​\)


後面的部分


\( \frac{𝜕z}{𝜕𝜃_j} = \frac{𝜕}{𝜕𝜃_j} 𝜃^Tx = \frac{𝜕}{𝜕𝜃_j} (𝜃_0x_0 +𝜃_1x_1 +\cdots + 𝜃_nx_n ) = x_j ​\)


所以


\(\frac{𝜕v}{𝜕𝜃_j} = \frac{𝜕v}{𝜕z} \frac{𝜕z}{𝜕𝜃_j} = v(1-v) x_j ​\)




\(\begin{equation}
\begin{split}
\frac{𝜕u}{𝜕𝜃_j} &= \frac{𝜕u}{𝜕v} \frac{𝜕v}{𝜕𝜃_j} \\
& = \sum_{i=1}^{n}( \frac{y^{(i)}}{v} - \frac {1 - y^{(i)}}{1-v} ) \cdot v(1-v) \cdot x_j^{(i)} \\
& = \sum_{i=1}^{n}( y^{(i)}(1-v) - (1-y^{(i)})v )x_j^{(i)} \\
& = \sum_{i=1}^{n}( y^{(i)} -v )x_j^{(i)} \\
& = \sum_{i=1}^{n}( y^{(i)} - f_𝜃(x^{(i)}) )x_j^{(i)} \\
\end{split}
\end{equation}​\)


先前最小化,是要往微分後的結果的正負符號相反方向移動。但現在要最大化,所以要往微分後的結果的正負符號相同方向移動。


\( 𝜃_j := 𝜃_j + 𝜂 \sum_{i=1}^{n}( y^{(i)} - f_𝜃(x^{(i)}) )x_j^{(i)} ​\)


也能配合多元迴歸,改寫成這樣


\(𝜃_j := 𝜃_j - 𝜂 \sum_{i=1}^{n}( f_𝜃(x^{(i)}) - y^{(i)} )x_j^{(i)}\)


線性不可分離


線性不可分離的問題不能直線,但可嘗試用曲線。


例如,將 \(x_1^2\) 加入學習資料


\(𝜃= \left[ \begin{matrix} 𝜃_0 \\ 𝜃_1 \\𝜃_2 \\𝜃_3 \end{matrix} \right] , x= \left[ \begin{matrix} 1 \\ x_1 \\x_2 \\x_1^2\end{matrix} \right] ​\)


然後


\( 𝜃^Tx = 𝜃_0+𝜃_1x_1+𝜃_2x_2+𝜃_3x_1^2 ​\)


假設


\(𝜃= \left[ \begin{matrix} 𝜃_0 \\ 𝜃_1 \\𝜃_2 \\𝜃_3 \end{matrix} \right] = \left[ \begin{matrix} 0 \\ 0 \\1 \\-1 \end{matrix} \right] \)


因為 \(𝜃^Tx \geq 0​\)


\( 𝜃^Tx = 𝜃_0+𝜃_1x_1+𝜃_2x_2+𝜃_3x_1^2 = x_2 - x_1^2 \geq 0 ​\)


得到方程式 \( x_2 \geq x_1^2 ​\)



現在的決策邊界變成曲線,因為參數 𝜃 是任意選擇的,所以資料沒有被正確地分類。


可以增加次方數,得到複雜的形狀的決策邊界。


另外還有 SVM (支援向量機) 的分類演算法,多元分類處理方法。


References


練好機器學習的基本功:用Python進行基礎數學理論的實作

2019/8/18

機器學習_線性迴歸


機器學習擅長


  1. 回歸 regression


    將連續性的資料進行觀察,用以預測未來的結果。例如股價、身高、體重

  2. 分類 classification


    收集既有的資料,進行訓練,根據訓練結果預測新資料的分類。例如:垃圾郵件判斷、手寫數字辨識

  3. 分群 clustering


    根據資料進行分群,但跟分類不同,分類的訓練資料已經有標記結果,要用來分群的資料並沒有群組的標記。例如:根據學測成績,進行文理組分群


使用具有標記的資料進行機器學習,稱為監督式學習。


使用不具有標記的資料進行機器學習,稱為非監督式學習。


線性迴歸(Linear regression)


當我們取得原始量測資料時,如果在平面座標上標記這些量測點,會感覺到這些點之間,可以畫出最接近這些點的一條直線方程式,線性回歸方法,可以找到這樣的方程式,未來就可以根據這個方程式,預測數值。



從圖形看起來,我們可找到一條「最接近」所有紅色觀測值的直線,以直線方程式 \({f(x)=ax+b}\) 表示這條直線,我們要做的就是找到一個方法,確定 a 與 b 的值,未來就可以利用這個方程式預測數據。在統計中,為了因應未來可能有很多未知數的問題,改以這樣的寫法:


\({f(x)=𝜃_0+𝜃_1x}\)


最小平方法


假設目前有這些數據,當我們任意找 \({𝜃_0 =1, 𝜃_1=2}​\) 時,f(x) 跟實際上的 y 之間有誤差。所以要找到適當的參數,讓 f(x) 與 y 之間的誤差最小,當然如果誤差為 0 是最好的。


x y f(x)
58 374 117
70 385 141
81 375 163
84 401 169

定義誤差函數為


\({E(𝜃)= \frac{1}{2} \sum_{i=1}^{n}( y^{(i)} - f_𝜃(x^{(i)})^2 }\)


  • \(x^{(i)}, y^{(i)}\) 分別是第 i 項的 x 與 y,例如 \(x^{(1)} = 58, y^{(1)}=374\)
  • \({(y^{(i)} - f_𝜃(x^{(i)}) }\) 是誤差值,但因為誤差有可能是負數,所以就用平方,轉成正數
  • 將所有誤差的平方加總後,為了微分計算方便,就在前面再乘上 1/2。因為乘上任意的正數,只會讓圖形橫向壓扁,但不會改變最小值的位置。任意的正數都可以,選擇 1/2 是因為後面的例子,f(x) 是二次函數的關係。

在讓 E(𝜃) 最小的狀況下,找到的 \(𝜃_0, 𝜃_1​\) 就是最小平方法


最速下降法


剛剛要讓 E(𝜃) 最小的狀況下,必須不斷地找到不同的 \(𝜃_0, 𝜃_1​\),這個計算很麻煩,可用微分來解決,因為微分就是在找函數切線斜率的變化。


例如 \(f(x) = (x-1)^2\) ,微分後 \(f'(x) = 2x-2\)


x f'(x) 的正負 f(x) 遞增或遞減
\(x < 1\) \(-\) 遞減
\(x=0\) 0
\(x>1\) \(+\) 遞增

f(x) 的圖形如下,當 x 由 3 往 1 逼近,f(x) 就越來越小,另外當 x 由 -1 往 1 逼近,f(x) 也會越來越小



意思就是說,只要 x 往導函數 (微分) 的反方向移動,就函數值會往最小值移動。


最速下降法(梯度下降法) 就是定義為


\(x := x - 𝜂 \frac{d}{dx}f(x)​\)


以實際的數字為例,當 \(𝜂 = 1, x = 3​\),x 會在 3 與 -1 之間往返


\(x := 3-1(2*3-2) = -1​\)


\(x := -1-1(2*(-1)-2) = 3​\)


當 \(𝜂 = 0.1, x = 3\),x 會往最小值逼近


\(x := 3-0.1(2*3-2) = 2.6​\)


\(x := 2.6-0.1(2*2.6-2) = 2.3​\)


\(x := 2.3-0.1(2*2.3-2) = 2.1​\)


當 𝜂 越大,x 就會往返,當 𝜂 越小,x 會往最小值逼近




回到剛剛的 誤差函數


\({E(𝜃)= \frac{1}{2} \sum_{i=1}^{n}(y^{(i)} - f_𝜃(x^{(i)})^2 }\)


因為 \(f_𝜃(x^{(i)})\) 是 \({f(x)=𝜃_0+𝜃_1x}\) ,有兩個未知的參數 \(𝜃_0, 𝜃_1\) ,要改用偏微分找最小值。


\(𝜃_0 := 𝜃_0 - 𝜂 \frac{𝜕E}{𝜕𝜃_0}​\)


\(𝜃_1 := 𝜃_1 - 𝜂 \frac{𝜕E}{𝜕𝜃_1}​\)


因 E(𝜃) 裡面有 \(f_𝜃(x)​\),而 \(f_𝜃(x)​\) 裡面有 𝜃


\(u = E(𝜃)​\)


\(v = f_𝜃(x)​\)


然後用合成函數的方式,計算微分


\(\frac{𝜕E}{𝜕𝜃_0} = \frac{𝜕u}{𝜕v} \frac{𝜕v}{𝜕𝜃_0}​\)


其中,前面的部分


\( \frac{𝜕u}{𝜕v} = \frac{𝜕}{𝜕v}( \frac{1}{2} \sum_{i=1}^{n}(y^{(i)} - v)^2 ) = \frac{1}{2} \sum_{i=1}^{n}\frac{𝜕}{𝜕v}(y^{(i)} - v)^2 = \frac{1}{2}\sum_{i=1}^{n}( -2y^{(i)} +2v ) = \sum_{i=1}^{n}( v-y^{(i)} )\)


後面的部分


\(\frac{𝜕v}{𝜕𝜃_0} = \frac{𝜕}{𝜕𝜃_0}( 𝜃_0 + 𝜃_1x ) = 1​\)


所以


\(\frac{𝜕E}{𝜕𝜃_0} = \sum_{i=1}^{n}( f_𝜃(x^{(i)})-y^{(i)} )\)


另外對 \(𝜃_1​\) 微分,可得到


\(\frac{𝜕v}{𝜕𝜃_1} = \frac{𝜕}{𝜕𝜃_1}( 𝜃_0 + 𝜃_1x ) = x\)


\(\frac{𝜕E}{𝜕𝜃_1} = \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )x^{(i)} \)




最後


\(𝜃_0 := 𝜃_0 - 𝜂 \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )\)


\(𝜃_1 := 𝜃_1 - 𝜂 \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )x^{(i)}​\)


用這個方法,就可以找出正確的 \(𝜃_0, 𝜃_1​\)


多項式回歸


一開始,我們假設數據的模型是線性的,所以使用一次函數,但也可能用二次或更高次的函數來定義 \(f_𝜃(x)​\),會更貼近原本的數據模型


\(f_𝜃(x) = 𝜃_0 + 𝜃_1x + 𝜃_2x^2​\)


\(f_𝜃(x) = 𝜃_0 + 𝜃_1x + 𝜃_2x^2 + \dots +𝜃_nx^n​\)


回到剛剛的問題,要對 \(𝜃_2​\) 進行偏微分


對 \(𝜃_1​\) 微分,可得到


\(\frac{𝜕v}{𝜕𝜃_2} = \frac{𝜕}{𝜕𝜃_2}( 𝜃_0 + 𝜃_1x +𝜃_2x^2 ) = x^2\)


\(\frac{𝜕E}{𝜕𝜃_2} = \sum_{i=1}^{n}( f_𝜃(x^{(i)} )- y^{(i)} )(x^{(i)} )^2\)


多元回歸


目前解決的問題,都只有一個變數 x,但大多數的問題,都是有兩個以上的變數。例如廣告的點擊率,可能會受廣告費、顯示位置、顯示大小 等原因影響。


\(f_𝜃(x_1, x_2, x_3)=𝜃_0+𝜃_1x_1+𝜃_2x_2+𝜃_3x_3\)


當變數有 n 個,可改用向量的方式表示


\(𝜃= \left[ \begin{matrix} 𝜃_0 \\ 𝜃_1 \\𝜃_2 \\. \\. \\𝜃_n \end{matrix} \right] x= \left[ \begin{matrix} x_1 \\ x_2 \\x_3 \\. \\. \\x_n \end{matrix} \right] ​\)


但因為 𝜃 跟 x 個數不同,不容易計算,就再加上一項 \(x_0 =1​\)


\(𝜃= \left[ \begin{matrix} 𝜃_0 \\ 𝜃_1 \\𝜃_2 \\. \\. \\𝜃_n \end{matrix} \right] x= \left[ \begin{matrix} x_0 \\ x_1 \\x_2 \\. \\. \\x_n \end{matrix} \right] \)


將 𝜃 變成轉置矩陣後,再跟 x 相乘,就會是剛剛的 \(f_𝜃(x)\)


\(𝜃^Tx = 𝜃_0x_0+𝜃_1x_1+ \dots + 𝜃_nx_n = f_𝜃(x) \)


變成向量後,再用剛剛合成函數偏微分的方法


\(u = E(𝜃)​\)


\(v = f_𝜃(x)​\)


\(\frac{𝜕u}{𝜕𝜃_j} =\frac{𝜕E}{𝜕𝜃_j} = \frac{𝜕u}{𝜕v} \frac{𝜕v}{𝜕𝜃_j}\)


前面的部分一樣,後面的部分


\(\frac{𝜕v}{𝜕𝜃_j} = \frac{𝜕}{𝜕𝜃_j}( 𝜃^Tx ) = \frac{𝜕}{𝜕𝜃_j}( 𝜃_0x_0+𝜃_1x_1+\dots+𝜃_nx_n )= x_j\)


第 j 項參數的定義為


\(𝜃_j := 𝜃_j - 𝜂 \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )x_j^{(i)}\)


當變數增加,計算量變大,用最速下降法會導致計算速度變慢,可用隨機梯度下降法改進。


最速下降法除了有計算速度慢的問題,還有可能陷入局部解的問題,像以下的函數圖形中,不同的起點,可能會找到局部最小值。



隨機梯度下降法


在多元迴歸中,第 j 項參數的定義為


\(𝜃_j := 𝜃_j - 𝜂 \sum_{i=1}^{n}( f_𝜃(x^{(i)} )-y^{(i)} )x_j^{(i)}​\)


但因為用到所有的資料的誤差,計算量太大,隨機梯度下降法式隨機選擇一項學習資料,套用在參數的更新上,例如選擇第 k 項。


\(𝜃_j := 𝜃_j - 𝜂 ( f_𝜃(x^{(k)} )-y^{(k)} )x_j^{(k)}​\)


原本最速下降法用來更新一次參數的時間,隨機梯度下降法可更新 n 次參數。因為是隨機選擇學習資料,不會陷入局部解的問題。




另外也有隨機選擇 m 筆學習資料的方法,也稱為小量批次資料法,假設 m 筆資料的集合為 K


\(𝜃_j := 𝜃_j - 𝜂 \sum_{k𝜖K} ( f_𝜃(x^{(k)} )-y^{(k)} )x_j^{(k)}​\)


References


練好機器學習的基本功:用Python進行基礎數學理論的實作

2019/8/12

erlang lager with date in log filename

erlang lager 預設是以設定中的 filename 加上 .1 .2 的 postfix 作為 logfile rotate 的依據,但通常在使用 logfile,會希望直接在 logfile 看到產生該 log 的日期,這時需要使用 Custom Log Rotation 的功能,自己撰寫 log_rotator。


首先我們先找到 lager 原始程式碼中預設的 lagerrotatordefault.erl,先複製成 mylagerlog_rotator,然後修改裡面的程式碼。


-module(my_lager_log_rotator).

-include_lib("kernel/include/file.hrl").

-behaviour(lager_rotator_behaviour).

-export([
  create_logfile/2, open_logfile/2, ensure_logfile/4, rotate_logfile/2
]).

create_logfile(Name, Buffer) ->
  {{Y, M, D}, {H, _, _}} = calendar:now_to_local_time(os:timestamp()),
  DateHour =  {Y, M, D, H},
  FileName = filename(Name, DateHour, 1),
  file:delete(Name),
  file:make_symlink(filename:absname(FileName), Name),
  open_logfile(Name, Buffer).

open_logfile(Name, Buffer) ->
  case filelib:ensure_dir(Name) of
    ok ->
      Options = [append, raw] ++
        case  Buffer of
          {Size, Interval} when is_integer(Interval), Interval >= 0, is_integer(Size), Size >= 0 ->
            [{delayed_write, Size, Interval}];
          _ -> []
        end,
      case file:open(Name, Options) of
        {ok, FD} ->
          case file:read_file_info(Name) of
            {ok, FInfo} ->
              Inode = FInfo#file_info.inode,
              {ok, {FD, Inode, FInfo#file_info.size}};
            X -> X
          end;
        Y -> Y
      end;
    Z -> Z
  end.


ensure_logfile(Name, FD, Inode, Buffer) ->
  case file:read_link(Name) of
    {ok, _} ->
      lager_ensure_logfile(Name, FD, Inode, Buffer);
    _ ->
      create_logfile(Name, Buffer)
  end.


lager_ensure_logfile(Name, undefined, _Inode, Buffer) ->
  open_logfile(Name, Buffer);
lager_ensure_logfile(Name, FD, Inode, Buffer) ->
  case file:read_file_info(Name) of
    {ok, FInfo} ->
      Inode2 = FInfo#file_info.inode,
      case Inode == Inode2 of
        true ->
          {ok, {FD, Inode, FInfo#file_info.size}};
        false ->
          %% delayed write can cause file:close not to do a close
          _ = file:close(FD),
          _ = file:close(FD),
          case open_logfile(Name, Buffer) of
            {ok, {FD2, Inode3, Size}} ->
              %% inode changed, file was probably moved and
              %% recreated
              {ok, {FD2, Inode3, Size}};
            Error ->
              Error
          end
      end;
    _ ->
      %% delayed write can cause file:close not to do a close
      _ = file:close(FD),
      _ = file:close(FD),
      case open_logfile(Name, Buffer) of
        {ok, {FD2, Inode3, Size}} ->
          %% file was removed
          {ok, {FD2, Inode3, Size}};
        Error ->
          Error
      end
  end.
%%
%%%% renames failing are OK
%%rotate_logfile(File, 0) ->
%%  %% open the file in write-only mode to truncate/create it
%%  case file:open(File, [write]) of
%%    {ok, FD} ->
%%      file:close(FD),
%%      ok;
%%    Error ->
%%      Error
%%  end;
%%rotate_logfile(File0, 1) ->
%%  File1 = File0 ++ ".0",
%%  _ = file:rename(File0, File1),
%%  rotate_logfile(File0, 0);
%%rotate_logfile(File0, Count) ->
%%  File1 = File0 ++ "." ++ integer_to_list(Count - 2),
%%  File2 = File0 ++ "." ++ integer_to_list(Count - 1),
%%  _ = file:rename(File1, File2),
%%  rotate_logfile(File0, Count - 1).
%%

rotate_logfile(Name, _Count) ->
  case file:read_link(Name) of
    {ok, LinkedName} ->
      case filelib:file_size(LinkedName) of
        0 ->
          %% if the files size is zero, it is removed
          catch file:delete(LinkedName);
        _ ->
          void
      end;
    _ ->
      void
  end,
  {ok, {FD, _, _}} = create_logfile(Name, []),
  file:close(FD).

%% @doc Create name of a new file
%% @private
filename(BaseFileName, DateHour, Branch) ->
  FileName = lists:append([BaseFileName,
    suffix(DateHour, false), ".", integer_to_list(Branch)
  ]),
  case filelib:is_file(FileName) of
    true ->
      filename(BaseFileName, DateHour, Branch + 1);
    _ ->
      FileName
  end.

%% @doc Zero-padding number
%% @private
zeropad(Num, MinLength) ->
  NumStr = integer_to_list(Num),
  zeropad_str(NumStr, MinLength - length(NumStr)).
zeropad_str(NumStr, Zeros) when Zeros > 0 ->
  zeropad_str([$0 | NumStr], Zeros - 1);
zeropad_str(NumStr, _) ->
  NumStr.

%% @doc Create a suffix
%% @private
suffix({Y, M, D, H}, WithHour) ->
  YS = zeropad(Y, 4),
  MS = zeropad(M, 2),
  DS = zeropad(D, 2),
  HS = zeropad(H, 2),
  case WithHour of
    true ->
      lists:flatten([$., YS, MS, DS, $., HS]);
    _ ->
      lists:flatten([$., YS, MS, DS])
  end.

將 mylagerlogrotator 套用在 lager 的設定檔的 lagerfile_backend 中,{rotator, my_lager_log_rotator}


[
  {lager, [
    {log_root, "./log"},
    {crash_log, "crash.log"},
    {error_logger_redirect, false},
    {colored, true},
    {colors, [
      {debug,     "\e[0;36m" },
      {info,      "\e[1;37m" },
      {notice,    "\e[1;36m" },
      {warning,   "\e[1;33m" },
      {error,     "\e[1;31m" },
      {critical,  "\e[1;35m" },
      {alert,     "\e[1;44m" },
      {emergency, "\e[1;41m" }
    ]},
    {handlers, [
      {lager_console_backend, [{level, debug}, {formatter, lager_default_formatter},
        {formatter_config, [date, " ", time, color, " ", pid, " ", module, ":", line, " [", severity, "] ", message, "\e[0m\n"]}]},
      {lager_file_backend, [{file, "debug.log"}, {level, debug}, {size, 3000}, {date, "$H00"}, {count, 2},
        {formatter_config, [date, " ", time, " ", pid, " ", module, ":", line, " [", severity, "] ", message, "\n"]}, {rotator, my_lager_log_rotator}]}
    ]}
  ]}
].

現在就會產生像這樣的 logfile


debug.log (symbolic link to debug.log.20190426.2)
debug.log.20190426.1
debug.log.20190426.2

目前還會需要調整的是設定檔中的 count,如果在 logfile 加上日期,count 應該是要代表保留幾天的資料,但目前還是依照 lager 原本的定義,為保留幾個同樣 prefix 檔名的 logfile。


References


erlang lager


leologgerrotator.erl

2019/8/5

奧卡姆剃刀 Occam's Razor, Ockham's Razor

奧卡姆剃刀,是由14世紀邏輯學家、聖方濟各會修士奧卡姆的威廉(William of Occam,約1285年至1349年)提出。奧卡姆(Ockham)在英格蘭的薩裡郡,那是他出生的地方。他在《箴言書註》2捲15題說「切勿浪費較多東西,去做『用較少的東西,同樣可以做好的事情』。」後來這個原理也簡化為「如無必要,勿增實體。」(Do not multiply entities beyond necessity.)


另外一些更簡要的說法是:避重趨輕、避繁逐簡、以簡御繁、避虛就實,更白話的說法是「夠用就好」


奧卡姆剃刀並不是一個數學定理,他是一個思考的原則,這個原則告訴我們應該追求簡化,不追求完全正確的答案,選擇一個最簡單的方式,來解釋與面對問題。


剃刀原則並不是說簡單的理論就是正確的理論,應該了解為「當兩個假說具有完全相同的解釋力和預測力時,我們以那個較為簡單的假說作為討論依據。」


這個原則套用在不同知識領域時:


  1. 科學領域:當你有兩個處於競爭地位的理論能得出同樣的結論,那麼簡單的那個更好。
  2. 企業管理領域:在管理企業制定決策時,應該儘量把複雜的事情簡單化,剔除干擾,解決最根本的問題,才能讓企業保持正確的方向。
  3. 投資領域:面對複雜當投資市場,應把複雜事情簡單化,簡化自己的投資策略,擺脫那些消耗了大量金錢、時間、精力的事情。
  4. 組織結構:組織結構扁平化與組織結構非層級化已經成為企業組織變革的趨勢。員工之間的關係是平等的分工合作關係,基層員工被賦予更多的權力。由於員工的積極參與,組織目標與個人目標之間的矛盾得到最大程度地消除。
  5. 簡化文書流程:簡單的信息遠比複雜的信息更有利於人們的思考與決策。
  6. 日心說/地心說:五百年前,還不知道世界的中心是太陽還是地球,但利用望遠鏡得到的許多數據,可證明地球跟太陽都可能是中心,不過日心說只需要 7 個假設,而地心說需要更多假設,於是哥白尼在天體運行論中利用奧卡姆剃刀原則,判斷太陽才是中心。

跟人溝通時,如果遇到一些喜歡長篇大論的人,就會覺得講半天講不到重點,但講話太精簡,簡單到讓人聽不懂也是問題,精簡程度怎麼拿捏就是個人修養的問題。


References


奧卡姆剃刀定律


奧卡姆剃刀定律 wiki


奧卡姆剃刀決策原則


設計法則:奧卡姆剃刀原理

2019/7/29

NFL(No Free Lunch Theorems) 沒有免費的午餐定理


在認識或瞭解 machine learing 或 AI 的概念後,通常會想到,如果能有一個可以處理所有問題的 AI,那麼就可以解決所有問題,那就能省去大量人力,大家都不用工作了。這是一個尋找 AI 的通用演算法的問題,只要能有一個超強的演算法,那就能很快地製造出符合不同需求的 AI 機器人。


但是在尋找這個演算法以前,我們要先知道,已經有人用了數學的方法,證明了並不存在一個能一統天下的 AI 演算法模型,這就是 NFL(No Free Lunch Theorems) 沒有免費的午餐定理。


NFL 定理 www.no-free-lunch.org 有兩個,一個是 No Free Lunch for Supervised Machine Learning (WOLPERT, David H., 1996. The lack of a priori distinctions between learning algorithms. Neural Computation, 8(7), 1341–1390.),一個是 No Free Lunch for Search/Optimization (WOLPERT, David H., and William G. MACREADY, 1997. No free lunch theorems for optimization. IEEE Transactions on Evolutionary Computation, 1(1), 67–82.)。


實際上在了解到這個定理的概念後,我們要知道,在不考慮具體問題的情況下,沒有任何一個算法比另一個算法更優,甚至直接胡亂猜測還會更好。我們無法去討論哪一個演算法比較好,但如果針對某個具體的特定的問題,確實可找到表現比較好的機器學習演算法,但這個演算法,卻無法解決其他的問題。換句話說,不同的問題,就可以找到最適當的演算法,而每個演算法,都有各自適用的問題。


也可以說如果我們對要解決的問題一無所知,且並假設其分佈完全隨機且平等,那麼任何演算法的預期性能都是相似的。在某個領域、特定假設下表現卓越的演算法,不一定在另一個領域也能是最厲害的。正因如此,我們才需要研究和發明更多的機器學習算法來處理不同的假設和數據,也就是處理不同的問題。


華爾街漫步 是一本投資指南,其中廣為人知的「隨機漫步理論」研究,也就是說一隻矇著眼睛的猴子,對著眾多股票隨意擲飛鏢,所挑選出來的投資組合,也會跟專家挑的組合一樣好,猴子隨機選股大勝專業經理人。就像是 NFL 有著類似的概念,特定的演算法就像是經理人選擇的股票一樣,沒有任何演算法/選股組合可以證明,他的選股策略比較厲害。


NFL for Machine Learning 有兩條規則


  1. 沒有一個機器學習演算法,在所有可能的函數中,能夠比隨機猜測的結果更好。
  2. 每個機器學習演算法都必須包含一些數據之外的知識或者假設,才能夠將數據一般化。

要用什麼策略選擇一個適當的演算法?


想像一個由 n 個 solution(算法) 與 m 個 problem 構成的矩陣,每一個格子的值表示問題被解決的程度。


  1. Restarting


    重複運算相同的演算法,會產生不同的結果,得到多個 solution,然後分析某個演算法,看看是否為相對較優良的演算法。這種方式特別適合用於初始 & 過程隨機化的算法,因為這些演算法每次執行都會得到不完全一樣的結果。

  2. Ordinal Optimization + Softend Goals


    因為計算 n x m 矩陣實際上每個格子裡面的「值」,可能是很複雜的向量/矩陣等需要耗費大量資源來處理。為了簡化計算的負擔, Ordinal Optimization 只考慮 A 算法是否比 B 算法好,而不管兩者之間的差距有多大。


    Softened Goals 則是常見的取捨法則:放棄追求「最最最好」的解法,轉而追求「足夠好」的方式。

  3. Ensemble Learning


    透過多個演算法彼此獨立工作,最終以類似「投票」、「比稿」的方式來決定最終預測的結果值。


References


應該如何理解No Free Lunch Theorems for Optimization?


百度百科 no free lunch


機器學習裡不存在的免費午餐:NO FREE LUNCH THEOREMS


機器學習周志華--沒有免費的午餐定理


帶你瞭解機器學習(一): 機器學習中的“哲學”


No Free Lunch on Machine Learning

2019/7/21

Linux IO: select, poll, epoll


以下是 Linux 的 IO model 的說明,另外 select, poll 及 epoll 是 IO multiplexing 的三種方法。


Kernal Space, User Space


Linux 是多工作業系統,因此有可能會發生多個 process 競爭相同的記憶體位址的狀況,為了解決衝突的問題,由 kernel 進行資源分配,也可避免資源佔用的問題,也可避免異動了別的 process 的資源而造成系統 crash。


CPU 執行程式時,會在 user space 與 kernel space 之間來回切換,user space 的系統函式庫,會轉換為 kernel space 的 system call,並由 kernel 處理,當 system call 完成後,就會回到 user space 繼續下去。


32 bits 的 OS,定址空間是 2^32 也就是 4G。kernel space 限制為 1G (虛擬地址0xC0000000到0xFFFFFFFF),而 user space 為 3G (虛擬地址0x00000000到0xBFFFFFFF),由各 process 使用。


64 bits 的 OS,會將 virtual address 分成一半,第一個 bit 為 0 是 user space,第一個 bit 為 1 是 kernel space,理論上是 8EB+8EB。但目前 processors 只實作了 48 bits,也就是 128TB+128TB。



process context switching


為了控制讓多個 process 分享系統資源,kernel 必須能夠儲存目前在 CPU 運作的 process,載入並執行新的 process,這個切換方式稱為 context switching,時間長短由硬體運算能力決定。


context switching 有以下的步驟


1.儲存 CPU 的 context,包括program counter和其他 register
2.更新PCB (Process Control Block)
3.把進程的PCB移入相應的queue,如 ready/blocking 等隊列
4.選擇另一個進程執行,並更新其PCB
5.更新memory的資料結構
6.恢復 CPU context


FD: file descriptor


file descriptor 是指向檔案 reference 的抽象概念。他是非負整數的索引值,指向 kernel 為每一個 process 維護的開啟檔案的記錄表,當程式打開或建立一個檔案,kernel 就會產生一個 file descriptor 給 process。

每一個 linux process 都有三個標準的 POSIX file descriptor: stdin 0, stdout 1, stderr 2


Buffered I/O


大多數文件系統的默認I/O 操作都是 Buffered I/O,在 Linux 會將 IO 資料先暫存在 page cache 中,也就是先複製到 kernel 的 buffer,然後再由 kernel buffer 複製到 user space。


Buffered I/O 分離了 user space 及實際的儲存設備,可以減少 HD 的讀取次數,提高系統效能。


但也因為多次複製,可能會造成 CPU 及 cache buffer 的消耗,有些特殊的應用,會避開 kernel cache buffer,而直接由 user space 儲存到 HD,以獲取更高的效能。


IO model


因為資料會先複製到 kernel buffer 裡面,然後再複製到 user space,當對一個資料進行 read,會經歷兩個階段:


  1. waiting for data to be ready
  2. copying the data from the kernel to the process

因為兩階段的 IO,linux 產生了五種 IO model


  1. blocking IO
  2. nonblocking IO
  3. IO multiplexing
  4. signal driven IO (不常用)
  5. asynchronous IO

blocking IO


linux 預設大部分的 socket 都是使用 blocking IO,當 process 呼叫 recv_from,會進入 wait for data 階段,在這個階段的 process 會進入 blocking 狀態,直到 kernel 將資料複製到 user space,該 process 才會解除 blocking 狀態,重新運作。


blocking IO 就是兩個階段的 IO 都被 block


nonblocking IO


當 process 呼叫 recv_from 如果 kernel 還沒將資料準備好,他不會 block process,而是產生 error,直到 kernel 將資料準備好,就會複製到 user space,並完成該讀取的工作。


nonblocking 需要 process 不斷向 kernel 詢問,資料是否 ready。


IO multiplexing


這就是常見的 select, poll, epoll,也稱為 event driven IO。這個方式可讓單一 porcess 就可以處理多個 IO,他會不斷地 polling 多個 socket,當某個 socket 有收到資料,就會主動callback 通知 process。


如果是 select,當 process 呼叫了 select,該 process 就會被 block,同時 kernel 會監控所有 select 處理的 sockets,如果有資料,select 就會 return,然後再由 process 呼叫 read,將資料由 kernel 複製到 user space。


這個方法類似 blocking IO,但進行了兩個 system call (select 及 recv_from),但 select 可處理多個 sockets。


select/epoll 的優點是可以處理多個 sockets,而不是效能。一般在 IO multiplexing 中,socket 都是設定為 non-blocking 的,process 是在 select 被 block 而不是 recv_from。


signal driven


先通知 kernel 如果某個 socket 有資料時,就以 signal 通知 process,process 在第二個步驟,才會被 block。


asynchronous


當 process 進行 read,就可以處理別的事情,當 kernel 收到非同步 read,就會馬上 return,直到將資料複製到 user space,完成後,才會發送 signal 給 process,通知已經完成了 read。


Comparison


  • non-blocking 跟 asynchronous 是不同的

  • synchronous 跟 asynchronous 的差異是 IO operation 會不會 blocking process,因此前面四種 model 都屬於 synchronous IO

  • nonblocking IO 中,在複製資料到 user space 的步驟,還是會有 blocking 的狀態


IO Multiplexing: select, poll, epoll


IO Multiplexing 可讓單一 process 監視多個 fd,當某個 fd 有資料,就可通知 process 進行 IO 操作,select, poll, epoll 都是同步 IO,都需要自己進行讀寫,在讀寫的過程中,process 都是被 blocked。


  • select

int select (int n, fd_set *readfds, fd_set *writefds, fd_set *exceptfds, struct timeval *timeout);

select 可監視 writefd, readfd, 及 exceptfd。呼叫 select 後,該 process 會被 blocked,直到某個 fd ready 或是 timeout。當 select return 後,必須要 traverse 所有 fdset,來找到 ready 的 fd。


select 在所有平台都支援,缺點是監視的 fd 有數量上限,通常是 1024,但可修改 macro 或是重新編譯 kernel 增加這個上限。


  • poll

int poll (struct pollfd *fds, unsigned int nfds, int timeout);

struct pollfd {
    int fd; /* file descriptor */
    short events; /* requested events to watch */
    short revents; /* returned events witnessed */
};

poll 使用一個 pollfd pointer 表示 fd,該 pollfd 包含要監視的 event及發生的 event,pollfd 沒有數量上限。poll return 後,必須 traverse pollfd,找到 ready 的 fd。


  • epoll

這是在 linux kernel 2.6 以後提供的,epoll 將跟 process 有關的 fd 事件,存放在 event table 裡面。


// size 為監視的 fd 數量
int epoll_create(int size);
int epoll_ctl(int epfd, int op, int fd, struct epoll_event *event);
int epoll_wait(int epfd, struct epoll_event * events, int maxevents, int timeout);

  1. int epoll_create(int size);


    當產生了 epoll 後,會佔用一個 fd value,不同於 select 必須提供最大監視 fd 數量 +1,size 並不是該 epoll 能監視的 fd 數量上限,而是配置 kernel 內部資料的建議參數。

  2. int epoll_ctl(int epfd, int op, int fd, struct epoll_event *event);


    • epfd: 是 epoll_create 的 return value
    • op: 有三個 macro 表示 operation: EPOLLCTLADD, EPOLLCTLDEL, EPOLLCTLMOD,分別是新增、刪除、修改 fd 監視的 events
    • fd: 需要監視的 fd
    • epoll_event: 告訴 kernel 要監視什麼 event

    struct epoll_event {
      __uint32_t events;  /* Epoll events */
      epoll_data_t data;  /* User data variable */
    };
    
    //events 是以下幾個 macro 的集合:
    EPOLLIN:表示對應的文件描述符可以讀(包括對端SOCKET正常關閉)
    EPOLLOUT:表示對應的文件描述符可以寫
    EPOLLPRI:表示對應的文件描述符有緊急的資料可讀(這裡應該表示有外部資料到來)
    EPOLLERR:表示對應的文件描述符發生錯誤
    EPOLLHUP:表示對應的文件描述符被掛斷
    EPOLLET: 將EPOLL設為 Edge Triggered 模式,這是相對於水平觸發(Level Triggered)來說的
    EPOLLONESHOT:只監聽一次事件,當監聽完這次事件之後,如果還需要繼續監聽這個socket的話,需要再次把這個socket加入到EPOLL隊列
  3. int epoll_wait(int epfd, struct epoll_event * events, int maxevents, int timeout);


    等待 epfd 的 IO event,最多回傳 maxevents 個 events,events 是事件的集合,maxevents 不能超過 epoll_create 的size


    timeout 為 0 表示要馬上 return,如果回傳的事件數量為 0 表示發生了 timeout




epoll 對 fd 的操作有兩種模式: LT (Level Trigger) 及 ET (Edge Trigger)


  • LT: 當 epollwait 偵測到 fd 事件發生,將該事件通知 process,該 process 可不立刻處理該 event,當下次呼叫 epollwait 時,會再次通知 process 這個事件


    同時支援 blocking 與 non-blocking socket,可對該 ready 的 fd 進行 IO,如果不做,kernel 會持續通知 ready

  • ET: 當 epollwait 偵測到 fd 事件發生,將該事件通知 process,該 process 必須立刻處理該 event,如果沒有處理,當下次呼叫 epollwait 時,不會再次通知 process 這個事件


    這是高速運作方式,只支援 non-blocking socket




在 select/poll 中,process 必須呼叫某些 function,kernel 才會對該 fd 進行監視,而 epoll 事先利用 epollctl 註冊 fd,當某個 fd ready 後,會透過 callback 機制,啟動這個 fd,當 process 呼叫 epollwait 就可得到通知。


因為 epoll 去掉了 traverse fds 的步驟,因此可以快速處理 IO event。


epoll 監視的 fd 數量沒有限制,通常是可以打開的文件的數量,可查詢 cat /proce/sys/fs/file-max 得知。select 的缺點,是該 process 可打開的 fd 有數量上限。


如果沒有大量的 idle/dead connection,epoll 的效率不會比 select/poll 高很多。


References


Linux IO模式及 select、poll、epoll詳解


select、poll、epoll之間的區別總結


select,poll,epoll優缺點及比較


關於epoll和select的區別,哪些說法是正確的?


Select、Poll與Epoll比較


Linux 開發,使用多線程還是用 IO 復用 select/epoll?


Socket IO model of humor on the Linux

2019/7/15

Rust Programming Language


今年的 Stack Overflow 調查報告中,發現 Rust 是目前最受歡迎的程式語言。因此我們了解一下Rust 程式語言的定位是:一種撰寫可靠且有效率的軟體的程式語言。 A language empowering everyone to build reliable and efficient software. Rust 是 Mozilla 開發的程式語言,2010 年誕生,目前是 1.34.2 版,其設計目的是開發大型 server 端軟體,強調安全性、記憶體處理及並行處理。雖然效能比 C++ 稍差一點點,但提供了安全的保障。


調查9萬名程序員後,我們發現了一堆不為人知的秘密 Stack Overflow 的年度開發者調查是面向全球開發者的規模最大、最全面的調查,每年的調查內容會涵蓋開發人員最喜歡的技術以及工作偏好等內容。今年是 Stack Overflow 連續第九年進行開發者調查,吸引了將近 9w 名開發人員參加。今年的調查報告結果:Rust 是最受喜愛的編程語言,Python 則是增長最快的。今年 Python 超過 Java 在開發者最喜愛的編程語言榜中排名第二。


一般在討論 Rust 的時候,會跟 C++ 一起比較,通常會使用 C++ 是因為要開發貼近硬體,高速且穩定的系統程式,因為要接近 real time,就不能使用高階語言的 GC 機制,C/C++ 可以開發出高效率的軟體,但常會遇到有關記憶體的問題,也經常發生在底層 library 發生安全漏洞時,導致使用這些 library 的軟體一夕崩壞的問題。


C++ 的發展初期,為了跟 1972年誕生的 C 語言相容,保留了很多設計上的相容性,卻也留下很多問題。而 Rust 是一個沒有歷史包袱的程式語言,當然能吸收新的設計理念,解決安全性的問題,沒有 GC 機制,可直接編譯為machine code,可以直接跟 C 語言互通。


另一種討論,是將 Rust 跟 Golang 比較,Go 有 GC 機制,但 Rust 沒有。Rust 的語法比 Go 複雜,但 Go 更常發生執行期的 crash,Rust 支援泛型。Golang 的目標,應該是取代 Java, Python 在後端運算的地位,但 Rust 的目標,是 C/C++ 的環境,基本上,Rust 跟 Golang 的使用情境應該沒有衝突。


另外,Rust 最常被討論的,是學習曲線陡峭的問題,雖然 Rust 受到開發者的推崇,但實際上,使用 Rust 的開發者並不多,主要原因是 Rust 的用意是取代 C++,而 C++ 的學習曲線比 Rust 更陡峭,而 Rust 本身的困難點在於它接近作業系統,常常會遇到要跟 C/C++ 互通的狀況,還有 Rust 有著其他程式語言不存在的語言特性。


Installation


在 mac/linux 安裝測試環境


curl https://sh.rustup.rs -sSf | sh

他會將 rust 安裝在 $HOME/.rustup,將 cargo 安裝在 $HOME/.cargo,cargo是 rust 的套件管理工具。


另外會在 $HOME/.profile 增加 PATH 環境變數,裡面有常用的指令:rustc, cargo, and rustup


export PATH="$HOME/.cargo/bin:$PATH"

確認是否有安裝完成


$ rustc --version
rustc 1.34.2 (6c2484dc3 2019-05-13)

如果要移除就要


rustup self uninstall

官方網站提供的書本有兩本


The Rust Programming Language


Rust by example


Hello World


hello.rs


fn main() {
    println!("Hello, World!");
}

編譯後,就會產生執行檔


$ rustc hello.rs
$./hello
Hello, World!

另一種開發方式,是利用 cargo 產生專案


cargo new --bin hello
cd hello/

然後修改 main.rs 內容跟剛剛的 hello.rs 一樣,再利用 cargo 編譯 hello project


$ cargo run
   Compiling hello v0.1.0 (/Users/charley/Downloads/hello)
    Finished dev [unoptimized + debuginfo] target(s) in 1.44s
     Running `target/debug/hello`
Hello, world!

如果要產生 release 版,沒有 debug message 的 執行檔,就要加上 build --release 參數


cargo build --release

另外類似 Go 內建程式碼 format 工具,rust 提供 rustfmt 工具,可用 cargo 安裝這個套件


cargo install rustfmt

在 project 中,可用以下指令重排專案


cargo fmt

如果執行時發生下錯誤


$ cargo fmt
error: 'cargo-fmt' is not installed for the toolchain 'stable-x86_64-apple-darwin'
To install, run `rustup component add rustfmt --toolchain stable-x86_64-apple-darwin`

依照說明,再安裝 rustfmt component 即可


rustup component add rustfmt --toolchain stable-x86_64-apple-darwin

Multi Thread


use std::thread;

// 產生 10 個 concurrent thread
fn main() {
    // 因 greeting 是不可變的,可以安全地同時被多個線程使用
    let greeting = "Hello";

    let mut threads = Vec::new();
    // for 可處理任何實作 iterator 特型的類別
    for num in 0..10 {
        threads.push(thread::spawn(move || {
            println!("{} from thread number {}", greeting, num);
        }));
    }

    // 等待所有 thread 結束
    for thread in threads {
        thread.join().unwrap();
    }
}

執行結果


Hello from thread number 0
Hello from thread number 2
Hello from thread number 3
Hello from thread number 1
Hello from thread number 4
Hello from thread number 5
Hello from thread number 6
Hello from thread number 7
Hello from thread number 8
Hello from thread number 9

thead之間是透過 channel 傳遞訊息


use std::thread;
use std::sync::mpsc::channel;

fn main() {
    // 產生 channel,channel 有 tx, rx 兩端
    let (tx, rx) = channel();

    // spawn a new thread,在 thread 中持續使用 channel 的 rx 接收訊息
    let join_handle = thread::spawn(move || {
        // 在 loop 中持續接收訊息,直到 tx 被 dropped
        // recv() 是 blocking method
        while let Ok(n) = rx.recv() {
            println!("Received {}", n);
        }
    });

    // 因為 rx 已經被 thread 使用了,不能在這邊使用到 rx,如果用到就會產生 compile error
    // 透過 tx 發送 10 個訊息
    // 如果接收端 被 dropped,那麼呼叫 unwrap() 就會發生 crash
    for i in 0..10 {
        tx.send(i).unwrap(); // send() 不是 blocking call
    }

    // drop tx 時,會讓 rx.recv() 收到 Err(_)
    drop(tx);

    // 等待 thread 結束
    join_handle.join().unwrap();
}

執行結果


Received 0
Received 1
Received 2
Received 3
Received 4
Received 5
Received 6
Received 7
Received 8
Received 9

Web Framework


如果要開發 web project,參考這兩個網頁的資訊


What are the best web frameworks for Rust?


Rust web framework comparison


可選用Actix 或是 Rocket,其中 Actix 的支援範圍最廣,支援 https, http client, WebSocket, asynchronous


另外因為 Rust 可直接編譯為 WebAssembly,故可以在網頁上運作。


使用 Actix 開發的 web project,雖然看起來不適合開發大型動態網頁的資料,但應該會適合開發 microservice,提供網頁微服務。


References


Rust wiki


「Rust」可進行安全的系統程式設計


如何看待 Rust 的應用前景?


[Rust] 程式設計教學:基礎概念


【譯】Tokio 內部機制:從頭理解 Rust 非同步 I/O 框架


我們為什麼要選擇小眾語言 Rust 來實現 TiKV?


【譯】Rust vs. Go


明明很好很強大,Rust 卻還是那麼小眾


「RustConAsia 2019」如何高效學習Rust

2019/7/1

WebSocket Support in Jetty


WebSocket 是在 http protocol 進行雙向通訊傳輸的協定,可以用 UTF-8 Text 或 Binary format。message 沒有長度限制,但 framing 有限制長度。可發送無限個訊息。訊息必須依照順序傳送,無法支援 interleaved messages。


WebSocket connection state


有四種


State Description
Connecting 當 HTTP upgrade 到 Websocket
Open socket is open, ready to read/write
Closing 啟動 WebSocket Close Handshake
Closed websocket is closed

WebSocket Events


Event Description
on Connect 成功連線,會收到 org.eclipse.jetty.websocket.api.Session object reference,這是該 socket 的 session
on Close 會有 Status Code
on Error websocket 發生 error
on Message 代表收到完整的 message,可以是 UTF-8 Text 或是 raw BINARY message

Jetty 提供的 WebSocket Spec


  • RFC-6455

    目前支援 WebSocket Protocol version 13

  • JSR-356

    Java WebScoket API (javax.webscoket),這是處理 websocket 的標準 java API


目前還不穩定的功能


  • perframe-compression

    Per Frame Compression Extension


    這是 Google/Chromium team 提供的 frame compression,但還在 early draft,Jetty 支援 draft-04 spec,目前已經被 permessage-compression 取代

  • permessage-compression

    Per Message Compression Extension


    將壓縮改為整個 message,而不是每一個 frame


WebSocket Session


websocket Session 物件有以下的使用方式


  1. 檢查 connection state (opened or not)


    if(session.isOpen()) {
    }
  2. 檢查 secure


    if(session.isSecure()) {
      // connection is using 'wss://'
    }
  3. 有哪些在 Upgrade Request and Response


    UpgradeRequest req = session.getUpgradeRequest();
    String channelName = req.getParameterMap().get("channelName");
    
    UpgradeResponse resp = session.getUpgradeResponse();
    String subprotocol = resp.getAcceptedSubProtocol();
  4. 取得 Local and Remote Address


    InetSocketAddress remoteAddr = session.getRemoteAddress();
  5. 存取 idle timeout


    session.setIdleTimeout(2000); // 2 second timeout

Jetty WebSocket API


同時支援 server 及 client


要開發 Jetty Websocket 程式,首先要在 Maven POM 加上 library,因測試同時要支援 RFC-6455 及 JSR-356,故同時加上了兩種 library


<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>tw.com.maxkit</groupId>
    <artifactId>test</artifactId>
    <version>0.1</version>

    <properties>
        <jetty.version>9.4.12.v20180830</jetty.version>
    </properties>

    <build>
        <plugins>
            <plugin>
                <groupId>org.eclipse.jetty</groupId>
                <artifactId>jetty-maven-plugin</artifactId>
                <version>${jetty.version}</version>
                <configuration>
                    <scanIntervalSeconds>2</scanIntervalSeconds>
                </configuration>
            </plugin>
        </plugins>
    </build>

    <dependencies>
        <!--Jetty dependencies start here -->
        <dependency>
            <groupId>org.eclipse.jetty</groupId>
            <artifactId>jetty-server</artifactId>
            <version>${jetty.version}</version>
        </dependency>

        <dependency>
            <groupId>org.eclipse.jetty</groupId>
            <artifactId>jetty-servlet</artifactId>
            <version>${jetty.version}</version>
        </dependency>
        <!--Jetty dependencies end here -->

        <!--Jetty Websocket server side dependencies start here -->
        <!--Jetty JSR-356 Websocket server side dependency -->
        <dependency>
            <groupId>org.eclipse.jetty.websocket</groupId>
            <artifactId>javax-websocket-server-impl</artifactId>
            <version>${jetty.version}</version>
        </dependency>

        <!--Jetty Websocket API server side dependency -->
        <dependency>
            <groupId>org.eclipse.jetty.websocket</groupId>
            <artifactId>websocket-server</artifactId>
            <version>${jetty.version}</version>
        </dependency>
        <!--Jetty Websocket server dependencies end here -->


        <!--Jetty Websocket client side dependencies start here -->
        <!--JSR-356 Websocket client side depencency  -->
        <dependency>
            <groupId>org.eclipse.jetty.websocket</groupId>
            <artifactId>javax-websocket-client-impl</artifactId>
            <version>${jetty.version}</version>
        </dependency>

        <!--Jetty Websocket API client side dependency -->
        <dependency>
            <groupId>org.eclipse.jetty.websocket</groupId>
            <artifactId>websocket-client</artifactId>
            <version>${jetty.version}</version>
        </dependency>
        <!--Jetty Websocket client side  dependencies end here -->

    </dependencies>

</project>

RFC-6455 websocket Server

首先要將 Jetty path 透過 WebSocketServlet 跟 WebSocket class 綁定。


以下是 ToUpperWebSocketServlet 的 servlet,會處理 /toUpper 這個 url,因為在 IDE 裡面,通常會將 webapp 對應到某個 context,假設是 test,那麼 websocket 服務的 url,應該是 ws://localhost:8080/test/toUpper


ToUpperWebSocketServlet.java


package tw.com.maxkit.jetty.server;

import javax.servlet.annotation.WebServlet;

import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;

@WebServlet(name = "ToUpper WebSocket Servlet", urlPatterns="/toUpper")
public class ToUpperWebSocketServlet  extends WebSocketServlet{

    @Override
    public void configure(WebSocketServletFactory factory) {
        // set a 10 second timeout
        factory.getPolicy().setIdleTimeout(10000);

//      factory.register(ToUpperWebSocket.class);
//      factory.register(ToUpperWebSocketListener.class);
        factory.register(ToUpperWebSocketAdapter.class);
    }

}

程式裡面設定了 ide timeout 的時間為 10s,另外有三種真正實作 websocket 訊息的方式,如果要使用某一種實作方式,只要調整 register 的 implementation class 即可。


//      factory.register(ToUpperWebSocket.class);
//      factory.register(ToUpperWebSocketListener.class);
        factory.register(ToUpperWebSocketAdapter.class);

  • WebSocket annotation

annotation description
@WebSocket 將這個 POJO 標記為 WebSocket,class 不能是 abstract and public
@OnWebSocketClose (optional) 收到 onClose event
@OnWebSocketMessage (optional) 有兩個 method,分別是 TEXT 與 BINARY message
@OnWebSocketError (optional) 收到 error event
@OnWebSocketFrame (optional) 收到 frame event

ToUppderWebSocket.java


package tw.com.maxkit.jetty.server;

import java.io.IOException;

import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketConnect;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;

@WebSocket
public class ToUpperWebSocket {

    @OnWebSocketMessage
    public void onText(Session session, String message) throws IOException {
        System.out.println("ToUpperWebSocket received:" + message);
        if (session.isOpen()) {
            String response = message.toUpperCase();
            session.getRemote().sendString(response);
        }
    }

    @OnWebSocketConnect
    public void onConnect(Session session) throws IOException {
        System.out.println( session.getRemoteAddress().getHostName() + " connected!");
    }

    @OnWebSocketClose
    public void onClose(Session session, int status, String reason) {
        System.out.println(session.getRemoteAddress().getHostName() + " closed!");
    }

}

  • WebSocketListener

ToUpperWebSocketListener.java


package tw.com.maxkit.jetty.server;

import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;

public class ToUpperWebSocketListener implements WebSocketListener {
    private Session outbound;

    public void onWebSocketBinary(byte[] payload, int offset, int len) {
        /* only interested in text messages */
    }

    public void onWebSocketClose(int statusCode, String reason) {
        this.outbound = null;
    }

    public void onWebSocketConnect(Session session) {
        this.outbound = session;
    }

    public void onWebSocketError(Throwable cause) {
        cause.printStackTrace(System.err);
    }

    public void onWebSocketText(String message) {
        if ((outbound != null) && (outbound.isOpen())) {
            System.out.printf("ToUpperWebSocketListener [%s]%n", message);
            // echo the message back
            outbound.getRemote().sendString(message.toUpperCase(), null);
        }
    }
}

  • WebSocketAdpapter

比 listener 簡單,提供檢查 session state 的 methods


ToUpperWebSocketAdapter.java


package tw.com.maxkit.jetty.server;


import org.eclipse.jetty.websocket.api.WebSocketAdapter;

import java.io.IOException;

public class ToUpperWebSocketAdapter extends WebSocketAdapter
{
    @Override
    public void onWebSocketText(String message)
    {
        if (isConnected())
        {
            try
            {
                System.out.printf("ToUpperWebSocketAdapter received: [%s]%n",message);
                // echo the message back
                getRemote().sendString(message.toUpperCase());
            }
            catch (IOException e)
            {
                e.printStackTrace(System.err);
            }
        }
    }
}

JSR-356 websocket Server

在網址 ws://localhost:8008/test/jsr356toUpper 提供服務


ToUpper356Socket.java


package tw.com.maxkit.jsr356.server;

import java.io.IOException;

import javax.websocket.CloseReason;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;

@ServerEndpoint("/jsr356toUpper")
public class ToUpper356Socket {

    @OnOpen
    public void onOpen(Session session) {
        System.out.println("WebSocket opened: " + session.getId());
    }
    @OnMessage
    public void onMessage(String txt, Session session) throws IOException {
        System.out.println("Message received: " + txt);
        session.getBasicRemote().sendText(txt.toUpperCase());
    }

    @OnClose
    public void onClose(CloseReason reason, Session session) {
        System.out.println("Closing a WebSocket due to " + reason.getReasonPhrase());

    }
}

測試網頁


websocketecho.html


<html>
<body>
    <div>
        <input type="text" id="input" />
    </div>
    <div>
        <input type="button" id="connectBtn" value="CONNECT"
            onclick="connect()" /> <input type="button" id="sendBtn"
            value="SEND" onclick="send()" disabled="true" />
    </div>
    <div id="output">
        <p>Output</p>
    </div>
</body>

<script type="text/javascript">
    var webSocket;
    var output = document.getElementById("output");
    var connectBtn = document.getElementById("connectBtn");
    var sendBtn = document.getElementById("sendBtn");

    function connect() {
        // oprn the connection if one does not exist
        if (webSocket !== undefined
                && webSocket.readyState !== WebSocket.CLOSED) {
            return;
        }
        // Create a websocket
        webSocket = new WebSocket("ws://localhost:8080/test/toUpper");

        webSocket.onopen = function(event) {
            updateOutput("Connected!");
            connectBtn.disabled = true;
            sendBtn.disabled = false;

        };

        webSocket.onmessage = function(event) {
            updateOutput(event.data);
        };

        webSocket.onclose = function(event) {
            updateOutput("Connection Closed");
            connectBtn.disabled = false;
            sendBtn.disabled = true;
        };
    }

    function send() {
        var text = document.getElementById("input").value;
        webSocket.send(text);
    }

    function closeSocket() {
        webSocket.close();
    }

    function updateOutput(text) {
        output.innerHTML += "<br/>" + text;
    }
</script>
</html>

WebSocket Client


client 同樣分 RFC-6455 與 JSR-356 兩種


RFC-6455

WebSocketClientMain.java


package tw.com.maxkit.jetty.client;

import java.net.URI;

import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;

public class WebSocketClientMain {

    public static void main(String[] args) {
        String dest = "ws://localhost:8080/test/toUpper";
        WebSocketClient client = new WebSocketClient();
        try {
            
            ToUpperClientSocket socket = new ToUpperClientSocket();
            client.start();
            URI echoUri = new URI(dest);
            ClientUpgradeRequest request = new ClientUpgradeRequest();
            client.connect(socket, echoUri, request);
            socket.getLatch().await();
            socket.sendMessage("echo");
            socket.sendMessage("test");
            Thread.sleep(10000l);

        } catch (Throwable t) {
            t.printStackTrace();
        } finally {
            try {
                client.stop();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}

ToUpperClientSocket.java


package tw.com.maxkit.jetty.client;

import java.io.IOException;
import java.util.concurrent.CountDownLatch;

import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketConnect;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;

@WebSocket
public class ToUpperClientSocket {

    private Session session;
    
    CountDownLatch latch= new CountDownLatch(1);

    @OnWebSocketMessage
    public void onText(Session session, String message) throws IOException {
        System.out.println("Message received from server:" + message);
    }

    @OnWebSocketConnect
    public void onConnect(Session session) {
        System.out.println("Connected to server");
        this.session=session;
        latch.countDown();
    }
    
    public void sendMessage(String str) {
        try {
            session.getRemote().sendString(str);
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }
    
    public CountDownLatch getLatch() {
        return latch;
    }

}

JSR-356 Client

WebSocket356ClientMain.java


package tw.com.maxkit.jsr356.client;

import java.net.URI;

import javax.websocket.ContainerProvider;
import javax.websocket.WebSocketContainer;

public class WebSocket356ClientMain {

    public static void main(String[] args) {
    
        try {

            String dest = "ws://localhost:8080/test/jsr356toUpper";
            ToUpper356ClientSocket socket = new ToUpper356ClientSocket();
            WebSocketContainer container = ContainerProvider.getWebSocketContainer();
            container.connectToServer(socket, new URI(dest));

            socket.getLatch().await();
            socket.sendMessage("echo356");
            socket.sendMessage("test356");
            Thread.sleep(10000l);

        } catch (Throwable t) {
            t.printStackTrace();
        }
    }
}

ToUpper356ClientSocket.java


package tw.com.maxkit.jsr356.client;

import java.io.IOException;
import java.util.concurrent.CountDownLatch;

import javax.websocket.ClientEndpoint;
import javax.websocket.CloseReason;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;

@ClientEndpoint
public class ToUpper356ClientSocket {

    CountDownLatch latch = new CountDownLatch(1);
    private Session session;

    @OnOpen
    public void onOpen(Session session) {
        System.out.println("Connected to server");
        this.session = session;
        latch.countDown();
    }

    @OnMessage
    public void onText(String message, Session session) {
        System.out.println("Message received from server:" + message);
    }

    @OnClose
    public void onClose(CloseReason reason, Session session) {
        System.out.println("Closing a WebSocket due to " + reason.getReasonPhrase());
    }

    public CountDownLatch getLatch() {
        return latch;
    }

    public void sendMessage(String str) {
        try {
            session.getBasicRemote().sendText(str);
        } catch (IOException e) {

            e.printStackTrace();
        }
    }
}



Sending Message to Remote Endpoint


發送訊息有幾種方式


Blocking Send Message

在完成訊息發送後,該 method 才會 return


這是發送 binary message


RemoteEndpoint remote = session.getRemote();

// Blocking Send of a BINARY message to remote endpoint
ByteBuffer buf = ByteBuffer.wrap(new byte[] { 0x11, 0x22, 0x33, 0x44 });
try
{
    remote.sendBytes(buf);
}
catch (IOException e)
{
    e.printStackTrace(System.err);
}

這是發送 text message


RemoteEndpoint remote = session.getRemote();

// Blocking Send of a TEXT message to remote endpoint
try
{
    remote.sendString("Hello World");
}
catch (IOException e)
{
    e.printStackTrace(System.err);
}

發送 Partial Message

如果有個大訊息,希望切割成多個部分,可利用 partial message sending methods,最後一個的 isLast == true


binary message


RemoteEndpoint remote = session.getRemote();

// Blocking Send of a BINARY message to remote endpoint
// Part 1
ByteBuffer buf1 = ByteBuffer.wrap(new byte[] { 0x11, 0x22 });
// Part 2 (last part)
ByteBuffer buf2 = ByteBuffer.wrap(new byte[] { 0x33, 0x44 });
try
{
    remote.sendPartialBytes(buf1,false);
    remote.sendPartialBytes(buf2,true); // isLast is true
}
catch (IOException e)
{
    e.printStackTrace(System.err);
}

text message


RemoteEndpoint remote = session.getRemote();

// Blocking Send of a TEXT message to remote endpoint
String part1 = "Hello";
String part2 = " World";
try
{
    remote.sendPartialString(part1,false);
    remote.sendPartialString(part2,true); // last part
}
catch (IOException e)
{
    e.printStackTrace(System.err);
}

發送 Ping / Pong Control Frame

PING


RemoteEndpoint remote = session.getRemote();

// Blocking Send of a PING to remote endpoint
String data = "You There?";
ByteBuffer payload = ByteBuffer.wrap(data.getBytes());
try
{
    remote.sendPing(payload);
}
catch (IOException e)
{
    e.printStackTrace(System.err);
}

PONG


RemoteEndpoint remote = session.getRemote();

// Blocking Send of a PONG to remote endpoint
String data = "Yup, I'm here";
ByteBuffer payload = ByteBuffer.wrap(data.getBytes());
try
{
    remote.sendPong(payload);
}
catch (IOException e)
{
    e.printStackTrace(System.err);
}

發非同步訊息發送

有兩個 async send message methods


  • RemoteEndpoint.sendBytesByFuture(ByteBuffer message)
  • RemoteEndpoint.sendStringByFuture(String message)

會回傳 java.util.concurrent.Future,可用來測試是否有發送成功


binary


RemoteEndpoint remote = session.getRemote();

// Async Send of a BINARY message to remote endpoint
ByteBuffer buf = ByteBuffer.wrap(new byte[] { 0x11, 0x22, 0x33, 0x44 });
remote.sendBytesByFuture(buf);

可利用 get 等待發送是否完成


RemoteEndpoint remote = session.getRemote();

// Async Send of a BINARY message to remote endpoint
ByteBuffer buf = ByteBuffer.wrap(new byte[] { 0x11, 0x22, 0x33, 0x44 });
try
{
    Future<Void> fut = remote.sendBytesByFuture(buf);
    // wait for completion (forever)
    fut.get();
}
catch (ExecutionException | InterruptedException e)
{
    // Send failed
    e.printStackTrace();
}

可在 get 加上 timeout 時間


RemoteEndpoint remote = session.getRemote();

// Async Send of a BINARY message to remote endpoint
ByteBuffer buf = ByteBuffer.wrap(new byte[] { 0x11, 0x22, 0x33, 0x44 });
Future<Void> fut = null;
try
{
    fut = remote.sendBytesByFuture(buf);
    // wait for completion (timeout)
    fut.get(2,TimeUnit.SECONDS);
}
catch (ExecutionException | InterruptedException e)
{
    // Send failed
    e.printStackTrace();
}
catch (TimeoutException e)
{
    // timeout
    e.printStackTrace();
    if (fut != null)
    {
        // cancel the message
        fut.cancel(true);
    }
}

text 訊息跟 binary 類似,只是將 sendBytesByFuture 換成 sendStringByFuture


References


Jetty WebSocket Example


Chapter 27. Jetty Websocket API