forked from SmirkCao/Lihang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1 update perceptron README; 2 update code; 3 update unit_test
- Loading branch information
Showing
9 changed files
with
234 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
3 3 1 | ||
4 3 1 | ||
1 1 -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
1 1 1 1 -1 | ||
1 0 -1 1 -1 | ||
0 1 -1 1 1 | ||
0 0 -1 -1 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
1 1 1 1 1 -1 | ||
1 1 0 -1 1 -1 | ||
1 0 0 -1 1 -1 | ||
0 1 1 -1 1 1 | ||
0 1 0 -1 1 1 | ||
0 0 0 -1 -1 1 | ||
1 0 1 -1 1 -1 | ||
0 0 1 -1 1 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,59 @@ | ||
# -*-coding:utf-8-*- | ||
# Project: CH2 | ||
# Project: CH02 | ||
# Filename: perceptron | ||
# Author: 😏 <smirk dot cao at gmail dot com> | ||
import pandas as pd | ||
import numpy as np | ||
import random | ||
|
||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import accuracy_score | ||
import argparse | ||
import logging | ||
|
||
|
||
class Perceptron(object): | ||
|
||
def __init__(self, | ||
max_iter=5000, | ||
eta=0.00001): | ||
eta=0.00001, | ||
verbose=True): | ||
self.eta_ = eta | ||
self.max_iter_ = max_iter | ||
self.w = 0 | ||
self.verbose = verbose | ||
|
||
def fit(self, x_, y_): | ||
self.w = np.zeros(x_[0].shape[0] + 1) | ||
def fit(self, X, y): | ||
self.w = np.zeros(X.shape[1] + 1) | ||
correct_count = 0 | ||
n_iter_ = 0 | ||
|
||
while n_iter_ < self.max_iter_: | ||
index = random.randint(0, y_.shape[0] - 1) | ||
xx_ = np.hstack([x_[index], 1]) | ||
yy_ = 2 * y_[index] - 1 | ||
wx = sum((self.w*xx_).T) | ||
index = random.randint(0, y.shape[0] - 1) | ||
xx_ = np.hstack([X[index], 1]) | ||
yy_ = 2 * y[index] - 1 | ||
wx = np.dot(self.w, xx_) | ||
|
||
if wx * yy_ > 0: | ||
correct_count += 1 | ||
if correct_count > self.max_iter_: | ||
break | ||
continue | ||
|
||
self.w += self.eta_*yy_*xx_ | ||
self.w += self.eta_ * yy_ * xx_ | ||
n_iter_ += 1 | ||
|
||
def predict(self, x_): | ||
x_ = np.hstack([x_, np.ones(x_.shape[0]).reshape((-1, 1))]) | ||
rst = np.array([1 if rst else 0 for rst in sum((x_ * self.w).T) > 0]) | ||
if self.verbose: | ||
print(n_iter_) | ||
|
||
def predict(self, X): | ||
# for b | ||
X = np.hstack([X, np.ones(X.shape[0]).reshape((-1, 1))]) | ||
# activation function for perceptron: sign | ||
rst = np.array([1 if rst else -1 for rst in np.dot(X, self.w) > 0]) | ||
# np.sign(0) == 0 | ||
# rst = np.sign(np.dot(X, self.w)) | ||
return rst | ||
|
||
|
||
if __name__ == '__main__': | ||
print('Start read data') | ||
raw_data = pd.read_csv('./data/train_binary.csv', header=0) | ||
data = raw_data.values | ||
|
||
X = data[0::, 1::] | ||
y = data[::, 0] | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=2018) | ||
|
||
print('Start training') | ||
p = Perceptron() | ||
p.fit(X_train, y_train) | ||
|
||
print('Start predicting') | ||
test_predict = p.predict(X_test) | ||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
logger = logging.getLogger(__name__) | ||
|
||
score = accuracy_score(y_test, test_predict) | ||
print("The accruacy socre is ", score) | ||
ap = argparse.ArgumentParser() | ||
ap.add_argument("-p", "--path", required=False, help="path to input data file") | ||
args = vars(ap.parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,115 @@ | ||
# -*-coding:utf-8-*- | ||
# Project: CH2 | ||
# Project: CH02 | ||
# Filename: unit_test | ||
# Author: 😏 <smirk dot cao at gmail dot com> | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import accuracy_score | ||
from sklearn.datasets import load_digits | ||
from perceptron import * | ||
import numpy as np | ||
import argparse | ||
import logging | ||
import unittest | ||
|
||
|
||
def test_logic(x_, y_): | ||
p = Perceptron(max_iter=100, eta=0.01) | ||
p.fit(x_, y_) | ||
print("w,b", p.w) | ||
print(p.predict(x_)) | ||
class TestPerceptron(unittest.TestCase): | ||
|
||
def test_e21(self): | ||
logger.info("test case e21") | ||
# data e2.1 | ||
data_raw = np.loadtxt("Input/data_2-1.txt") | ||
X = data_raw[:, :2] | ||
y = data_raw[:, -1] | ||
clf = Perceptron(eta=1) | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
logger.info(clf.w) | ||
logger.info(str(y_pred)) | ||
self.assertListEqual(y.tolist(), y_pred.tolist()) | ||
|
||
def test_e22(self): | ||
logger.info("test case e22") | ||
# data e2.1 | ||
data_raw = np.loadtxt("Input/data_2-1.txt") | ||
X = data_raw[:, :2] | ||
y = data_raw[:, -1] | ||
clf = Perceptron(verbose=False) | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
logger.info(clf.w) | ||
logger.info(str(y_pred)) | ||
self.assertListEqual(y.tolist(), y_pred.tolist()) | ||
|
||
def test_logic_1(self): | ||
# loaddata | ||
data_raw = np.loadtxt("Input/logic_data_1.txt") | ||
X = data_raw[:, :2] | ||
clf = Perceptron(max_iter=100, eta=0.0001, verbose=False) | ||
# test and | ||
y = data_raw[:, 2] | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
logger.info("test case logic_1 and") | ||
self.assertListEqual(y.tolist(), y_pred.tolist()) | ||
# test or | ||
logger.info("test logic_1 or") | ||
y = data_raw[:, 3] | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
self.assertListEqual(y.tolist(), y_pred.tolist()) | ||
# test not | ||
logger.info("test logic_1 not") | ||
y = data_raw[:, 4] | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
self.assertListEqual(y.tolist(), y_pred.tolist()) | ||
|
||
def test_logic_2(self): | ||
# loaddata | ||
data_raw = np.loadtxt("Input/logic_data_2.txt") | ||
X = data_raw[:, :3] | ||
clf = Perceptron(max_iter=100, eta=0.0001, verbose=False) | ||
# test and | ||
y = data_raw[:, 3] | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
logger.info("test case logic_2 and") | ||
self.assertListEqual(y.tolist(), y_pred.tolist()) | ||
# test or | ||
logger.info("test logic_2 or") | ||
y = data_raw[:, 4] | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
self.assertListEqual(y.tolist(), y_pred.tolist()) | ||
# test not | ||
logger.info("test logic_2 not") | ||
y = data_raw[:, 5] | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
self.assertListEqual(y.tolist(), y_pred.tolist()) | ||
|
||
def test_mnist(self): | ||
raw_data = load_digits(n_class=2) | ||
X = raw_data.data | ||
y = raw_data.target | ||
# 0和1比较容易分辨吧 | ||
y[y == 0] = -1 | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=2018) | ||
|
||
clf = Perceptron() | ||
clf.fit(X_train, y_train) | ||
test_predict = clf.predict(X_test) | ||
score = accuracy_score(y_test, test_predict) | ||
logger.info("The accruacy socre is %2.2f" % score) | ||
|
||
|
||
if __name__ == '__main__': | ||
X = np.array([[1, 1], [1, 0], [0, 1], [0, 0]]) | ||
y_and = np.array([1, 0, 0, 0]) | ||
y_or = np.array([1, 1, 1, 0]) | ||
y_not = np.array([0, 0, 1, 1]) | ||
""" | ||
学习率大分不开 | ||
and | ||
w,b [ 0.02 0.01 0.03 -0.05] | ||
[1 0 0 0 0 0 0 0] | ||
or | ||
w,b [ 0.02 0.02 0.02 -0.01] | ||
[1 1 1 1 1 0 1 1] | ||
not | ||
w,b [-0.03 0.01 0. 0.01] | ||
[0 0 0 1 1 1 0 1] | ||
""" | ||
X = np.array([[1, 1, 1], [1, 1, 0], [1, 0, 0], | ||
[0, 1, 1], [0, 1, 0], | ||
[0, 0, 0], [1, 0, 1], [0, 0, 1]]) | ||
y_and = np.array([1, 0, 0, 0, 0, 0, 0, 0]) | ||
y_or = np.array([1, 1, 1, 1, 1, 0, 1, 1]) | ||
y_not = np.array([0, 0, 0, 1, 1, 1, 0, 1]) | ||
""" | ||
and | ||
w,b [ 3. 3. 1. -4.] | ||
[1 1 0 0 0 0 0 0] | ||
or | ||
w,b [ 2. 2. 2. -1.] | ||
[1 1 1 1 1 0 1 1] | ||
not | ||
w,b [-4. 1. 0. 1.] | ||
[0 0 0 1 1 1 0 1] | ||
""" | ||
print("and") | ||
test_logic(X, y_and) | ||
print("or") | ||
test_logic(X, y_or) | ||
print("not") | ||
test_logic(X, y_not) | ||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
logger = logging.getLogger(__name__) | ||
|
||
ap = argparse.ArgumentParser() | ||
ap.add_argument("-p", "--path", required=False, help="path to input data file") | ||
args = vars(ap.parse_args()) | ||
|
||
unittest.main() |
Oops, something went wrong.