Skip to main content

CTGAN-ENN: a tabular GAN-based hybrid sampling method for imbalanced and overlapped data in customer churn prediction

Abstract

Class imbalance is one of many problems of customer churn datasets. One of the common problems is class overlap, where the data have a similar instance between classes. The prediction task of customer churn becomes more challenging when there is class overlap in the data training. In this research, we suggested a hybrid method based on tabular GANs, called CTGAN-ENN, to address class overlap and imbalanced data in datasets of customers that churn. We used five different customer churn datasets from an open platform. CTGAN is a tabular GAN-based oversampling to address class imbalance but has a class overlap problem. We combined CTGAN with the ENN under-sampling technique to overcome the class overlap. CTGAN-ENN reduced the number of class overlaps by each feature in all datasets. We investigated how effective CTGAN-ENN is in each machine learning technique. Based on our experiments, CTGAN-ENN achieved satisfactory results in KNN, GBM, XGB and LGB machine learning performance for customer churn predictions. We compared CTGAN-ENN with common over-sampling and hybrid sampling methods, and CTGAN-ENN achieved outperform results compared with other sampling methods and algorithm-level methods with cost-sensitive learning in several machine learning algorithms. We provide a time consumption algorithm between CTGAN and CTGAN-ENN. CTGAN-ENN achieved less time consumption than CTGAN. Our research work provides a new framework to handle customer churn prediction problems with several types of imbalanced datasets and can be useful in real-world data from customer churn prediction.

Introduction

Customer churn prediction uses data analysis and predictive algorithms to determine which clients are most likely to quit a company or cease utilizing its goods or services. By anticipating customer turnover, the business may take proactive measures to retain key clients and optimize its marketing and retention strategy. Businesses in a corporation can increase revenue by developing an accurate prediction of client turnover habits and providing retention solutions [1].

The task of classifying customers presents significant hurdles due to class imbalance. Class imbalance denotes a situation where one or more classes are much more prevalent than the others [2]. Imbalanced classes can cause misleading performance metrics and increase false negative predictions. False negative occurs when the model fails to identify a customer who is likely to churn. That condition can have a significant impact on the implementation of retention strategies because customers who churn are predicted as not churn.

The data level-solving approach to the imbalance class problem is one of a solution to address the issue [3]. The data level method focuses on the preprocessing stage and is independent of the machine learning prediction method. A type of generative model based on a neural network called Generative Adversarial Networks, or GAN for short, is intended to generate realistic samples of entities [4].

Class imbalance issues have been addressed with GAN-based oversampling; nevertheless, class imbalance is not the only difficulty with customer churn data; class overlap is another issue that GAN is unable to resolve. The degree of similarity between instances of distinct classes is known as class overlap. Training a classifier that can distinguish between the classes with accuracy is challenging due to the class overlap requirement. Poor conditions for training data can result in machine learning performance factors [5].

The latest study works on GAN-based hybrid sampling [2], that works overcome class overlap and achieve the best result compared to other oversampling and hybrid sampling methods. We proposed a research framework using tabular type of GAN called CTGAN, in tabular data, category variables, numerical values, and specific relationships between columns are frequently present along with other organizational features. Tabular data has a specific structure and limitations that traditional GANs are not well adapted to handle. Tabular GANs can be configured to meet the difficulties and limitations posed by structured data, making them an invaluable tool for tabular data creation tasks. We did an experiment focused on customer churn problems and tried the result of the method in both classical and ensemble machine learning. This experiment’s objective is to observe the effectiveness difference in evaluation metrics and time execution of all algorithms. Beside the data-level solutions, we compared CTGAN-ENN with algorithm-level solution called cost-sensitive learning.

Our research objectives are highlighted as follows:

  • Customer churn prediction datasets that have high dimensional features and different imbalance ratios.

  • How CTGAN-ENN can handle class imbalance and class overlap in customer churn datasets.

  • The effectiveness of CTGAN-ENN in different machine learning algorithms within several evaluation metrics in customer churn prediction.

  • Comparison between data-level and algorithm-level solutions with CTGAN-ENN performance.

  • Algorithm time consumption between original CTGAN and CTGAN-ENN in customer churn prediction.

The novelty of this paper is this is the first work that investigates a combination of tabular GAN (CTGAN) and an under-sampling technique (ENN). This work shows a framework for handling high dimensional features in customer churn prediction with class imbalance and class overlap. Moreover, this work evaluates how effective CTGAN-ENN is both in evaluation metrics and algorithm time consumption. This paper scope is only on data-level solution, because this study wants to prove a tabular GAN hybrid sampling method achieved better performance than non-tabular GAN hybrid sampling method and another classical hybrid sampling method in customer churn predictions data. Furthermore, we provided a comparison between CTGAN-ENN and cost-sensitive learning in several machine-learning algorithms.

Related works

Customer churn prediction

Machine learning algorithms such as K-nearest Neighbor, naïve Bayes, and decision trees have been widely used on customer churn predictions as supervised learning in recent years. Classical machine learning methods commonly have a not satisfied result in performance. Ensemble approaches are used for classical machine learning problems, such as Random Forest, AdaBoost, Gradient Boosting, and XGBoost [6]. The main problem with customer churn predictions is imbalanced data. A deep learning algorithm is proposed using Deep & Cross Network (DCN) to learn latent features from customer churn prediction and Asymmetric Loss Function (ASL) to handle imbalanced data [1]. Based on recent studies, machine learning and deep learning for customer churn prediction have not reached outperforming results; the other approach is the data-level solution. This method works on the data preprocessing stage, using traditional oversampling techniques such as SMOTE and ADAYSN or hybrid sampling techniques using SMOTE + under-sampling methods [6, 7]. Traditional sampling methods have a problem in representing synthetic data. A deep learning-based sampling method can be used to tackle this problem. One of the common methods is Generative Adversarial Network (GAN); GAN produces synthetic data based on deep learning algorithms and represents real data better than traditional oversampling methods [8]. The latest technology of customer churn predictions uses a GAN-based hybrid sampling method to overcome a class overlap problem that GAN produced; this method resulted in an outperform in AUC, F1-Score, and G-mean metrics compared to other sampling methods [2, 9]. Class overlaps problem and hybrid sampling method are explained empirically in the next subsects, Class overlap problem and Hybrid sampling methods below.

Class overlap problem

Class overlap in machine learning makes it challenging to develop a reliable classifier that can differentiate between the classes. Factors contributing to class overlap include similar data, sparse distribution, and noise. Figure 1 shows oversampling results using CTGAN in all datasets. It’s worth noting that all datasets in this research have overlap issues, evident from the lack of clear distance between classes and stacked data between classes.

Fig. 1
figure 1

Class overlap in customer churn datasets

Hybrid sampling methods

In the last few years, there has been research on customer churn prediction and several research for managing class imbalance. Sampling approaches, which are further classified as under-sampling, over-sampling, and hybrid sampling are well-known strategies for addressing class imbalance. Some studies have tried hybrid sampling. Sáez et al. [10] proposed the extension of SMOTE with a noise data filter called an Iterative-Partitioning Filter (IPF) for handling not only class imbalanced problems but also one of the data distribution problems, which is noisy data. The results proved their method performs better than existing techniques. Besides class imbalance, class overlap is another problem in the customer churn dataset. Vuttipittayamongkol [11] proposed an NCR-based under-sampling architecture to eliminate any possible overlapping data to address class imbalance in binary datasets. Their approach is centered on the under-sampling combination, and the experimental results show significant improvements in sensitivity.

Geiler et al. [6] investigated an effective strategy for churn prediction with SMOTE hybrid sampling in several machine learning algorithms, including ensemble machine learning. They compared SMOTE—Random under-sampling, SMOTE—NCR, and SMOTE—Tomek links. Based on their experiment, SMOTE—NCR outperformed in 2 datasets and SMOTE—Tomek links achieved the best AUC in one dataset. Another SMOTE hybrid method proposed by Zhaozhao Xu et al. [12] proposed M-SMOTE with ENN combination and experimented on a RF algorithm. Main objective of this work is to handle class imbalance from medical data with M-SMOTE, while ENN is used to handle the misclassified data. Compared to other comparable approaches, the outcome of this work in ten medical datasets is more encouraging, and they achieved 99.5% in the F-1 score metric.

Besides the traditional oversampling method, there is a neural network-based oversampling method called GAN (Generative Adversarial Network). Ding et al. [9] proposed RGAN-EL to train the GAN to prioritize the class overlapping region during sample distribution fitting. They experimented with 41 imbalanced datasets. They compared the proposed research framework with several other oversampling methods, and the result showed a promising result from RGAN-EL. The limitation of this work is they are not using high-dimensional data in complex fields. Most of the datasets consist of less than ten dimensions.

Zhu et al. [2] developed a new hybrid sampling method with GAN and adaptive neighborhood-based weighted under-sampling (ANWU) in customer classification datasets. The ANWU method is used for handling class overlaps by removing generated instances and the original majority of class instances. They used KNN, DT, RF and GBM for their experiment. Compared to existing benchmark approaches, the GAN-based hybrid sampling method performs better in accuracy and profit-based evaluation criteria.

Cost-sensitive learning

Cost-sensitive learning is a paradigm within machine learning that focuses on incorporating the varying costs associated with different types of errors or misclassifications into the learning process. In traditional machine learning algorithms, the objective is typically to minimize overall classification error without considering the potential differences in the consequences or costs of different types of misclassifications.

Cost in cost-sensitive learning represented as a cost matrix shown in Fig. 2 [13]. The cost of false positives is \({C}_{10}\) and the cost of false negative is \({C}_{01}\), in classification problem adjusting the cost based on objective of prediction is important. For example, if we want to build a machine learning that predicts customer churn, avoid customer that predicted not churn but churn or false negative is the priority of classification objective. Therefore, give \({C}_{01}\) more cost than the other matrix can be one of the solutions in algorithm-level.

Fig. 2
figure 2

Cost-sensitive learning cost matrix

Machine learning algorithm

We used two types of machine learning to experiment with the classification task in customer churn prediction after the CTGAN-ENN result. The first type is classical machine learning, known as KNN, DT, and NB. The second type is ensemble machine learning, such as XGB, RF, and GBM. These algorithms are explained in this section.

K-nearest neighbor (KNN)

A case-based learning approach called K-Nearest Neighbor retains all the training data for categorization [14]. KNN is a well-known machine learning method that may be used for both regression and classification applications. Since the approach is instance-based and non-parametric, it makes no underlying assumptions about how the data are distributed. The idea behind KNN is that similar data points should lead to similar outcomes. In other words, a new data point that must be classified or predicted, you may utilize the labels (for classification) or values (for regression) of the K nearest data points (neighbors) in the training dataset to create predictions.

