Did you ever want to make your machine learning model available to other people, but didn’t know how? Or maybe you just heard about the term API, and want to know what’s behind it? Then this post is for you! Here at STATWORX, we use and write APIs daily. For this article, I wrote down how you can build your own API for a machine learning model that you create and the meaning of some of the most important concepts like REST. After reading this short article, you will know how to make requests to your API within a Python program. So have fun reading and learning!

Table of Contents

What is an API?

API is short for Application Programming Interface. It allows users to interact with the underlying functionality of some written code by accessing the interface. There is a multitude of APIs, and chances are good that you already heard about the type of API, we are going to talk about in this blog post: The web API. This specific type of API allows users to interact with functionality over the internet. In this example, we are building an API that will provide predictions through our trained machine learning model. In a real-world setting, this kind of API could be embedded in some type of application, where a user enters new data and receives a prediction in return. APIs are very flexible and easy to maintain, making them a handy tool in the daily work of a Data Scientist or Data Engineer. An example of a publicly available machine learning API is Time Door. It provides Time Series tools that you can integrate into your applications. APIs can also be used to make data available, not only machine learning models.
API Illustration

And what is REST?

Representational State Transfer (or REST) is an approach that entails a specific style of communication through web services. When using some of the REST best practices to implement an API, we call that API a “REST API”. There are other approaches to web communication, too (such as the Simple Object Access Protocol: SOAP), but REST generally runs on less bandwidth, making it preferable to serve your machine learning models. In a REST API, the four most important types of requests are:
  • GET
  • PUT
  • POST
  • DELETE
For our little machine learning application, we will mostly focus on the POST method, since it is very versatile, and lots of clients can’t send GET methods. It’s important to mention that APIs are stateless. This means that they don’t save the inputs you give during an API call, so they don’t preserve the state. That’s significant because it allows multiple users and applications to use the API at the same time, without one user request interfering with another.

The Model

For this How-To-article, I decided to serve a machine learning model trained on the famous iris dataset. If you don’t know the dataset, you can check it out here. When making predictions, we will have four input parameters: sepal length, sepal width, petal length, and finally, petal width. Those will help to decide which type of iris flower the input is. For this example I used the scikit-learn implementation of a simple KNN (K-nearest neighbor) algorithm to predict the type of iris:
# model.py
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.externals import joblib
import numpy as np


def train(X,y):

    # train test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

    knn = KNeighborsClassifier(n_neighbors=1)

    # fit the model
    knn.fit(X_train, y_train)
    preds = knn.predict(X_test)
    acc = accuracy_score(y_test, preds)
    print(f'Successfully trained model with an accuracy of {acc:.2f}')

    return knn

if __name__ == '__main__':

    iris_data = datasets.load_iris()
    X = iris_data['data']
    y = iris_data['target']

    labels = {0 : 'iris-setosa',
              1 : 'iris-versicolor',
              2 : 'iris-virginica'}

    # rename integer labels to actual flower names
    y = np.vectorize(labels.__getitem__)(y)

    mdl = train(X,y)

    # serialize model
    joblib.dump(mdl, 'iris.mdl')
As you can see, I trained the model with 70% of the data and then validated with 30% out of sample test data. After the model training has taken place, I serialize the model with the joblib library. Joblib is basically an alternative to pickle, which preserves the persistence of scikit estimators, which include a large number of numpy arrays (such as the KNN model, which contains all the training data). After the file is saved as a joblib file (the file ending thereby is not important by the way, so don’t be confused that some people call it .model or .joblib), it can be loaded again later in our application.

The API with Python and Flask

To build an API from our trained model, we will be using the popular web development package Flask and Flask-RESTful. Further, we import joblib to load our model and numpy to handle the input and output data. In a new script, namely app.py, we can now set up an instance of a Flask app and an API and load the trained model (this requires saving the model in the same directory as the script):
from flask import Flask
from flask_restful import Api, Resource, reqparse
from sklearn.externals import joblib
import numpy as np

APP = Flask(__name__)
API = Api(APP)

IRIS_MODEL = joblib.load('iris.mdl')
The second step now is to create a class, which is responsible for our prediction. This class will be a child class of the Flask-RESTful class Resource. This lets our class inherit the respective class methods and allows Flask to do the work behind your API without needing to implement everything. In this class, we can also define the methods (REST requests) that we talked about before. So now we implement a Predict class with a .post() method we talked about earlier. The post method allows the user to send a body along with the default API parameters. Usually, we want the body to be in JSON format. Since this body is not delivered directly in the URL, but as a text, we have to parse this text and fetch the arguments. The flask _restful package offers the RequestParser class for that. We simply add all the arguments we expect to find in the JSON input with the .add_argument() method and parse them into a dictionary. We then convert it into an array and return the prediction of our model as JSON.
class Predict(Resource):

    @staticmethod
    def post():
        parser = reqparse.RequestParser()
        parser.add_argument('petal_length')
        parser.add_argument('petal_width')
        parser.add_argument('sepal_length')
        parser.add_argument('sepal_width')

        args = parser.parse_args()  # creates dict

        X_new = np.fromiter(args.values(), dtype=float)  # convert input to array

        out = {'Prediction': IRIS_MODEL.predict([X_new])[0]}

        return out, 200
You might be wondering what the 200 is that we are returning at the end: For APIs, some HTTP status codes are displayed when sending requests. You all might be familiar with the famous 404 - page not found code. 200 just means that the request has been received successfully. You basically let the user know that everything went according to plan. In the end, you just have to add the Predict class as a resource to the API, and write the main function:
API.add_resource(Predict, '/predict')

if __name__ == '__main__':
    APP.run(debug=True, port='1080')
The '/predict' you see in the .add_resource() call, is the so-called API endpoint. Through this endpoint, users of your API will be able to access and send (in this case) POST requests. If you don’t define a port, port 5000 will be the default. You can see the whole code for the app again here:
# app.py
from flask import Flask
from flask_restful import Api, Resource, reqparse
from sklearn.externals import joblib
import numpy as np

APP = Flask(__name__)
API = Api(APP)

IRIS_MODEL = joblib.load('iris.mdl')


class Predict(Resource):

    @staticmethod
    def post():
        parser = reqparse.RequestParser()
        parser.add_argument('petal_length')
        parser.add_argument('petal_width')
        parser.add_argument('sepal_length')
        parser.add_argument('sepal_width')

        args = parser.parse_args()  # creates dict

        X_new = np.fromiter(args.values(), dtype=float)  # convert input to array

        out = {'Prediction': IRIS_MODEL.predict([X_new])[0]}

        return out, 200


API.add_resource(Predict, '/predict')

if __name__ == '__main__':
    APP.run(debug=True, port='1080')

Run the API

Now it’s time to run and test our API! To run the app, simply open a terminal in the same directory as your app.py script and run this command.
python run app.py
You should now get a notification, that the API runs on your localhost in the port you defined. There are several ways of accessing the API once it is deployed. For debugging and testing purposes, I usually use tools like Postman. We can also access the API from within a Python application, just like another user might want to do to use your model in their code. We use the requests module, by first defining the URL to access and the body to send along with our HTTP request:
import requests

url = 'http://127.0.0.1:1080/predict'  # localhost and the defined port + endpoint
body = {
    "petal_length": 2,
    "sepal_length": 2,
    "petal_width": 0.5,
    "sepal_width": 3
}
response = requests.post(url, data=body)
response.json()
The output should look something like this:
Out[1]: {'Prediction': 'iris-versicolor'}
That’s how easy it is to include an API call in your Python code! Please note that this API is just running on your localhost. You would have to deploy the API to a live server (e.g., on AWS) for others to access it.

Conclusion

In this blog article, you got a brief overview of how to build a REST API to serve your machine learning model with a web interface. Further, you now understand how to integrate simple API requests into your Python code. For the next step, maybe try securing your APIs? If you are interested in learning how to build an API with R, you should check out this post. I hope that this gave you a solid introduction to the concept and that you will be building your own APIs immediately. Happy coding!  

Introduction

When working on data science projects in R, exporting internal R objects as files on your hard drive is often necessary to facilitate collaboration. Here at STATWORX, we regularly export R objects (such as outputs of a machine learning model) as .RDS files and put them on our internal file server. Our co-workers can then pick them up for further usage down the line of the data science workflow (such as visualizing them in a dashboard together with inputs from other colleagues). Over the last couple of months, I came to work a lot with RDS files and noticed a crucial shortcoming: The base R saveRDS function does not allow for any kind of archiving of existing same-named files on your hard drive. In this blog post, I will explain why this might be very useful by introducing the basics of serialization first and then showcasing my proposed solution: A wrapper function around the existing base R serialization framework.

Be wary of silent file replacements!

In base R, you can easily export any object from the environment to an RDS file with:
saveRDS(object = my_object, file = "path/to/dir/my_object.RDS")
However, including such a line somewhere in your script can carry unintended consequences: When calling saveRDS multiple times with identical file names, R silently overwrites existing, identically named .RDS files in the specified directory. If the object you are exporting is not what you expect it to be — for example due to some bug in newly edited code — your working copy of the RDS file is simply overwritten in-place. Needless to say, this can prove undesirable. If you are familiar with this pitfall, you probably used to forestall such potentially troublesome side effects by commenting out the respective lines, then carefully checking each time whether the R object looked fine, then executing the line manually. But even when there is nothing wrong with the R object you seek to export, it can make sense to retain an archived copy of previous RDS files: Think of a dataset you run through a data prep script, and then you get an update of the raw data, or you decide to change something in the data prep (like removing a variable). You may wish to archive an existing copy in such cases, especially with complex data prep pipelines with long execution time.

Don’t get tangled up in manual renaming

You could manually move or rename the existing file each time you plan to create a new one, but that’s tedious, error-prone, and does not allow for unattended execution and scalability. For this reason, I set out to write a carefully designed wrapper function around the existing saveRDS call, which is pretty straightforward: As a first step, it checks if the file you attempt to save already exists in the specified location. If it does, the existing file is renamed/archived (with customizable options), and the “updated” file will be saved under the originally specified name. This approach has the crucial advantage that the existing code that depends on the file name remaining identical (such as readRDS calls in other scripts) will continue to work with the latest version without any needs for adjustment! No more saving your objects as “models_2020-07-12.RDS”, then combing through the other scripts to replace the file name, only to repeat this process the next day. At the same time, an archived copy of the — otherwise overwritten — file will be kept.

What are RDS files anyways?

Before I walk you through my proposed solution, let’s first examine the basics of serialization, the underlying process behind high-level functions like saveRDS.
Simply speaking, serialization is the “process of converting an object into a stream of bytes so that it can be transferred over a network or stored in a persistent storage.” Stack Overflow: What is serialization?
There is also a low-level R interface, serialize, which you can use to explore (un-)serialization first-hand: Simply fire up R and run something like serialize(object = c(1, 2, 3), connection = NULL). This call serializes the specified vector and prints the output right to the console. The result is an odd-looking raw vector, with each byte separately represented as a pair of hex digits. Now let’s see what happens if we revert this process:
s <- serialize(object = c(1, 2, 3), connection = NULL)
print(s)
# >  [1] 58 0a 00 00 00 03 00 03 06 00 00 03 05 00 00 00 00 05 55 54 46 2d 38 00 00 00 0e 00
# > [29] 00 00 03 3f f0 00 00 00 00 00 00 40 00 00 00 00 00 00 00 40 08 00 00 00 00 00 00

unserialize(s)
# > 1 2 3
The length of this raw vector increases rapidly with the complexity of the stored information: For instance, serializing the famous, although not too large, iris dataset results in a raw vector consisting of 5959 pairs of hex digits! Besides the already mentioned saveRDS function, there is also the more generic save function. The former saves a single R object to a file. It allows us to restore the object from that file (with the counterpart readRDS), possibly under a different variable name: That is, you can assign the contents of a call to readRDS to another variable. By contrast, save allows for saving multiple R objects, but when reading back in (with load), they are simply restored in the environment under the object names they were saved with. (That’s also what happens automatically when you answer “Yes” to the notorious question of whether to “save the workspace image to ~/.RData” when quitting RStudio.)

Creating the archives

Obviously, it’s great to have the possibility to save internal R objects to a file and then be able to re-import them in a clean session or on a different machine. This is especially true for the results of long and computationally heavy operations such as fitting machine learning models. But as we learned earlier, one wrong keystroke can potentially erase that one precious 3-hour-fit fine-tuned XGBoost model you ran and carefully saved to an RDS file yesterday.

Digging into the wrapper

So, how did I go about fixing this? Let’s take a look at the code. First, I define the arguments and their defaults: The object and file arguments are taken directly from the wrapped function, the remaining arguments allow the user to customize the archiving process: Append the archive file name with either the date the original file was archived or last modified, add an additional timestamp (not just the calendar date), or save the file to a dedicated archive directory. For more details, please check the documentation here. I also include the ellipsis ... for additional arguments to be passed down to saveRDS. Additionally, I do some basic input handling (not included here).
save_rds_archive <- function(object,
                             file = "",
                             archive = TRUE,
                             last_modified = FALSE,
                             with_time = FALSE,
                             archive_dir_path = NULL,
                             ...) {
The main body of the function is basically a series of if/else statements. I first check if the archive argument (which controls whether the file should be archived in the first place) is set to TRUE, and then if the file we are trying to save already exists (note that “file” here actually refers to the whole file path). If it does, I call the internal helper function create_archived_file, which eliminates redundancy and allows for concise code.
if (archive) {

    # check if file exists
    if (file.exists(file)) {

      archived_file <- create_archived_file(file = file,
                                            last_modified = last_modified,
                                            with_time = with_time)

Composing the new file name

In this function, I create the new name for the file which is to be archived, depending on user input: If last_modified is set, then the mtime of the file is accessed. Otherwise, the current system date/time (= the date of archiving) is taken instead. Then the spaces and special characters are replaced with underscores, and, depending on the value of the with_time argument, the actual time information (not just the calendar date) is kept or not. To make it easier to identify directly from the file name what exactly (date of archiving vs. date of modification) the indicated date/time refers to, I also add appropriate information to the file name. Then I save the file extension for easier replacement (note that “.RDS”, “.Rds”, and “.rds” are all valid file extensions for RDS files). Lastly, I replace the current file extension with a concatenated string containing the type info, the new date/time suffix, and the original file extension. Note here that I add a “$” sign to the regex which is to be matched by gsub to only match the end of the string: If I did not do that and the file name would be something like “my_RDS.RDS”, then both matches would be replaced.
# create_archived_file.R

create_archived_file <- function(file, last_modified, with_time) {

  # create main suffix depending on type
  suffix_main <- ifelse(last_modified,
                        as.character(file.info(file)$mtime),
                        as.character(Sys.time()))

  if (with_time) {

    # create clean date-time suffix
    suffix <- gsub(pattern = " ", replacement = "_", x = suffix_main)
    suffix <- gsub(pattern = ":", replacement = "-", x = suffix)

    # add "at" between date and time
    suffix <- paste0(substr(suffix, 1, 10), "_at_", substr(suffix, 12, 19))

  } else {

    # create date suffix
    suffix <- substr(suffix_main, 1, 10)

  }

  # create info to paste depending on type
  type_info <- ifelse(last_modified,
                      "_MODIFIED_on_",
                      "_ARCHIVED_on_")

  # get file extension (could be any of "RDS", "Rds", "rds", etc.)
  ext <- paste0(".", tools::file_ext(file))

  # replace extension with suffix
  archived_file <- gsub(pattern = paste0(ext, "$"),
                        replacement = paste0(type_info,
                                             suffix,
                                             ext),
                        x = file)

  return(archived_file)

}

Archiving the archives?

By way of example, with last_modified = FALSE and with_time = TRUE, this function would turn the character file name “models.RDS” into “models_ARCHIVED_on_2020-07-12_at_11-31-43.RDS”. However, this is just a character vector for now — the file itself is not renamed yet. For this, we need to call the base R file.rename function, which provides a direct interface to your machine’s file system. I first check, however, whether a file with the same name as the newly created archived file string already exists: This could well be the case if one appends only the date (with_time = FALSE) and calls this function several times per day (or potentially on the same file if last_modified = TRUE). Somehow, we are back to the old problem in this case. However, I decided that it was not a good idea to archive files that are themselves archived versions of another file since this would lead to too much confusion (and potentially too much disk space being occupied). Therefore, only the most recent archived version will be kept. (Note that if you still want to keep multiple archived versions of a single file, you can set with_time = TRUE. This will append a timestamp to the archived file name up to the second, virtually eliminating the possibility of duplicated file names.) A warning is issued, and then the already existing archived file will be overwritten with the current archived version.

The last puzzle piece: Renaming the original file

To do this, I call the file.rename function, renaming the “file” originally passed by the user call to the string returned by the helper function. The file.rename function always returns a boolean indicating if the operation succeeded, which I save to a variable temp to inspect later. Under some circumstances, the renaming process may fail, for instance due to missing permissions or OS-specific restrictions. We did set up a CI pipeline with GitHub Actions and continuously test our code on Windows, Linux, and MacOS machines with different versions of R. So far, we didn’t run into any problems. Still, it’s better to provide in-built checks.

It’s an error! Or is it?

The problem here is that, when renaming the file on disk failed, file.rename raises merely a warning, not an error. Since any causes of these warnings most likely originate from the local file system, there is no sense in continuing the function if the renaming failed. That’s why I wrapped it into a tryCatch call that captures the warning message and passes it to the stop call, which then terminates the function with the appropriate message. Just to be on the safe side, I check the value of the temp variable, which should be TRUE if the renaming succeeded, and also check if the archived version of the file (that is, the result of our renaming operation) exists. If both of these conditions hold, I simply call saveRDS with the original specifications (now that our existing copy has been renamed, nothing will be overwritten if we save the new file with the original name), passing along further arguments with ....
        if (file.exists(archived_file)) {
          warning("Archived copy already exists - will overwrite!")
        }

        # rename existing file with the new name
        # save return value of the file.rename function
        # (returns TRUE if successful) and wrap in tryCatch
        temp <- tryCatch({file.rename(from = file,
                                      to = archived_file)
        },
        warning = function(e) {
          stop(e)
        })

      }

      # check return value and if archived file exists
      if (temp & file.exists(archived_file)) {
        # then save new file under specified name
        saveRDS(object = object, file = file, ...)
      }

    }
These code snippets represent the cornerstones of my function. I also skipped some portions of the source code for reasons of brevity, chiefly the creation of the “archive directory” (if one is specified) and the process of copying the archived file into it. Please refer to our GitHub for the complete source code of the main and the helper function. Finally, to illustrate, let’s see what this looks like in action:
x <- 5
y <- 10
z <- 20

## save to RDS
saveRDS(x, "temp.RDS")
saveRDS(y, "temp.RDS")

## "temp.RDS" is silently overwritten with y
## previous version is lost
readRDS("temp.RDS")
#> [1] 10

save_rds_archive(z, "temp.RDS")
## current version is updated
readRDS("temp.RDS")
#> [1] 20

## previous version is archived
readRDS("temp_ARCHIVED_on_2020-07-12.RDS")
#> [1] 10

Great, how can I get this?

The function save_rds_archive is now included in the newly refactored helfRlein package (now available in version 1.0.0!) which you can install directly from GitHub:
# install.packages("devtools")
devtools::install_github("STATWORX/helfRlein")
Feel free to check out additional documentation and the source code there. If you have any inputs or feedback on how the function could be improved, please do not hesitate to contact me or raise an issue on our GitHub.

Conclusion

That’s it! No more manually renaming your precious RDS files — with this function in place, you can automate this tedious task and easily keep a comprehensive archive of previous versions. You will be able to take another look at that one model you ran last week (and then discarded again) in the blink of an eye. I hope you enjoyed reading my post — maybe the function will come in handy for you someday! In my previous blog post, I have shown you how to run your R-scripts inside a docker container. For many of the projects we work on here at STATWORX, we end up using the RShiny framework to build our R-scripts into interactive applications. Using containerization for the deployment of ShinyApps has a multitude of advantages. There are the usual suspects such as easy cloud deployment, scalability, and easy scheduling, but it also addresses one of RShiny’s essential drawbacks: Shiny creates only a single R session per app, meaning that if multiple users access the same app, they all work with the same R session, leading to a multitude of problems. With the help of Docker, we can address this issue and start a container instance for every user, circumventing this problem by giving every user access to their own instance of the app and their individual corresponding R session. If you’re not familiar with building R-scripts into a docker image or with Docker terminology, I would recommend you to first read my previous blog post. So let’s move on from simple R-scripts and run entire ShinyApps in Docker now!

The Setup

Setting up a project

It is highly advisable to use RStudio’s project setup when working with ShinyApps, especially when using Docker. Not only do projects make it easy to keep your RStudio neat and tidy, but they also allow us to use the renv package to set up a package library for our specific project. This will come in especially handy when installing the needed packages for our app to the Docker image. For demonstration purposes, I decided to use an example app created in a previous blog post, which you can clone from the STATWORX GitHub repository. It is located in the “example-app” subfolder and consists of the three typical scripts used by ShinyApps (global.R, ui.R, and server.R) as well as files belonging to the renv package library. If you choose to use the example app linked above, then you won’t have to set up your own RStudio Project, you can instead open “example-app.Rproj”, which opens the project context I have already set up. If you choose to work along with an app of your own and haven’t created a project for it yet, you can instead set up your own by following the instructions provided by RStudio.

Setting up a package library

The RStudio project I provided already comes with a package library stored in the renv.lock file. If you prefer to work with your own app, you can create your own renv.lock file by installing the renv package from within your RStudio project and executing renv::init(). This initializes renv for your project and creates a renv.lock file in your project root folder. You can find more information on renv over at RStudio’s introduction article on it.

The Dockerfile

The Dockerfile is once again the central piece of creating a Docker image. We now aim to repeat this process for an entire app where we previously only built a single script into an image. The step from a single script to a folder with multiple scripts is small, but there are some significant changes needed to make our app run smoothly.
# Base image https://hub.docker.com/u/rocker/
FROM rocker/shiny:latest

# system libraries of general use
## install debian packages
RUN apt-get update -qq && apt-get -y --no-install-recommends install 
    libxml2-dev 
    libcairo2-dev 
    libsqlite3-dev 
    libmariadbd-dev 
    libpq-dev 
    libssh2-1-dev 
    unixodbc-dev 
    libcurl4-openssl-dev 
    libssl-dev

## update system libraries
RUN apt-get update && 
    apt-get upgrade -y && 
    apt-get clean

# copy necessary files
## app folder
COPY /example-app ./app
## renv.lock file
COPY /example-app/renv.lock ./renv.lock

# install renv & restore packages
RUN Rscript -e 'install.packages("renv")'
RUN Rscript -e 'renv::consent(provided = TRUE)'
RUN Rscript -e 'renv::restore()'

# expose port
EXPOSE 3838

# run app on container start
CMD ["R", "-e", "shiny::runApp('/app', host = '0.0.0.0', port = 3838)"]

The base image

The first difference is in the base image. Because we’re dockerizing a ShinyApp here, we can save ourselves a lot of work by using the rocker/shiny base image. This image handles the necessary dependencies for running a ShinyApp and comes with multiple R packages already pre-installed.

Necessary files

It is necessary to copy all relevant scripts and files for your app to your Docker image, so the Dockerfile does precisely that by copying the entire folder containing the app to the image. We can also make use of renv to handle package installation for us. This is why we first copy the renv.lock file to the image separately. We also need to install the renv package separately by using the Dockerfile’s ability to execute R-code by prefacing it with RUN Rscript -e. This package installation allows us to then call renv directly and restore our package library inside the image with renv::restore(). Now our entire project package library will be installed in our Docker image, with the exact same version and source of all the packages as in your local development environment. All this with just a few lines of code in our Dockerfile.

Starting the App at Runtime

At the very end of our Dockerfile, we tell the container to execute the following R-command:
shiny::runApp('/app', host = '0.0.0.0', port = 3838)
The first argument allows us to specify the file path to our scripts, which in our case is ./app. For the exposed port, I have chosen 3838, as this is the default choice for RStudio Server, but can be freely changed to whatever suits you best. With the final command in place every container based on this image will start the app in question automatically at runtime (and of course close it again once it’s been terminated).

The Finishing Touches

With the Dockerfile set up we’re now almost finished. All that remains is building the image and starting a container of said image.

Building the image

We open the terminal, navigate to the folder containing our new Dockerfile, and start the building process:
docker build -t my-shinyapp-image . 

Starting a container

After the building process has finished, we can now test our newly built image by starting a container:
docker run -d --rm -p 3838:3838 my-shinyapp-image
And there it is, running on localhost:3838.
docker-shiny-app-example

Outlook

Now that you have your ShinyApp running inside a Docker container, it is ready for deployment! Having containerized our app already makes this process a lot easier; there are further tools we can employ to ensure state of the art security, scalability, and seamless deployment. Stay tuned until next time, when we’ll go deeper into the full range of RShiny and Docker capabilities by introducing ShinyProxy. From my experience here at STATWORX, the best way to learn something is by trying it out yourself – with a little help from a friend! In this article, I will focus on giving you a hands-on guide on how to build a dashboard in Python. As framework, we will be using Dash, and the goal is to create a basic dashboard with a dropdown and two reactive graphs:
dash-app-final
Developed as an open-source library by Plotly, the Python framework Dash is built on top of Flask, Plotly.js, and React.js. Dash allows the building of interactive web applications in pure Python and is particularly suited for sharing insights gained from data. In case you’re interested in interactive charting with Python, I highly recommend my colleague Markus’ blog post Plotly – An Interactive Charting Library. For a general guide about basic visualization techniques, check out this great article by my colleague Vivian on Basic rules for good looking slides and dashboards. For our purposes, a basic understanding of HTML and CSS can be helpful. Nevertheless, I will provide you with external resources and explain every step thoroughly, so you’ll be able to follow the guide.

Guide structure

The source code can be found on GitHub.

Prerequisites

The project comprises a style sheet called style.css, sample stock data stockdata2.csv and the actual Dash application app.py

Load the Stylesheet

If you want your dashboard to look like the one above, please download the file style.css from our STATWORX GitHub. That is completely optional and won’t affect the functionalities of your app. Our stylesheet is a customized version of the stylesheet used by the Dash Uber Rides Demo. Dash will automatically load any .css-file placed in a folder named assets.
dashapp
    |--assets
        |-- style.css
    |--data
        |-- stockdata2.csv
    |-- app.py
The documentation on external resources in dash can be found here.

Load the Data

Feel free to use the same data we did (stockdata2.csv), or any pick any data with the following structure:
date stock value change
2007-01-03 MSFT 23.95070 -0.1667
2007-01-03 IBM 80.51796 1.0691
2007-01-03 SBUX 16.14967 0.1134
import pandas as pd

# Load data
df = pd.read_csv('data/stockdata2.csv', index_col=0, parse_dates=True)
df.index = pd.to_datetime(df['Date'])

Getting Started – How to start a Dash app

Back to Guide Structure After installing Dash (instructions can be found here), we are ready to start with the application. The following statements will load the necessary packages dash and dash_html_components. Without any layout defined, the app won’t start. An empty html.Div will suffice to get the app up and running.
import dash
import dash_html_components as html
If you have already worked with the WSGI web application framework Flask, the next step will be very familiar to you, as Dash uses Flask under the hood.
# Initialise the app
app = dash.Dash(__name__)

# Define the app
app.layout = html.Div()
# Run the app
if __name__ == '__main__':
    app.run_server(debug=True)

How a .css-files changes the layout of an app

The module dash_html_components provides you with several html components, also check out the documentation. Worth to mention is that the nesting of components is done via the children attribute.
app.layout = html.Div(children=[
                      html.Div(className='row',  # Define the row element
                               children=[
                                  html.Div(className='four columns div-user-controls'),  # Define the left element
                                  html.Div(className='eight columns div-for-charts bg-grey')  # Define the right element
                                  ])
                                ])
The first html.Div() has one child. Another html.Div named row, which will contain all our content. The children of row are four columns div-user-controls and eight columns div-for-charts bg-grey. The style for these div components come from our style.css. Now let’s first add some more information to our app, such as a title and a description. For that, we use the Dash Components H2 to render a headline and P to generate html paragraphs.
children = [
    html.H2('Dash - STOCK PRICES'),
    html.P('''Visualising time series with Plotly - Dash'''),
    html.P('''Pick one or more stocks from the dropdown below.''')
]
Switch to your terminal and run the app with python app.py.
dash-app-first-layout

The basics of an app’s layout

Another nice feature of Flask (and hence Dash) is hot-reloading. It makes it possible to update our app on the fly without having to restart the app every time we make a change to our code. Running our app with debug=True also adds a button to the bottom right of our app, which lets us take a look at error messages, as well a Callback Graph. We will come back to the Callback Graph in the last section of the article when we’re done implementing the functionalities of the app.
dash-app-layout

Charting in Dash – How to display a Plotly-Figure

Back to Guide Structure With the building blocks for our web app in place, we can now define a plotly-graph. The function dcc.Graph() from dash_core_components uses the same figure argument as the plotly package. Dash translates every aspect of a plotly chart to a corresponding key-value pair, which will be used by the underlying JavaScript library Plotly.js. In the following section, we will need the express version of plotly.py, as well as the Package Dash Core Components. Both packages are available with the installation of Dash.
import dash_core_components as dcc
import plotly.express as px
Dash Core Components has a collection of useful and easy-to-use components, which add interactivity and functionalities to your dashboard. Plotly Express is the express-version of plotly.py, which simplifies the creation of a plotly-graph, with the drawback of having fewer functionalities. To draw a plot on the right side of our app, add a dcc.Graph() as a child to the html.Div() named eight columns div-for-charts bg-grey. The component dcc.Graph() is used to render any plotly-powered visualization. In this case, it’s figure will be created by px.line() from the Python package plotly.express. As the express version of Plotly has limited native configurations, we are going to change the layout of our figure with the method update_layout(). Here, we use rgba(0, 0, 0, 0) to set the background transparent. Without updating the default background- and paper color, we would have a big white box in the middle of our app. As dcc.Graph() only renders the figure in the app; we can’t change its appearance once it’s created.
dcc.Graph(id='timeseries',
          config={'displayModeBar': False},
          animate=True,
          figure=px.line(df,
                         x='Date',
                         y='value',
                         color='stock',
                         template='plotly_dark').update_layout(
                                   {'plot_bgcolor': 'rgba(0, 0, 0, 0)',
                                    'paper_bgcolor': 'rgba(0, 0, 0, 0)'})
                                    )
After Dash reload the application, you will end up in something like that: A dashboard with a plotted graph:
dash-app-with-plot
Back to Guide Structure Another core component is dcc.dropdown(), which is used – you’ve guessed it – to create a dropdown menu. The available options in the dropdown menu are either given as arguments or supplied by a function. For our dropdown menu, we need a function that returns a list of dictionaries. The list contains dictionaries with two keys, label and value. These dictionaries provide the available options to the dropdown menu. The value of label is displayed in our app. The value of value will be exposed for other functions to use, and should not be changed. If you prefer the full name of a company to be displayed instead of the short name, you can do so by changing the value of the key label to Microsoft. For the sake of simplicity, we will use the same value for the keys label and value. Add the following function to your script, before defining the app’s layout.
# Creates a list of dictionaries, which have the keys 'label' and 'value'.
def get_options(list_stocks):
    dict_list = []
    for i in list_stocks:
        dict_list.append({'label': i, 'value': i})

    return dict_list
With a function that returns the names of stocks in our data in key-value pairs, we can now add dcc.Dropdown() from the Dash Core Components to our app. Add a html.Div() as child to the list of children of four columns div-user-controls, with the argument className=div-for-dropdown. This html.Div() has one child, dcc.Dropdown(). We want to be able to select multiple stocks at the same time and a selected default value, so our figure is not empty on startup. Set the argument multi=True and chose a default stock for value.
 html.Div(className='div-for-dropdown',
          children=[
              dcc.Dropdown(id='stockselector',
                           options=get_options(df['stock'].unique()),
                           multi=True,
                           value=[df['stock'].sort_values()[0]],
                           style={'backgroundColor': '#1E1E1E'},
                           className='stockselector')
                    ],
          style={'color': '#1E1E1E'})
The id and options arguments in dcc.Dropdown() will be important in the next section. Every other argument can be changed. If you want to try out different styles for the dropdown menu, follow the link for a list of different dropdown menus.

Working with Callbacks

Back to Guide Structure

How to add interactive functionalities to your app

Callbacks add interactivity to your app. They can take inputs, for example, certain stocks selected via a dropdown menu, pass these inputs to a function and pass the return value of the function to another component. We will write a function that returns a figure based on provided stock names. A callback will pass the selected values from the dropdown to the function and return the figure to a dcc.Grapph() in our app. At this point, the selected values in the dropdown menu do not change the stocks displayed in our graph. For that to happen, we need to implement a callback. The callback will handle the communication between our dropdown menu 'stockselector' and our graph 'timeseries'. We can delete the figure we have previously created, as we won’t need it anymore. We want two graphs in our app, so we will add another dcc.Graph() with a different id.
  • Remove the figure argument from dcc.Graph(id='timeseries')
  • Add another dcc.Graph() with className='change' as child to the html.Div() named eight columns div-for-charts bg-grey.
dcc.Graph(id='timeseries', config={'displayModeBar': False})
dcc.Graph(id='change', config={'displayModeBar': False})
Callbacks add interactivity to your app. They can take Inputs from components, for example certain stocks selected via a dropdown menu, pass these inputs to a function and pass the returned values from the function back to components. In our implementation, a callback will be triggered when a user selects a stock. The callback uses the value of the selected items in the dropdown menu (Input) and passes these values to our functions update_timeseries() and update_change(). The functions will filter the data based on the passed inputs and return a plotly figure from the filtered data. The callback then passes the figure returned from our functions back to the component specified in the output. A callback is implemented as a decorator for a function. Multiple inputs and outputs are possible, but for now, we will start with a single input and a single output. We need the class dash.dependencies.Input and dash.dependencies.Output. Add the following line to your import statements.
from dash.dependencies import Input, Output
Input() and Output() take the id of a component (e.g. in dcc.Graph(id='timeseries') the components id is 'timeseries') and the property of a component as arguments. Example Callback:
# Update Time Series
@app.callback(Output('id of output component', 'property of output component'),
              [Input('id of input component', 'property of input component')])
def arbitrary_function(value_of_first_input):
    '''
    The property of the input component is passed to the function as value_of_first_input.
    The functions return value is passed to the property of the output component.
    '''
    return arbitrary_output
If we want our stockselector to display a time series for one or more specific stocks, we need a function. The value of our input is a list of stocks selected from the dropdown menu stockselector.

Implementing Callbacks

The function draws the traces of a plotly-figure based on the stocks which were passed as arguments and returns a figure that can be used by dcc.Graph(). The inputs for our function are given in the order in which they were set in the callback. Names chosen for the function’s arguments do not impact the way values are assigned. Update the figure time series:
@app.callback(Output('timeseries', 'figure'),
              [Input('stockselector', 'value')])
def update_timeseries(selected_dropdown_value):
    ''' Draw traces of the feature 'value' based one the currently selected stocks '''
    # STEP 1
    trace = []  
    df_sub = df
    # STEP 2
    # Draw and append traces for each stock
    for stock in selected_dropdown_value:   
        trace.append(go.Scatter(x=df_sub[df_sub['stock'] == stock].index,
                                 y=df_sub[df_sub['stock'] == stock]['value'],
                                 mode='lines',
                                 opacity=0.7,
                                 name=stock,
                                 textposition='bottom center'))  
    # STEP 3
    traces = [trace]
    data = [val for sublist in traces for val in sublist]
    # Define Figure
    # STEP 4
    figure = {'data': data,
              'layout': go.Layout(
                  colorway=["#5E0DAC", '#FF4F00', '#375CB1', '#FF7400', '#FFF400', '#FF0056'],
                  template='plotly_dark',
                  paper_bgcolor='rgba(0, 0, 0, 0)',
                  plot_bgcolor='rgba(0, 0, 0, 0)',
                  margin={'b': 15},
                  hovermode='x',
                  autosize=True,
                  title={'text': 'Stock Prices', 'font': {'color': 'white'}, 'x': 0.5},
                  xaxis={'range': [df_sub.index.min(), df_sub.index.max()]},
              ),

              }

    return figure

STEP 1

  • A trace will be drawn for each stock. Create an empty list for each trace from the plotly figure.

STEP 2

Within the for-loop, a trace for a plotly figure will be drawn with the function go.Scatter().
  • Iterate over the stocks currently selected in our dropdown menu, draw a trace, and append that trace to our list from step 1.

STEP 3

  • Flatten the traces

STEP 4

Plotly figures are dictionaries with the keys data and layout. The value of data is our flattened list with the traces we have drawn. The layout is defined with the plotly class go.Layout().
  • Add the trace to our figure
  • Define the layout of our figure
Now we simply repeat the steps above for our second graph. Just change the data for our y-Axis to change and slightly adjust the layout. Update the figure change:
@app.callback(Output('change', 'figure'),
              [Input('stockselector', 'value')])
def update_change(selected_dropdown_value):
    ''' Draw traces of the feature 'change' based one the currently selected stocks '''
    trace = []
    df_sub = df
    # Draw and append traces for each stock
    for stock in selected_dropdown_value:
        trace.append(go.Scatter(x=df_sub[df_sub['stock'] == stock].index,
                                 y=df_sub[df_sub['stock'] == stock]['change'],
                                 mode='lines',
                                 opacity=0.7,
                                 name=stock,
                                 textposition='bottom center'))
    traces = [trace]
    data = [val for sublist in traces for val in sublist]
    # Define Figure
    figure = {'data': data,
              'layout': go.Layout(
                  colorway=["#5E0DAC", '#FF4F00', '#375CB1', '#FF7400', '#FFF400', '#FF0056'],
                  template='plotly_dark',
                  paper_bgcolor='rgba(0, 0, 0, 0)',
                  plot_bgcolor='rgba(0, 0, 0, 0)',
                  margin={'t': 50},
                  height=250,
                  hovermode='x',
                  autosize=True,
                  title={'text': 'Daily Change', 'font': {'color': 'white'}, 'x': 0.5},
                  xaxis={'showticklabels': False, 'range': [df_sub.index.min(), df_sub.index.max()]},
              ),
              }

    return figure
Run your app again. You are now able to select one or more stocks from the dropdown. For each selected item, a line plot will be generated in the graph. By default, the dropdown menu has search functionalities, which makes the selection out of many available options an easy task.
dash-app-final

Visualize Callbacks – Callback Graph

With the callbacks in place and our app completed, let’s take a quick look at our callback graph. If you are running your app with debug=True, a button will appear in the bottom right corner of the app. Here we have access to a callback graph, which is a visual representation of the callbacks which we have implemented in our code. The graph shows that our components timeseries and change display a figure based on the value of the component stockselector. If your callbacks don’t work how you expect them to, especially when working on larger and more complex apps, this tool will come in handy.
dash-app-final-callback

Conclusion

Let’s recap the most important building blocks of Dash. Getting the App up and running requires just a couple lines of code. A basic understanding of HTML and CSS is enough to create a simple Dash dashboard. You don’t have to worry about creating interactive charts, Plotly already does that for you. Making your dashboard reactive is done via Callbacks, which are functions with the users’ interaction as the input. If you liked this blog, feel free to contact me via LinkedIn or Email. I am curious to know what you think and always happy to answer any questions about data, my journey to data science, or the exciting things we do here at STATWORX. Thank you for reading! At STATWORX, coding is our bread and butter. Because our projects involve many different people in several organizations across multiple generations of programmers, writing clean code is essential. The main requirements for well structured and readable code are comments and sections. In RStudio, these sections are defined by comments that end with at least four dashes ---- (you can also use trailing equal signs ==== or hashes ####). In my opinion, the code is even more clear if the dashes cover the whole range of 80 characters (why you should not exceed the 80 characters limit). That’s how my code usually looks like:
# loading packages -------------------------------------------------------------
library(dplyr)

# load data --------------------------------------------------------------------
my_iris <- as_tibble(iris)

# prepare data -----------------------------------------------------------------
my_iris_preped <- my_iris %>% 
  filter(Species == "virginica") %>% 
  mutate_if(is.numeric, list(squared = sqrt))

# ...
Clean, huh? Well, yes, but neither of the three options available to achieve this are as neat as I want it to be:
  • Press - for some time.
  • Copy a certain amount of dashes and insert them sequentially. Both options often result in too many dashes, so I have to remove the redundant ones.
  • Use the shortcut to insert a new section (CMD/STRG + SHIFT + R). However, you cannot neatly include it after you wrote your comments.
Wouldn’t it be nice to have a keyboard shortcut that included the right amount of dashes up from the cursor position? “Easy as can be,” I thought before trying to define a custom shortcut in RStudio. Unfortunately, it turned out not to be that easy. There is a manual from RStudio that actually covers how you can create your shortcut, but it requires you to put it in a package first. Since I have not been an expert in R package development myself, I decided to go the full distance in this blogpost. By following it step by step, you should be able to define your shortcuts within a few minutes. Note: This article is not about creating a CRAN-worthy package, but covers what is necessary to define your own shortcuts. If you have already created packages before, you can skip the parts about package development and jump directly to what is new to you.

Setting up an R package

First of all, open RStudio and create an R package directory. For this, please do the following steps:
  1. Go to “New Project…”
  2. “New Directory”
  3. “R Package”
  4. Select an awesome package name of your choice. In this example, I named my package shoRtcut
  5. In “Create project as subdirectory of:” select a directory of your choice. A new folder with your package name will be created in this directory.
Tada, everything necessary for a powerful R Package has been set up. RStudio also automatically provides a dummy function hello(). Since we do not like to have this function in our own package, move to the “R” folder in your project and delete the hello.R file. Do the same in the “man” folder and delete hello.md.

Creating an Addin Function

Now we can start and define our function. For this, we need the wonderful packages usethis and devtools. These provide all the functionality we need for the next steps.

Defining the Addin Function

Via the use_r() function, we define a new R script file with the given name. That should correspond to the name of the function we are about to create. In my case, I call it set_new_chapter.
# use this function to automatically create a new r script for your function
usethis::use_r("set_new_chapter")
You are directly forwarded to the created file. Now the tricky part begins, defining a function that does what you want. When defining shortcuts that interact with an R script in RStudio, you will soon discover the package rstudioapi. With its functions, you can grab all information from RStudio and make it available within R. Let me guide you through it step by step.
  1. As per usual, I set up a regular R function and define its name as set_new_chapter. Next, I define up until which limit I want to include the dashes. You will note that I rather set nchars to 81 than 80. This is because the number corresponds to the cursor position after including the dashes. You will notice that when you write text, the cursor automatically jumps to the position right after the newly typed character. After you have written your 80th character, the cursor will be at position 81.
  2. Now we have to find out where the cursor is currently located. This information can be unearthed by the getActiveDocumentContext() function. The returned object returns quite a bit of information, but we are only interested in the cursor position regarding the column. Why the column? You can think of the script like a matrix. Hitting return brings you to a new row, typing a character to a new column. Having a font with equal space characters, which is the default setting in RStudio makes this concept easy to see.
  3. By sneaking into the nested list, we find the information we are looking for and store it in context_col. Now we check whether the cursor is already at “column” 81. If not, there is space in which we insert the dashes. For this final step, we can use another function: insertText.
  4. As its name implies, it inserts text in an R script or console. You can either specify a specific position in the document or, by leaving it empty, insert text at the current cursor position, which is exactly what I want right now. As the final step, I need to find out the number of dashes that should be inserted. That’s the difference between the current cursor location and its target position. For example, if the cursor blinks at column 51, meaning I already have typed 50 characters, I want to insert 30 dashes.
  5. To document the function, I use the “Code” > “Insert Roxygen Skeleton” feature and fill it out appropriately.
This is what my final function looks like.
#' Insert dashes from courser position to up to 80 characters
#'
#' @return dashes inside RStudio
set_new_chapter <- function(){
  # set limit to which position dashes should be included
  nchars <- 81

  # grab current document information
  context <- rstudioapi::getActiveDocumentContext()
  # extract horizontal courser position in document
  context_col <- contextselection[[1]]range$end["column"]

  # if a line has less than 81 characters, insert hyphens at the current line
  # up to 80 characters
  if (nchars > context_col) {
    rstudioapi::insertText(strrep("-", nchars - context_col))
  }
}

Defining the Function AS and Addin

Now we must somehow tell RStudio that this particular function should be used as an addin rather than a regular function. For this, go to “File” > “New File” > “Text File” and include the following text:
Name: Insert Dashes (---)
Description: Inserts `---` at the cursor position up to 80 characters.
Binding: set_new_chapter
Interactive: false
  • Name is a short description of what the addin does. This will be displayed when you want to set the shortcut later.
  • Description is a longer description of its functionality.
  • Binding sets the name of the function that should be called by the shortcut.
  • Interactive defines whether this addin is interactive (e.g., runs a Shiny application) or not.
You now must save this file as “addins.dcf” in your project with the following path: “inst” > “rstudio”. The result should look like this:

Finalize the Package

To wrap everything up and make the shortcut available to you and your colleagues, we only have to call a few more functions. Not all these steps are necessary, yet it is good practice to create a proper package.
# OPTIONAL: define the license of your package
usethis::use_mit_license(name = "Matthias Nistler")

# define dependencies you use in your package
usethis::use_package("rstudioapi")

# OPTIONAL: include your function description to the manual
roxygen2::roxygenise()

# check for errors
devtools::check()

# update/create your package
devtools::build()

> ✓  checking for file ‘/Users/matthiasnistler/Projekte/2020/blog_shoRtcut/DESCRIPTION’ ...
> ─  preparing ‘shoRtcut’:
> ✓  checking DESCRIPTION meta-information ...
> ─  checking for LF line-endings in source and make files and shell scripts
> ─  checking for empty or unneeded directories
> ─  building ‘shoRtcut_0.0.0.9000.tar.gz’
> [1] "/Users/matthiasnistler/Projekte/2020/shoRtcut_0.0.0.9000.tar.gz"
There you go! You just created an awesome package and distributed it to your friends and colleagues.

Make the shortcut available

For the last step, you have to install your package and set a keyboard combination for your shortcut. For this, use the following specification of install.packages:
install.packages(
    # same path as above
  "/Users/matthiasnistler/Projekte/2020/shoRtcut_0.0.0.9000.tar.gz", 
  # indicate it is a local file
  repos = NULL)

# check if everything works
shoRtcut:::set_new_chapter()
Now go to “Tools” > “Modify Keyboard Shortcuts…” and search for “dashes”. Here you can define the keyboard combination by clicking inside the empty “Shortcut” field and pressing the desired key-combination on your keyboard. Click “Apply”, and that’s it!
In case you are just here to use my shortcut, you can install it via remotes::install_github("mnist91/shoRtcut").

Congratulations!

You made it! Now you can use your own RStudio shortcut. Exciting, isn’t it? But that’s not all there ist – next week, I will give you an introduction to the wonderful world of R package naming. So stay tuned and happy coding! Data operations is an increasingly important part of data science because it enables companies to feed large business data back into production effectively. We at STATWORX, therefore, operationalize our models and algorithms by translating them into Application Programming Interfaces (APIs). Representational State Transfer (REST) APIs are well suited to be implemented as part of a modern micro-services infrastructure. They are flexible, easy to deploy, scale, and maintain, and they are further accessible by multiple clients and client types at the same time. Their primary purpose is to simplify programming by abstracting the underlying implementation and only exposing objects and actions that are needed for any further development and interaction. An additional advantage of APIs is that they allow for an easy combination of code written in different programming languages or by different development teams. This is because APIs are naturally separated from each other, and communication with and between APIs is handled by IP or URL (http), typically using JSON or XML format. Imagine, e.g., an infrastructure, where an API that’s written in Python and one that’s written in R communicate with each other and serve an application written in JavaScript. In this blog post, I will show you how to translate a simple R script, which transforms tables from wide to long format, into a REST API with the R package Plumber and how to run it locally or with Docker. I have created this example API for our trainee program, and it serves our new data scientists and engineers as a starting point to familiarize themselves with the subject.

Translate the R Script

Transforming an R script into a REST API is quite easy. All you need, in addition to R and RStudio, is the package Plumber and optionally Docker. REST APIs can be interacted with by sending a REST Request, and the probably most commonly used ones are GET, PUT, POST, and DELETE. Here is the code of the example API, that transforms tables from wide to long or from long to wide format:
## transform wide to long and long to wide format
#' @post /widelong
#' @get /widelong
function(req) {
  # library
  require(tidyr)
  require(dplyr)
  require(magrittr)
  require(httr)
  require(jsonlite)

  # post body
  body <- jsonlite::fromJSON(req$postBody)

  .data <- body$.data
  .trans <- body$.trans
  .key <- body$.key
  .value <- body$.value
  .select <- body$.select

  # wide or long transformation
  if(.trans == 'l' || .trans == 'long') {
    .data %<>% gather(key = !!.key, value = !!.value, !!.select)
    return(.data)
  } else if(.trans == 'w' || .trans == 'wide') {
    .data %<>% spread(key = !!.key, value = !!.value)
    return(.data)
  } else {
    print('Please specify the transformation')
  }
}
As you can see, it is a standard R function, that is extended by the special plumber comments @post and @get, which enable the API to respond to those types of requests. It is necessary to add the path, /widelong, to any incoming request. That is done because it is possible to stack several API functions, which respond to different paths. We could, e.g., add another function with the path /naremove to our API, which removes NAs from tables. The R function itself has one function argument req, which is used to receive a (POST) Request Body. In general, there are two different possibilities to send additional arguments and objects to a REST API, the header and the body. I decided to use a body only and no header at all, which makes the API cleaner, safer and allows us to send larger objects. A header could, e.g., be used to set some optional function arguments, but should be used sparsely otherwise. Using a body with the API is also the reason to allow for GET and POST Requests (@post, @get) at the same time. While some clients prefer to send a body with a GET Request, when they do not permanently post something to the server etc., many other clients do not have the option to send a body with a GET Request at all. In this case, it is mandatory to add a POST Request. Typical clients are Applications, Integrated Development Environments (IDEs), and other APIs. By accepting both request types, our API, therefore, gains greater response flexibility. For the request-response format of the API, I have decided to stick with the JavaScript Object Notation (JSON), which is probably the most common format. It would be possible to use Extensible Markup Language (XML) with R Plumber instead as well. The decision for one or the other will most likely depend on which additional R packages you want to use or on which format the API’s clients are predominantly using. The R packages that are used to handle REST Requests in my example API are jsonlite and httr. The three Tidyverse packages are used to do the table transformation to wide or long.

RUN the API

The finished REST API can be run locally with R or RStudio as follows:
library(plumber)

widelong_api <- plumber::plumb("./path/to/directory/widelongwide.R")
widelong_api$run(host = '127.0.0.1', port = 8000)
Upon starting the API, the Plumber package provides us with an IP address, and a port and a client, e.g., another R instance, can now begin to send REST Requests. It also opens a browser tool called Swagger, which can be useful to check if your API is working as intended. Once the development of an API is finished, I would suggest to build a docker image and run it in a container. That makes the API highly portable and independent of its host system. Since we want to use most APIs in production and deploy them to, e.g., a company server or the cloud, this is especially important. Here is the Dockerfile to build the docker image of the example API:
FROM trestletech/plumber

# Install dependencies
RUN apt-get update --allow-releaseinfo-change && apt-get install -y 
    liblapack-dev 
    libpq-dev

# Install R packages
RUN R -e "install.packages(c('tidyr', 'dplyr', 'magrittr', 'httr', 'jsonlite'), 
repos = 'http://cran.us.r-project.org')"

# Add API
COPY ./path/to/directory/widelongwide.R /widelongwide.R

# Make port available
EXPOSE 8000

# Entrypoint
ENTRYPOINT ["R", "-e", 
"widelong <- plumber::plumb('widelongwide.R'); 
widelong$run(host = '0.0.0.0', port= 8000)"]

CMD ["/widelongwide.R"]

Send a REST Request

The wide-long example API can generally respond to any client sending a POST or GET Request with a Body in JSON format, that contains a table in csv format and all needed information on how to transform it. Here is an example for a web application, which I have written for our trainee program to supplement the wide-long API:
The application is written in R Shiny, which is a great R package to transform your static plots and outputs into an interactive dashboard. If you are interested in how to create dashboards in R, check out other posts on our STATWORX Blog. Last but not least here is an example on how to send a REST Request from R or RStudio:
library(httr)
library(jsonlite)
options(stringsAsFactors = FALSE)

# url for local testing
url <- "http://127.0.0.1:8000"

# url for docker container
url <- "http://0.0.0.0:8000"

# read example stock data
.data <- read.csv('./path/to/data/stocks.csv')

# create example body
body <- list(
  .data = .data,
  .trans = "w",
  .key = "stock",
  .value = "price",
  .select = c("X","Y","Z")
)

# set API path
path <- 'widelong'

# send POST Request to API
raw.result <- POST(url = url, path = path, body = body, encode = 'json')

# check status code
raw.result$status_code

# retrieve transformed example stock data
.t_data <- fromJSON(rawToChar(raw.result$content))
As you can see, it is quite easy to make REST Requests in R. If you need some test data, you could use the stocks data example from the Tidyverse.

Summary

In this blog post, I showed you how to translate a simple R script, which transforms tables from wide to long format, into a REST API with the R package Plumber and how to run it locally or with Docker. I hope you enjoyed the read and learned something about operationalizing R scripts into REST APIs with the R package Plumber and how to run them locally and with Docker. You are of welcome to copy and use any code from this blog post to start and create your REST APIs with R. Until then, stay tuned and visit our STATWORX Blog again soon.

We’re hiring!

Data Engineering is your jam and you’re looking for a job? We’re currently looking for Junior Consultants and Consultants in Data Engineering. Check the requirements and benefits of working with us on our career site. We’re looking forward to your application! Nearly one year ago, I analyzed how we use emojis in our Slack messages. Since then, STATWORX grew, and we are a lot more people now! So, I just wanted to check if something changed. Last time, I did not show our custom emojis, since they are, of course, not available in the fonts I used. This time, I will incorporate them with geom_image(). It is part of the ggimage package from Guangchuang Yu, which you can find here on his Github. With geom_image() you can include images like .png files to your ggplot.

What changed since last year?

Let’s first have a look at the amount of emojis we are using. In the plot below, you can see that since my last analysis in October 2018 (red line) the amount of emojis is rising. Not as much as I thought it would, but compared to the previous period, we now have more days with a usage of over 100 emojis per day!
Like last time, our top emoji is 👍, followed by 😂 and 😄. But sneaking in at number ten is one of our custom emojis: party_hat_parrot!
top-10-used-emojis

How to include custom images?

In my previous blogpost, I hid all our custom emojis behind❓since they were not part of the font. It did not occur to me to use their images, even though the package is from the same creator! So, to make up for my ignorance, I grabbed the top 30 custom emojis and downloaded their images from our Slack servers, saved them as .png and made sure they are all roughly the same size. To use geom_image() I just added the path of the images to my data (the are just an abbreviation for the complete path).
                NAME COUNT REACTION IMAGE
1:          alnatura    25       63 .../custom/alnatura.png
2:              blog    19       20 .../custom/blog.png
3:           dataiku    15       22 .../custom/dataiku.png
4: dealwithit_parrot     3      100 .../custom/dealwithit_parrot.png
5:      deananddavid    31       18 .../custom/deananddavid.png
This would have been enough to just add the images now, but since I wanted the NAME attribute as a label, I included geom_text_repel from the ggrepel library. This makes handling of non-overlapping labels much simpler!
ggplot(custom_dt, aes( x = REACTION, y = COUNT, label = NAME)) +
  geom_image(aes(image = IMAGE), size = 0.04) +
  geom_text_repel(point.padding = 0.9, segment.alpha = 0) +
  xlab("as reaction") +
  ylab("within message") +
  theme_minimal()
Usually, if a label is “too far” away from the marker, geom_text_repel includes a line to indicate where the labels belong. Since these lines would overlap the images, I used segment.alpha = 0 to make them invisible. With point.padding = 0.9 I gave the labels a bit more space, so it looks nicer. Depending on the size of the plot, this needs to be adjusted. In the plot, one can see our usage of emojis within a message (y-axis) and as a reaction (x-axis).
To combine the emoji font and custom emojis, I used the following data and code — really… why did I not do this last time? 🤔 Since the UNICODE is NA when I want to use the IMAGE, there is no “double plotting”.
                     EMOJI REACTION COUNT  SUM PLACE    UNICODE   IMAGE
 1:                    :+1:     1090     0 1090     1 U0001f44d
 2:                   :joy:      609   152  761     2 U0001f602
 3:                 :smile:       91   496  587     3 U0001f604
 4:                    :-1:      434     9  443     4 U0001f44e
 5:                  :tada:      346    38  384     5 U0001f389
 6:                  :fire:      274    17  291     6 U0001f525
 7: :slightly_smiling_face:        1   250  251     7 U0001f642
 8:                  :wink:       27   191  218     8 U0001f609
 9:                  :clap:      201    13  214     9 U0001f44f
10:      :party_hat_parrot:      192     9  201    10       <NA>  .../custom/party_hat_parrot.png
quartz()
ggplot(plotdata2, aes(x = PLACE, y = SUM, label = UNICODE)) +
  geom_bar(stat = "identity", fill = "steelblue") +
  geom_text(family="EmojiOne") +
  xlab("Most popular emojis") +
  ylab("Number of usage") +
  scale_fill_brewer(palette = "Paired") +
  geom_image(aes(image = IMAGE), size = 0.04) +
  theme_minimal()
ps = grid.export(paste0(main_path, "plots/top-10-used-emojis.svg"), addClass=T)
dev.off()

The meaning behind emojis

Now we know what our top emojis are. But what is the rest of the world doing? Thanks to Emojimore for providing me with this overview! On their site, you can find meanings for a lot more emojis.
Behind each of our custom emojis is a story as well. For example, all the food emojis are helping us every day to decide where to eat and provide information on what everyone is planning for lunch! And if you do not agree with the decision, just react with sadphan to let the others know about your feelings. If you want to know the whole stories behind all custom emojis or even help create new ones, then maybe you should join our team — check out our available job offers here!   The tidyverse is making the life of a data scientist a lot easier. That’s why we at STATWORX love to execute our analytics and data science with the tidyverse. Its user-centered approach has many advantages. Instead of the base R version df_test[df_test$x > 10], we can write df_test %>% filter(x>10)), which is a lot more readable – especially if our data pipeline gets more complex and nested. Also, as you might have noticed, we can use the column names directly instead of referencing the Data Frame before. Because of those advantages, we want to use dplyr-verbs for writing our function. Imagine we want to write our own summary-function my_summary(), which takes a grouping variable and calculates some descriptive statistics. Let’s see what happens when we wrap a dplyr-pipeline into a function:
my_summary <- function(df, grouping_var){
 df %>%
  group_by(grouping_var) %>% 
  summarise(
   avg = mean(air_time),
   sum = sum(air_time),
   min = min(air_time),
   max = max(air_time),
   obs = n()
  )
}
my_summary(airline_df, origin)
Error in grouped_df_impl(data, unname(vars), drop) : 
 Column `grouping_var` is unknown 
Our new function uses group_by(), which is searching for a grouping variable grouping_var, and not for origin, as we intended. So, what happened here? group_by() is searching within its scope for the variable grouping_var, which it does not find. group_by() is quoting its arguments, grouping_var in our example. That’s why dplyr can implement custom ways of handling its operation. Throughout the tidyverse, tidy evaluation is used. Therefore we can use column names, as it is a variable. However, our data frame has no column grouping_var.

Non-Standard Evaluation

Talking about whether an argument is quoted or evaluated is a more precise way of stating whether or not a function uses non-standard evaluation (NSE). – Hadley Wickham
The quoting used by group_by() means, that it uses non-standard evaluation, like most verbs you can find in dplyr. Nonetheless, non-standard evaluation is not only found and used within dplyr and the tidyverse.
non-standard-evaluation
Because dplyr quotes its arguments, we have to do two things to use it in our function:
  • First, we have to quote our argument
  • Second, we have to tell dplyr, that we already have quoted the argument, which we do with unquoting
We will see this quote-and-unquote pattern consequently through functions which are using tidy evaluation.
my_summary <- function(df, grouping_var){
  df %>%
    group_by(!!grouping_var) %>% 
    summarise(
      avg = mean(air_time),
      sum = sum(air_time),
      min = min(air_time),
      max = max(air_time),
      obs = n()
    )
}
my_summary(airline_df, quo(origin))
Therefore, as input in our function, we quote the origin-variable, which means that R doesn’t search for the symbol origin in the global environment, but holds evaluation. The quotation takes place with the quo() function. In order to tell group_by(), that the variable was already quoted we need to use the !!-Operator; pronounced Bang-Bang (if you wondered about the title). If we are not using !!, group_by() at first searches for the variable within its scope, which are the columns of the given data frame. As mentioned before, throughout the tidyverse, tidy evaluation is used with its eval_tidy()-function. Whereby, it also introduces the concept of data mask, which makes data a first class object in R.

Data Mask

data-mask
Generally speaking, the data mask approach is much more convenient. However, on the programming site, we have to pay attention to some things, like the quote-and-unquote pattern from before. As a next step, we want the quotation to take place inside of the function, so the user of the function does not have to do it. Sadly, using quo() inside the function does not work.
my_summary <- function(df, grouping_var){
  quo(grouping_var)
  df %>%
    group_by(!!grouping_var) %>% 
    summarise(
      avg = mean(air_time),
      sum = sum(air_time),
      min = min(air_time),
      max = max(air_time),
      obs = n()
    )
}
my_summary(airline_df, origin)
Error in quos(...) : object 'origin' not found 
We are getting an error message because quo() is taking it too literal and is quoting grouping_var directly instead of substituting it with origin as desired. That’s why we use the function enquo() for enriched quotation, which creates a quosure. A quosure is an object which contains an expression and an environment. Quosures redefine the internal promise object into something that can be used for programming. Thus, the following code is working, and we see the quote-and-unquote pattern again.
my_summary <- function(df, grouping_var){
  grouping_var <- enquo(grouping_var)
  df %>%
    group_by(!!grouping_var) %>% 
    summarise(
      avg = mean(air_time),
      sum = sum(air_time),
      min = min(air_time),
      max = max(air_time),
      obs = n()
    )
}
my_summary(airline_df, origin)
# A tibble: 2 x 6
  origin   avg    sum   min   max   obs
  <fct>  <dbl>  <int> <int> <int> <int>
1 JFK     166. 587966    19   415  3539
2 LAX     132. 850259     1   381  6461

All R code is a tree

To better understand what’s happening, it is useful to know that every R code can be represented by an Abstract Syntax Tree (AST) because the structure of the code is strictly hierarchical. The leaves of an AST are either symbols or constants. The more complex a function call gets, the deeper an AST is getting with more and more levels. Symbols are drawn in dark-blue and have rounded corners, whereby constants have green borders and square corners. The strings are surrounded by quotes so that they won’t be confused with symbols. The branches are function calls and are depicted as orange rectangles.
a(b(4, "s"), c(3, 4, d()))
abstract-syntax-tree
To understand how an expression is represented as an AST, it helps to write it in its prefix form.
y <- x * 10
`<-`(y, `*`(x, 10))
prefix-tree
There is also the R package called lobstr, which contains the function ast() to create an AST from R Code. The code from the first example lobstr::ast(a(b(4, "s"), c(3, 4, d()))) results in this:
lobstr-ast
It looks as expected and just like our hand-drawn AST. The concept of ASTs helps us to understand what is happening in our function. So, if we have the following simple function, !!` introduces a placeholder (promise) for x.
x <- expr(-1)
f(!!x, y)
Due to R’s lazy evaluation, the function f() is not evaluated immediately, but once we called it. At the moment of the function call, the placeholder x is replaced by an additional AST, which can get arbitrary complex. Furthermore, it keeps the order of the operators correct, which is not the case when we use parse() and paste() with strings. So the resulting AST of our code snippet is the following:
promise-tree
Furthermore, !! also works with symbols, functions, and constants.

Perfecting our function

Now, we want to add an argument for the variable we are summarizing to refine our function. At the moment we have air_time hardcoded into it. Thus, we want to replace it with a general summary_var as an argument in our function. Additionally, we want the column names of the final output data frame to be adjusted dynamically, depending on the input variable. For adding summary_var, we follow the quote and unquote pattern from above. However, for the column-naming, we need two additional functions. Firstly, quo_name(), which converts a quoted symbol into a string. Therefore, we can use normal string operations on it and, e.g. use the base paste command for manipulating it. However, we also need to unquote it, which would be on the Left-Hand-Side, where R is not allowing any computations. Thus, we need the second function, the vestigial operator := instead of the normal =.
my_summary <- function(df, grouping_var, summary_var){
  grouping_var <- enquo(grouping_var)
  summary_var <- enquo(summary_var)
  summary_nm <- quo_name(summary_var)
  summary_nm_avg <- paste0("avg_",summary_nm)
  summary_nm_sum <- paste0("sum_",summary_nm)
  summary_nm_obs <- paste0("obs_",summary_nm)

  df %>%
    group_by(!!grouping_var) %>% 
    summarise(
      !!summary_nm_avg := mean(!!summary_var),
      !!summary_nm_sum := sum(!!summary_var),
      !!summary_nm_obs := n()
    )
}
my_summary(airline_df, origin, air_time)
# A tibble: 2 x 4
  origin avg_air_time sum_air_time obs_air_time
  <fct>         <dbl>        <int>        <int>
1 JFK            166.       587966         3539
2 LAX            132.       850259         6461

Tidy Dots

In the next step, we want to add the possibility to summarize an arbitrary number of variables. Therefore, we need to use tidy dots (or dot-dot-dot) . E.g. if we call the documentation for select(), we get
Usage select(.data, ...) Arguments ... One or more unquoted expressions separated by commas.
In select() we can use any number of variables we want to select. We will use tidy dots ... in our function. However, there are some things we have to account for. Within the function, ... is treated as a list. So we cannot use !! or enquo(), because these commands are made for single variables. However, there are counterparts for the case of .... In order to quote several arguments at once, we can use enquos(). enquos() gives back a list of quoted arguments. In order to unquote several arguments we need to use !!!, which is also called the big bang-Operator. !!! replaces arguments one-to-many, which is called unquote-splicing and respects hierarchical orders.
splicing
With using purrr, we can neatly handle the computation with our list entries provided by ... (for more information ask your Purrr-Macist). So, putting everything together, we finally arrive at our final function.
my_summary <- function(df, grouping_var, ...) {
  grouping_var <- enquo(grouping_var)

  smry_vars <- enquos(..., .named = TRUE)

  smry_avg <- purrr::map(smry_vars, function(var) {
    expr(mean(!!var, na.rm = TRUE))
  })
  names(smry_avg) <- paste0("avg_", names(smry_avg))

  smry_sum <- purrr::map(smry_vars, function(var) {
    expr(sum(!!var, na.rm = TRUE))
  })
  names(smry_sum) <- paste0("sum_", names(smry_sum))

  df %>%
    group_by(!!grouping_var) %>%
    summarise(!!!smry_avg, !!!smry_sum, obs = n())
}

my_summary(airline_df, origin, dep_delay, arr_delay)
# A tibble: 2 x 6
  origin avg_dep_delay avg_arr_delay sum_dep_delay sum_arr_delay   obs
  <fct>          <dbl>         <dbl>         <int>         <int> <int>
1 JFK            12.9          11.8          45792         41625  3539
2 LAX             8.64          5.13         55816         33117  6461

And the tidy evaluation goes on and on

As mentioned in the beginning, tidy evaluation is not only used within dplyr but within most of the packages in the tidyverse. Thus, to know how tidy evaluation works is also helpful if one wants to use ggplot in order to create a function for a styled version of a grouped scatter plot. In this example, the function takes the data, the values for the x and y-axes as well as the grouping variable as inputs:
scatter_by <- function(.data, x, y, z=NULL) {
  x <- enquo(x)
  y <- enquo(y)
  z <- enquo(z)

  ggplot(.data) + 
    geom_point(aes(!!x, !!y, color = !!z)) +
    theme_minimal()
}
scatter_by(airline_df, distance, air_time, origin) 
scatter-1
Another example would be to use R Shiny Inputs in a sparklyr-Pipeline. input$ cannot be used directly within sparklyr, because it would try to resolve the input list object on the spark side. server.R
library(shiny)
library(dplyr)
library(sparklyr)

# Define server logic required to filter numbers
shinyServer(function(input, output) {
    tbl_1 <- tibble(a = 1:5, b = 6:10)
    sc <- spark_connect(master = "local")

    tbl_1_sp <-
        sparklyr::copy_to(
            dest = sc,
            df = tbl_1,
            name = "tbl_1_sp",
            overwrite = TRUE
        )

    observeEvent(input$select_a, {

        number_b <- tbl_1_sp %>%
            filter(a == !!input$select_a) %>%
            collect() %>%
            pull()

        output$text_b <- renderText({
            paste0("Selected number : ", number_b)
        })
    })
})
ui.R
library(shiny)
library(dplyr)
library(sparklyr)


# Define UI for application t
shinyUI(fluidPage(
    # Application title
    titlePanel("Select Number Example"),

    # Sidebar with a slider input for number
    sidebarLayout(sidebarPanel(
        sliderInput(
            "select_a",
            "Number for 1:",
            min = 1,
            max = 5,
            value = 1
        )
    ),

    # Show a text as output
    mainPanel(textOutput("text_b")))
))

Conclusion

There are many use cases for tidy evaluation, especially for advanced programmers. With the tidyverse getting bigger by the day, knowing tidy evaluation gets more and more useful. For getting more information about the metaprogramming in R and other advanced topics, I can recommend the book Advanced R by Hadley Wickham.

[author class=“mtl” title=“Über den Autor”]

Reinforcement learning is currently one of the hottest topics in machine learning. For a recent conference we attended (the awesome Data Festival in Munich), we’ve developed a reinforcement learning model that learns to play Super Mario Bros on NES so that visitors, that come to our booth, can compete against the agent in terms of level completion time.
The promotion was a great success and people enjoyed the “human vs. machine” competition. There was only one contestant who was able to beat the AI by taking a secret shortcut, that the AI wasn’t aware of. Also, developing the model in Python was a lot of fun. So, I decided to write a blog post about it that covers some of the fundamental concepts of reinforcement learning as well as the actual implementation of our Super Mario agent in TensorFlow (beware, I’ve used TensorFlow 1.13.1, TensorFlow 2.0 was not released at the time of writing this article).

Recap: reinforcement learning

Most machine learning models have an explicit connection between inputs and outputs that does not change during training time. Therefore, it can be difficult to model or predict systems, where the inputs or targets themselves depend on previous predictions. However, often,the world around the model updates itself with every prediction made. What sounds quite abstract is actually a very common situation in the real world: autonomous driving, machine control, process automation etc. – in many situations, decisions that are made by models have an impact on their surroundings and consequently on the next actions to be taken. Classical supervised learning approaches can only be used to a limited extend in such kinds of situations. To solve the latter, machine learning models are needed that are able to cope with time-dependent variation of inputs and outputs that are interdependent. This is where reinforcement learning comes into play. In reinforcement learning, the model (called agent) interacts with its environment by choosing from a set of possible actions (action space) in each state of the environment that cause either positive or negative rewards from the environment. Think of rewards as an abstract concept of signalizing that the action taken was good or bad. Thereby, the reward issued by the environment can be immediate or delayed into the future. By learning from the combination of environment states, actions and corresponsing rewards (so called transitions), the agent tries to reach an optimal set of decision rules (the policy) that maximize the total reward gathered by the agent in each state.

Q-learning and Deep Q-learning

In reinforcement learning we often use a learning concept called Q-learning. Q-learning is based on so called Q-values, that help the agent determining the optimal action, given the current state of the environment. Q-values are “discounted” future rewards, that our agent collects during training by taking actions and moving through the different states of the environment. Q-values themselves are tried to be approximated during training, either by simple exploration of the environment or by using a function approximator, such as a deep neural network (as in our case here). Mostly, we select in each state the action that has the highest Q-value, i.e. the highest discounuted future reward, givent the current state of the environment. When using a neural network as a Q-function approximator we learn by computing the difference between the predicted Q-values and the “true” Q-values, i.e. the representation of the optimal decision in the current state. Based on the computed loss, we update the network’s parameters using gradient descent, just like in any other neural network model. By doing this often, our network converges to a state, where it can approximate the Q-values of the next state, given the current state of the environment. If the approximation is good enough, we simple select the action that has the highest Q-value. By doing so, the agent is able to decide in each situation, which action generates the best outcome in terms of reward collection. In most deep reinforcement learning models there are actually two deep neural networks involved: the online- and the target-network. This is done because during training, the loss function of a single neural network is computed against steadily changing targets (Q-values), that are based on the networks weights themselves. This adds increased difficulty to the optimization problem or might result in no convergence at all. The target network is basically a copy of the online network with frozen weights that are not directly trained. Instead the target network’s weights are synchronized with the online network after a certain amount of training steps. Enforcing “stable outputs” of the target network that do not change after each training step makes sure that the computed target Q-values that are needed for computing the loss do not change steadily which supports convergence of the optimization problem.

Deep Double Q-learning

Another possible issue with Q-learning is, that due to the selection of the maximum Q-value for determining the best action, the model sometimes produces extraordinary high Q-values during training. Basically, this is not always a problem but might turn into one, if there is a strong concentration at certain actions that in return lead to the negletion of less favorable but “worth-to-try” actions. If the latter are neglected all the time, the model might run into a locally optimal solution or even worse selects the same actions all the time. One way to deal with this problem is to introduce an updated version of Q-learning called double Q-learning. In double Q-learning the actions in each state are not simply chosen by selecting the action with maximum Q-value of the target network. Instead, the selection process is split into three distinct steps: (1) first, the target network computes the target Q-values of the state after taking the action. Then, (2) the online network computes the Q-values of the state after taking the action and selects the best action by finding the maximum Q-value. Finally, (3) the target Q-Values are calculated using the target Q-values of the target network, but at the selected action indices of the online network. This assures, that there cannot occur an overestimation of Q-values because the Q-values are not updated based on themselves.

Gym environments

In order to build a reinforcement learning aplication, we need two things: (1) an environment that the agent can interact with and learn from (2) the agent, that observes the state(s) of the environment and chooses appropriate actions using Q-values, that (ideally) result in high rewards for the agent. An environment is typically provided as a so called gym, a class that contains the neecessary code to emulate the states and rewards of the environment as a function of the agent’s actions as well further information, e.g. about the possible action space. Here is an example of a simple environment class in Python:
class Environment:
    """ A simple environment skeleton """
    def __init__(self):
          # Initializes the environment
        pass

    def step(self, action):
          # Changes the environment based on agents action
        return next_state, reward, done, info

    def reset(self):
        # Resets the environment to its initial state
        pass

    def render(self):
          # Show the state of the environment on screen
        pass
The environment has three major class functions: (1) step() executes the environment code as function of the action selected by the agent and returns the next state of the environment, the reward with respect to action, a done flag indicating if the environment has reached its terminal state as well as a dictionary of additional information about the environment and its state, (2) reset() resets the environment in it’s original state and (3) render() print the current state on the screen (for example showing the current frame of the Super Mario Bros game). For Python, a go-to place for finding gyms is OpenAI. It contains lots of diffenrent games and problems well suited for solving using reinforcement learning. Furthermore, there is an Open AI project called Gym Retro that contains hundrets of Sega and SNES games, ready to be tackled by reinforcement learning algorithms.

Agent

The agent comsumes the current state of the environment and selects an appropriate action based on the selection policy. The policy maps the state of the environment to the action to be taken by the agent. Finding the right policy is a key question in reinforcement learning and often involves the usage of deep neural networks. The following agent simply observes the state of the environment and returns action = 1 if state is larger than 0 and action = 0 otherwise.
class Agent:
    """ A simple agent """
    def __init__(self):
        pass

    def action(self, state):
        if state > 0:
            return 1
        else:
            return 0
This is of course a very simplistic policy. In practical reinforcement learning applications the state of the environment can be very complex and high-dimensional. One example are video games. The state of the environment is determined by the pixels on screen and the previous actions of the player. Our agent needs to find a policy that maps the screen pixels into actions that generate rewards from the environment.

Environment wrappers

Gym environments contain most of the functionalities needed to use them in a reinforcement learning scenario. However, there are certain features that do not come prebuilt into the gym, such as image downscaling, frame skipping and stacking, reward clipping and so on. Luckily, there exist so called gym wrappers that provide such kinds of utility functions. An example that can be used for many video games such as Atari or NES can be found here. For video game gyms it is very common to use wrapper functions in order to achieve a good performance of the agent. The example below shows a simple reward clipping wrapper.
import gym

class ClipRewardEnv(gym.RewardWrapper):
        """ Example wrapper for reward clipping """
    def __init__(self, env):
        gym.RewardWrapper.__init__(self, env)

    def reward(self, reward):
        # Clip reward to {1, 0, -1} by its sign
        return np.sign(reward)
From the example shown above you can see, that it is possible to change the default behavior of the environment by “overwriting” its core functions. Here, rewards of the environment are clipped to [-1, 0, 1] using np.sign() based on the sign of the reward.

The Super Mario Bros NES environment

For our Super Mario Bros reinforcement learning experiment, I’ve used gym-super-mario-bros. The API ist straightforward and very similar to the Open AI gym API. The following code shows a random agent playing Super Mario. This causes Mario to wiggle around on the screen and – of course – does not lead to a susscessful completion of the game.
from nes_py.wrappers import BinarySpaceToDiscreteSpaceEnv
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT


# Make gym environment
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = BinarySpaceToDiscreteSpaceEnv(env, SIMPLE_MOVEMENT)

# Play random
done = True
for step in range(5000):
    if done:
        state = env.reset()
    state, reward, done, info = env.step(env.action_space.sample())
    env.render()

# Close device
env.close()
The agent interacts with the environment by choosing random actions from the action space of the environment. The action space of a video game is actually quite large since you can press multiple buttons at the same time. Here, the action space is reduced to SIMPLE_MOVEMENT, which covers basic game actions such as run in all directions, jump, duck and so on. BinarySpaceToDiscreteSpaceEnv transforms the binary action space (dummy indicator variables for all buttons and directions) into a single integer. So for example the integer action 12 corresponds to pressing right and A (running).

Using a deep learning model as an agent

When playing Super Mario Bros on NES, humans see the game screen – more precisely – they see consecutive frames of pixels, displayed at a high speed on the screen. Our human brains are capable of transforming the raw sensorial input from our eyes into electrical signals that are processed by our brain that trigger corresponding actions (pressing buttons on the controller) that (hopefully) lead Mario to the finishing line. When training the agent, the gym renders each game frame as a matrix of pixels, according to the respective action taken by the agent. Basically, those pixels can be used as an input to any machine learning model. However, in reinforcement learning we often use convolutional neural networks (CNNs) that excel at image recognition problems compared to other ML models. I won’t go into technical detail about CNNs here, there’s a plethora of great intro articles to CNNs like this one. Instead of using only the current game screen as an input to the model, it is common to use multiple stacked frames as an input to the CNN. By doing so, the model can process changes and “movements” on the screen between consecutive frames, which would not be possible when using only a single game frame. Here, the input tensor of our model is of size [84, 84, 4]. This corresponds to a stack of 4 grayscale frames, each frame of size 84×84 pixels. This corresponds to the default tensor size for 2-dimensional convolution. The architecture of the deep learning model consists of three convolutional layers, followed by a flatten and one fully connected layer with 512 neurons as well as an output layer, consisting of actions = 6 nerons, which corresponds to the action space of the game (in this case RIGHT_ONLY, i.e. actions to move Mario to the right – enlarging the action space usually causes an increase in problem complexity and training time). If you take a closer look at the TensorBoard image below, you’ll notice that the model actually consists of not only one but two identical convolutional branches. One is the online network branch, the other one is the target network branch. The online network is acutally trained using gradient descent. The target network is not directly trained but periodically synchronized every copy = 10000 steps by copying the weights from the online branch to the target branch of the network. The target network branch is excluded from gradient descent training by using the tf.stop_gradient() function around the output layer of the branch. This causes a stop in the flow of gradients at the output layer so that they cannot propagate along the branch and so the weights are not updated.
The agent learns by (1) taking random samples of historical transitions, (2) computing the “true” Q-values based on the states of the environment after action, next_state, using the target network branch and the double Q-learning rule, (3) discounting the target Q-values using gamma = 0.9 and (4) run a batch gradient descent step based on the network’s internal Q-prediction and the true Q-values, supplied by target_q. In order to speed up the training process, the agent is not trained after each action but every train_each = 3 frames which corresponds to a training every 4 frames. In addition, not every frame is stored in the replay buffer but each 4th frame. This is called frame skipping. More specifically, a max pooling operation is performed that aggregates the information between the last 4 consecutive frames. This is motivated by the fact that consecutive frames contain nearly the same information which does not add new information to the learning problem and might introduce strongly autocorrelated datapoints. Speaking of correlated data: our network is trained using adaptive moment estimation (ADAM) and gradient descent at a learning_rate = 0.00025, which requires i.i.d. datapoints in order to work well. This means, that we cannot simply use all new transition tuples subsequently for training since they are highly correlated. To solve this issue we use a concept called experience replay buffer. Hereby, we store every transition of our game in a ring buffer object (in Python the deque() function) which is then randomly sampled from, when we acquire our training data of batch_size = 32. By using a random sampling strategy and a large enough replay buffer, we can assume that the resulting datapoints are (hopefully) not correlated. The following codebox shows the DQNAgent class.
import time
import random
import numpy as np
from collections import deque
import tensorflow as tf
from matplotlib import pyplot as plt


class DQNAgent:
    """ DQN agent """
    def __init__(self, states, actions, max_memory, double_q):
        self.states = states
        self.actions = actions
        self.session = tf.Session()
        self.build_model()
        self.saver = tf.train.Saver(max_to_keep=10)
        self.session.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver()
        self.memory = deque(maxlen=max_memory)
        self.eps = 1
        self.eps_decay = 0.99999975
        self.eps_min = 0.1
        self.gamma = 0.90
        self.batch_size = 32
        self.burnin = 100000
        self.copy = 10000
        self.step = 0
        self.learn_each = 3
        self.learn_step = 0
        self.save_each = 500000
        self.double_q = double_q

    def build_model(self):
        """ Model builder function """
        self.input = tf.placeholder(dtype=tf.float32, shape=(None, ) + self.states, name='input')
        self.q_true = tf.placeholder(dtype=tf.float32, shape=[None], name='labels')
        self.a_true = tf.placeholder(dtype=tf.int32, shape=[None], name='actions')
        self.reward = tf.placeholder(dtype=tf.float32, shape=[], name='reward')
        self.input_float = tf.to_float(self.input) / 255.
        # Online network
        with tf.variable_scope('online'):
            self.conv_1 = tf.layers.conv2d(inputs=self.input_float, filters=32, kernel_size=8, strides=4, activation=tf.nn.relu)
            self.conv_2 = tf.layers.conv2d(inputs=self.conv_1, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu)
            self.conv_3 = tf.layers.conv2d(inputs=self.conv_2, filters=64, kernel_size=3, strides=1, activation=tf.nn.relu)
            self.flatten = tf.layers.flatten(inputs=self.conv_3)
            self.dense = tf.layers.dense(inputs=self.flatten, units=512, activation=tf.nn.relu)
            self.output = tf.layers.dense(inputs=self.dense, units=self.actions, name='output')
        # Target network
        with tf.variable_scope('target'):
            self.conv_1_target = tf.layers.conv2d(inputs=self.input_float, filters=32, kernel_size=8, strides=4, activation=tf.nn.relu)
            self.conv_2_target = tf.layers.conv2d(inputs=self.conv_1_target, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu)
            self.conv_3_target = tf.layers.conv2d(inputs=self.conv_2_target, filters=64, kernel_size=3, strides=1, activation=tf.nn.relu)
            self.flatten_target = tf.layers.flatten(inputs=self.conv_3_target)
            self.dense_target = tf.layers.dense(inputs=self.flatten_target, units=512, activation=tf.nn.relu)
            self.output_target = tf.stop_gradient(tf.layers.dense(inputs=self.dense_target, units=self.actions, name='output_target'))
        # Optimizer
        self.action = tf.argmax(input=self.output, axis=1)
        self.q_pred = tf.gather_nd(params=self.output, indices=tf.stack([tf.range(tf.shape(self.a_true)[0]), self.a_true], axis=1))
        self.loss = tf.losses.huber_loss(labels=self.q_true, predictions=self.q_pred)
        self.train = tf.train.AdamOptimizer(learning_rate=0.00025).minimize(self.loss)
        # Summaries
        self.summaries = tf.summary.merge([
            tf.summary.scalar('reward', self.reward),
            tf.summary.scalar('loss', self.loss),
            tf.summary.scalar('max_q', tf.reduce_max(self.output))
        ])
        self.writer = tf.summary.FileWriter(logdir='./logs', graph=self.session.graph)

    def copy_model(self):
        """ Copy weights to target network """
        self.session.run([tf.assign(new, old) for (new, old) in zip(tf.trainable_variables('target'), tf.trainable_variables('online'))])

    def save_model(self):
        """ Saves current model to disk """
        self.saver.save(sess=self.session, save_path='./models/model', global_step=self.step)

    def add(self, experience):
        """ Add observation to experience """
        self.memory.append(experience)

    def predict(self, model, state):
        """ Prediction """
        if model == 'online':
            return self.session.run(fetches=self.output, feed_dict={self.input: np.array(state)})
        if model == 'target':
            return self.session.run(fetches=self.output_target, feed_dict={self.input: np.array(state)})

    def run(self, state):
        """ Perform action """
        if np.random.rand() < self.eps:
            # Random action
            action = np.random.randint(low=0, high=self.actions)
        else:
            # Policy action
            q = self.predict('online', np.expand_dims(state, 0))
            action = np.argmax(q)
        # Decrease eps
        self.eps *= self.eps_decay
        self.eps = max(self.eps_min, self.eps)
        # Increment step
        self.step += 1
        return action

    def learn(self):
        """ Gradient descent """
        # Sync target network
        if self.step % self.copy == 0:
            self.copy_model()
        # Checkpoint model
        if self.step % self.save_each == 0:
            self.save_model()
        # Break if burn-in
        if self.step < self.burnin:
            return
        # Break if no training
        if self.learn_step < self.learn_each:
            self.learn_step += 1
            return
        # Sample batch
        batch = random.sample(self.memory, self.batch_size)
        state, next_state, action, reward, done = map(np.array, zip(*batch))
        # Get next q values from target network
        next_q = self.predict('target', next_state)
        # Calculate discounted future reward
        if self.double_q:
            q = self.predict('online', next_state)
            a = np.argmax(q, axis=1)
            target_q = reward + (1. - done) * self.gamma * next_q[np.arange(0, self.batch_size), a]
        else:
            target_q = reward + (1. - done) * self.gamma * np.amax(next_q, axis=1)
        # Update model
        summary, _ = self.session.run(fetches=[self.summaries, self.train],
                                      feed_dict={self.input: state,
                                                 self.q_true: np.array(target_q),
                                                 self.a_true: np.array(action),
                                                 self.reward: np.mean(reward)})
        # Reset learn step
        self.learn_step = 0
        # Write
        self.writer.add_summary(summary, self.step)

Training the agent to play

First, we need to instantiate the environment. Here, we use the first level of Super Mario Bros, SuperMarioBros-1-1-v0 as well as a discrete event space with RIGHT_ONLY action space. Additionally, we use a wrapper that applies frame resizing, stacking and max pooling, reward clipping as well as lazy frame loading to the environment. When the training starts, the agent begins to explore the environment by taking random actions. This is done in order to build up initial experience that serves as a starting point for the actual learning process. After burin = 100000 game frames, the agent slowly starts to replace random actions by actions determined by the CNN policy. This is called an epsilon-greedy policy. Epsilon-greeedy means, that the agent takes a random action with probability epsilon or a policy-based action with probability (1-epsilon). Here, epsilon diminisches linearly during training by a factor of eps_decay = 0.99999975 until it reaches eps = 0.1 where it remains constant for the rest of the training process. It is important to not completely eliminate random actions from the training process in order to avoid getting stuck on locally optimal solutions. For each action taken, the environment returns a four objects: (1) the next game state, (2) the reward for taking the action, (3) a flag if the episode is done and (4) an info dictionary containing additional information from the environment. After taking the action, a tuple of the returned objects is added to the replay buffer and the agent performs a learning step. After learning, the current state is updated with the next_state and the loop increments. The while loop breaks, if the done flag is True. This corresponds to either the death of Mario or to a successful completion of the level. Here, the agent is trained in 10000 episodes.
import time
import numpy as np
from nes_py.wrappers import BinarySpaceToDiscreteSpaceEnv
import gym_super_mario_bros
from gym_super_mario_bros.actions import RIGHT_ONLY
from agent import DQNAgent
from wrappers import wrapper

# Build env (first level, right only)
env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
env = BinarySpaceToDiscreteSpaceEnv(env, RIGHT_ONLY)
env = wrapper(env)

# Parameters
states = (84, 84, 4)
actions = env.action_space.n

# Agent
agent = DQNAgent(states=states, actions=actions, max_memory=100000, double_q=True)

# Episodes
episodes = 10000
rewards = []

# Timing
start = time.time()
step = 0

# Main loop
for e in range(episodes):

    # Reset env
    state = env.reset()

    # Reward
    total_reward = 0
    iter = 0

    # Play
    while True:

        # Show env (diabled)
        # env.render()

        # Run agent
        action = agent.run(state=state)

        # Perform action
        next_state, reward, done, info = env.step(action=action)

        # Remember transition
        agent.add(experience=(state, next_state, action, reward, done))

        # Update agent
        agent.learn()

        # Total reward
        total_reward += reward

        # Update state
        state = next_state

        # Increment
        iter += 1

        # If done break loop
        if done or info['flag_get']:
            break

    # Rewards
    rewards.append(total_reward / iter)

    # Print
    if e % 100 == 0:
        print('Episode {e} - +'
              'Frame {f} - +'
              'Frames/sec {fs} - +'
              'Epsilon {eps} - +'
              'Mean Reward {r}'.format(e=e,
                                       f=agent.step,
                                       fs=np.round((agent.step - step) / (time.time() - start)),
                                       eps=np.round(agent.eps, 4),
                                       r=np.mean(rewards[-100:])))
        start = time.time()
        step = agent.step

# Save rewards
np.save('rewards.npy', rewards)
After each game episode, the averagy reward in this episode is appended to the rewards list. Furthermore, different stats such as frames per second and the current epsilon are printed after every 100 episodes.

Replay

During training, the program checkpoints the current network at save_each = 500000 frames and keeps the 10 latest models on disk. I’ve downloaded several model versions during training to my local machine and produced the following video.
It is so awesome to see the learning progress of the agent! The training process took approximately 20 hours on a GPU accelerated VM on Google Cloud.

Summary and outlook

Reinforcement learning is an exciting field in machine learning that offers a wide range of possible applications in science and business likewise. However, the training of reinforcement learning agents is still quite cumbersome and often requires tedious tuning of hyperparameters and network architecture in order to work well. There have been recent advances, such as RAINBOW (a combination of multiple RL learning strategies) that aim at a more robust framework for training reinforcement learning agents but the field is still an area of active research. Besides Q-learning, there are many other interesting training concepts in reinforcement learning that have been developed. If you want to try different RL agents and training approaches, I suggest you check out Stable Baselines, a great way to easily use state-of-the-art RL agents and training concepts. If you are a deep learning beginner and want to learn more, you should check our brandnew STATWORX Deep Learning Bootcamp, a 5-day in-person introduction into the field that covers everything you need to know in order to develop your first deep learning models: neural net theory, backpropagation and gradient descent, programming models in Python, TensorFlow and Keras, CNNs and other image recognition models, recurrent networks and LSTMs for time series data and NLP as well as advaned topics such as deep reinforcement learning and GANs. If you have any comments or questions on my post, feel free to contact me!  Also, feel free to use my code (link to GitHub repo) or share this post with your peers on social platforms of your choice. If you’re interested in more content like this, join our mailing list, constantly bringing you fresh data science, machine learning and AI reads and treats from me and my team at STATWORX right into your inbox! Lastly, follow me on LinkedIn or my company STATWORX on Twitter, if you’re interested in more! It’s Valentine’s day, making this the most romantic time of the year. But actually, already 2018 was a year full of love here at STATWORX: many of my STATWORX colleagues got engaged. And so we began to wonder – some fearful, some hopeful – who will be next? Therefore, today we’re going to tackle this question in the only true way: with data science!

Gathering the Data

To get my data, I surveyed my colleagues. I asked my (to be) married colleagues to answer my questions based on the very day they got engaged. My single colleagues answered my questions with respect to their current situation. I asked them about some factors that I’ve always suspected to influence someone’s likeliness to get married. For example, I’m sure that in comparison to Python users, R users are much more romantic. The indiscreet questions I badgered my coworkers with were:
  • Are you married or engaged?
  • How long have you been in your relationship?
  • Is your employment permanent?
  • How long have you been working at STATWORX?
  • What’s your age?
  • Are you living together with your partner?
  • Are you co-owning a pet with your partner?
  • What’s your preferred programming language: R, Python or none of both.
I’m going to treat the relationship status as dichotomous variable: Married or engaged vs. single or “only” dating. To maintain some of the privacy of my colleagues I gave them all some randomly (!!) chosen pet names. (Side note: There really is a subreddit for everything.)

Descriptive Exploration

Since the first step in generating data driven answer should always be a descriptive exploration of the data at hand, I made some plots. First, I took a look at the absolute frequencies of preferred programming languages in the groups of singles vs. married or engaged STATWORX employees. I fear, the romantic nature of R users is not the explanation we’re looking for:
# reformatting the target variable
df1 <- df %>%
  dplyr::mutate(engaged = ifelse(engaged == "yes", 
                                 "Engaged or Married", 
                                 "Single")) %>%
  dplyr::group_by(`primary programming language`, engaged) %>%
  dplyr::summarise(freq = n(),
                   image = "~/Desktop/heart_red.png") 

# since in geom_image size cannot be mapped to variable
# multiple layers of data subsets  
ggplot() +
  geom_image(data = filter(df1, freq == 1), 
             aes(y = `primary programming language`,
                 x = engaged, 
                 image = image), 
             size = 0.1) + 
  geom_image(data = filter(df1, freq > 1 & freq <= 5), 
             aes(y = `primary programming language`, 
                 x = engaged, 
                 image = image),
             size = 0.2) +
  geom_image(data = filter(df1, freq >= 13), 
             aes(y= `primary programming language`, 
                 x = engaged, 
                 image = image),
             size = 0.3) +
  geom_text(data = df1, 
            aes(y =`primary programming language`, 
                x = engaged, 
                label = freq), 
            color = "white", size = 4) +
  ylab("Preferred programming language") +
  xlab("n Absolute frequencies") +
  theme_minimal()
programming languages frequencies
I also explored the association of relationship status and the more conventional factors of age and relationship duration. And indeed, those of my colleagues who are in their late twenties or older and have been partnered for a while now, are mostly married or engaged.
# plotting age and relationship duration vs. relationship status

ggplot() +
# doing this only to get a legend:
  geom_point(data = df,
             aes(x = age, y = `relationship duration`,
                 color = engaged), shape = 15) + 
  geom_image(data = filter(df, engaged == "yes"), 
             aes(x = age, y = `relationship duration`,
                 image = "~/Desktop/heart_red.png")) +
  geom_image(data = filter(df, engaged == "no"), 
             aes(x = age, y = `relationship duration`,
                 image = "~/Desktop/heart_black.png")) +
  ylab("Relationship duration n") +
  xlab("n Age") +
  scale_color_manual(name = "Married or engaged?",
                     values = c("#000000", "#D00B0B")) +
  scale_x_continuous(breaks = pretty_breaks()) +
  theme_minimal() +
  theme(legend.position = "bottom")
age relationship duration

Statistical Models

I’ll employ some statistical models, but the data base is rather small. Therefore, our model options are somewhat limited (and of course only suitable for entertainment). But it’s still possible to fit a decision tree, which might help to pinpoint due to which circumstance some of us are still waiting for that special someone to put a ring on (it).
# recoding target to get more understandable labels
df <- df %>%
  dplyr::mutate(engaged = ifelse(engaged == "yes", 
                                 "(to be) married", 
                                 "single"))

# growing a decision three with a ridiculous low minsplit
fit <- rpart(engaged ~ `relationship duration` + age + 
             `shared pet` + `permanent employment` +
             cohabitating + `years at STATWORX`,
             control = rpart.control(minsplit = 2), # overfitting ftw
             method = "class", data = df)

# plotting the three
rpart.plot(fit, type = 3, extra = 2, 
           box.palette = c("#D00B0B", "#fae6e6"))

relationship decision tree
Our decision tree implies, that the unintentionally unmarried of us maybe should consider to move in with their partner, since cohabitating seems to be the most important factor. But that still doesn’t exactly answer the question, who of us will be next. To predict our chances to get engaged, I estimated a logistic regression. We see that cohabiting, one’s age and the time we’ve been working at STATWORX are accompanied by a higher probability to (soon to) be married. However, simply having been together for a long time or owning a pet together with our partner, does not help. (Although, I assume that this rather unintuitive interrelation is caused by a certain outlier in the data – “Honey”, I’m looking at you!) Finally, I got the logistic regression’s predicted probabilities for all of us to be married or engaged. As you can see down below, the single days of “Teddy Bear”, “Honey”, “Sweet Pea” and “Babe” seem to be numbered.
# reformatting the target variable
df <- df %>%
  dplyr::mutate(engaged = ifelse(engaged == "(to be) married", 1, 0))

# in-sample fitting: estimating the model 
log_reg <- glm(engaged ~ `relationship duration` + age +
               `shared pet` + `permanent employment` + 
               cohabitating + `years at STATWORX`,
              family = binomial, data = df)

df$probability <- predict(log_reg, df, type = "response")

# plotting the predicted probabilities
ggplot() +
  # again, doing this only to get a legend:
  geom_point(data = df,
             aes(x = probability, y = nickname,
                 color = as.factor(engaged)), shape = 15) + 
  geom_image(data = filter(df, engaged == 1), 
             aes(x = probability, y = nickname,
                 image = "~/Desktop/heart_red.png")) +
  geom_image(data = filter(df, engaged == 0), 
             aes(x = probability, y = nickname,
                 image = "~/Desktop/heart_black.png")) +
  ylab(" ") +
  xlab("Predicted Probability") +
  scale_color_manual(name = "Married or engaged?",
                     values = c("#000000", "#D00B0B"),
                     labels = c("no", "yes")) +
  scale_x_continuous(breaks = pretty_breaks()) +
  theme_minimal() +
  theme(legend.position = "bottom")
predicted probabilities for marriage
I hope this was as insightful for you as it was for me. And to all of us, whose hopes have been shattered by cold, hard facts, let’s remember: there are tons of discounted chocolates waiting for us on February 15th.
It’s Valentine’s day, making this the most romantic time of the year. But actually, already 2018 was a year full of love here at STATWORX: many of my STATWORX colleagues got engaged. And so we began to wonder – some fearful, some hopeful – who will be next? Therefore, today we’re going to tackle this question in the only true way: with data science!

Gathering the Data

To get my data, I surveyed my colleagues. I asked my (to be) married colleagues to answer my questions based on the very day they got engaged. My single colleagues answered my questions with respect to their current situation. I asked them about some factors that I’ve always suspected to influence someone’s likeliness to get married. For example, I’m sure that in comparison to Python users, R users are much more romantic. The indiscreet questions I badgered my coworkers with were: I’m going to treat the relationship status as dichotomous variable: Married or engaged vs. single or “only” dating. To maintain some of the privacy of my colleagues I gave them all some randomly (!!) chosen pet names. (Side note: There really is a subreddit for everything.)

Descriptive Exploration

Since the first step in generating data driven answer should always be a descriptive exploration of the data at hand, I made some plots. First, I took a look at the absolute frequencies of preferred programming languages in the groups of singles vs. married or engaged STATWORX employees. I fear, the romantic nature of R users is not the explanation we’re looking for:
# reformatting the target variable
df1 <- df %>%
  dplyr::mutate(engaged = ifelse(engaged == "yes", 
                                 "Engaged or Married", 
                                 "Single")) %>%
  dplyr::group_by(`primary programming language`, engaged) %>%
  dplyr::summarise(freq = n(),
                   image = "~/Desktop/heart_red.png") 

# since in geom_image size cannot be mapped to variable
# multiple layers of data subsets  
ggplot() +
  geom_image(data = filter(df1, freq == 1), 
             aes(y = `primary programming language`,
                 x = engaged, 
                 image = image), 
             size = 0.1) + 
  geom_image(data = filter(df1, freq > 1 & freq <= 5), 
             aes(y = `primary programming language`, 
                 x = engaged, 
                 image = image),
             size = 0.2) +
  geom_image(data = filter(df1, freq >= 13), 
             aes(y= `primary programming language`, 
                 x = engaged, 
                 image = image),
             size = 0.3) +
  geom_text(data = df1, 
            aes(y =`primary programming language`, 
                x = engaged, 
                label = freq), 
            color = "white", size = 4) +
  ylab("Preferred programming language") +
  xlab("n Absolute frequencies") +
  theme_minimal()
programming languages frequencies
I also explored the association of relationship status and the more conventional factors of age and relationship duration. And indeed, those of my colleagues who are in their late twenties or older and have been partnered for a while now, are mostly married or engaged.
# plotting age and relationship duration vs. relationship status

ggplot() +
# doing this only to get a legend:
  geom_point(data = df,
             aes(x = age, y = `relationship duration`,
                 color = engaged), shape = 15) + 
  geom_image(data = filter(df, engaged == "yes"), 
             aes(x = age, y = `relationship duration`,
                 image = "~/Desktop/heart_red.png")) +
  geom_image(data = filter(df, engaged == "no"), 
             aes(x = age, y = `relationship duration`,
                 image = "~/Desktop/heart_black.png")) +
  ylab("Relationship duration n") +
  xlab("n Age") +
  scale_color_manual(name = "Married or engaged?",
                     values = c("#000000", "#D00B0B")) +
  scale_x_continuous(breaks = pretty_breaks()) +
  theme_minimal() +
  theme(legend.position = "bottom")
age relationship duration

Statistical Models

I’ll employ some statistical models, but the data base is rather small. Therefore, our model options are somewhat limited (and of course only suitable for entertainment). But it’s still possible to fit a decision tree, which might help to pinpoint due to which circumstance some of us are still waiting for that special someone to put a ring on (it).
# recoding target to get more understandable labels
df <- df %>%
  dplyr::mutate(engaged = ifelse(engaged == "yes", 
                                 "(to be) married", 
                                 "single"))

# growing a decision three with a ridiculous low minsplit
fit <- rpart(engaged ~ `relationship duration` + age + 
             `shared pet` + `permanent employment` +
             cohabitating + `years at STATWORX`,
             control = rpart.control(minsplit = 2), # overfitting ftw
             method = "class", data = df)

# plotting the three
rpart.plot(fit, type = 3, extra = 2, 
           box.palette = c("#D00B0B", "#fae6e6"))

relationship decision tree
Our decision tree implies, that the unintentionally unmarried of us maybe should consider to move in with their partner, since cohabitating seems to be the most important factor. But that still doesn’t exactly answer the question, who of us will be next. To predict our chances to get engaged, I estimated a logistic regression. We see that cohabiting, one’s age and the time we’ve been working at STATWORX are accompanied by a higher probability to (soon to) be married. However, simply having been together for a long time or owning a pet together with our partner, does not help. (Although, I assume that this rather unintuitive interrelation is caused by a certain outlier in the data – “Honey”, I’m looking at you!) Finally, I got the logistic regression’s predicted probabilities for all of us to be married or engaged. As you can see down below, the single days of “Teddy Bear”, “Honey”, “Sweet Pea” and “Babe” seem to be numbered.
# reformatting the target variable
df <- df %>%
  dplyr::mutate(engaged = ifelse(engaged == "(to be) married", 1, 0))

# in-sample fitting: estimating the model 
log_reg <- glm(engaged ~ `relationship duration` + age +
               `shared pet` + `permanent employment` + 
               cohabitating + `years at STATWORX`,
              family = binomial, data = df)

df$probability <- predict(log_reg, df, type = "response")

# plotting the predicted probabilities
ggplot() +
  # again, doing this only to get a legend:
  geom_point(data = df,
             aes(x = probability, y = nickname,
                 color = as.factor(engaged)), shape = 15) + 
  geom_image(data = filter(df, engaged == 1), 
             aes(x = probability, y = nickname,
                 image = "~/Desktop/heart_red.png")) +
  geom_image(data = filter(df, engaged == 0), 
             aes(x = probability, y = nickname,
                 image = "~/Desktop/heart_black.png")) +
  ylab(" ") +
  xlab("Predicted Probability") +
  scale_color_manual(name = "Married or engaged?",
                     values = c("#000000", "#D00B0B"),
                     labels = c("no", "yes")) +
  scale_x_continuous(breaks = pretty_breaks()) +
  theme_minimal() +
  theme(legend.position = "bottom")
predicted probabilities for marriage
I hope this was as insightful for you as it was for me. And to all of us, whose hopes have been shattered by cold, hard facts, let’s remember: there are tons of discounted chocolates waiting for us on February 15th.