Skip to content

Commit

Permalink
reuse mapextras obj
Browse files Browse the repository at this point in the history
  • Loading branch information
ianspektor committed Oct 10, 2023
1 parent 3615ad6 commit 66f0f14
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
8 changes: 7 additions & 1 deletion temporian/core/event_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,11 +1424,17 @@ def map(
The function receives the scalar value, and optionally as second
argument a [`MapExtras`][temporian.types.MapExtras] object containing
information about the value's position in the EventSet.
information about the value's position in the EventSet. The MapExtras
object should not be modified by the function, since it is shared across
all calls.
If the output of the functon has a different dtype than the input, the
`output_dtypes` argument must be specified.
This operator is slow. When possible, existing operators should be used.
A Temporian graph with a `map` operator is not serializable.
Usage example with lambda function:
```python
>>> a = tp.event_set(
Expand Down
26 changes: 11 additions & 15 deletions temporian/implementation/numpy/operators/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,20 @@ def __call__(self, input: EventSet) -> Dict[str, EventSet]:
index_data.features,
output_schema.feature_dtypes(),
):
output_values = []
for value, timestamp in zip(
orig_feature, index_data.timestamps
extras = MapExtras(
index_key=index_key,
timestamp=0,
feature_name=feature_schema.name,
)
output_values = [None] * len(orig_feature)
for i, (value, timestamp) in enumerate(
zip(orig_feature, index_data.timestamps)
):
extras.timestamp = timestamp
if receives_extras:
output_values.append(
func(
value,
MapExtras(
index_key=index_key,
timestamp=timestamp,
feature_name=feature_schema.name,
),
)
)
output_values[i] = func(value, extras) # type: ignore
else:
assert len(signature(func).parameters) == 1
output_values.append(func(value))
output_values[i] = func(value) # type: ignore
try:
output_arr = np.array(
output_values, dtype=tp_dtype_to_np_dtype(output_dtype)
Expand Down

0 comments on commit 66f0f14

Please sign in to comment.