Skip to content

Commit

Permalink
change 'num_timestamp' to 'num_timestamps'
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiminzhang0830 committed May 10, 2023
1 parent 6493522 commit 1c0221e
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 71 deletions.
8 changes: 4 additions & 4 deletions ppsci/arch/afno.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ class AFNONet(base.Arch):
num_blocks (int, optional): Number of blocks. Defaults to 8.
sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01.
hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0.
num_timestamp (int, optional): Number of timestamp. Defaults to 1.
num_timestamps (int, optional): Number of timestamp. Defaults to 1.
Examples:
>>> import ppsci
Expand All @@ -445,7 +445,7 @@ def __init__(
num_blocks: int = 8,
sparsity_threshold: float = 0.01,
hard_thresholding_fraction: float = 1.0,
num_timestamp: int = 1,
num_timestamps: int = 1,
):
super().__init__()
self.input_keys = input_keys
Expand All @@ -457,7 +457,7 @@ def __init__(
self.out_channels = out_channels
self.embed_dim = embed_dim
self.num_blocks = num_blocks
self.num_timestamp = num_timestamp
self.num_timestamps = num_timestamps
norm_layer = partial(nn.LayerNorm, epsilon=1e-6)

self.patch_embed = PatchEmbed(
Expand Down Expand Up @@ -555,7 +555,7 @@ def forward(self, x):

y = []
input = x
for i in range(self.num_timestamp):
for i in range(self.num_timestamps):
out = self.forward_tensor(input)
y.append(out)
input = out
Expand Down
14 changes: 7 additions & 7 deletions ppsci/data/dataset/era5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ERA5Dataset(io.Dataset):
precip_file_path (Optional[str]): Precipitation data set path. Defaults to None.
weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
vars_channel (Optional[Tuple[int, ...]]): The variable channel index in ERA5 dataset. Defaults to None.
num_label_timestamp (int, optional): Number of timestamp of label. Defaults to 1.
num_label_timestamps (int, optional): Number of timestamp of label. Defaults to 1.
transforms (Optional[vision.Compose]): Compose object contains sample wise
transform(s). Defaults to None.
training (bool, optional): Whether in train mode. Defaults to True.
Expand All @@ -56,7 +56,7 @@ def __init__(
precip_file_path: Optional[str] = None,
weight_dict: Optional[Dict[str, float]] = None,
vars_channel: Optional[Tuple[int, ...]] = None,
num_label_timestamp: int = 1,
num_label_timestamps: int = 1,
transforms: Optional[vision.Compose] = None,
training: bool = True,
):
Expand All @@ -74,7 +74,7 @@ def __init__(
self.vars_channel = (
vars_channel if vars_channel is not None else [i for i in range(20)]
)
self.num_label_timestamp = num_label_timestamp
self.num_label_timestamps = num_label_timestamps
self.transforms = transforms
self.training = training

Expand Down Expand Up @@ -102,9 +102,9 @@ def __getitem__(self, global_idx):
local_idx = global_idx % self.n_samples_per_year
step = 0 if local_idx >= self.n_samples_per_year - 1 else 1

if self.num_label_timestamp > 1:
if local_idx >= self.n_samples_per_year - self.num_label_timestamp:
local_idx = self.n_samples_per_year - self.num_label_timestamp - 1
if self.num_label_timestamps > 1:
if local_idx >= self.n_samples_per_year - self.num_label_timestamps:
local_idx = self.n_samples_per_year - self.num_label_timestamps - 1

input_file = self.files[year_idx]
label_file = (
Expand All @@ -124,7 +124,7 @@ def __getitem__(self, global_idx):

input_item = {self.input_keys[0]: input_file[input_idx, self.vars_channel]}
label_item = {}
for i in range(self.num_label_timestamp):
for i in range(self.num_label_timestamps):
if self.precip_file_path is not None:
label_item[self.label_keys[i]] = np.expand_dims(
label_file[label_idx + i], 0
Expand Down
10 changes: 5 additions & 5 deletions ppsci/geometry/timedomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def __init__(
if time_step is not None:
if time_step <= 0:
raise ValueError(f"time_step({time_step}) must be larger than 0.")
self.num_timestamp = int(np.ceil((t1 - t0) / time_step)) + 1
self.num_timestamps = int(np.ceil((t1 - t0) / time_step)) + 1
elif timestamps is not None:
self.num_timestamp = len(timestamps)
self.num_timestamps = len(timestamps)

def on_initial(self, t):
return np.isclose(t, self.t0).flatten()
Expand Down Expand Up @@ -117,7 +117,7 @@ def uniform_points(self, n, boundary=True):
nx = int(np.ceil(n / nt))
elif self.timedomain.timestamps is not None:
# exclude start time t0
nt = self.timedomain.num_timestamp - 1
nt = self.timedomain.num_timestamps - 1
nx = int(np.ceil(n / nt))
else:
nx = int(
Expand Down Expand Up @@ -205,7 +205,7 @@ def random_points(self, n, random="pseudo", criteria=None):
tx = tx[:n]
return tx
elif self.timedomain.timestamps is not None:
nt = self.timedomain.num_timestamp - 1
nt = self.timedomain.num_timestamps - 1
t = self.timedomain.timestamps[1:]
nx = int(np.ceil(n / nt))

Expand Down Expand Up @@ -402,7 +402,7 @@ def random_boundary_points(self, n, random="pseudo", criteria=None):
return t_x
elif self.timedomain.timestamps is not None:
# exclude start time t0
nt = self.timedomain.num_timestamp - 1
nt = self.timedomain.num_timestamps - 1
t = self.timedomain.timestamps[1:]
nx = int(np.ceil(n / nt))

Expand Down
24 changes: 12 additions & 12 deletions ppsci/validate/geo_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,19 @@ def __init__(
self.output_keys = list(label_dict.keys())

nx = dataloader_cfg["total_size"]
self.num_timestamp = 1
self.num_timestamps = 1
# TODO(sensen): simplify code below
if isinstance(geom, geometry.TimeXGeometry):
if geom.timedomain.num_timestamp is not None:
if geom.timedomain.num_timestamps is not None:
if with_initial:
# include t0
self.num_timestamp = geom.timedomain.num_timestamp
self.num_timestamps = geom.timedomain.num_timestamps
assert (
nx % self.num_timestamp == 0
), f"{nx} % {self.num_timestamp} != 0"
nx //= self.num_timestamp
nx % self.num_timestamps == 0
), f"{nx} % {self.num_timestamps} != 0"
nx //= self.num_timestamps
input = geom.sample_interior(
nx * (geom.timedomain.num_timestamp - 1),
nx * (geom.timedomain.num_timestamps - 1),
random,
criteria,
evenly,
Expand All @@ -114,13 +114,13 @@ def __init__(
}
else:
# exclude t0
self.num_timestamp = geom.timedomain.num_timestamp - 1
self.num_timestamps = geom.timedomain.num_timestamps - 1
assert (
nx % self.num_timestamp == 0
), f"{nx} % {self.num_timestamp} != 0"
nx //= self.num_timestamp
nx % self.num_timestamps == 0
), f"{nx} % {self.num_timestamps} != 0"
nx //= self.num_timestamps
input = geom.sample_interior(
nx * (geom.timedomain.num_timestamp - 1),
nx * (geom.timedomain.num_timestamps - 1),
random,
criteria,
evenly,
Expand Down
62 changes: 31 additions & 31 deletions ppsci/visualize/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,21 @@
]


def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamp=1):
def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamps=1):
"""Save plot from given 1D data.
Args:
filename (str): Filename.
coord (np.ndarray): Coordinate array.
value (Dict[str, np.ndarray]): Dict of value array.
value_keys (Tuple[str, ...]): Value keys.
num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
"""
fig, a = plt.subplots(len(value_keys), num_timestamp, squeeze=False)
fig, a = plt.subplots(len(value_keys), num_timestamps, squeeze=False)
fig.subplots_adjust(hspace=0.8)

len_ts = len(coord) // num_timestamp
for t in range(num_timestamp):
len_ts = len(coord) // num_timestamps
for t in range(num_timestamps):
st = t * len_ts
ed = (t + 1) * len_ts
coord_t = coord[st:ed]
Expand All @@ -96,29 +96,29 @@ def _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamp=1
color=cnames[i],
label=key,
)
if num_timestamp > 1:
if num_timestamps > 1:
a[i][t].set_title(f"{key}(t={t})")
else:
a[i][t].set_title(f"{key}")
a[i][t].grid()
a[i][t].legend()

if num_timestamp == 1:
if num_timestamps == 1:
fig.savefig(filename, dpi=300)
else:
fig.savefig(f"{filename}_{t}", dpi=300)

if num_timestamp == 1:
if num_timestamps == 1:
logger.info(f"1D result is saved to {filename}.png")
else:
logger.info(
f"1D result is saved to {filename}_0.png"
f" ~ {filename}_{num_timestamp - 1}.png"
f" ~ {filename}_{num_timestamps - 1}.png"
)


def save_plot_from_1d_dict(
filename, data_dict, coord_keys, value_keys, num_timestamp=1
filename, data_dict, coord_keys, value_keys, num_timestamps=1
):
"""Plot dict data as file.
Expand All @@ -127,7 +127,7 @@ def save_plot_from_1d_dict(
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
"""
space_ndim = len(coord_keys) - int("t" in coord_keys)
if space_ndim not in [1, 2, 3]:
Expand All @@ -149,14 +149,14 @@ def save_plot_from_1d_dict(
value = [x for x in value]
value = np.concatenate(value, axis=1)

_save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamp)
_save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamps)


def _save_plot_from_2d_array(
filename: str,
visu_data: Tuple[np.ndarray, ...],
visu_keys: Tuple[str, ...],
num_timestamp: int = 1,
num_timestamps: int = 1,
stride: int = 1,
xticks: Optional[Tuple[float, ...]] = None,
yticks: Optional[Tuple[float, ...]] = None,
Expand All @@ -167,7 +167,7 @@ def _save_plot_from_2d_array(
filename (str): Filename.
visu_data (Tuple[np.ndarray, ...]): Data that requires visualization.
visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v").
num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
stride (int, optional): The time stride of visualization. Defaults to 1.
xticks (Optional[Tuple[float, ...]]): Tuple of xtick locations. Defaults to None.
yticks (Optional[Tuple[float, ...]]): Tuple of ytick locations. Defaults to None.
Expand All @@ -179,10 +179,10 @@ def _save_plot_from_2d_array(

fig, ax = plt.subplots(
len(visu_keys),
num_timestamp,
num_timestamps,
squeeze=False,
sharey=True,
figsize=(num_timestamp, len(visu_keys)),
figsize=(num_timestamps, len(visu_keys)),
)
fig.subplots_adjust(hspace=0.3)
target_flag = any(["target" in key for key in visu_keys])
Expand All @@ -191,7 +191,7 @@ def _save_plot_from_2d_array(
c_max = np.amax(data)
c_min = np.amin(data)

for t_idx in range(num_timestamp):
for t_idx in range(num_timestamps):
t = t_idx * stride
ax[i, t_idx].imshow(
data[t, :, :],
Expand Down Expand Up @@ -226,7 +226,7 @@ def save_plot_from_2d_dict(
filename: str,
data_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
visu_keys: Tuple[str, ...],
num_timestamp: int = 1,
num_timestamps: int = 1,
stride: int = 1,
xticks: Optional[Tuple[float, ...]] = None,
yticks: Optional[Tuple[float, ...]] = None,
Expand All @@ -237,7 +237,7 @@ def save_plot_from_2d_dict(
filename (str): Output filename.
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
stride (int, optional): The time stride of visualization. Defaults to 1.
xticks (Optional[Tuple[float,...]]): The list of xtick locations. Defaults to None.
yticks (Optional[Tuple[float,...]]): The list of ytick locations. Defaults to None.
Expand All @@ -246,7 +246,7 @@ def save_plot_from_2d_dict(
if isinstance(visu_data[0], paddle.Tensor):
visu_data = [x.numpy() for x in visu_data]
_save_plot_from_2d_array(
filename, visu_data, visu_keys, num_timestamp, stride, xticks, yticks
filename, visu_data, visu_keys, num_timestamps, stride, xticks, yticks
)


Expand Down Expand Up @@ -308,21 +308,21 @@ def _save_plot_from_3d_array(
filename: str,
visu_data: Tuple[np.ndarray, ...],
visu_keys: Tuple[str, ...],
num_timestamp: int = 1,
num_timestamps: int = 1,
):
"""Save plot from given 3D data.
Args:
filename (str): Filename.
visu_data (Tuple[np.ndarray, ...]): Data that requires visualization.
visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v").
num_timestamp (int, optional): Number of timestamps coord/value contains. Defaults to 1.
num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1.
"""

fig = plt.figure(figsize=(10, 10))
len_ts = len(visu_data[0]) // num_timestamp
for t in range(num_timestamp):
ax = fig.add_subplot(1, num_timestamp, t + 1, projection="3d")
len_ts = len(visu_data[0]) // num_timestamps
for t in range(num_timestamps):
ax = fig.add_subplot(1, num_timestamps, t + 1, projection="3d")
st = t * len_ts
ed = (t + 1) * len_ts
visu_data_t = [data[st:ed] for data in visu_data]
Expand All @@ -343,40 +343,40 @@ def _save_plot_from_3d_array(
loc="upper right",
framealpha=0.95,
)
if num_timestamp == 1:
if num_timestamps == 1:
fig.savefig(filename, dpi=300)
else:
fig.savefig(f"{filename}_{t}", dpi=300)

if num_timestamp == 1:
if num_timestamps == 1:
logger.info(f"3D result is saved to {filename}.png")
else:
logger.info(
f"3D result is saved to {filename}_0.png"
f" ~ {filename}_{num_timestamp - 1}.png"
f" ~ {filename}_{num_timestamps - 1}.png"
)


def save_plot_from_3d_dict(
filename: str,
data_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
visu_keys: Tuple[str, ...],
num_timestamp: int = 1,
num_timestamps: int = 1,
):
"""Plot dict data as file.
Args:
filename (str): Output filename.
data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
num_timestamp (int, optional): Number of timestamp in data_dict. Defaults to 1.
num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
"""

visu_data = [data_dict[k] for k in visu_keys]
if isinstance(visu_data[0], paddle.Tensor):
visu_data = [x.numpy() for x in visu_data]

_save_plot_from_3d_array(filename, visu_data, visu_keys, num_timestamp)
_save_plot_from_3d_array(filename, visu_data, visu_keys, num_timestamps)


def _save_plot_weather_from_array(
Expand Down
Loading

0 comments on commit 1c0221e

Please sign in to comment.