library(tidyverse)
library(ggplot2)
library(ContaminatedMixt)
library(factoextra)Presentation 5B: Kmeans and Random Forest
Part 1: K-means Clustering
Clustering is a type of unsupervised learning technique used to group similar data points together based on their features. The goal is to find inherent patterns or structures within the data, e.g. to see whether the data points fall into distinct groups with distinct features or not.
Wine dataset
For this we will use the wine data set as an example:
Let’s load in the dataset
data('wine') #load dataset
df_wine <- wine %>%
as_tibble() #convert to tibble
df_wine# A tibble: 178 × 14
Type Alcohol Malic Ash Alcalinity Magnesium Phenols Flavanoids
<fct> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl>
1 Barolo 14.2 1.71 2.43 15.6 127 2.8 3.06
2 Barolo 13.2 1.78 2.14 11.2 100 2.65 2.76
3 Barolo 13.2 2.36 2.67 18.6 101 2.8 3.24
4 Barolo 14.4 1.95 2.5 16.8 113 3.85 3.49
5 Barolo 13.2 2.59 2.87 21 118 2.8 2.69
6 Barolo 14.2 1.76 2.45 15.2 112 3.27 3.39
7 Barolo 14.4 1.87 2.45 14.6 96 2.5 2.52
8 Barolo 14.1 2.15 2.61 17.6 121 2.6 2.51
9 Barolo 14.8 1.64 2.17 14 97 2.8 2.98
10 Barolo 13.9 1.35 2.27 16 98 2.98 3.15
# ℹ 168 more rows
# ℹ 6 more variables: Nonflavanoid <dbl>, Proanthocyanins <dbl>, Color <dbl>,
# Hue <dbl>, Dilution <dbl>, Proline <int>
This dataset contains 178 rows, each corresponding to one of three different cultivars of wine. It has 13 numerical columns that record different features of the wine.
We will try out a popular method, k-means clustering. It works by initializing K centroids and assigning each data point to the nearest centroid. The algorithm then recalculates the centroids as the mean of the points in each cluster, repeating the process until the clusters stabilize. You can see an illustration of the process below. Its weakness is that we need to define the number of centroids, i.e. clusters, beforehand.

