Refactored a bit.
This commit is contained in:
parent
290c8f8c57
commit
6bbaa4b3eb
8 changed files with 323 additions and 331 deletions
134
efast.py
134
efast.py
|
|
@ -6,28 +6,16 @@ from datetime import datetime, timedelta
|
|||
import numpy as np
|
||||
import rasterio
|
||||
from rasterio.warp import Resampling
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from scipy import ndimage
|
||||
|
||||
_this_file = Path(__file__).resolve()
|
||||
_venv_lib = _this_file.parent.parent / "venv" / "lib"
|
||||
_efast_pkg_path = None
|
||||
if _venv_lib.exists():
|
||||
for py_dir in _venv_lib.glob("python*"):
|
||||
candidate = py_dir / "site-packages" / "efast" / "efast.py"
|
||||
if candidate.exists():
|
||||
_efast_pkg_path = candidate
|
||||
break
|
||||
RESOLUTION_RATIO = 21
|
||||
|
||||
if _efast_pkg_path and _efast_pkg_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"efast_fusion_module", _efast_pkg_path
|
||||
)
|
||||
efast_fusion = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(efast_fusion)
|
||||
else:
|
||||
try:
|
||||
import efast as efast_fusion
|
||||
except ImportError:
|
||||
import site
|
||||
|
||||
efast_fusion = None
|
||||
for site_pkg in site.getsitepackages():
|
||||
candidate = Path(site_pkg) / "efast" / "efast.py"
|
||||
if candidate.exists():
|
||||
|
|
@ -37,23 +25,51 @@ else:
|
|||
efast_fusion = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(efast_fusion)
|
||||
break
|
||||
else:
|
||||
if efast_fusion is None:
|
||||
raise ImportError(
|
||||
"efast package not found. Install with: pip install git+https://github.com/DHI-GRAS/efast.git"
|
||||
)
|
||||
|
||||
|
||||
def _load_clouds(clouds_file):
|
||||
clouds = {"s2": set(), "s3": set()}
|
||||
if clouds_file.exists():
|
||||
clouds_data = json.loads(clouds_file.read_text())
|
||||
clouds["s2"] = set(clouds_data.get("s2", []))
|
||||
clouds["s3"] = set(clouds_data.get("s3", []))
|
||||
return clouds
|
||||
|
||||
|
||||
def _reproject_to_target(
|
||||
data, src_transform, src_crs, target_bounds, target_crs, width, height, resampling
|
||||
):
|
||||
dst_transform = rasterio.transform.from_bounds(
|
||||
target_bounds.left,
|
||||
target_bounds.bottom,
|
||||
target_bounds.right,
|
||||
target_bounds.top,
|
||||
width,
|
||||
height,
|
||||
)
|
||||
reprojected, _ = rasterio.warp.reproject(
|
||||
source=data,
|
||||
destination=np.zeros((data.shape[0], height, width), dtype=data.dtype),
|
||||
src_transform=src_transform,
|
||||
src_crs=src_crs,
|
||||
dst_transform=dst_transform,
|
||||
dst_crs=target_crs,
|
||||
resampling=resampling,
|
||||
)
|
||||
return reprojected, dst_transform
|
||||
|
||||
|
||||
def prepare_s2(season, site_position, site_name, date_range=None):
|
||||
s2_dir = Path(f"data/{site_name}/{season}/raw/s2/")
|
||||
s3_dir = Path(f"data/{site_name}/{season}/raw/s3/")
|
||||
s2_output_dir = Path(f"data/{site_name}/{season}/prepared/s2/")
|
||||
clouds_file = Path(f"data/{site_name}/{season}/clouds.json")
|
||||
|
||||
clouds = {"s2": set(), "s3": set()}
|
||||
if clouds_file.exists():
|
||||
clouds_data = json.loads(clouds_file.read_text())
|
||||
clouds["s2"] = set(clouds_data.get("s2", []))
|
||||
clouds["s3"] = set(clouds_data.get("s3", []))
|
||||
clouds = _load_clouds(clouds_file)
|
||||
|
||||
s2_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -66,9 +82,8 @@ def prepare_s2(season, site_position, site_name, date_range=None):
|
|||
target_crs = s3_ref.crs
|
||||
s3_width = s3_ref.width
|
||||
s3_height = s3_ref.height
|
||||
ratio = 21
|
||||
s2_width = s3_width * ratio
|
||||
s2_height = s3_height * ratio
|
||||
s2_width = s3_width * RESOLUTION_RATIO
|
||||
s2_height = s3_height * RESOLUTION_RATIO
|
||||
|
||||
for s2_file in s2_dir.glob("*.geotiff"):
|
||||
if s2_file.name in clouds["s2"]:
|
||||
|
|
@ -79,25 +94,15 @@ def prepare_s2(season, site_position, site_name, date_range=None):
|
|||
if not refl_dst.exists():
|
||||
with rasterio.open(s2_file) as src:
|
||||
data = src.read().astype("float32") / 10000.0
|
||||
s2_res = (target_bounds.right - target_bounds.left) / s2_width
|
||||
dst_transform = rasterio.transform.from_bounds(
|
||||
target_bounds.left,
|
||||
target_bounds.bottom,
|
||||
target_bounds.right,
|
||||
target_bounds.top,
|
||||
reprojected_data, dst_transform = _reproject_to_target(
|
||||
data,
|
||||
src.transform,
|
||||
src.crs,
|
||||
target_bounds,
|
||||
target_crs,
|
||||
s2_width,
|
||||
s2_height,
|
||||
)
|
||||
reprojected_data, _ = rasterio.warp.reproject(
|
||||
source=data,
|
||||
destination=np.zeros(
|
||||
(src.count, s2_height, s2_width), dtype=data.dtype
|
||||
),
|
||||
src_transform=src.transform,
|
||||
src_crs=src.crs,
|
||||
dst_transform=dst_transform,
|
||||
dst_crs=target_crs,
|
||||
resampling=Resampling.cubic,
|
||||
Resampling.cubic,
|
||||
)
|
||||
profile = src.profile.copy()
|
||||
profile.update(
|
||||
|
|
@ -118,30 +123,19 @@ def prepare_s2(season, site_position, site_name, date_range=None):
|
|||
with rasterio.open(refl_dst) as src:
|
||||
s2_hr = src.read(1)
|
||||
mask = s2_hr == 0
|
||||
distance_to_cloud_hr = ndimage.distance_transform_edt(~mask)
|
||||
distance_to_cloud_hr = np.clip(distance_to_cloud_hr, 0, 255).astype(
|
||||
"float32"
|
||||
)
|
||||
distance_to_cloud_hr = np.clip(
|
||||
ndimage.distance_transform_edt(~mask), 0, 255
|
||||
).astype("float32")
|
||||
|
||||
s3_res = (target_bounds.right - target_bounds.left) / s3_width
|
||||
lr_transform = rasterio.transform.from_bounds(
|
||||
target_bounds.left,
|
||||
target_bounds.bottom,
|
||||
target_bounds.right,
|
||||
target_bounds.top,
|
||||
distance_to_cloud_lr, lr_transform = _reproject_to_target(
|
||||
distance_to_cloud_hr[np.newaxis, :, :],
|
||||
src.transform,
|
||||
src.crs,
|
||||
target_bounds,
|
||||
target_crs,
|
||||
s3_width,
|
||||
s3_height,
|
||||
)
|
||||
distance_to_cloud_lr, _ = rasterio.warp.reproject(
|
||||
source=distance_to_cloud_hr[np.newaxis, :, :],
|
||||
destination=np.zeros(
|
||||
(1, s3_height, s3_width), dtype=distance_to_cloud_hr.dtype
|
||||
),
|
||||
src_transform=src.transform,
|
||||
src_crs=target_crs,
|
||||
dst_transform=lr_transform,
|
||||
dst_crs=target_crs,
|
||||
resampling=Resampling.average,
|
||||
Resampling.average,
|
||||
)
|
||||
distance_to_cloud_lr = distance_to_cloud_lr[0]
|
||||
|
||||
|
|
@ -164,10 +158,7 @@ def prepare_s3(season, site_position, site_name, date_range=None):
|
|||
s3_preprocessed_dir = Path(f"data/{site_name}/{season}/prepared/s3/")
|
||||
clouds_file = Path(f"data/{site_name}/{season}/clouds.json")
|
||||
|
||||
clouds = {"s3": set()}
|
||||
if clouds_file.exists():
|
||||
clouds_data = json.loads(clouds_file.read_text())
|
||||
clouds["s3"] = set(clouds_data.get("s3", []))
|
||||
clouds = _load_clouds(clouds_file)
|
||||
|
||||
s3_preprocessed_dir.mkdir(parents=True, exist_ok=True)
|
||||
for s3_file in s3_dir.glob("*.geotiff"):
|
||||
|
|
@ -193,8 +184,9 @@ def run_efast(season, site_position, site_name, date_range=None):
|
|||
|
||||
print(f"[EFAST] Starting fusion: {site_name} ({lat:.6f}, {lon:.6f}), {season}")
|
||||
|
||||
start_date = datetime.strptime(datetime_range.split("/")[0], "%Y-%m-%d")
|
||||
end_date = datetime.strptime(datetime_range.split("/")[1], "%Y-%m-%d")
|
||||
start_str, end_str = datetime_range.split("/")
|
||||
start_date = datetime.strptime(start_str, "%Y-%m-%d")
|
||||
end_date = datetime.strptime(end_str, "%Y-%m-%d")
|
||||
|
||||
current_date = start_date
|
||||
while current_date <= end_date:
|
||||
|
|
@ -216,7 +208,7 @@ def run_efast(season, site_position, site_name, date_range=None):
|
|||
max_days=30,
|
||||
date_position=2,
|
||||
minimum_acquisition_importance=0.0,
|
||||
ratio=21,
|
||||
ratio=RESOLUTION_RATIO,
|
||||
)
|
||||
if output_file.exists():
|
||||
print(f"[EFAST] Saved: {output_file}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue