Skip to content

Commit

Permalink
Merge branch 'feature/variants_and_shipping_zones_per_channel' of git…
Browse files Browse the repository at this point in the history
…hub.com:mirumee/saleor into feature/variants_and_shipping_zones_per_channel
  • Loading branch information
IKarbowiak committed Apr 12, 2021
2 parents 9738bdf + c49a0a3 commit 46e091f
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 65 deletions.
6 changes: 4 additions & 2 deletions saleor/graphql/product/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
VariantAttributesByProductTypeIdLoader,
)
from .products import (
AvailableProductVariantsByProductVariantIdAndChannel,
AvailableProductVariantsByProductIdAndChannel,
CategoryByIdLoader,
CategoryChildrenByCategoryIdLoader,
CollectionByIdLoader,
Expand All @@ -29,6 +29,7 @@
ProductTypeByVariantIdLoader,
ProductVariantByIdLoader,
ProductVariantChannelListingByIdLoader,
ProductVariantsByProductIdAndChannel,
ProductVariantsByProductIdLoader,
VariantChannelListingByVariantIdAndChannelIdLoader,
VariantChannelListingByVariantIdAndChannelSlugLoader,
Expand Down Expand Up @@ -69,5 +70,6 @@
"VariantChannelListingByVariantIdAndChannelIdLoader",
"VariantChannelListingByVariantIdLoader",
"VariantsChannelListingByProductIdAndChannelSlugLoader",
"AvailableProductVariantsByProductVariantIdAndChannel",
"ProductVariantsByProductIdAndChannel",
"AvailableProductVariantsByProductIdAndChannel",
]
57 changes: 31 additions & 26 deletions saleor/graphql/product/dataloaders/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,36 +180,41 @@ def batch_load(self, keys):
return [variant_map.get(product_id, []) for product_id in keys]


class AvailableProductVariantsByProductVariantIdAndChannel(
DataLoader[VariantIdAndChannelSlug, ProductVariantChannelListing]
):
context_key = "available_productvariant_by_variant_and_channel"
field = "slug"
class ProductVariantsByProductIdAndChannel(DataLoader):
context_key = "productvariant_by_product_and_channel"

def batch_load(self, keys):
# Split the list of keys by channel first. A typical query will only touch
# a handful of unique countries but may access thousands of product variants
# so it's cheaper to execute one query per channel.
available_variants_by_channel: DefaultDict[str, List[int]] = defaultdict(list)
for variants, channel in keys:
# For each channel execute a single query for all product variants.
available_variants = self.batch_load_available(channel, variants)
available_variants_by_channel[channel] = available_variants
return [available_variants_by_channel[key] for _, key in keys]

def batch_load_available(
self, channel: str, variant_ids: Iterable[int]
) -> Iterable[int]:
filter = {
f"channel__{self.field}": channel,
"variant_id__in": variant_ids,
"price_amount__isnull": False,
product_ids = [key[0] for key in keys]
channel_slugs = [key[1] for key in keys]
variants_filter = self.get_variants_filter(product_ids, channel_slugs)

variants = ProductVariant.objects.filter(**variants_filter).annotate(
channel_slug=F("channel_listings__channel__slug")
)
variant_map = defaultdict(list)
for variant in variants.iterator():
variant_map[(variant.product_id, variant.channel_slug)].append(variant)

return [variant_map.get(key, []) for key in keys]

def get_variants_filter(self, products_ids, channel_slugs):
return {
"product_id__in": products_ids,
"channel_listings__channel__slug__in": channel_slugs,
}
available_variants = ProductVariantChannelListing.objects.filter(
**filter
).values_list("variant__id", flat=True)

return list(available_variants)

class AvailableProductVariantsByProductIdAndChannel(
ProductVariantsByProductIdAndChannel
):
context_key = "available_productvariant_by_product_and_channel"

def get_variants_filter(self, products_ids, channel_slugs):
return {
"product_id__in": products_ids,
"channel_listings__channel__slug__in": channel_slugs,
"channel_listings__price_amount__isnull": False,
}


class ProductVariantChannelListingByIdLoader(DataLoader):
Expand Down
60 changes: 56 additions & 4 deletions saleor/graphql/product/tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def query_categories_with_filter():
node {
id
name
variants {
id
}
}
}
}
Expand Down Expand Up @@ -1071,6 +1074,51 @@ def test_fetch_all_products_available_as_staff_user(
assert len(content["data"]["products"]["edges"]) == num_products


def test_fetch_all_product_variants_available_as_staff_user_with_channel(
staff_api_client, permission_manage_products, product_variant_list, channel_USD
):
variables = {"channel": channel_USD.slug}
response = staff_api_client.post_graphql(
QUERY_FETCH_ALL_PRODUCTS,
variables,
permissions=(permission_manage_products,),
check_no_permissions=False,
)
num_products = Product.objects.count()
num_variants = ProductVariant.objects.count()
assert num_variants > 1

content = get_graphql_content(response)
products = content["data"]["products"]
variants = products["edges"][0]["node"]["variants"]

assert products["totalCount"] == num_products
assert len(products["edges"]) == num_products
assert len(variants) == num_variants - 1


def test_fetch_all_product_variants_available_as_staff_user_without_channel(
staff_api_client, permission_manage_products, product_variant_list, channel_USD
):
response = staff_api_client.post_graphql(
QUERY_FETCH_ALL_PRODUCTS,
permissions=(permission_manage_products,),
check_no_permissions=False,
)

num_products = Product.objects.count()
num_variants = ProductVariant.objects.count()
assert num_variants > 1

content = get_graphql_content(response)
products = content["data"]["products"]
variants = products["edges"][0]["node"]["variants"]

assert products["totalCount"] == num_products
assert len(products["edges"]) == num_products
assert len(variants) == num_variants


def test_fetch_all_products_not_available_as_staff_user(
staff_api_client, permission_manage_products, product, channel_USD
):
Expand Down Expand Up @@ -7530,10 +7578,10 @@ def test_product_variant_without_price_as_staff(
stock,
channel_USD,
):
product = variant.product
ProductVariantChannelListing.objects.filter(
channel=channel_USD, variant__product_id=product.pk
).update(price_amount=None)

variant_channel_listing = variant.channel_listings.first()
variant_channel_listing.price_amount = None
variant_channel_listing.save()

query = """
query getProductVariants($id: ID!, $channel: String, $address: AddressInput) {
Expand Down Expand Up @@ -7561,8 +7609,12 @@ def test_product_variant_without_price_as_staff(
response = staff_api_client.post_graphql(query, variables)
content = get_graphql_content(response)
variants_data = content["data"]["product"]["variants"]

assert variants_data[0]["pricing"] is not None

assert variants_data[1]["id"] == variant_id
assert variants_data[1]["pricing"] is None

assert len(variants_data) == 2


Expand Down
40 changes: 11 additions & 29 deletions saleor/graphql/product/types/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
OrderLinesByVariantIdAndChannelIdLoader,
)
from ...product.dataloaders.products import (
AvailableProductVariantsByProductVariantIdAndChannel,
AvailableProductVariantsByProductIdAndChannel,
ProductVariantsByProductIdAndChannel,
)
from ...translations.fields import TranslationField
from ...translations.types import (
Expand Down Expand Up @@ -780,29 +781,15 @@ def resolve_images(root: ChannelContext[models.Product], info, **_kwargs):
def resolve_variants(root: ChannelContext[models.Product], info, **_kwargs):
requestor = get_user_or_app_from_context(info.context)
is_staff = requestor_is_staff_member_or_app(requestor)

def return_available_variants(variants):
if is_staff:
return variants

variants_id = (variant.id for variant in variants)
available_variants_in_channel = (
AvailableProductVariantsByProductVariantIdAndChannel(info.context).load(
(variants_id, root.channel_slug)
)
if is_staff and not root.channel_slug:
variants = ProductVariantsByProductIdLoader(info.context).load(root.node.id)
elif is_staff and root.channel_slug:
variants = ProductVariantsByProductIdAndChannel(info.context).load(
(root.node.id, root.channel_slug)
)

def return_available_variants_in_channel(available_variants):
result = []
variants_map = {variant.id: variant for variant in variants}
for variant_id in variants_map.keys() & available_variants:
variant = variants_map.get(variant_id)
if variant:
result.append(variant)
return result

return available_variants_in_channel.then(
return_available_variants_in_channel
else:
variants = AvailableProductVariantsByProductIdAndChannel(info.context).load(
(root.node.id, root.channel_slug)
)

def map_channel_context(variants):
Expand All @@ -811,12 +798,7 @@ def map_channel_context(variants):
for variant in variants
]

return (
ProductVariantsByProductIdLoader(info.context)
.load(root.node.id)
.then(return_available_variants)
.then(map_channel_context)
)
return variants.then(map_channel_context)

@staticmethod
@permission_required(ProductPermissions.MANAGE_PRODUCTS)
Expand Down
8 changes: 4 additions & 4 deletions saleor/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,7 @@ def product_with_collections(


@pytest.fixture
def product_available_in_many_channels(product, channel_PLN):
def product_available_in_many_channels(product, channel_PLN, channel_USD):
ProductChannelListing.objects.create(
product=product,
channel=channel_PLN,
Expand Down Expand Up @@ -1608,7 +1608,7 @@ def variant_with_many_stocks_different_shipping_zones(


@pytest.fixture
def product_variant_list(product, channel_USD):
def product_variant_list(product, channel_USD, channel_PLN):
variants = list(
ProductVariant.objects.bulk_create(
[
Expand Down Expand Up @@ -1636,10 +1636,10 @@ def product_variant_list(product, channel_USD):
),
ProductVariantChannelListing(
variant=variants[2],
channel=channel_USD,
channel=channel_PLN,
cost_price_amount=Decimal(1),
price_amount=Decimal(10),
currency=channel_USD.currency_code,
currency=channel_PLN.currency_code,
),
]
)
Expand Down

0 comments on commit 46e091f

Please sign in to comment.