282 lines
10 KiB
Python
282 lines
10 KiB
Python
import json
|
|
import shutil
|
|
import importlib.util
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
import numpy as np
|
|
import rasterio
|
|
from rasterio.warp import Resampling
|
|
from rasterio.vrt import WarpedVRT
|
|
from rasterio import shutil as rio_shutil
|
|
from scipy import ndimage
|
|
|
|
RESOLUTION_RATIO = 21
|
|
|
|
try:
|
|
import efast as efast_fusion
|
|
except ImportError:
|
|
import site
|
|
|
|
efast_fusion = None
|
|
for site_pkg in site.getsitepackages():
|
|
candidate = Path(site_pkg) / "efast" / "efast.py"
|
|
if candidate.exists():
|
|
spec = importlib.util.spec_from_file_location(
|
|
"efast_fusion_module", candidate
|
|
)
|
|
efast_fusion = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(efast_fusion)
|
|
break
|
|
if efast_fusion is None:
|
|
raise ImportError(
|
|
"efast package not found. Install with: pip install git+https://github.com/DHI-GRAS/efast.git"
|
|
)
|
|
|
|
|
|
def _load_clouds(clouds_file):
|
|
clouds = {"s2": set(), "s3": set()}
|
|
if clouds_file.exists():
|
|
clouds_data = json.loads(clouds_file.read_text())
|
|
clouds["s2"] = set(clouds_data.get("s2", []))
|
|
clouds["s3"] = set(clouds_data.get("s3", []))
|
|
return clouds
|
|
|
|
|
|
def _reproject_to_target(
|
|
data, src_transform, src_crs, target_bounds, target_crs, width, height, resampling
|
|
):
|
|
dst_transform = rasterio.transform.from_bounds(
|
|
target_bounds.left,
|
|
target_bounds.bottom,
|
|
target_bounds.right,
|
|
target_bounds.top,
|
|
width,
|
|
height,
|
|
)
|
|
reprojected, _ = rasterio.warp.reproject(
|
|
source=data,
|
|
destination=np.zeros((data.shape[0], height, width), dtype=data.dtype),
|
|
src_transform=src_transform,
|
|
src_crs=src_crs,
|
|
dst_transform=dst_transform,
|
|
dst_crs=target_crs,
|
|
resampling=resampling,
|
|
)
|
|
return reprojected, dst_transform
|
|
|
|
|
|
def prepare_s2(season, site_position, site_name, date_range=None):
|
|
s2_dir = Path(f"data/{site_name}/{season}/raw/s2/")
|
|
s3_dir = Path(f"data/{site_name}/{season}/raw/s3/")
|
|
s2_output_dir = Path(f"data/{site_name}/{season}/prepared/s2/")
|
|
clouds_file = Path(f"data/{site_name}/{season}/clouds.json")
|
|
|
|
clouds = _load_clouds(clouds_file)
|
|
|
|
s2_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
s3_files = [f for f in s3_dir.glob("*.geotiff") if f.name not in clouds["s3"]]
|
|
if not s3_files:
|
|
raise ValueError("No non-cloud S3 files found for reference bounds")
|
|
|
|
with rasterio.open(s3_files[0]) as s3_ref:
|
|
target_bounds = s3_ref.bounds
|
|
target_crs = s3_ref.crs
|
|
s3_width = s3_ref.width
|
|
s3_height = s3_ref.height
|
|
s2_width = s3_width * RESOLUTION_RATIO
|
|
s2_height = s3_height * RESOLUTION_RATIO
|
|
|
|
for s2_file in s2_dir.glob("*.geotiff"):
|
|
if s2_file.name in clouds["s2"]:
|
|
continue
|
|
date_str = s2_file.name.split("_")[0]
|
|
|
|
refl_dst = s2_output_dir / f"S2A_MSIL2A_{date_str}_REFL.tif"
|
|
if not refl_dst.exists():
|
|
with rasterio.open(s2_file) as src:
|
|
data = src.read().astype("float32") / 10000.0
|
|
reprojected_data, dst_transform = _reproject_to_target(
|
|
data,
|
|
src.transform,
|
|
src.crs,
|
|
target_bounds,
|
|
target_crs,
|
|
s2_width,
|
|
s2_height,
|
|
Resampling.cubic,
|
|
)
|
|
profile = src.profile.copy()
|
|
profile.update(
|
|
{
|
|
"dtype": "float32",
|
|
"nodata": 0,
|
|
"width": s2_width,
|
|
"height": s2_height,
|
|
"transform": dst_transform,
|
|
"crs": target_crs,
|
|
}
|
|
)
|
|
with rasterio.open(refl_dst, "w", **profile) as dst_file:
|
|
dst_file.write(reprojected_data)
|
|
|
|
dist_cloud_dst = s2_output_dir / f"S2A_MSIL2A_{date_str}_DIST_CLOUD.tif"
|
|
if not dist_cloud_dst.exists():
|
|
with rasterio.open(refl_dst) as src:
|
|
s2_hr = src.read(1)
|
|
mask = s2_hr == 0
|
|
distance_to_cloud_hr = np.clip(
|
|
ndimage.distance_transform_edt(~mask), 0, 255
|
|
).astype("float32")
|
|
|
|
distance_to_cloud_lr, lr_transform = _reproject_to_target(
|
|
distance_to_cloud_hr[np.newaxis, :, :],
|
|
src.transform,
|
|
src.crs,
|
|
target_bounds,
|
|
target_crs,
|
|
s3_width,
|
|
s3_height,
|
|
Resampling.average,
|
|
)
|
|
distance_to_cloud_lr = distance_to_cloud_lr[0]
|
|
|
|
profile = src.profile.copy()
|
|
profile.update(
|
|
{
|
|
"count": 1,
|
|
"dtype": "float32",
|
|
"width": s3_width,
|
|
"height": s3_height,
|
|
"transform": lr_transform,
|
|
}
|
|
)
|
|
with rasterio.open(dist_cloud_dst, "w", **profile) as dst:
|
|
dst.write(distance_to_cloud_lr, 1)
|
|
|
|
|
|
def prepare_s3(season, site_position, site_name, date_range=None):
|
|
s3_dir = Path(f"data/{site_name}/{season}/raw/s3/")
|
|
s2_prepared_dir = Path(f"data/{site_name}/{season}/prepared/s2/")
|
|
s3_preprocessed_dir = Path(f"data/{site_name}/{season}/prepared/s3/")
|
|
clouds_file = Path(f"data/{site_name}/{season}/clouds.json")
|
|
|
|
clouds = _load_clouds(clouds_file)
|
|
s3_preprocessed_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Get reference profile from S2 DIST_CLOUD file
|
|
dist_cloud_files = list(s2_prepared_dir.glob("*DIST_CLOUD.tif"))
|
|
if not dist_cloud_files:
|
|
raise ValueError("No S2 DIST_CLOUD files found. Run prepare_s2 first.")
|
|
|
|
with rasterio.open(dist_cloud_files[0]) as src:
|
|
target_profile = src.profile
|
|
|
|
# Group S3 files by date
|
|
s3_by_date = {}
|
|
for s3_file in s3_dir.glob("*.geotiff"):
|
|
if s3_file.name in clouds["s3"]:
|
|
continue
|
|
date_str = s3_file.name.split("_")[0]
|
|
if date_str not in s3_by_date:
|
|
s3_by_date[date_str] = []
|
|
s3_by_date[date_str].append(s3_file)
|
|
|
|
# Process each date
|
|
for date_str, s3_files in s3_by_date.items():
|
|
output_path = s3_preprocessed_dir / f"composite_{date_str}.tif"
|
|
if output_path.exists():
|
|
continue
|
|
|
|
if len(s3_files) == 1:
|
|
# Single file: reproject directly
|
|
with rasterio.open(s3_files[0]) as src:
|
|
vrt_options = {
|
|
"transform": target_profile["transform"],
|
|
"height": target_profile["height"],
|
|
"width": target_profile["width"],
|
|
"crs": target_profile["crs"],
|
|
"resampling": Resampling.cubic,
|
|
}
|
|
with WarpedVRT(src, **vrt_options) as vrt:
|
|
rio_shutil.copy(vrt, output_path, driver="GTiff")
|
|
else:
|
|
# Multiple files: create weighted composite
|
|
s3_stack = []
|
|
for s3_file in s3_files:
|
|
with rasterio.open(s3_file) as src:
|
|
vrt_options = {
|
|
"transform": target_profile["transform"],
|
|
"height": target_profile["height"],
|
|
"width": target_profile["width"],
|
|
"crs": target_profile["crs"],
|
|
"resampling": Resampling.cubic,
|
|
}
|
|
with WarpedVRT(src, **vrt_options) as vrt:
|
|
data = vrt.read()
|
|
# Remove abnormally high values (pixel-wise mean across bands)
|
|
pixel_means = np.abs(np.nanmean(data, axis=0))
|
|
mask = pixel_means >= 5
|
|
data[:, mask] = np.nan
|
|
s3_stack.append(data)
|
|
|
|
s3_stack = np.array(s3_stack)
|
|
# Simple mean composite (can be enhanced with temporal weighting)
|
|
composite = np.nanmean(s3_stack, axis=0)
|
|
composite = composite.astype("float32")
|
|
|
|
profile = target_profile.copy()
|
|
profile.update({"count": composite.shape[0], "dtype": "float32"})
|
|
with rasterio.open(output_path, "w", **profile) as dst:
|
|
dst.write(composite)
|
|
|
|
|
|
def run_efast(season, site_position, site_name, date_range=None):
|
|
lat, lon = site_position
|
|
datetime_range = date_range or f"{season}-01-01/{season}-12-31"
|
|
|
|
efast_base_dir = Path(f"data/{site_name}/{season}/prepared/")
|
|
s2_output_dir = efast_base_dir / "s2"
|
|
s3_output_dir = efast_base_dir / "s3"
|
|
fusion_output_dir = efast_base_dir / "fusion"
|
|
|
|
fusion_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f"[EFAST] Starting fusion: {site_name} ({lat:.6f}, {lon:.6f}), {season}")
|
|
|
|
start_str, end_str = datetime_range.split("/")
|
|
start_date = datetime.strptime(start_str, "%Y-%m-%d")
|
|
end_date = datetime.strptime(end_str, "%Y-%m-%d")
|
|
|
|
current_date = start_date
|
|
while current_date <= end_date:
|
|
date_str = current_date.strftime("%Y%m%d")
|
|
|
|
output_file = fusion_output_dir / f"REFL_{date_str}.tif"
|
|
if output_file.exists():
|
|
print(f"[EFAST] Skipping {date_str} (exists)")
|
|
current_date += timedelta(days=1)
|
|
continue
|
|
|
|
try:
|
|
efast_fusion.fusion(
|
|
current_date,
|
|
s3_output_dir,
|
|
s2_output_dir,
|
|
fusion_output_dir,
|
|
product="REFL",
|
|
max_days=30,
|
|
date_position=2,
|
|
minimum_acquisition_importance=0.0,
|
|
ratio=RESOLUTION_RATIO,
|
|
)
|
|
if output_file.exists():
|
|
print(f"[EFAST] Saved: {output_file}")
|
|
else:
|
|
print(f"[EFAST] No output for {date_str} (insufficient nearby data)")
|
|
except Exception as e:
|
|
print(f"[EFAST] Error processing {date_str}: {e}")
|
|
|
|
current_date += timedelta(days=1)
|
|
|
|
print("[EFAST] Completed")
|