Multiple Linear Regression for AirBnB prediction¶
Consider the classic case of an affine function:
$$y=ax+b$$
Here, $a$ and $b$ are real numbers. These two numbers entirely define the curve and thus make it possible to obtain an affine relation between $x$ and $y$. In statistics, this relationship is the basis of so-called linear models, where a response variable is defined as a sum of explanatory variables where each of the latter is multiplied by a coefficient.
Modèle linéaire multiple¶
Dans le cas multiple (pour $p$ variables explicatives), pour la $i$-ème observation, le modèle s'écrit :
$$y_i= \beta_0 + \sum_{j=1}^p \beta_j x_{ij} + \varepsilon_i$$
Ainsi, une observation $x_i$ n'est plus une valeur, mais un vecteur $(x_{i1}, \dots, x_{ip})$. Il est plus commode de regrouper ces prix $y_i$ et ces vecteurs d'observations $x_i$ dans des matrices :
$$Y=X \beta + \varepsilon$$
Sous les hypothèses équivalentes du modèle simple en plus grand dimension
$$(\mathcal{H}) : \left\{\begin{matrix} \text{rank}(X)=p\\ \mathbb{E}[\varepsilon]=0 \text{ et }\text{Var}(\varepsilon)=\sigma^2 I_p \end{matrix}\right.$$
Les différents éléments qui interviennent sont :
- $\beta$ : le vecteur directeur
- $X$ : la matrice des observations
- $Y$ : le vecteur de prix
- $\varepsilon$ : le vecteur de bruit
Avec $X=( \mathbf{1}, X_1, \dots, X_n)$, $Y=(y_1, \dots, y_n)^\top$ et $\varepsilon=(\varepsilon_1, \dots, \varepsilon_n)^\top$. La solution des MCO (Moindres Carrés Ordinaires) est alors :
$$\hat{\beta}= (X^\top X)^{-1} X^\top Y$$
Vous pouvez d'ailleurs faire la démonstration de votre coté ! Pour plus d'information mathématiques, je vous conseil le portail de wikipédia qui est très bien fait : lien ici
Implementation¶
#importer vos librairies
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn import linear_model #modèle linéaire
from sklearn.metrics import mean_squared_error, r2_score #métriques d'évaluation
#charger les données
#price_availability.csv
#listings_final.csv
#attention l'individu 589 n'a pas de prix !!
prices = pd.read_csv("../data/price_availability.csv", sep=";")
listings = pd.read_csv("../data/listings_final.csv", sep=";")
listings = listings.drop(589)
print("Data loaded.")
Data loaded.
/tmp/ipykernel_3057786/1216967489.py:5: DtypeWarning: Columns (3) have mixed types. Specify dtype option on import or set low_memory=False. prices = pd.read_csv("../data/price_availability.csv", sep=";")
Understanding the input data¶
The objective here is to load the data to create the $X$ and $Y$ matrices of the linear model. Attention, it is not necessary to add the column vector $\mathbf{1}$ in the first column, because scikit-learn does it automatically!
#define our input variable X and output variable Y
X = listings.loc[:, ["listing_id", "person_capacity", "bedrooms", "bathrooms" ]]
Y = []
#build the price vector
for i, row in X.iterrows():
y = 0
ID = int(row["listing_id"])
subset = prices[prices["listing_id"] == ID]
y = subset["local_price"].mean()
Y.append(y)
#convert into numpy array
Y = np.asarray(Y)
In Machine Learning, we usually split the data set into two subsets:
- A training set (train set), on which the model will be calibrated.
- A test set (test set), which will not be used during the calibration but will make it possible to check the ability of the model to generalize on new unknown observations.
In general, we split the data set (split) by taking $\alpha \%$ from the set for training and $1-\alpha \%$ as a test. In most cases, we consider that $\alpha = 10,20 or 30\%$.
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=42)
X_train.shape, y_train.shape, X_test.shape, y_test.shape
((699, 4), (699,), (300, 4), (300,))
Training the model¶
For information, scikit-learn uses the OLS (Ordinary Least Squares) solver from numpy for the technicals details
regr = linear_model.LinearRegression()
regr.fit(X_train.values, y_train)
LinearRegression()
We display the vector of the coefficients to quickly interpret the model.
#what do you think about the results ?
print('Coefficients beta_j : \n', regr.coef_)
print('Coefficients INTERCEPT beta_0 : \n', regr.intercept_)
Coefficients beta_j : [2.52164236e-06 3.27051274e+01 1.39869362e+01 7.80156818e+01] Coefficients INTERCEPT beta_0 : -80.85554900336902
Model Validation¶
Coefficient of determination $R^2$¶
Thereafter, we will make the assumption of Gaussianity on the noises. In idea, we'd like to get a numerical value that tells us how much sense linear regression makes on our data. For this, we introduce the following notations:
- $SCT=\|Y-\hat{y} \mathbf{1}\|^2$ is the sum of total squares
- $SCE=\|\hat{Y}-\hat{y} \mathbf{1}\|^2$ is the sum of squares explained
- $SCR=\|\hat{\varepsilon}\|^2$ is the sum of residual squares
The idea is to decompose the total sum of squares as the sum of squares that the model explains, in addition to the sum of squares that are related to the residuals (and therefore that the model cannot explain). We therefore see here the interest of calculating a coefficient from the $SCE$. Since we have the following relationship:
$$SCT=SCE+SCR \text{ then } 1=\frac{SCE}{SCT}+\frac{SCR}{SCT}$$
The smaller the residuals (and therefore the "good" regression), the smaller $SCR$ becomes and therefore $SCE$ becomes large. The reverse scheme works the same way. In the best case, we obtain $SCR=0$ and therefore $SCE=SCT$ from where the first member is worth $1$. In the constricted case, $SCE=0$ and automatically, the first member is null. This is how we define the coefficient of determination $R^2$ as $$R^2=\frac{SCE}{SCT}=1-\frac{SCR}{SCT}$$ Thus, $R^2 \in [0,1]$. The closer $R^2$ is to $1$, the more linear regression makes sense. On the contrary, if $R^2$ is close to $0$, the linear model has a weak explanatory power.
X_test
listing_id | person_capacity | bedrooms | bathrooms | |
---|---|---|---|---|
453 | 14992207 | 2 | 1 | 1.0 |
794 | 24564156 | 2 | 1 | 1.0 |
209 | 3452604 | 2 | 0 | 1.0 |
309 | 8243908 | 1 | 1 | 1.0 |
741 | 23233753 | 5 | 2 | 1.0 |
... | ... | ... | ... | ... |
314 | 8525469 | 8 | 3 | 2.5 |
404 | 12976143 | 7 | 3 | 2.0 |
7 | 5662637 | 2 | 1 | 1.0 |
155 | 2158913 | 2 | 0 | 1.0 |
810 | 24766158 | 9 | 4 | 3.0 |
300 rows × 4 columns
y_test
array([ 79.81038961, 125. , 95.45333333, 29. , 82.5883905 , 123.63829787, 360. , 130. , 450. , 164.845953 , 68.35142119, 294.18181818, 68.08247423, 51.01595745, 170.30548303, 80.24479167, 73.78249337, 195. , 375.0268714 , 834.96124031, 850.65633075, 89. , 650. , 63.38219895, 83. , 190. , 202.23514212, 96.6056701 , 108.95026178, 50. , 28.31937173, 195. , 93.70234987, 474.14258189, 450. , 517.0984456 , 83.37730871, 395. , 79. , 129.17493473, 135.11227154, 51.52785146, 117.22572178, 260. , 618. , 42. , 180. , 160. , 179.67315175, 40.0025641 , 80.984375 , 116.09947644, 429.28645833, 76.06896552, 219.5037594 , 459.71391076, 183.0848329 , 204.39276486, 161.5503876 , 131. , 44.16569201, 156.19693095, 288.68421053, 219.08349515, 60.05263158, 794.54663212, 86.13246753, 42.61658031, 290. , 350. , 191.34883721, 220. , 115. , 96.05714286, 180. , 329.36832061, 60.74093264, 241.3003876 , 59. , 119.02631579, 234.89460154, 68.55844156, 300.78640777, 175. , 21. , 53. , 152.67792208, 113.58333333, 152.56701031, 249.06493506, 240.37764933, 43.01578947, 197.68766404, 42.10471204, 59.53439153, 117. , 94. , 250. , 145.16966581, 450. , 86.77922078, 78. , 425. , 33.03645833, 177.4151436 , 242.98457584, 29. , 307.01547389, 53.53002611, 126.62239583, 182.21875 , 141.06994819, 240. , 36. , 149. , 74.87743191, 30. , 483.32432432, 399.47683398, 74.10810811, 350. , 403.08070866, 146. , 74.43410853, 203.21025641, 143.51181102, 92. , 33.4296875 , 494.36538462, 75.37105263, 156.14038462, 38.50773196, 353. , 233.69170984, 81. , 962.54025974, 36.58961039, 92. , 90.09819121, 124.45052083, 146.57519789, 53.50129199, 95.484375 , 140.16494845, 112. , 210.02362205, 180. , 188.57253886, 328.57699805, 94.67783505, 105. , 114.75773196, 152.53299492, 111.81443299, 311.54580153, 180.95287958, 165. , 42. , 1200. , 149. , 297. , 157. , 434.9132948 , 85.40682415, 81.3110687 , 104. , 63. , 163.60103627, 206.14690722, 132.18897638, 180. , 311.62934363, 88. , 42.40519481, 210.35917313, 52.16410256, 253.96850394, 109. , 158.8 , 104.78947368, 145. , 29. , 170. , 58. , 63.14397906, 115. , 195.72890026, 261.86046512, 381.96564885, 210. , 28.4947644 , 144.42159383, 1500. , 89. , 203.89312977, 43.84615385, 284.35356201, 99. , 296. , 65.86910995, 24. , 181.04986877, 170.64935065, 251.04639175, 72.39793282, 33. , 218.96373057, 145. , 48. , 390.16184971, 119. , 52. , 109.71611253, 222.8042328 , 179.21336761, 39. , 49.23697917, 97.18372703, 73.07731959, 347.67015707, 62.94010417, 232.38921002, 57. , 83.34563107, 237.90052356, 40. , 299. , 248.84971098, 207.69487179, 230. , 151.07915567, 113.53383459, 203.41755319, 240.99197861, 108. , 98.35695538, 165. , 250. , 331.59615385, 102.38157895, 196.98455598, 86. , 38.51171875, 105.71849866, 220. , 72.64583333, 109. , 355.52083333, 135.85677083, 238.25918762, 153.55844156, 99. , 59.16020672, 177.52645503, 49. , 220. , 170. , 54.10433071, 80.56363636, 61.62105263, 111.86458333, 250. , 79. , 121.50259067, 148.08051948, 94.74472169, 357.7046332 , 90.21052632, 158. , 89.43307087, 252.83854167, 191.13520408, 400.39787798, 221.99738903, 169. , 48.39030612, 40. , 172.67822736, 93.35248042, 61. , 139.17105263, 68.05526316, 79.10344828, 88.65066667, 330.46875 , 188.3203125 , 101.12631579, 109. , 124.16103896, 251.88824663, 69. , 47. , 221.83333333, 170. , 55.96938776, 718.09278351, 517.99618321, 213.32460733, 79. , 1300. ])
#compute y_pred
Y_pred = regr.predict(X_test)
len(Y_pred)
/home/benjamin/anaconda3/lib/python3.9/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but LinearRegression was fitted without feature names warnings.warn(
300
Y_pred
array([114.36230811, 138.49934015, 71.27662022, 64.64038407, 247.24685811, 171.0029468 , 564.31714095, 107.86749667, 385.37925255, 235.39531125, 163.90996015, 614.7240766 , 160.96465918, 252.46586729, 119.24192937, 78.55782096, 125.91559204, 195.2273598 , 436.64291877, 794.25736816, 589.91839747, 172.78036143, 223.17212728, 129.38357904, 129.92715435, 176.17462164, 121.77210176, 155.17725776, 157.73633916, 69.31984118, 93.10191171, 216.986993 , 113.62122638, 758.93637273, 457.52511937, 289.16569258, 68.03509436, 360.42285622, 129.91033125, 216.36433134, 170.89834022, 61.12385668, 210.70917088, 237.01939097, 321.30896763, 124.19051511, 78.23050292, 200.28316065, 78.73200593, 146.09983247, 128.47832586, 153.12417842, 225.8534829 , 69.53185805, 300.16300834, 165.2637238 , 214.97562767, 186.70835543, 219.71281443, 120.85148612, 77.74362047, 169.98505384, 142.51521195, 264.61448676, 190.12144451, 460.55218915, 119.14417921, 132.41413551, 210.05680618, 347.19287902, 197.69976006, 395.84853998, 200.31141939, 162.3362118 , 185.55987594, 491.47541121, 109.9623458 , 118.84015609, 120.1807003 , 133.29218269, 229.58745931, 109.55731226, 229.58940602, 143.06252686, 312.48666386, 112.74668429, 194.74318804, 73.08322971, 179.45230433, 196.43496928, 328.55328423, 81.22221249, 284.70807285, 48.19219459, 124.64695125, 135.07116607, 78.39407811, 191.38776334, 159.81707365, 259.43582621, 191.53786287, 178.10517238, 244.22094281, 96.9979699 , 185.65565296, 146.57202513, 105.30228567, 201.50788077, 132.22847073, 197.0119954 , 247.10782483, 185.13368614, 301.88593538, 61.45064517, 96.6717085 , 129.57884471, 74.12809996, 607.88602216, 526.06074738, 87.83519031, 246.85174455, 445.16139293, 110.52627833, 129.53665281, 218.9508198 , 196.42999653, 159.84673047, 62.03017154, 373.51904914, 141.58897993, 156.61981148, 134.02162069, 357.32546438, 361.25625322, 79.86517905, 451.75656895, 76.98751357, 111.33763189, 77.42575232, 250.55696078, 125.18203497, 79.31946408, 81.01324117, 231.44010033, 83.6304349 , 519.89578554, 103.05509483, 108.01306634, 318.12222567, 136.51964161, 167.15697746, 197.02209832, 273.99128072, 137.35441618, 285.22038491, 143.48880242, 241.28075567, 125.61840014, 585.49836868, 108.80613647, 292.50426133, 190.83180049, 200.30679226, 117.81201946, 52.74165019, 144.20373393, 111.54527675, 243.99491382, 262.82007698, 126.36298543, 241.59927319, 366.68501024, 131.12659753, 130.77172046, 117.59558942, 138.06057438, 176.74338868, 97.74632779, 168.86630498, 212.38780398, 211.65758434, 47.98493882, 171.41349914, 170.72894233, 71.97206397, 222.20236163, 205.47054187, 437.84113272, 263.12463343, 206.1906475 , 144.06930013, 131.71790876, 139.66872963, 74.34394598, 372.16738832, 25.74015141, 156.37814584, 172.43078997, 289.48665504, 111.63905792, 104.99258765, 285.27412334, 147.93041066, 344.50948134, 103.87440791, 122.72884382, 78.46805553, 146.28272919, 74.88285374, 219.39910981, 139.01132937, 130.10754486, 140.50757108, 276.52122257, 226.19387536, 90.94380545, 180.37801337, 173.23106213, 168.21181241, 110.39014999, 131.56152406, 90.53025106, 107.81270895, 135.71747691, 223.40498582, 106.95143948, 144.2039987 , 148.0047265 , 116.81240895, 432.29504729, 145.03092593, 141.53214715, 192.12377962, 556.09592367, 149.5926193 , 166.18844536, 126.5750152 , 164.33319873, 296.00366852, 124.33251383, 241.37442813, 145.05923589, 102.9717709 , 125.424821 , 117.91701996, 112.38647272, 144.25657191, 422.69020867, 201.81151929, 249.31698123, 162.958896 , 145.13671639, 120.12715322, 147.77428469, 47.03376523, 283.50340749, 135.37656318, 125.7027652 , 147.40797427, 174.96847271, 197.9171011 , 283.44124396, 157.47742931, 161.42677666, 273.07730634, 101.64372526, 492.66564106, 145.31890101, 118.1252264 , 236.81580787, 349.32795309, 207.8221955 , 248.67300298, 225.47676718, 142.34465006, 84.4449279 , 112.37024465, 223.61447326, 180.56428175, 115.41620202, 304.58199211, 153.87743182, 80.04466069, 38.84067305, 198.41309302, 136.49311745, 179.25980468, 152.83252627, 142.01233847, 216.26805054, 124.60142047, 118.4463172 , 187.67349013, 257.39270478, 148.69064095, 439.28366744, 378.79370715, 90.83646922, 68.0143942 , 565.93678128])
y_test
array([ 79.81038961, 125. , 95.45333333, 29. , 82.5883905 , 123.63829787, 360. , 130. , 450. , 164.845953 , 68.35142119, 294.18181818, 68.08247423, 51.01595745, 170.30548303, 80.24479167, 73.78249337, 195. , 375.0268714 , 834.96124031, 850.65633075, 89. , 650. , 63.38219895, 83. , 190. , 202.23514212, 96.6056701 , 108.95026178, 50. , 28.31937173, 195. , 93.70234987, 474.14258189, 450. , 517.0984456 , 83.37730871, 395. , 79. , 129.17493473, 135.11227154, 51.52785146, 117.22572178, 260. , 618. , 42. , 180. , 160. , 179.67315175, 40.0025641 , 80.984375 , 116.09947644, 429.28645833, 76.06896552, 219.5037594 , 459.71391076, 183.0848329 , 204.39276486, 161.5503876 , 131. , 44.16569201, 156.19693095, 288.68421053, 219.08349515, 60.05263158, 794.54663212, 86.13246753, 42.61658031, 290. , 350. , 191.34883721, 220. , 115. , 96.05714286, 180. , 329.36832061, 60.74093264, 241.3003876 , 59. , 119.02631579, 234.89460154, 68.55844156, 300.78640777, 175. , 21. , 53. , 152.67792208, 113.58333333, 152.56701031, 249.06493506, 240.37764933, 43.01578947, 197.68766404, 42.10471204, 59.53439153, 117. , 94. , 250. , 145.16966581, 450. , 86.77922078, 78. , 425. , 33.03645833, 177.4151436 , 242.98457584, 29. , 307.01547389, 53.53002611, 126.62239583, 182.21875 , 141.06994819, 240. , 36. , 149. , 74.87743191, 30. , 483.32432432, 399.47683398, 74.10810811, 350. , 403.08070866, 146. , 74.43410853, 203.21025641, 143.51181102, 92. , 33.4296875 , 494.36538462, 75.37105263, 156.14038462, 38.50773196, 353. , 233.69170984, 81. , 962.54025974, 36.58961039, 92. , 90.09819121, 124.45052083, 146.57519789, 53.50129199, 95.484375 , 140.16494845, 112. , 210.02362205, 180. , 188.57253886, 328.57699805, 94.67783505, 105. , 114.75773196, 152.53299492, 111.81443299, 311.54580153, 180.95287958, 165. , 42. , 1200. , 149. , 297. , 157. , 434.9132948 , 85.40682415, 81.3110687 , 104. , 63. , 163.60103627, 206.14690722, 132.18897638, 180. , 311.62934363, 88. , 42.40519481, 210.35917313, 52.16410256, 253.96850394, 109. , 158.8 , 104.78947368, 145. , 29. , 170. , 58. , 63.14397906, 115. , 195.72890026, 261.86046512, 381.96564885, 210. , 28.4947644 , 144.42159383, 1500. , 89. , 203.89312977, 43.84615385, 284.35356201, 99. , 296. , 65.86910995, 24. , 181.04986877, 170.64935065, 251.04639175, 72.39793282, 33. , 218.96373057, 145. , 48. , 390.16184971, 119. , 52. , 109.71611253, 222.8042328 , 179.21336761, 39. , 49.23697917, 97.18372703, 73.07731959, 347.67015707, 62.94010417, 232.38921002, 57. , 83.34563107, 237.90052356, 40. , 299. , 248.84971098, 207.69487179, 230. , 151.07915567, 113.53383459, 203.41755319, 240.99197861, 108. , 98.35695538, 165. , 250. , 331.59615385, 102.38157895, 196.98455598, 86. , 38.51171875, 105.71849866, 220. , 72.64583333, 109. , 355.52083333, 135.85677083, 238.25918762, 153.55844156, 99. , 59.16020672, 177.52645503, 49. , 220. , 170. , 54.10433071, 80.56363636, 61.62105263, 111.86458333, 250. , 79. , 121.50259067, 148.08051948, 94.74472169, 357.7046332 , 90.21052632, 158. , 89.43307087, 252.83854167, 191.13520408, 400.39787798, 221.99738903, 169. , 48.39030612, 40. , 172.67822736, 93.35248042, 61. , 139.17105263, 68.05526316, 79.10344828, 88.65066667, 330.46875 , 188.3203125 , 101.12631579, 109. , 124.16103896, 251.88824663, 69. , 47. , 221.83333333, 170. , 55.96938776, 718.09278351, 517.99618321, 213.32460733, 79. , 1300. ])
#afficher l'erreur des moindres carrées sur l'ensemble d'entrainement ainsi que le R2
print("Mean squared error: %.2f"
% mean_squared_error(y_test, Y_pred))
# Coefficient de détermination R2
print('Variance score: %.2f' % r2_score(y_test, Y_pred))
Mean squared error: 19684.34 Variance score: 0.42
#compute the RMSE for more intuitive results
np.sqrt(19631.83)
140.1136324559463
Bonus: Homoscedasticity analysis¶
The analysis of homoscedasticity is essential: it is in particular this which allows us to check, from the residuals, whether the noises indeed verify the hypothesis $(\mathcal{H})$. We therefore calculate the studentized residuals.
$$t_i^*=\frac{\hat{\varepsilon}_i}{\hat{\sigma}_{(i)} \sqrt{1-h_{ii}}}$$ With $h_{ii}=\{X(X^\top X)^{-1} X^\top\}_{ii}=H_{ii}$ the projection matrix on the hyperplane of the variables. More precisely, $H$ is the matrix which projects $Y$ on the space generated by the variables, that is $\hat{Y}=HY$. Similarly, we consider $\hat{\sigma}_{(i)}$ the noise variance estimator by removing the observation $i$ (by a Leave-One-Out cross-validation method that we do not will not detail here).
In this case, it can be shown that the studentized residuals follow a Student's distribution with $n-p-1$ degrees of freedom.
H = np.matmul(X_train.values, np.linalg.solve(np.dot(X_train.T.values, X_train.values), X_train.T.values))
import scipy
Y_pred = regr.predict(X_train)
n = X_train.shape[0]
p = 4
residuals = np.abs(y_train - Y_pred)
H = np.matmul(X_train.values, np.linalg.solve(np.dot(X_train.T.values, X_train.values), X_train.T.values))
std_hat = np.dot(residuals, residuals) / (n - p)
standart_residuals = np.asarray([residuals[i] / np.sqrt(std_hat * (1 - H[i, i])) for i in range(len(residuals))])
student_residuals = np.asarray([ standart_residuals[i] * np.sqrt((n - p - 1) / (n - p - standart_residuals[i]**2)) for i in range(n) ])
cook = np.asarray([ H[i, i] * student_residuals[i] / (X_train.shape[1] * (1 - H[i, i])) for i in range(n) ])
plt.figure(figsize=(20, 12))
plt.subplot(221)
plt.scatter(Y_pred, student_residuals, s=12, c="white", edgecolors="blue")
plt.plot([min(Y_pred), max(Y_pred)], [ scipy.stats.t.ppf(q=0.975, df=n-p-1), scipy.stats.t.ppf(q=0.975, df=n-p-1)], color="green", alpha=0.6, label="Quantile de Student")
plt.title("Analyse de l’homoscédasticité")
plt.xlabel("Prédictions $\hat{y}_i$")
plt.ylabel("Résidus studentisés $|t_i^*|$")
plt.legend()
/home/benjamin/anaconda3/lib/python3.9/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but LinearRegression was fitted without feature names warnings.warn(
<matplotlib.legend.Legend at 0x7f45ceb81250>