Skip to content

Commit

Permalink
8강_교차검증_Cross Validation
Browse files Browse the repository at this point in the history
  • Loading branch information
dscoool authored Aug 17, 2021
1 parent 344ba4f commit b165a8b
Showing 1 changed file with 388 additions and 0 deletions.
388 changes: 388 additions & 0 deletions 8강_교차검증.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,388 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"## 8강 교차검증 (Cross Validation)\r\n",
"\r\n",
"데이터를 학습 및 검증할 때에는 \r\n",
"\r\n",
"훈련 세트(Training Set)와 테스트 세트(Test Set)으로 나누어, \r\n",
"\r\n",
"교차 검증하게 됩니다.\r\n",
"\r\n",
"\r\n",
"교차 검증(Cross Validation)은 머신러닝/딥러닝 평가에 필수적으로 사용되는 방법으로 데이터를 통한 모델을 설계한 후 모델을 검증하는 단계라고 볼 수 있습니다. 다시 말해 모델을 추정하는 데 사용되지 않았던 새로운 데이터를 예측하는 일반화 능력을 테스트하는 방법입니다. 교차 검증이 사용되는 이유를 알아보면서 어떤 방법인지 구체적으로 살펴보도록 하겠습니다. 우리가 딥러닝 모델을 사용할 데이터가 아래와 같다고 가정해 봅시다.\r\n",
"\r\n",
"![cross_val.png](attachment:cross_val.png)\r\n",
"\r\n",
"교차 검증을 이해하기 위해 먼저 주어진 데이터를 왜 훈련 세트(Train set)와 테스트 세트(Test set)로 분할하여 사용하는지에 관한 간단한 설명이 필요합니다. 예를 들어 위와 같이 레이블(Label)이 있는 데이터가 훈련 세트(Train set)과 테스트 세트(Test set)로 분할되어 구성되어 있다고 해봅시다. 만약 주어진 테스트 세트만 활용하여 모델의 성능을 확인하고 파라미터를 업데이트하는 일련의 과정을 반복한다면, 결국 해당 모델은 오로지 고정된 테스트 세트에 대해서만 잘 작동하게 됩니다. 다시 말해 테스트 세트에 과적합(overfitting)되어 실제 데이터를 가지고 예측 및 분류를 수행할 경우, 잘 작동하지 못할 가능성이 높습니다. 이러한 문제를 해결하고자 사용하는 방법이 바로 교차 검증(cross validation)입니다. 교차 검증은 테스트 세트를 하나로 고정하는 대신, 데이터의 모든 부분을 사용하여 모델을 검증하게 되며, 변동성을 낮추기 위해 여러 번의 검증 결과를 결합하여 모델의 예측 성능을 추정하게 됩니다. \r\n"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"!pip install scikit-learn"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: scikit-learn in c:\\python39-at\\lib\\site-packages (0.24.2)\n",
"Requirement already satisfied: scipy>=0.19.1 in c:\\python39-at\\lib\\site-packages (from scikit-learn) (1.7.1)\n",
"Requirement already satisfied: joblib>=0.11 in c:\\python39-at\\lib\\site-packages (from scikit-learn) (1.0.1)\n",
"Requirement already satisfied: numpy>=1.13.3 in c:\\python39-at\\lib\\site-packages (from scikit-learn) (1.21.1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\python39-at\\lib\\site-packages (from scikit-learn) (2.2.0)\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING: You are using pip version 21.1.3; however, version 21.2.3 is available.\n",
"You should consider upgrading via the 'c:\\python39-at\\python.exe -m pip install --upgrade pip' command.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"source": [
"import numpy as np"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"X = np.arange(10).reshape((5, 2))"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"X"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[0, 1],\n",
" [2, 3],\n",
" [4, 5],\n",
" [6, 7],\n",
" [8, 9]])"
]
},
"metadata": {},
"execution_count": 4
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"y = np.arange(5)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"y"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0, 1, 2, 3, 4])"
]
},
"metadata": {},
"execution_count": 6
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"source": [
"from sklearn.model_selection import train_test_split\r\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"source": [
"X_train"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[4, 5],\n",
" [0, 1],\n",
" [6, 7]])"
]
},
"metadata": {},
"execution_count": 13
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 14,
"source": [
"y_train"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([2, 0, 3])"
]
},
"metadata": {},
"execution_count": 14
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 15,
"source": [
"X_test"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[2, 3],\n",
" [8, 9]])"
]
},
"metadata": {},
"execution_count": 15
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 16,
"source": [
"y_test"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([1, 4])"
]
},
"metadata": {},
"execution_count": 16
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## CROSS VALIDATION\r\n",
"\r\n"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 17,
"source": [
"N = 5\r\n",
"X = np.arange(8 * N).reshape(-1, 2) * 10\r\n",
"y = np.hstack([np.ones(N), np.ones(N) * 2, np.ones(N) * 3, np.ones(N) * 4])\r\n",
"print(\"X:\\n\", X, sep=\"\")\r\n",
"print(\"y:\\n\", y, sep=\"\")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"X:\n",
"[[ 0 10]\n",
" [ 20 30]\n",
" [ 40 50]\n",
" [ 60 70]\n",
" [ 80 90]\n",
" [100 110]\n",
" [120 130]\n",
" [140 150]\n",
" [160 170]\n",
" [180 190]\n",
" [200 210]\n",
" [220 230]\n",
" [240 250]\n",
" [260 270]\n",
" [280 290]\n",
" [300 310]\n",
" [320 330]\n",
" [340 350]\n",
" [360 370]\n",
" [380 390]]\n",
"y:\n",
"[1. 1. 1. 1. 1. 2. 2. 2. 2. 2. 3. 3. 3. 3. 3. 4. 4. 4. 4. 4.]\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 24,
"source": [
"from sklearn.model_selection import KFold \r\n",
"import numpy as np \r\n",
"\r\n",
"X = np.arange(16).reshape((8,-1)) \r\n",
"y = np.arange(8).reshape((-1,1)) \r\n",
"\r\n",
"kf = KFold(n_splits=4) \r\n",
"\r\n",
"for train_index, test_index in kf.split(X): \r\n",
" print(\"TRAIN:\", train_index, \"TEST:\", test_index) \r\n",
" X_train, X_test = X[train_index], X[test_index] \r\n",
" y_train, y_test = y[train_index], y[test_index]\r\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"TRAIN: [2 3 4 5 6 7] TEST: [0 1]\n",
"TRAIN: [0 1 4 5 6 7] TEST: [2 3]\n",
"TRAIN: [0 1 2 3 6 7] TEST: [4 5]\n",
"TRAIN: [0 1 2 3 4 5] TEST: [6 7]\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"stratified 는 label 의 분포를 유지시켜준다고 생각하면 된다. 즉, 각 fold 안의 데이터셋의 label 분포가 전체 데이터셋의 label 분포를 따른다.\r\n",
"\r\n",
"다시 말해서, 각 fold가 전체 데이터셋을 잘 대표한다.\r\n",
"\r\n",
"모델을 학습시킬 때 편향되지 않게 학습시킬 수 있다\r\n",
"\r\n",
"\r\n",
"\r\n",
"출처: https://sgmath.tistory.com/61 [신나게 공부하자]"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 26,
"source": [
"X = np.arange(12*2).reshape((12,-1)) \r\n",
"y = np.array([0,0,1,2,1,0,0,0,0,1,2,2]) \r\n",
"\r\n",
"skf = StratifiedKFold(n_splits=3) \r\n",
"\r\n",
"for train_index, test_index in skf.split(X, y): \r\n",
" print(\"TRAIN:\", train_index, \"TEST:\", test_index)\r\n",
"\r\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"TRAIN: [ 4 5 6 7 8 9 10 11] TEST: [0 1 2 3]\n",
"TRAIN: [ 0 1 2 3 7 8 9 11] TEST: [ 4 5 6 10]\n",
"TRAIN: [ 0 1 2 3 4 5 6 10] TEST: [ 7 8 9 11]\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## K-fold Cross Validation\r\n",
"\r\n",
"K-겹 교차 검증(K-fold Cross Validation)\r\n",
"K-겹 교차 검증은 가장 일반적으로 사용되는 방법으로 데이터를 K개의 세트로 분할하여 그 세트 중에서 하나를 추출하여 훈련 세트로 사용하는 것입니다. 이 과정에는 K는 단일 파라미터로 주어진 데이터 세트를 분할할 그룹의 수가 됩니다. 데이터를 K개의 데이터로 분할하여 각 반복(iteration)마다 테스트 세트를 다르게 할당한 후, 총 K개의 '데이터 폴드 세트'를 구성하게 됩니다. 아래 예시 그림의 경우 훈련 폴드(Train fold)는 4개, 테스트 폴드(Test fold)는 1개로 이루어져 있으며, 총 5개의 데이터 폴드 세트가 있습니다. 따라서 모델을 학습 및 훈련하는데 총 K번의 반복(iteration)이 필요하며, 일반적으로 각 데이터 폴드 세트에 대한 검증 결과들을 평균하여 최종적인 검증 결과를 도출하게 됩니다. K겹 교차 검증은 보통 일반화 성능을 만족시키는 최적의 하이퍼파리미터를 구하기 위한 모델 튜닝에 사용됩니다.\r\n",
"\r\n",
"![K-fold Cross Validation](kfold_01.png)\r\n",
" \r\n",
" ✓ 계층별 K-겹 교차 검증(Stratified K-fold Cross Validation)\r\n",
"계층별 K-겹 교차 검증은 K-겹 교차 검증의 변형된 방법으로 주로 분류(Classification) 문제에서 사용되며, 레이블(Label)의 분포가 각 클래스 별로 불균형한 경우 유용하게 활용됩니다. 예컨대 레이블의 분포가 불균형한 상태에서 샘플의 인덱스(index) 순으로 데이터 폴드 세트를 구성한다면 데이터를 검증하는데 오류를 야기할 수 있습니다. 이 때, 해당 교차 검증 방법을 통해 데이터 레이블 분포까지 고려하여 각 훈련 또는 테스트 폴드의 분포가 전체 데이터 세트의 분포에 근사하게 됩니다. 따라서 데이터 폴드 세트를 구성하는 함수에서 데이터 레이블의 값이 요구됩니다.\r\n",
"\r\n",
"\r\n",
"![Stratified K-fold Cross Validation](kfold_01.png)\r\n",
"\r\n"
],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"language_info": {
"name": "python",
"version": "3.9.6",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.9.6 64-bit"
},
"interpreter": {
"hash": "63fd5069d213b44bf678585dea6b12cceca9941eaf7f819626cde1f2670de90d"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit b165a8b

Please sign in to comment.