Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Dec 29, 2017
1 parent 81a6a72 commit 6ee5585
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
4 changes: 2 additions & 2 deletions lifetimes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ def expected_cumulative_transactions(model, transactions, datetime_col, customer

# make sure the date column uses datetime objects, and use Pandas' DateTimeIndex.to_period()
# to convert the column to a PeriodIndex which is useful for time-wise grouping and truncating
transactions[datetime_col] = pd.to_datetime(transactions[datetime_col], format=datetime_format)
transactions = transactions.set_index(datetime_col).to_period(freq).reset_index()
transactions[datetime_col] = pd.to_datetime(transactions[datetime_col], format=datetime_format).dt.to_period(freq)
transactions = transactions.drop_duplicates()

# find birth dates of users
birth_dates = transactions.groupby(customer_id_col, sort=False, as_index=True)[datetime_col].min()
Expand Down
14 changes: 10 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,20 @@ def test_customer_lifetime_value_with_known_values(fitted_bg):
t = fitted_bg.data.head()
expected = np.array([0.016053, 0.021171, 0.030461, 0.031686, 0.001607])
# discount_rate=0 means the clv will be the same as the predicted
clv_d0 = utils._customer_lifetime_value(fitted_bg, t['frequency'], t['recency'], t['T'], monetary_value=pd.Series([1,1,1,1,1]), time=1, discount_rate=0.)
clv_d0 = utils._customer_lifetime_value(fitted_bg, t['frequency'], t['recency'], t['T'], monetary_value=pd.Series([1, 1, 1, 1, 1]), time=1, discount_rate=0.)
assert_almost_equal(clv_d0.values, expected, decimal=5)
# discount_rate=1 means the clv will halve over a period
clv_d1 = utils._customer_lifetime_value(fitted_bg, t['frequency'], t['recency'], t['T'], monetary_value=pd.Series([1,1,1,1,1]), time=1, discount_rate=1.)
clv_d1 = utils._customer_lifetime_value(fitted_bg, t['frequency'], t['recency'], t['T'], monetary_value=pd.Series([1, 1, 1, 1, 1]), time=1, discount_rate=1.)
assert_almost_equal(clv_d1.values, expected / 2., decimal=5)
# time=2, discount_rate=0 means the clv will be twice the initial
clv_t2_d0 = utils._customer_lifetime_value(fitted_bg, t['frequency'], t['recency'], t['T'], monetary_value=pd.Series([1,1,1,1,1]), time=2, discount_rate=0)
clv_t2_d0 = utils._customer_lifetime_value(fitted_bg, t['frequency'], t['recency'], t['T'], monetary_value=pd.Series([1, 1, 1, 1, 1]), time=2, discount_rate=0)
assert_allclose(clv_t2_d0.values, expected * 2., rtol=0.1)
# time=2, discount_rate=1 means the clv will be twice the initial
clv_t2_d1 = utils._customer_lifetime_value(fitted_bg, t['frequency'], t['recency'], t['T'], monetary_value=pd.Series([1,1,1,1,1]), time=2, discount_rate=1.)
clv_t2_d1 = utils._customer_lifetime_value(fitted_bg, t['frequency'], t['recency'], t['T'], monetary_value=pd.Series([1, 1, 1, 1, 1]), time=2, discount_rate=1.)
assert_allclose(clv_t2_d1.values, expected / 2. + expected / 4., rtol=0.1)


def test_expected_cumulative_transactions_dedups_inside_a_time_period(fitted_bg, example_transaction_data):
by_week = utils.expected_cumulative_transactions(fitted_bg, example_transaction_data, 'date', 'id', 10, freq='W')
by_day = utils.expected_cumulative_transactions(fitted_bg, example_transaction_data, 'date', 'id', 10, freq='D')
assert (by_week['actual'] >= by_day['actual']).all()

0 comments on commit 6ee5585

Please sign in to comment.