Decision tree (DT)

A recursive partition of the instance space is represented by a decision tree classifier. It consists of nodes that divide the instance space into two or more subspaces based on a discrete function over the input attributes. When the root, or root node of the tree, occupies the entire area, the first split occurs.

The nodes that come after are either leaf nodes, which show the final categorization, or internal nodes, which have a predecessor and multiple successors [15]. When it comes to classification jobs, any new data point that travels along the path to a leaf node will be projected to belong to the majority class in that leaf. The mean or median of the target variable in that node is usually the leaf node value for regression tasks.

Naïve bayes (NB)

The Naïve Bayes algorithm relies on a probabilistic approach to perform classification tasks. It operates under the premise that a feature’s existence in a class is independent of the existence of another feature in the same class [16]. The “naïve” assumption of feature independence, which simplifies the computations but might not hold true in all real-world situations, is the foundation of the Bayes theorem.

XGBoost (XGB)

XGboost is a weighted quantile sketch and sparsity-aware algorithm for approximate tree learning. To create a scalable tree-boosting system, XGBoost offers information on cache usage patterns and data compression [17]. XGBoost starts by creating a weak learner, which is a simple model that can make some predictions about the data, and then creates a new tree that is trained to correct the errors of the weak learner.

Random forest (RF)

As a kind of bagging technique, random forest constructs several models using various data subsets and then aggregates the forecasts from each model to produce a final prediction. In order for a random forest to function, a set of decision trees that grow in a number of randomly selected data subspaces are combined to create a prediction ensemble. Every tree in the collection is made by first selecting a subset of input coordinates at each node at random (which are there after referred to as features or variables). Based on these features, the training set determines the proper split [18].

Gradient boosting machine (GBM)

GBM is an ensemble forward learning model works by strongest predictor is chosen once all the weaker ones have been eliminated. This revised decision tree method compares each successor to the others, using the structure score, gain computations, and ever-finer approximations to produce a set of the tree's most satisfying structures [19]. To achieve a more accurate and complete model, GBM combines the predictions of several inaccurate models, typically decision trees. It accomplishes this by gradually training a series of decision trees, each one intended to fix the mistakes produced by the one before it.

Light gradient-boosting machine (LGB)

Light Gradient-Boosting Machine in short LightGBM is a novel algorithm from GBDT (Gradient Boosting Decision Tree), the purpose of LightGBM is reducing features by its information gain [20]. LightGBM works by parallel voting decision tree algorithm. It is designed to maximize parallel learning by reducing memory usage, accelerating the training process, and combining sophisticated network connectivity. Partition the training data among several machines, then carry out the local voting to determine the top-k attributes and the global voting to determine the top-two-k attributes for each iteration [20].

Summary

We discovered several kinds of research about hybrid sampling methods. Most recent work has used traditional oversampling techniques like SMOTE [6, 10, 12]. This technique can cause less data diversity in oversampling results. An oversampling technique that is based on a neural network can be the solution to sampling diversity. A common neural network technique for oversampling is the Generative Adversarial Network (GAN). A recent study suggested using GAN-based hybrid sampling to get around the GAN's class overlap problem. The latest work on GAN-based hybrid sampling has some limitations. The first limitation is that they do not use the tabular type of GAN for tabular data problems [2, 9]. The second limitation is that they are not measuring the time execution of the algorithms, which can be a good insight for real-world implementation. The remaining problem of the latest study about the GAN-based hybrid sampling method is that it does not involve a tabular-based GAN on tabular data cases.

In this research, we proposed a research framework that includes a GAN-based hybrid sampling scheme and used an ENN under-sampling combination to address the class overlap and tabular GAN type (CTGAN) to handle tabular data. Our work is focused on customer churn prediction problems with time execution insight in experiments. Hopefully, this method can be useful in real-world cases and used for any company that wants to implement churn prediction in their campaign strategies.

Methodology

Data preprocessing using the CTGAN-ENN approach is the first of two phases in our proposed research framework, as seen in Fig. 3 The input of imbalanced dataset divided by minority and majority classes. The minority class is oversampled by GAN using CTGAN [21], and it produces generated data within a balanced number as the majority class. The next process is concatenating the majority and generated data.

Fig. 3
figure 3

Proposed research framework

CTGAN can handle several challenges that are different from the traditional GAN model. First is a mixed data type because real-world data consists of a mixed data type between discrete and continuous. The second is highly imbalanced categorical columns. Customer churn prediction datasets used in this research have a high imbalanced ratio [21].

Figure 4 is the CTGAN workflow, assuming data has two features \({D}_{1}\) and \({D}_{2}\), and we want to oversample the \({D}_{2}\) feature and pick \({D}_{2}=1\). Select all the row from \({D}_{2}\) that has 1 in value and let it be train data \({T}_{train}\). The result of generator compared with \({T}_{train}\) resulted a critic score; The critic score estimated the distance between the learned conditional distribution and the conditional distribution of real data. That process is called training-by-sampling, which is better than a random value from traditional GAN.

Fig. 4
figure 4

CTGAN original framework

The ENN technique is effective for removing instances so that it can handle class overlap. Through the consideration of its k-nearest neighbors who are members of the other class, ENN selectively eliminates data that do not belong to majority class [11]. The details of ENN algorithm are presented in Algorithm 1. The input is concatenated data from CTGAN and original data\(D\), the number of nearest neighbors is \(k=3\), the majority samples \({D}_{maj}\) is dependent on the datasets and the initial under-sampling rate is \(R=1\).

