diff --git a/io/eolearn/io/geopedia.py b/io/eolearn/io/geopedia.py index 25975c1de..258f5e1ee 100644 --- a/io/eolearn/io/geopedia.py +++ b/io/eolearn/io/geopedia.py @@ -26,12 +26,10 @@ class AddGeopediaFeature(EOTask): * rasterize back and add raster to EOPatch """ - def __init__(self, feature_type, feature_name, layer, theme, - raster_value, raster_dtype=np.uint8, no_data_val=0, + def __init__(self, feature, layer, theme, raster_value, raster_dtype=np.uint8, no_data_val=0, image_format=MimeType.PNG, mean_abs_difference=2): - self.feature_type = feature_type - self.feature_name = feature_name + self.feature_type, self.feature_name = next(self._parse_features(feature)()) self.raster_value = raster_value self.raster_dtype = raster_dtype @@ -96,8 +94,8 @@ def _map_from_binaries(self, eopatch, dst_shape, request_data): """ Each request represents a binary class which will be mapped to the scalar `raster_value` """ - if eopatch.feature_exists(self.feature_type, self.feature_name): - raster = eopatch.get_feature(self.feature_type, self.feature_name).squeeze() + if self.feature_name in eopatch[self.feature_type]: + raster = eopatch[self.feature_type][self.feature_name].squeeze() else: raster = np.ones(dst_shape, dtype=self.raster_dtype) * self.no_data_val @@ -137,7 +135,7 @@ def execute(self, eopatch): """ Add requested feature to this existing EOPatch. """ - data_arr = eopatch.get_feature(FeatureType.MASK, 'IS_DATA') + data_arr = eopatch[FeatureType.MASK]['IS_DATA'] _, height, width, _ = data_arr.shape request = self._get_wms_request(eopatch.bbox, width, height) @@ -151,9 +149,9 @@ def execute(self, eopatch): else: raise ValueError("Unsupported raster value type") - if (self.feature_type in [FeatureType.MASK_TIMELESS]) and raster.ndim == 2: + if self.feature_type is FeatureType.MASK_TIMELESS and raster.ndim == 2: raster = raster[..., np.newaxis] - eopatch.add_feature(self.feature_type, self.feature_name, raster) + eopatch[self.feature_type][self.feature_name] = raster return eopatch diff --git a/io/eolearn/io/local_io.py b/io/eolearn/io/local_io.py index cff616ec0..deded2a88 100644 --- a/io/eolearn/io/local_io.py +++ b/io/eolearn/io/local_io.py @@ -14,10 +14,8 @@ class ExportToTiff(SaveToDisk): """ Task exports specified feature to Geo-Tiff. - :param feature_type: Type of the raster feature which will be exported - :type feature_type: eolearn.core.FeatureType - :param feature_name: Name of the raster feature which will be exported - :type feature_name: str + :param feature: Feature which will be exported + :type feature: (FeatureType, str) :param folder: root directory where all Geo-Tiff images will be saved :type folder: str :param band_count: Number of bands to be added to tiff image @@ -28,17 +26,18 @@ class ExportToTiff(SaveToDisk): :type no_data_value: int or float """ - def __init__(self, feature_type, feature_name, folder='.', *, band_count=1, image_dtype=np.uint8, no_data_value=0): + def __init__(self, feature, folder='.', *, band_count=1, image_dtype=np.uint8, no_data_value=0): super().__init__(folder) - self.feature_type = feature_type - self.feature_name = feature_name + self.feature = self._parse_features(feature) self.band_count = band_count self.image_dtype = image_dtype self.no_data_value = no_data_value def execute(self, eopatch, *, filename): - array = eopatch.get_feature(self.feature_type, self.feature_name) + + feature_type, feature_name = next(self.feature(eopatch)) + array = eopatch[feature_type][feature_name] if self.band_count == 1: array = array[..., 0]