Skip to content

Commit

Permalink
Update StripeModel._stripe_object_to_record to handle DecimalField in…
Browse files Browse the repository at this point in the history
…stances

This was done because the DecimalField and StripePercentField values were getting cached and the instance so returned, therefore had stale data.

Added Tests to ensure the returned instance had the updated Python object from the db and added a testmodel just for this purpose in the fields dir in the tests folder. This was done to isolate test models as recommended by Django
  • Loading branch information
arnav13081994 authored Dec 20, 2021
1 parent d9db6ab commit 4ba6954
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 23 deletions.
13 changes: 12 additions & 1 deletion djstripe/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

from djstripe.utils import get_friendly_currency_amount

from ..fields import JSONField, StripeDateTimeField, StripeForeignKey, StripeIdField
from ..fields import (
JSONField,
StripeDateTimeField,
StripeForeignKey,
StripeIdField,
StripePercentField,
)
from ..managers import StripeModelManager
from ..settings import djstripe_settings

Expand Down Expand Up @@ -941,6 +947,11 @@ def sync_from_stripe_data(cls, data):
instance.save()
instance._attach_objects_post_save_hook(cls, data)

for field in instance._meta.concrete_fields:
if isinstance(field, StripePercentField):
# get rid of cached values
delattr(instance, field.name)

return instance

@classmethod
Expand Down
9 changes: 9 additions & 0 deletions tests/fields/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Models used exclusively for testing"""

from django.db import models

from djstripe.fields import StripePercentField


class TestDecimalModel(models.Model):
noval = StripePercentField()
2 changes: 2 additions & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
"jsonfield",
"djstripe",
"tests",
# to load custom models defined to test fields.py
"tests.fields",
"tests.apps.testapp",
"tests.apps.example",
]
Expand Down
30 changes: 30 additions & 0 deletions tests/test_coupon.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
from decimal import Decimal

import pytest
from django.test.testcases import TestCase
Expand Down Expand Up @@ -105,3 +106,32 @@ class TestCouponStr(TestCase):
def test_blank_coupon_str(self):
coupon = Coupon()
self.assertEqual(str(coupon).strip(), "(invalid amount) off")


class TestCouponDecimal:
@pytest.mark.parametrize(
"inputted,expected",
[
(Decimal("1"), Decimal("1.00")),
(Decimal("1.5234567"), Decimal("1.52")),
(Decimal("0"), Decimal("0.00")),
(Decimal("23.2345678"), Decimal("23.23")),
("1", Decimal("1.00")),
("1.5234567", Decimal("1.52")),
("0", Decimal("0.00")),
("23.2345678", Decimal("23.23")),
(1, Decimal("1.00")),
(1.5234567, Decimal("1.52")),
(0, Decimal("0.00")),
(23.2345678, Decimal("23.24")),
],
)
def test_decimal_percent_off_coupon(self, inputted, expected):
fake_coupon = deepcopy(FAKE_COUPON)
fake_coupon["percent_off"] = inputted

coupon = Coupon.sync_from_stripe_data(fake_coupon)
field_data = coupon.percent_off

assert isinstance(field_data, Decimal)
assert field_data == expected
31 changes: 31 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.test.utils import override_settings

from djstripe.fields import StripeDateTimeField, StripeDecimalCurrencyAmountField
from tests.fields.models import TestDecimalModel

pytestmark = pytest.mark.django_db

Expand Down Expand Up @@ -43,3 +44,33 @@ def test_stripe_to_db_datetime_val(self):
datetime(1997, 9, 18, 7, 48, 35, tzinfo=timezone.utc),
self.noval.stripe_to_db({"noval": 874568915}),
)


class TestStripePercentField:
@pytest.mark.parametrize(
"inputted,expected",
[
(Decimal("1"), Decimal("1.00")),
(Decimal("1.5234567"), Decimal("1.52")),
(Decimal("0"), Decimal("0.00")),
(Decimal("23.2345678"), Decimal("23.23")),
("1", Decimal("1.00")),
("1.5234567", Decimal("1.52")),
("0", Decimal("0.00")),
("23.2345678", Decimal("23.23")),
(1, Decimal("1.00")),
(1.5234567, Decimal("1.52")),
(0, Decimal("0.00")),
(23.2345678, Decimal("23.24")),
],
)
def test_stripe_percent_field(self, inputted, expected):
# create a model with the StripePercentField
model_field = TestDecimalModel(noval=inputted)
model_field.save()

# get the field data
field_data = TestDecimalModel.objects.get(pk=model_field.pk).noval

assert isinstance(field_data, Decimal)
assert field_data == expected
71 changes: 71 additions & 0 deletions tests/test_invoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
dj-stripe Invoice Model Tests.
"""
from copy import deepcopy
from decimal import Decimal
from unittest.mock import ANY, patch

