diff --git a/weatherbench2/derived_variables.py b/weatherbench2/derived_variables.py index 774299c..09a22d2 100644 --- a/weatherbench2/derived_variables.py +++ b/weatherbench2/derived_variables.py @@ -664,7 +664,8 @@ def interpolate_spectral_frequencies( def interp_at_one_lat(da: xr.DataArray) -> xr.DataArray: da = ( - da.swap_dims({wavenumber_dim: 'frequency'}) # pytype: disable=wrong-arg-types + da.squeeze('latitude') + .swap_dims({wavenumber_dim: 'frequency'}) # pytype: disable=wrong-arg-types .drop_vars(wavenumber_dim) .interp(frequency=frequencies, method=method, **interp_kwargs) ) @@ -673,7 +674,7 @@ def interp_at_one_lat(da: xr.DataArray) -> xr.DataArray: da['wavelength'] = da['wavelength'].assign_attrs(units='m') return da - return spectrum.groupby('latitude').apply(interp_at_one_lat) + return spectrum.groupby('latitude', squeeze=False).apply(interp_at_one_lat) @dataclasses.dataclass