{
"cells": [
{
"cell_type": "markdown",
"id": "8cb96470-6a62-4e18-8595-8273e79949e8",
"metadata": {},
"source": [
"# Using `MLlib` from `pyspark` to Fit Machine Learning Models"
]
},
{
"cell_type": "markdown",
"id": "57124796-3da8-40dd-9bdd-780171c0ccfd",
"metadata": {},
"source": [
"Start our `spark` session."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2bdcc3f9-b2f8-45e5-ba9e-653b7b61fee7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"24/03/08 17:02:19 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
}
],
"source": [
"import pandas as pd\n",
"from pyspark.sql import SparkSession\n",
"spark = SparkSession.builder.getOrCreate()"
]
},
{
"cell_type": "markdown",
"id": "fd113d88-10dc-410f-b8bc-16a898049855",
"metadata": {},
"source": [
"Now read in a data set using `pandas` and convert it to a `spark` SQL style data frame."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9d4e18e1-0cb5-41e1-a472-11089adc8f7a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" name | \n",
" selling_price | \n",
" year | \n",
" seller_type | \n",
" owner | \n",
" km_driven | \n",
" ex_showroom_price | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Royal Enfield Classic 350 | \n",
" 175000 | \n",
" 2019 | \n",
" Individual | \n",
" 1st owner | \n",
" 350 | \n",
" NaN | \n",
"
\n",
" \n",
" 1 | \n",
" Honda Dio | \n",
" 45000 | \n",
" 2017 | \n",
" Individual | \n",
" 1st owner | \n",
" 5650 | \n",
" NaN | \n",
"
\n",
" \n",
" 2 | \n",
" Royal Enfield Classic Gunmetal Grey | \n",
" 150000 | \n",
" 2018 | \n",
" Individual | \n",
" 1st owner | \n",
" 12000 | \n",
" 148114.0 | \n",
"
\n",
" \n",
" 3 | \n",
" Yamaha Fazer FI V 2.0 [2016-2018] | \n",
" 65000 | \n",
" 2015 | \n",
" Individual | \n",
" 1st owner | \n",
" 23000 | \n",
" 89643.0 | \n",
"
\n",
" \n",
" 4 | \n",
" Yamaha SZ [2013-2014] | \n",
" 20000 | \n",
" 2011 | \n",
" Individual | \n",
" 2nd owner | \n",
" 21000 | \n",
" NaN | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" name selling_price year seller_type \\\n",
"0 Royal Enfield Classic 350 175000 2019 Individual \n",
"1 Honda Dio 45000 2017 Individual \n",
"2 Royal Enfield Classic Gunmetal Grey 150000 2018 Individual \n",
"3 Yamaha Fazer FI V 2.0 [2016-2018] 65000 2015 Individual \n",
"4 Yamaha SZ [2013-2014] 20000 2011 Individual \n",
"\n",
" owner km_driven ex_showroom_price \n",
"0 1st owner 350 NaN \n",
"1 1st owner 5650 NaN \n",
"2 1st owner 12000 148114.0 \n",
"3 1st owner 23000 89643.0 \n",
"4 2nd owner 21000 NaN "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bike_data = pd.read_csv(\"bikeDetails.csv\")\n",
"bike_data.head()"
]
},
{
"cell_type": "markdown",
"id": "1e574259-a425-4f57-94f5-4e0e01cb4115",
"metadata": {},
"source": [
"Convert to a spark SQL data frame."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bd795d76-463a-4260-b8de-f999d6592ed4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-------------+----+-----------+---------+---------+-----------------+\n",
"| name|selling_price|year|seller_type| owner|km_driven|ex_showroom_price|\n",
"+--------------------+-------------+----+-----------+---------+---------+-----------------+\n",
"|Royal Enfield Cla...| 175000|2019| Individual|1st owner| 350| NaN|\n",
"| Honda Dio| 45000|2017| Individual|1st owner| 5650| NaN|\n",
"|Royal Enfield Cla...| 150000|2018| Individual|1st owner| 12000| 148114.0|\n",
"|Yamaha Fazer FI V...| 65000|2015| Individual|1st owner| 23000| 89643.0|\n",
"|Yamaha SZ [2013-2...| 20000|2011| Individual|2nd owner| 21000| NaN|\n",
"+--------------------+-------------+----+-----------+---------+---------+-----------------+\n",
"only showing top 5 rows\n",
"\n"
]
}
],
"source": [
"bike = spark.createDataFrame(bike_data)\n",
"bike.show(5)"
]
},
{
"cell_type": "markdown",
"id": "092eeeff-6522-435b-bb2f-c048d4c407e4",
"metadata": {},
"source": [
"We'll fit a linear regression model using log selling price as our response and log km driven, year, and a 1st owner indicator variable as our predictors. This means we need to create some new columns and drop a bunch of others. \n",
"- We also **need to rename the response** as 'label'. \n",
"- These first steps can easily be done using the `SQLTransformer()`, `StringIndexer()`, and `Binarizer()` [functions from `pyspark.ml.feature`](https://spark.apache.org/docs/latest/api/python/reference/pyspark.ml.html)\n",
"- Using the `MLlib` functions will allow us to place these transformations into a `pipeline` (which we'll do shortly). "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d1bc8d13-8a31-4ea3-a73f-8063b5eace96",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml.feature import SQLTransformer, StringIndexer, Binarizer, VectorAssembler"
]
},
{
"cell_type": "markdown",
"id": "8127493d-3737-4d32-80ab-592bb49dcea9",
"metadata": {},
"source": [
"First let's try to create our dummy variable. This isn't as easy as you'd like (nothing in here really is!). We can first take the string to a numeric index. Then we can binary-ize that!\n",
"\n",
"We'll use `StringIndexer()` first. An [example is given here](https://spark.apache.org/docs/latest/ml-features.html#stringindexer). This is an estimator that we can use the `.fit()` method on. Then we can use the `.transform()` method on that."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d5cc9eab-7352-444f-b6b4-f8a7cc041dab",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-------------+----+-----------+---------+---------+-----------------+-------------+\n",
"| name|selling_price|year|seller_type| owner|km_driven|ex_showroom_price|owner_numeric|\n",
"+--------------------+-------------+----+-----------+---------+---------+-----------------+-------------+\n",
"|Royal Enfield Cla...| 175000|2019| Individual|1st owner| 350| NaN| 0.0|\n",
"| Honda Dio| 45000|2017| Individual|1st owner| 5650| NaN| 0.0|\n",
"|Royal Enfield Cla...| 150000|2018| Individual|1st owner| 12000| 148114.0| 0.0|\n",
"|Yamaha Fazer FI V...| 65000|2015| Individual|1st owner| 23000| 89643.0| 0.0|\n",
"|Yamaha SZ [2013-2...| 20000|2011| Individual|2nd owner| 21000| NaN| 1.0|\n",
"| Honda CB Twister| 18000|2010| Individual|1st owner| 60000| 53857.0| 0.0|\n",
"|Honda CB Hornet 160R| 78500|2018| Individual|1st owner| 17000| 87719.0| 0.0|\n",
"|Royal Enfield Bul...| 180000|2008| Individual|2nd owner| 39000| NaN| 1.0|\n",
"|Hero Honda CBZ ex...| 30000|2010| Individual|1st owner| 32000| NaN| 0.0|\n",
"| Bajaj Discover 125| 50000|2016| Individual|1st owner| 42000| 60122.0| 0.0|\n",
"| Yamaha FZ16| 35000|2015| Individual|1st owner| 32000| 78712.0| 0.0|\n",
"| Honda Navi| 28000|2016| Individual|2nd owner| 10000| 47255.0| 1.0|\n",
"|Bajaj Avenger Str...| 80000|2018| Individual|1st owner| 21178| 95955.0| 0.0|\n",
"| Yamaha YZF R3| 365000|2019| Individual|1st owner| 1127| 351680.0| 0.0|\n",
"| Jawa 42| 185000|2020| Individual|1st owner| 1700| NaN| 0.0|\n",
"|Suzuki Access 125...| 25000|2012| Individual|1st owner| 55000| 58314.0| 0.0|\n",
"| Hero Honda Glamour| 25000|2006| Individual|1st owner| 27000| NaN| 0.0|\n",
"| Yamaha YZF R15 S| 40000|2010| Individual|2nd owner| 45000| 117926.0| 1.0|\n",
"|Royal Enfield Cla...| 150000|2018| Individual|1st owner| 23000| 148114.0| 0.0|\n",
"| Yamaha FZ25| 120000|2018| Individual|1st owner| 39000| 132680.0| 0.0|\n",
"|Hero Passion Pro 110| 15000|2008| Individual|1st owner| 60000| NaN| 0.0|\n",
"|Honda Navi [2016-...| 26000|2016| Individual|1st owner| 17450| 44389.0| 0.0|\n",
"| Honda Activa i| 32000|2013| Individual|2nd owner| 20696| 53900.0| 1.0|\n",
"| Jawa Standard| 180000|2019| Individual|1st owner| 2000| NaN| 0.0|\n",
"|Royal Enfield Thu...| 110000|2016| Individual|1st owner| 20000| NaN| 0.0|\n",
"| Honda Dream Yuga| 25000|2012| Individual|1st owner| 35000| 56147.0| 0.0|\n",
"|TVS Apache RTR 16...| 80000|2018| Individual|1st owner| 15210| NaN| 0.0|\n",
"|Honda Navi [2016-...| 42000|2017| Individual|1st owner| 24000| 44389.0| 0.0|\n",
"|Yamaha Fazer [200...| 40000|2013| Individual|3rd owner| 35000| 84751.0| 2.0|\n",
"|Hero Honda Splend...| 21000|2009| Individual|1st owner| 10000| NaN| 0.0|\n",
"+--------------------+-------------+----+-----------+---------+---------+-----------------+-------------+\n",
"only showing top 30 rows\n",
"\n"
]
}
],
"source": [
"indexer = StringIndexer(inputCols = [\"owner\"], outputCols = [\"owner_numeric\"])\n",
"indexerTrans = indexer.fit(bike) #now indexerTrans will have a .transform() method\n",
"indexerTrans.transform(bike).show(30)"
]
},
{
"cell_type": "markdown",
"id": "640ba048-a75f-43ec-8f2b-aaca1209f35c",
"metadata": {},
"source": [
"Alright, now let's conver that to a 0/1 indicator with `Binarizer()`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8dea17f3-5b2f-48d0-b29e-33748801c87c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-------------+----+-----------+---------+---------+-----------------+-------------+---------------+\n",
"| name|selling_price|year|seller_type| owner|km_driven|ex_showroom_price|owner_numeric|owner_indicator|\n",
"+--------------------+-------------+----+-----------+---------+---------+-----------------+-------------+---------------+\n",
"|Royal Enfield Cla...| 175000|2019| Individual|1st owner| 350| NaN| 0.0| 0.0|\n",
"| Honda Dio| 45000|2017| Individual|1st owner| 5650| NaN| 0.0| 0.0|\n",
"|Royal Enfield Cla...| 150000|2018| Individual|1st owner| 12000| 148114.0| 0.0| 0.0|\n",
"|Yamaha Fazer FI V...| 65000|2015| Individual|1st owner| 23000| 89643.0| 0.0| 0.0|\n",
"|Yamaha SZ [2013-2...| 20000|2011| Individual|2nd owner| 21000| NaN| 1.0| 1.0|\n",
"| Honda CB Twister| 18000|2010| Individual|1st owner| 60000| 53857.0| 0.0| 0.0|\n",
"|Honda CB Hornet 160R| 78500|2018| Individual|1st owner| 17000| 87719.0| 0.0| 0.0|\n",
"|Royal Enfield Bul...| 180000|2008| Individual|2nd owner| 39000| NaN| 1.0| 1.0|\n",
"|Hero Honda CBZ ex...| 30000|2010| Individual|1st owner| 32000| NaN| 0.0| 0.0|\n",
"| Bajaj Discover 125| 50000|2016| Individual|1st owner| 42000| 60122.0| 0.0| 0.0|\n",
"| Yamaha FZ16| 35000|2015| Individual|1st owner| 32000| 78712.0| 0.0| 0.0|\n",
"| Honda Navi| 28000|2016| Individual|2nd owner| 10000| 47255.0| 1.0| 1.0|\n",
"|Bajaj Avenger Str...| 80000|2018| Individual|1st owner| 21178| 95955.0| 0.0| 0.0|\n",
"| Yamaha YZF R3| 365000|2019| Individual|1st owner| 1127| 351680.0| 0.0| 0.0|\n",
"| Jawa 42| 185000|2020| Individual|1st owner| 1700| NaN| 0.0| 0.0|\n",
"|Suzuki Access 125...| 25000|2012| Individual|1st owner| 55000| 58314.0| 0.0| 0.0|\n",
"| Hero Honda Glamour| 25000|2006| Individual|1st owner| 27000| NaN| 0.0| 0.0|\n",
"| Yamaha YZF R15 S| 40000|2010| Individual|2nd owner| 45000| 117926.0| 1.0| 1.0|\n",
"|Royal Enfield Cla...| 150000|2018| Individual|1st owner| 23000| 148114.0| 0.0| 0.0|\n",
"| Yamaha FZ25| 120000|2018| Individual|1st owner| 39000| 132680.0| 0.0| 0.0|\n",
"+--------------------+-------------+----+-----------+---------+---------+-----------------+-------------+---------------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
"binaryTrans = Binarizer(threshold = 0.5, inputCol = \"owner_numeric\", outputCol = \"owner_indicator\")\n",
"binaryTrans.transform(\n",
" indexerTrans.transform(bike)\n",
").show()"
]
},
{
"cell_type": "markdown",
"id": "6335b0e2-ef57-4791-9b74-17bd25f88e24",
"metadata": {},
"source": [
"Great! Now the easy part. We just want to create some log variables and select only certain columns. The `SQLTransformer()` allows for (only) basic SQL commands. Note we want our response to be called `label` so we'll rename it here."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5b17fcfb-c334-4bce-a31f-76495305ecb0",
"metadata": {},
"outputs": [],
"source": [
"sqlTrans = SQLTransformer(\n",
" statement = \"\"\"\n",
" SELECT owner_indicator, year, log(km_driven) as log_km_driven, log(selling_price) as label FROM __THIS__\n",
" \"\"\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "eddd2f65-bb66-4854-996c-127f23ba1a79",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------------+----+------------------+------------------+\n",
"|owner_indicator|year| log_km_driven| label|\n",
"+---------------+----+------------------+------------------+\n",
"| 0.0|2019| 5.857933154483459|12.072541252905651|\n",
"| 0.0|2017| 8.639410824140487|10.714417768752456|\n",
"| 0.0|2018| 9.392661928770137|11.918390573078392|\n",
"| 0.0|2015|10.043249494911286|11.082142548877775|\n",
"| 1.0|2011| 9.95227771670556| 9.903487552536127|\n",
"| 0.0|2010|11.002099841204238| 9.798127036878302|\n",
"| 0.0|2018| 9.740968623038354| 11.2708539037705|\n",
"| 1.0|2008|10.571316925111784|12.100712129872347|\n",
"| 0.0|2010|10.373491181781864|10.308952660644293|\n",
"| 0.0|2016|10.645424897265505|10.819778284410283|\n",
"| 0.0|2015|10.373491181781864| 10.46310334047155|\n",
"| 1.0|2016| 9.210340371976184|10.239959789157341|\n",
"| 0.0|2018| 9.9607181859904|11.289781913656018|\n",
"| 0.0|2019| 7.027314514039777| 12.80765263256463|\n",
"| 0.0|2020| 7.438383530044307|12.128111104060462|\n",
"| 0.0|2012|10.915088464214607|10.126631103850338|\n",
"| 0.0|2006|10.203592144986466|10.126631103850338|\n",
"| 1.0|2010|10.714417768752456|10.596634733096073|\n",
"| 0.0|2018|10.043249494911286|11.918390573078392|\n",
"| 0.0|2018|10.571316925111784|11.695247021764184|\n",
"| 0.0|2008|11.002099841204238| 9.615805480084347|\n",
"| 0.0|2016| 9.767094927630573|10.165851817003619|\n",
"| 1.0|2013| 9.937695723865865|10.373491181781864|\n",
"| 0.0|2019| 7.600902459542082|12.100712129872347|\n",
"| 0.0|2016| 9.903487552536127|11.608235644774552|\n",
"| 0.0|2012| 10.46310334047155|10.126631103850338|\n",
"| 0.0|2018| 9.62970838525334|11.289781913656018|\n",
"| 0.0|2017|10.085809109330082|10.645424897265505|\n",
"| 1.0|2013| 10.46310334047155|10.596634733096073|\n",
"| 0.0|2009| 9.210340371976184| 9.95227771670556|\n",
"+---------------+----+------------------+------------------+\n",
"only showing top 30 rows\n",
"\n"
]
}
],
"source": [
"sqlTrans.transform(\n",
" binaryTrans.transform(\n",
" indexerTrans.transform(bike)\n",
" )\n",
").show(30)"
]
},
{
"cell_type": "markdown",
"id": "87f5c404-0e3e-4aa3-9f13-33cb35935c8b",
"metadata": {},
"source": [
"We also **need to put the predictors into a single column** called 'features'. \n",
"Placing multiple columns into one can be done via `VectorAssembler()` from `pyspark.ml.feature`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "cdcf2144-ca15-4c0b-8eaf-70cc823a8ca8",
"metadata": {},
"outputs": [],
"source": [
"assembler = VectorAssembler(inputCols = [\"year\", \"log_km_driven\", \"owner_indicator\"], outputCol = \"features\", handleInvalid = 'keep')"
]
},
{
"cell_type": "markdown",
"id": "a6fb39fb-0197-43ba-96d8-85d702f4e9cf",
"metadata": {},
"source": [
"Notice that we are passing what would be the result columns from the previous SQL transform we did. The `VectorAssembler()` also has a `.transform()` method. Let's see how these would be used together to produce a new data set."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c68a3a66-8303-4f92-8e41-e43b5cc47437",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------------+----+------------------+------------------+--------------------+\n",
"|owner_indicator|year| log_km_driven| label| features|\n",
"+---------------+----+------------------+------------------+--------------------+\n",
"| 0.0|2019| 5.857933154483459|12.072541252905651|[2019.0,5.8579331...|\n",
"| 0.0|2017| 8.639410824140487|10.714417768752456|[2017.0,8.6394108...|\n",
"| 0.0|2018| 9.392661928770137|11.918390573078392|[2018.0,9.3926619...|\n",
"| 0.0|2015|10.043249494911286|11.082142548877775|[2015.0,10.043249...|\n",
"| 1.0|2011| 9.95227771670556| 9.903487552536127|[2011.0,9.9522777...|\n",
"| 0.0|2010|11.002099841204238| 9.798127036878302|[2010.0,11.002099...|\n",
"| 0.0|2018| 9.740968623038354| 11.2708539037705|[2018.0,9.7409686...|\n",
"| 1.0|2008|10.571316925111784|12.100712129872347|[2008.0,10.571316...|\n",
"| 0.0|2010|10.373491181781864|10.308952660644293|[2010.0,10.373491...|\n",
"| 0.0|2016|10.645424897265505|10.819778284410283|[2016.0,10.645424...|\n",
"+---------------+----+------------------+------------------+--------------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"assembler.transform(\n",
" sqlTrans.transform(\n",
" binaryTrans.transform(\n",
" indexerTrans.transform(bike)\n",
" )\n",
" )\n",
").show(10)"
]
},
{
"cell_type": "markdown",
"id": "41f0fd7f-1ff7-4da5-8f8a-5b523cf2d891",
"metadata": {},
"source": [
"Awesome! Our data is now transformed to have the new variables we want and the data is in the format needed to fit model using `MLlib`. Specifically:\n",
"- A column named `label` that is the response variable\n",
"- A column named `features` that is a column containing all the predictor values together\n",
"\n",
"Next step, we'll fit a basic multiple linear regression model using the `LinearRegression()` function from `pyspark.ml.regression`. This function does regularized regression (Elastic net, which includes LASSO and Ridge Regression as special cases) so there are a few set up parameters we can modify to just get the usual MLR fit."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "69aded9d-78ae-4b53-9b41-4810faac0ea1",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml.regression import LinearRegression"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e48e2a16-40cf-495d-9c59-fabb8e04fcb3",
"metadata": {},
"outputs": [],
"source": [
"lr = LinearRegression(regParam = 0, elasticNetParam = 0)"
]
},
{
"cell_type": "markdown",
"id": "e4dc8a8c-0c29-49e1-97a5-b1ac70f1bfa2",
"metadata": {},
"source": [
"Now we can use the `.fit()` method on the transformed data we made above."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8ab09c66-b292-43cc-9db7-07fff2cf0e02",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DataFrame[owner_indicator: double, year: bigint, log_km_driven: double, label: double, features: vector]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lr_data = assembler.transform(\n",
" sqlTrans.transform(\n",
" binaryTrans.transform(\n",
" indexerTrans.transform(bike)\n",
" )\n",
" )\n",
")\n",
"lr_data"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "edf05787-5cb1-422d-9a46-a3e22433c346",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"24/03/08 17:03:10 WARN Instrumentation: [1d9ded2d] regParam is zero, which might cause numerical instability and overfitting.\n",
" \r"
]
}
],
"source": [
"lrModel = lr.fit(lr_data)"
]
},
{
"cell_type": "markdown",
"id": "1f9c894a-d9b5-48ee-ada3-0560933c16c7",
"metadata": {},
"source": [
"We can inspect the model fit using attributes of our `lrModel` object."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3fba9c91-ba10-488d-8e8f-795fa3f7df8b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Intercept: -151.65184885852284 Coefficients: [0.08175654813171404,-0.22825871252866586,0.10021253855848164]\n"
]
}
],
"source": [
"print(\"Intercept: %s\" % str(lrModel.intercept), \"Coefficients: %s\" % str(lrModel.coefficients))"
]
},
{
"cell_type": "markdown",
"id": "07ae7bc8-6923-4a46-9790-6ecfd12bbf9b",
"metadata": {},
"source": [
"Let's inspect the training set RMSE and other model fit metrics."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8081aeaf-b29f-4596-8586-05ddc2c06f98",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+\n",
"| residuals|\n",
"+--------------------+\n",
"|-0.00495628658079...|\n",
"| -0.5646701626673813|\n",
"| 0.7294822208803797|\n",
"| 0.28700612130943703|\n",
"| -0.6856003220335012|\n",
"+--------------------+\n",
"only showing top 5 rows\n",
"\n",
"RMSE: 0.509892\n",
"r2: 0.485191\n"
]
}
],
"source": [
"trainingSummary = lrModel.summary\n",
"trainingSummary.residuals.show(5)\n",
"print(\"RMSE: %f\" % trainingSummary.rootMeanSquaredError)\n",
"print(\"r2: %f\" % trainingSummary.r2)"
]
},
{
"cell_type": "markdown",
"id": "eeb91591-2a1b-4695-9ddf-a508217ca507",
"metadata": {},
"source": [
"Of course we will often want to do predictions as well. This is easy to do! The model itself actually becomes a transformer once it is fit! We just use the `.transform()` method and pass it data similar in format to that on which the model was trained. Here we'll just look at the prediction from the data used to fit the model."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1adaaba3-85eb-435a-ba23-8e4c9ea0101e",
"metadata": {},
"outputs": [],
"source": [
"preds = lrModel.transform(lr_data)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "fa5bcea2-110d-4581-9a88-7001b7009ef7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------------+----+------------------+------------------+--------------------+------------------+\n",
"|owner_indicator|year| log_km_driven| label| features| prediction|\n",
"+---------------+----+------------------+------------------+--------------------+------------------+\n",
"| 0.0|2019| 5.857933154483459|12.072541252905651|[2019.0,5.8579331...|12.077497539486444|\n",
"| 0.0|2017| 8.639410824140487|10.714417768752456|[2017.0,8.6394108...|11.279087931419838|\n",
"| 0.0|2018| 9.392661928770137|11.918390573078392|[2018.0,9.3926619...|11.188908352198013|\n",
"| 0.0|2015|10.043249494911286|11.082142548877775|[2015.0,10.043249...|10.795136427568337|\n",
"| 1.0|2011| 9.95227771670556| 9.903487552536127|[2011.0,9.9522777...|10.589087874569628|\n",
"+---------------+----+------------------+------------------+--------------------+------------------+\n",
"only showing top 5 rows\n",
"\n"
]
}
],
"source": [
"preds.show(5)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "85bfad84-e0e8-40d0-9d4d-fd5a261b2053",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------------+------------------+\n",
"| label| prediction|\n",
"+------------------+------------------+\n",
"|12.072541252905651|12.077497539486444|\n",
"|10.714417768752456|11.279087931419838|\n",
"|11.918390573078392|11.188908352198013|\n",
"|11.082142548877775|10.795136427568337|\n",
"| 9.903487552536127|10.589087874569628|\n",
"+------------------+------------------+\n",
"only showing top 5 rows\n",
"\n"
]
}
],
"source": [
"preds.select(\"label\", \"prediction\").show(5)"
]
},
{
"cell_type": "markdown",
"id": "e001274a-e662-484a-a7da-89a082d3ef86",
"metadata": {},
"source": [
"Just a sanity check to show we get the same RMSE if we do it with these values:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "a4951eef-915f-4ad8-8ecc-7ffb4e346313",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.5098915575316225"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_preds = preds.select(\"label\", \"prediction\").toPandas()\n",
"import numpy as np\n",
"np.sqrt(np.mean((my_preds[\"label\"]-my_preds[\"prediction\"])**2))"
]
},
{
"cell_type": "markdown",
"id": "d1e56c09-ae72-4381-9e3a-3195d74372e8",
"metadata": {},
"source": [
"One more sanity check. Compare with `sklearn`.\n",
"\n",
"Note: I had an error with the `scipy` module not being recognized. I opened a new terminal window (file --> New Launcher, choose Terminal at the bottom) and had to submit the code:\n",
"`pip install scipy --force`"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "238ffde7-c389-44e5-a9ac-3b586335a956",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-151.55163631488233 [ 0.08175655 -0.22825871 -0.10021254]\n",
"0.5098915575316225\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jupyter-jbpost2@ncsu.edu/.local/lib/python3.9/site-packages/sklearn/metrics/_regression.py:483: FutureWarning: 'squared' is deprecated in version 1.4 and will be removed in 1.6. To calculate the root mean squared error, use the function'root_mean_squared_error'.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from sklearn import linear_model\n",
"from sklearn.metrics import mean_squared_error\n",
"bike_data = pd.read_csv(\"https://www4.stat.ncsu.edu/~online/datasets/bikeDetails.csv\")\n",
"bike_data['log_selling_price'] = np.log(bike_data['selling_price'])\n",
"bike_data['log_km_driven'] = np.log(bike_data['km_driven'])\n",
"bike_data['one_owner'] = pd.get_dummies(data = bike_data['owner'])['1st owner']\n",
"mlr_fit = linear_model.LinearRegression() #Create a reg object\n",
"mlr_fit.fit(bike_data[['year','log_km_driven','one_owner']], bike_data['log_selling_price'].values)\n",
"print(mlr_fit.intercept_, mlr_fit.coef_)\n",
"print(mean_squared_error(bike_data['log_selling_price'], mlr_fit.predict(bike_data[[\"year\", \"log_km_driven\", \"one_owner\"]]), squared = False))"
]
},
{
"cell_type": "markdown",
"id": "1a156fd2-4b04-4d25-ac8f-0868c4f25bbe",
"metadata": {},
"source": [
"Cool! We've fit an MLR model using the `pyspark` `MLlib` library!"
]
},
{
"cell_type": "markdown",
"id": "e5cc71e4-fcba-4a40-bdee-60d9dd906437",
"metadata": {},
"source": [
"# Using Cross-Validation to Select Our Model"
]
},
{
"cell_type": "markdown",
"id": "61fc9e99-0fa7-4fab-b392-5a8d12ea5fd2",
"metadata": {},
"source": [
"Of course we know that what we care about is the quality of the prediction our model makes on **data it wasn't trained on**. Generally, this means we need to either \n",
"- split our data into a training and test (sometimes called validation) set\n",
"- use cross-validation\n",
"\n",
"And of course, we've talked about the need for both when selecting from many types of models.\n",
"\n",
"Let's focus on doing k-fold CV using `pyspark`. We need to set up our grid of tuning parameters (if applicable) and then run our CV algorithm. This can be done using the `ParamGridBuilder()` and `CrossValidator()` functions from `pyspark.ml.tuning`, respectively."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "67652b75-cc9a-4fe7-8fa0-49e3bd13fad2",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml.tuning import CrossValidator, ParamGridBuilder"
]
},
{
"cell_type": "markdown",
"id": "5548fcca-f145-40fe-a592-8c964fea4f43",
"metadata": {},
"source": [
"The only tuning parameter we'll worry about is the `regParam`. which is $\\lambda$ in the penalty part of the loss function:\n",
"$$\\alpha(\\lambda||w||_1)+(1-\\alpha)(\\lambda/2||w||_2^2)$$\n",
"If we set the `elasticNetParam` to 1 ($\\alpha$ above) then we are doing LASSO regression.\n",
"\n",
"We can see some [details for each algorithm via the help files](https://spark.apache.org/docs/latest/api/python/reference/pyspark.ml.html#regression). We can look at how well the model does when we use different amounts of LASSO regularization. We use `ParamGridBuilder()` and `.addGrid()` to specify the tuning parameter values. Then finally we use the `.build()` method to instruct it to build the grid."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "749f9944-99b6-47b4-a98a-4b34f8b4a455",
"metadata": {},
"outputs": [],
"source": [
"lr = LinearRegression()\n",
"paramGrid = ParamGridBuilder() \\\n",
" .addGrid(lr.regParam, [0, 0.05, 0.1, 0.15]) \\\n",
" .addGrid(lr.elasticNetParam, [1]) \\\n",
" .build()"
]
},
{
"cell_type": "markdown",
"id": "716dcfac-9400-400f-b699-69cba47a9e82",
"metadata": {},
"source": [
"Next up, we can use the `CrossValidator()` function to run k-fold CV over the grid of tuning parameters we just set up. \n",
"\n",
"We need to tell `pyspark` what loss function to use when evaluating. This is done via the `RegressionEvaluator()` function from `pyspark.ml.evaluation`. To override the default metric used we can specify it explicitly using the `metricName=` argument when calling the function. `rmse` is the default here."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "5baecc00-3a23-4361-867e-06b25bdf1214",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml.evaluation import RegressionEvaluator"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "f3e7849f-d98d-4715-ace8-05bc31e0dc0c",
"metadata": {},
"outputs": [],
"source": [
"crossval = CrossValidator(estimator = lr,\n",
" estimatorParamMaps = paramGrid,\n",
" evaluator = RegressionEvaluator(metricName='rmse'),\n",
" numFolds=5)"
]
},
{
"cell_type": "markdown",
"id": "9d776905-895e-440f-96a1-e7b7a90a0aaa",
"metadata": {},
"source": [
"We've now set up the `crossval` object. Just like with `sklearn` we now use the `.fit()` method to actually fit the models."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "0a508983-2880-44ae-b288-5bb2308ed1e9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"24/03/08 17:06:44 WARN Instrumentation: [e4f7ec65] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:06:49 WARN Instrumentation: [96732fed] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:06:52 WARN Instrumentation: [ccc91038] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:06:54 WARN Instrumentation: [04867e84] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:06:56 WARN Instrumentation: [e89224fc] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:06:58 WARN Instrumentation: [ca268c57] regParam is zero, which might cause numerical instability and overfitting.\n"
]
}
],
"source": [
"cvModel = crossval.fit(lr_data)"
]
},
{
"cell_type": "markdown",
"id": "d3ce9f80-5e7e-4c5f-b644-15d4358a7ad8",
"metadata": {},
"source": [
"By default, the only model returned is the **best** model as measured by our Loss function. \n",
"\n",
"To determine which model was returned we can look at the `.avgMetrics` attribute along with the `paramGrid` object we used to fit the models."
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "bdd219e7-b39d-4a2f-a679-1cc94adfe19a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(0, 0.5112406194486974),\n",
" (0.05, 0.5136269855905575),\n",
" (0.1, 0.5235001947739562),\n",
" (0.15, 0.5395464815213404)]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(zip([0, 0.05, 0.1, 0.15], cvModel.avgMetrics))"
]
},
{
"cell_type": "markdown",
"id": "15dd89db-2aeb-4de1-9ad6-6c9a73e28b60",
"metadata": {},
"source": [
"We can see the best model's parameter estimates as well."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "b2f19cee-96e1-448b-99bc-14d0d8e89735",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-151.6518488577565 [0.08175654813133452,-0.22825871252886246,0.10021253855757116]\n"
]
}
],
"source": [
"print(cvModel.bestModel._java_obj.intercept(), cvModel.bestModel._java_obj.coefficients())"
]
},
{
"cell_type": "markdown",
"id": "2e837395-8910-47ac-abda-a611870bf2cc",
"metadata": {},
"source": [
"As with the single linear model fit we did above, this new object, `cvModel`, is now a transformation as well. This allows us to get prediction using the best fit model. If we had a test set we could use it to get the test set prediction easily. As we didn't split into a test set, let's just see how to use it to predict on the training data (i.e. return the fitted values for the model)."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "e15c4b9a-3b9c-4733-85a6-ee2eed11d066",
"metadata": {},
"outputs": [],
"source": [
"predsCV = cvModel.transform(lr_data)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "e1e2f6d6-85fe-42df-b47c-1173bc0aa901",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------------+------------------+\n",
"| label| prediction|\n",
"+------------------+------------------+\n",
"|12.072541252905651|12.077497539485364|\n",
"|10.714417768752456|11.279087931418985|\n",
"|11.918390573078392|11.188908352196648|\n",
"|11.082142548877775|10.795136427567968|\n",
"| 9.903487552536127|10.589087874569884|\n",
"+------------------+------------------+\n",
"only showing top 5 rows\n",
"\n"
]
}
],
"source": [
"predsCV.select(\"label\", \"prediction\").show(5)"
]
},
{
"cell_type": "markdown",
"id": "98b942ad-bc72-4a1f-990b-43cf7103b0cd",
"metadata": {},
"source": [
"As our model chosen was not regularized, we get the same predictions as previous!"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "0bf159e1-cd81-469f-860d-d9d623bbdc09",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------------+------------------+\n",
"| label| prediction|\n",
"+------------------+------------------+\n",
"|12.072541252905651|12.077497539486444|\n",
"|10.714417768752456|11.279087931419838|\n",
"|11.918390573078392|11.188908352198013|\n",
"|11.082142548877775|10.795136427568337|\n",
"| 9.903487552536127|10.589087874569628|\n",
"+------------------+------------------+\n",
"only showing top 5 rows\n",
"\n"
]
}
],
"source": [
"preds.select(\"label\", \"prediction\").show(5)"
]
},
{
"cell_type": "markdown",
"id": "f3c79f2f-7263-406b-a9ff-7be469bcddcc",
"metadata": {},
"source": [
"We can find the training RMSE by passing these predictions to the `RegressionEvaluator().evaluate()` method."
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "ab922b95-c156-42ba-8ee5-47198257caf6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.509891557531623"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"RegressionEvaluator().evaluate(cvModel.transform(lr_data))"
]
},
{
"cell_type": "markdown",
"id": "e003660f-23f3-4b35-a719-40d7dccb9c8d",
"metadata": {},
"source": [
"# Pipelines and the Training/Test Split"
]
},
{
"cell_type": "markdown",
"id": "2911145f-aabc-4c3a-b3c9-01fe66ddcaba",
"metadata": {},
"source": [
"Of course we often want to have a training and test set so we can do all of our model fitting and tuning on the training set and then see how different model types compare on the test set. We should always split our data into training and test sets first, before doing transformations. Reason being:\n",
"- If we do transformations on the training set, we want to use the exact same transformations on the test set\n",
"- For instance, if we center some predictors (subtract the mean) in our training set, we want to subtract the mean of the **training** set to standardize the test set values prior to doing test set predictions!\n",
"- By splitting the data first, we can make sure we aren't using any test data (and the knowledge that comes with it) when training our models\n",
"\n",
"By setting up a **pipeline** in `MLlib` we can easily (ha, that's kind of a joke) create the sequence of transformations/model fits and apply those same transformations on our test set!\n",
"\n",
"Let's start with the training/test split. This can be done using the `.randomSplit()` method on a spark SQL style data frame."
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "eea5f9d6-be15-4e2e-8f73-06cc555e9245",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"840 221\n"
]
}
],
"source": [
"train, test = bike.randomSplit([0.8,0.2], seed = 1)\n",
"print(train.count(), test.count())"
]
},
{
"cell_type": "markdown",
"id": "79577693-3b0f-4294-a24a-39c9a9459fd1",
"metadata": {},
"source": [
"Now that we've split our data we can **set up** the transformations to do. Just like with other spark stuff, things are set up as a DAG and done only at run time. \n",
"We created the transformation plans previously so we'll just pull them down here for clarity."
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "3bd14b8c-f4d7-4a8b-832d-1506c81a6fab",
"metadata": {},
"outputs": [],
"source": [
"#Sequence of transformations\n",
"#indexerTrans\n",
"#binaryTrans\n",
"#sqlTrans\n",
"#assembler\n",
"#(linear model or other model here!)"
]
},
{
"cell_type": "markdown",
"id": "92451add-fe23-4196-86b4-0e3971f0bb4f",
"metadata": {},
"source": [
"We are then ready to create our pipeline that includes these transformations and the model that we want to fit. We use the `Pipeline()` function from the `pyspark.ml` module to set up our sequence of transformations/estimators."
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "aeb06165-54e7-4184-b9ce-d0efd5fe34ac",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml import Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "930fceb4-f177-4fa5-9a67-337e26819844",
"metadata": {},
"outputs": [],
"source": [
"lr = LinearRegression()\n",
"paramGrid = ParamGridBuilder() \\\n",
" .addGrid(lr.regParam, [0, 0.01, 0.04, 0.06, 0.1]) \\\n",
" .addGrid(lr.elasticNetParam, [0, 0.5, 0.8, 0.9, 1]) \\\n",
" .build()\n",
"pipeline = Pipeline(stages = [indexerTrans, binaryTrans, sqlTrans, assembler, lr])"
]
},
{
"cell_type": "markdown",
"id": "4d31b5fc-e665-41c6-b074-d28d89c95ba0",
"metadata": {},
"source": [
"Our DAG is now set up and we can use this pipeline within our CV calculation (or basic model fitting). What's nice is that since it contains all the information about the transformations done, we can easily apply this to a test set and not have to worry about how to do the transformations/prepping of the data on that set. \n",
"\n",
"Instead of using the model type as the `estimator` we'll pass the pipeline we've set up."
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "3cd29241-7948-4521-92ba-3e90655a1bca",
"metadata": {},
"outputs": [],
"source": [
"crossval = CrossValidator(estimator = pipeline,\n",
" estimatorParamMaps = paramGrid,\n",
" evaluator = RegressionEvaluator(),\n",
" numFolds=5)"
]
},
{
"cell_type": "markdown",
"id": "0f55a754-3bd5-4e4e-ad3b-543cbd0cf1cf",
"metadata": {},
"source": [
"With everything set up, we can now fit our models!"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "dbdf6b6d-8cb6-412e-b7bc-4f4f8e302827",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"24/03/08 17:08:42 WARN Instrumentation: [45408c1a] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:08:43 WARN Instrumentation: [bdf949e2] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:08:43 WARN Instrumentation: [c9e03b2c] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:08:44 WARN Instrumentation: [cc03885d] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:08:45 WARN Instrumentation: [8f5c1468] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:08:56 WARN Instrumentation: [d5eaed65] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:08:57 WARN Instrumentation: [e1d0afb2] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:08:58 WARN Instrumentation: [1e88fc13] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:08:58 WARN Instrumentation: [769528f1] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:08:59 WARN Instrumentation: [a95045a7] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:10 WARN Instrumentation: [9e6d1f94] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:11 WARN Instrumentation: [778a593a] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:11 WARN Instrumentation: [7a4bdc3a] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:12 WARN Instrumentation: [ffefcfcf] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:12 WARN Instrumentation: [333c0cdd] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:23 WARN Instrumentation: [0c04f6cd] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:24 WARN Instrumentation: [a9177fea] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:24 WARN Instrumentation: [efc198f1] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:25 WARN Instrumentation: [4c91d04f] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:25 WARN Instrumentation: [0d67662a] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:35 WARN Instrumentation: [1f57078a] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:36 WARN Instrumentation: [5973b334] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:36 WARN Instrumentation: [f17e57e2] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:37 WARN Instrumentation: [d043d7d8] regParam is zero, which might cause numerical instability and overfitting.\n",
"24/03/08 17:09:37 WARN Instrumentation: [750b1ce7] regParam is zero, which might cause numerical instability and overfitting.\n"
]
}
],
"source": [
"# Run cross-validation, and choose the best set of parameters.\n",
"cvModel = crossval.fit(train)"
]
},
{
"cell_type": "markdown",
"id": "27e22222-1aaa-4dbd-a5fb-ee397aefc048",
"metadata": {},
"source": [
"Check which model is chosen as the best:"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "254ce767-18d7-47b2-ba79-7ddf3844f5f4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[0.5124021026719168, dict_values([0.0, 0.0])],\n",
" [0.5124021026718817, dict_values([0.0, 0.5])],\n",
" [0.5124021026718577, dict_values([0.0, 0.8])],\n",
" [0.5124021026718556, dict_values([0.0, 0.9])],\n",
" [0.5124021026719084, dict_values([0.0, 1.0])],\n",
" [0.5123205996805408, dict_values([0.01, 0.0])],\n",
" [0.5123483858580572, dict_values([0.01, 0.5])],\n",
" [0.5122975622688871, dict_values([0.01, 0.8])],\n",
" [0.5122867285770079, dict_values([0.01, 0.9])],\n",
" [0.5122589945611346, dict_values([0.01, 1.0])],\n",
" [0.5123909761343102, dict_values([0.04, 0.0])],\n",
" [0.5122102563576977, dict_values([0.04, 0.5])],\n",
" [0.5120946331251414, dict_values([0.04, 0.8])],\n",
" [0.5121238218400215, dict_values([0.04, 0.9])],\n",
" [0.51224440770079, dict_values([0.04, 1.0])],\n",
" [0.512663579012802, dict_values([0.06, 0.0])],\n",
" [0.5126627597862455, dict_values([0.06, 0.5])],\n",
" [0.5137013318120469, dict_values([0.06, 0.8])],\n",
" [0.5142717497055167, dict_values([0.06, 0.9])],\n",
" [0.5148905572778366, dict_values([0.06, 1.0])],\n",
" [0.5136433512562688, dict_values([0.1, 0.0])],\n",
" [0.5159728675900753, dict_values([0.1, 0.5])],\n",
" [0.5199269354732108, dict_values([0.1, 0.8])],\n",
" [0.5215278837702534, dict_values([0.1, 0.9])],\n",
" [0.523281713834437, dict_values([0.1, 1.0])]]"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_list = []\n",
"for i in range(len(paramGrid)):\n",
" my_list.append([cvModel.avgMetrics[i], paramGrid[i].values()])\n",
"my_list"
]
},
{
"cell_type": "markdown",
"id": "52bc633a-c4af-41b8-8a1c-5f603ff9bd0b",
"metadata": {},
"source": [
"Use that best model to get test error on the test set."
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "3e5818e2-3522-46a3-af5c-e001962df133",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------------+----+------------------+------------------+--------------------+------------------+\n",
"|owner_indicator|year| log_km_driven| label| features| prediction|\n",
"+---------------+----+------------------+------------------+--------------------+------------------+\n",
"| 0.0|2019| 5.857933154483459|12.072541252905651|[2019.0,5.8579331...|11.953749958528732|\n",
"| 0.0|2020| 7.438383530044307|12.128111104060462|[2020.0,7.4383835...|11.705822488439821|\n",
"| 1.0|2013| 9.937695723865865|10.373491181781864|[2013.0,9.9376957...|10.679205674574945|\n",
"| 0.0|2018|10.571316925111784|11.695247021764184|[2018.0,10.571316...|10.919885599143953|\n",
"| 0.0|2017| 7.824046010856292|10.915088464214607|[2017.0,7.8240460...|11.405445806006924|\n",
"+---------------+----+------------------+------------------+--------------------+------------------+\n",
"only showing top 5 rows\n",
"\n"
]
}
],
"source": [
"cvModel.transform(test).show(5)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "8acc2249-0b6a-48dd-826b-461a380cd802",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.5114772577953631\n"
]
}
],
"source": [
"test_error = RegressionEvaluator().evaluate(cvModel.transform(test))\n",
"print(test_error)"
]
},
{
"cell_type": "markdown",
"id": "a6a633d5-7132-4736-a960-82d37fd41704",
"metadata": {},
"source": [
"Fantastic! We can now set up a pipeline and use CV to fit our model. Then we can take the best model and find test set predictions! \n",
"\n",
"Remember, once we have our final model we fit that model on the entire data set. Then we use it for future predictions and what-not."
]
},
{
"cell_type": "markdown",
"id": "469bc77e-4a8f-4c5c-a7c6-a786217440d1",
"metadata": {},
"source": [
"# `MLflow` Basics\n",
"\n",
"Note: To run this code in our jupyterhub I had to do a few things first. I needed to force install a few modules. Open a new terminal window and do the following:\n",
"\n",
"`pip install requests --force` \n",
"`pip install pyyaml --force` \n",
"`pip install entrypoints --force`\n",
"\n",
"\n",
"This section is modified from the [tutorial given here](https://github.com/mlflow/mlflow/blob/master/examples/sklearn_elasticnet_wine/train.ipynb). Although this uses `sklearn`, it can be modified to work with `MLlib`!\n",
"\n",
"First, they create a function to train an elastic net model specifically on the wine data we've used before. This function takes in the $\\alpha$ and $L_1$ ratio values from the `sklearn` `ElasticNet()` function. This function can then be passed different values of these and the given metrics will be found and logged."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "19f91f92-e399-4361-ac2a-8c7cbc76059f",
"metadata": {},
"outputs": [],
"source": [
"# Wine Quality Sample\n",
"def train(in_alpha, in_l1_ratio):\n",
" import os\n",
" import warnings\n",
" import sys\n",
"\n",
" import pandas as pd\n",
" import numpy as np\n",
" from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
" from sklearn.model_selection import train_test_split\n",
" from sklearn.linear_model import ElasticNet\n",
"\n",
" import mlflow\n",
" import mlflow.sklearn\n",
"\n",
" import logging\n",
"\n",
" logging.basicConfig(level=logging.WARN)\n",
" logger = logging.getLogger(__name__)\n",
"\n",
" def eval_metrics(actual, pred):\n",
" rmse = np.sqrt(mean_squared_error(actual, pred))\n",
" mae = mean_absolute_error(actual, pred)\n",
" r2 = r2_score(actual, pred)\n",
" return rmse, mae, r2\n",
"\n",
" warnings.filterwarnings(\"ignore\")\n",
" np.random.seed(40)\n",
"\n",
" # Read the wine-quality csv file from the URL\n",
" csv_url = (\n",
" \"http://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\"\n",
" )\n",
" try:\n",
" data = pd.read_csv(csv_url, sep=\";\")\n",
" except Exception as e:\n",
" logger.exception(\n",
" \"Unable to download training & test CSV, check your internet connection. Error: %s\", e\n",
" )\n",
"\n",
" # Split the data into training and test sets. (0.75, 0.25) split.\n",
" train, test = train_test_split(data)\n",
"\n",
" # The predicted column is \"quality\" which is a scalar from [3, 9]\n",
" train_x = train.drop([\"quality\"], axis=1)\n",
" test_x = test.drop([\"quality\"], axis=1)\n",
" train_y = train[[\"quality\"]]\n",
" test_y = test[[\"quality\"]]\n",
"\n",
" # Set default values if no alpha is provided\n",
" if float(in_alpha) is None:\n",
" alpha = 0.5\n",
" else:\n",
" alpha = float(in_alpha)\n",
"\n",
" # Set default values if no l1_ratio is provided\n",
" if float(in_l1_ratio) is None:\n",
" l1_ratio = 0.5\n",
" else:\n",
" l1_ratio = float(in_l1_ratio)\n",
"\n",
" # Useful for multiple runs (only doing one run in this sample notebook)\n",
" with mlflow.start_run():\n",
" # Execute ElasticNet\n",
" lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)\n",
" lr.fit(train_x, train_y)\n",
"\n",
" # Evaluate Metrics\n",
" predicted_qualities = lr.predict(test_x)\n",
" (rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)\n",
"\n",
" # Print out metrics\n",
" print(\"Elasticnet model (alpha=%f, l1_ratio=%f):\" % (alpha, l1_ratio))\n",
" print(\" RMSE: %s\" % rmse)\n",
" print(\" MAE: %s\" % mae)\n",
" print(\" R2: %s\" % r2)\n",
"\n",
" # Log parameter, metrics, and model to MLflow\n",
" mlflow.log_param(\"alpha\", alpha)\n",
" mlflow.log_param(\"l1_ratio\", l1_ratio)\n",
" mlflow.log_metric(\"rmse\", rmse)\n",
" mlflow.log_metric(\"r2\", r2)\n",
" mlflow.log_metric(\"mae\", mae)\n",
"\n",
" mlflow.sklearn.log_model(lr, \"model\")"
]
},
{
"cell_type": "markdown",
"id": "9c71a444-e898-40e3-9aaa-310f5d9d5123",
"metadata": {},
"source": [
"With this function set up, we now just call it with different values and it creates logs for them."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d46d80f2-b850-4089-a729-23f22deb4417",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024/03/29 13:51:46 WARNING mlflow.utils.git_utils: Failed to import Git (the Git executable is probably not on your PATH), so Git SHA is not available. Error: No module named 'git'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Elasticnet model (alpha=0.500000, l1_ratio=0.500000):\n",
" RMSE: 0.7931640229276851\n",
" MAE: 0.6271946374319586\n",
" R2: 0.10862644997792614\n"
]
}
],
"source": [
"train(0.5, 0.5)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9a824103-5f1f-4440-9137-20969c0b33c7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Elasticnet model (alpha=0.500000, l1_ratio=1.000000):\n",
" RMSE: 0.832819092896359\n",
" MAE: 0.6681279771237894\n",
" R2: 0.017268050734704055\n"
]
}
],
"source": [
"train(0.5, 1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d37a8c7e-328a-46ab-9cbd-733efb8f08d4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Elasticnet model (alpha=0.100000, l1_ratio=0.500000):\n",
" RMSE: 0.7308996187375898\n",
" MAE: 0.5615486628017713\n",
" R2: 0.2430813606733676\n"
]
}
],
"source": [
"train(0.1, 0.5)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "05ca843a-0b4e-4875-922b-123b170229d8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Elasticnet model (alpha=1.000000, l1_ratio=0.000000):\n",
" RMSE: 0.7508731220796289\n",
" MAE: 0.5811664801219333\n",
" R2: 0.2011470433755671\n"
]
}
],
"source": [
"train(1, 0)"
]
},