Skip to content

Commit

Permalink
fix: multi-class vs binary classifier (#1191)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewferlitsch authored Oct 26, 2022
1 parent 4f09c94 commit 1a538fd
Showing 1 changed file with 35 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -297,25 +297,32 @@
{
"cell_type": "markdown",
"metadata": {
"id": "timestamp"
"id": "06571eb4063b"
},
"source": [
"#### Timestamp\n",
"#### UUID\n",
"\n",
"If you are in a live tutorial session, you might be using a shared test account or project. To avoid name collisions between users on resources created, you create a timestamp for each instance session, and append the timestamp onto the name of resources you create in this tutorial."
"If you are in a live tutorial session, you might be using a shared test account or project. To avoid name collisions between users on resources created, you create a uuid for each instance session, and append it onto the name of resources you create in this tutorial."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JYtXOocrox9Q"
"id": "4e166d927e36"
},
"outputs": [],
"source": [
"from datetime import datetime\n",
"import random\n",
"import string\n",
"\n",
"TIMESTAMP = datetime.now().strftime(\"%Y%m%d%H%M%S\")"
"\n",
"# Generate a uuid of a specifed length(default=8)\n",
"def generate_uuid(length: int = 8) -> str:\n",
" return \"\".join(random.choices(string.ascii_lowercase + string.digits, k=length))\n",
"\n",
"\n",
"UUID = generate_uuid()"
]
},
{
Expand Down Expand Up @@ -362,12 +369,11 @@
"import sys\n",
"\n",
"# If on Vertex AI Workbench, then don't execute this code\n",
"IS_COLAB = False\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"if not os.path.exists(\"/opt/deeplearning/metadata/env_version\") and not os.getenv(\n",
" \"DL_ANACONDA_HOME\"\n",
"):\n",
" if \"google.colab\" in sys.modules:\n",
" IS_COLAB = True\n",
" from google.colab import auth as google_auth\n",
"\n",
" google_auth.authenticate_user()\n",
Expand Down Expand Up @@ -402,7 +408,8 @@
},
"outputs": [],
"source": [
"BUCKET_URI = \"gs://[your-bucket-name]\" # @param {type:\"string\"}"
"BUCKET_NAME = \"[your-bucket-name]\" # @param {type:\"string\"}\n",
"BUCKET_URI = f\"gs://{BUCKET_NAME}\""
]
},
{
Expand All @@ -413,8 +420,9 @@
},
"outputs": [],
"source": [
"if BUCKET_URI == \"\" or BUCKET_URI is None or BUCKET_URI == \"gs://[your-bucket-name]\":\n",
" BUCKET_URI = \"gs://\" + PROJECT_ID + \"aip-\" + TIMESTAMP"
"if BUCKET_NAME == \"\" or BUCKET_NAME is None or BUCKET_NAME == \"[your-bucket-name]\":\n",
" BUCKET_NAME = PROJECT_ID + \"aip-\" + UUID\n",
" BUCKET_URI = \"gs://\" + BUCKET_NAME"
]
},
{
Expand Down Expand Up @@ -749,6 +757,7 @@
"import hypertune\n",
"import argparse\n",
"import logging\n",
"import numpy as np\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
Expand Down Expand Up @@ -790,16 +799,23 @@
"def train_model(dtrain):\n",
" logging.info(\"Start training ...\")\n",
" # Train XGBoost model\n",
" model = xgb.train({}, dtrain, num_boost_round=args.boost_rounds)\n",
" params = {\n",
" 'objective': 'multi:softprob',\n",
" 'num_class': 3\n",
" }\n",
" model = xgb.train(params, dtrain, num_boost_round=args.boost_rounds)\n",
" logging.info(\"Training completed\")\n",
" return model\n",
"\n",
"def evaluate_model(model, test_data, test_labels):\n",
" dtest = xgb.DMatrix(test_data)\n",
" pred = model.predict(dtest)\n",
" predictions = [round(value) for value in pred]\n",
" predictions = [np.around(value) for value in pred]\n",
" # evaluate predictions\n",
" accuracy = accuracy_score(test_labels, predictions)\n",
" try:\n",
" accuracy = accuracy_score(test_labels, predictions)\n",
" except:\n",
" accuracy = 0.0\n",
" logging.info(f\"Evaluation completed with model accuracy: {accuracy}\")\n",
"\n",
" # report metric for hyperparameter tuning\n",
Expand Down Expand Up @@ -893,7 +909,7 @@
},
"outputs": [],
"source": [
"DISPLAY_NAME = \"iris_\" + TIMESTAMP\n",
"DISPLAY_NAME = \"iris_\" + UUID\n",
"\n",
"job = aip.CustomPythonPackageTrainingJob(\n",
" display_name=DISPLAY_NAME,\n",
Expand Down Expand Up @@ -932,7 +948,7 @@
},
"outputs": [],
"source": [
"MODEL_DIR = \"{}/{}\".format(BUCKET_URI, TIMESTAMP)\n",
"MODEL_DIR = \"{}/{}\".format(BUCKET_URI, UUID)\n",
"DATASET_DIR = \"gs://cloud-samples-data/ai-platform/iris\"\n",
"\n",
"ROUNDS = 20\n",
Expand Down Expand Up @@ -983,7 +999,7 @@
"source": [
"if TRAIN_GPU:\n",
" model = job.run(\n",
" model_display_name=\"iris_\" + TIMESTAMP,\n",
" model_display_name=\"iris_\" + UUID,\n",
" args=CMDARGS,\n",
" replica_count=1,\n",
" machine_type=TRAIN_COMPUTE,\n",
Expand All @@ -994,7 +1010,7 @@
" )\n",
"else:\n",
" model = job.run(\n",
" model_display_name=\"iris_\" + TIMESTAMP,\n",
" model_display_name=\"iris_\" + UUID,\n",
" args=CMDARGS,\n",
" replica_count=1,\n",
" machine_type=TRAIN_COMPUTE,\n",
Expand Down Expand Up @@ -1095,7 +1111,7 @@
},
"outputs": [],
"source": [
"delete_bucket = False\n",
"delete_bucket = True\n",
"\n",
"if delete_bucket or os.getenv(\"IS_TESTING\"):\n",
" ! gsutil rm -r $BUCKET_URI"
Expand Down

0 comments on commit 1a538fd

Please sign in to comment.