Skip to content

Commit

Permalink
model tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
gruwnald committed Jan 29, 2024
1 parent b16724c commit 481aece
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 154 deletions.
133 changes: 74 additions & 59 deletions analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 2,
"id": "e1ebdb68d9ddea42",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-01-23T19:36:34.757732700Z",
"start_time": "2024-01-23T19:36:34.750485600Z"
}
},
"collapsed": false
},
"outputs": [],
"source": [
"from pyspark.ml.feature import VectorAssembler, MinMaxScaler\n",
"from pyspark.ml.feature import VectorAssembler, StandardScaler\n",
"from pyspark.sql.functions import when, col, array_max, udf, expr, size\n",
"from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier, MultilayerPerceptronClassifier\n",
"from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
"from pyspark.ml import Pipeline\n",
"from pyspark.sql import SparkSession\n",
"from pyspark.sql.types import IntegerType, ArrayType, StringType\n",
"from pyspark.sql.types import IntegerType, ArrayType\n",
"\n",
"import os\n",
"import sys\n",
Expand All @@ -40,7 +40,8 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "9f76b9ee081a89f6",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-23T18:10:24.912647600Z",
Expand Down Expand Up @@ -71,8 +72,7 @@
" .option(\"rowTag\", \"row\") \\\n",
" .load(\"tex.stackexchange.com/Tags.xml\") \\\n",
" .alias('tags')"
],
"id": "9f76b9ee081a89f6"
]
},
{
"cell_type": "markdown",
Expand All @@ -86,7 +86,8 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "dcab8de6b06d41c6",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-23T18:10:26.638772500Z",
Expand All @@ -103,12 +104,12 @@
" return [tag_counts.get(tag, 0) for tag in tags]\n",
"\n",
"replace_tags_with_counts_udf = udf(replace_tags_with_counts, ArrayType(IntegerType()))"
],
"id": "dcab8de6b06d41c6"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "efd7bbf74f912b60",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-23T18:10:26.930372300Z",
Expand Down Expand Up @@ -136,12 +137,12 @@
" (questions._CreationDate - users._CreationDate).cast(\"integer\").alias(\"UserExperience\"),\n",
" when(col(\"_AcceptedAnswerId\").isNull(), 0).otherwise(1).alias(\"Accepted\")\n",
")"
],
"id": "efd7bbf74f912b60"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "94af3a489842a4d9",
"metadata": {
"ExecuteTime": {
"end_time": "2024-01-23T18:10:58.589420600Z",
Expand All @@ -156,36 +157,43 @@
"+----------+----------+-----------+------------+------------+-------+--------------+------------+---------------+----------+--------------+--------+\n",
"|QuestionId|BodyLength|TitleLength|TagsCountMax|NumberOfTags|OwnerId|OwnerDownVotes|OwnerUpVotes|OwnerReputation|OwnerViews|UserExperience|Accepted|\n",
"+----------+----------+-----------+------------+------------+-------+--------------+------------+---------------+----------+--------------+--------+\n",
"|515 |253 |11 |1424 |4 |22 |0 |1 |183 |14 |173490 |1 |\n",
"|36756 |40 |7 |2906 |3 |29 |1 |32 |15947 |415 |42537642 |1 |\n",
"|611902 |95 |6 |3490 |1 |29 |1 |32 |15947 |415 |349653635 |0 |\n",
"|148 |45 |5 |11290 |3 |34 |3 |57 |3295 |129 |5373 |1 |\n",
"|2769 |94 |12 |6023 |3 |54 |0 |18 |175 |208 |3544842 |1 |\n",
"+----------+----------+-----------+------------+------------+-------+--------------+------------+---------------+----------+--------------+--------+\n"
"| 515| 253| 11| 1424| 4| 22| 0| 1| 183| 14| 173490| 1|\n",
"| 36756| 40| 7| 2906| 3| 29| 1| 32| 15947| 415| 42537642| 1|\n",
"| 611902| 95| 6| 3490| 1| 29| 1| 32| 15947| 415| 349653635| 0|\n",
"| 148| 45| 5| 11290| 3| 34| 3| 57| 3295| 129| 5373| 1|\n",
"| 2769| 94| 12| 6023| 3| 54| 0| 18| 175| 208| 3544842| 1|\n",
"| 458427| 64| 14| 80| 1| 54| 0| 18| 175| 208| 261214023| 1|\n",
"| 2855| 89| 10| 2844| 1| 65| 6| 61| 1354| 184| 3711499| 1|\n",
"| 3363| 347| 8| 34156| 2| 65| 6| 61| 1354| 184| 4942559| 1|\n",
"| 32800| 72| 9| 34156| 1| 65| 6| 61| 1354| 184| 39473702| 1|\n",
"| 36733| 128| 5| 34156| 1| 65| 6| 61| 1354| 184| 42517217| 1|\n",
"+----------+----------+-----------+------------+------------+-------+--------------+------------+---------------+----------+--------------+--------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"questions.show(5, truncate=False)"
],
"id": "94af3a489842a4d9"
"questions.show(10)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 28,
"id": "3c4b86aa3eb93155",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-01-23T19:36:38.204290Z",
"start_time": "2024-01-23T19:36:38.189770800Z"
}
},
"collapsed": false
},
"outputs": [],
"source": [
"features = ['BodyLength', 'TitleLength', 'TagsCountMax', 'NumberOfTags', 'OwnerDownVotes', 'OwnerUpVotes', 'OwnerReputation', 'OwnerViews', 'UserExperience']\n",
"assembler = VectorAssembler(inputCols=features, outputCol=\"features\")\n",
"scaler = MinMaxScaler(inputCol=\"features\", outputCol=\"scaledFeatures\")"
"features = ['NumberOfTags', 'TagsCountMax', 'OwnerUpVotes',\n",
" 'OwnerDownVotes', 'OwnerReputation', 'OwnerViews', 'UserExperience',]\n",
"assembler = VectorAssembler(inputCols=features, outputCol=\"rawfeatures\")\n",
"scaler = StandardScaler(inputCol=\"rawfeatures\", outputCol=\"scaledFeatures\", withMean=True, withStd=True)"
]
},
{
Expand All @@ -210,87 +218,87 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 56,
"id": "bf8144da797e31d6",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-01-23T19:36:39.745638700Z",
"start_time": "2024-01-23T19:36:39.729033400Z"
}
},
"collapsed": false
},
"outputs": [],
"source": [
"train, test = questions.randomSplit([0.7, 0.3], seed=12345)"
"train, test = questions.randomSplit([0.8, 0.2])"
]
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 57,
"id": "d5014e0a1f9a13f5",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-01-23T19:39:27.681290200Z",
"start_time": "2024-01-23T19:36:40.600449200Z"
}
},
"collapsed": false
},
"outputs": [],
"source": [
"# Logistic Regression model\n",
"lr = LogisticRegression(labelCol=\"Accepted\", featuresCol=\"scaledFeatures\")\n",
"lr = LogisticRegression(labelCol=\"Accepted\", featuresCol=\"scaledFeatures\", maxIter=100)\n",
"lr_pipeline = Pipeline(stages=[assembler, scaler, lr])\n",
"lr_model = lr_pipeline.fit(train)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 58,
"id": "346ffc6627dbd121",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-01-23T19:43:02.864138300Z",
"start_time": "2024-01-23T19:39:27.683994200Z"
}
},
"collapsed": false
},
"outputs": [],
"source": [
"# Random Forest model\n",
"rf = RandomForestClassifier(labelCol=\"Accepted\", featuresCol=\"scaledFeatures\", numTrees=10)\n",
"rf = RandomForestClassifier(labelCol=\"Accepted\", featuresCol=\"scaledFeatures\", numTrees=100)\n",
"rf_pipeline = Pipeline(stages=[assembler, scaler, rf])\n",
"rf_model = rf_pipeline.fit(train)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 59,
"id": "6300714e9c0f9365",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-01-23T19:45:06.401487600Z",
"start_time": "2024-01-23T19:43:02.866133200Z"
}
},
"collapsed": false
},
"outputs": [],
"source": [
"# Gradient Boosting model\n",
"gbt = GBTClassifier(labelCol=\"Accepted\", featuresCol=\"scaledFeatures\", maxIter=10)\n",
"gbt = GBTClassifier(labelCol=\"Accepted\", featuresCol=\"scaledFeatures\", maxIter=100)\n",
"gbt_pipeline = Pipeline(stages=[assembler, scaler, gbt])\n",
"gbt_model = gbt_pipeline.fit(train)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 60,
"id": "f9a3c29342f41f28",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-01-23T19:47:58.408827400Z",
"start_time": "2024-01-23T19:45:06.404734300Z"
}
},
"collapsed": false
},
"outputs": [],
"source": [
Expand All @@ -313,14 +321,14 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 61,
"id": "74504727a2134dc0",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-01-23T19:47:58.843994900Z",
"start_time": "2024-01-23T19:47:58.408827400Z"
}
},
"collapsed": false
},
"outputs": [],
"source": [
Expand All @@ -342,25 +350,25 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 62,
"id": "d172d490df736a99",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-01-23T19:51:49.636614700Z",
"start_time": "2024-01-23T19:47:58.850519100Z"
}
},
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"===== Accuracy =====\n",
"Logistic Regression: 0.6025042444821732\n",
"Random Forest: 0.701573111205433\n",
"Gradient Boosting: 0.7057247453310697\n",
"Neural Network: 0.6203310696095077\n"
"Logistic Regression: 0.6027293555134974\n",
"Random Forest: 0.7045920908457017\n",
"Gradient Boosting: 0.7134575156888137\n",
"Neural Network: 0.66231696384102\n"
]
}
],
Expand All @@ -372,6 +380,13 @@
"print('Gradient Boosting: ', evaluator.evaluate(gbt_predictions))\n",
"print('Neural Network: ', evaluator.evaluate(nn_predictions))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
38 changes: 2 additions & 36 deletions features.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
"from pyspark.ml import Pipeline\n",
"from pyspark.ml.classification import RandomForestClassifier\n",
"from pyspark.ml.feature import VectorAssembler, StandardScaler\n",
"from pyspark.ml.tuning import ParamGridBuilder, CrossValidator\n",
"from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
"from pyspark.sql import SparkSession\n",
"from pyspark.sql.types import IntegerType, ArrayType\n",
"from pyspark.sql.functions import col, when, size, expr, udf, array_max\n",
Expand Down Expand Up @@ -120,38 +118,6 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+----------+-----------+------------+------------+-------+--------------+------------+---------------+----------+--------------+--------+\n",
"|QuestionId|BodyLength|TitleLength|TagsCountMax|NumberOfTags|OwnerId|OwnerDownVotes|OwnerUpVotes|OwnerReputation|OwnerViews|UserExperience|Accepted|\n",
"+----------+----------+-----------+------------+------------+-------+--------------+------------+---------------+----------+--------------+--------+\n",
"| 515| 253| 11| 1424| 4| 22| 0| 1| 183| 14| 173490| 1|\n",
"| 36756| 40| 7| 2906| 3| 29| 1| 32| 15947| 415| 42537642| 1|\n",
"| 611902| 95| 6| 3490| 1| 29| 1| 32| 15947| 415| 349653635| 0|\n",
"| 148| 45| 5| 11290| 3| 34| 3| 57| 3295| 129| 5373| 1|\n",
"| 2769| 94| 12| 6023| 3| 54| 0| 18| 175| 208| 3544842| 1|\n",
"| 458427| 64| 14| 80| 1| 54| 0| 18| 175| 208| 261214023| 1|\n",
"| 2855| 89| 10| 2844| 1| 65| 6| 61| 1354| 184| 3711499| 1|\n",
"| 3363| 347| 8| 34156| 2| 65| 6| 61| 1354| 184| 4942559| 1|\n",
"| 32800| 72| 9| 34156| 1| 65| 6| 61| 1354| 184| 39473702| 1|\n",
"| 36733| 128| 5| 34156| 1| 65| 6| 61| 1354| 184| 42517217| 1|\n",
"+----------+----------+-----------+------------+------------+-------+--------------+------------+---------------+----------+--------------+--------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"data.show(10)"
]
},
{
"cell_type": "markdown",
"id": "8f448264834d27a4",
Expand Down Expand Up @@ -195,7 +161,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "60d52d7e512aefda",
"metadata": {
"ExecuteTime": {
Expand Down Expand Up @@ -224,7 +190,7 @@
],
"source": [
"print(\"Feature Importances:\")\n",
"for feature, importance in sorted(zip(assembler.getInputCols(), result), key=lambda x: x[1], reverse=True):\n",
"for feature, importance in sorted(zip(features, result), key=lambda x: x[1], reverse=True):\n",
" print(f\"{feature}: {importance}\")"
]
},
Expand Down
Loading

0 comments on commit 481aece

Please sign in to comment.