Skip to content

Commit

Permalink
Added DjstripePaymentMethod._get_or_create_from_stripe_object to allo…
Browse files Browse the repository at this point in the history
…w syncing of PaymentMethodForeignkeys

DjstripePaymentMethod._get_or_create_from_stripe_object is essentially a wrapper around the already implemented DjstripePaymentMethod._get_or_create_source and also added the Subscription Source type to the list of known sources

Updated Corresponding Tests
  • Loading branch information
arnav13081994 authored and jleclanche committed Dec 10, 2021
1 parent 767b4b8 commit 663b95c
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 16 deletions.
3 changes: 3 additions & 0 deletions djstripe/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def fraudulent(self) -> bool:
self.fraud_details and list(self.fraud_details.values())[0] == "fraudulent"
)

# todo may be unnecessary after this PR
def _attach_objects_hook(self, cls, data, current_ids=None):
from .payment_methods import DjstripePaymentMethod

Expand Down Expand Up @@ -1290,6 +1291,8 @@ def _attach_objects_post_save_hook(

save = False

# todo check all "reverse" PaymentMethod FKs model's attach and attach post swave hooks for sources syncs.
# todo should be unnecessary after this pr
customer_sources = data.get("sources")
sources = {}
if customer_sources:
Expand Down
55 changes: 54 additions & 1 deletion djstripe/models/payment_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ def from_stripe_object(cls, data):
return instance

@classmethod
def _get_or_create_source(cls, data, source_type):
def _get_or_create_source(cls, data, source_type=None):

# prefer passed in source_type
if not source_type:
source_type = data["object"]

try:
model = cls._model_for_type(source_type)
model._get_or_create_from_stripe_object(data)
Expand Down Expand Up @@ -77,6 +82,54 @@ def object_model(self):
def resolve(self):
return self.object_model.objects.get(id=self.id)

@classmethod
def _get_or_create_from_stripe_object(
cls,
data,
field_name="id",
refetch=True,
current_ids=None,
pending_relations=None,
save=True,
stripe_account=None,
):

raw_field_data = data.get(field_name)
id_ = StripeModel._id_from_data(raw_field_data)

if id_.startswith("card"):
source_cls = Card
source_type = "card"
elif id_.startswith("src"):
source_cls = Source
source_type = "source"
elif id_.startswith("ba"):
source_cls = BankAccount
source_type = "bank_account"
elif id_.startswith("acct"):
source_cls = Account
source_type = "account"
else:
# This may happen if we have source types we don't know about.
# Let's not make dj-stripe entirely unusable if that happens.
logger.warning(f"Unknown Object. Could not sync source with id: {id_}")
return cls.objects.get_or_create(
id=id_, defaults={"type": f"UNSUPPORTED_{id_}"}
)

# call model's _get_or_create_from_stripe_object to ensure
# that object exists before getting or creating its source object
source_cls._get_or_create_from_stripe_object(
data,
field_name,
refetch=refetch,
current_ids=current_ids,
pending_relations=pending_relations,
stripe_account=stripe_account,
)

return cls.objects.get_or_create(id=id_, defaults={"type": source_type})


class LegacySourceMixin:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_charge.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def test_sync_from_stripe_data_unsupported_source(

charge = Charge.sync_from_stripe_data(fake_charge_copy)
self.assertEqual("test_id", charge.source_id)
self.assertEqual("unsupported", charge.source.type)
self.assertEqual("UNSUPPORTED_test_id", charge.source.type)
self.assertEqual(charge.source, DjstripePaymentMethod.objects.get(id="test_id"))

charge_retrieve_mock.assert_not_called()
Expand Down
52 changes: 38 additions & 14 deletions tests/test_customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Plan,
Price,
Product,
Source,
Subscription,
)
from djstripe.settings import djstripe_settings
Expand All @@ -50,6 +51,7 @@
FAKE_PRICE,
FAKE_PRODUCT,
FAKE_SOURCE,
FAKE_SOURCE_II,
FAKE_SUBSCRIPTION,
FAKE_SUBSCRIPTION_II,
FAKE_UPCOMING_INVOICE,
Expand Down Expand Up @@ -117,12 +119,11 @@ def test_customer_sync_unsupported_source(self):
user = get_user_model().objects.create_user(
username="test_user_sync_unsupported_source"
)
synced_customer = fake_customer.create_for_user(user)
self.assertEqual(0, synced_customer.legacy_cards.count())
self.assertEqual(0, synced_customer.sources.count())
self.assertEqual(
synced_customer.default_source,
DjstripePaymentMethod.objects.get(id=fake_customer["default_source"]["id"]),
self.assertRaisesRegexp(
ValueError,
"Trying to fit a 'fish' into 'Card'. Aborting.",
fake_customer.create_for_user,
user,
)

def test_customer_sync_has_subscriber_metadata(self):
Expand Down Expand Up @@ -210,24 +211,46 @@ def test_customer_create_metadata_disabled(self, customer_mock):
},
)

@patch.object(Card, "_get_or_create_from_stripe_object")
@patch("stripe.Customer.retrieve", autospec=True)
@patch(
"stripe.Card.retrieve",
return_value=FAKE_CUSTOMER_II["default_source"],
autospec=True,
)
def test_customer_sync_non_local_card(self, card_retrieve_mock):
def test_customer_sync_non_local_card(
self, card_retrieve_mock, customer_retrieve_mock, card_get_or_create_mock
):
fake_customer = deepcopy(FAKE_CUSTOMER_II)
fake_customer["id"] = fake_customer["sources"]["data"][0][
"customer"
] = "cus_test_sync_non_local_card"
fake_customer["default_source"]["id"] = fake_customer["sources"]["data"][0][
"id"
] = "card_cus_test_sync_non_local_card"

customer_retrieve_mock.return_value = fake_customer

fake_card = deepcopy(fake_customer["default_source"])
fake_card["customer"] = "cus_test_sync_non_local_card"
card_retrieve_mock.return_value = fake_card
card_get_or_create_mock.return_value = fake_card

user = get_user_model().objects.create_user(
username="test_user_sync_non_local_card"
)

# create a source object so that FAKE_CUSTOMER_III with a default source
# can be created correctly.
fake_source_data = deepcopy(FAKE_SOURCE_II)
fake_source_data["card"] = deepcopy(fake_card)
fake_source_data["customer"] = fake_customer

Source.sync_from_stripe_data(fake_source_data)

customer = fake_customer.create_for_user(user)

self.assertEqual(customer.sources.count(), 0)
self.assertEqual(customer.legacy_cards.count(), 1)
self.assertEqual(customer.sources.count(), 1)
self.assertEqual(customer.legacy_cards.count(), 0)
self.assertEqual(
customer.default_source.id, fake_customer["default_source"]["id"]
)
Expand Down Expand Up @@ -292,12 +315,13 @@ def test_customer_sync_no_sources(self, customer_mock):
def test_customer_sync_default_source_string(self):
Customer.objects.all().delete()
Card.objects.all().delete()

customer_fake = deepcopy(FAKE_CUSTOMER)
customer_fake["default_source"] = customer_fake["sources"]["data"][0][
"id"
] = "card_sync_source_string"

customer = Customer.sync_from_stripe_data(customer_fake)
self.assertEqual(customer.default_source.id, customer_fake["default_source"])
self.assertEqual(
customer.default_source.id, customer_fake["default_source"]["id"]
)
self.assertEqual(customer.legacy_cards.count(), 2)
self.assertEqual(len(list(customer.customer_payment_methods)), 2)

Expand Down
7 changes: 7 additions & 0 deletions tests/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ def setUp(self):
user = get_user_model().objects.create_user(
username="testuser", email="[email protected]"
)

# create a source object so that FAKE_CUSTOMER_III with a default source
# can be created correctly.
fake_source_data = deepcopy(FAKE_SOURCE)
fake_source_data["customer"] = None
self.source = Source.sync_from_stripe_data(fake_source_data)

self.customer = FAKE_CUSTOMER_III.create_for_user(user)
self.customer.sources.all().delete()
self.customer.legacy_cards.all().delete()
Expand Down
22 changes: 22 additions & 0 deletions tests/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from djstripe.models import Plan, Product, Subscription

from . import (
FAKE_CARD,
FAKE_CUSTOMER,
FAKE_CUSTOMER_II,
FAKE_PLAN,
Expand Down Expand Up @@ -144,6 +145,27 @@ def test_sync_from_stripe_data(
)
self.assertEqual(datetime_to_unix(subscription.cancel_at), 1624553655)

@patch("stripe.Plan.retrieve", return_value=deepcopy(FAKE_PLAN), autospec=True)
@patch(
"stripe.Product.retrieve", return_value=deepcopy(FAKE_PRODUCT), autospec=True
)
@patch(
"stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER), autospec=True
)
def test_sync_from_stripe_data_default_source_string(
self, customer_retrieve_mock, product_retrieve_mock, plan_retrieve_mock
):
subscription_fake = deepcopy(FAKE_SUBSCRIPTION)
subscription_fake["default_source"] = FAKE_CARD["id"]

subscription = Subscription.sync_from_stripe_data(subscription_fake)
self.assertEqual(subscription.default_source.id, FAKE_CARD["id"])

# pop out "djstripe.Subscription.default_source" from self.assert_fks
expected_blank_fks = deepcopy(self.default_expected_blank_fks)
expected_blank_fks.remove("djstripe.Subscription.default_source")
self.assert_fks(subscription, expected_blank_fks=expected_blank_fks)

@patch("stripe.Plan.retrieve", return_value=deepcopy(FAKE_PLAN_II), autospec=True)
@patch(
"stripe.Product.retrieve", return_value=deepcopy(FAKE_PRODUCT), autospec=True
Expand Down

0 comments on commit 663b95c

Please sign in to comment.