From ce7e764860ad4ac4f06543e9071ea44ebc231d58 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 8 Nov 2024 08:17:43 -0500 Subject: [PATCH 01/17] fixed repr --- pynapple/core/time_series.py | 59 +++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index c61b3654..88822967 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -101,6 +101,7 @@ def __init__(self, t, d, time_units="s", time_support=None, load_array=True): ) self.dtype = self.values.dtype + self._load_array = load_array def __setitem__(self, key, value): """setter for time series""" @@ -164,7 +165,7 @@ def __array_ufunc__(self, ufunc, method, *args, **kwargs): if isinstance(out, np.ndarray) or is_array_like(out): if out.shape[0] == self.index.shape[0]: - kwargs = {} + kwargs = {"load_array": self._load_array} if hasattr(self, "columns"): kwargs["columns"] = self.columns if hasattr(self, "_metadata"): @@ -212,7 +213,7 @@ def __array_function__(self, func, types, args, kwargs): # if out.ndim > self.ndim: # return out if out.shape[0] == self.index.shape[0]: - kwargs = {} + kwargs = {"load_array": self._load_array} if ( (self.ndim == 2) and (out.ndim == 2) @@ -770,16 +771,25 @@ def create_str(array): _str_ = [] if self.shape[0] > max_rows: n_rows = max_rows // 2 - for i, array in zip(self.index[0:n_rows], self.values[0:n_rows]): + top_rows = ( + self.values[0:n_rows].compute() if hasattr(self.values, "compute") + else self.values[:n_rows] + ) + bottom_rows = ( + self.values[self.values.shape[0] - n_rows : self.values.shape[0]].compute() if hasattr(self.values, "compute") + else self.values[self.values.shape[0] - n_rows : self.values.shape[0]] + ) + for i, array in zip(self.index[0:n_rows], top_rows): _str_.append([i, create_str(array)]) _str_.append(["...", ""]) for i, array in zip( self.index[-n_rows:], - self.values[self.values.shape[0] - n_rows : self.values.shape[0]], + bottom_rows, ): _str_.append([i, create_str(array)]) else: - for i, array in zip(self.index, self.values): + rows = self.values.compute() if hasattr(self.values, "compute") else self.values + for i, array in zip(self.index, rows): _str_.append([i, create_str(array)]) return tabulate(_str_, headers=headers, colalign=("left",)) + "\n" + bottom @@ -794,6 +804,7 @@ def __getitem__(self, key): "When indexing with a Tsd, it must contain boolean values" ) output = self.values[key.values] + output = output.compute() if hasattr(output, "compute") else output index = self.index[key.values] elif isinstance(key, tuple): if any( @@ -804,9 +815,11 @@ def __getitem__(self, key): ) key = tuple(k.values if isinstance(k, Tsd) else k for k in key) output = self.values.__getitem__(key) + output = output.compute() if hasattr(output, "compute") else output index = self.index.__getitem__(key[0]) else: output = self.values.__getitem__(key) + output = output.compute() if hasattr(output, "compute") else output index = self.index.__getitem__(key) if isinstance(index, Number): @@ -977,12 +990,20 @@ def __repr__(self): if len(self) > max_rows: n_rows = max_rows // 2 ends = np.array([end] * n_rows) + top_rows = ( + self.values[0:n_rows, 0:max_cols].compute() if hasattr(self.values, "compute") + else self.values[0:n_rows, 0:max_cols] + ) + bottom_rows = ( + self.values[-n_rows:, 0:max_cols].compute() if hasattr(self.values, "compute") + else self.values[-n_rows:, 0:max_cols] + ) table = np.vstack( ( np.hstack( ( self.index[0:n_rows, None], - np.round(self.values[0:n_rows, 0:max_cols], 5), + np.round(top_rows, 5), ends, ), dtype=object, @@ -998,7 +1019,7 @@ def __repr__(self): np.hstack( ( self.index[-n_rows:, None], - np.round(self.values[-n_rows:, 0:max_cols], 5), + np.round(bottom_rows, 5), ends, ), dtype=object, @@ -1007,10 +1028,14 @@ def __repr__(self): ) else: ends = np.array([end] * len(self)) + rows = ( + self.values[:, 0:max_cols].compute() if hasattr(self.values, "compute") else + self.values[:, 0:max_cols] + ) table = np.hstack( ( self.index[:, None], - np.round(self.values[:, 0:max_cols], 5), + np.round(rows, 5), ends, ), dtype=object, @@ -1114,6 +1139,7 @@ def __getitem__(self, key, *args, **kwargs): key = (slice(None, None, None), key) output = self.values.__getitem__(key) + output = output.compute() if hasattr(output, "compute") else output columns = self.columns if isinstance(key, tuple): @@ -1322,14 +1348,20 @@ def __repr__(self): if len(self) > max_rows: n_rows = max_rows // 2 table = [] - for i, v in zip(self.index[0:n_rows], self.values[0:n_rows]): + top_rows = ( + self.values[0:n_rows].compute() if hasattr(self.values, "compute") + else self.values[0:n_rows] + ) + bottom_rows = ( + self.values[self.values.shape[0] - n_rows : self.values.shape[0]].compute() if hasattr(self.values, "compute") + else self.values[self.values.shape[0] - n_rows : self.values.shape[0]] + ) + for i, v in zip(self.index[0:n_rows], top_rows): table.append([i, v]) table.append(["..."]) for i, v in zip( self.index[-n_rows:], - self.values[ - self.values.shape[0] - n_rows : self.values.shape[0] - ], + bottom_rows, ): table.append([i, v]) @@ -1384,6 +1416,7 @@ def __getitem__(self, key, *args, **kwargs): key = key.d output = self.values.__getitem__(key) + output = output.compute() if hasattr(output, "compute") else output if isinstance(key, tuple): index = self.index.__getitem__(key[0]) @@ -1638,8 +1671,6 @@ def __init__(self, t, time_units="s", time_support=None): def __repr__(self): upper = "Time (s)" - - max_rows = 2 rows = _get_terminal_size()[1] max_rows = np.maximum(rows - 10, 2) From 4ef8b2e2402a05e33f0e4c3261aef8505b26e45b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 8 Nov 2024 08:18:00 -0500 Subject: [PATCH 02/17] formatted --- pynapple/core/time_series.py | 41 ++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 88822967..16a49e41 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -772,12 +772,18 @@ def create_str(array): if self.shape[0] > max_rows: n_rows = max_rows // 2 top_rows = ( - self.values[0:n_rows].compute() if hasattr(self.values, "compute") + self.values[0:n_rows].compute() + if hasattr(self.values, "compute") else self.values[:n_rows] ) bottom_rows = ( - self.values[self.values.shape[0] - n_rows : self.values.shape[0]].compute() if hasattr(self.values, "compute") - else self.values[self.values.shape[0] - n_rows : self.values.shape[0]] + self.values[ + self.values.shape[0] - n_rows : self.values.shape[0] + ].compute() + if hasattr(self.values, "compute") + else self.values[ + self.values.shape[0] - n_rows : self.values.shape[0] + ] ) for i, array in zip(self.index[0:n_rows], top_rows): _str_.append([i, create_str(array)]) @@ -788,7 +794,11 @@ def create_str(array): ): _str_.append([i, create_str(array)]) else: - rows = self.values.compute() if hasattr(self.values, "compute") else self.values + rows = ( + self.values.compute() + if hasattr(self.values, "compute") + else self.values + ) for i, array in zip(self.index, rows): _str_.append([i, create_str(array)]) @@ -991,11 +1001,13 @@ def __repr__(self): n_rows = max_rows // 2 ends = np.array([end] * n_rows) top_rows = ( - self.values[0:n_rows, 0:max_cols].compute() if hasattr(self.values, "compute") + self.values[0:n_rows, 0:max_cols].compute() + if hasattr(self.values, "compute") else self.values[0:n_rows, 0:max_cols] ) bottom_rows = ( - self.values[-n_rows:, 0:max_cols].compute() if hasattr(self.values, "compute") + self.values[-n_rows:, 0:max_cols].compute() + if hasattr(self.values, "compute") else self.values[-n_rows:, 0:max_cols] ) table = np.vstack( @@ -1029,8 +1041,9 @@ def __repr__(self): else: ends = np.array([end] * len(self)) rows = ( - self.values[:, 0:max_cols].compute() if hasattr(self.values, "compute") else - self.values[:, 0:max_cols] + self.values[:, 0:max_cols].compute() + if hasattr(self.values, "compute") + else self.values[:, 0:max_cols] ) table = np.hstack( ( @@ -1349,12 +1362,18 @@ def __repr__(self): n_rows = max_rows // 2 table = [] top_rows = ( - self.values[0:n_rows].compute() if hasattr(self.values, "compute") + self.values[0:n_rows].compute() + if hasattr(self.values, "compute") else self.values[0:n_rows] ) bottom_rows = ( - self.values[self.values.shape[0] - n_rows : self.values.shape[0]].compute() if hasattr(self.values, "compute") - else self.values[self.values.shape[0] - n_rows : self.values.shape[0]] + self.values[ + self.values.shape[0] - n_rows : self.values.shape[0] + ].compute() + if hasattr(self.values, "compute") + else self.values[ + self.values.shape[0] - n_rows : self.values.shape[0] + ] ) for i, v in zip(self.index[0:n_rows], top_rows): table.append([i, v]) From 212f83de2dac8da45e3917259cd8e11b4e37b20d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 11:39:55 -0500 Subject: [PATCH 03/17] added dask array compat --- pynapple/core/time_series.py | 75 ++++++------ pyproject.toml | 3 +- tests/test_lazy_loading.py | 222 +++++++++++++++++++++++++++++++++++ 3 files changed, 261 insertions(+), 39 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index bdf53485..5e8df475 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -19,6 +19,7 @@ import importlib import warnings from numbers import Number +from typing import Callable import numpy as np import pandas as pd @@ -63,6 +64,27 @@ def _get_class(data): return TsdTensor +def _initialize_tsd_output(inp, out): + + if isinstance(out, np.ndarray) or is_array_like(out): + # # if dims increased in any case, we can't return safely a time series + # if out.ndim > self.ndim: + # return out + if out.shape[0] == inp.index.shape[0]: + kwargs = {"load_array": inp._load_array} + if (inp.ndim == 2) and (out.ndim == 2) and (out.shape[1] == inp.shape[1]): + # only pass columns and metadata if number of columns is preserved + if hasattr(inp, "columns"): + kwargs["columns"] = inp.columns + if hasattr(inp, "_metadata"): + kwargs["metadata"] = inp._metadata + return _get_class(out)( + t=inp.index, d=out, time_support=inp.time_support, **kwargs + ) + + return out + + class _BaseTsd(_Base, NDArrayOperatorsMixin, abc.ABC): """ Abstract base class for time series objects. @@ -125,6 +147,19 @@ def method(*args, **kwargs): return np_func(self, *args, **kwargs) return method + if name not in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): + # apply array specific methods + attr = getattr(self.d, name, None) + + if isinstance(attr, Callable): + + def method(*args, **kwargs): + out = attr(*args, **kwargs) + return _initialize_tsd_output(self, out) + + return method + elif attr: + return attr raise AttributeError( "Time series object does not have the attribute {}".format(name) @@ -167,20 +202,7 @@ def __array_ufunc__(self, ufunc, method, *args, **kwargs): else: out = ufunc(*new_args, **kwargs) - if isinstance(out, np.ndarray) or is_array_like(out): - if out.shape[0] == self.index.shape[0]: - kwargs = {"load_array": self._load_array} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - return _get_class(out)( - t=self.index, d=out, time_support=self.time_support, **kwargs - ) - else: - return out - else: - return out + return _initialize_tsd_output(self, out) else: return NotImplemented @@ -211,30 +233,7 @@ def __array_function__(self, func, types, args, kwargs): new_args.append(a) out = func._implementation(*new_args, **kwargs) - - if isinstance(out, np.ndarray) or is_array_like(out): - # # if dims increased in any case, we can't return safely a time series - # if out.ndim > self.ndim: - # return out - if out.shape[0] == self.index.shape[0]: - kwargs = {"load_array": self._load_array} - if ( - (self.ndim == 2) - and (out.ndim == 2) - and (out.shape[1] == self.shape[1]) - ): - # only pass columns and metadata if number of columns is preserved - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - return _get_class(out)( - t=self.index, d=out, time_support=self.time_support, **kwargs - ) - else: - return out - else: - return out + return _initialize_tsd_output(self, out) def as_array(self): """ diff --git a/pyproject.toml b/pyproject.toml index 00e0a8f0..f5ba5afd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,8 @@ docs = [ "matplotlib", "seaborn", "zarr", - "dandi" + "dandi", + "dask", ] dandi = [ "dandi", # Dandi package diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index 9695497d..8078c738 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -1,17 +1,78 @@ import warnings from contextlib import nullcontext as does_not_raise from pathlib import Path +from tempfile import TemporaryDirectory +import dask.array as da import h5py import numpy as np import pandas as pd import pytest +import zarr from pynwb.testing.mock.base import mock_TimeSeries from pynwb.testing.mock.file import mock_NWBFile import pynapple as nap +@pytest.fixture +def dask_array_tsdframe(): + """Fixture for a Dask array.""" + array = da.random.random((100, 100), chunks=(10, 10)) + return array + + +@pytest.fixture +def dask_array_tsd(): + """Fixture for a Dask array.""" + array = da.random.random((100,), chunks=(10,)) + return array + + +@pytest.fixture +def dask_array_tsdtensor(): + """Fixture for a Dask array.""" + array = da.random.random((100, 10, 5), chunks=(10, 1, 2)) + return array + + +@pytest.fixture +def zarr_tsd(): + """Fixture for a Zarr array.""" + with TemporaryDirectory() as tmpdir: + store = zarr.DirectoryStore(tmpdir) + root = zarr.open(store, mode="w") + array = root.create_dataset("data", shape=(100,), chunks=(10,), dtype="f8") + array[:] = np.random.random((100,)) + yield array + + +@pytest.fixture +def zarr_tsdframe(): + """Fixture for a Zarr array.""" + with TemporaryDirectory() as tmpdir: + store = zarr.DirectoryStore(tmpdir) + root = zarr.open(store, mode="w") + array = root.create_dataset( + "data", shape=(100, 100), chunks=(10, 10), dtype="f8" + ) + array[:] = np.random.random((100, 100)) + yield array + + +@pytest.fixture +def zarr_tsdtensor(): + """Fixture for a Zarr array.""" + with TemporaryDirectory() as tmpdir: + store = zarr.DirectoryStore(tmpdir) + root = zarr.open(store, mode="w") + array = root.create_dataset( + "data", shape=(100, 10, 2), chunks=(10, 9, 1), dtype="f8" + ) + array[:] = np.random.random((100, 10, 2)) + yield array + + @pytest.mark.parametrize( "time, data, expectation", [ @@ -301,3 +362,164 @@ def test_tsgroup_no_warnings(tmp_path): # default fixture # file_path = Path(f'data_{k}.h5') # if file_path.exists(): # file_path.unlink() + + +def test_dask_lazy_loading_tsd(dask_array_tsd): + tsd = nap.Tsd( + t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=False + ) + assert isinstance(tsd.d, da.Array) + assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsd) + assert isinstance(tsd.d, da.Array) + assert isinstance(tsd[1:10].d, np.ndarray) + tsd = nap.Tsd( + t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=True + ) + assert isinstance(tsd.d, np.ndarray) + + +def test_dask_lazy_compute_tsd(dask_array_tsd): + tsd = nap.Tsd( + t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=False + ) + tsd = tsd + 1 + assert isinstance(tsd.d, da.Array) + assert isinstance(tsd[:10].d, np.ndarray) + assert tsd[:10]._load_array is True + + out = tsd.compute() + assert isinstance(out.d, np.ndarray) + assert isinstance(tsd.chunks, tuple) + assert tsd._load_array is False + + out2 = tsd.map_blocks(np.exp) + assert isinstance(out2.d, da.Array) + assert out2._load_array is False + + assert isinstance(np.exp(tsd).d, da.Array) + + +def test_dask_lazy_loading_tsdframe(dask_array_tsdframe): + tsdframe = nap.TsdFrame( + t=np.arange(dask_array_tsdframe.shape[0]), + d=dask_array_tsdframe, + load_array=False, + ) + assert isinstance(tsdframe.d, da.Array) + assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsdframe) + assert isinstance(tsdframe.d, da.Array) + assert isinstance(tsdframe[1:10].d, np.ndarray) + assert isinstance(tsdframe.loc[1].d, np.ndarray) + tsdframe = nap.TsdFrame( + t=np.arange(dask_array_tsdframe.shape[0]), + d=dask_array_tsdframe, + load_array=True, + ) + assert isinstance(tsdframe.d, np.ndarray) + + +def test_dask_lazy_compute_tsdframe(dask_array_tsdframe): + tsdframe = nap.TsdFrame( + t=np.arange(dask_array_tsdframe.shape[0]), + d=dask_array_tsdframe, + load_array=False, + ) + tsdframe = tsdframe**2 + assert isinstance(tsdframe.d, da.Array) + assert isinstance(tsdframe[:10].d, np.ndarray) + assert tsdframe[:10]._load_array is True + + out = tsdframe.compute() + assert isinstance(out.d, np.ndarray) + assert isinstance(tsdframe.chunks, tuple) + assert tsdframe._load_array is False + + out2 = tsdframe.map_blocks(np.exp) + assert isinstance(out2.d, da.Array) + assert out2._load_array is False + assert isinstance(np.exp(tsdframe).d, da.Array) + + +def test_dask_lazy_loading_tsdtensor(dask_array_tsdtensor): + tsdtensor = nap.TsdTensor( + t=np.arange(dask_array_tsdtensor.shape[0]), + d=dask_array_tsdtensor, + load_array=False, + ) + assert isinstance(tsdtensor.d, da.Array) + assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsdtensor) + assert isinstance(tsdtensor.d, da.Array) + assert isinstance(tsdtensor[1:10].d, np.ndarray) + tsdtensor = nap.TsdTensor( + t=np.arange(dask_array_tsdtensor.shape[0]), + d=dask_array_tsdtensor, + load_array=True, + ) + assert isinstance(tsdtensor.d, np.ndarray) + + +def test_dask_lazy_compute_tsdtensor(dask_array_tsdtensor): + tsdtensor = nap.TsdTensor( + t=np.arange(dask_array_tsdtensor.shape[0]), + d=dask_array_tsdtensor, + load_array=False, + ) + tsdtensor = tsdtensor + 1 + assert isinstance(tsdtensor.d, da.Array) + assert isinstance(tsdtensor[:10].d, np.ndarray) + assert tsdtensor[:10]._load_array is True + + out = tsdtensor.compute() + assert isinstance(out.d, np.ndarray) + assert isinstance(tsdtensor.chunks, tuple) + assert tsdtensor._load_array is False + + out2 = tsdtensor.map_blocks(np.exp) + assert isinstance(out2.d, da.Array) + assert out2._load_array is False + + assert isinstance(np.exp(tsdtensor).d, da.Array) + + +def test_lazy_load_zarr_tsd(zarr_tsd): + tsd = nap.Tsd(t=np.arange(zarr_tsd.shape[0]), d=zarr_tsd, load_array=False) + assert isinstance(tsd.d, zarr.Array) + assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsd) + assert isinstance(tsd.d, zarr.Array) + assert isinstance(tsd[1:10].d, np.ndarray) + tsd = nap.TsdFrame(t=np.arange(zarr_tsd.shape[0]), d=zarr_tsd, load_array=True) + assert isinstance(tsd.d, np.ndarray) + + +def test_lazy_load_zarr_tsdframe(zarr_tsdframe): + tsdframe = nap.TsdFrame( + t=np.arange(zarr_tsdframe.shape[0]), d=zarr_tsdframe, load_array=False + ) + assert isinstance(tsdframe.d, zarr.Array) + assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsdframe) + assert isinstance(tsdframe.d, zarr.Array) + assert isinstance(tsdframe[1:10].d, np.ndarray) + tsdframe = nap.TsdFrame( + t=np.arange(zarr_tsdframe.shape[0]), d=zarr_tsdframe, load_array=True + ) + assert isinstance(tsdframe.d, np.ndarray) + + +def test_lazy_load_zarr_tsdtensor(zarr_tsdtensor): + tsdtensor = nap.TsdTensor( + t=np.arange(zarr_tsdtensor.shape[0]), d=zarr_tsdtensor, load_array=False + ) + assert isinstance(tsdtensor.d, zarr.Array) + assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsdtensor) + assert isinstance(tsdtensor.d, zarr.Array) + assert isinstance(tsdtensor[1:10].d, np.ndarray) + tsdtensor = nap.TsdTensor( + t=np.arange(zarr_tsdtensor.shape[0]), d=zarr_tsdtensor, load_array=True + ) + assert isinstance(tsdtensor.d, np.ndarray) From 386ee03a6fb9efa837b2aa15c457348e1c8ddce5 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 12:04:26 -0500 Subject: [PATCH 04/17] add dask to dev deps --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f5ba5afd..01b03138 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,7 @@ dandi = [ "fsspec", "requests", "aiohttp", + "dask", ] From cb2239333f5db25dcb609c90202f689e63d73956 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 12:05:01 -0500 Subject: [PATCH 05/17] removed auto added import --- tests/test_misc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_misc.py b/tests/test_misc.py index e2f4469f..4b88d3b3 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -16,7 +16,6 @@ import pytest import pynapple as nap -from docs.generated.api_guide.tutorial_pynapple_nwb import n_channels # look for tests folder path = Path(__file__).parent From 6b5e594e9d756a2c291a227955b04ff5c8847050 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 12:07:58 -0500 Subject: [PATCH 06/17] added deps --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 01b03138..082e7525 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dev = [ "pytest", # Testing framework "flake8", # Code linter "coverage", # Test coverage measurement + "dask", ] doc = [ "matplotlib", @@ -72,7 +73,8 @@ doc = [ "myst-nb", "dandi", "sphinx-autobuild", - "sphinx-contributors" + "sphinx-contributors", + "dask", # "sphinx-exec-code" ] docs = [ @@ -94,7 +96,6 @@ dandi = [ "fsspec", "requests", "aiohttp", - "dask", ] From 84cf039c84ae75cd9d0a74408be3d11d9d01e976 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 12:13:45 -0500 Subject: [PATCH 07/17] added deps --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 082e7525..49bb7458 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ doc = [ "sphinx-autobuild", "sphinx-contributors", "dask", + "zarr", # "sphinx-exec-code" ] docs = [ From dc8d2db225c10c36a6328b48f586ecee0c4f5de0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 12:16:22 -0500 Subject: [PATCH 08/17] added deps --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 49bb7458..e30bf0e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dev = [ "flake8", # Code linter "coverage", # Test coverage measurement "dask", + "zarr", ] doc = [ "matplotlib", From 766a46967e45fa8fa593858b69b96522e87e07a0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 12:17:39 -0500 Subject: [PATCH 09/17] remove redundant dep --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e30bf0e0..1685f0cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,6 @@ doc = [ "sphinx-autobuild", "sphinx-contributors", "dask", - "zarr", # "sphinx-exec-code" ] docs = [ From 8339e40eb3e8a4a6ab200222dd70a0500f505e36 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 17:11:37 -0500 Subject: [PATCH 10/17] replaced all __class__(...) calls to enforce that the metadata is kept --- pynapple/core/base_class.py | 30 ++++++++-------- pynapple/core/time_series.py | 67 +++++++++++++++-------------------- pynapple/core/utils.py | 6 ++-- pynapple/process/filtering.py | 5 +-- tests/test_abstract_tsd.py | 3 ++ tests/test_lazy_loading.py | 6 ++-- 6 files changed, 52 insertions(+), 65 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 37a5a779..22a650b5 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -57,6 +57,14 @@ def __init__(self, t, time_units="s", time_support=None): self.rate = np.nan self.time_support = IntervalSet(start=[], end=[]) + @abc.abstractmethod + def _define_instance(self, time, iset, data=None, **kwargs): + """Return a new class instance. + + Grab "columns", "metadata" and other and other + """ + pass + @property def t(self): """The time index of the time series""" @@ -368,25 +376,15 @@ def restrict(self, iset): ends = iset.end idx = _restrict(time_array, starts, ends) - - kwargs = {} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - - if hasattr(self, "values"): - data_array = self.values - return self.__class__( - t=time_array[idx], d=data_array[idx], time_support=iset, **kwargs - ) - else: - return self.__class__(t=time_array[idx], time_support=iset) + data = None if not hasattr(self, "values") else self.values[idx] + return self._define_instance(time_array[idx] , iset, data=data) def copy(self): """Copy the data, index and time support""" - return self.__class__(t=self.index.copy(), time_support=self.time_support) + data = getattr(self, "values", None) + if data is not None: + data = data.copy() if hasattr(data, "copy") else data[:].copy() + return self._define_instance(self.index.copy(), self.time_support, data=data) def find_support(self, min_gap, time_units="s"): """ diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 5e8df475..88018ce4 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -129,6 +129,26 @@ def __init__(self, t, d, time_units="s", time_support=None, load_array=True): self.dtype = self.values.dtype self._load_array = load_array + + def _define_instance(self, time, iset, data=None, **kwargs): + """ + Define a new class instance. + + Optional parameters for initialization are either passed to the function or are grabbed from self. + """ + for key in ["columns", "metadata", "load_array"]: + if hasattr(self, key): + kwargs[key] = kwargs.get(key, getattr(self, key)) + return self.__class__( + t=time, d=data, time_support=iset, **kwargs + ) + + + @property + def load_array(self): + """Read-only property load-array.""" + return self._load_array + def __setitem__(self, key, value): """setter for time series""" if isinstance(key, _BaseTsd): @@ -265,12 +285,6 @@ def to_numpy(self): """ return np.asarray(self.values) - def copy(self): - """Copy the data, index and time support""" - return self.__class__( - t=self.index.copy(), d=self.values[:].copy(), time_support=self.time_support - ) - def value_from(self, data, ep=None): """ Replace the value with the closest value from Tsd/TsdFrame/TsdTensor argument @@ -314,7 +328,7 @@ def value_from(self, data, ep=None): ), "First argument should be an instance of Tsd, TsdFrame or TsdTensor" t, d, time_support, kwargs = super().value_from(data, ep) - return data.__class__(t=t, d=d, time_support=time_support, **kwargs) + return data._define_instance(t, time_support, data=d, **kwargs) def count(self, *args, dtype=None, **kwargs): """ @@ -428,13 +442,7 @@ def bin_average(self, bin_size, ep=None, time_units="s"): t, d = _bin_average(time_array, data_array, starts, ends, bin_size) - kwargs = {} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - - return self.__class__(t=t, d=d, time_support=ep, **kwargs) + return self._define_instance(t, ep, data=d) def dropna(self, update_time_support=True): """Drop every rows containing NaNs. By default, the time support is updated to start and end around the time points that are non NaNs. @@ -468,13 +476,7 @@ def dropna(self, update_time_support=True): else: ep = self.time_support - kwargs = {} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - - return self.__class__(t=t, d=d, time_support=ep, **kwargs) + return self._define_instance(t, ep, data=d) def convolve(self, array, ep=None, trim="both"): """Return the discrete linear convolution of the time series with a one dimensional sequence. @@ -698,12 +700,8 @@ def interpolate(self, ts, ep=None, left=None, right=None): new_d[start : start + len(t), ...] = interpolated_values start += len(t) - kwargs_dict = dict(time_support=ep) - if hasattr(self, "columns"): - kwargs_dict["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs_dict["metadata"] = self._metadata - return self.__class__(t=new_t, d=new_d, **kwargs_dict) + + return self._define_instance(new_t, ep, data=new_d) class TsdTensor(_BaseTsd): @@ -1351,16 +1349,6 @@ def as_units(self, units="s"): df.columns = self.columns.copy() return df - def copy(self): - """Copy the data, index, time support, columns and metadata of the TsdFrame object.""" - return self.__class__( - t=self.index.copy(), - d=self.values[:].copy(), - time_support=self.time_support, - columns=self.columns.copy(), - metadata=self._metadata, - ) - def save(self, filename): """ Save TsdFrame object in npz format. The file will contain the timestamps, the @@ -2025,6 +2013,9 @@ def __init__(self, t, time_units="s", time_support=None): self.nap_class = self.__class__.__name__ self._initialized = True + def _define_instance(self, time, iset, data=None, **kwargs): + return self.__class__(t=time, time_support=iset) + def __repr__(self): upper = "Time (s)" rows = _get_terminal_size()[1] @@ -2130,7 +2121,7 @@ def value_from(self, data, ep=None): t, d, time_support, kwargs = super().value_from(data, ep) - return data.__class__(t, d, time_support=time_support, **kwargs) + return data._define_instance(t, time_support, data=d, **kwargs) def count(self, *args, dtype=None, **kwargs): """ diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index eb8410be..d0b83ce4 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -233,12 +233,10 @@ def _split_tsd(func, tsd, indices_or_sections, axis=0): if func in [np.split, np.array_split, np.vsplit] and axis == 0: out = func._implementation(tsd.values, indices_or_sections) index_list = np.split(tsd.index.values, indices_or_sections) - kwargs = {"columns": tsd.columns.values} if hasattr(tsd, "columns") else {} - return [tsd.__class__(t=t, d=d, **kwargs) for t, d in zip(index_list, out)] + return [tsd._define_instance(t, None, data=d) for t, d in zip(index_list, out)] elif func in [np.dsplit, np.hsplit]: out = func._implementation(tsd.values, indices_or_sections) - kwargs = {"columns": tsd.columns.values} if hasattr(tsd, "columns") else {} - return [tsd.__class__(t=tsd.index, d=d, **kwargs) for d in out] + return [tsd._define_instance(tsd.index, None, data=d) for d in out] else: return func._implementation(tsd.values, indices_or_sections, axis) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 01303e02..912c5242 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -89,10 +89,7 @@ def _compute_butterworth_filter( slc = data.get_slice(start=ep.start[0], end=ep.end[0]) out[slc] = sosfiltfilt(sos, data.d[slc], axis=0) - kwargs = dict(t=data.t, d=out, time_support=data.time_support) - if isinstance(data, nap.TsdFrame): - kwargs["columns"] = data.columns - return data.__class__(**kwargs) + return data._define_instance(data.t, data.time_support, data=out) def _compute_spectral_inversion(kernel): diff --git a/tests/test_abstract_tsd.py b/tests/test_abstract_tsd.py index 3a184dd3..709d1450 100644 --- a/tests/test_abstract_tsd.py +++ b/tests/test_abstract_tsd.py @@ -37,6 +37,9 @@ def __str__(self): def __repr__(self): return "In repr" + def _define_instance(self, time, iset, data=None, **kwargs): + pass + def test_create_atsd(): a = MyClass(t=np.arange(10), d=np.arange(10)) diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index 8078c738..efad83e8 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -369,7 +369,7 @@ def test_dask_lazy_loading_tsd(dask_array_tsd): t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=False ) assert isinstance(tsd.d, da.Array) - assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, da.Array) repr(tsd) assert isinstance(tsd.d, da.Array) assert isinstance(tsd[1:10].d, np.ndarray) @@ -407,7 +407,7 @@ def test_dask_lazy_loading_tsdframe(dask_array_tsdframe): load_array=False, ) assert isinstance(tsdframe.d, da.Array) - assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, da.Array) repr(tsdframe) assert isinstance(tsdframe.d, da.Array) assert isinstance(tsdframe[1:10].d, np.ndarray) @@ -449,7 +449,7 @@ def test_dask_lazy_loading_tsdtensor(dask_array_tsdtensor): load_array=False, ) assert isinstance(tsdtensor.d, da.Array) - assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, da.Array) repr(tsdtensor) assert isinstance(tsdtensor.d, da.Array) assert isinstance(tsdtensor[1:10].d, np.ndarray) From d76c80e076a3d83a4d8d5bef753675ff30b3cac3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 18:01:17 -0500 Subject: [PATCH 11/17] added test for define instance --- tests/test_time_series.py | 60 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index e1064da9..1d6444ab 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -2207,3 +2207,63 @@ def test_get_slice_public(start, end, expected_slice, expected_array, ts): out_array = ts.t[out_slice] assert out_slice == expected_slice assert np.all(out_array == expected_array) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"columns": [1, 2]}, + {"metadata": {"banana": [3, 4]}}, + {"load_array": False}, + { + "columns": ["a", "b"], + "metadata": {"banana": [3, 4]}, + "load_array": False + }, + ] +) +@pytest.mark.parametrize( + "tsd", + [ + nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0,15)), + nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0,15)), + nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0,15)), + nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0,15)), + nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15), columns=["a", "b"]), + nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15), metadata={"pineapple": [1, 2]}), + nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15), load_array=True), + nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15),load_array=True, columns=["a", "b"], metadata={"pineapple": [1, 2]}), + nap.TsdTensor(t=np.arange(10), d=np.zeros((10, 2, 3)), time_support=nap.IntervalSet(0,15)), + ] +) +def test_define_instance(tsd, kwargs): + t = tsd.t + d = getattr(tsd, "d", None) + iset = tsd.time_support + cols = kwargs.get("columns", None) + + # metadata index must be cols if provided. + # clear metadata if cols are provided to avoid errors + if (cols is not None) and ("metadata" not in kwargs): + kwargs["metadata"] = {} + + out = tsd._define_instance(t, iset, data=d, **kwargs) + + # check data + np.testing.assert_array_equal(out.t, t) + np.testing.assert_array_equal(out.time_support, iset) + if hasattr(tsd, "d"): + np.testing.assert_array_equal(out.d, d) + + # if TsdFrame check kwargs + if isinstance(tsd, nap.TsdFrame): + for key in ["columns", "load_array"]: + val = kwargs.get(key, getattr(tsd, key)) + assert np.all(val == getattr(out, key)) + # get expected metadata + meta = kwargs.get("metadata", getattr(tsd, "metadata")) + for key, val, in meta.items(): + assert np.all(out.metadata[key] == val) + + From 43ff23a4c8fdb96d84c7c2e949ace4c38efa3763 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 18:02:06 -0500 Subject: [PATCH 12/17] linted --- pynapple/core/base_class.py | 2 +- pynapple/core/time_series.py | 6 +--- tests/test_time_series.py | 59 ++++++++++++++++++++++++------------ 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 22a650b5..8beaa5cc 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -377,7 +377,7 @@ def restrict(self, iset): idx = _restrict(time_array, starts, ends) data = None if not hasattr(self, "values") else self.values[idx] - return self._define_instance(time_array[idx] , iset, data=data) + return self._define_instance(time_array[idx], iset, data=data) def copy(self): """Copy the data, index and time support""" diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 88018ce4..3a38e0a1 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -129,7 +129,6 @@ def __init__(self, t, d, time_units="s", time_support=None, load_array=True): self.dtype = self.values.dtype self._load_array = load_array - def _define_instance(self, time, iset, data=None, **kwargs): """ Define a new class instance. @@ -139,10 +138,7 @@ def _define_instance(self, time, iset, data=None, **kwargs): for key in ["columns", "metadata", "load_array"]: if hasattr(self, key): kwargs[key] = kwargs.get(key, getattr(self, key)) - return self.__class__( - t=time, d=data, time_support=iset, **kwargs - ) - + return self.__class__(t=time, d=data, time_support=iset, **kwargs) @property def load_array(self): diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 1d6444ab..3bc4d38b 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -2216,26 +2216,46 @@ def test_get_slice_public(start, end, expected_slice, expected_array, ts): {"columns": [1, 2]}, {"metadata": {"banana": [3, 4]}}, {"load_array": False}, - { - "columns": ["a", "b"], - "metadata": {"banana": [3, 4]}, - "load_array": False - }, - ] + {"columns": ["a", "b"], "metadata": {"banana": [3, 4]}, "load_array": False}, + ], ) @pytest.mark.parametrize( "tsd", [ - nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0,15)), - nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0,15)), - nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0,15)), - nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0,15)), - nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15), columns=["a", "b"]), - nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15), metadata={"pineapple": [1, 2]}), - nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15), load_array=True), - nap.TsdFrame(t=np.arange(10), d=np.zeros((10, 2)), time_support=nap.IntervalSet(0,15),load_array=True, columns=["a", "b"], metadata={"pineapple": [1, 2]}), - nap.TsdTensor(t=np.arange(10), d=np.zeros((10, 2, 3)), time_support=nap.IntervalSet(0,15)), - ] + nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0, 15)), + nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0, 15)), + nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0, 15)), + nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0, 15)), + nap.TsdFrame( + t=np.arange(10), + d=np.zeros((10, 2)), + time_support=nap.IntervalSet(0, 15), + columns=["a", "b"], + ), + nap.TsdFrame( + t=np.arange(10), + d=np.zeros((10, 2)), + time_support=nap.IntervalSet(0, 15), + metadata={"pineapple": [1, 2]}, + ), + nap.TsdFrame( + t=np.arange(10), + d=np.zeros((10, 2)), + time_support=nap.IntervalSet(0, 15), + load_array=True, + ), + nap.TsdFrame( + t=np.arange(10), + d=np.zeros((10, 2)), + time_support=nap.IntervalSet(0, 15), + load_array=True, + columns=["a", "b"], + metadata={"pineapple": [1, 2]}, + ), + nap.TsdTensor( + t=np.arange(10), d=np.zeros((10, 2, 3)), time_support=nap.IntervalSet(0, 15) + ), + ], ) def test_define_instance(tsd, kwargs): t = tsd.t @@ -2263,7 +2283,8 @@ def test_define_instance(tsd, kwargs): assert np.all(val == getattr(out, key)) # get expected metadata meta = kwargs.get("metadata", getattr(tsd, "metadata")) - for key, val, in meta.items(): + for ( + key, + val, + ) in meta.items(): assert np.all(out.metadata[key] == val) - - From e7337e0fa2d448dc4a604c99588e33829f76ce9f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 11 Dec 2024 09:51:09 -0500 Subject: [PATCH 13/17] fix class init --- pynapple/core/base_class.py | 31 +++++---- pynapple/core/time_series.py | 122 ++++++++++++---------------------- pynapple/core/utils.py | 6 +- pynapple/process/filtering.py | 5 +- tests/test_abstract_tsd.py | 3 + tests/test_misc.py | 1 - tests/test_time_series.py | 81 ++++++++++++++++++++++ 7 files changed, 146 insertions(+), 103 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 37a5a779..6f659a8d 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -57,6 +57,15 @@ def __init__(self, t, time_units="s", time_support=None): self.rate = np.nan self.time_support = IntervalSet(start=[], end=[]) + @abc.abstractmethod + def _define_instance(self, time, iset, data=None, **kwargs): + """Return a new class instance. + + Pass "columns", "metadata" and other attributes of self + to the new instance unless specified in kwargs. + """ + pass + @property def t(self): """The time index of the time series""" @@ -368,25 +377,15 @@ def restrict(self, iset): ends = iset.end idx = _restrict(time_array, starts, ends) - - kwargs = {} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - - if hasattr(self, "values"): - data_array = self.values - return self.__class__( - t=time_array[idx], d=data_array[idx], time_support=iset, **kwargs - ) - else: - return self.__class__(t=time_array[idx], time_support=iset) + data = None if not hasattr(self, "values") else self.values[idx] + return self._define_instance(time_array[idx], iset, data=data) def copy(self): """Copy the data, index and time support""" - return self.__class__(t=self.index.copy(), time_support=self.time_support) + data = getattr(self, "values", None) + if data is not None: + data = data.copy() if hasattr(data, "copy") else data[:].copy() + return self._define_instance(self.index.copy(), self.time_support, data=data) def find_support(self, min_gap, time_units="s"): """ diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 8f3faecc..d5a0ab8b 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -19,6 +19,7 @@ import importlib import warnings from numbers import Number +from typing import Callable import numpy as np import pandas as pd @@ -63,6 +64,27 @@ def _get_class(data): return TsdTensor +def _initialize_tsd_output(inp, out): + + if isinstance(out, np.ndarray) or is_array_like(out): + # # if dims increased in any case, we can't return safely a time series + # if out.ndim > self.ndim: + # return out + if out.shape[0] == inp.index.shape[0]: + kwargs = {"load_array": inp._load_array} + if (inp.ndim == 2) and (out.ndim == 2) and (out.shape[1] == inp.shape[1]): + # only pass columns and metadata if number of columns is preserved + if hasattr(inp, "columns"): + kwargs["columns"] = inp.columns + if hasattr(inp, "_metadata"): + kwargs["metadata"] = inp._metadata + return _get_class(out)( + t=inp.index, d=out, time_support=inp.time_support, **kwargs + ) + + return out + + class _BaseTsd(_Base, NDArrayOperatorsMixin, abc.ABC): """ Abstract base class for time series objects. @@ -106,6 +128,17 @@ def __init__(self, t, d, time_units="s", time_support=None, load_array=True): self.dtype = self.values.dtype + def _define_instance(self, time, iset, data=None, **kwargs): + """ + Define a new class instance. + + Optional parameters for initialization are either passed to the function or are grabbed from self. + """ + for key in ["columns", "metadata"]: + if hasattr(self, key): + kwargs[key] = kwargs.get(key, getattr(self, key)) + return self.__class__(t=time, d=data, time_support=iset, **kwargs) + def __setitem__(self, key, value): """setter for time series""" if isinstance(key, _BaseTsd): @@ -166,20 +199,7 @@ def __array_ufunc__(self, ufunc, method, *args, **kwargs): else: out = ufunc(*new_args, **kwargs) - if isinstance(out, np.ndarray) or is_array_like(out): - if out.shape[0] == self.index.shape[0]: - kwargs = {} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - return _get_class(out)( - t=self.index, d=out, time_support=self.time_support, **kwargs - ) - else: - return out - else: - return out + return _initialize_tsd_output(self, out) else: return NotImplemented @@ -210,30 +230,7 @@ def __array_function__(self, func, types, args, kwargs): new_args.append(a) out = func._implementation(*new_args, **kwargs) - - if isinstance(out, np.ndarray) or is_array_like(out): - # # if dims increased in any case, we can't return safely a time series - # if out.ndim > self.ndim: - # return out - if out.shape[0] == self.index.shape[0]: - kwargs = {} - if ( - (self.ndim == 2) - and (out.ndim == 2) - and (out.shape[1] == self.shape[1]) - ): - # only pass columns and metadata if number of columns is preserved - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - return _get_class(out)( - t=self.index, d=out, time_support=self.time_support, **kwargs - ) - else: - return out - else: - return out + return _initialize_tsd_output(self, out) def as_array(self): """ @@ -265,12 +262,6 @@ def to_numpy(self): """ return np.asarray(self.values) - def copy(self): - """Copy the data, index and time support""" - return self.__class__( - t=self.index.copy(), d=self.values[:].copy(), time_support=self.time_support - ) - def value_from(self, data, ep=None): """ Replace the value with the closest value from Tsd/TsdFrame/TsdTensor argument @@ -314,7 +305,7 @@ def value_from(self, data, ep=None): ), "First argument should be an instance of Tsd, TsdFrame or TsdTensor" t, d, time_support, kwargs = super().value_from(data, ep) - return data.__class__(t=t, d=d, time_support=time_support, **kwargs) + return data._define_instance(t, time_support, data=d, **kwargs) def count(self, *args, dtype=None, **kwargs): """ @@ -428,13 +419,7 @@ def bin_average(self, bin_size, ep=None, time_units="s"): t, d = _bin_average(time_array, data_array, starts, ends, bin_size) - kwargs = {} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - - return self.__class__(t=t, d=d, time_support=ep, **kwargs) + return self._define_instance(t, ep, data=d) def dropna(self, update_time_support=True): """Drop every rows containing NaNs. By default, the time support is updated to start and end around the time points that are non NaNs. @@ -468,13 +453,7 @@ def dropna(self, update_time_support=True): else: ep = self.time_support - kwargs = {} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - - return self.__class__(t=t, d=d, time_support=ep, **kwargs) + return self._define_instance(t, ep, data=d) def convolve(self, array, ep=None, trim="both"): """Return the discrete linear convolution of the time series with a one dimensional sequence. @@ -698,12 +677,8 @@ def interpolate(self, ts, ep=None, left=None, right=None): new_d[start : start + len(t), ...] = interpolated_values start += len(t) - kwargs_dict = dict(time_support=ep) - if hasattr(self, "columns"): - kwargs_dict["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs_dict["metadata"] = self._metadata - return self.__class__(t=new_t, d=new_d, **kwargs_dict) + + return self._define_instance(new_t, ep, data=new_d) class TsdTensor(_BaseTsd): @@ -1313,16 +1288,6 @@ def as_units(self, units="s"): df.columns = self.columns.copy() return df - def copy(self): - """Copy the data, index, time support, columns and metadata of the TsdFrame object.""" - return self.__class__( - t=self.index.copy(), - d=self.values[:].copy(), - time_support=self.time_support, - columns=self.columns.copy(), - metadata=self._metadata, - ) - def save(self, filename): """ Save TsdFrame object in npz format. The file will contain the timestamps, the @@ -1974,10 +1939,11 @@ def __init__(self, t, time_units="s", time_support=None): self.nap_class = self.__class__.__name__ self._initialized = True + def _define_instance(self, time, iset, data=None, **kwargs): + return self.__class__(t=time, time_support=iset) + def __repr__(self): upper = "Time (s)" - - max_rows = 2 rows = _get_terminal_size()[1] max_rows = np.maximum(rows - 10, 2) @@ -2081,7 +2047,7 @@ def value_from(self, data, ep=None): t, d, time_support, kwargs = super().value_from(data, ep) - return data.__class__(t, d, time_support=time_support, **kwargs) + return data._define_instance(t, time_support, data=d, **kwargs) def count(self, *args, dtype=None, **kwargs): """ diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index eb8410be..d0b83ce4 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -233,12 +233,10 @@ def _split_tsd(func, tsd, indices_or_sections, axis=0): if func in [np.split, np.array_split, np.vsplit] and axis == 0: out = func._implementation(tsd.values, indices_or_sections) index_list = np.split(tsd.index.values, indices_or_sections) - kwargs = {"columns": tsd.columns.values} if hasattr(tsd, "columns") else {} - return [tsd.__class__(t=t, d=d, **kwargs) for t, d in zip(index_list, out)] + return [tsd._define_instance(t, None, data=d) for t, d in zip(index_list, out)] elif func in [np.dsplit, np.hsplit]: out = func._implementation(tsd.values, indices_or_sections) - kwargs = {"columns": tsd.columns.values} if hasattr(tsd, "columns") else {} - return [tsd.__class__(t=tsd.index, d=d, **kwargs) for d in out] + return [tsd._define_instance(tsd.index, None, data=d) for d in out] else: return func._implementation(tsd.values, indices_or_sections, axis) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 01303e02..912c5242 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -89,10 +89,7 @@ def _compute_butterworth_filter( slc = data.get_slice(start=ep.start[0], end=ep.end[0]) out[slc] = sosfiltfilt(sos, data.d[slc], axis=0) - kwargs = dict(t=data.t, d=out, time_support=data.time_support) - if isinstance(data, nap.TsdFrame): - kwargs["columns"] = data.columns - return data.__class__(**kwargs) + return data._define_instance(data.t, data.time_support, data=out) def _compute_spectral_inversion(kernel): diff --git a/tests/test_abstract_tsd.py b/tests/test_abstract_tsd.py index 3a184dd3..709d1450 100644 --- a/tests/test_abstract_tsd.py +++ b/tests/test_abstract_tsd.py @@ -37,6 +37,9 @@ def __str__(self): def __repr__(self): return "In repr" + def _define_instance(self, time, iset, data=None, **kwargs): + pass + def test_create_atsd(): a = MyClass(t=np.arange(10), d=np.arange(10)) diff --git a/tests/test_misc.py b/tests/test_misc.py index e2f4469f..4b88d3b3 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -16,7 +16,6 @@ import pytest import pynapple as nap -from docs.generated.api_guide.tutorial_pynapple_nwb import n_channels # look for tests folder path = Path(__file__).parent diff --git a/tests/test_time_series.py b/tests/test_time_series.py index e1064da9..3bc4d38b 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -2207,3 +2207,84 @@ def test_get_slice_public(start, end, expected_slice, expected_array, ts): out_array = ts.t[out_slice] assert out_slice == expected_slice assert np.all(out_array == expected_array) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"columns": [1, 2]}, + {"metadata": {"banana": [3, 4]}}, + {"load_array": False}, + {"columns": ["a", "b"], "metadata": {"banana": [3, 4]}, "load_array": False}, + ], +) +@pytest.mark.parametrize( + "tsd", + [ + nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0, 15)), + nap.Ts(t=np.arange(10), time_support=nap.IntervalSet(0, 15)), + nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0, 15)), + nap.Tsd(t=np.arange(10), d=np.arange(10), time_support=nap.IntervalSet(0, 15)), + nap.TsdFrame( + t=np.arange(10), + d=np.zeros((10, 2)), + time_support=nap.IntervalSet(0, 15), + columns=["a", "b"], + ), + nap.TsdFrame( + t=np.arange(10), + d=np.zeros((10, 2)), + time_support=nap.IntervalSet(0, 15), + metadata={"pineapple": [1, 2]}, + ), + nap.TsdFrame( + t=np.arange(10), + d=np.zeros((10, 2)), + time_support=nap.IntervalSet(0, 15), + load_array=True, + ), + nap.TsdFrame( + t=np.arange(10), + d=np.zeros((10, 2)), + time_support=nap.IntervalSet(0, 15), + load_array=True, + columns=["a", "b"], + metadata={"pineapple": [1, 2]}, + ), + nap.TsdTensor( + t=np.arange(10), d=np.zeros((10, 2, 3)), time_support=nap.IntervalSet(0, 15) + ), + ], +) +def test_define_instance(tsd, kwargs): + t = tsd.t + d = getattr(tsd, "d", None) + iset = tsd.time_support + cols = kwargs.get("columns", None) + + # metadata index must be cols if provided. + # clear metadata if cols are provided to avoid errors + if (cols is not None) and ("metadata" not in kwargs): + kwargs["metadata"] = {} + + out = tsd._define_instance(t, iset, data=d, **kwargs) + + # check data + np.testing.assert_array_equal(out.t, t) + np.testing.assert_array_equal(out.time_support, iset) + if hasattr(tsd, "d"): + np.testing.assert_array_equal(out.d, d) + + # if TsdFrame check kwargs + if isinstance(tsd, nap.TsdFrame): + for key in ["columns", "load_array"]: + val = kwargs.get(key, getattr(tsd, key)) + assert np.all(val == getattr(out, key)) + # get expected metadata + meta = kwargs.get("metadata", getattr(tsd, "metadata")) + for ( + key, + val, + ) in meta.items(): + assert np.all(out.metadata[key] == val) From aa4254b8a457296be575954ed421cd61535b3ac3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 11 Dec 2024 09:53:44 -0500 Subject: [PATCH 14/17] fix merge --- pynapple/core/time_series.py | 2 +- tests/test_time_series.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index d5a0ab8b..4d1a71c7 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -71,7 +71,7 @@ def _initialize_tsd_output(inp, out): # if out.ndim > self.ndim: # return out if out.shape[0] == inp.index.shape[0]: - kwargs = {"load_array": inp._load_array} + kwargs = {} if (inp.ndim == 2) and (out.ndim == 2) and (out.shape[1] == inp.shape[1]): # only pass columns and metadata if number of columns is preserved if hasattr(inp, "columns"): diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 3bc4d38b..dbf18ab6 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -2215,8 +2215,7 @@ def test_get_slice_public(start, end, expected_slice, expected_array, ts): {}, {"columns": [1, 2]}, {"metadata": {"banana": [3, 4]}}, - {"load_array": False}, - {"columns": ["a", "b"], "metadata": {"banana": [3, 4]}, "load_array": False}, + {"columns": ["a", "b"], "metadata": {"banana": [3, 4]}}, ], ) @pytest.mark.parametrize( @@ -2278,9 +2277,8 @@ def test_define_instance(tsd, kwargs): # if TsdFrame check kwargs if isinstance(tsd, nap.TsdFrame): - for key in ["columns", "load_array"]: - val = kwargs.get(key, getattr(tsd, key)) - assert np.all(val == getattr(out, key)) + val = kwargs.get("columns", getattr(tsd, "columns")) + assert np.all(val == getattr(out, "columns")) # get expected metadata meta = kwargs.get("metadata", getattr(tsd, "metadata")) for ( From 4b8723f646f09acff973d956e8ef455e94de0670 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 11 Dec 2024 09:54:16 -0500 Subject: [PATCH 15/17] linted --- pynapple/core/time_series.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 4d1a71c7..7d7fafbd 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -19,7 +19,6 @@ import importlib import warnings from numbers import Number -from typing import Callable import numpy as np import pandas as pd From 500ab98e4fc42c50efd72103c993b99d2c3fec48 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 11 Dec 2024 11:15:52 -0500 Subject: [PATCH 16/17] added test for __class__ direct invocations --- tests/test_call_invocation.py | 221 ++++++++++++++++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 tests/test_call_invocation.py diff --git a/tests/test_call_invocation.py b/tests/test_call_invocation.py new file mode 100644 index 00000000..b1d26bac --- /dev/null +++ b/tests/test_call_invocation.py @@ -0,0 +1,221 @@ +import ast +import importlib +import inspect +import pkgutil +import sys +import textwrap + +import pytest +from numba import jit + +import pynapple as nap + + +def valid_func(x): + # do not call + x.__class__ + + +@jit +def valid_func_decorated(x): + x.__class__ + + +def invalid_func(x): + x.__class__() + + +@jit +def invalid_func_decorated(x): + x.__class__() + + +class BaseClass: + def method(self): + pass + + +class ValidClass(BaseClass): + def __init__(self): + pass + + def method(self): + self.__class__ + + +class InvalidClass(BaseClass): + def __init__(self): + pass + + def method(self): + self.__class__() + + +class ValidClassNoInheritance: + def __init__(self): + pass + + def method(self): + self.__class__() + + +def is_function_or_wrapped_function(obj): + """ + Custom predicate to identify functions, including those wrapped by decorators. + + Args: + obj: The object to check. + + Returns: + bool: True if the object is a function or a wrapped function. + """ + # Unwrap the object if it’s wrapped by decorators + unwrapped = inspect.unwrap( + obj, stop=(lambda f: inspect.isfunction(f) or inspect.isbuiltin(f)) + ) + return inspect.isfunction(unwrapped) + + +def class_class_invocations(cls): + class_results = [] + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): + try: + # Get the source code of the method + source = textwrap.dedent(inspect.getsource(method)) + # Parse the source into an abstract syntax tree + tree = ast.parse(source) + # Walk the AST to check for `__call__` invocations + for node in ast.walk(tree): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + if node.func.attr == "__class__": + class_results.append(name) + break + except Exception as e: + # cannot grab source code of inherited methods. + print(cls, name, method, repr(e)) + pass + try: + class_results.remove("_define_instance") + except ValueError: + pass + return class_results + + +def subclass_class_invocations(base_class): + """ + Finds methods in subclasses of a base class where the `__call__` method is invoked. + + Args: + base_class (type): The base class to inspect. + + Returns: + dict: A dictionary with subclass names as keys and a list of method names invoking `__call__`. + """ + results = {} + + cls_results = class_class_invocations(base_class) + + if cls_results: + results[base_class.__name__] = cls_results + + for subclass in base_class.__subclasses__(): + + subclass_results = class_class_invocations(subclass) + if subclass_results: + results[subclass.__name__] = subclass_results + + return results + + +def find_class_invocations_in_function(func): + """ + Checks if a function contains a call to `__call__`. + + Args: + func (callable): The function to analyze. + + Returns: + bool: True if `__call__` is invoked in the function, False otherwise. + """ + try: + # Get the source code of the function + source = textwrap.dedent(inspect.getsource(func)) + # Parse the source into an abstract syntax tree + tree = ast.parse(source) + # Walk the AST to check for `__call__` invocations + for node in ast.walk(tree): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + if node.func.attr == "__class__": + return True + except Exception as e: + # Log the function that couldn't be analyzed + print(f"Could not analyze function {func}: {e}") + return False + + +def find_class_invocations_in_module_functions(module): + """ + Recursively find all functions in a module that invoke `__class__`. + + Args: + module (module): The module to inspect. + + Returns: + dict: A dictionary with module, class, or function names as keys and + a list of function/method names invoking `__class__`. + """ + results_func = {} + + # Inspect functions directly defined in the module + for name, func in inspect.getmembers( + module, predicate=is_function_or_wrapped_function + ): + if find_class_invocations_in_function(func): + results_func[module.__name__ + f".{name}"] = name + + # Recursively inspect submodules + if hasattr(module, "__path__"): # Only packages have a __path__ + for submodule_info in pkgutil.iter_modules(module.__path__): + submodule_name = f"{module.__name__}.{submodule_info.name}" + submodule = importlib.import_module(submodule_name) + submodule_results = find_class_invocations_in_module_functions(submodule) + if submodule_results: + results_func.update(submodule_results) + + return results_func + + +def test_find_func(): + # Get the current module + current_module = sys.modules[__name__] + + # Run the detection function + results = find_class_invocations_in_module_functions(current_module) + expected_results = { + "tests.test_call_invocation.invalid_func": "invalid_func", + "tests.test_call_invocation.invalid_func_decorated": "invalid_func_decorated", + } + assert results == expected_results + + +def test_find_class(): + # Run the detection function + results = subclass_class_invocations(BaseClass) + expected_results = {"InvalidClass": ["method"]} + assert results == expected_results + + +def test_no_direct__class__invocation_in_base_subclasses(): + results_func = find_class_invocations_in_module_functions(nap) + results_cls = subclass_class_invocations(nap.core.base_class._Base) + if results_cls != {}: + raise ValueError( + f"Direct use of __class__ found in the following _Base objects and methods: {results_cls}. \n" + "Please, replace them with `_define_instance`." + ) + + if results_cls != {}: + raise ValueError( + f"Direct use of __class__ found in the following modules and functions: {results_func}. \n" + "Please, replace them with `_define_instance`." + ) From f259096cb8327fd97d18b28d0367d054d22bea3f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 11 Dec 2024 11:31:29 -0500 Subject: [PATCH 17/17] fixed docstrings --- tests/test_call_invocation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_call_invocation.py b/tests/test_call_invocation.py index b1d26bac..bd7aa27f 100644 --- a/tests/test_call_invocation.py +++ b/tests/test_call_invocation.py @@ -103,13 +103,13 @@ def class_class_invocations(cls): def subclass_class_invocations(base_class): """ - Finds methods in subclasses of a base class where the `__call__` method is invoked. + Finds methods in subclasses of a base class where the `__class__` method is invoked. Args: base_class (type): The base class to inspect. Returns: - dict: A dictionary with subclass names as keys and a list of method names invoking `__call__`. + dict: A dictionary with subclass names as keys and a list of method names invoking `__class__`. """ results = {} @@ -129,13 +129,13 @@ def subclass_class_invocations(base_class): def find_class_invocations_in_function(func): """ - Checks if a function contains a call to `__call__`. + Checks if a function contains a call to `__class__`. Args: func (callable): The function to analyze. Returns: - bool: True if `__call__` is invoked in the function, False otherwise. + bool: True if `__class__` is invoked in the function, False otherwise. """ try: # Get the source code of the function