Drift analyses are crucial to understanding how your model will fare on different distributions. This notebook covers how one might use TruEra's Python SDK to examine the drift of a model between train and test data -- but this could also be used to examine a model's behavior over time.
Before you begin ⬇️¶
- Install the TruEra Python SDK
- Check our primer on Explainer objects
- Read our tutorial on local compute mode for the Python SDK.
What we'll cover ☑️¶
- First, we'll create a TruEra project. We'll use sample data from scikit-learn, and create a project with a sample gradient boosted tree model. We'll also ingest train and test split data for this model.
- We'll then track the performance of your model between train and test sets with an
- Finally, we'll drill into the root causes of the instability between distributions, so we can understand and debug your model.
from truera.client.truera_workspace import TrueraWorkspace from truera.client.truera_authentication import TokenAuthentication TOKEN = '<AUTH_TOKEN>' TRUERA_URL = '<TRUERA_URL>' auth = TokenAuthentication(TOKEN) tru = TrueraWorkspace(TRUERA_URL, auth)
Step 2: Download sample data¶
Here we'll use the data from scikit-learn's California Housing dataset, which is a regression dataset. This is available directly from the
# Retrieve the data. import pandas as pd from sklearn.datasets import fetch_california_housing data_bunch = fetch_california_housing() XS_ALL = pd.DataFrame(data=data_bunch["data"], columns=data_bunch["feature_names"]) YS_ALL = data_bunch["target"]
# Create train and test data splits. from sklearn.model_selection import train_test_split XS_TRAIN, XS_TEST, YS_TRAIN, YS_TEST = train_test_split(XS_ALL, YS_ALL, test_size=0.5, random_state=0)
We'll add two kinds of noise to exacerbate the differences between our
test sets for the purpose of this notebook:
1. Shift the
HouseAge feature in the
test data (but not the
train data) by 10. This is an example of data drift.
2. When the
HouseAge feature is in between 20 and 30, set the label to 0. This is an example of mislabelled data points.
XS_TEST["HouseAge"] += 10 YS_TRAIN[(20 <= XS_TRAIN["HouseAge"]) & (XS_TRAIN["HouseAge"] < 30)] = 0
tru.add_project("California Housing", score_type="regression")
from truera.client.ingestion import ColumnSpec column_spec = ColumnSpec( id_col_name="id", label_col_names="label", pre_data_col_names=data_bunch["feature_names"] )
data_test = XS_TEST.merge(YS_TEST, left_index=True, right_index=True).reset_index(names="id") data_train = XS_TRAIN.merge(YS_TRAIN, left_index=True, right_index=True).reset_index(names="id")
tru.add_data_collection("sklearn_data") tru.add_data(data_train, data_split_name="train", column_spec=column_spec) tru.add_data(data_test, data_split_name="test", column_spec=column_spec)
# Train the model. from sklearn.ensemble import GradientBoostingRegressor from sklearn.metrics import mean_squared_error gb_model = GradientBoostingRegressor(random_state=0) gb_model.fit(XS_TRAIN, YS_TRAIN)
# Add to TruEra workspace. tru.add_python_model("gradient boosted", gb_model)
Step 5: Examine model accuracy between train and test¶
Here, we create an explainer object setting
train as our base data split, and
test as our comparison data split. This enables us to easily compare performance across splits.
explainer = tru.get_explainer(base_data_split="train", comparison_data_splits=["test"]) explainer.compute_performance("RMSE")
We can see there is a marked gap between the RMSE of our
# Find feature that has shifted the most. instability = explainer.compute_feature_contributors_to_instability("regression") instability.T.sort_values(by="test", ascending=False)
HouseAge feature has shifted so heavily let's plot its distribution in both
import matplotlib.pyplot as plt plt.figure(figsize=(21, 6)) XS_TRAIN["HouseAge"].hist() XS_TEST["HouseAge"].hist() plt.legend(["Train", "Test"]) plt.xlabel("`HouseAge` value") plt.ylabel("Frequency")
This shows some odd behavior in that the distribution of the
HouseAge seems to have shifted between the
train data and the
test data. In fact, it appears that the data has shifted by around 10. So we were able to catch the issue!
Given the problematic behavior, let's also look at the influence sensitivity plot (ISP) of the feature.
explainer = tru.get_explainer(base_data_split="train") explainer.plot_isp("HouseAge")
The data does appear quite fishy in the 20 to 30 region, as we might expect!