Skip to content

Commit

Permalink
add receive_extras param
Browse files Browse the repository at this point in the history
  • Loading branch information
ianspektor committed Oct 10, 2023
1 parent 66f0f14 commit 0c08a20
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 34 deletions.
17 changes: 13 additions & 4 deletions temporian/core/event_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,12 +1418,13 @@ def map(
self: EventSetOrNode,
func: MapFunction,
output_dtypes: Optional[TargetDtypes] = None,
receive_extras: bool = False,
) -> EventSetOrNode:
"""Applies a function on each value of an
[`EventSet`][temporian.EventSet]'s features.
The function receives the scalar value, and optionally as second
argument a [`MapExtras`][temporian.types.MapExtras] object containing
The function receives the scalar value, and if `receive_extras` is True,
also a [`MapExtras`][temporian.types.MapExtras] object containing
information about the value's position in the EventSet. The MapExtras
object should not be modified by the function, since it is shared across
all calls.
Expand Down Expand Up @@ -1486,7 +1487,7 @@ def map(
>>> def f(value, extras):
... return f"{extras.feature_name}-{extras.timestamp}-{value}"
>>> b = a.map(f, output_dtypes=str)
>>> b = a.map(f, output_dtypes=str, receive_extras=True)
>>> b
indexes: ...
(3 events):
Expand All @@ -1506,13 +1507,21 @@ def map(
input dtypes (and not both types mixed), and the values are the
target dtypes for them. All dtypes must be Temporian types (see
`dtype.py`).
receive_extras: Whether the function should receive a
[`MapExtras`][temporian.types.MapExtras] object as second
argument.
Returns:
EventSet with the function applied on each value.
"""
from temporian.core.operators.map import map as tp_map

return tp_map(self, func=func, output_dtypes=output_dtypes)
return tp_map(
self,
func=func,
output_dtypes=output_dtypes,
receive_extras=receive_extras,
)

def moving_count(
self: EventSetOrNode,
Expand Down
14 changes: 6 additions & 8 deletions temporian/core/operators/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Map operator class and public API function definitions."""

from inspect import signature
from typing import Any, Callable, Dict, Optional, Union
from temporian.core import operator_lib
from temporian.core.compilation import compile
Expand Down Expand Up @@ -51,6 +50,7 @@ def __init__(
self,
input: EventSetNode,
func: MapFunction,
receive_extras: bool,
dtype: Optional[DType] = None,
dtype_to_dtype: Optional[Dict[DType, DType]] = None,
feature_name_to_dtype: Optional[Dict[str, DType]] = None,
Expand Down Expand Up @@ -84,11 +84,7 @@ def __init__(
)
assert len(output_dtypes) == len(input.schema.features)

num_params = len(signature(func).parameters)
if num_params > 2 or num_params < 1:
raise ValueError("`func` must receive 1 or 2 arguments.")

self._receives_extras = num_params == 2
self._receive_extras = receive_extras

self.add_attribute("func", func)
self._func = func
Expand All @@ -114,8 +110,8 @@ def func(self) -> MapFunction:
return self._func

@property
def receives_extras(self) -> bool:
return self._receives_extras
def receive_extras(self) -> bool:
return self._receive_extras

@classmethod
def build_op_definition(cls) -> pb.OperatorDef:
Expand All @@ -142,6 +138,7 @@ def map(
input: EventSetOrNode,
func: MapFunction,
output_dtypes: Optional[TargetDtypes],
receive_extras: bool,
) -> EventSetOrNode:
assert isinstance(input, EventSetNode)

Expand All @@ -154,6 +151,7 @@ def map(
return Map(
input=input,
func=func,
receive_extras=receive_extras,
dtype=dtype,
feature_name_to_dtype=feature_name_to_dtype,
dtype_to_dtype=dtype_to_dtype,
Expand Down
21 changes: 1 addition & 20 deletions temporian/core/operators/test/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_basic(self):
def test_with_extras(self):
evset = event_set(timestamps=[1, 2, 3], features={"x": [10, 20, 30]})

result = evset.map(lambda v, e: v + e.timestamp)
result = evset.map(lambda v, e: v + e.timestamp, receive_extras=True)

expected = event_set(
timestamps=[1, 2, 3],
Expand Down Expand Up @@ -118,25 +118,6 @@ def test_wrong_output_dtype(self):
):
evset.map(lambda x: "v" + str(x))

def test_too_many_args(self):
evset = event_set(timestamps=[1], features={"a": [2]})

with self.assertRaisesRegex(
ValueError, "`func` must receive 1 or 2 arguments."
):
evset.map(lambda v, e, z: v + e.timestamp)

def test_too_little_args(self):
evset = event_set(timestamps=[1], features={"a": [2]})

def f():
return 0

with self.assertRaisesRegex(
ValueError, "`func` must receive 1 or 2 arguments."
):
evset.map(f)

def test_serialize_fails(self):
@compile
def f(e):
Expand Down
11 changes: 9 additions & 2 deletions temporian/implementation/numpy/operators/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __call__(self, input: EventSet) -> Dict[str, EventSet]:
output_schema = self.output_schema("output")

func = self.operator.func
receives_extras = self.operator.receives_extras
receive_extras = self.operator.receive_extras

# Create output EventSet
output_evset = EventSet(data={}, schema=output_schema)
Expand All @@ -59,15 +59,21 @@ def __call__(self, input: EventSet) -> Dict[str, EventSet]:
timestamp=0,
feature_name=feature_schema.name,
)

# TODO: preallocate numpy array directly when output dtype isn't
# string (in which case we need to know the max length of func's
# results before doing so)
output_values = [None] * len(orig_feature)

for i, (value, timestamp) in enumerate(
zip(orig_feature, index_data.timestamps)
):
extras.timestamp = timestamp
if receives_extras:
if receive_extras:
output_values[i] = func(value, extras) # type: ignore
else:
output_values[i] = func(value) # type: ignore

try:
output_arr = np.array(
output_values, dtype=tp_dtype_to_np_dtype(output_dtype)
Expand All @@ -79,6 +85,7 @@ def __call__(self, input: EventSet) -> Dict[str, EventSet]:
" correct `output_dypes` and returning those types in"
" `func`."
) from exc

features.append(output_arr)

output_evset.set_index_value(
Expand Down

0 comments on commit 0c08a20

Please sign in to comment.