Skip to content

Commit

Permalink
consolidate join code
Browse files Browse the repository at this point in the history
  • Loading branch information
weaverba137 committed Sep 4, 2024
1 parent 18d07ad commit 6cb211c
Showing 1 changed file with 36 additions and 43 deletions.
79 changes: 36 additions & 43 deletions py/specprodDB/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,37 @@
from specprodDB.util import cameraid


def match_rows(left, right):
"""Match rows in `left` to rows in `right`.
Parameters
----------
src : array-like
The column to be matched. This could be "artificial".
dst : array-like
The column to be matched. This could be "artificial".
Returns
-------
:class:`tuple`
The row indexes of `left` and `right` that match.
"""
left_join = Table()
left_join['JOIN_ID'] = left
left_join['LEFT_INDEX'] = np.arange(len(left))
right_join = Table()
right_join['JOIN_ID'] = right
right_join['RIGHT_INDEX'] = np.arange(len(right))
joined = join(left_join, right_join, join_type='outer', keys='JOIN_ID')
if hasattr(joined['LEFT_INDEX'], 'mask'):
good_join = (~joined['LEFT_INDEX'].mask)
else:
good_join = (joined['LEFT_INDEX'] >= 0)
if hasattr(joined['RIGHT_INDEX'], 'mask'):
good_join = good_join & (~joined['RIGHT_INDEX'].mask)
return (joined[good_join]['LEFT_INDEX'], joined[good_join]['RIGHT_INDEX'])


def patch_frames(src_frames, dst_frames):
"""Patch frames data in `dst_frames` with the data in `src_frames`.
Expand All @@ -37,17 +68,8 @@ def patch_frames(src_frames, dst_frames):
A *copy* of `dst_frames` with data replaced from `src_frames`.
"""
log = get_logger()
src_frames_join = Table()
src_frames_join['FRAMEID'] = np.array([100*row['EXPID'] + cameraid(row['CAMERA']) for row in src_frames])
src_frames_join['SRC_INDEX'] = np.arange(len(src_frames))
dst_frames_join = Table()
dst_frames_join['FRAMEID'] = np.array([100*row['EXPID'] + cameraid(row['CAMERA']) for row in dst_frames])
dst_frames_join['DST_INDEX'] = np.arange(len(dst_frames))
joined_frames = join(src_frames_join, dst_frames_join, join_type='outer', keys='FRAMEID')
src_frames_index = joined_frames[(~joined_frames['SRC_INDEX'].mask) &
(~joined_frames['DST_INDEX'].mask)]['SRC_INDEX']
dst_frames_index = joined_frames[(~joined_frames['SRC_INDEX'].mask) &
(~joined_frames['DST_INDEX'].mask)]['DST_INDEX']
src_frames_index, dst_frames_index = match_rows(np.array([100*row['EXPID'] + cameraid(row['CAMERA']) for row in src_frames]),
np.array([100*row['EXPID'] + cameraid(row['CAMERA']) for row in dst_frames]))
dst_frames_patched = dst_frames.copy()
for column in dst_frames_patched.colnames:
if (column in src_frames.colnames and hasattr(src_frames[column], 'mask') and np.any(src_frames[column].mask[src_frames_index])):
Expand Down Expand Up @@ -102,17 +124,7 @@ def patch_exposures(src_exposures, dst_exposures, first_night=None):
#
# Set up a join.
#
src_exposures_join = Table()
src_exposures_join['EXPID'] = src_exposures['EXPID']
src_exposures_join['SRC_INDEX'] = np.arange(len(src_exposures))
dst_exposures_join = Table()
dst_exposures_join['EXPID'] = dst_exposures['EXPID']
dst_exposures_join['DST_INDEX'] = np.arange(len(dst_exposures))
joined_exposures = join(src_exposures_join, dst_exposures_join, join_type='outer', keys='EXPID')
src_exposures_index = joined_exposures[(~joined_exposures['SRC_INDEX'].mask) &
(~joined_exposures['DST_INDEX'].mask)]['SRC_INDEX']
dst_exposures_index = joined_exposures[(~joined_exposures['SRC_INDEX'].mask) &
(~joined_exposures['DST_INDEX'].mask)]['DST_INDEX']
src_exposures_index, dst_exposures_index = match_rows(src_exposures['EXPID'], dst_exposures['EXPID'])
dst_exposures_bad_coord = ((dst_exposures['TILERA'][dst_exposures_index] == 0) &
(dst_exposures['TILEDEC'][dst_exposures_index] == 0))
#
Expand Down Expand Up @@ -184,9 +196,7 @@ def patch_exposures(src_exposures, dst_exposures, first_night=None):
#
# Patch missing MJD.
#
missing_mjd = np.where((dst_exposures_patched['NIGHT'] >= first_night) &
(dst_exposures_patched['EFFTIME_SPEC'] > 0) &
(dst_exposures_patched['MJD'] < 50000))[0]
missing_mjd = (dst_exposures_patched['MJD'] < 50000)
for row in dst_exposures_patched[missing_mjd]:
raw_data_file = os.path.join(os.environ['DESI_SPECTRO_DATA'],
"{0:08d}".format(row['NIGHT']),
Expand Down Expand Up @@ -232,24 +242,7 @@ def patch_tiles(src_tiles, dst_tiles):
#
# Patch TILERA, TILEDEC.
#
src_tiles_join = Table()
src_tiles_join['TILEID'] = src_tiles['TILEID']
src_tiles_join['SRC_INDEX'] = np.arange(len(src_tiles))
dst_tiles_join = Table()
dst_tiles_join['TILEID'] = dst_tiles['TILEID']
dst_tiles_join['DST_INDEX'] = np.arange(len(dst_tiles))
joined_tiles = join(src_tiles_join, dst_tiles_join, join_type='outer', keys='TILEID')
#
# Sometimes every tile in src is also in dst.
#
if hasattr(joined_tiles['SRC_INDEX'], 'mask'):
good_join = (~joined_tiles['SRC_INDEX'].mask)
else:
good_join = (joined_tiles['SRC_INDEX'] >= 0)
if hasattr(joined_tiles['DST_INDEX'], 'mask'):
good_join = good_join & (~joined_tiles['SRC_INDEX'].mask)
src_tiles_index = joined_tiles[good_join]['SRC_INDEX']
dst_tiles_index = joined_tiles[good_join]['DST_INDEX']
src_tiles_index, dst_tiles_index = match_rows(src_tiles['TILEID'], dst_tiles['TILEID'])
dst_tiles_mask_matched = ((dst_tiles['TILERA'][dst_tiles_index] == 0) &
(dst_tiles['TILEDEC'][dst_tiles_index] == 0))
dst_tiles_patched = dst_tiles.copy()
Expand Down

0 comments on commit 6cb211c

Please sign in to comment.