Skip to content

Commit

Permalink
Fix 4 issues in codegen (edgedb#387)
Browse files Browse the repository at this point in the history
* Add missing std::json
* Support optional argument
* Fix camelcase generation
* Allow symlinks in project dir
  • Loading branch information
fantix authored Nov 2, 2022
1 parent 314ec4a commit a912511
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 100 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:

strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11.0-rc.2"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
edgedb-version: [stable , nightly]
os: [ubuntu-latest, macos-latest, windows-2019]
loop: [asyncio, uvloop]
Expand Down
58 changes: 35 additions & 23 deletions edgedb/codegen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"std::decimal": "decimal.Decimal",
"std::datetime": "datetime.datetime",
"std::duration": "datetime.timedelta",
"std::json": "str",
"cal::local_date": "datetime.date",
"cal::local_time": "datetime.time",
"cal::local_datetime": "datetime.datetime",
Expand Down Expand Up @@ -125,9 +126,7 @@ def __init__(self, args: argparse.Namespace):
self._skip_pydantic_validation = args.skip_pydantic_validation
self._async = False
try:
self._project_dir = pathlib.Path(
find_edgedb_project_dir()
).resolve()
self._project_dir = pathlib.Path(find_edgedb_project_dir())
except edgedb.ClientConnectionError:
print(
"Cannot find edgedb.toml: "
Expand Down Expand Up @@ -174,7 +173,6 @@ def run(self):

def _process_dir(self, dir_: pathlib.Path):
for file_or_dir in dir_.iterdir():
file_or_dir = file_or_dir.resolve()
if not file_or_dir.exists():
continue
if file_or_dir.is_dir():
Expand Down Expand Up @@ -283,10 +281,7 @@ def _generate(
) -> str:
buf = io.StringIO()

if "_" in name or name.islower():
name_hint = f"{name}_result"
else:
name_hint = f"{name}Result"
name_hint = f"{self._snake_to_camel(name)}Result"
out_type = self._generate_code(dr.output_type, name_hint)
if dr.output_cardinality.is_multi():
if SYS_VERSION_INFO >= (3, 9):
Expand All @@ -306,14 +301,16 @@ def _generate(
if isinstance(dr.input_type, describe.ObjectType):
if "".join(dr.input_type.elements.keys()).isdecimal():
for el_name, el in dr.input_type.elements.items():
args[int(el_name)] = self._generate_code(
el.type, f"arg{el_name}"
args[int(el_name)] = self._generate_code_with_cardinality(
el.type, f"arg{el_name}", el.cardinality
)
args = {f"arg{i}": v for i, v in sorted(args.items())}
else:
kw_only = True
for el_name, el in dr.input_type.elements.items():
args[el_name] = self._generate_code(el.type, el_name)
args[el_name] = self._generate_code_with_cardinality(
el.type, el_name, el.cardinality
)

if self._async:
print(f"async def {name}(", file=buf)
Expand Down Expand Up @@ -417,15 +414,10 @@ def _generate_code(
for el_name, element in type_.elements.items():
if element.is_implicit and el_name != "id":
continue
el_code = self._generate_code(
element.type, f"{rv}{el_name.title()}"
name_hint = f"{rv}{self._snake_to_camel(el_name)}"
el_code = self._generate_code_with_cardinality(
element.type, name_hint, element.cardinality
)
if element.cardinality == edgedb.Cardinality.AT_MOST_ONE:
if SYS_VERSION_INFO >= (3, 10):
el_code = f"{el_code} | None"
else:
self._imports.add("typing")
el_code = f"typing.Optional[{el_code}]"
if element.kind == edgedb.ElementKind.LINK_PROPERTY:
link_props.append((el_name, el_code))
else:
Expand Down Expand Up @@ -465,7 +457,7 @@ def _generate_code(
print(f"class {rv}(typing.NamedTuple):", file=buf)
for el_name, el_type in type_.element_types.items():
el_code = self._generate_code(
el_type, f"{rv}{el_name.title()}"
el_type, f"{rv}{self._snake_to_camel(el_name)}"
)
print(f"{INDENT}{el_name}: {el_code}", file=buf)
self._defs[rv] = buf.getvalue().strip()
Expand All @@ -489,14 +481,27 @@ def _generate_code(
self._cache[type_.desc_id] = rv
return rv

def _generate_code_with_cardinality(
self,
type_: typing.Optional[describe.AnyType],
name_hint: str,
cardinality: edgedb.Cardinality,
):
rv = self._generate_code(type_, name_hint)
if cardinality == edgedb.Cardinality.AT_MOST_ONE:
if SYS_VERSION_INFO >= (3, 10):
rv = f"{rv} | None"
else:
self._imports.add("typing")
rv = f"typing.Optional[{rv}]"
return rv

def _find_name(self, name: str) -> str:
default_prefix = f"{self._default_module}::"
if name.startswith(default_prefix):
name = name[len(default_prefix) :]
mod, _, name = name.rpartition("::")
parts = name.split("_")
if len(parts) > 1 or name.islower():
name = "".join(map(str.title, parts))
name = self._snake_to_camel(name)
name = mod.title() + name
if name in self._names:
for i in range(2, 100):
Expand All @@ -512,3 +517,10 @@ def _find_name(self, name: str) -> str:
sys.exit(17)
self._names.add(name)
return name

def _snake_to_camel(self, name: str) -> str:
parts = name.split("_")
if len(parts) > 1 or name.islower():
return "".join(map(str.title, parts))
else:
return name
1 change: 1 addition & 0 deletions tests/codegen/linked/test_linked.edgeql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 42
16 changes: 16 additions & 0 deletions tests/codegen/linked/test_linked_async_edgeql.py.assert
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# AUTOGENERATED FROM 'linked/test_linked.edgeql' WITH:
# $ edgedb-py


from __future__ import annotations
import edgedb


async def test_linked(
client: edgedb.AsyncIOClient,
) -> int:
return await client.query_single(
"""\
select 42\
""",
)
16 changes: 16 additions & 0 deletions tests/codegen/linked/test_linked_edgeql.py.assert
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# AUTOGENERATED FROM 'linked/test_linked.edgeql' WITH:
# $ edgedb-py --target blocking --no-skip-pydantic-validation


from __future__ import annotations
import edgedb


def test_linked(
client: edgedb.Client,
) -> int:
return client.query_single(
"""\
select 42\
""",
)
45 changes: 44 additions & 1 deletion tests/codegen/test-project1/generated_async_edgeql.py.assert
Original file line number Diff line number Diff line change
@@ -1,9 +1,42 @@
# AUTOGENERATED FROM 'select_scalar.edgeql' WITH:
# AUTOGENERATED FROM:
# 'select_optional_json.edgeql'
# 'select_scalar.edgeql'
# 'linked/test_linked.edgeql'
# WITH:
# $ edgedb-py --target async --file --no-skip-pydantic-validation


from __future__ import annotations
import dataclasses
import edgedb
import uuid


@dataclasses.dataclass
class SelectOptionalJsonResultItem:
id: uuid.UUID
snake_case: SelectOptionalJsonResultItemSnakeCase | None


@dataclasses.dataclass
class SelectOptionalJsonResultItemSnakeCase:
id: uuid.UUID


async def select_optional_json(
client: edgedb.AsyncIOClient,
arg0: str | None,
) -> list[tuple[str, SelectOptionalJsonResultItem]]:
return await client.query(
"""\
create type TestCase {
create link snake_case -> TestCase;
};

select (<optional json>$0, TestCase {snake_case});\
""",
arg0,
)


async def select_scalar(
Expand All @@ -14,3 +47,13 @@ async def select_scalar(
select 1;\
""",
)


async def test_linked(
client: edgedb.AsyncIOClient,
) -> int:
return await client.query_single(
"""\
select 42\
""",
)
1 change: 1 addition & 0 deletions tests/codegen/test-project1/linked
5 changes: 5 additions & 0 deletions tests/codegen/test-project1/select_optional_json.edgeql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
create type TestCase {
create link snake_case -> TestCase;
};

select (<optional json>$0, TestCase {snake_case});
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# AUTOGENERATED FROM 'select_optional_json.edgeql' WITH:
# $ edgedb-py


from __future__ import annotations
import dataclasses
import edgedb
import typing
import uuid


class NoPydanticValidation:
@classmethod
def __get_validators__(cls):
from pydantic.dataclasses import dataclass as pydantic_dataclass
pydantic_dataclass(cls)
cls.__pydantic_model__.__get_validators__ = lambda: []
return []


@dataclasses.dataclass
class SelectOptionalJsonResultItem(NoPydanticValidation):
id: uuid.UUID
snake_case: typing.Optional[SelectOptionalJsonResultItemSnakeCase]


@dataclasses.dataclass
class SelectOptionalJsonResultItemSnakeCase(NoPydanticValidation):
id: uuid.UUID


async def select_optional_json(
client: edgedb.AsyncIOClient,
arg0: typing.Optional[str],
) -> typing.List[typing.Tuple[str, SelectOptionalJsonResultItem]]:
return await client.query(
"""\
create type TestCase {
create link snake_case -> TestCase;
};

select (<optional json>$0, TestCase {snake_case});\
""",
arg0,
)
36 changes: 36 additions & 0 deletions tests/codegen/test-project1/select_optional_json_edgeql.py.assert
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# AUTOGENERATED FROM 'select_optional_json.edgeql' WITH:
# $ edgedb-py --target blocking --no-skip-pydantic-validation


from __future__ import annotations
import dataclasses
import edgedb
import typing
import uuid


@dataclasses.dataclass
class SelectOptionalJsonResultItem:
id: uuid.UUID
snake_case: typing.Optional[SelectOptionalJsonResultItemSnakeCase]


@dataclasses.dataclass
class SelectOptionalJsonResultItemSnakeCase:
id: uuid.UUID


def select_optional_json(
client: edgedb.Client,
arg0: typing.Optional[str],
) -> list[tuple[str, SelectOptionalJsonResultItem]]:
return client.query(
"""\
create type TestCase {
create link snake_case -> TestCase;
};

select (<optional json>$0, TestCase {snake_case});\
""",
arg0,
)
Loading

0 comments on commit a912511

Please sign in to comment.