Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting start_with_input to False Specifying a Single Step leads to Nans/Incorrect Outputs #86

Open
leanderloew opened this issue Jun 10, 2024 · 5 comments

Comments

@leanderloew
Copy link

leanderloew commented Jun 10, 2024

Hey, I realized, that if you specify start_with_input = False and you specify a single outer step you get Nans in the u_component_of_wind output variable and the outputs generally look bad.

@leanderloew
Copy link
Author

For example, here is a plot of the u_component_of_wind:
Screenshot 2024-06-10 at 14 52 59
Screenshot 2024-06-10 at 14 53 10
All further levels look like this. If you specify more than one step it looks good.

@leanderloew
Copy link
Author

You can reproduce it with this code:

outer_steps = 1  # total of 4 days
timedelta = np.timedelta64(1, 'h') * inner_steps
times = (np.arange(outer_steps) * inner_steps)  # time axis in hours

# initialize model state
data_dict, forcings = model.data_from_xarray(eval_era5.head(time=1))
inputs, input_forcings = pytree_utils.slice_along_axis((data_dict, forcings), axis=0, idx=0)
state = model.encode(
    inputs, forcings=input_forcings, rng_key=None
)

# make forecast
final_state, predictions = model.unroll(
    state,
    forcings,
    steps=outer_steps,
    timedelta=timedelta,
    start_with_input=False,
)
predictions_ds = model.data_to_xarray(predictions, times=times)

print(np.isnan(predictions_ds["u_component_of_wind"]).sum())

import matplotlib.pyplot as plt
for i in range(len(predictions_ds.level)):
    print(predictions_ds.level[i])
    array = predictions_ds["u_component_of_wind"][0,i,:]
    plt.imshow(array, cmap='viridis', aspect='auto')
    plt.colorbar()
    plt.title('2D Array Visualization')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.show()

@leanderloew leanderloew changed the title Setting start_with_input to False + High Resolution Model + Single Step leads to Nans/Incorrect Outputs Setting start_with_input to False Specifying a Single Step leads to Nans/Incorrect Outputs Jun 10, 2024
@yaniyuval
Copy link
Contributor

Hi,
Thanks for raising this issue.
I can reproduce this problem only when I am running NeuralGCM on a CPU. When I run the model on a GPU I do not see this issue. We will try to look into this issue, but until then I suggest using a GPU (which would also reduce the computation time substantially).

@leanderloew
Copy link
Author

Hey, I just reproduced it on an L4 GPU.

@yaniyuval
Copy link
Contributor

Thanks for letting us know. I could verify that on T4 GPU this does not occur and we hope to find time to try to understand what fails on L4 GPU.
As a side comment, is there a specific reason that you need to set start_with_input to False (I think that if you set start_with_input to True and then just slicing from 1,...end you will get almost the same result as using start_with_input=False (you will have one less time step outputted)).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants