Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: as_offset implementation in embedded #1397

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5b6b8b7
as_offset implementation in embedded
nfarabullini Dec 13, 2023
81c0141
Merge branch 'main' into as_offset_embedded
nfarabullini Dec 13, 2023
1acb1d8
edit to exclusion_matrices
nfarabullini Dec 13, 2023
b94e81c
edit to exclusion_matrices
nfarabullini Dec 13, 2023
f815712
resolved some pre-commit errors
nfarabullini Dec 13, 2023
67f8117
resolved some pre-commit errors
nfarabullini Dec 13, 2023
e17ff41
implemented EXPERIMENTAL_FUN_BUILTIN_NAMES
nfarabullini Dec 14, 2023
0e61be2
edits for as_offset
nfarabullini Jan 4, 2024
0219e72
additional cleanup
nfarabullini Jan 4, 2024
762d7eb
additional cleanup
nfarabullini Jan 4, 2024
fa1c588
reverted a couple of edits
nfarabullini Jan 4, 2024
38c052c
ran pre-commit
nfarabullini Jan 4, 2024
e8d6e5e
edit to test
nfarabullini Jan 4, 2024
782375b
edit for md dimensional field
nfarabullini Jan 5, 2024
654b14d
replaced connectivity with restricted
nfarabullini Jan 5, 2024
186b81d
edit to as_offset in experimental
nfarabullini Jan 5, 2024
e50eb64
small clenaup
nfarabullini Jan 5, 2024
09a4c44
updated code to checked vars
nfarabullini Jan 5, 2024
9f2bcc6
ran pre-commit
nfarabullini Jan 5, 2024
90f6796
removed [0][0] indexing
nfarabullini Jan 5, 2024
b19ba78
edits for tests and others for as_offset
nfarabullini Jan 9, 2024
c5464f3
edits to test
nfarabullini Jan 15, 2024
7c3e9cb
ran pre-commit
nfarabullini Jan 15, 2024
2ea83f1
edits
nfarabullini Jan 15, 2024
6387049
edits for other offsets
nfarabullini Jan 15, 2024
1e892b4
changes to path
nfarabullini Jan 15, 2024
2f3d9f6
edit for dace backend
nfarabullini Jan 16, 2024
b941b92
update with main
nfarabullini Jan 16, 2024
81fc838
trout attempt for fieldoffset in cache
nfarabullini Jan 16, 2024
7261ae7
edit suggested by Edoardo
nfarabullini Jan 16, 2024
ba1a91f
edit to offset_invariants
nfarabullini Jan 16, 2024
e074b10
edits following Hannes' review
nfarabullini Jan 17, 2024
ea7f20f
ran pre-commit
nfarabullini Jan 17, 2024
c483a0f
commented test out
nfarabullini Jan 18, 2024
b363cc2
placed test back
nfarabullini Jan 18, 2024
391508b
edit to failing test
nfarabullini Jan 18, 2024
c668dc8
Update tests/next_tests/integration_tests/feature_tests/ffront_tests/…
nfarabullini Jan 18, 2024
83a4b38
Merge branch 'main' of https://github.com/nfarabullini/gt4py into as_…
nfarabullini Jan 18, 2024
7e26c20
edits to dimensions refactoring
nfarabullini Jan 22, 2024
8bc6725
minor cleanup
nfarabullini Jan 22, 2024
e3e8db8
edit to test
nfarabullini Jan 23, 2024
5b6e553
Merge branch 'ruff-config' into as_offset_embedded
egparedes Mar 4, 2024
dacc6d5
Merge new style lint config
egparedes Mar 4, 2024
34fdeb7
Merge branch 'ruff-config' into as_offset_embedded
egparedes Mar 4, 2024
84abf3d
Recover deleted pieces after merging with main
egparedes Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
as_offset implementation in embedded
  • Loading branch information
nfarabullini committed Dec 13, 2023
commit 5b6b8b751c9d1d20a4769785c49013566b64ec61
35 changes: 31 additions & 4 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common
from gt4py.next.ffront import fbuiltins
from gt4py.next.ffront import experimental, fbuiltins


try:
Expand Down Expand Up @@ -173,13 +173,19 @@ def remap(
assert common.is_connectivity_field(connectivity)

# Compute the new domain
dim = connectivity.codomain
dim = (
connectivity.codomain
if isinstance(connectivity.codomain, common.Dimension)
else connectivity.codomain.source
)
dim_idx = self.domain.dim_index(dim)
if dim_idx is None:
raise ValueError(f"Incompatible index field, expected a field with dimension {dim}.")

current_range: common.UnitRange = self.domain[dim_idx][1]
new_ranges = connectivity.inverse_image(current_range)
if isinstance(connectivity.codomain, fbuiltins.FieldOffset):
new_ranges = [new_ranges[dim_idx]]
new_domain = self.domain.replace(dim_idx, *new_ranges)

# perform contramap
Expand All @@ -201,6 +207,14 @@ def remap(
new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start
# finally, take the new array
new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx)
if len(new_idx_array.shape) > 1 and isinstance(
connectivity.codomain, fbuiltins.FieldOffset
):
new_buffer = (
np.diagonal(new_buffer).T
if dim.kind == "horizontal"
else np.diagonal(new_buffer.T)
)

return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype)

Expand Down Expand Up @@ -343,7 +357,7 @@ def from_array( # type: ignore[override]
cls,
data: npt.ArrayLike | core_defs.NDArrayObject,
/,
codomain: common.DimT,
codomain: common.DimT | fbuiltins.FieldOffset,
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
Expand All @@ -363,7 +377,7 @@ def from_array( # type: ignore[override]
assert len(domain) == array.ndim
assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape))

assert isinstance(codomain, common.Dimension)
assert isinstance(codomain, (common.Dimension, fbuiltins.FieldOffset))

return cls(domain, array, codomain)

Expand Down Expand Up @@ -586,7 +600,20 @@ def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdA
raise AssertionError("This is the NdArrayField implementation of `fbuiltins.astype`.")


def _as_offset(offset_: fbuiltins.FieldOffset, field: common.Field) -> NdArrayConnectivityField:
if isinstance(field, NdArrayField):
# change field.ndarray from local to global
global_index_arr = np.arange(field.shape[0]) + field.ndarray
return NumPyArrayConnectivityField.from_array(
global_index_arr, codomain=offset_, domain=field.domain
)
raise AssertionError(
"This is the NdArrayConnectivityField implementation of `experimental.as_offset`."
)


NdArrayField.register_builtin_func(fbuiltins.astype, _astype)
NdArrayField.register_builtin_func(experimental.as_offset, _as_offset)


def _get_slices_from_domain_slice(
Expand Down
34 changes: 9 additions & 25 deletions src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,14 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from dataclasses import dataclass
from gt4py.next import common
from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset

from gt4py.next.type_system import type_specifications as ts


@dataclass
class BuiltInFunction:
__gt_type: ts.FunctionType

def __call__(self, *args, **kwargs):
"""Act as an empty place holder for the built in function."""

def __gt_type__(self):
return self.__gt_type


as_offset = BuiltInFunction(
ts.FunctionType(
pos_only_args=[
ts.DeferredType(constraint=ts.OffsetType),
ts.DeferredType(constraint=ts.FieldType),
],
pos_or_kw_args={},
kw_only_args={},
returns=ts.DeferredType(constraint=ts.OffsetType),
)
)
@BuiltInFunction
def as_offset(
offset_: FieldOffset,
field: common.Field,
/,
) -> common.ConnectivityField:
raise NotImplementedError()
6 changes: 4 additions & 2 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from gt4py._core import definitions as core_defs
from gt4py.next import common, embedded
from gt4py.next.common import Dimension, Field # noqa: F401 # direct import for TYPE_BUILTINS
from gt4py.next.ffront.experimental import as_offset # noqa: F401
from gt4py.next.iterator import runtime
from gt4py.next.type_system import type_specifications as ts

Expand Down Expand Up @@ -58,6 +57,10 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp
return ts.FieldType
elif t is common.Dimension:
return ts.DimensionType
elif t is FieldOffset:
return ts.OffsetType
elif t is common.ConnectivityField:
return ts.OffsetType
havogt marked this conversation as resolved.
Show resolved Hide resolved
elif t is core_defs.ScalarT:
return ts.ScalarType
elif t is type:
Expand Down Expand Up @@ -297,7 +300,6 @@ def impl(
"broadcast",
"where",
"astype",
"as_offset",
] + MATH_BUILTIN_NAMES

BUILTIN_NAMES = TYPE_BUILTIN_NAMES + FUN_BUILTIN_NAMES
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
isinstance(new_func.type, ts.FunctionType)
and not type_info.is_concrete(return_type)
and isinstance(new_func, foast.Name)
and new_func.id in fbuiltins.FUN_BUILTIN_NAMES
and new_func.id in (fbuiltins.FUN_BUILTIN_NAMES + ["as_offset"])
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
):
visitor = getattr(self, f"_visit_{new_func.id}")
return visitor(new_node, **kwargs)
Expand Down
1 change: 0 additions & 1 deletion tests/next_tests/exclusion_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
]
EMBEDDED_SKIP_LIST = [
(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE),
(CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE),
]
GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [
Expand Down