import pytest
import stripe
from django.contrib.auth import get_user_model
from django.test.testcases import TestCase
from stripe.error import InvalidRequestError
Expand Down Expand Up @@ -31,6 +34,8 @@
AssertStripeFksMixin,
)

pytestmark = pytest.mark.django_db


class InvoiceTest(AssertStripeFksMixin, TestCase):
def setUp(self):
Expand Down Expand Up @@ -1350,3 +1355,69 @@ def test_no_upcoming_invoices(self, invoice_upcoming_mock):
def test_upcoming_invoice_error(self, invoice_upcoming_mock):
with self.assertRaises(InvalidRequestError):
Invoice.upcoming()


class TestInvoiceDecimal:
@pytest.mark.parametrize(
"inputted,expected",
[
(Decimal("1"), Decimal("1.00")),
(Decimal("1.5234567"), Decimal("1.52")),
(Decimal("0"), Decimal("0.00")),
(Decimal("23.2345678"), Decimal("23.23")),
("1", Decimal("1.00")),
("1.5234567", Decimal("1.52")),
("0", Decimal("0.00")),
("23.2345678", Decimal("23.23")),
(1, Decimal("1.00")),
(1.5234567, Decimal("1.52")),
(0, Decimal("0.00")),
(23.2345678, Decimal("23.24")),
],
)
def test_decimal_tax_percent(self, inputted, expected, monkeypatch):
fake_invoice = deepcopy(FAKE_INVOICE)
fake_invoice["tax_percent"] = inputted

def mock_invoice_get(*args, **kwargs):
return fake_invoice

def mock_customer_get(*args, **kwargs):
return FAKE_CUSTOMER

def mock_charge_get(*args, **kwargs):
return FAKE_CHARGE

def mock_payment_method_get(*args, **kwargs):
return FAKE_CARD_AS_PAYMENT_METHOD

def mock_payment_intent_get(*args, **kwargs):
return FAKE_PAYMENT_INTENT_I

def mock_subscription_get(*args, **kwargs):
return FAKE_SUBSCRIPTION

def mock_balance_transaction_get(*args, **kwargs):
return FAKE_BALANCE_TRANSACTION

def mock_product_get(*args, **kwargs):
return FAKE_PRODUCT

# monkeypatch stripe retrieve calls to return
# the desired json response.
monkeypatch.setattr(stripe.Invoice, "retrieve", mock_invoice_get)
monkeypatch.setattr(stripe.Customer, "retrieve", mock_customer_get)
monkeypatch.setattr(
stripe.BalanceTransaction, "retrieve", mock_balance_transaction_get
)
monkeypatch.setattr(stripe.Subscription, "retrieve", mock_subscription_get)
monkeypatch.setattr(stripe.Charge, "retrieve", mock_charge_get)
monkeypatch.setattr(stripe.PaymentMethod, "retrieve", mock_payment_method_get)
monkeypatch.setattr(stripe.PaymentIntent, "retrieve", mock_payment_intent_get)
monkeypatch.setattr(stripe.Product, "retrieve", mock_product_get)

invoice = Invoice.sync_from_stripe_data(fake_invoice)
field_data = invoice.tax_percent

