Refactored a bit.

This commit is contained in:
Felix Delattre 2025-12-27 10:25:17 +01:00
parent 290c8f8c57
commit 6bbaa4b3eb
8 changed files with 323 additions and 331 deletions

View file

@ -1,28 +1,76 @@
import os
import rasterio
import xml.etree.ElementTree as ET
import requests
from pathlib import Path
from rasterio.warp import transform_geom
from rasterio.windows import from_bounds, transform as window_transform
from pystac_client import Client
BBOX_SIZE = 0.011
def _get_bbox(lon, lat):
half = BBOX_SIZE / 2
return [lon - half, lat - half, lon + half, lat + half]
def _get_window_for_bbox(src, bbox):
bbox_geom = {
"type": "Polygon",
"coordinates": [
[
[bbox[0], bbox[1]],
[bbox[2], bbox[1]],
[bbox[2], bbox[3]],
[bbox[0], bbox[3]],
[bbox[0], bbox[1]],
]
],
}
bbox_transformed = transform_geom("EPSG:4326", src.crs, bbox_geom)
coords = bbox_transformed["coordinates"][0]
x_coords = [c[0] for c in coords[:4]]
y_coords = [c[1] for c in coords[:4]]
bbox_crs = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
src_bounds = src.bounds
intersect_bbox = [
max(bbox_crs[0], src_bounds.left),
max(bbox_crs[1], src_bounds.bottom),
min(bbox_crs[2], src_bounds.right),
min(bbox_crs[3], src_bounds.top),
]
return from_bounds(*intersect_bbox, src.transform)
def _extract_viewing_angle(item):
if "granule_metadata" not in item.assets:
return None
try:
xml_url = item.assets["granule_metadata"].href
xml_resp = requests.get(xml_url, timeout=10)
xml_resp.raise_for_status()
root = ET.fromstring(xml_resp.content)
angles = [
abs(float(zenith_elem.text))
for angle_elem in root.findall(".//Mean_Viewing_Incidence_Angle")
if (zenith_elem := angle_elem.find("ZENITH_ANGLE")) is not None
]
return angles[0] if angles else None
except Exception as e:
print(f"[S2] Warning: Could not extract viewing angle: {e}")
return None
def download_s2(season, site_position, site_name, date_range=None):
lat, lon = site_position
datetime_range = date_range or f"{season}-01-01/{season}-12-31"
output_dir = f"data/{site_name}/{season}/raw/s2/"
output_dir = Path(f"data/{site_name}/{season}/raw/s2/")
print(f"[S2] Starting download: {site_name} ({lat:.6f}, {lon:.6f}), {season}")
bbox_size = 0.011
bbox = [
lon - bbox_size / 2,
lat - bbox_size / 2,
lon + bbox_size / 2,
lat + bbox_size / 2,
]
bbox = _get_bbox(lon, lat)
bands = {"B02": "blue", "B03": "green", "B04": "red", "B8A": "nir"}
os.makedirs(output_dir, exist_ok=True)
output_dir.mkdir(parents=True, exist_ok=True)
print("[S2] Connecting to STAC catalog...")
client = Client.open("https://earth-search.aws.element84.com/v1")
@ -46,8 +94,8 @@ def download_s2(season, site_position, site_name, date_range=None):
print(f"[S2] Found {len(items_by_key)} unique items")
for (date, increment), item in items_by_key.items():
filepath = os.path.join(output_dir, f"{date}_{increment}.geotiff")
if os.path.exists(filepath):
filepath = output_dir / f"{date}_{increment}.geotiff"
if filepath.exists():
print(f"[S2] Skipping {date}_{increment}.geotiff (exists)")
continue
@ -56,88 +104,33 @@ def download_s2(season, site_position, site_name, date_range=None):
profile = None
for band_name, asset_name in bands.items():
if asset_name in item.assets:
asset = item.assets[asset_name]
with rasterio.open(asset.href) as src:
bbox_geom = {
"type": "Polygon",
"coordinates": [
[
[bbox[0], bbox[1]],
[bbox[2], bbox[1]],
[bbox[2], bbox[3]],
[bbox[0], bbox[3]],
[bbox[0], bbox[1]],
]
],
if asset_name not in item.assets:
continue
asset = item.assets[asset_name]
with rasterio.open(asset.href) as src:
window = _get_window_for_bbox(src, bbox)
if window.height <= 0 or window.width <= 0:
continue
data = src.read(window=window)
new_transform = window_transform(window, src.transform)
if profile is None:
profile = {
"driver": "GTiff",
"height": window.height,
"width": window.width,
"count": len(bands),
"dtype": data.dtype,
"crs": src.crs,
"transform": new_transform,
"compress": "lzw",
}
bbox_transformed = transform_geom("EPSG:4326", src.crs, bbox_geom)
coords = bbox_transformed["coordinates"][0]
x_coords = [c[0] for c in coords[:4]]
y_coords = [c[1] for c in coords[:4]]
bbox_crs = [
min(x_coords),
min(y_coords),
max(x_coords),
max(y_coords),
]
src_bounds = src.bounds
intersect_bbox = [
max(bbox_crs[0], src_bounds.left),
max(bbox_crs[1], src_bounds.bottom),
min(bbox_crs[2], src_bounds.right),
min(bbox_crs[3], src_bounds.top),
]
window = from_bounds(*intersect_bbox, src.transform)
if window.height > 0 and window.width > 0:
data = src.read(window=window)
new_transform = window_transform(window, src.transform)
if profile is None:
profile = {
"driver": "GTiff",
"height": window.height,
"width": window.width,
"count": len(bands),
"dtype": data.dtype,
"crs": src.crs,
"transform": new_transform,
"compress": "lzw",
}
band_idx = list(bands.keys()).index(band_name)
band_data[band_idx] = data[0]
band_idx = list(bands.keys()).index(band_name)
band_data[band_idx] = data[0]
if profile and len(band_data) == len(bands):
stacked = [band_data[i] for i in sorted(band_data.keys())]
band_names = [list(bands.keys())[i] for i in sorted(band_data.keys())]
# Extract viewing angle from granule metadata XML
viewing_angle = None
if "granule_metadata" in item.assets:
try:
xml_url = item.assets["granule_metadata"].href
xml_resp = requests.get(xml_url, timeout=10)
xml_resp.raise_for_status()
root = ET.fromstring(xml_resp.content)
# Find Mean_Viewing_Incidence_Angle ZENITH_ANGLE
for angle_elem in root.findall(".//Mean_Viewing_Incidence_Angle"):
if angle_elem.get("bandId") == "0": # Use first band or average
zenith_elem = angle_elem.find("ZENITH_ANGLE")
if zenith_elem is not None:
viewing_angle = abs(float(zenith_elem.text))
break
# If not found, try averaging all bands
if viewing_angle is None:
angles = []
for angle_elem in root.findall(
".//Mean_Viewing_Incidence_Angle"
):
zenith_elem = angle_elem.find("ZENITH_ANGLE")
if zenith_elem is not None:
angles.append(abs(float(zenith_elem.text)))
if angles:
viewing_angle = sum(angles) / len(angles)
except Exception as e:
print(f"[S2] Warning: Could not extract viewing angle: {e}")
viewing_angle = _extract_viewing_angle(item)
with rasterio.open(filepath, "w", **profile) as dst:
for i, data in enumerate(stacked, 1):
@ -146,11 +139,10 @@ def download_s2(season, site_position, site_name, date_range=None):
if viewing_angle is not None:
dst.update_tags(VIEWING_ZENITH_ANGLE=viewing_angle)
print(
f"[S2] Saved: {filepath} (viewing angle: {viewing_angle:.2f}°)"
if viewing_angle
else f"[S2] Saved: {filepath}"
angle_msg = (
f" (viewing angle: {viewing_angle:.2f}°)" if viewing_angle else ""
)
print(f"[S2] Saved: {filepath}{angle_msg}")
else:
print(f"[S2] Skipping {date}_{increment} (missing bands)")