Skip to content

Commit

Permalink
fix(joins): relations should not collide with mapped scalar fields (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
Weakky authored Feb 16, 2024
1 parent ca42e86 commit 73fdee2
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod prisma_21182;
mod prisma_21369;
mod prisma_21901;
mod prisma_22298;
mod prisma_22971;
mod prisma_5952;
mod prisma_6173;
mod prisma_7010;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use indoc::indoc;
use query_engine_tests::*;

#[test_suite(schema(schema))]
mod prisma_22971 {
fn schema() -> String {
let schema = indoc! {
r#"model User {
#id(id, Int, @id, @map("hello"))
updatedAt String @default("now") @map("updated_at")
postId Int? @map("post")
post Post? @relation("User_post", fields: [postId], references: [id])
}
model Post {
#id(id, Int, @id, @map("world"))
updatedAt String @default("now") @map("up_at")
from_User_post User[] @relation("User_post")
}"#
};

schema.to_owned()
}

// Ensures that mapped fields are correctly resolved, even when there's a conflict between a scalar field name and a relation field name.
#[connector_test]
async fn test_22971(runner: Runner) -> TestResult<()> {
run_query!(&runner, r#"mutation { createOnePost(data: { id: 1 }) { id } }"#);
run_query!(
&runner,
r#"mutation { createOneUser(data: { id: 1, postId: 1 }) { id } }"#
);

insta::assert_snapshot!(
run_query!(&runner, r#"{
findManyUser {
id
updatedAt
post {
id
updatedAt
}
}
}"#),
@r###"{"data":{"findManyUser":[{"id":1,"updatedAt":"now","post":{"id":1,"updatedAt":"now"}}]}}"###
);

Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ fn coerce_json_relation_to_pv(value: serde_json::Value, rs: &RelationSelection)
let related_model = rs.field.related_model();

for (key, value) in obj {
match related_model.fields().all().find(|f| f.db_name() == key) {
match related_model.fields().all().find(|f| f.name() == key) {
Some(Field::Scalar(sf)) => {
map.push((key, coerce_json_scalar_to_pv(value, &sf)?));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub(crate) async fn get_single_record_joins(
ctx: &Context<'_>,
) -> crate::Result<Option<SingleRecord>> {
let selected_fields = selected_fields.to_virtuals_last();
let field_names: Vec<_> = selected_fields.db_names_grouping_virtuals().collect();
let field_names: Vec<_> = selected_fields.prisma_names_grouping_virtuals().collect();
let idents = selected_fields.type_identifiers_with_arities_grouping_virtuals();

let indexes = get_selection_indexes(
Expand Down Expand Up @@ -132,7 +132,7 @@ pub(crate) async fn get_many_records_joins(
ctx: &Context<'_>,
) -> crate::Result<ManyRecords> {
let selected_fields = selected_fields.to_virtuals_last();
let field_names: Vec<_> = selected_fields.db_names_grouping_virtuals().collect();
let field_names: Vec<_> = selected_fields.prisma_names_grouping_virtuals().collect();
let idents = selected_fields.type_identifiers_with_arities_grouping_virtuals();
let meta = column_metadata::create(field_names.as_slice(), idents.as_slice());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,7 @@ impl JoinSelectBuilder for LateralJoinSelectBuilder {
ctx: &Context<'_>,
) -> Select<'a> {
match field {
SelectedField::Scalar(sf) => select.column(
sf.as_column(ctx)
.table(parent_alias.to_table_string())
.set_is_selected(true),
),
SelectedField::Scalar(sf) => select.column(aliased_scalar_column(sf, parent_alias, ctx)),
SelectedField::Relation(rs) => {
let table_name = match rs.field.relation().is_many_to_many() {
true => m2m_join_alias_name(&rs.field),
Expand Down Expand Up @@ -170,7 +166,7 @@ impl JoinSelectBuilder for LateralJoinSelectBuilder {
.iter()
.filter_map(|field| match field {
SelectedField::Scalar(sf) => Some((
Cow::from(sf.db_name().to_owned()),
Cow::from(sf.name().to_owned()),
Expression::from(sf.as_column(ctx).table(parent_alias.to_table_string())),
)),
SelectedField::Relation(rs) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use query_structure::*;
use crate::{
context::Context,
filter::alias::Alias,
model_extensions::{AsColumns, AsTable, ColumnIterator, RelationFieldExt},
model_extensions::{AsColumn, AsColumns, AsTable, ColumnIterator, RelationFieldExt},
ordering::OrderByBuilder,
sql_trace::SqlTraceComment,
};
Expand Down Expand Up @@ -628,6 +628,19 @@ fn json_agg() -> Function<'static> {
.alias(JSON_AGG_IDENT)
}

pub(crate) fn aliased_scalar_column(sf: &ScalarField, parent_alias: Alias, ctx: &Context<'_>) -> Column<'static> {
let col = sf
.as_column(ctx)
.table(parent_alias.to_table_string())
.set_is_selected(true);

if sf.name() != sf.db_name() {
col.alias(sf.name().to_owned())
} else {
col
}
}

#[inline]
fn empty_json_array() -> serde_json::Value {
serde_json::Value::Array(Vec::new())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ impl JoinSelectBuilder for SubqueriesSelectBuilder {
ctx: &Context<'_>,
) -> Select<'a> {
match field {
SelectedField::Scalar(sf) => select.column(
sf.as_column(ctx)
.table(parent_alias.to_table_string())
.set_is_selected(true),
),
SelectedField::Scalar(sf) => select.column(aliased_scalar_column(sf, parent_alias, ctx)),
SelectedField::Relation(rs) => self.with_relation(select, rs, Vec::new().iter(), parent_alias, ctx),
_ => select,
}
Expand Down Expand Up @@ -115,7 +111,7 @@ impl JoinSelectBuilder for SubqueriesSelectBuilder {
.iter()
.filter_map(|field| match field {
SelectedField::Scalar(sf) => Some((
Cow::from(sf.db_name().to_owned()),
Cow::from(sf.name().to_owned()),
Expression::from(sf.as_column(ctx).table(parent_alias.to_table_string())),
)),
SelectedField::Relation(rs) => Some((
Expand Down
11 changes: 2 additions & 9 deletions query-engine/core/src/response_ir/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,6 @@ impl<'a, 'b> SerializedFieldWithRelations<'a, 'b> {
Self::VirtualsGroup(name, _) => name,
}
}

fn db_name(&self) -> &str {
match self {
Self::Model(f, _) => f.db_name(),
Self::VirtualsGroup(name, _) => name,
}
}
}

// TODO: Handle errors properly
Expand Down Expand Up @@ -427,7 +420,7 @@ fn serialize_relation_selection(
let fields = collect_serialized_fields_with_relations(typ, &rrs.model, &rrs.virtuals, &rrs.fields);

for field in fields {
let value = value_obj.remove(field.db_name()).unwrap();
let value = value_obj.remove(field.name()).unwrap();

match field {
SerializedFieldWithRelations::Model(Field::Scalar(_), out_field) if !out_field.field_type().is_object() => {
Expand Down Expand Up @@ -481,7 +474,7 @@ fn collect_serialized_fields_with_relations<'a, 'b>(
model
.fields()
.all()
.find(|field| field.db_name() == name)
.find(|field| field.name() == name)
.and_then(|field| {
object_type
.find_field(field.name())
Expand Down
10 changes: 5 additions & 5 deletions query-engine/query-structure/src/field_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl FieldSelection {
/// [`FieldSelection::db_names_grouping_virtuals`] and
/// [`FieldSelection::type_identifiers_with_arities_grouping_virtuals`].
fn selections_with_virtual_group_heads(&self) -> impl Iterator<Item = &SelectedField> {
self.selections().unique_by(|f| f.db_name_grouping_virtuals())
self.selections().unique_by(|f| f.prisma_name_grouping_virtuals())
}

/// Returns all Prisma (e.g. schema model field) names of contained fields.
Expand All @@ -102,9 +102,9 @@ impl FieldSelection {
/// into the grouped containers for virtual fields, like `_count`. The names returned by this
/// method correspond to the results of queries that use JSON objects to represent joined
/// relations and relation aggregations.
pub fn db_names_grouping_virtuals(&self) -> impl Iterator<Item = String> + '_ {
pub fn prisma_names_grouping_virtuals(&self) -> impl Iterator<Item = String> + '_ {
self.selections_with_virtual_group_heads()
.map(|f| f.db_name_grouping_virtuals())
.map(|f| f.prisma_name_grouping_virtuals())
.map(Cow::into_owned)
}

Expand Down Expand Up @@ -384,10 +384,10 @@ impl SelectedField {
/// relations and relation aggregations. For those queries, the result of this method
/// corresponds to the top-level name of the value which is a JSON object that contains this
/// field inside.
pub fn db_name_grouping_virtuals(&self) -> Cow<'_, str> {
pub fn prisma_name_grouping_virtuals(&self) -> Cow<'_, str> {
match self {
SelectedField::Virtual(vs) => vs.serialized_name().0.into(),
_ => self.db_name(),
_ => self.prisma_name(),
}
}

Expand Down

0 comments on commit 73fdee2

Please sign in to comment.