Skip to content

Commit

Permalink
Commands to extract spatial and tabular training data.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed May 22, 2024
1 parent 02b0eff commit 2d59d2e
Show file tree
Hide file tree
Showing 4 changed files with 497 additions and 85 deletions.
107 changes: 90 additions & 17 deletions speed/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import click

import speed.logging
from speed.data import cpcir, goes

LOGGER = logging.getLogger(__file__)

Expand Down Expand Up @@ -108,21 +109,16 @@ def extract_data(
type=int,
default=256,
)
@click.option(
"--filename_pattern",
type=str,
default="collocation_{time}",
)
def extract_training_data_2d(
def extract_training_data_spatial(
collocation_path: str,
output_folder: str,
overlap: float = 0.0,
size: int = 256,
filename_pattern: str = "collocation_{time}",
min_input_frac: float = None
) -> int:
"""
Extract collocations for a given date.
Extract spatial training scenes from collocations in COLLOCATION_PATH and write scenes
to OUTPUT_FOLDER.
"""
from speed.data.utils import extract_scenes

Expand All @@ -139,17 +135,94 @@ def extract_training_data_2d(

collocation_files = sorted(list(collocation_path.glob("*.nc")))
for collocation_file in tqdm(collocation_files):

sensor_name = collocation_file.name.split("_")[1]
input_data = xr.load_dataset(collocation_file, group="input_data")
reference_data = xr.load_dataset(collocation_file, group="reference_data")
extract_scenes(
input_data,
reference_data,
output_folder,
overlap=overlap,
size=size,
filename_pattern=filename_pattern
)

try:
extract_scenes(
sensor_name,
input_data,
reference_data,
output_folder,
overlap=overlap,
size=size,
)
except Exception:
LOGGER.exception(
"Encountered an error when processing file %s.",
collocation_file
)

return 1

cli.add_command(extract_training_data_2d)
cli.add_command(extract_training_data_spatial, name="extract_training_data_spatial")
cli.add_command(cpcir.cli, name="extract_cpcir_obs")
cli.add_command(goes.cli, name="extract_goes_obs")


@click.command()
@click.argument("collocation_path")
@click.argument("output_folder")
def extract_training_data_tabular(
collocation_path: str,
output_folder: str,
) -> int:
"""
Extract tabular training data from collocations in COLLOCATION_PATH and write resulting files
to OUTPUT_FOLDER.
"""
from speed.data.utils import extract_training_data

output_folder = Path(output_folder)
output_folder.mkdir(exist_ok=True, parents=True)

collocation_path = Path(collocation_path)
if not collocation_path.exists():
LOGGER.error(
"'collocation_path' must point to an existing directory."
)
return 1


collocation_files = sorted(list(collocation_path.glob("*.nc")))
for collocation_file in tqdm(collocation_files[:50]):

sensor_name = collocation_file.name.split("_")[1]
input_data = xr.load_dataset(collocation_file, group="input_data")
reference_data = xr.load_dataset(collocation_file, group="reference_data")

inpt_data = []
anc_data = []
trgt_data = []

try:
inpt, anc, trgt = extract_training_data(
input_data,
reference_data,
)
inpt_data.append(inpt)
anc_data.append(anc)
trgt_data.append(trgt)

except Exception:
LOGGER.exception(
"Encountered an error when processing file %s.",
collocation_file
)

input_data = xr.concat(inpt_data, dim="samples")
ancillary_data = xr.concat(anc_data, dim="samples")
target_data = xr.concat(trgt_data, dim="samples")

input_data.to_netcdf(output_folder / "pmw.nc")
ancillary_data.to_netcdf(output_folder / "ancillary.nc")
target_data.to_netcdf(output_folder / "target.nc")

return 0

cli.add_command(extract_training_data_spatial, name="extract_training_data_spatial")
cli.add_command(extract_training_data_tabular, name="extract_training_data_tabular")
cli.add_command(cpcir.cli, name="extract_cpcir_obs")
cli.add_command(goes.cli, name="extract_goes_obs")
77 changes: 61 additions & 16 deletions speed/data/gpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
speed.data.gpm
==============
This module contains the code to process GPM L1C data into SPREE collocations.
This module contains the code to process GPM L1C data into SPEED collocations.
"""
from datetime import datetime, timedelta
import logging
Expand Down Expand Up @@ -46,6 +46,7 @@
calculate_swath_resample_indices,
resample_data,
extract_rect,
get_useful_scan_range,
)
from speed.data.reference import ReferenceData
from speed.data.input import InputData
Expand Down Expand Up @@ -82,7 +83,8 @@ def run_preprocessor(gpm_granule: Granule) -> xr.Dataset:

preprocessor_data = preprocessor_data.rename({
"channels": "channels_gprof",
"brightness_temperatures": "tbs_mw_gprof"
"brightness_temperatures": "tbs_mw_gprof",
"earth_incidence_angle": "earth_incidence_angle_gprof"
})
invalid = preprocessor_data.tbs_mw_gprof.data < 0
preprocessor_data.tbs_mw_gprof.data[invalid] = np.nan
Expand Down Expand Up @@ -123,11 +125,22 @@ def load_l1c_brightness_temperatures(
source,
l1c_data.tbs.data,
target,
radius_of_influence=radius_of_influence
radius_of_influence=radius_of_influence,
fill_value=np.nan
)
return tbs
eia = kd_tree.resample_nearest(
source,
l1c_data.incidence_angle.data,
target,
radius_of_influence=radius_of_influence,
fill_value=np.nan
)
if eia.ndim < tbs.ndim:
eia = np.broadcast_to(eia[..., None], tbs.shape)
return tbs, eia

tbs = []
eia = []
swath = 1
while f"latitude_s{swath}" in l1c_data.variables:
lats = l1c_data[f"latitude_s{swath}"].data
Expand All @@ -140,15 +153,27 @@ def load_l1c_brightness_temperatures(
source,
l1c_data[f"tbs_s{swath}"].data,
target,
radius_of_influence=radius_of_influence
radius_of_influence=radius_of_influence,
fill_value=np.nan
)
)
eia_s = kd_tree.resample_nearest(
source,
l1c_data[f"incidence_angle_s{swath}"].data,
target,
radius_of_influence=radius_of_influence,
fill_value=np.nan
)
if eia_s.ndim < tbs[-1].ndim:
eia_s = np.broadcast_to(eia_s[..., None], tbs[-1].shape)
eia.append(eia_s)
swath += 1

tbs = np.concatenate(tbs, axis=-1)
eia = np.concatenate(eia, axis=-1)
invalid = tbs < 0
tbs[invalid] = np.nan
return tbs
return tbs, eia


class GPMInput(InputData):
Expand Down Expand Up @@ -228,33 +253,58 @@ def process_match(
return None

preprocessor_data = run_preprocessor(inpt_granule)
tbs_mw = load_l1c_brightness_temperatures(
preprocessor_data.attrs.pop("frequencies")

tbs_mw, eia_mw = load_l1c_brightness_temperatures(
inpt_granule,
preprocessor_data,
self.radius_of_influence
)
preprocessor_data["tbs_mw"] = (("scans", "pixels", "channels"), tbs_mw.data)
preprocessor_data["earth_incidence_angle"] = (("scans", "pixels", "channels"), eia_mw.data)
preprocessor_data["tbs_mw"].attrs.update(self.characteristics["channels"])
preprocessor_data["tbs_mw_gprof"].attrs.update(self.characteristics["channels_gprof"])

# Load and combine reference data for all matche granules
ref_data = reference_data.load_reference_data(inpt_granule, ref_granules)
ref_data, ref_data_fpavg = reference_data.load_reference_data(
inpt_granule,
ref_granules,
beam_width=0.98
)
if ref_data is None:
LOGGER.info(
"No reference data for %s.",
ref_granules
)
return None


reference_data_r = ref_data.interp(
latitude=preprocessor_data.latitude,
longitude=preprocessor_data.longitude,
method="nearest",
)
reference_data_fpavg = ref_data.load_reference_data_fpavg(inpt_granule, ref_granules)
for var in reference_data_fpavg:
reference_data_r["time"] = ref_data.time.astype(np.int64).interp(
latitude=preprocessor_data.latitude,
longitude=preprocessor_data.longitude,
method="nearest"
).astype("datetime64[ns]")

for var in ref_data_fpavg:
if var in reference_data_r:
reference_data_r[var + "_fpavg"] = reference_data_fpavg[var]
reference_data_r[var + "_fpavg"] = ref_data_fpavg[var]

# Limit scans to scans with useful data.
scan_start, scan_end = get_useful_scan_range(
reference_data_r,
"surface_precip",
min_scans=256,
margin=64
)
preprocessor_data = preprocessor_data[{"scans": slice(scan_start, scan_end)}]
reference_data_r = reference_data_r[{"scans": slice(scan_start, scan_end)}]
preprocessor_data.attrs["scan_start"] = inpt_granule.primary_index_range[0] + scan_start
preprocessor_data.attrs["scan_end"] = inpt_granule.primary_index_range[0] + scan_end

row_start = ref_data.attrs.get("lower_left_row", 0)
n_rows = ref_data.latitude.size
Expand Down Expand Up @@ -311,11 +361,6 @@ def process_match(
ref_data["scan_index"] = indices.scan_index
ref_data["pixel_index"] = indices.pixel_index

if "time" in ref_data:
scan_time = preprocessor_data_r.scan_time
scan_time = scan_time.fillna(value=scan_time.min())
ref_data = interp_along_swath(ref_data, scan_time, dimension="time")

LOGGER.info(
"Saving file in gridded format to %s.",
output_folder
Expand Down
Loading

0 comments on commit 2d59d2e

Please sign in to comment.