Machine Learning with R: A Complete Guide to Decision Trees

Estimated time:

<em><strong>Updated</strong>: August 22, 2022.</em> <h2><span data-preserver-spaces="true">R Decision Trees</span></h2> <span data-preserver-spaces="true">R Decision Trees are among the most fundamental algorithms in supervised machine learning, used to handle both regression and classification tasks. In a nutshell, you can think of it as a glorified collection of if-else statements. What makes these if-else statements different from traditional programming is that the logical conditions are "generated" by the machine learning algorithm, but more on that later.</span> <blockquote><span data-preserver-spaces="true">Interested in more basic machine learning guides? </span><a class="editor-rtfLink" href="" target="_blank" rel="noopener noreferrer"><span data-preserver-spaces="true">Check our detailed guide on Logistic Regression with R</span></a><span data-preserver-spaces="true">.</span></blockquote> <span data-preserver-spaces="true">Today you'll learn the basic theory behind the decision trees algorithm and also how to implement the algorithm in R.</span> <span data-preserver-spaces="true">Table of contents:</span> <ul><li><a href="#introduction">Introduction to R Decision Trees</a></li><li><a href="#data-loading">Dataset Loading and Preparation</a></li><li><a href="#modeling">Predictive Modeling with R Decision Trees</a></li><li><a href="#predictions">Generating Predictions</a></li><li><a href="#conclusion">Summary of R Decision Trees</a></li></ul> <hr /> <h2 id="introduction"><span data-preserver-spaces="true">Introduction to R Decision Trees</span></h2> <span data-preserver-spaces="true">Decision trees are intuitive. All they do is ask questions like is the gender male or is the value of a particular variable higher than some threshold. Based on the answers, either more questions are asked, or the classification is made. Simple!</span> <span data-preserver-spaces="true">To predict class labels, the decision tree starts from the root (root node). Calculating which attribute should represent the root node is straightforward and boils down to figuring out which attribute best separates the training records. The calculation is done with the</span><strong><span data-preserver-spaces="true"> gini impurity </span></strong><span data-preserver-spaces="true">formula. It's simple math but can get tedious to do manually if you have many attributes.</span> <span data-preserver-spaces="true">After determining the root node, the tree "branches out" to better classify all of the impurities found in the root node.</span> <span data-preserver-spaces="true">That's why it's common to hear decision tree = multiple if-else statements analogy. The analogy makes sense to a degree, but the conditional statements are calculated automatically. In simple words, the machine learns the best conditions for your data.</span> <span data-preserver-spaces="true">Let's take a look at the following decision tree representation to drive these points further home:</span> <img class="size-full wp-image-6590" src="" alt="Image 1 - Example decision tree" width="573" height="404" /> Image 1 - Example decision tree (<a href="" target="_blank" rel="noopener noreferrer">source</a>) <span data-preserver-spaces="true">As you can see, variables </span><em><span data-preserver-spaces="true">Outlook?</span></em><span data-preserver-spaces="true">, </span><em><span data-preserver-spaces="true">Humidity?</span></em><span data-preserver-spaces="true">, and </span><em><span data-preserver-spaces="true">Windy?</span></em><span data-preserver-spaces="true"> are used to predict the dependent variable - </span><em><span data-preserver-spaces="true">Play</span></em><span data-preserver-spaces="true">.</span> <span data-preserver-spaces="true">You now know the basic theory behind the algorithm, and you'll learn how to implement it in R next.</span> <h2 id="data-loading"><span data-preserver-spaces="true">Dataset Loading and Preparation</span></h2> <span data-preserver-spaces="true">There's no machine learning without data, and there's no working with data without libraries. You'll need these ones to follow along:</span> <pre><code class="language-r">library(caTools) library(rpart) library(rpart.plot) library(caret) library(Boruta) library(cvms) library(dplyr) <br>head(iris)</code></pre> <span data-preserver-spaces="true">As you can see, we'll use the Iris dataset to build our decision tree classifier. This is how the first couple of lines look like (output from the <code>head()</code> function call):</span> <img class="size-full wp-image-6591" src="" alt="Image 2 - Iris dataset head" width="958" height="264" /> Image 2 - Iris dataset head <span data-preserver-spaces="true">The dataset is pretty much familiar to anyone with a week of experience in data science and machine learning, so it doesn't require a further introduction. Also, the dataset is as clean as they come, which will save us a lot of time in this section.</span> <span data-preserver-spaces="true">The only thing we have to do before continuing to predictive modeling is to split this dataset randomly into training and testing subsets. You can use the following code snippet to do a split in a 75:25 ratio:</span> <pre><code class="language-r">set.seed(42) sample_split &lt;- sample.split(Y = iris$Species, SplitRatio = 0.75) train_set &lt;- subset(x = iris, sample_split == TRUE) test_set &lt;- subset(x = iris, sample_split == FALSE)</code></pre> <span data-preserver-spaces="true">And that's it! Let's start with modeling next.</span> <h2 id="modeling"><span data-preserver-spaces="true">Predictive Modeling with R Decision Trees</span></h2> <span data-preserver-spaces="true">We're using the <code>rpart</code> library to build the model. The syntax for building models is identical to linear and logistic regression. You'll need to put the target variable on the left and features on the right, separated with the ~ sign. If you want to use all features, put a dot (.) instead of feature names.</span> <span data-preserver-spaces="true">Also, don't forget to specify <code>method = "class"</code> since we're dealing with a classification dataset here.</span> <span data-preserver-spaces="true">Here's how to train the model:</span> <pre><code class="language-r">model &lt;- rpart(Species ~ ., data = train_set, method = "class") model</code></pre> <span data-preserver-spaces="true">The output of calling <code>model</code> is shown in the following image:</span> <img class="wp-image-6592 size-large" src="" alt="Image 3 - Decision tree classifier model" width="1024" height="302" /> Image 3 - Decision tree classifier model <span data-preserver-spaces="true">From this image alone, you can see the "rules" decision tree model used to make classifications. If you'd like a more visual representation, you can use the <code>rpart.plot</code> package to visualize the tree:</span> <pre><code class="language-r">rpart.plot(model)</code></pre> <img class="size-full wp-image-6593" src="" alt="Image 4 - Visual representation of the decision tree" width="1886" height="1306" /> Image 4 - Visual representation of the decision tree <span data-preserver-spaces="true">You can see how many classifications were correct (in the train set) by examining the bottom nodes. The </span><em><span data-preserver-spaces="true">setosa</span></em><span data-preserver-spaces="true"> was correctly classified every time, the </span><em><span data-preserver-spaces="true">versicolor</span></em><span data-preserver-spaces="true"> was misclassified for </span><em><span data-preserver-spaces="true">virginica</span></em><span data-preserver-spaces="true"> 5% of the time, and </span><em><span data-preserver-spaces="true">virginica</span></em><span data-preserver-spaces="true"> was misclassified for </span><em><span data-preserver-spaces="true">versicolor</span></em><span data-preserver-spaces="true"> 3% of the time. It's a simple graph, but you can read everything from it.</span> <span data-preserver-spaces="true">Decision trees are also useful for examining feature importance, ergo, how much predictive power lies in each feature. You can use the <code>varImp()</code> function to find out. The following snippet calculates the importance and sorts them descendingly:</span> <pre><code class="language-r">importances &lt;- varImp(model) importances %&gt;%  arrange(desc(Overall))</code></pre> <span data-preserver-spaces="true">The results are shown in the image below:</span> <img class="size-full wp-image-6594" src="" alt="Image 5 - Feature importances" width="350" height="190" /> Image 5 - Feature importances If the <code>varImp()</code> doesn't do it for you and you're looking for something more advanced, look no further than Boruta. <h3>Feature Importances with Boruta</h3> Boruta is a feature ranking and selection algorithm based on the Random Forests algorithm. It will tell you if features in your dataset are relevant for making predictions. There are ways to adjust this "relevancy", such as tweaking the P-value and other parameters, but that's not something we'll go over today. A call to <code>boruta()</code> function is identical to <code>part()</code>, with the additional <code>doTrace</code> parameter for limiting the console output. The code snippet below shows you how to find the importance, and how to print them sorted in descending order: <pre><code class="language-r">library(Boruta) <br>boruta_output &lt;- Boruta(Species ~ ., data = train_set, doTrace = 0) rough_fix_mod &lt;- TentativeRoughFix(boruta_output) boruta_signif &lt;- getSelectedAttributes(rough_fix_mod) importances &lt;- attStats(rough_fix_mod) importances &lt;- importances[importances$decision != "Rejected", c("meanImp", "decision")] importances[order(-importances$meanImp), ]</code></pre> <img class="size-full wp-image-15282" src="" alt="Image 6 - Boruta importances" width="510" height="190" /> Image 6 - Boruta importances In case you want to present these results visually, the package has you covered: <pre><code class="language-r">plot(boruta_output, ces.axis = 0.7, las = 2, xlab = "", main = "Feature importance")</code></pre> <img class="size-full wp-image-15284" src="" alt="Image 7 - Boruta plot" width="2186" height="1730" /> Image 7 - Boruta plot Look only for the green color - it means the feature is important. The red color would indicate the feature isn't important, and blue represents the variable used by Boruta to determine importance, so these can be discarded. The higher the box plot on the Y-axis is, the more important the feature. It's that easy! <span data-preserver-spaces="true">You've built and explored the model so far, but there's no use in it yet. The next section shows you how to make predictions on previously unseen data and evaluate the model.</span> <h2 id="predictions"><span data-preserver-spaces="true">Generating Predictions</span></h2> <span data-preserver-spaces="true">Predicting new instances is now a trivial task. All you have to do is use the <code>predict()</code> function and pass in the testing subset. Also, make sure to specify <code>type = "class"</code> for everything to work correctly. Here's an example:</span> <pre><code class="language-r">preds &lt;- predict(model, newdata = test_set, type = "class") preds</code></pre> <span data-preserver-spaces="true">The results are shown in the following image: </span> <img class="wp-image-6595 size-full" src="" alt="Image 8 - Decision tree predictions" width="1402" height="424" /> Image 8 - Decision tree predictions <span data-preserver-spaces="true">But how good are these predictions? Let's evaluate. The confusion matrix is one of the most commonly used metrics to evaluate classification models. In R, it also outputs values for other metrics, such as sensitivity, specificity, and others.</span> <span data-preserver-spaces="true">Here's how you can print the confusion matrix:</span> <pre><code class="language-r">confusionMatrix(test_set$Species, preds)</code></pre> <span data-preserver-spaces="true">And here are the results:</span> <img class="wp-image-6596 size-full" src="" alt="Image 9 - Confusion matrix on the test set" width="1108" height="1144" /> Image 9 - Confusion matrix on the test set <span data-preserver-spaces="true">As you can see, there are some misclassifications in <em>versicolor</em> and <em>virginica</em> classes, similar to what we've seen in the training set. Overall, the model is just short of 90% accuracy, which is more than acceptable for a simple decision tree classifier.</span> But let's be honest - the amount of details in the previous image is overwhelming. What if you want to display the confusion matrix only, and display it visually as a heatmap? That's where the <code>cvms</code> package comes in. It allows you to visually represent a tibble, which is just what we need. Keep in mind the parameters in <code>plot_confusion_matrix()</code> function - all are intuitive to understand, and the values are fetched from <code>cfm</code>. Your's might be different: <pre><code class="language-r">library(cvms) <br>cm &lt;- confusionMatrix(test_set$Species, preds) cfm &lt;- as_tibble(cm$table) plot_confusion_matrix(cfm, target_col = "Reference", prediction_col = "Prediction", counts_col = "n")</code></pre> <img class="size-full wp-image-15286" src="" alt="Image 10 - Confusion matrix plot" width="1602" height="1618" /> Image 10 - Confusion matrix plot Much better, isn't it? Now you have something to present. Let's wrap things up in the following section. <hr /> <h2 id="conclusion"><span data-preserver-spaces="true">Summary of R Decision Trees</span></h2> <span data-preserver-spaces="true">Decision trees are an excellent introductory algorithm to the whole family of tree-based algorithms. It's commonly used as a baseline model, which more sophisticated tree-based algorithms (such as random forests and gradient boosting) need to outperform.</span> <span data-preserver-spaces="true">Today you've learned basic logic and intuition behind decision trees, and how to implement and evaluate the algorithm in R. You can expect the whole suite of tree-based algorithms covered soon, so stay tuned to the Appsilon blog if you want to learn more.</span> <strong><span data-preserver-spaces="true">If you want to implement machine learning in your organization, you can always reach out to </span></strong><a class="editor-rtfLink" href="" target="_blank" rel="noopener noreferrer"><strong><span data-preserver-spaces="true">Appsilon</span></strong></a><strong><span data-preserver-spaces="true"> for help.</span></strong> <h3><span data-preserver-spaces="true">Learn More</span></h3><ul><li><a class="editor-rtfLink" href="" target="_blank" rel="noopener noreferrer"><span data-preserver-spaces="true">Machine Learning with R: A Complete Guide to Logistic Regression</span></a></li><li><a class="editor-rtfLink" href="" target="_blank" rel="noopener noreferrer"><span data-preserver-spaces="true">Machine Learning with R: A Complete Guide to Linear Regression</span></a></li><li><a class="editor-rtfLink" href="" target="_blank" rel="noopener noreferrer"><span data-preserver-spaces="true">What Can I Do With R? 6 Essential R Packages for Programmers</span></a></li><li><a class="editor-rtfLink" href="" target="_blank" rel="noopener noreferrer"><span data-preserver-spaces="true">AI for Good: ML Wildlife Image Classification to Analyze Camera Trap Datasets</span></a></li><li><a class="editor-rtfLink" href="" target="_blank" rel="noopener noreferrer"><span data-preserver-spaces="true">YOLO Algorithm and YOLO Object Detection: An Introduction</span></a></li></ul>

Contact us!
Damian's Avatar
Damian Rodziewicz
Head of Sales
Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.