Algorithm 1
figure a

Edited Nearest Neighbor (ENN)

The under-sampling method removed overlap instances from concatenate data. This stage result is the final customer churn dataset and is ready to process in machine learning prediction. The result of the effectiveness of the CTGAN-ENN method in this research is shown in the experiment result section.

After forming the final datasets, the second phase of our proposed research framework experimented with two types of machine learning on the customer churn datasets. The first type of model used KNN, Decision Trees, and Naïve Bayes, while the second type was an ensemble model that used XGBoost, Random Forest, GBM and LightGBM. To evaluate the models, we used four different criteria: AUC-ROC, G-Mean, F-1 Score, and algorithm execution time.

Algorithm 2
figure b

Pseudo-code of Proposed Framework

Algorithm 2 is pseudo-code of proposed framework, attempts to address the problem of unbalanced datasets in machine learning by putting forth a thorough strategy that includes phases for data pretreatment, augmentation, training, and evaluation. To resolve the underlying class imbalance, the dataset is first split into majority and minority classes. The creation of artificial minority instances via the CTGAN technique which mimics the minority class’s distribution by utilizing generative adversarial networks is a crucial next step. This augmentation technique gives the model a more balanced representation of the data, which attempts to alleviate the problem of class imbalance. After augmentation, the original majority and minority subsets are combined with the augmented minority data to create a consolidated dataset. The approach uses the Edited Nearest Neighbors (ENN) algorithm to further refine the dataset and improve its quality. By removing instances that are noisy or borderline, ENN enhances the final dataset's discriminative power. This cleaned dataset (called \(finalDatase{t}_{m}\)) is used as the basis for the model training and assessment that comes after.

The next step balanced final dataset ready to apply ensemble learning (\(ensembleM{L}_{m}\)) and classical machine learning (\(classicalM{L}_{m}\)) approaches. To verify our proposed method performance, a variety of modeling paradigms can be explored thanks to this dual approach. Finally, performance measurements like Area Under the Curve (AUC), F-measure, and G-mean are used to assess how effective the trained models are. With AUC representing overall discriminative power, F-measure reflecting the trade-off between precision and recall, and G-mean evaluating the model’s ability to manage class imbalance, these metrics shed light on the models’ capacity to discriminate between various classes. The last metric is time execution of algorithms. This metric considers computational efficiency to provide insights into the algorithm’s scalability and practical applicability.

Experiment design and result analysis

Data description

This research uses six datasets that are published in the Kaggle platform, which are telecommunication 1 (telco 1) dataset [22], bank dataset [23], mobile dataset [24], telecommunication 2 (telco 2) dataset [25], telecommunication 3 (telco 3) [26] and insurance dataset [27]. The imbalance ratio shows all datasets have imbalanced problems from 2.7 until extremely imbalanced on 7.5 imbalance ratio. This condition is one of the factor classifiers that hardly achieve stratifying results. The datasets overview shown in Table 1 below.

Table 1 Datasets overview

Experiment settings

All experiments in this research were performed using available Python packages. Table 2 shows several types of techniques, packages, and parameters that we used in the experiment. We compared the results of CTGAN-ENN (CE) with those of the conventional hybrid sampling techniques SMOTE-ENN (SE), ADAYSN-ENN (AE), and another GAN-based hybrid sampling technique known as WGAN-GP+ENN (WE) to demonstrate the efficacy of our proposed research framework.

Table 2 Packages and parameters for the sampling method

We apply the filling missing value preprocessing technique, wherein null or missing values in our experiment are replaced with mean values computed from the data. Since the mean replaces continuous data without introducing outliers, it yields a better result when used in place of missing or null values. The experiments of this study use fivefold cross validation for data training and data testing, k-fold cross validation gives a stable value of evaluation, because it is used all subsets of data.

The next stage after preprocessing customer churn data using the hybrid sampling method is to do a classification task. Table 3 displays the models, packages, and parameters that we employed for both classical and ensemble machine learning.

Table 3 Packages, functions, and parameters for machine learning method in data-level experiment

In the experiment we provided cost-sensitive learning approached on Decision Tree, XGBoost, Random Forest and Light Gradient Boosting algorithms. The aim is to provide a comparison of CTGAN-ENN not only on data-level solutions but also on algorithm-level solutions. Table 4 shown the packages and parameter in the experiment that implemented using the scikit-learn library.

Table 4 Packages, functions, and parameters for machine learning method in algorithm-level experiment

Evaluation metrics

Three separate evaluation techniques were used in the experiment, each with a distinct set of goals. The precision and recall harmonic means are used to calculate the first metric, known as the F1 score. It provides an equitable assessment of a model’s efficacy by considering both false positives and false negatives.

$$F1-Score=2* \frac{\left(PR*RC\right)}{\left(PR+RC\right)}$$
(1)

\(RC\) stand for recall, the ratio of accurately predicted positive observations to the total number of real positives. It assesses the model’s capacity to recognize and accurately classify every occurrence of the positive class. \(PR\) is precision, the ratio of accurately predicted positive observations to the total number of predicted positives is known as precision. It gauges how well the model predicts the favorable outcomes.

The second metric is AUC-ROC, plotting the true positive rate against the false positive rate yields the second figure of Area Under the Receiver Operating Characteristic Curve. The area under the ROC curve is denoted by AUC. It offers a model’s overall performance across different thresholds.

The third is G-mean, which uses a geometric mean of recall and specificity to provide a balanced assessment of a model's performance [2]. The G-mean measure can identify positive examples and prevent false positives, regardless of the distribution between sample classes.

$$G-mean=\sqrt{RC*SP}$$
(2)

Specificity, represented by \(SP\), is a gauge of how well the model can identify negative cases. Rather than dividing the training and testing data at random, we employed the k-fold cross-validation approach with the number of k = 5.

F1-score result

Tables 5, 6, 7, 8 and 9 give the result on each dataset with different evaluation metric. The best value of the experiment is marked in bold. Table 5 shows that CTGAN-ENN (CE) achieved the best performance in 21 scenarios out of 42 scenarios in F1-Score.

Table 5 Experimental result on F1-Score
Table 6 Experimental result on AUC
Table 7 Experimental result on G-Mean
Table 8 Result of Accuracy in Minority Class (%)
Table 9 Cost-Sensitive Learning Result

The performance of WE and SE both achieved rank 1 in 6 scenarios of F1-Score. WE performed well in the Decision Tree and Random Forest algorithm in two datasets, while SE performed well in all machine learning algorithms but only in the telco1 dataset. Our proposed research framework outperformed all machine learning algorithms, especially in KNN, GBM, XGB and LGB.

AUC-ROC result

Table 6 shows that CE obtained 29 of the best results out of 42 scenarios in terms of AUC. It outperformed the other sampling method most in all datasets except with the KNN and NB algorithms. Table 5 confirms that our proposed research framework performs well by reaching 29 the best performance out of 42 scenarios in the G-mean metric. Interestingly, WE performed well in the DT algorithm consistently in all evaluation metrics.

We found an interesting result that CE worked well on ensemble machine learning in all scenarios. Our technique performed exceptionally well in GBM, XGB, RF and LGB. Table 7 shows the mean ranking of all scenarios in this experiment and proves that CE achieved the best mean ranking in the ensemble machine learning algorithm.

G-Mean result

Machine learning algorithms are divided into two types: KNN, DT, NB is the classical model; otherwise, GBM, XGB, RF, LGB is ensemble model. From the experimental result, we can see the effectiveness of our proposed research framework. The ensemble model given a few improvements from the original CTGAN compared to the combination model with ENN (CE), but in the classical model, KNN, DT, and NB, the improvement was more significant than the ensemble model. It happened because the ensemble model has a robust classification ability. Even without the sampling method, the ensemble model achieved satisfying results.

Accuracy of minority class

Minority class accuracy refers to the accuracy metric specifically calculated for the minority class; in this study case all datasets have the same minority class (churn). The accuracy of minority class in this experiment was measured by recall metric because the minority class is positive class, as shown in formula 3 below.

$$recall=\frac{true\, positive}{true\, positive+false\, negative}$$
(3)

Table 8 shows the accuracy of minority class result by percentage, The comparison sheds light on the efficacy of CTGAN-ENN in addressing the challenges of imbalanced datasets. CTGAN-ENN outperformed on 19 of 42 scenario, and in the second-best rank is SE achieved best result on 10 of 42 scenario.

CTGAN-ENN and algorithm-level comparison

CTGAN-ENN achieved better results in all algorithm-level experiments, as shown in Table 9 we used Cost-Sensitive (CS) on Decision Tree (DT), Random Forest (RF), Light Gradient-Boosting Machine (LGB) and XGBoost (XGB). The cost that we used was 1 for class 0 or not churn, and 10 for class 1 or churn class, we given more cost on minority class to adjust the robustness of algorithm. Based on our experiment CTGAN-ENN (CE) consistently gained better performance in AUC, F1-Score and G-Mean metric compared to cost-sensitive learning result.

Mean ranking score, class overlap degree, and time execution result

CTGAN-ENN achieved a lower Fisher’s discriminant ratio is displayed in Table 10. The datasets degree of class overlap is considerably reduced by CTGAN-ENN.

Table 10 Class overlaps from features in datasets

The average rank within all experiment scenarios provided in Table 11, lower value of the mean ranking score indicates better results in the scenario. CE achieved the best average rank score in almost all scenarios, 14 of 21 scenarios. WE gained five best average ranks; WE outperformed in the DT classifier while CE outperformed in almost all ensemble machine learning models.

Table 11 Mean ranking score

We calculate the algorithm time consumption after data preprocessing using CTGAN and CTGAN-ENN. CTGAN-ENN reduced the time consumption algorithm in all machine learning algorithms except in NB. Although the improvement is not much in several cases, CTGAN-ENN has less algorithm time execution compared with CTGAN. The time consumption of the algorithm is significantly reduced by 38.47% on average, indicating that CTGAN-ENN can work effectively in large data of customer churn. The detailed results of these comparisons are provided in Appendix 1.

Figure 5 represents the visualization of all datasets by using the T-SNE technique. We can spot the difference between the middle side of the figure and the right side of the figure. Almost all stacked data points between classes were removed by ENN and produced a clear area between classes. The results of CTGAN-ENN in all datasets made machine learning algorithms learn more easily and achieved outperformed results compared to other sampling methods.

Fig. 5
figure 5figure 5

The visualization of non-sampling (left), CTGAN (middle) and CTGAN-ENN (right) of all datasets

Finally, after all the experiment’s measurements, our proposed research framework, CTGAN-ENN, outperforms compared to other over-sampling and hybrid sampling methods. From the experimental results, CTGAN-ENN is working very well on KNN and ensemble machine learning models GBM, XGB, and RF. If a company considers building a fast and optimal customer churn prediction model, KNN is the optimal model. Using our proposed research framework, a company can build the customer churn model using CTGAN-ENN and choose the machine learning algorithm according to their objective.

Compared with the latest work [2, 9], our work offers a new hybrid sampling method perspective using tabular GAN. The strong point of tabular GAN is the generative phase is built for tabular data while another GAN is originally built for image data.

Summary and analysis of results

The classical generative data algorithms like SMOTE and ADAYSN only work well in one dataset according to our experime nts. This happened because SMOTE and ADAYSN generate data using a cluster of original datasets. The variation of data might not be diverse for training in machine learning. In some datasets, SMOTE and ADAYSN work because the original data itself has a good distribution.

CTGAN-ENN in our work still has some limitations. The customer churn datasets usually have only two classes (churn and not churn). Meanwhile, that might be a multiclass problem for customer classifications in real-world datasets. The variation of GANs is widely developed nowadays. Besides the classification task, the GAN’s technique can be used in other customer machine learning tasks. For example, predicting the demand for a product by days or predicting a customer transaction number by period. Our work does not yet cover all the possibilities for using GANs and their hybrid method.

Our proposed research framework with XGBoost algorithm achieved better results than the latest work on the Telco 1 and Insurance datasets [1]. Specifically, it achieved an F1-Score of 0.949 for Telco 1 and 0.981 for Insurance, while the latest work achieved 0.635 and 0.623 respectively. Algorithm-level approach by cost-sensitive learning used DT, RF, LGB and XGB algorithms was compared with CTGAN-ENN, CTGAN-ENN consistently achieved better performance in all of experiments.

Compared with other latest work results [2], because the datasets are different, we compared with the same characteristics of datasets by imbalanced ratio for the latest work was IBM dataset with 2.76 imbalanced ratio and our works was Telco 1 with 2.7 imbalanced ratio. Our works showed better performance both in AUC and G-mean metrics. The latest work in AUC with GBM achieved 0.832 while our work achieved 0.991. In G-mean metric, the latest work achieved 0.743 with Random Forest algorithm while our work achieved 0.955.

Conclusion

Class imbalance and class overlap are common problems in customer churn prediction. A GAN-based hybrid sampling method was recently proposed to handle the issues. For a better result, we proposed a tabular GAN-based hybrid sampling method, CTGAN-ENN, which combines a tabular GAN-based oversampling and Edited Nearest Neighbor (ENN) under-sampling.

Our primary takeaway from the experiment’s outcome is that CTGAN-ENN enhanced customer churn prediction performance. The KNN algorithm achieved the best results in algorithm time consumption. Using CTGAN-ENN as a preprocessing strategy, GBM and XGB can be the most accurate models for predicting customer turnover if time consumption is disregarded. In the other scenario, we discovered that the DT algorithm performed well with WGAN-GP+ENN. Compared to recent studies using identical datasets and comparable dataset properties, our suggested research framework produced better findings.

Our findings have practical implications for stakeholders who want to build customer churn predictions in their company data. In the real-world case customer data rapidly increases over time, and the results of this study can give insight into the big data fields. By choosing the right combination of CTGAN-ENN and machine learning algorithms, stakeholders can consider their resources to build a customer churn prediction model.

This study’s limitation is that it only focuses on data-level solutions and binary classification problems; instead of binary classification tasks, we can investigate several modifications of our methodology for multi-class classification tasks in another customer classification problem. Our study results also have a problem with classical machine learning outcomes. This issue might be handled with algorithm-level solutions, our experiment in algorithm-level solutions only used simple cost-sensitive learning without further detail analysis. In future work, we can consider extending our framework using cost-sensitive learning in details analysis as algorithm-level solutions.

Availability of data and materials

Code and datasets can be accessed on GitHub: https://github.com/mahayasa/gan-hybrid-sampling-customer-churn. No datasets were generated or analysed during the current study.

Abbreviations

ADA:

ADAYSN

ADAYSN:

Adaptive Synthetic Sampling

AE:

ADAYSN+Edited Nearest Neighbor

AUC:

Area Under Curve

CE:

CTGAN + Edited Nearest Neighbor

CS:

Cost-sensitive

CT:

CTGAN

DT:

Decision tree

ENN:

Edited Nearest Neighbor

KNN:

K-nearest neighbor

LGB:

Light Gradient-Boosting Machine

NB:

Naïve Bayes

RF:

Random forest

GAN:

Generative Adversarial Network

GBM:

Gradient boosting machine

SE:

SMOTE+Edited Nearest Neighbor

SM:

SMOTE

SMOTE:

Synthetic Minority Over-sampling Technique

WG:

WGAN-GP

WE:

WGAN-GP+Edited Nearest Neighbor

XGB:

XGBoost

