Skip to content

Commit

Permalink
More fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Jun 24, 2024
1 parent 8c1505d commit 8516aec
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 90 deletions.
132 changes: 84 additions & 48 deletions speed/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,20 @@ def extract_data(
)
@click.option("--include_geo", help="Include GEO observations.", is_flag=True, default=False)
@click.option("--include_geo_ir", help="Include geo IR observations.", is_flag=True, default=False)
@click.option(
"--n_processes",
help="The number of processes to use for the data extraction.",
default=1
)
def extract_training_data_spatial(
collocation_path: str,
output_folder: str,
overlap: float = 0.0,
size: int = 256,
min_input_frac: float = None,
include_geo: bool = False,
include_geo_ir: bool = False
include_geo_ir: bool = False,
n_processes: int = 1
) -> int:
"""
Extract spatial training scenes from collocations in COLLOCATION_PATH and write scenes
Expand All @@ -148,24 +154,47 @@ def extract_training_data_spatial(


collocation_files = sorted(list(collocation_path.glob("*.nc")))
for collocation_file in track(collocation_files, "Extracting spatial training data:"):

sensor_name = collocation_file.name.split("_")[1]

try:
extract_scenes(
if n_processes < 2:
for collocation_file in track(collocation_files, "Extracting spatial training data:"):
try:
extract_scenes(
collocation_file,
output_folder,
overlap=overlap,
size=size,
include_geo=include_geo,
include_geo_ir=include_geo_ir,
)
except Exception:
LOGGER.exception(
"Encountered an error when processing file %s.",
collocation_file
)
else:
pool = ProcessPoolExecutor(max_workers=n_processes)
manager = multiprocessing.Manager()
tasks = {}
for collocation_file in collocation_files:
task = pool.submit(
extract_scenes,
collocation_file,
output_folder,
overlap=overlap,
size=size,
include_geo=include_geo,
include_geo_ir=include_geo_ir,
)
except Exception:
LOGGER.exception(
"Encountered an error when processing file %s.",
collocation_file
include_geo_ir=include_geo_ir
)
tasks[task] = collocation_file

for task in track(tasks, "Extracting spatial training data:"):
try:
task.result()
except Exception as exc:
LOGGER.exception(
"Encountered an error processing target file %s.",
collocation_file
)

return 1

Expand Down Expand Up @@ -246,7 +275,7 @@ def extract_training_data_tabular(

encoding = {
"observations": {
"zlib": True,
"compression": "zstd",
"scale_factor": 0.01,
"dtype": "uint16",
"_FillValue": uint16_max
Expand All @@ -257,8 +286,8 @@ def extract_training_data_tabular(

# PMW data
encoding_pmw = {
"observations": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "zlib": True},
"earth_incidence_angle": {"dtype": "int16", "_FillValue": -(2e-15), "scale_factor": 0.01, "zlib": True},
"observations": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "compression": "zstd"},
"earth_incidence_angle": {"dtype": "int16", "_FillValue": -(2**15), "scale_factor": 0.01, "compression": "zstd"},
}
(output_folder / sensor).mkdir(exist_ok=True)
pmw_data.to_netcdf(
Expand All @@ -269,23 +298,23 @@ def extract_training_data_tabular(

# Ancillary data
encoding_anc = {
"wet_bulb_temperature": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "zlib": True},
"two_meter_temperature": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "zlib": True},
"lapse_rate": {"dtype": "int16", "_FillValue": -1e15, "scale_factor": 0.01, "zlib": True},
"total_column_water_vapor": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 0.5, "zlib": True},
"surface_temperature": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "zlib": True},
"moisture_convergence": {"dtype": "float32", "zlib": True},
"leaf_area_index": {"dtype": "float32", "zlib": True},
"snow_depth": {"dtype": "float32", "zlib": True},
"orographic_wind": {"dtype": "float32", "zlib": True},
"10m_wind": {"dtype": "float32", "zlib": True},
"surface_type": {"dtype": "uint8", "zlib": True},
"mountain_type": {"dtype": "uint8", "zlib": True},
"quality_flag": {"dtype": "uint8", "zlib": True},
"land_fraction": {"dtype": "uint8", "zlib": True},
"ice_fraction": {"dtype": "uint8", "zlib": True},
"sunglint_angle": {"dtype": "int8", "_FillValue": 127, "zlib": True},
"airlifting_index": {"dtype": "uint8", "zlib": True},
"wet_bulb_temperature": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "compression": "zstd"},
"two_meter_temperature": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "compression": "zstd"},
"lapse_rate": {"dtype": "int16", "_FillValue": -1e15, "scale_factor": 0.01, "compression": "zstd"},
"total_column_water_vapor": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 0.5, "compression": "zstd"},
"surface_temperature": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "compression": "zstd"},
"moisture_convergence": {"dtype": "float32", "compression": "zstd"},
"leaf_area_index": {"dtype": "float32", "compression": "zstd"},
"snow_depth": {"dtype": "float32", "compression": "zstd"},
"orographic_wind": {"dtype": "float32", "compression": "zstd"},
"10m_wind": {"dtype": "float32", "compression": "zstd"},
"surface_type": {"dtype": "uint8", "compression": "zstd"},
"mountain_type": {"dtype": "uint8", "compression": "zstd"},
"quality_flag": {"dtype": "uint8", "compression": "zstd"},
"land_fraction": {"dtype": "uint8", "compression": "zstd"},
"ice_fraction": {"dtype": "uint8", "compression": "zstd"},
"sunglint_angle": {"dtype": "int8", "_FillValue": 127, "compression": "zstd"},
"airlifting_index": {"dtype": "uint8", "compression": "zstd"},
}

(output_folder / "ancillary").mkdir(exist_ok=True)
Expand All @@ -296,14 +325,14 @@ def extract_training_data_tabular(

# Target data
encoding_target = {
"surface_precip": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "zlib": True},
"radar_quality_index": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "zlib": True},
"valid_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "zlib": True},
"precip_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "zlib": True},
"snow_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "zlib": True},
"hail_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "zlib": True},
"convective_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "zlib": True},
"stratiform_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "zlib": True},
"surface_precip": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "compression": "zstd"},
"radar_quality_index": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "compression": "zstd"},
"valid_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "compression": "zstd"},
"precip_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "compression": "zstd"},
"snow_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "compression": "zstd"},
"hail_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "compression": "zstd"},
"convective_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "compression": "zstd"},
"stratiform_fraction": {"dtype": "uint8", "_FillValue": 255, "scale_factor": 1.0/254.0, "compression": "zstd"},
}
(output_folder / "target").mkdir(exist_ok=True)
target_data.to_netcdf(
Expand All @@ -312,7 +341,7 @@ def extract_training_data_tabular(
)

encoding = {
"observations": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "zlib": True},
"observations": {"dtype": "uint16", "_FillValue": uint16_max, "scale_factor": 0.01, "compression": "zstd"},
}
if len(geo_data) > 0:
geo_data = xr.concat(geo_data, dim="samples")
Expand Down Expand Up @@ -380,13 +409,20 @@ def extract_evaluation_data(
LOGGER.info(f"Found {len(combined)} collocations in {collocation_path}.")

for median_time in track(combined, description="Extracting evaluation data:"):
extract_evaluation_data(
times_gridded[median_time],
times_on_swath[median_time],
output_folder,
include_geo=include_geo,
include_geo_ir=include_geo_ir
)
try:
extract_evaluation_data(
times_gridded[median_time],
times_on_swath[median_time],
output_folder,
include_geo=include_geo,
include_geo_ir=include_geo_ir
)
except Exception:
LOGGER.exception(
"Encountered an error when processing validation scene %s.",
median_time
)




Expand Down
4 changes: 3 additions & 1 deletion speed/data/cpcir.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def add_cpcir_obs(
median_time = to_datetime64(datetime.strptime(time_str, "%Y%m%d%H%M%S"))
rounded = round_time(median_time, np.timedelta64(30, "m"))
offsets = (np.arange(-n_steps // 2, n_steps // 2) + 1) * np.timedelta64("30", "m")
time_steps = rounded - offsets
time_steps = rounded + offsets
cpcir_recs = merged_ir.get(TimeRange(time_steps.min(), time_steps.max()))

with xr.open_dataset(path_gridded, group="input_data") as data_gridded:
Expand Down Expand Up @@ -86,10 +86,12 @@ def add_cpcir_obs(

# Save data in gridded format.
cpcir_data_g = xr.concat(cpcir_data_g, "time").sortby("time").rename({"Tb": "tbs_ir"})
cpcir_data_g = cpcir_data_g.interp(time=time_steps, method="nearest")
cpcir_data_g.to_netcdf(path_gridded, group="geo_ir", mode="a")

# Save data in gridded format.
cpcir_data_n = xr.concat(cpcir_data_n, "time").sortby("time").rename({"Tb": "tbs_ir"})
cpcir_data_n = cpcir_data_g.interp(time=time_steps, method="nearest")
cpcir_data_n.to_netcdf(path_on_swath, group="geo_ir", mode="a")


Expand Down
2 changes: 1 addition & 1 deletion speed/data/goes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def add_goes_obs(
median_time = to_datetime64(datetime.strptime(time_str, "%Y%m%d%H%M%S"))
rounded = round_time(median_time, time_step)
offsets = (np.arange(-n_steps // 2, n_steps // 2) + 1) * time_step
time_steps = rounded - offsets
time_steps = rounded + offsets

if sector.lower() == "conus":
products = [GOES16L1BRadiances("C", channel) for channel in range(1, 17)]
Expand Down
Loading

0 comments on commit 8516aec

Please sign in to comment.