Skip to content

Commit

Permalink
only use roman.meta for get_crds_parameters (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram authored Oct 3, 2024
1 parent dfbcdf1 commit 7485edc
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 36 deletions.
1 change: 1 addition & 0 deletions changes/372.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Only use ``roman.meta`` attributes for crds parameter selection.
24 changes: 9 additions & 15 deletions src/roman_datamodels/datamodels/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import asdf
import numpy as np
from asdf.exceptions import ValidationError
from asdf.lazy_nodes import AsdfDictNode, AsdfListNode
from asdf.tags.core.ndarray import NDArrayType
from astropy.time import Time

from roman_datamodels import stnode, validate
Expand Down Expand Up @@ -298,7 +298,9 @@ def convert_val(val):
return val

return {
f"roman.{key}": convert_val(val) for (key, val) in self.items() if include_arrays or not isinstance(val, np.ndarray)
f"roman.{key}": convert_val(val)
for (key, val) in self.items()
if include_arrays or not isinstance(val, (np.ndarray, NDArrayType))
}

def items(self):
Expand All @@ -315,29 +317,21 @@ def items(self):
schemas directly.
"""

def recurse(tree, path=[]):
if isinstance(tree, (stnode.DNode, dict, AsdfDictNode)):
for key, val in tree.items():
yield from recurse(val, path + [key])
elif isinstance(tree, (stnode.LNode, list, tuple, AsdfListNode)):
for i, val in enumerate(tree):
yield from recurse(val, path + [i])
elif tree is not None:
yield (".".join(str(x) for x in path), tree)

yield from recurse(self._instance)
yield from self._instance._recursive_items()

def get_crds_parameters(self):
"""
Get parameters used by CRDS to select references for this model.
This will only return items under ``roman.meta``.
Returns
-------
dict
"""
return {
key: val
for key, val in self.to_flat_dict(include_arrays=False).items()
f"roman.meta.{key}": val
for key, val in self.meta.to_flat_dict(include_arrays=False, recursive=True).items()
if isinstance(val, (str, int, float, complex, bool))
}

Expand Down
37 changes: 26 additions & 11 deletions src/roman_datamodels/stnode/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from collections.abc import MutableMapping

import asdf
import asdf.lazy_nodes
import asdf.schema as asdfschema
import asdf.yamlutil as yamlutil
import numpy as np
from asdf.exceptions import ValidationError
from asdf.lazy_nodes import AsdfDictNode, AsdfListNode
from asdf.tags.core import ndarray
from asdf.util import HashableDict
from astropy.time import Time
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(self, node=None, parent=None, name=None):
# Handle if we are passed different data types
if node is None:
self.__dict__["_data"] = {}
elif isinstance(node, (dict, asdf.lazy_nodes.AsdfDictNode)):
elif isinstance(node, (dict, AsdfDictNode)):
self.__dict__["_data"] = node
else:
raise ValueError("Initializer only accepts dicts")
Expand Down Expand Up @@ -220,10 +220,10 @@ def __getattr__(self, key):
value = self._convert_to_scalar(key, self._data[key])

# Return objects as node classes, if applicable
if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)):
if isinstance(value, (dict, AsdfDictNode)):
return DNode(value, parent=self, name=key)

elif isinstance(value, (list, asdf.lazy_nodes.AsdfListNode)):
elif isinstance(value, (list, AsdfListNode)):
return LNode(value)

else:
Expand Down Expand Up @@ -266,7 +266,20 @@ def _schema_attributes(self):
self._x_schema_attributes = SchemaProperties.from_schema(self._schema())
return self._x_schema_attributes

def to_flat_dict(self, include_arrays=True):
def _recursive_items(self):
def recurse(tree, path=[]):
if isinstance(tree, (DNode, dict, AsdfDictNode)):
for key, val in tree.items():
yield from recurse(val, path + [key])
elif isinstance(tree, (LNode, list, tuple, AsdfListNode)):
for i, val in enumerate(tree):
yield from recurse(val, path + [i])
elif tree is not None:
yield (".".join(str(x) for x in path), tree)

yield from recurse(self)

def to_flat_dict(self, include_arrays=True, recursive=False):
"""
Returns a dictionary of all of the schema items as a flat dictionary.
Expand All @@ -288,11 +301,13 @@ def convert_val(val):
return str(val)
return val

item_getter = self._recursive_items if recursive else self.items

if include_arrays:
return {key: convert_val(val) for (key, val) in self.items()}
return {key: convert_val(val) for (key, val) in item_getter()}
else:
return {
key: convert_val(val) for (key, val) in self.items() if not isinstance(val, (np.ndarray, ndarray.NDArrayType))
key: convert_val(val) for (key, val) in item_getter() if not isinstance(val, (np.ndarray, ndarray.NDArrayType))
}

def _schema(self):
Expand Down Expand Up @@ -329,7 +344,7 @@ def __setitem__(self, key, value):
value = self._convert_to_scalar(key, value, self._data.get(key))

# If the value is a dictionary, loop over its keys and convert them to tagged scalars
if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)):
if isinstance(value, (dict, AsdfDictNode)):
for sub_key, sub_value in value.items():
value[sub_key] = self._convert_to_scalar(sub_key, sub_value)

Expand Down Expand Up @@ -366,7 +381,7 @@ class LNode(UserList):
def __init__(self, node=None):
if node is None:
self.data = []
elif isinstance(node, (list, asdf.lazy_nodes.AsdfListNode)):
elif isinstance(node, (list, AsdfListNode)):
self.data = node
elif isinstance(node, self.__class__):
self.data = node.data
Expand All @@ -375,9 +390,9 @@ def __init__(self, node=None):

def __getitem__(self, index):
value = self.data[index]
if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)):
if isinstance(value, (dict, AsdfDictNode)):
return DNode(value)
elif isinstance(value, (list, asdf.lazy_nodes.AsdfListNode)):
elif isinstance(value, (list, AsdfListNode)):
return LNode(value)
else:
return value
Expand Down
33 changes: 23 additions & 10 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,18 +858,31 @@ def test_datamodel_schema_info_existence(name):
assert keyword in info["roman"]["meta"]


def test_crds_parameters(tmp_path):
# CRDS uses meta.exposure.start_time to compare to USEAFTER
file_path = tmp_path / "testwfi_image.asdf"
utils.mk_level2_image(filepath=file_path)
with datamodels.open(file_path) as wfi_image:
crds_pars = wfi_image.get_crds_parameters()
assert "roman.meta.exposure.start_time" in crds_pars
@pytest.mark.parametrize("include_arrays", (True, False))
def test_to_flat_dict(include_arrays, tmp_path):
file_path = tmp_path / "test.asdf"
utils.mk_level2_image(filepath=file_path, shape=(8, 8))
with datamodels.open(file_path) as model:
if include_arrays:
assert "roman.data" in model.to_flat_dict()
else:
assert "roman.data" not in model.to_flat_dict(include_arrays=False)

utils.mk_ramp(filepath=file_path)
with datamodels.open(file_path) as ramp:
crds_pars = ramp.get_crds_parameters()

@pytest.mark.parametrize("mk_model", (utils.mk_level2_image, utils.mk_ramp))
def test_crds_parameters(mk_model, tmp_path):
# CRDS uses meta.exposure.start_time to compare to USEAFTER
file_path = tmp_path / "test.asdf"
mk_model(filepath=file_path)
with datamodels.open(file_path) as model:
# patch on a value that is valid (a simple type)
# but isn't under meta. Since it's not under meta
# it shouldn't be in the crds_pars.
model["test"] = 42
crds_pars = model.get_crds_parameters()
assert "roman.meta.exposure.start_time" in crds_pars
assert "roman.cal_logs" not in crds_pars
assert "roman.test" not in crds_pars


def test_model_validate_without_save():
Expand Down

0 comments on commit 7485edc

Please sign in to comment.