diff --git a/changes/372.bugfix.rst b/changes/372.bugfix.rst new file mode 100644 index 00000000..66bbde47 --- /dev/null +++ b/changes/372.bugfix.rst @@ -0,0 +1 @@ +Only use ``roman.meta`` attributes for crds parameter selection. diff --git a/src/roman_datamodels/datamodels/_core.py b/src/roman_datamodels/datamodels/_core.py index 644e7fe0..22a6d19a 100644 --- a/src/roman_datamodels/datamodels/_core.py +++ b/src/roman_datamodels/datamodels/_core.py @@ -19,7 +19,7 @@ import asdf import numpy as np from asdf.exceptions import ValidationError -from asdf.lazy_nodes import AsdfDictNode, AsdfListNode +from asdf.tags.core.ndarray import NDArrayType from astropy.time import Time from roman_datamodels import stnode, validate @@ -298,7 +298,9 @@ def convert_val(val): return val return { - f"roman.{key}": convert_val(val) for (key, val) in self.items() if include_arrays or not isinstance(val, np.ndarray) + f"roman.{key}": convert_val(val) + for (key, val) in self.items() + if include_arrays or not isinstance(val, (np.ndarray, NDArrayType)) } def items(self): @@ -315,29 +317,21 @@ def items(self): schemas directly. """ - def recurse(tree, path=[]): - if isinstance(tree, (stnode.DNode, dict, AsdfDictNode)): - for key, val in tree.items(): - yield from recurse(val, path + [key]) - elif isinstance(tree, (stnode.LNode, list, tuple, AsdfListNode)): - for i, val in enumerate(tree): - yield from recurse(val, path + [i]) - elif tree is not None: - yield (".".join(str(x) for x in path), tree) - - yield from recurse(self._instance) + yield from self._instance._recursive_items() def get_crds_parameters(self): """ Get parameters used by CRDS to select references for this model. + This will only return items under ``roman.meta``. + Returns ------- dict """ return { - key: val - for key, val in self.to_flat_dict(include_arrays=False).items() + f"roman.meta.{key}": val + for key, val in self.meta.to_flat_dict(include_arrays=False, recursive=True).items() if isinstance(val, (str, int, float, complex, bool)) } diff --git a/src/roman_datamodels/stnode/_node.py b/src/roman_datamodels/stnode/_node.py index ca987545..c1e1ccee 100644 --- a/src/roman_datamodels/stnode/_node.py +++ b/src/roman_datamodels/stnode/_node.py @@ -10,11 +10,11 @@ from collections.abc import MutableMapping import asdf -import asdf.lazy_nodes import asdf.schema as asdfschema import asdf.yamlutil as yamlutil import numpy as np from asdf.exceptions import ValidationError +from asdf.lazy_nodes import AsdfDictNode, AsdfListNode from asdf.tags.core import ndarray from asdf.util import HashableDict from astropy.time import Time @@ -166,7 +166,7 @@ def __init__(self, node=None, parent=None, name=None): # Handle if we are passed different data types if node is None: self.__dict__["_data"] = {} - elif isinstance(node, (dict, asdf.lazy_nodes.AsdfDictNode)): + elif isinstance(node, (dict, AsdfDictNode)): self.__dict__["_data"] = node else: raise ValueError("Initializer only accepts dicts") @@ -220,10 +220,10 @@ def __getattr__(self, key): value = self._convert_to_scalar(key, self._data[key]) # Return objects as node classes, if applicable - if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)): + if isinstance(value, (dict, AsdfDictNode)): return DNode(value, parent=self, name=key) - elif isinstance(value, (list, asdf.lazy_nodes.AsdfListNode)): + elif isinstance(value, (list, AsdfListNode)): return LNode(value) else: @@ -266,7 +266,20 @@ def _schema_attributes(self): self._x_schema_attributes = SchemaProperties.from_schema(self._schema()) return self._x_schema_attributes - def to_flat_dict(self, include_arrays=True): + def _recursive_items(self): + def recurse(tree, path=[]): + if isinstance(tree, (DNode, dict, AsdfDictNode)): + for key, val in tree.items(): + yield from recurse(val, path + [key]) + elif isinstance(tree, (LNode, list, tuple, AsdfListNode)): + for i, val in enumerate(tree): + yield from recurse(val, path + [i]) + elif tree is not None: + yield (".".join(str(x) for x in path), tree) + + yield from recurse(self) + + def to_flat_dict(self, include_arrays=True, recursive=False): """ Returns a dictionary of all of the schema items as a flat dictionary. @@ -288,11 +301,13 @@ def convert_val(val): return str(val) return val + item_getter = self._recursive_items if recursive else self.items + if include_arrays: - return {key: convert_val(val) for (key, val) in self.items()} + return {key: convert_val(val) for (key, val) in item_getter()} else: return { - key: convert_val(val) for (key, val) in self.items() if not isinstance(val, (np.ndarray, ndarray.NDArrayType)) + key: convert_val(val) for (key, val) in item_getter() if not isinstance(val, (np.ndarray, ndarray.NDArrayType)) } def _schema(self): @@ -329,7 +344,7 @@ def __setitem__(self, key, value): value = self._convert_to_scalar(key, value, self._data.get(key)) # If the value is a dictionary, loop over its keys and convert them to tagged scalars - if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)): + if isinstance(value, (dict, AsdfDictNode)): for sub_key, sub_value in value.items(): value[sub_key] = self._convert_to_scalar(sub_key, sub_value) @@ -366,7 +381,7 @@ class LNode(UserList): def __init__(self, node=None): if node is None: self.data = [] - elif isinstance(node, (list, asdf.lazy_nodes.AsdfListNode)): + elif isinstance(node, (list, AsdfListNode)): self.data = node elif isinstance(node, self.__class__): self.data = node.data @@ -375,9 +390,9 @@ def __init__(self, node=None): def __getitem__(self, index): value = self.data[index] - if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)): + if isinstance(value, (dict, AsdfDictNode)): return DNode(value) - elif isinstance(value, (list, asdf.lazy_nodes.AsdfListNode)): + elif isinstance(value, (list, AsdfListNode)): return LNode(value) else: return value diff --git a/tests/test_models.py b/tests/test_models.py index 42b73cf3..d6c785ca 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -858,18 +858,31 @@ def test_datamodel_schema_info_existence(name): assert keyword in info["roman"]["meta"] -def test_crds_parameters(tmp_path): - # CRDS uses meta.exposure.start_time to compare to USEAFTER - file_path = tmp_path / "testwfi_image.asdf" - utils.mk_level2_image(filepath=file_path) - with datamodels.open(file_path) as wfi_image: - crds_pars = wfi_image.get_crds_parameters() - assert "roman.meta.exposure.start_time" in crds_pars +@pytest.mark.parametrize("include_arrays", (True, False)) +def test_to_flat_dict(include_arrays, tmp_path): + file_path = tmp_path / "test.asdf" + utils.mk_level2_image(filepath=file_path, shape=(8, 8)) + with datamodels.open(file_path) as model: + if include_arrays: + assert "roman.data" in model.to_flat_dict() + else: + assert "roman.data" not in model.to_flat_dict(include_arrays=False) - utils.mk_ramp(filepath=file_path) - with datamodels.open(file_path) as ramp: - crds_pars = ramp.get_crds_parameters() + +@pytest.mark.parametrize("mk_model", (utils.mk_level2_image, utils.mk_ramp)) +def test_crds_parameters(mk_model, tmp_path): + # CRDS uses meta.exposure.start_time to compare to USEAFTER + file_path = tmp_path / "test.asdf" + mk_model(filepath=file_path) + with datamodels.open(file_path) as model: + # patch on a value that is valid (a simple type) + # but isn't under meta. Since it's not under meta + # it shouldn't be in the crds_pars. + model["test"] = 42 + crds_pars = model.get_crds_parameters() assert "roman.meta.exposure.start_time" in crds_pars + assert "roman.cal_logs" not in crds_pars + assert "roman.test" not in crds_pars def test_model_validate_without_save():