Author:
CSO & Co-Founder
Reading time:
Classification is one of the most common tasks in machine learning. It covers a wide variety of business needs from predicting player’s churn in mobile gaming to fraud detection in banking. Although each problem demands a specific approach and deep business understanding, it is still possible to extract common patterns that appear while approaching all these problems with machine learning. This article will cover some of them and if you want to read about some basics please go to Machine Learning Basics.
In simple words, classification is a machine learning task in which the model is taught to distinguish things (items, users, transactions, etc.) between two (binary classification) or more (multiclass classification) classes. As an example, consider churn prediction in gaming – all active users are classified either as low-churn-risk (first class) or high-churn-risk (second class). This allows us to approach those two groups of gamers with different means (ads, offers) and by this encourage more players to keep playing the game – at the same time increasing the company’s ROI from players’ payments or in-game ads.
To simplify, we will use the gaming example throughout this article. Keep in mind that the problems we will discuss are much more general and appear in almost all classification tasks.
In the simplest case, when using the machine learning model to make predictions, it can return the label of a class it thinks that the given example belongs to. So in our example, it would correspond to informing if each user is most likely to churn or not churn. An alternative approach is to return the estimated probability of belonging to each class, e.g. for a given user the model would predict 95% of churn and 5% of staying in-game.
Most models accept only numerical features. This means that what you need to do before feeding data to the model is to clean it. There might be many things to do, let us list some of them:
We will use the titanic dataset as at is simple and contains both continuous and categorical columns.
There is an enormous diversity of machine learning models that can be applied to classification tasks – any of them may perform better than the other depending on data it learnt from and the business needs that stand behind the classification problem. One could try some linear models (e.g. logistic regression, support vector machines), or some tree-based models (decision trees, random forest, xgboost), or even deep neural networks. There is no universal answer to which model would perform the best and usually, a huge part of the data scientist’s job is to perform a lot of experiments trying to find the one that fits data the best. To achieve this, we first separate train and test split.
The first will be used to pick the best model and the second to validate it. This way we make sure that the metrics we evaluate on the test set will correspond to the ones measured when the model is applied in production.
Then, to choose the best model and its hyperparameters we perform so-called grid search (or one of its alternatives e.g. Bayes optimization). Grid search consists of two parts:
This approach is much more time consuming than a simple train-test but gives us higher confidence about the quality of the final model.
Proper metric choice is crucial for the real-world application of machine learning. It allows us to tell how good the model is and decide whether one model is better than the other. It also makes progress possible – one can track the value of the metric from one experiment to the other and finally choose the configuration with the highest score. Properly adjusted metric also tells how well the business needs are met by the machine learning solution.
Because of numerous metrics that can fit different tasks, we will further limit ourselves to metrics that apply only to binary classification and model returning dichotomous answers.
One of the simplest metrics is accuracy – by definition, it is a percentage of cases in which the model is correct – that clear interpretation is a great advantage from the business perspective. This metric suits well when trying to build a general classification model on the balanced dataset (meaning that the number of each class is approximately the same). It fails however in case of significant class imbalance – consider data in which only 1% of users churn. One can easily reach an accuracy of 99% just by assigning not churn label to each input. In that case, we end up with 99% accuracy and a completely useless model.
In the case of imbalanced data, usually, the more complicated metrics are being used. One of them is an f-score which is the harmonic mean of precision (rate of cases in which, while labeling an example by positive class, the model is correct) and recall (detection rate for positive class). Usage of harmonic instead of arithmetic means ensures that to maximize f-score one needs to maximize both precision and recall at the same time, which solves the possible problem of one of them being very low and the other very high.
Depending on the business problems, one may care more about recall than precision, e.g. when detecting customer churn it may be beneficial to pay the price of more players erroneously considered churns (lower precision), in return correctly classifying more players as churns (higher recall). In that case, the f-beta-score function family can be helpful. It is simply f-score with more weight assigned to precision or recall depending on the beta value.
In this article, we only glimpsed at the diversity of problems that arise while performing classification tasks with machine learning models. Although we managed to extract some common subtasks, it should be clear that every machine learning project requires a very deep understanding of both business and data.
The success of the project highly depends on those two as they are crucial in making optimal decisions during the process of solution development.
Category: