Scale it up.

This commit is contained in:
Felix Delattre 2026-06-10 19:37:33 +02:00
parent ba36dfe914
commit c033f5f527

View file

@ -29,6 +29,7 @@ import json
import os import os
import shutil import shutil
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -95,6 +96,17 @@ _BAND_ASSETS: dict[str, str] = {
_SCL_ASSET = "scl" _SCL_ASSET = "scl"
_MIN_BBOX_HALF_DEG = 0.008 _MIN_BBOX_HALF_DEG = 0.008
_GDAL_COG_ENV = {
"GDAL_HTTP_VERSION": "2",
"GDAL_HTTP_MERGE_CONSECUTIVE_RANGES": "YES",
"GDAL_HTTP_MULTIPLEX": "YES",
"GDAL_HTTP_TCP_KEEPALIVE": "YES",
"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR",
"CPL_VSIL_CURL_CACHE_SIZE": "200000000",
"GDAL_MAX_CONNECTIONS": "100",
"AWS_NO_SIGN_REQUEST": "YES",
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Internal S3 constants # Internal S3 constants
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -314,54 +326,75 @@ def stac_search_s2(
return list({item.id: item for item in search.items()}.values()) return list({item.id: item for item in search.items()}.values())
def _process_item(
item: Any,
bbox: list[float],
bands: list[str],
output_dir: Path,
ratio: int,
) -> str | None:
"""Range-read one S2 item and write a masked REFL GeoTIFF.
Returns a skip-message string when the item cannot be processed, else None.
"""
out_path = output_dir / f"{item.id}_REFL.tif"
if out_path.is_file():
return None
bands_result = _read_bands(item, bbox, bands)
if bands_result is None:
return f"[S2] Skipping {item.id}: missing asset or no bbox overlap"
band_arrays, ref_profile = bands_result
mask = _cloud_mask(item, bbox, (ref_profile["height"], ref_profile["width"]))
stacked = (np.stack(band_arrays) + _boa_offset(item)) / 10_000.0
np.clip(stacked, 0, None, out=stacked)
stacked[:, mask] = 0.0
stacked = _pad_to_multiple(stacked, ratio)
out_profile = {
"driver": "GTiff",
"count": len(bands),
"dtype": "float32",
"nodata": 0,
"crs": ref_profile["crs"],
"transform": ref_profile["transform"],
"height": stacked.shape[1],
"width": stacked.shape[2],
"compress": "lzw",
}
with rasterio.open(out_path, "w", **out_profile) as dst:
dst.write(stacked)
for i, band_name in enumerate(bands, 1):
dst.set_band_description(i, band_name)
return None
def download_s2_window( def download_s2_window(
items: list[Any], items: list[Any],
bbox: list[float], bbox: list[float],
output_dir: Path, output_dir: Path,
bands: list[str], bands: list[str],
ratio: int = RESOLUTION_RATIO, ratio: int = RESOLUTION_RATIO,
max_workers: int = 32,
) -> None: ) -> None:
"""Range-read S2 L2A COG windows and write masked REFL GeoTIFFs. """Range-read S2 L2A COG windows and write masked REFL GeoTIFFs.
Writes ``{item.id}_REFL.tif`` directly no intermediate raw download. Writes ``{item.id}_REFL.tif`` directly no intermediate raw download.
Cloud/shadow pixels (SCL 0, 3, >7) are zeroed. BOA offset is inferred from Cloud/shadow pixels (SCL 0, 3, >7) are zeroed. BOA offset is inferred from
``processing:baseline``. Output is zero-padded to multiples of ``ratio``. ``processing:baseline``. Output is zero-padded to multiples of ``ratio``.
Items are fetched in parallel using ``max_workers`` threads.
""" """
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
with rasterio.Env(**_GDAL_COG_ENV):
for item in tqdm(items, unit="granule", desc="S2 COG window read"): with ThreadPoolExecutor(max_workers=max_workers) as pool:
out_path = output_dir / f"{item.id}_REFL.tif" futures = {
if out_path.is_file(): pool.submit(_process_item, item, bbox, bands, output_dir, ratio): item.id
continue for item in items
}
bands_result = _read_bands(item, bbox, bands) with tqdm(total=len(futures), unit="granule", desc="S2 COG window read") as pbar:
if bands_result is None: for fut in as_completed(futures):
tqdm.write(f"[S2] Skipping {item.id}: missing asset or no bbox overlap") msg = fut.result()
continue if msg:
band_arrays, ref_profile = bands_result tqdm.write(msg)
target_shape = (ref_profile["height"], ref_profile["width"]) pbar.update(1)
mask = _cloud_mask(item, bbox, target_shape)
stacked = (np.stack(band_arrays) + _boa_offset(item)) / 10_000.0
np.clip(stacked, 0, None, out=stacked)
stacked[:, mask] = 0.0
stacked = _pad_to_multiple(stacked, ratio)
out_profile = {
"driver": "GTiff",
"count": len(bands),
"dtype": "float32",
"nodata": 0,
"crs": ref_profile["crs"],
"transform": ref_profile["transform"],
"height": stacked.shape[1],
"width": stacked.shape[2],
"compress": "lzw",
}
with rasterio.open(out_path, "w", **out_profile) as dst:
dst.write(stacked)
for i, band_name in enumerate(bands, 1):
dst.set_band_description(i, band_name)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------