assert isinstance(field_data, Decimal)
assert field_data == expected
78 changes: 78 additions & 0 deletions tests/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
"""
from copy import deepcopy
from datetime import datetime
from decimal import Decimal
from unittest.mock import patch

import pytest
import stripe
from django.contrib.auth import get_user_model
from django.test import TestCase
from django.utils import timezone
Expand Down Expand Up @@ -43,6 +46,8 @@
datetime_to_unix,
)

pytestmark = pytest.mark.django_db

# TODO: test with Prices instead of Plans when creating Subscriptions
# with Prices is fully supported

Expand Down Expand Up @@ -1094,3 +1099,76 @@ def test_sync_metered_plan(
| {"djstripe.Subscription.latest_invoice"}
),
)


class TestSubscriptionDecimal:
@pytest.mark.parametrize(
"inputted,expected",
[
(Decimal("1"), Decimal("1.00")),
(Decimal("1.5234567"), Decimal("1.52")),
(Decimal("0"), Decimal("0.00")),
(Decimal("23.2345678"), Decimal("23.23")),
("1", Decimal("1.00")),
("1.5234567", Decimal("1.52")),
("0", Decimal("0.00")),
("23.2345678", Decimal("23.23")),
(1, Decimal("1.00")),
(1.5234567, Decimal("1.52")),
(0, Decimal("0.00")),
(23.2345678, Decimal("23.24")),
],
)
def test_decimal_application_fee_percent(self, inputted, expected, monkeypatch):
fake_subscription = deepcopy(FAKE_SUBSCRIPTION)
fake_subscription["application_fee_percent"] = inputted

def mock_invoice_get(*args, **kwargs):
return FAKE_INVOICE

def mock_customer_get(*args, **kwargs):
return FAKE_CUSTOMER

def mock_charge_get(*args, **kwargs):
return FAKE_CHARGE

def mock_payment_method_get(*args, **kwargs):
return FAKE_CARD_AS_PAYMENT_METHOD

def mock_payment_intent_get(*args, **kwargs):
return FAKE_PAYMENT_INTENT_I

def mock_subscription_get(*args, **kwargs):
return fake_subscription

def mock_balance_transaction_get(*args, **kwargs):
return FAKE_BALANCE_TRANSACTION

def mock_product_get(*args, **kwargs):
return FAKE_PRODUCT

def mock_plan_get(*args, **kwargs):
return FAKE_PLAN

# monkeypatch stripe retrieve calls to return
# the desired json response.
monkeypatch.setattr(stripe.Invoice, "retrieve", mock_invoice_get)
monkeypatch.setattr(stripe.Customer, "retrieve", mock_customer_get)
monkeypatch.setattr(
stripe.BalanceTransaction, "retrieve", mock_balance_transaction_get
)
monkeypatch.setattr(stripe.Subscription, "retrieve", mock_subscription_get)
monkeypatch.setattr(stripe.Charge, "retrieve", mock_charge_get)
monkeypatch.setattr(stripe.PaymentMethod, "retrieve", mock_payment_method_get)
monkeypatch.setattr(stripe.PaymentIntent, "retrieve", mock_payment_intent_get)
monkeypatch.setattr(stripe.Product, "retrieve", mock_product_get)
monkeypatch.setattr(stripe.Plan, "retrieve", mock_plan_get)

# Create Latest Invoice
Invoice.sync_from_stripe_data(FAKE_INVOICE)

subscription = Subscription.sync_from_stripe_data(fake_subscription)
field_data = subscription.application_fee_percent

assert isinstance(field_data, Decimal)
assert field_data == expected
53 changes: 31 additions & 22 deletions tests/test_tax_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from copy import deepcopy
from decimal import Decimal

import pytest
from django.test import TestCase

from djstripe.models import TaxRate
from tests import (
FAKE_TAX_RATE_EXAMPLE_1_VAT,
FAKE_TAX_RATE_EXAMPLE_2_SALES,
AssertStripeFksMixin,
)
from tests import FAKE_TAX_RATE_EXAMPLE_1_VAT, AssertStripeFksMixin

pytestmark = pytest.mark.django_db


class TaxRateTest(AssertStripeFksMixin, TestCase):
Expand All @@ -25,21 +24,31 @@ def test___str__(self):
str(tax_rate),
)

def test_sync_from_stripe_data(self):
tax_rate = TaxRate.sync_from_stripe_data(deepcopy(FAKE_TAX_RATE_EXAMPLE_1_VAT))
# need to refresh to load percentage as decimal
tax_rate.refresh_from_db()

self.assertIsInstance(tax_rate.percentage, Decimal)
self.assertEqual(tax_rate.percentage, Decimal("15.0"))

def test_sync_from_stripe_data_non_integer(self):
# an example non-integer taxrate
tax_rate = TaxRate.sync_from_stripe_data(
deepcopy(FAKE_TAX_RATE_EXAMPLE_2_SALES)
)
# need to refresh to load percentage as decimal
tax_rate.refresh_from_db()

self.assertIsInstance(tax_rate.percentage, Decimal)
self.assertEqual(tax_rate.percentage, Decimal("4.25"))
class TestTaxRateDecimal:
@pytest.mark.parametrize(
"inputted,expected",
[
(Decimal("1"), Decimal("1.00")),
(Decimal("1.5234567"), Decimal("1.52")),
(Decimal("0"), Decimal("0.00")),
(Decimal("23.2345678"), Decimal("23.23")),
("1", Decimal("1.00")),
("1.5234567", Decimal("1.52")),
("0", Decimal("0.00")),
("23.2345678", Decimal("23.23")),
(1, Decimal("1.00")),
(1.5234567, Decimal("1.52")),
(0, Decimal("0.00")),
(23.2345678, Decimal("23.24")),
],
)
def test_decimal_tax_percent(self, inputted, expected):
fake_tax_rate = deepcopy(FAKE_TAX_RATE_EXAMPLE_1_VAT)
fake_tax_rate["percentage"] = inputted

tax_rate = TaxRate.sync_from_stripe_data(fake_tax_rate)
field_data = tax_rate.percentage

assert isinstance(field_data, Decimal)
assert field_data == expected

0 comments on commit 4ba6954

Please sign in to comment.