From c793ddbbe98ed72d9d74480314f0984f71d47ff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 8 Feb 2024 21:01:27 +0100 Subject: [PATCH] Ensure that parameters are leaf nodes when loading a model There was a subtle bug where we populate models with parameters that are not leaf nodes because we called `to` on them for device placement. This change fixes this issue and validates that all model parameters are leaf nodes in the model tests. --- curated_transformers/tests/models/util.py | 22 ++++++++++++++++------ curated_transformers/util/serde.py | 2 +- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/curated_transformers/tests/models/util.py b/curated_transformers/tests/models/util.py index c229a288..79ae7e80 100644 --- a/curated_transformers/tests/models/util.py +++ b/curated_transformers/tests/models/util.py @@ -96,8 +96,7 @@ def assert_causal_lm_output_equals_hf( ) orig_model.eval() - for _, param in orig_model.state_dict().items(): - assert param.device == torch_device + check_params_buffers(orig_model, torch_device) hf_model = transformers.AutoModelForCausalLM.from_pretrained( model_name, @@ -153,8 +152,7 @@ def assert_decoder_output_equals_hf( ) orig_model.eval() - for _, param in orig_model.state_dict().items(): - assert param.device == torch_device + check_params_buffers(orig_model, torch_device) hf_model = transformers.AutoModel.from_pretrained( model_name, revision=model_revision, trust_remote_code=trust_remote_code @@ -217,8 +215,7 @@ def assert_encoder_output_equals_hf( orig_model = model_class.from_hf_hub(name=model_name, device=torch_device) orig_model.eval() - for _, param in orig_model.state_dict().items(): - assert param.device == torch_device + check_params_buffers(orig_model, torch_device) hf_model = transformers.AutoModel.from_pretrained(model_name) hf_model.to(torch_device) @@ -362,3 +359,16 @@ def assert_model_config(model: TransformerModule, model_output: Tensor): hidden_width = model_output.size(-1) assert config.layer.feedforward.hidden_width == hidden_width + + +def check_params_buffers(model: Module, device: torch.device): + """ + Check that parameters/buffers are placed on the correct device and that + parameters are leaf nodes. + """ + for buffer in model.buffers(): + assert buffer.device == device + + for param in model.parameters(): + assert param.device == device + assert param.is_leaf diff --git a/curated_transformers/util/serde.py b/curated_transformers/util/serde.py index 10f59045..42cfc9ad 100644 --- a/curated_transformers/util/serde.py +++ b/curated_transformers/util/serde.py @@ -126,7 +126,7 @@ def default_tensor_to_parameter_converter( old_param = module._parameters[parameter_name] assert old_param is not None _validate_replacement(old_param, tensor, module_prefix) - return Parameter(tensor, requires_grad=old_param.requires_grad).to(device=device) # type: ignore + return Parameter(tensor.to(device=device), requires_grad=old_param.requires_grad) # type: ignore def _emplace_module_state_dict(