Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only use roman.meta for get_crds_parameters #372

Merged
merged 4 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/372.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Only use ``roman.meta`` attributes for crds parameter selection.
24 changes: 9 additions & 15 deletions src/roman_datamodels/datamodels/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import asdf
import numpy as np
from asdf.exceptions import ValidationError
from asdf.lazy_nodes import AsdfDictNode, AsdfListNode
from asdf.tags.core.ndarray import NDArrayType
from astropy.time import Time

from roman_datamodels import stnode, validate
Expand Down Expand Up @@ -298,7 +298,9 @@ def convert_val(val):
return val

return {
f"roman.{key}": convert_val(val) for (key, val) in self.items() if include_arrays or not isinstance(val, np.ndarray)
f"roman.{key}": convert_val(val)
for (key, val) in self.items()
if include_arrays or not isinstance(val, (np.ndarray, NDArrayType))
}

def items(self):
Expand All @@ -315,29 +317,21 @@ def items(self):
schemas directly.
"""

def recurse(tree, path=[]):
if isinstance(tree, (stnode.DNode, dict, AsdfDictNode)):
for key, val in tree.items():
yield from recurse(val, path + [key])
elif isinstance(tree, (stnode.LNode, list, tuple, AsdfListNode)):
for i, val in enumerate(tree):
yield from recurse(val, path + [i])
elif tree is not None:
yield (".".join(str(x) for x in path), tree)

yield from recurse(self._instance)
yield from self._instance._recursive_items()

def get_crds_parameters(self):
"""
Get parameters used by CRDS to select references for this model.

This will only return items under ``roman.meta``.

Returns
-------
dict
"""
return {
key: val
for key, val in self.to_flat_dict(include_arrays=False).items()
f"roman.meta.{key}": val
for key, val in self.meta.to_flat_dict(include_arrays=False, recursive=True).items()
if isinstance(val, (str, int, float, complex, bool))
}

Expand Down
37 changes: 26 additions & 11 deletions src/roman_datamodels/stnode/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from collections.abc import MutableMapping

import asdf
import asdf.lazy_nodes
import asdf.schema as asdfschema
import asdf.yamlutil as yamlutil
import numpy as np
from asdf.exceptions import ValidationError
from asdf.lazy_nodes import AsdfDictNode, AsdfListNode
from asdf.tags.core import ndarray
from asdf.util import HashableDict
from astropy.time import Time
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(self, node=None, parent=None, name=None):
# Handle if we are passed different data types
if node is None:
self.__dict__["_data"] = {}
elif isinstance(node, (dict, asdf.lazy_nodes.AsdfDictNode)):
elif isinstance(node, (dict, AsdfDictNode)):
self.__dict__["_data"] = node
else:
raise ValueError("Initializer only accepts dicts")
Expand Down Expand Up @@ -220,10 +220,10 @@ def __getattr__(self, key):
value = self._convert_to_scalar(key, self._data[key])

# Return objects as node classes, if applicable
if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)):
if isinstance(value, (dict, AsdfDictNode)):
return DNode(value, parent=self, name=key)

elif isinstance(value, (list, asdf.lazy_nodes.AsdfListNode)):
elif isinstance(value, (list, AsdfListNode)):
return LNode(value)

else:
Expand Down Expand Up @@ -266,7 +266,20 @@ def _schema_attributes(self):
self._x_schema_attributes = SchemaProperties.from_schema(self._schema())
return self._x_schema_attributes

def to_flat_dict(self, include_arrays=True):
def _recursive_items(self):
def recurse(tree, path=[]):
if isinstance(tree, (DNode, dict, AsdfDictNode)):
for key, val in tree.items():
yield from recurse(val, path + [key])
elif isinstance(tree, (LNode, list, tuple, AsdfListNode)):
for i, val in enumerate(tree):
yield from recurse(val, path + [i])
elif tree is not None:
yield (".".join(str(x) for x in path), tree)

yield from recurse(self)

def to_flat_dict(self, include_arrays=True, recursive=False):
"""
Returns a dictionary of all of the schema items as a flat dictionary.

Expand All @@ -288,11 +301,13 @@ def convert_val(val):
return str(val)
return val

item_getter = self._recursive_items if recursive else self.items

if include_arrays:
return {key: convert_val(val) for (key, val) in self.items()}
return {key: convert_val(val) for (key, val) in item_getter()}
else:
return {
key: convert_val(val) for (key, val) in self.items() if not isinstance(val, (np.ndarray, ndarray.NDArrayType))
key: convert_val(val) for (key, val) in item_getter() if not isinstance(val, (np.ndarray, ndarray.NDArrayType))
}

def _schema(self):
Expand Down Expand Up @@ -329,7 +344,7 @@ def __setitem__(self, key, value):
value = self._convert_to_scalar(key, value, self._data.get(key))

# If the value is a dictionary, loop over its keys and convert them to tagged scalars
if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)):
if isinstance(value, (dict, AsdfDictNode)):
for sub_key, sub_value in value.items():
value[sub_key] = self._convert_to_scalar(sub_key, sub_value)

Expand Down Expand Up @@ -366,7 +381,7 @@ class LNode(UserList):
def __init__(self, node=None):
if node is None:
self.data = []
elif isinstance(node, (list, asdf.lazy_nodes.AsdfListNode)):
elif isinstance(node, (list, AsdfListNode)):
self.data = node
elif isinstance(node, self.__class__):
self.data = node.data
Expand All @@ -375,9 +390,9 @@ def __init__(self, node=None):

def __getitem__(self, index):
value = self.data[index]
if isinstance(value, (dict, asdf.lazy_nodes.AsdfDictNode)):
if isinstance(value, (dict, AsdfDictNode)):
return DNode(value)
elif isinstance(value, (list, asdf.lazy_nodes.AsdfListNode)):
elif isinstance(value, (list, AsdfListNode)):
return LNode(value)
else:
return value
Expand Down
33 changes: 23 additions & 10 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,18 +858,31 @@ def test_datamodel_schema_info_existence(name):
assert keyword in info["roman"]["meta"]


def test_crds_parameters(tmp_path):
# CRDS uses meta.exposure.start_time to compare to USEAFTER
file_path = tmp_path / "testwfi_image.asdf"
utils.mk_level2_image(filepath=file_path)
with datamodels.open(file_path) as wfi_image:
crds_pars = wfi_image.get_crds_parameters()
assert "roman.meta.exposure.start_time" in crds_pars
@pytest.mark.parametrize("include_arrays", (True, False))
def test_to_flat_dict(include_arrays, tmp_path):
file_path = tmp_path / "test.asdf"
utils.mk_level2_image(filepath=file_path, shape=(8, 8))
with datamodels.open(file_path) as model:
if include_arrays:
assert "roman.data" in model.to_flat_dict()
else:
assert "roman.data" not in model.to_flat_dict(include_arrays=False)

utils.mk_ramp(filepath=file_path)
with datamodels.open(file_path) as ramp:
crds_pars = ramp.get_crds_parameters()

@pytest.mark.parametrize("mk_model", (utils.mk_level2_image, utils.mk_ramp))
def test_crds_parameters(mk_model, tmp_path):
# CRDS uses meta.exposure.start_time to compare to USEAFTER
file_path = tmp_path / "test.asdf"
mk_model(filepath=file_path)
with datamodels.open(file_path) as model:
# patch on a value that is valid (a simple type)
# but isn't under meta. Since it's not under meta
# it shouldn't be in the crds_pars.
model["test"] = 42
crds_pars = model.get_crds_parameters()
assert "roman.meta.exposure.start_time" in crds_pars
assert "roman.cal_logs" not in crds_pars
assert "roman.test" not in crds_pars


def test_model_validate_without_save():
Expand Down