Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-3749: Add mypy type checking #8852

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ jobs:
default_python: "3.12"
envs: |
- linux: check-dependencies
- linux: check-types
latest_crds_contexts:
uses: spacetelescope/crds/.github/workflows/contexts.yml@d96060f99a7ca75969f6652b050243592f4ebaeb # 12.0.0
crds_context:
Expand Down
1 change: 1 addition & 0 deletions changes/8852.general.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added mypy type checking to CI checks
2 changes: 1 addition & 1 deletion jwst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
_regex_git_hash = re.compile(r".*\+g(\w+)")
__version_commit__ = ""
if "+" in __version__:
commit = _regex_git_hash.match(__version__).groups()
commit = _regex_git_hash.match(__version__).groups() # type: ignore
if commit:
__version_commit__ = commit[0]
2 changes: 1 addition & 1 deletion jwst/ami/ami_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import copy

from jwst.datamodels import CubeModel, ImageModel
from jwst.datamodels import CubeModel, ImageModel # type: ignore[attr-defined]

from .find_affine2d_parameters import find_rotation
from . import instrument_data
Expand Down
2 changes: 1 addition & 1 deletion jwst/ami/leastsqnrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@

if linfit:
try:
from linearfit import linearfit
from linearfit import linearfit # type: ignore[import-not-found]

Check warning on line 616 in jwst/ami/leastsqnrm.py

View check run for this annotation

Codecov / codecov/patch

jwst/ami/leastsqnrm.py#L616

Added line #L616 was not covered by tests

# dependent variables
M = np.asmatrix(flatimg)
Expand Down
2 changes: 1 addition & 1 deletion jwst/ami/matrix_dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def matrix_idft(*args, **kwargs):
return matrix_dft(*args, **kwargs)


matrix_idft.__doc__ = matrix_dft.__doc__.replace(
matrix_idft.__doc__ = matrix_dft.__doc__.replace( # type: ignore
'Perform a matrix discrete Fourier transform',
'Perform an inverse matrix discrete Fourier transform'
)
Expand Down
9 changes: 5 additions & 4 deletions jwst/assign_wcs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _reproject(x, y):


def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray],
disp_axis: int = None, pscale_ratio: float = None) -> float:
disp_axis: int | None = None, pscale_ratio: float | None = None) -> float:
"""Compute scaling transform.

Parameters
Expand Down Expand Up @@ -137,8 +137,8 @@ def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray],
crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False)

coords = SkyCoord(ra=crval_with_offsets[spatial_idx[0]], dec=crval_with_offsets[spatial_idx[1]], unit="deg")
xscale = np.abs(coords[0].separation(coords[1]).value)
yscale = np.abs(coords[0].separation(coords[2]).value)
xscale: float = np.abs(coords[0].separation(coords[1]).value)
yscale: float = np.abs(coords[0].separation(coords[2]).value)

if pscale_ratio is not None:
xscale *= pscale_ratio
Expand All @@ -149,7 +149,8 @@ def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray],
# Assuming disp_axis is consistent with DataModel.meta.wcsinfo.dispersion.direction
return yscale if disp_axis == 1 else xscale

return np.sqrt(xscale * yscale)
scale: float = np.sqrt(xscale * yscale)
return scale


def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> List[float]:
Expand Down
2 changes: 1 addition & 1 deletion jwst/associations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

# Take version from the upstream package
from .. import __version__
from jwst import __version__


# Utility
Expand Down
4 changes: 2 additions & 2 deletions jwst/associations/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ class Association(MutableMapping):
GLOBAL_CONSTRAINT = None
"""Global constraints"""

INVALID_VALUES = None
INVALID_VALUES: tuple | None = None
"""Attribute values that indicate the
attribute is not specified.
"""

ioregistry = IORegistry()
ioregistry: IORegistry = IORegistry()
"""The association IO registry"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion jwst/associations/association_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

__all__ = []
__all__: list = []


# Define JSON encoder to convert `Member` to `dict`
Expand Down
2 changes: 1 addition & 1 deletion jwst/associations/lib/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class SimpleConstraintABC(abc.ABC):
"""

# Attributes to show in the string representation.
_str_attrs = ('name', 'value')
_str_attrs: tuple = ('name', 'value')

def __new__(cls, *args, **kwargs):
"""Force creation of the constraint attribute dict before anything else."""
Expand Down
2 changes: 1 addition & 1 deletion jwst/associations/lib/rules_level3_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class DMS_Level3_Base(DMSBaseMixin, Association):
INVALID_VALUES = _EMPTY

# Make sequences type-dependent
_sequences = defaultdict(Counter)
_sequences: defaultdict = defaultdict(Counter)

def __init__(self, *args, **kwargs):

Expand Down
4 changes: 2 additions & 2 deletions jwst/associations/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class BasePoolRule():

# Define the pools and testing parameters related to them.
# Each entry is a tuple starting with the path of the pool.
pools = []
pools: list = []

# Define the rules that SHOULD be present.
# Each entry is the class name of the rule.
valid_rules = []
valid_rules: list = []

def test_rules_exist(self):
rules = registry_level3_only()
Expand Down
11 changes: 5 additions & 6 deletions jwst/badpix_selfcal/badpix_selfcal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

import jwst.datamodels as dm
from jwst.datamodels import IFUImageModel # type: ignore[attr-defined]
from stcal.outlier_detection.utils import medfilt
from stdatamodels.jwst.datamodels.dqflags import pixel

Expand All @@ -12,7 +12,7 @@ def badpix_selfcal(minimg: np.ndarray,
flagfrac_lower: float = 0.001,
flagfrac_upper: float = 0.001,
kernel_size: int = 15,
dispaxis=None) -> np.ndarray:
dispaxis=None) -> tuple:
"""
Flag residual artifacts as bad pixels in the DQ array of a JWST exposure

Expand Down Expand Up @@ -59,26 +59,25 @@ def badpix_selfcal(minimg: np.ndarray,
flag_low, flag_high = np.nanpercentile(minimg_hpf, [flagfrac_lower * 100, (1 - flagfrac_upper) * 100])
bad = (minimg_hpf > flag_high) | (minimg_hpf < flag_low)
flagged_indices = np.where(bad)

return flagged_indices


def apply_flags(input_model: dm.IFUImageModel, flagged_indices: np.ndarray) -> dm.IFUImageModel:
def apply_flags(input_model: IFUImageModel, flagged_indices: np.ndarray) -> IFUImageModel:
"""
Apply the flagged indices to the input model. Sets the flagged pixels to NaN
and the DQ flag to DO_NOT_USE + OTHER_BAD_PIXEL

Parameters
----------
input_model : dm.IFUImageModel
input_model : IFUImageModel
Input science data to be corrected
flagged_indices : np.ndarray
Indices of the flagged pixels,
shaped like output from np.where

Returns
-------
output_model : dm.IFUImageModel
output_model : IFUImageModel
Flagged data model
"""

Expand Down
2 changes: 1 addition & 1 deletion jwst/cube_skymatch/cube_skymatch_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CubeSkyMatchStep(Step):
binwidth = float(min=0.0, default=0.1) # Bin width for 'mode' and 'midpt' `skystat`, in sigma
"""

reference_file_types = []
reference_file_types: list = []

def process(self, input1, input2):
cube_models = ModelContainer(input1)
Expand Down
2 changes: 1 addition & 1 deletion jwst/datamodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
_deprecated_modules = ['schema']

# Deprecated models in stdatamodels
_deprecated_models = []
_deprecated_models: list[str] = []

# Import all submodules from stdatamodels.jwst.datamodels
for attr in dir(stdatamodels.jwst.datamodels):
Expand Down
2 changes: 1 addition & 1 deletion jwst/dq_init/dq_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from stdatamodels.jwst import datamodels

from ..lib import reffile_utils
from jwst.datamodels import dqflags
from jwst.datamodels import dqflags # type: ignore[attr-defined]

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
Expand Down
4 changes: 2 additions & 2 deletions jwst/dq_init/tests/test_dq_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,12 @@ def test_dq_add1_groupdq():

# Set parameters for multiple runs of guider data
args = "xstart, ystart, xsize, ysize, nints, ngroups, instrument, exp_type, detector"
test_data = [(1, 1, 2048, 2048, 2, 2, 'FGS', 'FGS_ID-IMAGE', 'GUIDER1'),
test_data_multiple = [(1, 1, 2048, 2048, 2, 2, 'FGS', 'FGS_ID-IMAGE', 'GUIDER1'),
(1, 1, 1032, 1024, 1, 5, 'MIRI', 'MIR_IMAGE', 'MIRIMAGE')]
ids = ["GuiderRawModel-Image", "RampModel"]


@pytest.mark.parametrize(args, test_data, ids=ids)
@pytest.mark.parametrize(args, test_data_multiple, ids=ids)
def test_fullstep(xstart, ystart, xsize, ysize, nints, ngroups, instrument, exp_type, detector):
"""Test that the full step runs"""

Expand Down
38 changes: 19 additions & 19 deletions jwst/extract_1d/apply_apcorr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import abc

from typing import Tuple, Union, Type
from scipy.interpolate import RectBivariateSpline, interp1d
from astropy.io import fits
from stdatamodels import DataModel
from stdatamodels.jwst.datamodels import MultiSlitModel

from ..assign_wcs.util import compute_scale
Expand Down Expand Up @@ -55,10 +52,8 @@
}
}

size_key = None

def __init__(self, input_model: DataModel, apcorr_table: fits.FITS_rec, sizeunit: str,
location: Tuple[float, float, float] = None, slit_name: str = None, **match_kwargs):
def __init__(self, input_model, apcorr_table, sizeunit,
location = None, slit_name = None, **match_kwargs):
self.correction = None

self.model = input_model
Expand All @@ -75,6 +70,11 @@
self.apcorr_func = self.approximate()
self.tabulated_correction = None

@property
@abc.abstractmethod
def size_key(self):
...

Check warning on line 76 in jwst/extract_1d/apply_apcorr.py

View check run for this annotation

Codecov / codecov/patch

jwst/extract_1d/apply_apcorr.py#L76

Added line #L76 was not covered by tests

def _convert_size_units(self):
"""If the SIZE or Radius column is in units of arcseconds, convert to pixels."""
if self.apcorr_sizeunits.startswith('arcsec'):
Expand Down Expand Up @@ -102,7 +102,7 @@
'pixels.'
)

def _get_match_keys(self) -> dict:
def _get_match_keys(self):
"""Get column keys needed for reducing the reference table based on input."""
instrument = self.model.meta.instrument.name.upper()
exptype = self.model.meta.exposure.type.upper()
Expand All @@ -113,7 +113,7 @@
if key in exptype:
return relevant_pars[key]

def _get_match_pars(self) -> dict:
def _get_match_pars(self):
"""Get meta parameters required for reference table row-selection."""
match_pars = {}

Expand All @@ -125,7 +125,7 @@

return match_pars

def _reduce_reftable(self) -> fits.FITS_record:
def _reduce_reftable(self):
"""Reduce full reference table to a single matched row."""
table = self._reference_table.copy()

Expand All @@ -145,7 +145,7 @@
"""Generate an approximate aperture correction function based on input data."""
pass

def apply(self, spec_table: fits.FITS_rec):
def apply(self, spec_table):
"""Apply interpolated aperture correction value to source-related extraction results in-place.

Parameters
Expand Down Expand Up @@ -181,14 +181,14 @@
"""
size_key = 'size'

def __init__(self, *args, pixphase: float = 0.5, **kwargs):
def __init__(self, *args, pixphase = 0.5, **kwargs):
self.phase = pixphase # In the future we'll attempt to measure the pixel phase from inputs.

super().__init__(*args, **kwargs)

def approximate(self):
"""Generate an approximate function for interpolating apcorr values to input wavelength and size."""
def _approx_func(wavelength: float, size: float, pixel_phase: float) -> RectBivariateSpline:
def _approx_func(wavelength, size, pixel_phase):
"""Create a 'custom' approximation function that approximates the aperture correction in two stages based on
input data.

Expand Down Expand Up @@ -228,7 +228,7 @@
def measure_phase(self): # Future method in determining pixel phase
pass

def tabulate_correction(self, spec_table: fits.FITS_rec):
def tabulate_correction(self, spec_table):
"""Tabulate the interpolated aperture correction value.

This will save time when applying it later, especially if it is to be applied to multiple integrations.
Expand All @@ -255,7 +255,7 @@

self.tabulated_correction = np.asarray(coefs)

def apply(self, spec_table: fits.FITS_rec, use_tabulated=False):
def apply(self, spec_table, use_tabulated=False):
"""Apply interpolated aperture correction value to source-related extraction results in-place.

Parameters
Expand Down Expand Up @@ -297,8 +297,8 @@
class ApCorrRadial(ApCorrBase):
"""Aperture correction class used with spectral data produced from an extraction aperture radius."""

def __init__(self, input_model: DataModel, apcorr_table,
location: Tuple[float, float, float] = None):
def __init__(self, input_model, apcorr_table,
location = None):

self.correction = None
self.model = input_model
Expand Down Expand Up @@ -329,7 +329,7 @@
'pixels.'
)

def apply(self, spec_table: fits.FITS_rec):
def apply(self, spec_table):
"""Apply interpolated aperture correction value to source-related extraction results in-place.

Parameters
Expand Down Expand Up @@ -410,7 +410,7 @@
return RectBivariateSpline(size, wavelength, apcorr.T, ky=1, kx=1)


def select_apcorr(input_model: DataModel) -> Union[Type[ApCorr], Type[ApCorrPhase], Type[ApCorrRadial]]:
def select_apcorr(input_model):
"""Select appropriate Aperture correction class based on input DataModel.

Parameters
Expand Down
Loading
Loading