# Import necessary libraries
from dataidea.packages import *
from dataidea.datasets import loadDataset
from sklearn.ensemble import RandomForestClassifier
Visual EDA
# Load the dataset
= loadDataset('../assets/demo_cleaned.csv', inbuilt=False, file_type='csv') demo_df
demo_df.head()
age | gender | marital_status | address | income | income_category | job_category | |
---|---|---|---|---|---|---|---|
0 | 55 | f | 1 | 12 | 72.0 | 3.0 | 3 |
1 | 56 | m | 0 | 29 | 153.0 | 4.0 | 3 |
2 | 24 | m | 1 | 4 | 26.0 | 2.0 | 1 |
3 | 45 | m | 0 | 9 | 76.0 | 4.0 | 2 |
4 | 44 | m | 1 | 17 | 144.0 | 4.0 | 3 |
= pd.get_dummies(demo_df, ['gender'], dtype=int, drop_first=1) demo_df2
demo_df2.head()
age | marital_status | address | income | income_category | job_category | gender_m | |
---|---|---|---|---|---|---|---|
0 | 55 | 1 | 12 | 72.0 | 3.0 | 3 | 0 |
1 | 56 | 0 | 29 | 153.0 | 4.0 | 3 | 1 |
2 | 24 | 1 | 4 | 26.0 | 2.0 | 1 | 1 |
3 | 45 | 0 | 9 | 76.0 | 4.0 | 2 | 1 |
4 | 44 | 1 | 17 | 144.0 | 4.0 | 3 | 1 |
Data Distribution Plots:
- Histograms for numerical features (age and income).
- Bar plots for categorical features (gender, income_category, job_category etc).
# 1. Data Distribution Plots
= plt.subplots(2, 3, figsize=(18, 10))
fig, axes 'age'], ax=axes[0, 0], kde=True, color='skyblue')
sns.histplot(demo_df2['income'], ax=axes[0, 1], kde=True, color='salmon')
sns.histplot(demo_df2[='gender_m', data=demo_df2, ax=axes[0, 2])
sns.countplot(x='income_category', data=demo_df2, ax=axes[1, 0])
sns.countplot(x='job_category', data=demo_df2, ax=axes[1, 1])
sns.countplot(x='marital_status', data=demo_df2, ax=axes[1, 2])
sns.countplot(x plt.tight_layout()
Pairwise Feature Scatter Plots:
- Scatter plots of age vs. income,
# 2. Pairwise Feature Scatter Plots
= sns.pairplot(demo_df2, vars=['age', 'income'], hue='marital_status', palette={0: 'blue', 1: 'orange'}) g
Correlation Heatmap:
- A heatmap showing the correlation between numerical features.
# 3. Correlation Heatmap
=(8, 6))
plt.figure(figsize'age', 'income']].corr(), annot=True, cmap='coolwarm', fmt=".2f") sns.heatmap(demo_df2[[
Missing Values Matrix:
- A matrix indicating missing values in different features.
# 4. Missing Values Matrix
=(8, 6))
plt.figure(figsize='viridis') sns.heatmap(demo_df2.isnull(), cmap
- Feature Importance Plot:
- After training a model (e.g., random forest), we can visualize feature importances to see which features contribute the most to predicting survival.
# 5. Feature Importance Plot
# Prepare data for training
= demo_df2.drop(['marital_status'], axis=1)
X = demo_df2['marital_status']
y
# Train Random Forest Classifier
= RandomForestClassifier()
rf_classifier
rf_classifier.fit(X, y)
# Plot feature importances
=(10, 6))
plt.figure(figsize= rf_classifier.feature_importances_
importances = np.argsort(importances)[::-1]
indices =importances[indices], y=X.columns[indices], palette='viridis', hue=X.columns[indices])
sns.barplot(x
plt.show()