References

  1. Wen X, Wang Y, Ji X, Traoré MK. Three-stage churn management framework based on DCN with asymmetric loss. Expert Syst Appl. 2022;207:117998. https://doi.org/10.1016/j.eswa.2022.117998.

    Article  Google Scholar 

  2. Zhu B, Pan X, Vanden Broucke S, Xiao J. A GAN-based hybrid sampling method for imbalanced customer classification. Inf Sci. 2022;609:1397–411. https://doi.org/10.1016/j.ins.2022.07.145.

    Article  Google Scholar 

  3. Das S, Mullick SS, Zelinka I. On supervised class-imbalanced learning: an updated perspective and some key challenges. IEEE Trans Artif Intell. 2022;3(6):973–93. https://doi.org/10.1109/TAI.2022.3160658.

    Article  Google Scholar 

  4. Goodfellow IJ et al. Generative Adversarial Networks. 2014. http://arxiv.org/abs/1406.2661

  5. Huyen C. Designing machine learning systems. Sebastopol: O’Reilly Media; 2022.

    Google Scholar 

  6. Geiler L, Affeldt S, Nadif M. An effective strategy for churn prediction and customer profiling. Data Knowl Eng. 2022. https://doi.org/10.1016/j.datak.2022.102100.

    Article  Google Scholar 

  7. Wu S, Yau W-C, Ong T-S, Chong S-C. Integrated churn prediction and customer segmentation framework for telco business. IEEE Access. 2021;9:62118–36. https://doi.org/10.1109/ACCESS.2021.3073776.

    Article  Google Scholar 

  8. Su C, Wei L, Xie X. Churn prediction in telecommunications industry based on conditional Wasserstein GAN, In: 2022 IEEE 29th International Conference on High Performance Computing, Data, and Analytics (HiPC), 2022, pp. 186–191. https://doi.org/10.1109/HiPC56025.2022.00034.

  9. Ding H, Sun Y, Wang Z, Huang N, Shen Z, Cui X. RGAN-EL: a GAN and ensemble learning-based hybrid approach for imbalanced data classification. Inf Process Manag. 2023;60(2):103235. https://doi.org/10.1016/j.ipm.2022.103235.

    Article  Google Scholar 

  10. Sáez JA, Luengo J, Stefanowski J, Herrera F. SMOTE-IPF: addressing the noisy and borderline examples problem in imbalanced classification by a re-sampling method with filtering. Inf Sci. 2015;291:184–203. https://doi.org/10.1016/j.ins.2014.08.051.

    Article  Google Scholar 

  11. Vuttipittayamongkol P, Elyan E. Neighbourhood-based undersampling approach for handling imbalanced and overlapped data. Inf Sci. 2020;509:47–70. https://doi.org/10.1016/j.ins.2019.08.062.

    Article  Google Scholar 

  12. Xu Z, Shen D, Nie T, Kou Y. A hybrid sampling algorithm combining M-SMOTE and ENN based on random forest for medical imbalanced data. J Biomed Inform. 2020;107:103465. https://doi.org/10.1016/j.jbi.2020.103465.

    Article  Google Scholar 

  13. Elkan C. The Foundations of Cost-Sensitive Learning.

  14. Guo G, Wang H, Bell D, Bi Y, Greer K. LNCS 2888—KNN model-based approach in classification. Berlin: Springer; 2003.

    Google Scholar 

  15. Altuve M, Alvarez AJ, Severeyn E. Multiclass classification of metabolic conditions using fasting plasma levels of glucose and insulin. Health Technol (Berl). 2021;11(4):953–62. https://doi.org/10.1007/s12553-021-00550-w.

    Article  Google Scholar 

  16. Kumari S, Kumar D, Mittal M. An ensemble approach for classification and prediction of diabetes mellitus using soft voting classifier. Int J Cogn Comput Eng. 2021;2:40–6. https://doi.org/10.1016/j.ijcce.2021.01.001.

    Article  Google Scholar 

  17. Chen T, Guestrin C. XGBoost: a scalable tree boosting system. 2016. https://doi.org/10.1145/2939672.2939785.

  18. Biau G, Fr GB. Analysis of a random forests model. 2012.

  19. Shrivastav LK, Jha SK. A gradient boosting machine learning approach in modeling the impact of temperature and humidity on the transmission rate of COVID-19 in India. Appl Intell. 2021;51(5):2727–39. https://doi.org/10.1007/s10489-020-01997-6.

    Article  Google Scholar 

  20. Ke G et al. LightGBM: A highly efficient gradient boosting decision tree. https://github.com/Microsoft/LightGBM. Accessed 17 Mar 2023.

  21. Xu L, Skoularidou M, Cuesta-Infante A, Veeramachaneni K. Modeling Tabular data using Conditional GAN. 2019. http://arxiv.org/abs/1907.00503. Accessed 8 May 2023.

  22. Telco Customer Churn | Kaggle. https://www.kaggle.com/datasets/blastchar/telco-customer-churn. Accessed 07 Jun 2023.

  23. Churn Modelling | Kaggle. https://www.kaggle.com/datasets/shrutimechlearn/churn-modelling. Accessed 07 Jun 2023.

  24. mobile-churn-data.xlsx | Kaggle. https://www.kaggle.com/datasets/dimitaryanev/mobilechurndataxlsx. Accessed 07 Jun 2023

  25. Customer Churn Prediction 2020 | Kaggle. https://www.kaggle.com/competitions/customer-churn-prediction-2020. Accessed 07 Jun 2023.

  26. Customer Churn. https://www.kaggle.com/datasets/royjafari/customer-churn. Accessed 18 Mar 2024

  27. Vinod Kumar. Insurance churn prediction : weekend hackathon. https://www.kaggle.com/datasets/k123vinod/insurance-churn-prediction-weekend-hackathon. Accessed 15 Mar 2023.

  28. SMOTE—Version 0.10.1. https://imbalanced-learn.org/stable/references/generated/imblearn.over_sampling.SMOTE.html. Accessed 08 Jun 2023.

  29. Lemaître G, Nogueira F, Aridas CK. Imbalanced-learn: a python toolbox to tackle the curse of imbalanced datasets in machine learning. J Mach Learn Res. 2017;18(17):1–5.

    Google Scholar 

  30. ydata-synthetic Python package for synthetic data generation for tabular and time-series data. https://docs.synthetic.ydata.ai/1.3/. Accessed 04 Jul 2023.

  31. ctgan · PyPI.https://pypi.org/project/ctgan/. Accessed 08 Jun 2023.

  32. EditedNearestNeighbours—Version 0.10.1. https://imbalanced-learn.org/stable/references/generated/imblearn.under_sampling.EditedNearestNeighbours.html. Accessed 08 Jun 2023.

  33. SMOTEENN—Version 0.10.1. https://imbalanced-learn.org/stable/references/generated/imblearn.combine.SMOTEENN.html. Accessed 08 Jun 2023.

  34. Pedregosa F, et al. Scikit-learn: machine learning in python. J Mach Learn Res. 2011;12(85):2825–30.

    MathSciNet  Google Scholar 

  35. XGBoost Documentation—xgboost 2.0.3 documentation. https://xgboost.readthedocs.io/en/stable/. Accessed 19 Mar 2024.

  36. sklearn.ensemble.RandomForestClassifier—scikit-learn 1.4.1 documentation. https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html. Accessed 12 Mar 2024.

  37. lightgbm.LGBMClassifier—LightGBM 4.3.0.99 documentation. https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html. Accessed 19 Mar 2024.

