Customer churn prediction using Spark with a declarative approach

Thiago Lima
10 min readNov 17, 2021

Overview

In the business world, churn is defined when a customer cancels or abandons the service. Predicting when a customer tends to churn can be very profitable to companies, since this could increase the retention rate, by offering discounts and incentives.

This article will cover how to create a scalable model using spark to predict a customer churn. The data used for this project was provided by Udacity and it’s about a user log for a fictional streaming music app called Sparkify.

The code behind this article can be found here.

Sparkify’s menu of users is divided by free tier and premium. Free users will be exposed to advertising between the songs. The premium users will be called in this article as paid users. Users can upgrade, downgrade or cancel the services at any time. So, it’s very important that users really like the service.

Among other interactions provided by the app, the users can play songs, login and logout, like a song with a thumbs up or dislike a song with a thumbs down, add a song in a playlist and add another user as a friend.

Spark and Big Data

To deal with big data, most of the time it’s necessary to have a distributed system of multiple computers. Spark could be very helpful to deal with all data engineering, data science and machine learning in a single node, or multiple machines. You can find more information about it here.

Let’s start by setting up the spark session to extract information from the data set in SQL. After finishing this step we can read our dataset and create a temporary view of the dataframe, allowing us to work with a declarative approach using SQL.

Preparing the dataset to be used by spark

Exploratory data analysis

Let’s understand more about the data we have and the missing values.

Dataset Schema

Below, we can see a sample of the dataset with some of the displayed columns. This dataset represents a user log, i.e, all interactions between the user and the streaming music app.

Selected features from the dataset

Each row of this dataset represents an interaction, which can be defined by the variable page, a categorical variable that assume values like: Home, Submit Downgrade, Thumbs Down, Roll Advert, Logout, Save Settings, Cancellation Confirmation, Thumbs Up, NextSong and etc.

Each log row is composed of some user’s information such as the user ID, first and last name, gender, location and level as paid or free options. HTTP protocol information is also part of log row, containing information such as:

  • ts: timestamp of the request/response
  • method: HTTP request, with 2 categories GET or PUT
  • status: HTTP status, with 3 categories 200, 307 and 404
  • auth: authentication type, with 4 categories (Logged In, Logged Out, Canceled and Guest)
  • sessionID: the session ID

The dataset has a total 286,500 rows, of which 8,346 are interactions of logged out and guest users without userID, that should be removed.

By our definition of churn, it was found 52 users who churned of a total 225 users. This represents 23% of the active customers in base. Below, it’s possible to see that sparkify it has big problems with customer retention, since losses loom larger than gains.

In order to calculate the plots below, it was necessary to convert the timestamp to datetime, but before we divided it by 1000 since the original timestamps are in milliseconds.

Inserting a new column with timestamp converted in datetime

After this procedure, the data show us that the users’s logs cover an interval of about 2 months starting on 1 october 2018 and finishing on 3 december 2018. Below we can see how many active free and paid users the logs registered by day.

The number of active free users declines over time in the app, which could be understood as a rise of the number of upgrade plans, as we can see in the plot below.

Feature engineering

Let’s try out some measures to understand how the behaviour of users are, and add some of them to our data model. First, we can see below that free users are a bit more susceptible to churn than paid ones.

In order to engineer the next features, let’s use the concept of session to group the events in the users’s log. According to Wikipedia, a session can be defined as a set of interchangeable requests between two devices in a period of the time, which means it has a start and an end. In Sparkify’s dataset, ‘sessionID’ stores the ID of a session along with others informations of the requests, or responses in the log. Now it will be easier to understand the upcoming averages by session.

One of the hypotheses was that free users receive much more advertisements on average per session in the app and it collaborate to its churn. Below it’s possible to see that free users receive much more advertisements on average, and this behaviour could be correlated to churn.

Extracting the average advertising events per session by free user

Before adding this measure to our data model, it’s necessary to treat it to get just the advertisements for the last period level, i. e, if the user was free and upgraded to a paid account, we have to use the advertisements average just for the paid period. The result can be seen below, with the advertisements average isolated, it’s cleaner to see the effect of advertisements on churned users.

Another hypothesis is that users who thumbs up more on average could be more susceptible to stay in the client base, since this event can be interpreted as a probable good sign to the recommendation song system of the app. On the opposite we have the thumbs down average measure, which could be interpreted as a bad sign to the song recommendation system. Below we can see how these measures could be correlated to churn.

Let’s suppose users who really liked this music streaming service and his songs recommendations, would accumulate a Thumbs up proportion much greater than Thumbs down. Since we are trying to predict churns, let’s redirect the hypotheses assuming that users who have a higher thumbs down proportion are more susceptible to churn.

Extracting thumbs down proportion by user

Users who churned have, on average, a proportion of thumbs down greater than those who didn’t. User agent is a way to find information about the browser and device’s Operational System the user is using. So taking in consideration that web apps should be prepared to work with different browsers, we should cross this info with the churn element to see how they are related.

It’s possible to see that some OS and browsers combinations are much more inclined to churn than others. You can see others features used in the model to capture the users’s engagement on the code here.

Our model reached 13 features as it is explained below:

Categorical:

  1. level — free tier or premium account;
  2. gender — user’s gender being M (Male) F (Female);
  3. userAgent — operational systems and browsers combination used by the user;

Numerical:

  1. advertising — average number of advertisement received per session, taking in consideration just the last level;
  2. avgThumbsUp — average of Thumbs up per session;
  3. avgThumbsDown — average of Thumbs down per session;
  4. thumbsDownProportion — proportion of thumbs down considering thumbs up + thumbs down as total;
  5. avgNumSongs — average number of songs played per session;
  6. avgSessionDuration — average duration time, in seconds, of a session;
  7. avgNumPlaylistAddition — average number of songs added to a playlist per session;
  8. avgNumFriendsAdded — average number of friends added per session;
  9. avgTimeBetweenSessions — average time, in seconds, between two sessions;
  10. avgTimeBetweenSongs — average time, in seconds, between two songs;

It was builded a pipeline to treat the data model, inserting stages to create indexes for categorical features and passing this data to a VectorAssembler, which transforms multiple columns into a vector column to train our model.

Model training

To train the model, it was chosen: Logistic Regression, Random Forest Classifier and Gradient Boosted Trees. Since the dataset is imbalanced, having much more users who stayed than who churned, we need to make sure that the minority class (churned users) have as high score prediction by the model as the majority class. This can be achieved by the harmonic mean between precision and recall measures known as F1-score.

Precision metric is about how precise our model is in predicting, i.e, the proportion of hits among all predictions. The formula can be defined as: True Positive/(True Positive + False Positive).

Recall metric in this case, could be explained as the hit proportion among all actual churned users. The formula is defined as: True Positive/(True Positive + False Negative).

F1-score is calculated as: 2*Precision*Recall(Precision+Recall).

The fitted models performance for the three methods:

Table with the metrics used to measure the model performance

Hyperparameter tuning

To improve the performance of our statistical model and mitigate the overfitting, cross validation technique was used. The cross validation technique gives more security on the model generalisation to unseen data if combined with the hyperparameter tuning technique.

With these concepts in mind, we used F1 score to optimize the model, which will guarantee a more robust model. Also, It was performed k-fold cross-validation with 3 folds for all methods due to the small size of the final dataset.

As you could see, the distributions are far away from being normal, so we use MinMaxScaler to preserve the shape of the distributions. To prevent a data leakage, it was necessary to insert the scaler as a preprocessor step in a pipeline. If we scale all the trained data before the cross validate process, when it subdivides the trained data again to use parts to validate, the data would be already biased by the scaler.

Prevent data leakage

Logistic Regression

To prevent overfitting in logistic regression, it was used the regularization technique, which can be done using the parameter regParam. You can find more about regularization here. The range of values used for the param and the best parameter for this model:

RANGE OF VALUES
regParam: [0.1, 0.15, 0.2, 0.5]
BEST PARAM
regParam: 0.1

Random Forest Classifier

In decision trees, to prevent overfitting, it’s necessary pruning the trees before it produce leaves with tiny samples. The stopping rule used for pruning can be checked here. It was possible to tune and avoid the overfitting using the following params with its respective range values.

  • numTrees: Number of trees in the forest.
  • maxDepth: Maximum depth of each tree in the forest.

The range of values used for the params and the best parameters for this model:

RANGE OF VALUES
numTrees: [15, 30, 45]
maxDepth: [10, 15, 20]
BEST PARAM
numTrees: 30
maxDepth: 10

Gradient Boosting Tree

Preventing overfitting in Gradient Boosting Tree is not so hard, since the algorithm is robust enough to not overfit with increasing trees, as explained here. But just to be sure, it used maxDepth, as it did before, and minInstancesPerNode, which helps prune the tree. The range of values used for the params and the best parameters for this model:

RANGE OF VALUES
minInstancesPerNode: [5, 10, 20, 40]
maxDepth: [10, 15, 20]
BEST PARAM
minInstancesPerNode: 40
maxDepth: 10

Results

Below we can see the performance achieved by each method after applying the cross validation technique.

Table with the metrics used to measure the model performance

We can see Random Forest Classifier stand out with the higher recall, precision and F1 score measures, standing as the best method we could choose. Unfortunately, Logistic Regression doesn’t have the featureImportances attribute, but the importance of each feature for the other two methods are available below:

These numbers suggest a higher importance to the song recommendation system when we look at the thumbs up and the thumbs down measures and the proportion between them, which show us the importance of these types of features to improve the performance model.

Conclusions

Using the declarative approach with some knowledge in SQL, it was possible to manipulate spark easier, building a practical machine learning ready to be scaled. The dataset used here is just a sample, but the model is ready to run with big data with powerful distributed systems.

Random Forest achieved a surprisingly 0.9 in F1 score, a great achievement. The daily users’s log could be used to re-train and evaluate the model with a frequency defined by infrastructure and the budget enabled. This will allow the model to evolve with the current users, and improve the precision.

Users identified as a possible churn, should be included in some customer retention program, which will include some strategies to maintain the customer in the base. The strategy chosen should be validated across A/B testing, to understand if the strategy is affecting the users’s decisions positively.

--

--