From b7c11b04e3a5643e8b3d055160c208ba2acc06ea Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Mon, 8 Jul 2024 14:06:31 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 650368716 --- scripts/compute_derived_variables.py | 12 ++++++++---- weatherbench2/derived_variables.py | 7 +++++-- weatherbench2/derived_variables_test.py | 13 +++++++++++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/scripts/compute_derived_variables.py b/scripts/compute_derived_variables.py index 30a49e2..b47a9fb 100644 --- a/scripts/compute_derived_variables.py +++ b/scripts/compute_derived_variables.py @@ -151,6 +151,14 @@ def main(argv: list[str]) -> None: source_dataset, source_chunks = xbeam.open_zarr(INPUT_PATH.value) + for var_name in PREEXISTING_VARIABLES_TO_REMOVE.value: + if var_name in source_dataset: + del source_dataset[var_name] + source_chunks = { + # Removing variables may remove some dims. + k: v for k, v in source_chunks.items() if k in source_dataset.dims + } + # Validate and clean-up the source datset. if RENAME_RAW_TP_NAME.value: source_dataset = source_dataset.rename( @@ -168,10 +176,6 @@ def main(argv: list[str]) -> None: rename_variables.get(k, k): v for k, v in source_chunks.items() } - for var_name in PREEXISTING_VARIABLES_TO_REMOVE.value: - if var_name in source_dataset: - del source_dataset[var_name] - for var_name, dv in derived_variables.items(): if var_name in source_dataset: raise ValueError( diff --git a/weatherbench2/derived_variables.py b/weatherbench2/derived_variables.py index 09a22d2..d16de5b 100644 --- a/weatherbench2/derived_variables.py +++ b/weatherbench2/derived_variables.py @@ -472,8 +472,11 @@ def compute(self, dataset: xr.Dataset) -> xr.DataArray: class PrecipitationAccumulation(DerivedVariable): """Compute precipitation accumulation from hourly accumulations. - Accumulation is computed for the time period leading up to the lead_time. - E.g. 24h accumulation at lead_time=24h indicates 0-24h accumulation. + Accumulation is computed for the time period leading up to and including the + lead_time. E.g. 24h accumulation at lead_time=24h indicates accumulation + from lead_time=0 to lead_time=24. This is equal to the values of + `total_precipitation_name` at 24, minus the value at 0. + Caution: Small negative values sometimes appear in model output. Here, we set them to zero. diff --git a/weatherbench2/derived_variables_test.py b/weatherbench2/derived_variables_test.py index 0b7cf78..fea556a 100644 --- a/weatherbench2/derived_variables_test.py +++ b/weatherbench2/derived_variables_test.py @@ -139,6 +139,19 @@ def testPrecipitationAccumulation6hr(self): accumulation_hours=6, ) result = derived_variable.compute(dataset) + + # Test a few specific times for example's sake. + sel = lambda ds, hr: ds.sel(prediction_timedelta=f'{hr}hr') + np.testing.assert_array_equal( + (sel(dataset, 24) - sel(dataset, 24 - 6)).total_precipitation.data, + sel(result, 24), + ) + np.testing.assert_array_equal( + (sel(dataset, 18) - sel(dataset, 18 - 6)).total_precipitation.data, + sel(result, 18), + ) + + # Test every timedelta. expected = xr.DataArray( [np.nan, 5, 10, 0, 6, 10, 0], dims=['prediction_timedelta'],