Download references

Acknowledgements

This work was supported by Khon Kaen University ASEAN GMS grant and part of AIDA (Applied Intelligence and Data Analytics) lab in College of Computing, Khon Kaen University, Thailand.

Funding

Khon Kaen University.

Author information

Authors and Affiliations

Authors

Contributions

I Nyoman Mahayasa Adiputra and Paweena Wanchai conceived of the presented idea. Both authors contributed to the design and implementation of the research, to the analysis of the results and to the writing of the manuscript. I Nyoman Mahayasa Adiputra carried out the experiments. Paweena Wanchai supervised the project, provided critical feedback, and helped shape the research and analysis. Both authors discussed the results and contributed to the final version of the manuscript.

Corresponding author

Correspondence to Paweena Wanchai.

Ethics declarations

Competing interests

The authors declare that this research works have no competing interests.

Additional information

Publisher's Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Appendix 1

Appendix 1

CTGAN and CTGAN-ENN algorithm time execution (s)

Algorithm

Dataset

CTGAN

CTGAN-ENN

KNN

Bank

2.10

1.54

 

Mobile

336.54

276.99

 

Telco1

8.58

4.89

 

Telco2

1.72

1.14

 

Telco3

3.22

1.84

 

Insurance

81.91

44.45

DT

Bank

1.78

0.92

 

Mobile

79.83

39.53

 

Telco1

1.38

1.33

 

Telco2

2.07

1.49

 

Telco3

0.79

0.67

 

Insurance

7.40

6.20

NB

Bank

0.30

0.30

 

Mobile

4.69

3.52

 

Telco1

0.66

0.58

 

Telco2

0.18

0.16

 

Telco3

0.23

0.22

 

Insurance

1.40

0.90

GBM

Bank

43.70

27.34

 

Mobile

1820

1673

 

Telco1

29.40

23.80

 

Telco2

51.26

39.15

 

Telco3

27.42

23.18

 

Insurance

247.61

177.9

XGB

Bank

27.56

11.95

 

Mobile

458.07

353.40

 

Telco1

27.20

23.00

 

Telco2

27.25

16.98

 

Telco3

2.72

2.19

 

Insurance

158.94

123.6

RF

Bank

41.12

28.73

 

Mobile

605.53

336.22

 

Telco1

23.90

14.20

 

Telco2

37.72

23.53

 

Telco3

12.18

10.72

 

Insurance

141.45

100.62

LGM

Bank

6.79

3.78

 

Mobile

58.51

53.91

 

Telco1

5.56

7.48

 

Telco2

4.72

6.12

 

Telco3

11.03

7.67

 

Insurance

17.52

15.19

  1. Bold values represent the algorithm execution time of customer churn predcition, altought the improvement of prediction not really significant in some cases comapred to CTGAN, CTGAN-ENN has a smaller time excetuion. That indicates the realibility of CTGAN-ENN to works on big scale data

Rights and permissions

Open Access This article is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License, which permits any non-commercial use, sharing, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if you modified the licensed material. You do not have permission under this licence to share adapted material derived from this article or parts of it. The images or other third party material in this article are included in the article’s Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by-nc-nd/4.0/.

Reprints and permissions

About this article

Check for updates. Verify currency and authenticity via CrossMark

Cite this article

Adiputra, I.N.M., Wanchai, P. CTGAN-ENN: a tabular GAN-based hybrid sampling method for imbalanced and overlapped data in customer churn prediction. J Big Data 11, 121 (2024). https://doi.org/10.1186/s40537-024-00982-x

Download citation

  • Received:

  • Accepted:

  • Published:

  • DOI: https://doi.org/10.1186/s40537-024-00982-x

Keywords