Running k-means
For k-means it is very important that the data is numeric and scaled so we will do that before running the algorithm.
# Set seed to ensure reproducibility
set.seed(123)
# Pull numeric variables and scale these
kmeans_df <- df_wine %>%
dplyr::select(where(is.numeric)) %>%
mutate(across(everything(), scale))
kmeans_df# A tibble: 178 × 13
Alcohol[,1] Malic[,1] Ash[,1] Alcalinity[,1] Magnesium[,1] Phenols[,1]
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1.51 -0.561 0.231 -1.17 1.91 0.807
2 0.246 -0.498 -0.826 -2.48 0.0181 0.567
3 0.196 0.0212 1.11 -0.268 0.0881 0.807
4 1.69 -0.346 0.487 -0.807 0.928 2.48
5 0.295 0.227 1.84 0.451 1.28 0.807
6 1.48 -0.516 0.304 -1.29 0.858 1.56
7 1.71 -0.417 0.304 -1.47 -0.262 0.327
8 1.30 -0.167 0.888 -0.567 1.49 0.487
9 2.25 -0.623 -0.716 -1.65 -0.192 0.807
10 1.06 -0.883 -0.352 -1.05 -0.122 1.09
# ℹ 168 more rows
# ℹ 7 more variables: Flavanoids <dbl[,1]>, Nonflavanoid <dbl[,1]>,
# Proanthocyanins <dbl[,1]>, Color <dbl[,1]>, Hue <dbl[,1]>,
# Dilution <dbl[,1]>, Proline <dbl[,1]>
K-means clustering in R is easy, we simply run the kmeans() function:
set.seed(123)
kmeans_res <- kmeans_df %>%
kmeans(centers = 4, nstart = 25)
kmeans_resK-means clustering with 4 clusters of sizes 45, 56, 49, 28
Cluster means:
Alcohol Malic Ash Alcalinity Magnesium Phenols
1 -0.9051690 -0.53898599 -0.6498944 0.1592193 -0.71473842 -0.4537841
2 0.9580555 -0.37748461 0.1969019 -0.8214121 0.39943022 0.9000233
3 0.1860184 0.90242582 0.2485092 0.5820616 -0.05049296 -0.9857762
4 -0.7869073 0.04195151 0.2157781 0.3683284 0.43818899 0.6543578
Flavanoids Nonflavanoid Proanthocyanins Color Hue Dilution
1 -0.2408779 0.3315072 -0.4329238 -0.9177666 0.5202140 0.07869143
2 0.9848901 -0.6204018 0.5575193 0.2423047 0.4799084 0.76926636
3 -1.2327174 0.7148253 -0.7474990 0.9857177 -1.1879477 -1.29787850
4 0.5746004 -0.5429201 0.8888549 -0.7346332 0.2830335 0.60628629
Proline
1 -0.7820425
2 1.2184972
3 -0.3789756
4 -0.5169332
Clustering vector:
[1] 2 2 2 2 4 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 4 2 2 2 4 2 2 2 2 2 2 2 2 2 2 2
[38] 2 2 2 2 2 2 4 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 1 1 4 1 4 2 1 1 4 1 4 1 4
[75] 4 1 1 1 4 4 1 1 1 3 4 1 1 1 1 1 1 1 1 4 4 4 4 1 4 4 1 1 4 1 1 1 1 1 1 4 4
[112] 1 1 1 1 1 1 1 1 1 4 4 4 4 4 1 4 1 1 1 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
[149] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
Within cluster sum of squares by cluster:
[1] 289.9515 268.5747 302.9915 307.0966
(between_SS / total_SS = 49.2 %)
Available components:
[1] "cluster" "centers" "totss" "withinss" "tot.withinss"
[6] "betweenss" "size" "iter" "ifault"
We can call kmeans_res$centers to inspect the values of the centroids. For example the center of cluster 1 is placed at the coordinates -0.79 for Alcohol, 0.04 for Malic Acid, 0.22 for Ash and so on. Since our data has 13 dimensions, i.e. features, the cluster centers also do.
This is not super practical if we would like to visually inspect the clustering since we cannot plot in 13 dimensions. How could we solve this?
Visualizing k-means results
We would like to see where our wine bottles and their clusters lie in a low-dimensional space. This can easily be done using the fviz_cluster()
fviz_cluster(object = kmeans_res,
data = kmeans_df,
palette = c("#2E9FDF", "#00AFBB", "#E7B800", "orchid3"),
geom = "point",
ellipse.type = "norm",
ggtheme = theme_bw())
Optimal number of clusters
There are several ways to investigate the ideal number of clusters and fviz_nbclust from the factoextra package provides three of them:
The so-called elbow method observes how the sum of squared errors (sse) changes as we vary the number of clusters. This is also sometimes referred to as “within sum of square” (wss).
kmeans_df %>%
fviz_nbclust(kmeans, method = "wss")
The gap statistic compares the within-cluster variation (how compact the clusters are) for different values of K to the expected variation under a null reference distribution (i.e., random clustering).
kmeans_df %>%
fviz_nbclust(kmeans, method = "gap_stat")
Both of these tell us that there should be three clusters and we also know that there are three cultivars of wine in the dataset. Let’s redo k-means with three centroids.
# Set seed to ensure reproducibility
set.seed(123)
#run kmeans
kmeans_res <- kmeans_df %>%
kmeans(centers = 3, nstart = 25)#add updated cluster info to the dataframe
fviz_cluster(kmeans_res, data = kmeans_df,
palette = c("#2E9FDF", "#00AFBB", "#E7B800"),
geom = "point",
ellipse.type = "norm",
ggtheme = theme_bw())
Now, some obvious questions to ask might be (I) how do the clusters relate to the three cultivars of wine in the dataset? and (II) which variables are mainly driving the clustering?
Question (I) can be can be answered somewhat easily simply by coloring the clusters according to wine label or adding the label to the points in the plot above.
fzPlot <- fviz_cluster(kmeans_res, data = kmeans_df,
palette = c("#2E9FDF", "#00AFBB", "#E7B800"),
geom = "point",
ellipse.type = "norm",
ggtheme = theme_bw())
wine_labs <- transform(fzPlot$data,
my_label = df_wine$Type)
wine_labs name x y coord cluster my_label
1 1 -3.30742097 -1.43940225 13.0109123 2 Barolo
2 2 -2.20324981 0.33245507 4.9648361 2 Barolo
3 3 -2.50966069 -1.02825072 7.3556964 2 Barolo
4 4 -3.74649719 -2.74861839 21.5911443 2 Barolo
5 5 -1.00607049 -0.86738404 1.7645329 2 Barolo
6 6 -3.04167373 -2.11643092 13.7310589 2 Barolo
7 7 -2.44220051 -1.17154534 7.3368618 2 Barolo
8 8 -2.05364379 -1.60443714 6.7916714 2 Barolo
9 9 -2.50381135 -0.91548847 7.1071904 2 Barolo
10 10 -2.74588238 -0.78721703 8.1595807 2 Barolo
11 11 -3.46994837 -1.29866985 13.7270851 2 Barolo
12 12 -1.74981688 -0.61025577 3.4342712 2 Barolo
13 13 -2.10751729 -0.67380561 4.8956431 2 Barolo
14 14 -3.44842921 -1.12744948 13.1628064 2 Barolo
15 15 -4.30065228 -2.09007971 22.8640433 2 Barolo
16 16 -2.29870383 -1.65787506 8.0325890 2 Barolo
17 17 -2.16584568 -2.32075875 10.0768087 2 Barolo
18 18 -1.89362947 -1.62677993 6.2322455 2 Barolo
19 19 -3.53202167 -2.51125971 18.7816024 2 Barolo
20 20 -2.07865856 -1.05815307 5.4405093 2 Barolo
21 21 -3.11561376 -0.78468361 10.3227775 2 Barolo
22 22 -1.08351361 -0.24106354 1.2321134 2 Barolo
23 23 -2.52809263 0.09158228 6.3996397 2 Barolo
24 24 -1.64036108 0.51482667 2.9558310 2 Barolo
25 25 -1.75662066 0.31625681 3.1857345 2 Barolo
26 26 -0.98729406 -0.93802129 1.8546335 2 Barolo
27 27 -1.77028387 -0.68424496 3.6020961 2 Barolo
28 28 -1.23194878 0.08955442 1.5257178 2 Barolo
29 29 -2.18225047 -0.68762990 5.2350520 2 Barolo
30 30 -2.24976267 -0.19092336 5.0978838 2 Barolo
31 31 -2.49318704 -1.23734344 7.7470004 2 Barolo
32 32 -2.66987964 -1.46773335 9.2824985 2 Barolo
33 33 -1.62399801 -0.05255620 2.6401317 2 Barolo
34 34 -1.89733870 -1.62846673 6.2517980 2 Barolo
35 35 -1.40642118 -0.69597107 2.4623963 2 Barolo
36 36 -1.89847087 -0.17621387 3.6352430 2 Barolo
37 37 -1.38096669 -0.65678714 2.3384383 2 Barolo
38 38 -1.11905070 -0.11378878 1.2652224 2 Barolo
39 39 -1.49796891 0.76726764 2.8326105 2 Barolo
40 40 -2.52268490 -1.79793023 9.5964922 2 Barolo
41 41 -2.58081526 -0.77742329 7.2649943 2 Barolo
42 42 -0.66660159 -0.16948285 0.4730821 2 Barolo
43 43 -3.06216898 -1.15266742 10.7055210 2 Barolo
44 44 -0.46090897 -0.32981177 0.3212129 2 Barolo
45 45 -2.09544094 0.07080918 4.3958867 2 Barolo
46 46 -1.13297020 -1.77210849 4.4239900 2 Barolo
47 47 -2.71893118 -1.18798353 8.8038917 2 Barolo
48 48 -2.81340300 -0.64444071 8.3305403 2 Barolo
49 49 -2.00419725 -1.24352164 5.5631527 2 Barolo
50 50 -2.69987528 -1.74703922 10.3414726 2 Barolo
51 51 -3.20587409 -0.16652226 10.3053583 2 Barolo
52 52 -2.85091773 -0.74318238 8.6800519 2 Barolo
53 53 -3.49574328 -1.60819732 14.8065197 2 Barolo
54 54 -2.21853316 -1.86989325 8.4183902 2 Barolo
55 55 -2.14094846 -1.01389147 5.6116362 2 Barolo
56 56 -2.46238340 -1.32526988 7.8196723 2 Barolo
57 57 -2.73380617 -1.43250785 9.5257749 2 Barolo
58 58 -2.16762631 -1.20878999 6.1597770 2 Barolo
59 59 -3.13054925 -1.72670828 12.7818601 2 Barolo
60 60 0.92596992 3.06484062 10.2506683 3 Grignolino
61 61 1.53814123 1.37755758 4.2635433 3 Grignolino
62 62 1.83108449 0.82764942 4.0378740 1 Grignolino
63 63 -0.03052074 1.25923400 1.5866018 3 Grignolino
64 64 -2.04449433 1.91961759 7.8648888 3 Grignolino
65 65 0.60796583 1.90269154 3.9898575 3 Grignolino
66 66 -0.89769555 0.76176263 1.3861396 3 Grignolino
67 67 -2.24218226 1.87929123 8.5591168 3 Grignolino
68 68 -0.18286818 2.42031869 5.8913833 3 Grignolino
69 69 0.81051865 0.21989369 0.7052937 3 Grignolino
70 70 -1.97006319 1.39933587 5.8392898 3 Grignolino
71 71 1.56779366 0.88249373 3.2367721 3 Grignolino
72 72 -1.65301884 0.95402102 3.6426274 3 Grignolino
73 73 0.72333196 1.06065342 1.6481948 3 Grignolino
74 74 -2.55501977 -0.25946663 6.5954489 2 Grignolino
75 75 -1.82741266 1.28425547 4.9887491 3 Grignolino
76 76 0.86555129 2.43722606 6.6892499 3 Grignolino
77 77 -0.36897357 2.14784815 4.7493932 3 Grignolino
78 78 1.45327752 1.37946048 4.0149268 3 Grignolino
79 79 -1.25937829 0.76868117 2.1769044 3 Grignolino
80 80 -0.37509228 1.02415439 1.1895864 3 Grignolino
81 81 -0.75992026 3.36555997 11.9044727 3 Grignolino
82 82 -1.03166776 1.44662897 3.1570737 3 Grignolino
83 83 0.49348469 2.37454522 5.8819921 3 Grignolino
84 84 2.53183508 0.08719738 6.4177923 1 Grignolino
85 85 -0.83297044 1.46952520 2.8533441 3 Grignolino
86 86 -0.78568828 2.02092573 4.7014469 3 Grignolino
87 87 0.80456258 2.22754675 5.6092855 3 Grignolino
88 88 0.55647288 2.36631035 5.9090867 3 Grignolino
89 89 1.11197430 1.79717757 4.4663340 3 Grignolino
90 90 0.55415961 2.65006452 7.3299348 3 Grignolino
91 91 1.34548982 2.11204365 6.2710712 3 Grignolino
92 92 1.56008180 1.84700434 5.8452803 3 Grignolino
93 93 1.92711944 1.55510868 6.1321523 3 Grignolino
94 94 -0.74456561 2.30642556 5.8739768 3 Grignolino
95 95 -0.95476209 2.21727377 5.8278736 3 Grignolino
96 96 -2.53670943 -0.16879786 6.4633875 2 Grignolino
97 97 0.54242248 0.36788878 0.4295643 3 Grignolino
98 98 -1.02814946 2.55835254 7.6022591 3 Grignolino
99 99 -2.24557492 1.42871116 7.0838223 3 Grignolino
100 100 -1.40624916 2.16009839 6.6435617 3 Grignolino
101 101 -0.79547585 2.37026258 6.2509265 3 Grignolino
102 102 0.54798592 2.28667820 5.5291858 3 Grignolino
103 103 0.16072037 1.16120769 1.3742343 3 Grignolino
104 104 0.65793897 2.67242260 7.5747263 3 Grignolino
105 105 -0.39125074 2.09282809 4.5330066 3 Grignolino
106 106 1.76751314 1.71245783 6.0566145 3 Grignolino
107 107 0.36523707 2.16325103 4.8130531 3 Grignolino
108 108 1.61611371 1.35177021 4.4391062 3 Grignolino
109 109 -0.08230361 2.29974728 5.2956114 3 Grignolino
110 110 -1.57383547 1.45792167 4.6024937 3 Grignolino
111 111 -1.41657326 1.41421730 4.0066904 3 Grignolino
112 112 0.27791878 1.92513751 3.7833933 3 Grignolino
113 113 1.29947929 0.76102555 2.2678063 3 Grignolino
114 114 0.45578615 2.26303187 5.3290542 3 Grignolino
115 115 0.49279573 1.93359062 3.9816203 3 Grignolino
116 116 -0.48071836 3.86089273 15.1375828 3 Grignolino
117 117 0.25217752 2.81355567 7.9796890 3 Grignolino
118 118 0.10692601 1.92349609 3.7112704 3 Grignolino
119 119 2.42616867 1.25360477 7.4578193 1 Grignolino
120 120 0.54953935 2.21591073 5.2122539 3 Grignolino
121 121 -0.73754141 1.40499335 2.5179736 3 Grignolino
122 122 -1.33256273 -0.25262431 1.8395425 2 Grignolino
123 123 1.17377592 0.66209914 1.8161252 3 Grignolino
124 124 0.46103449 0.61654897 0.5926854 3 Grignolino
125 125 -0.97572169 1.44150419 3.0299671 3 Grignolino
126 126 0.09653741 2.10406268 4.4363993 3 Grignolino
127 127 -0.03837888 1.26319878 1.5971441 3 Grignolino
128 128 1.59266578 1.20474513 3.9879951 3 Grignolino
129 129 0.47821593 1.93338681 3.9666750 3 Grignolino
130 130 1.78779033 1.14705241 4.5119235 3 Grignolino
131 131 1.32336859 -0.16990994 1.7801738 1 Barbera
132 132 2.37779336 -0.37352893 5.7934251 1 Barbera
133 133 2.92867865 -0.26311960 8.6463906 1 Barbera
134 134 2.14077227 -0.36721907 4.7177558 1 Barbera
135 135 2.36320318 0.45834188 5.7948065 1 Barbera
136 136 3.05522315 -0.35241870 9.4585874 1 Barbera
137 137 3.90473898 -0.15414769 15.2707480 1 Barbera
138 138 3.92539034 -0.65783157 15.8414317 1 Barbera
139 139 3.08557209 -0.34786148 9.6417627 1 Barbera
140 140 2.36779237 -0.29115903 5.6912143 1 Barbera
141 141 2.77099630 -0.28599811 7.7602154 1 Barbera
142 142 2.28012931 -0.37146000 5.3369722 1 Barbera
143 143 2.97723506 -0.48784177 9.1019182 1 Barbera
144 144 2.36851341 -0.48097694 5.8411946 1 Barbera
145 145 2.20364930 -1.15678934 6.1942318 1 Barbera
146 146 2.61823528 -0.56157662 7.1705243 1 Barbera
147 147 4.26859758 -0.64784348 18.6406264 1 Barbera
148 148 3.57256360 -1.26912271 14.3738831 1 Barbera
149 149 2.79916760 -1.56611596 10.2880585 1 Barbera
150 150 2.89150275 -2.03531563 12.5032979 1 Barbera
151 151 2.31420887 -2.34973775 10.8768302 1 Barbera
152 152 2.54265841 -2.03952982 10.6247937 1 Barbera
153 153 1.80744271 -1.52334876 5.5874406 1 Barbera
154 154 2.75238051 -2.13291565 12.1249276 1 Barbera
155 155 2.72945105 -0.40873328 7.6169659 1 Barbera
156 156 3.59472857 -1.79731421 16.1524119 1 Barbera
157 157 2.88169708 -1.91980308 11.9898219 1 Barbera
158 158 3.38261413 -1.30818615 13.1534293 1 Barbera
159 159 1.04523342 -3.50520194 13.3789535 1 Barbera
160 160 1.60538369 -2.39986842 8.3366252 1 Barbera
161 161 3.13428951 -0.73608464 10.3655913 1 Barbera
162 162 2.23385546 -1.17215877 6.3640664 1 Barbera
163 163 2.83966343 -0.55447984 8.3711363 1 Barbera
164 164 2.59019044 -0.69600220 7.1935056 1 Barbera
165 165 2.94100316 -1.55093397 11.0548957 1 Barbera
166 166 3.52010248 -0.88004430 13.1655994 1 Barbera
167 167 2.39934228 -2.58506402 12.4393994 1 Barbera
168 168 2.92084537 -1.27086200 10.1464279 1 Barbera
169 169 2.17527658 -2.07169331 9.0237414 1 Barbera
170 170 2.37423037 -2.58138565 12.3005217 1 Barbera
171 171 3.20258311 0.25054235 10.3193101 1 Barbera
172 172 3.66757294 -0.84536318 14.1657302 1 Barbera
173 173 2.45862032 -2.18762727 10.8305270 1 Barbera
174 174 3.36104305 -2.21005484 16.1809528 1 Barbera
175 175 2.59463669 -1.75228636 9.8026471 1 Barbera
176 176 2.67030685 -2.75313287 14.7102793 1 Barbera
177 177 2.38030254 -2.29088437 10.9139914 1 Barbera
178 178 3.19973210 -2.76113075 17.8621285 1 Barbera
fzPlot + geom_text(data = wine_labs,
aes(label = my_label), size = 2.5)
The answer to question (II) is a bit more tricky. One way to get a sense of which variables may be contributing to the clustering is to examine the variance of the cluster centers for each variable. A higher variance suggests that the cluster centroids differ more on that variable, so that variable may be playing a larger role in separating the clusters.
Note: This is only a descriptive heuristic, not a statistical test. Unlike PCA, k-means does not produce loadings, so there is no direct measure of variable importance in the same sense.
kmeans_res$centers Alcohol Malic Ash Alcalinity Magnesium Phenols
1 0.1644436 0.8690954 0.1863726 0.5228924 -0.07526047 -0.97657548
2 0.8328826 -0.3029551 0.3636801 -0.6084749 0.57596208 0.88274724
3 -0.9234669 -0.3929331 -0.4931257 0.1701220 -0.49032869 -0.07576891
Flavanoids Nonflavanoid Proanthocyanins Color Hue Dilution
1 -1.21182921 0.72402116 -0.77751312 0.9388902 -1.1615122 -1.2887761
2 0.97506900 -0.56050853 0.57865427 0.1705823 0.4726504 0.7770551
3 0.02075402 -0.03343924 0.05810161 -0.8993770 0.4605046 0.2700025
Proline
1 -0.4059428
2 1.1220202
3 -0.7517257
var_importance <- apply(kmeans_res$centers, 2, var)
sort(var_importance, decreasing = TRUE) Flavanoids Dilution Proline Hue Phenols
1.2020837 1.1590920 0.9941934 0.8835955 0.8645478
Color Alcohol Malic Proanthocyanins Nonflavanoid
0.8523894 0.7858539 0.4957524 0.4680695 0.4169275
Alcalinity Magnesium Ash
0.3351087 0.2888914 0.2045454
Part 2: Random Forest
In this section, we will train a Random Forest (RF) model. Random Forest is a simple ensemble machine learning method that builds multiple decision trees and combines their predictions to improve accuracy and robustness. By averaging the results of many trees, it reduces overfitting and increases generalization, making it particularly effective for complex, non-linear relationships. One of its key strengths is its ability to handle large datasets with many features, while also providing insights into feature importance.

Why do we want to try a RF? Unlike linear, logistic, or elastic net regression (presentation 5C - only if time permits), RF does not require predictors (or model residuals) to be normally distributed, nor does RF assume a linear relationship between predictors and the outcome — it can naturally capture non-linear patterns and complex interactions between variables.
Another advantage is that RF considers one predictor at a time when splitting at a predictor, making it robust to differences in variable scales and allowing it to handle categorical variables directly.
The downside to a is RF model is that it typically require a reasonably large sample size to perform well and can be less interpretable compared to regression-based approaches.
First and foremost, lets load the R packages needed for analysis:
library(tidyverse)
library(caret)
library(randomForest)For this exercise we will use a dataset from patients with Heart Disease. Information on the columns in the dataset can be found here.
HD <- read_csv("../data/HeartDisease.csv")
head(HD)# A tibble: 6 × 14
age sex chestPainType restBP chol fastingBP restElecCardio maxHR
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 52 1 0 125 212 0 1 168
2 53 1 0 140 203 1 0 155
3 70 1 0 145 174 0 1 125
4 61 1 0 148 203 0 1 161
5 62 0 0 138 294 1 1 106
6 58 0 0 100 248 0 0 122
# ℹ 6 more variables: exerciseAngina <dbl>, STdepEKG <dbl>,
# slopePeakExST <dbl>, nMajorVessels <dbl>, DefectType <dbl>,
# heartDisease <dbl>
Let’s convert some of the variables that are encoded as numeric datatypes but should be factors:
facCols <- c("sex",
"chestPainType",
"fastingBP",
"restElecCardio",
"exerciseAngina",
"slopePeakExST",
"DefectType",
"heartDisease")
HD <- HD %>%
mutate(across(all_of(facCols), as.factor))
head(HD)# A tibble: 6 × 14
age sex chestPainType restBP chol fastingBP restElecCardio maxHR
<dbl> <fct> <fct> <dbl> <dbl> <fct> <fct> <dbl>
1 52 1 0 125 212 0 1 168
2 53 1 0 140 203 1 0 155
3 70 1 0 145 174 0 1 125
4 61 1 0 148 203 0 1 161
5 62 0 0 138 294 1 1 106
6 58 0 0 100 248 0 0 122
# ℹ 6 more variables: exerciseAngina <fct>, STdepEKG <dbl>,
# slopePeakExST <fct>, nMajorVessels <dbl>, DefectType <fct>,
# heartDisease <fct>
Next, let’s do some summary statistics to have a look at the variables we have in our dataset. Firstly, the numeric columns. We can get a quick overview of variable distributions and ranges with some histograms.
# Reshape data to long format for ggplot2
long_data <- HD %>%
dplyr::select(where(is.numeric)) %>%
pivot_longer(cols = everything(),
names_to = "variable",
values_to = "value")
head(long_data)# A tibble: 6 × 2
variable value
<chr> <dbl>
1 age 52
2 restBP 125
3 chol 212
4 maxHR 168
5 STdepEKG 1
6 nMajorVessels 2
# Plot histograms for each numeric variable in one grid
ggplot(long_data,
aes(x = value)) +
geom_histogram(binwidth = 0.5, fill = "#9395D3", color ='grey30') +
facet_wrap(vars(variable), scales = "free") +
theme_minimal()
Importantly, let’s check the balance of the categorical/factor variables.
HD %>%
dplyr::select(where(is.factor)) %>%
pivot_longer(everything(), names_to = "Variable", values_to = "Level") %>%
dplyr::count(Variable, Level, name = "Count")# A tibble: 22 × 3
Variable Level Count
<chr> <fct> <int>
1 DefectType 0 7
2 DefectType 1 64
3 DefectType 2 544
4 DefectType 3 410
5 chestPainType 0 497
6 chestPainType 1 167
7 chestPainType 2 284
8 chestPainType 3 77
9 exerciseAngina 0 680
10 exerciseAngina 1 345
# ℹ 12 more rows
# OR
# cat_cols <- HD_EN %>% dplyr::select(where(is.factor)) %>% colnames()
#
# for (col in cat_cols){
# print(col)
# print(table(HD_EN[[col]]))
# }From our count table above we see that variables DefectType, chestPainType, restElecCardio, and slopePeakExST are unbalanced. Especially DefectType and restElecCardio are problematic with only 7 and 15 observations for one of the factor levels.
To avoid issues when modelling, we will filter out these observations and re-level the two variables.
HD_RF <- HD %>%
filter(DefectType != "0", restElecCardio != "2") %>%
mutate(DefectType = as.factor(as.character(DefectType)),
restElecCardio = as.factor(as.character(restElecCardio)))
head(HD_RF)# A tibble: 6 × 14
age sex chestPainType restBP chol fastingBP restElecCardio maxHR
<dbl> <fct> <fct> <dbl> <dbl> <fct> <fct> <dbl>
1 52 1 0 125 212 0 1 168
2 53 1 0 140 203 1 0 155
3 70 1 0 145 174 0 1 125
4 61 1 0 148 203 0 1 161
5 62 0 0 138 294 1 1 106
6 58 0 0 100 248 0 0 122
# ℹ 6 more variables: exerciseAngina <fct>, STdepEKG <dbl>,
# slopePeakExST <fct>, nMajorVessels <dbl>, DefectType <fct>,
# heartDisease <fct>
In addition to ensuring an at least somewhat balanced data set, RF requires the outcome variable to be a categorical type, meaning we must convert heartDisease from a binary (0 or 1) variables to a category variable.
# Mutate outcome to category and add ID column for splitting
HD_RF <- HD_RF %>%
mutate(heartDisease = fct_recode(heartDisease, noHD = "0", yesHD = "1"))
head(HD_RF$heartDisease)[1] noHD noHD noHD noHD noHD yesHD
Levels: noHD yesHD
Train & Test Set
We split our dataset into train and test set, we will keep 70% of the data in the training set and take out 30% for the test set. Importantly, we must ensure that all levels of each categorical/factor variable are represented in both sets! There are different ways of doing this in R, but one very easy way to achieve a good split is with the createDataPartition() function from caret.
# Set seed to ensure same split when code is re-run
set.seed(123)
#
idx <- createDataPartition(HD_RF$heartDisease, p = 0.7, list = FALSE)
train <- HD_RF[idx, ]
test <- HD_RF[-idx, ]# Split on outcome variable
table(train$heartDisease)
noHD yesHD
339 364
table(test$heartDisease)
noHD yesHD
144 156
In our case, because we have a fairly balanced dataset, we are lucky and all levels of the outcome variable can be represented in both sets. If this was not the case, we might have had to remove some variables (or levels) altogether.
Random Forest Model
Now let’s set up a RF model with cross-validation - this way we do not overfit our model. The R-package caret has a very versatile function trainControl() which can be used with a range of re-sampling methods including bootstrapping, out-of-bag error, and leave-one-out cross-validation.
set.seed(123)
# Set up cross-validation: 5-fold CV
RFcv <- trainControl(
method = "cv",
number = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary,
savePredictions = "final"
)Now that we have set up parameters for cross validation in the RFcv object above, we can feed it to the train() function from the caret packages. We also specify the training data, the name of the outcome variable, and, importantly, that we want to perform random forest (method = "rf") as the train() function can be used for different models.
set.seed(123)
rf_model <- train(
x = train %>% dplyr::select(-heartDisease),
y = train$heartDisease,
method = "rf",
trControl = RFcv,
metric = "Accuracy",
tuneLength = 5,
importance = TRUE
)
print(rf_model)Random Forest
703 samples
13 predictor
2 classes: 'noHD', 'yesHD'
No pre-processing
Resampling: Cross-Validated (5 fold)
Summary of sample sizes: 563, 562, 562, 563, 562
Resampling results across tuning parameters:
mtry ROC Sens Spec
2 0.9929161 0.9645742 0.9945205
4 0.9927587 0.9675154 0.9945205
7 0.9940726 0.9645303 0.9945205
10 0.9942133 0.9645303 0.9945205
13 0.9937031 0.9645303 0.9917808
ROC was used to select the optimal model using the largest value.
The final value used for the model was mtry = 10.
Next, we can plot your model fit to see how many explanatory variables significantly contribute to our model.
# Best parameters
rf_model$bestTune mtry
4 10
# Plot performance
plot(rf_model)
As we see from the plot above, the model performs best with four variables randomly sampled as candidates at each split (mtry = 10). There are different tuning parameters in a random forest, these include; n_estimators (number of trees), max_features (features to try at splits), max_depth (tree depth), and min_samples_split (minimum number of samples required to split a node).
The tuneLength argument in the train() function specifies how many different values of the tuning parameters to try.
Next, we use the test set to evaluate our model performance. We will do this with the predict function, which will give us the predicted probabilities for each class.
# Predict class probabilities
y_pred <- predict(rf_model, newdata = test, type = "prob")
head(y_pred) noHD yesHD
1 0.994 0.006
2 0.968 0.032
3 0.016 0.984
4 1.000 0.000
5 0.998 0.002
6 1.000 0.000
As the output of predict is a probability of belonging to a specific class, we will need to convert these probabilities to class labels (yesHD or noHD). This way we can compare a predicted label with the true label in the test set to evaluate our model performance.
y_pred <- as.factor(ifelse(y_pred$yesHD > 0.5, "yesHD", "noHD"))
caret::confusionMatrix(y_pred, test$heartDisease)Confusion Matrix and Statistics
Reference
Prediction noHD yesHD
noHD 141 0
yesHD 3 156
Accuracy : 0.99
95% CI : (0.9711, 0.9979)
No Information Rate : 0.52
P-Value [Acc > NIR] : <2e-16
Kappa : 0.98
Mcnemar's Test P-Value : 0.2482
Sensitivity : 0.9792
Specificity : 1.0000
Pos Pred Value : 1.0000
Neg Pred Value : 0.9811
Prevalence : 0.4800
Detection Rate : 0.4700
Detection Prevalence : 0.4700
Balanced Accuracy : 0.9896
'Positive' Class : noHD
Lastly, we can extract the predictive variables with the greatest importance from your fit.
varImpOut <- varImp(rf_model)
varImpOut$importance noHD yesHD
age 62.08613 62.08613
sex 30.02973 30.02973
chestPainType 94.83905 94.83905
restBP 65.98055 65.98055
chol 75.11579 75.11579
fastingBP 0.00000 0.00000
restElecCardio 20.85050 20.85050
maxHR 63.45448 63.45448
exerciseAngina 12.77844 12.77844
STdepEKG 93.18319 93.18319
slopePeakExST 31.04470 31.04470
nMajorVessels 96.73773 96.73773
DefectType 100.00000 100.00000
# Order by importance
varImportance <- as.data.frame(as.matrix(varImpOut$importance)) %>%
rownames_to_column(var = 'VarName') %>%
arrange(desc(yesHD))
varImportance VarName noHD yesHD
1 DefectType 100.00000 100.00000
2 nMajorVessels 96.73773 96.73773
3 chestPainType 94.83905 94.83905
4 STdepEKG 93.18319 93.18319
5 chol 75.11579 75.11579
6 restBP 65.98055 65.98055
7 maxHR 63.45448 63.45448
8 age 62.08613 62.08613
9 slopePeakExST 31.04470 31.04470
10 sex 30.02973 30.02973
11 restElecCardio 20.85050 20.85050
12 exerciseAngina 12.77844 12.77844
13 fastingBP 0.00000 0.00000
Variable importance is based on how much each variable improves the model’s accuracy across splits. Variables DefectType, nMajorVessels and chestPainType appear to be highlight predictive of heart disease.