
{
"cell_type": "markdown",
"id": "d5591501-13dd-48a6-af2f-381e25f81e8a",
"metadata": {},
"source": [
"If not in a docker container we would now be able to call the mlflow UI and look through things. I couldn't get that to work and gave up! Instead, we can just read in the data it uses in the UI and look at it within python."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4b6a94ff-0b22-47a9-be4f-bc30e9e3a993",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" run_id | \n",
" experiment_id | \n",
" status | \n",
" artifact_uri | \n",
" start_time | \n",
" end_time | \n",
" metrics.r2 | \n",
" metrics.mae | \n",
" metrics.rmse | \n",
" params.l1_ratio | \n",
" params.alpha | \n",
" tags.mlflow.log-model.history | \n",
" tags.mlflow.source.type | \n",
" tags.mlflow.user | \n",
" tags.mlflow.runName | \n",
" tags.mlflow.source.name | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 65d548898f1e40558fef1624d3d78cc6 | \n",
" 0 | \n",
" FINISHED | \n",
" file:///home/jupyter-jbpost2%40ncsu.edu/mlruns... | \n",
" 2024-03-08 22:15:28.734000+00:00 | \n",
" 2024-03-08 22:15:29.939000+00:00 | \n",
" 0.201147 | \n",
" 0.581166 | \n",
" 0.750873 | \n",
" 0.0 | \n",
" 1.0 | \n",
" [{\"run_id\": \"65d548898f1e40558fef1624d3d78cc6\"... | \n",
" LOCAL | \n",
" jupyter-jbpost2@ncsu.edu | \n",
" bedecked-dolphin-527 | \n",
" /opt/tljh/user/envs/pySpark/lib/python3.9/site... | \n",
"
\n",
" \n",
" 1 | \n",
" 6f421ddea3b5467889fcbc030ff9c9df | \n",
" 0 | \n",
" FINISHED | \n",
" file:///home/jupyter-jbpost2%40ncsu.edu/mlruns... | \n",
" 2024-03-08 22:15:26.864000+00:00 | \n",
" 2024-03-08 22:15:28.101000+00:00 | \n",
" 0.243081 | \n",
" 0.561549 | \n",
" 0.730900 | \n",
" 0.5 | \n",
" 0.1 | \n",
" [{\"run_id\": \"6f421ddea3b5467889fcbc030ff9c9df\"... | \n",
" LOCAL | \n",
" jupyter-jbpost2@ncsu.edu | \n",
" enthused-ape-110 | \n",
" /opt/tljh/user/envs/pySpark/lib/python3.9/site... | \n",
"
\n",
" \n",
" 2 | \n",
" e79aa57b53084ea79b24950ecbb0284f | \n",
" 0 | \n",
" FINISHED | \n",
" file:///home/jupyter-jbpost2%40ncsu.edu/mlruns... | \n",
" 2024-03-08 22:15:23.857000+00:00 | \n",
" 2024-03-08 22:15:25.117000+00:00 | \n",
" 0.017268 | \n",
" 0.668128 | \n",
" 0.832819 | \n",
" 1.0 | \n",
" 0.5 | \n",
" [{\"run_id\": \"e79aa57b53084ea79b24950ecbb0284f\"... | \n",
" LOCAL | \n",
" jupyter-jbpost2@ncsu.edu | \n",
" wise-worm-607 | \n",
" /opt/tljh/user/envs/pySpark/lib/python3.9/site... | \n",
"
\n",
" \n",
" 3 | \n",
" c40bba6d7d8d43dc8cbce75cfb448b93 | \n",
" 0 | \n",
" FINISHED | \n",
" file:///home/jupyter-jbpost2%40ncsu.edu/mlruns... | \n",
" 2024-03-08 22:15:16.728000+00:00 | \n",
" 2024-03-08 22:15:18.462000+00:00 | \n",
" 0.108626 | \n",
" 0.627195 | \n",
" 0.793164 | \n",
" 0.5 | \n",
" 0.5 | \n",
" [{\"run_id\": \"c40bba6d7d8d43dc8cbce75cfb448b93\"... | \n",
" LOCAL | \n",
" jupyter-jbpost2@ncsu.edu | \n",
" welcoming-pug-473 | \n",
" /opt/tljh/user/envs/pySpark/lib/python3.9/site... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" run_id experiment_id status \\\n",
"0 65d548898f1e40558fef1624d3d78cc6 0 FINISHED \n",
"1 6f421ddea3b5467889fcbc030ff9c9df 0 FINISHED \n",
"2 e79aa57b53084ea79b24950ecbb0284f 0 FINISHED \n",
"3 c40bba6d7d8d43dc8cbce75cfb448b93 0 FINISHED \n",
"\n",
" artifact_uri \\\n",
"0 file:///home/jupyter-jbpost2%40ncsu.edu/mlruns... \n",
"1 file:///home/jupyter-jbpost2%40ncsu.edu/mlruns... \n",
"2 file:///home/jupyter-jbpost2%40ncsu.edu/mlruns... \n",
"3 file:///home/jupyter-jbpost2%40ncsu.edu/mlruns... \n",
"\n",
" start_time end_time \\\n",
"0 2024-03-08 22:15:28.734000+00:00 2024-03-08 22:15:29.939000+00:00 \n",
"1 2024-03-08 22:15:26.864000+00:00 2024-03-08 22:15:28.101000+00:00 \n",
"2 2024-03-08 22:15:23.857000+00:00 2024-03-08 22:15:25.117000+00:00 \n",
"3 2024-03-08 22:15:16.728000+00:00 2024-03-08 22:15:18.462000+00:00 \n",
"\n",
" metrics.r2 metrics.mae metrics.rmse params.l1_ratio params.alpha \\\n",
"0 0.201147 0.581166 0.750873 0.0 1.0 \n",
"1 0.243081 0.561549 0.730900 0.5 0.1 \n",
"2 0.017268 0.668128 0.832819 1.0 0.5 \n",
"3 0.108626 0.627195 0.793164 0.5 0.5 \n",
"\n",
" tags.mlflow.log-model.history tags.mlflow.source.type \\\n",
"0 [{\"run_id\": \"65d548898f1e40558fef1624d3d78cc6\"... LOCAL \n",
"1 [{\"run_id\": \"6f421ddea3b5467889fcbc030ff9c9df\"... LOCAL \n",
"2 [{\"run_id\": \"e79aa57b53084ea79b24950ecbb0284f\"... LOCAL \n",
"3 [{\"run_id\": \"c40bba6d7d8d43dc8cbce75cfb448b93\"... LOCAL \n",
"\n",
" tags.mlflow.user tags.mlflow.runName \\\n",
"0 jupyter-jbpost2@ncsu.edu bedecked-dolphin-527 \n",
"1 jupyter-jbpost2@ncsu.edu enthused-ape-110 \n",
"2 jupyter-jbpost2@ncsu.edu wise-worm-607 \n",
"3 jupyter-jbpost2@ncsu.edu welcoming-pug-473 \n",
"\n",
" tags.mlflow.source.name \n",
"0 /opt/tljh/user/envs/pySpark/lib/python3.9/site... \n",
"1 /opt/tljh/user/envs/pySpark/lib/python3.9/site... \n",
"2 /opt/tljh/user/envs/pySpark/lib/python3.9/site... \n",
"3 /opt/tljh/user/envs/pySpark/lib/python3.9/site... "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import mlflow\n",
"#throws an issue but it returns the data appropriately\n",
"runs = mlflow.search_runs(experiment_ids=[\"0\"])\n",
"runs.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "202e8a7c-5907-493c-a9d7-7aca7871ca77",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['run_id', 'experiment_id', 'status', 'artifact_uri', 'start_time',\n",
" 'end_time', 'metrics.r2', 'metrics.mae', 'metrics.rmse',\n",
" 'params.l1_ratio', 'params.alpha', 'tags.mlflow.log-model.history',\n",
" 'tags.mlflow.source.type', 'tags.mlflow.user', 'tags.mlflow.runName',\n",
" 'tags.mlflow.source.name'],\n",
" dtype='object')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"runs.columns"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dfb34037-48a7-4584-a238-6cc424db46e6",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" metrics.rmse | \n",
" metrics.r2 | \n",
" metrics.mae | \n",
" params.l1_ratio | \n",
" params.alpha | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" 0.730900 | \n",
" 0.243081 | \n",
" 0.561549 | \n",
" 0.5 | \n",
" 0.1 | \n",
"
\n",
" \n",
" 0 | \n",
" 0.750873 | \n",
" 0.201147 | \n",
" 0.581166 | \n",
" 0.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.793164 | \n",
" 0.108626 | \n",
" 0.627195 | \n",
" 0.5 | \n",
" 0.5 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.832819 | \n",
" 0.017268 | \n",
" 0.668128 | \n",
" 1.0 | \n",
" 0.5 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" metrics.rmse metrics.r2 metrics.mae params.l1_ratio params.alpha\n",
"1 0.730900 0.243081 0.561549 0.5 0.1\n",
"0 0.750873 0.201147 0.581166 0.0 1.0\n",
"3 0.793164 0.108626 0.627195 0.5 0.5\n",
"2 0.832819 0.017268 0.668128 1.0 0.5"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"runs[[\"metrics.rmse\", \"metrics.r2\", \"metrics.mae\", \"params.l1_ratio\", \"params.alpha\"]].sort_values(\"metrics.rmse\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9fbe1831-ca0a-4082-bb46-824aa09dd072",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pySpark",
"language": "python",
"name": "pyspark"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}