Refactored efast.py to leverage efast package functions.
This commit is contained in:
parent
853c1c6a30
commit
6741433228
3 changed files with 81 additions and 182 deletions
251
efast.py
251
efast.py
|
|
@ -1,36 +1,24 @@
|
|||
import json
|
||||
import shutil
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import rasterio
|
||||
from rasterio.warp import Resampling
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rasterio import shutil as rio_shutil
|
||||
from scipy import ndimage
|
||||
|
||||
RESOLUTION_RATIO = 21
|
||||
|
||||
try:
|
||||
import efast as efast_fusion
|
||||
import efast
|
||||
from efast.s2_processing import distance_to_clouds
|
||||
from efast.s3_processing import reproject_and_crop_s3
|
||||
except ImportError:
|
||||
import site
|
||||
raise ImportError(
|
||||
"efast package not found. Install with: pip install git+https://github.com/DHI-GRAS/efast.git"
|
||||
)
|
||||
|
||||
efast_fusion = None
|
||||
for site_pkg in site.getsitepackages():
|
||||
candidate = Path(site_pkg) / "efast" / "efast.py"
|
||||
if candidate.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"efast_fusion_module", candidate
|
||||
)
|
||||
efast_fusion = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(efast_fusion)
|
||||
break
|
||||
if efast_fusion is None:
|
||||
raise ImportError(
|
||||
"efast package not found. Install with: pip install git+https://github.com/DHI-GRAS/efast.git"
|
||||
)
|
||||
RESOLUTION_RATIO = 21
|
||||
|
||||
|
||||
def _load_clouds(clouds_file):
|
||||
|
|
@ -42,27 +30,24 @@ def _load_clouds(clouds_file):
|
|||
return clouds
|
||||
|
||||
|
||||
def _reproject_to_target(
|
||||
data, src_transform, src_crs, target_bounds, target_crs, width, height, resampling
|
||||
):
|
||||
def _reproject_raster_to_target(src_path, dst_path, target_bounds, target_crs, width, height, resampling=Resampling.cubic):
|
||||
dst_transform = rasterio.transform.from_bounds(
|
||||
target_bounds.left,
|
||||
target_bounds.bottom,
|
||||
target_bounds.right,
|
||||
target_bounds.top,
|
||||
width,
|
||||
height,
|
||||
target_bounds.left, target_bounds.bottom,
|
||||
target_bounds.right, target_bounds.top,
|
||||
width, height
|
||||
)
|
||||
reprojected, _ = rasterio.warp.reproject(
|
||||
source=data,
|
||||
destination=np.zeros((data.shape[0], height, width), dtype=data.dtype),
|
||||
src_transform=src_transform,
|
||||
src_crs=src_crs,
|
||||
dst_transform=dst_transform,
|
||||
dst_crs=target_crs,
|
||||
resampling=resampling,
|
||||
)
|
||||
return reprojected, dst_transform
|
||||
with rasterio.open(src_path) as src:
|
||||
vrt_options = {
|
||||
"transform": dst_transform,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"crs": target_crs,
|
||||
"resampling": resampling,
|
||||
}
|
||||
with WarpedVRT(src, **vrt_options) as vrt:
|
||||
profile = vrt.profile.copy()
|
||||
profile.update({"dtype": "float32", "nodata": 0})
|
||||
rio_shutil.copy(vrt, dst_path, driver="GTiff", **profile)
|
||||
|
||||
|
||||
def prepare_s2(season, site_position, site_name, date_range=None):
|
||||
|
|
@ -72,7 +57,6 @@ def prepare_s2(season, site_position, site_name, date_range=None):
|
|||
clouds_file = Path(f"data/{site_name}/{season}/clouds.json")
|
||||
|
||||
clouds = _load_clouds(clouds_file)
|
||||
|
||||
s2_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
s3_files = [f for f in s3_dir.glob("*.geotiff") if f.name not in clouds["s3"]]
|
||||
|
|
@ -82,77 +66,29 @@ def prepare_s2(season, site_position, site_name, date_range=None):
|
|||
with rasterio.open(s3_files[0]) as s3_ref:
|
||||
target_bounds = s3_ref.bounds
|
||||
target_crs = s3_ref.crs
|
||||
s3_width = s3_ref.width
|
||||
s3_height = s3_ref.height
|
||||
s2_width = s3_width * RESOLUTION_RATIO
|
||||
s2_height = s3_height * RESOLUTION_RATIO
|
||||
s2_width = s3_ref.width * RESOLUTION_RATIO
|
||||
s2_height = s3_ref.height * RESOLUTION_RATIO
|
||||
|
||||
for s2_file in s2_dir.glob("*.geotiff"):
|
||||
if s2_file.name in clouds["s2"]:
|
||||
continue
|
||||
date_str = s2_file.name.split("_")[0]
|
||||
|
||||
refl_dst = s2_output_dir / f"S2A_MSIL2A_{date_str}_REFL.tif"
|
||||
if not refl_dst.exists():
|
||||
with rasterio.open(s2_file) as src:
|
||||
data = src.read().astype("float32") / 10000.0
|
||||
reprojected_data, dst_transform = _reproject_to_target(
|
||||
data,
|
||||
src.transform,
|
||||
src.crs,
|
||||
target_bounds,
|
||||
target_crs,
|
||||
s2_width,
|
||||
s2_height,
|
||||
Resampling.cubic,
|
||||
)
|
||||
profile = src.profile.copy()
|
||||
profile.update(
|
||||
{
|
||||
"dtype": "float32",
|
||||
"nodata": 0,
|
||||
"width": s2_width,
|
||||
"height": s2_height,
|
||||
"transform": dst_transform,
|
||||
"crs": target_crs,
|
||||
}
|
||||
)
|
||||
with rasterio.open(refl_dst, "w", **profile) as dst_file:
|
||||
dst_file.write(reprojected_data)
|
||||
if refl_dst.exists():
|
||||
continue
|
||||
|
||||
dist_cloud_dst = s2_output_dir / f"S2A_MSIL2A_{date_str}_DIST_CLOUD.tif"
|
||||
if not dist_cloud_dst.exists():
|
||||
with rasterio.open(refl_dst) as src:
|
||||
s2_hr = src.read(1)
|
||||
mask = s2_hr == 0
|
||||
distance_to_cloud_hr = np.clip(
|
||||
ndimage.distance_transform_edt(~mask), 0, 255
|
||||
).astype("float32")
|
||||
temp_normalized = s2_output_dir / f"temp_{s2_file.name}"
|
||||
with rasterio.open(s2_file) as src:
|
||||
data = src.read().astype("float32") / 10000.0
|
||||
profile = src.profile.copy()
|
||||
profile.update({"dtype": "float32", "nodata": 0})
|
||||
with rasterio.open(temp_normalized, "w", **profile) as dst:
|
||||
dst.write(data)
|
||||
|
||||
distance_to_cloud_lr, lr_transform = _reproject_to_target(
|
||||
distance_to_cloud_hr[np.newaxis, :, :],
|
||||
src.transform,
|
||||
src.crs,
|
||||
target_bounds,
|
||||
target_crs,
|
||||
s3_width,
|
||||
s3_height,
|
||||
Resampling.average,
|
||||
)
|
||||
distance_to_cloud_lr = distance_to_cloud_lr[0]
|
||||
_reproject_raster_to_target(temp_normalized, refl_dst, target_bounds, target_crs, s2_width, s2_height)
|
||||
temp_normalized.unlink()
|
||||
|
||||
profile = src.profile.copy()
|
||||
profile.update(
|
||||
{
|
||||
"count": 1,
|
||||
"dtype": "float32",
|
||||
"width": s3_width,
|
||||
"height": s3_height,
|
||||
"transform": lr_transform,
|
||||
}
|
||||
)
|
||||
with rasterio.open(dist_cloud_dst, "w", **profile) as dst:
|
||||
dst.write(distance_to_cloud_lr, 1)
|
||||
distance_to_clouds(s2_output_dir, ratio=RESOLUTION_RATIO)
|
||||
|
||||
|
||||
def prepare_s3(season, site_position, site_name, date_range=None):
|
||||
|
|
@ -164,72 +100,37 @@ def prepare_s3(season, site_position, site_name, date_range=None):
|
|||
clouds = _load_clouds(clouds_file)
|
||||
s3_preprocessed_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get reference profile from S2 DIST_CLOUD file
|
||||
dist_cloud_files = list(s2_prepared_dir.glob("*DIST_CLOUD.tif"))
|
||||
if not dist_cloud_files:
|
||||
raise ValueError("No S2 DIST_CLOUD files found. Run prepare_s2 first.")
|
||||
|
||||
with rasterio.open(dist_cloud_files[0]) as src:
|
||||
target_profile = src.profile
|
||||
|
||||
# Group S3 files by date
|
||||
s3_by_date = {}
|
||||
s3_by_date = defaultdict(list)
|
||||
for s3_file in s3_dir.glob("*.geotiff"):
|
||||
if s3_file.name in clouds["s3"]:
|
||||
continue
|
||||
date_str = s3_file.name.split("_")[0]
|
||||
if date_str not in s3_by_date:
|
||||
s3_by_date[date_str] = []
|
||||
s3_by_date[date_str].append(s3_file)
|
||||
if s3_file.name not in clouds["s3"]:
|
||||
s3_by_date[s3_file.name.split("_")[0]].append(s3_file)
|
||||
|
||||
temp_composite_dir = s3_preprocessed_dir / "temp_composites"
|
||||
if temp_composite_dir.exists():
|
||||
shutil.rmtree(temp_composite_dir)
|
||||
temp_composite_dir.mkdir()
|
||||
|
||||
# Process each date
|
||||
for date_str, s3_files in s3_by_date.items():
|
||||
output_path = s3_preprocessed_dir / f"composite_{date_str}.tif"
|
||||
if output_path.exists():
|
||||
continue
|
||||
|
||||
composite_path = temp_composite_dir / f"composite_{date_str}.tif"
|
||||
if len(s3_files) == 1:
|
||||
# Single file: reproject directly
|
||||
with rasterio.open(s3_files[0]) as src:
|
||||
vrt_options = {
|
||||
"transform": target_profile["transform"],
|
||||
"height": target_profile["height"],
|
||||
"width": target_profile["width"],
|
||||
"crs": target_profile["crs"],
|
||||
"resampling": Resampling.cubic,
|
||||
}
|
||||
with WarpedVRT(src, **vrt_options) as vrt:
|
||||
rio_shutil.copy(vrt, output_path, driver="GTiff")
|
||||
shutil.copy(s3_files[0], composite_path)
|
||||
else:
|
||||
# Multiple files: create weighted composite
|
||||
s3_stack = []
|
||||
for s3_file in s3_files:
|
||||
with rasterio.open(s3_file) as src:
|
||||
vrt_options = {
|
||||
"transform": target_profile["transform"],
|
||||
"height": target_profile["height"],
|
||||
"width": target_profile["width"],
|
||||
"crs": target_profile["crs"],
|
||||
"resampling": Resampling.cubic,
|
||||
}
|
||||
with WarpedVRT(src, **vrt_options) as vrt:
|
||||
data = vrt.read()
|
||||
# Remove abnormally high values (pixel-wise mean across bands)
|
||||
pixel_means = np.abs(np.nanmean(data, axis=0))
|
||||
mask = pixel_means >= 5
|
||||
data[:, mask] = np.nan
|
||||
s3_stack.append(data)
|
||||
|
||||
s3_stack = np.array(s3_stack)
|
||||
# Simple mean composite (can be enhanced with temporal weighting)
|
||||
composite = np.nanmean(s3_stack, axis=0)
|
||||
composite = composite.astype("float32")
|
||||
|
||||
profile = target_profile.copy()
|
||||
profile.update({"count": composite.shape[0], "dtype": "float32"})
|
||||
with rasterio.open(output_path, "w", **profile) as dst:
|
||||
data = src.read()
|
||||
data[:, np.abs(np.nanmean(data, axis=0)) >= 5] = np.nan
|
||||
s3_stack.append(data)
|
||||
composite = np.nanmean(np.array(s3_stack), axis=0).astype("float32")
|
||||
with rasterio.open(s3_files[0]) as src:
|
||||
profile = src.profile.copy()
|
||||
profile.update({"count": composite.shape[0], "dtype": "float32"})
|
||||
with rasterio.open(composite_path, "w", **profile) as dst:
|
||||
dst.write(composite)
|
||||
|
||||
reproject_and_crop_s3(temp_composite_dir, s2_prepared_dir, s3_preprocessed_dir)
|
||||
shutil.rmtree(temp_composite_dir)
|
||||
|
||||
|
||||
def run_efast(season, site_position, site_name, date_range=None):
|
||||
lat, lon = site_position
|
||||
|
|
@ -241,7 +142,6 @@ def run_efast(season, site_position, site_name, date_range=None):
|
|||
fusion_output_dir = efast_base_dir / "fusion"
|
||||
|
||||
fusion_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"[EFAST] Starting fusion: {site_name} ({lat:.6f}, {lon:.6f}), {season}")
|
||||
|
||||
start_str, end_str = datetime_range.split("/")
|
||||
|
|
@ -251,32 +151,19 @@ def run_efast(season, site_position, site_name, date_range=None):
|
|||
current_date = start_date
|
||||
while current_date <= end_date:
|
||||
date_str = current_date.strftime("%Y%m%d")
|
||||
|
||||
output_file = fusion_output_dir / f"REFL_{date_str}.tif"
|
||||
if output_file.exists():
|
||||
print(f"[EFAST] Skipping {date_str} (exists)")
|
||||
current_date += timedelta(days=1)
|
||||
continue
|
||||
|
||||
try:
|
||||
efast_fusion.fusion(
|
||||
current_date,
|
||||
s3_output_dir,
|
||||
s2_output_dir,
|
||||
fusion_output_dir,
|
||||
product="REFL",
|
||||
max_days=30,
|
||||
date_position=2,
|
||||
minimum_acquisition_importance=0.0,
|
||||
ratio=RESOLUTION_RATIO,
|
||||
)
|
||||
if output_file.exists():
|
||||
print(f"[EFAST] Saved: {output_file}")
|
||||
else:
|
||||
print(f"[EFAST] No output for {date_str} (insufficient nearby data)")
|
||||
except Exception as e:
|
||||
print(f"[EFAST] Error processing {date_str}: {e}")
|
||||
|
||||
else:
|
||||
try:
|
||||
efast.fusion(
|
||||
current_date, s3_output_dir, s2_output_dir, fusion_output_dir,
|
||||
product="REFL", max_days=30, date_position=2,
|
||||
minimum_acquisition_importance=0.0, ratio=RESOLUTION_RATIO,
|
||||
)
|
||||
print(f"[EFAST] Saved: {output_file}" if output_file.exists() else f"[EFAST] No output for {date_str} (insufficient nearby data)")
|
||||
except Exception as e:
|
||||
print(f"[EFAST] Error processing {date_str}: {e}")
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
print("[EFAST] Completed")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue