From 2e967df2afcd559500016fd4f6a149408a7efca3 Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Fri, 27 Jan 2023 13:26:07 -0800 Subject: [PATCH] Generalize more operations in map_utils to 3d, add support for a batch channel in elastic_mesh_3d, and switch the map_utils and mesh modules to pep-585 type annotations. PiperOrigin-RevId: 505195694 --- map_utils.py | 448 +++++++++++++++++++++++++++++----------------- mesh.py | 19 +- stitch_elastic.py | 38 ++-- 3 files changed, 317 insertions(+), 188 deletions(-) diff --git a/map_utils.py b/map_utils.py index 6eaf84a..bf29f65 100644 --- a/map_utils.py +++ b/map_utils.py @@ -47,7 +47,7 @@ """ import collections -from typing import List, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union from connectomics.common import bounding_box import jax import jax.numpy as jnp @@ -57,50 +57,57 @@ from scipy import spatial -def _interpolate_points(data_points: Tuple[np.ndarray, np.ndarray], - query_points: Tuple[np.ndarray, np.ndarray], - data_x: np.ndarray, - data_y: np.ndarray, - method='linear') -> Tuple[np.ndarray, np.ndarray]: +def _interpolate_points( + data_points: Sequence[np.ndarray], + query_points: Sequence[np.ndarray], + *args, + method: str = 'linear', +) -> list[np.ndarray]: """Interpolates 2d data. - This is like griddata(), but for vector fields (defined by data_x, data_y). + This is like griddata(), but for multi-dimensional fields (defined by *args). Args: - data_points: arrays of x, y coordinates where the field components are + data_points: arrays of x, y, ... coordinates where the field components are defined - query_points: arrays of x, y coordinates at which to interpolate data - data_x: horizontal component of the field - data_y: vertical component of the field + query_points: arrays of x, y, ... coordinates at which to interpolate data + *args: one or more scalar component fields to interpolate method: interpolation scheme to use (linear, nearest, cubic) Returns: - x, y components of the field sampled at 'query_points' + components of the field sampled at 'query_points', in the same order of + *args """ + if len(data_points) != len(query_points): + raise ValueError( + 'Data and query points dimensionalities needs to match, are: ' + f'{len(data_points)} and {len(query_points)}' + ) + + ret = [] if method == 'nearest': - ip = interpolate.NearestNDInterpolator(data_points, data_x) - ip_x = ip(query_points) - ip.values = data_y - ip_y = ip(query_points) - return ip_x, ip_y + ip = interpolate.NearestNDInterpolator(data_points, args[0]) + ret.append(ip(query_points)) + for arg in args[1:]: + ip.values = arg + ret.append(ip(query_points)) + + return ret assert method in ('linear', 'cubic') - point_x, point_y = data_points - data_points = np.array([point_x, point_y]).T + data_points = np.array(data_points).T tri = spatial.Delaunay(np.ascontiguousarray(data_points, dtype=np.double)) if method == 'linear': - ip = interpolate.LinearNDInterpolator(tri, data_x, fill_value=np.nan) - ip_x = ip(query_points) - ip = interpolate.LinearNDInterpolator(tri, data_y, fill_value=np.nan) - ip_y = ip(query_points) + for arg in args: + ip = interpolate.LinearNDInterpolator(tri, arg, fill_value=np.nan) + ret.append(ip(query_points)) else: - ip = interpolate.CloughTocher2DInterpolator(tri, data_x, fill_value=np.nan) - ip_x = ip(query_points) - ip = interpolate.CloughTocher2DInterpolator(tri, data_y, fill_value=np.nan) - ip_y = ip(query_points) + for arg in args: + ip = interpolate.CloughTocher2DInterpolator(tri, arg, fill_value=np.nan) + ret.append(ip(query_points)) - return ip_x, ip_y + return ret def _as_vec(value: Union[float, Sequence[float]], dim: int) -> Sequence[float]: @@ -112,13 +119,15 @@ def _as_vec(value: Union[float, Sequence[float]], dim: int) -> Sequence[float]: def _identity_map_absolute( - coord_shape: Union[Tuple[int, int], Tuple[int, int, int]], - stride: Union[float, Sequence[float]]) -> List[np.ndarray]: + coord_shape: Union[tuple[int, int], tuple[int, int, int]], + stride: Union[float, Sequence[float]], +) -> list[np.ndarray]: """Generates an identity map in absolute form. Args: coord_shape: [z, ]y, x shape of the map to generate - stride: distance between nearest neighbors of the coordinate map + stride: distance between nearest neighbors of the coordinate map ([z]yx + sequence or a single float) Returns: identity maps: [z -> z,] y -> y, x -> x @@ -134,7 +143,8 @@ def _identity_map_absolute( def to_absolute( coord_map: np.ndarray, stride: Union[float, Sequence[float]], - box: Optional[bounding_box.BoundingBoxBase] = None) -> np.ndarray: + box: Optional[bounding_box.BoundingBox] = None, +) -> np.ndarray: """Converts a coordinate map from relative to absolute representation. Args: @@ -157,7 +167,8 @@ def to_absolute( if box is not None: if not np.all(coord_map.shape[-dim:][::-1] == box.size[:dim]): raise ValueError( - f'box shape ({box.size}) mismatch with coord map ({coord_map.shape})') + f'box shape ({box.size}) mismatch with coord map ({coord_map.shape})' + ) off_zyx = [ o + start * step for o, step, start in zip(off_zyx, stride, box.start[:dim][::-1]) @@ -171,41 +182,48 @@ def to_absolute( def to_relative( coord_map: np.ndarray, - stride: float, - box: Optional[bounding_box.BoundingBoxBase] = None) -> np.ndarray: + stride: Union[float, Sequence[float]], + box: Optional[bounding_box.BoundingBox] = None, +) -> np.ndarray: """Converts a coordinate map from absolute to relative representation. Args: - coord_map: [2, z, y, x] array of coordinates, where the channels represent - an absolute (x, y) location in space - stride: distance between nearest neighbors of the coordinate map + coord_map: [2 or 3, z, y, x] array of coordinates, where the channels + represent an absolute (x, y[, z]) location in space + stride: distance between nearest neighbors of the coordinate map ([z]yx + sequence or a single float) box: bounding box from which coord_map was extracted Returns: - coordinate map where entries represent a (Δx, Δy) offset from the - original (x, y) location + coordinate map where entries represent a (Δx, Δy[, Δz]) offset from the + original (x, y[, z]) location """ coord_map = coord_map.copy() - hy, hx = _identity_map_absolute(coord_map.shape[2:4], stride) + dim = coord_map.shape[0] + stride = _as_vec(stride, dim) + identity = _identity_map_absolute(coord_map.shape[(4 - dim) : 4], stride) if box is not None: - if not np.all(coord_map.shape[2:][::-1] == box.size[:2]): + if not np.all(coord_map.shape[(4 - dim) :][::-1] == box.size[:dim]): raise ValueError( - f'box shape ({box.size}) mismatch with coord map ({coord_map.shape})') - hy += box.start[1] * stride - hx += box.start[0] * stride + f'box shape ({box.size}) mismatch with coord map ({coord_map.shape})' + ) + + for i in range(dim): + identity[dim - i - 1] += box.start[i] * stride[dim - i - 1] + + for i in range(dim): + coord_map[i] -= identity[dim - i - 1] - coord_map[0, ...] -= hx - coord_map[1, ...] -= hy return coord_map -def fill_missing(coord_map: np.ndarray, - extrapolate=False, - invalid_to_zero=False) -> np.ndarray: +def fill_missing( + coord_map: np.ndarray, extrapolate=False, invalid_to_zero=False +) -> np.ndarray: """Fills missing entries in a coordinate map. Args: - coord_map: [2, z, y, x] coordinate map in relative format + coord_map: [2 or 3, z, y, x] coordinate map in relative format extrapolate: if False, will only fill by interpolation invalid_to_zero: whether to zero out completely invalid sections (i.e., reset to identity map) @@ -218,7 +236,7 @@ def fill_missing(coord_map: np.ndarray, return coord_map ret = coord_map.copy() - hy, hx = np.mgrid[:coord_map.shape[2], :coord_map.shape[3]] + hy, hx = np.mgrid[: coord_map.shape[2], : coord_map.shape[3]] query_points = hx.ravel(), hy.ravel() for z in range(coord_map.shape[1]): @@ -235,7 +253,8 @@ def fill_missing(coord_map: np.ndarray, points, query_points, # coord_map[0, z, ...][valid], - coord_map[1, z, ...][valid]) + coord_map[1, z, ...][valid], + ) ret[0, z, ...] = u.reshape(hx.shape) ret[1, z, ...] = v.reshape(hx.shape) except spatial.qhull.QhullError: @@ -252,17 +271,20 @@ def fill_missing(coord_map: np.ndarray, query_points, ret[0, z, ...][valid], ret[1, z, ...][valid], - method='nearest') + method='nearest', + ) ret[0, z, ...] = u.reshape(hx.shape) ret[1, z, ...] = v.reshape(hy.shape) return ret -def outer_box(coord_map: np.ndarray, - box: bounding_box.BoundingBoxBase, - stride: Union[float, Sequence[float]], - target_len: Optional[float] = None) -> bounding_box.BoundingBox: +def outer_box( + coord_map: np.ndarray, + box: bounding_box.BoundingBox, + stride: Union[float, Sequence[float]], + target_len: Optional[float] = None, +) -> bounding_box.BoundingBox: """Returns a bounding box covering all target nodes. Args: @@ -276,15 +298,15 @@ def outer_box(coord_map: np.ndarray, Returns: bounding box containing all (u, v,[ w]) coordinates referenced by the input map (x, y[, z]) -> (u, v[, w]); the bounding box is for a - coordinate map - with `target_len` node spacing + coordinate map with `target_len` node spacing """ abs_map = to_absolute(coord_map, stride, box) extents_xyz = [(np.nanmin(c), np.nanmax(c)) for c in abs_map] dim = coord_map.shape[0] - target_len_xyz = _as_vec(target_len if target_len is not None else stride, - dim)[::-1] + target_len_xyz = _as_vec( + target_len if target_len is not None else stride, dim + )[::-1] start = box.start.copy() size = box.size.copy() for i, ((x_min, x_max), tl) in enumerate(zip(extents_xyz, target_len_xyz)): @@ -295,8 +317,9 @@ def outer_box(coord_map: np.ndarray, return bounding_box.BoundingBox(start, size) -def inner_box(coord_map: np.ndarray, box: bounding_box.BoundingBoxBase, - stride: float) -> bounding_box.BoundingBox: +def inner_box( + coord_map: np.ndarray, box: bounding_box.BoundingBox, stride: float +) -> bounding_box.BoundingBox: """Returns a box within which all nodes are mapped to by coord map. Args: @@ -308,6 +331,8 @@ def inner_box(coord_map: np.ndarray, box: bounding_box.BoundingBoxBase, bounding box, all (u, v) points contained within which have an entry in the (x, y) -> (u, v) map """ + assert coord_map.shape[0] == 2 + # Part of the map might be invalid, in which case we extrapolate # in order to get a fully valid array. int_map = to_absolute(fill_missing(coord_map, extrapolate=True), stride, box) @@ -322,22 +347,27 @@ def inner_box(coord_map: np.ndarray, box: bounding_box.BoundingBoxBase, y1 = y1 // stride return bounding_box.BoundingBox( - start=(x0, y0, box.start[2]), - size=(x1 - x0 + 1, y1 - y0 + 1, box.size[2])) + start=(x0, y0, box.start[2]), size=(x1 - x0 + 1, y1 - y0 + 1, box.size[2]) + ) -def invert_map(coord_map: np.ndarray, src_box: bounding_box.BoundingBoxBase, - dst_box: bounding_box.BoundingBoxBase, - stride: float) -> np.ndarray: +def invert_map( + coord_map: np.ndarray, + src_box: bounding_box.BoundingBox, + dst_box: bounding_box.BoundingBox, + stride: Union[float, Sequence[float]], +) -> np.ndarray: """Inverts a coordinate map. - Given a (x, y) -> (u, v) map, returns a (u, v) -> (x, y) map. + Given a (x, y[, z]) -> (u, v[, w]) map, returns a (u, v[, w]) -> (x, y[, z]) + map. Args: - coord_map: [2, z, y, x] coordinate map in relative format + coord_map: [2 or 3, z, y, x] coordinate map in relative format src_box: box corresponding to coord_map dst_box: uv coordinate box for which to compute output - stride: distance between nearest neighbors of the coordinate map + stride: distance between nearest neighbors of the coordinate map ([z]yx + sequence or a single float) Returns: inverted coordinate map in relative format @@ -345,49 +375,88 @@ def invert_map(coord_map: np.ndarray, src_box: bounding_box.BoundingBoxBase, # Switch to a coordinate system originating at the first target node # of the coordinate map. coord_map = coord_map.astype(np.float64) + dim = coord_map.shape[0] + stride = _as_vec(stride, dim) src_box = src_box.adjusted_by(start=-dst_box.start, end=-dst_box.start) dst_box = dst_box.adjusted_by(start=-dst_box.start, end=-dst_box.start) coord_map = to_absolute(coord_map, stride, src_box) - src_y, src_x = np.mgrid[:src_box.size[1], :src_box.size[0]] - src_x = (src_box.start[0] + src_x) * stride - src_y = (src_box.start[1] + src_y) * stride - - # (u, v) points at which the map will be evaluated. - query_v, query_u = np.mgrid[:dst_box.size[1], :dst_box.size[0]] - query_u = (dst_box.start[0] + query_u) * stride - query_v = (dst_box.start[1] + query_v) * stride - query_points = query_u.ravel(), query_v.ravel() - ret_uv = np.full((2, coord_map.shape[1], dst_box.size[1], dst_box.size[0]), - np.nan, - dtype=coord_map.dtype) + def _sel_size(box): + if dim == 2: + return np.mgrid[: box.size[1], : box.size[0]] + elif dim == 3: + return np.mgrid[: box.size[2], : box.size[1], : box.size[0]] + else: + raise NotImplementedError() - for z in range(coord_map.shape[1]): - valid = np.all(np.isfinite(coord_map[:, z, ...]), axis=0) - if not np.any(valid): - continue + src_coords = _sel_size(src_box) # [z]yx + for i, src in enumerate(src_coords): + src_coords[i] = (src + src_box.start[dim - i - 1]) * stride[i] - src_points = ( - coord_map[0, z, ...][valid], # - coord_map[1, z, ...][valid]) + # ([w, ]v, u) points at which the map will be evaluated. + query_coords = _sel_size(dst_box) + for i, query in enumerate(query_coords): + query_coords[i] = (query + dst_box.start[dim - i - 1]) * stride[i] - try: - u, v = _interpolate_points(src_points, query_points, src_x[valid], - src_y[valid]) - ret_uv[0, z, ...] = u.reshape(query_u.shape) - ret_uv[1, z, ...] = v.reshape(query_v.shape) - except spatial.qhull.QhullError: - pass + query_points = tuple([q.ravel() for q in query_coords[::-1]]) # uv[w] - return to_relative(ret_uv, stride, dst_box) - - -def resample_map(coord_map: np.ndarray, - src_box: bounding_box.BoundingBoxBase, - dst_box: bounding_box.BoundingBoxBase, - src_stride: float, - dst_stride: float, - method='linear') -> np.ndarray: + if dim == 2: + ret_uv = np.full( + (2, coord_map.shape[1], dst_box.size[1], dst_box.size[0]), + np.nan, + dtype=coord_map.dtype, + ) + + for z in range(coord_map.shape[1]): + valid = np.all(np.isfinite(coord_map[:, z, ...]), axis=0) + if not np.any(valid): + continue + + src_points = tuple([c[z][valid] for c in coord_map]) + try: + u, v = _interpolate_points( + src_points, query_points, *[s[valid] for s in src_coords[::-1]] + ) + ret_uv[0, z, ...] = u.reshape(query_coords[1].shape) + ret_uv[1, z, ...] = v.reshape(query_coords[0].shape) + except spatial.qhull.QhullError: + pass + + return to_relative(ret_uv, stride, dst_box) + + assert dim == 3 + + ret = np.full( + (3, dst_box.size[2], dst_box.size[1], dst_box.size[0]), + np.nan, + dtype=coord_map.dtype, + ) + valid = np.all(np.isfinite(coord_map), axis=0) + if not np.any(valid): + return ret + + src_points = tuple([c[valid] for c in coord_map]) + try: + u, v, w = _interpolate_points( + src_points, query_points, *[s[valid] for s in src_coords[::-1]] + ) + ret[0, ...] = u.reshape(query_coords[2].shape) + ret[1, ...] = v.reshape(query_coords[1].shape) + ret[2, ...] = w.reshape(query_coords[0].shape) + except spatial.qhull.QhullError: + pass + + return to_relative(ret, stride, dst_box) + + +def resample_map( + coord_map: np.ndarray, + src_box: bounding_box.BoundingBox, + dst_box: bounding_box.BoundingBox, + src_stride: float, + dst_stride: float, + method='linear', +) -> np.ndarray: """Resamples a coordinate map to a new grid. Args: @@ -401,18 +470,22 @@ def resample_map(coord_map: np.ndarray, Returns: resampled coordinate map with dst_stride node separation """ - src_y, src_x = np.mgrid[:src_box.size[1], :src_box.size[0]] + assert coord_map.shape[0] == 2 + + src_y, src_x = np.mgrid[: src_box.size[1], : src_box.size[0]] src_y = (src_y + src_box.start[1]) * src_stride src_x = (src_x + src_box.start[0]) * src_stride - tg_y, tg_x = np.mgrid[:dst_box.size[1], :dst_box.size[0]] + tg_y, tg_x = np.mgrid[: dst_box.size[1], : dst_box.size[0]] tg_y = (tg_y + dst_box.start[1]) * dst_stride tg_x = (tg_x + dst_box.start[0]) * dst_stride tg_points = tg_x.ravel(), tg_y.ravel() - ret = np.full((2, coord_map.shape[1], dst_box.size[1], dst_box.size[0]), - np.nan, - dtype=coord_map.dtype) + ret = np.full( + (2, coord_map.shape[1], dst_box.size[1], dst_box.size[0]), + np.nan, + dtype=coord_map.dtype, + ) for z in range(coord_map.shape[1]): valid = np.isfinite(coord_map[0, z, ...]) if not np.any(valid): @@ -425,7 +498,8 @@ def resample_map(coord_map: np.ndarray, tg_points, # coord_map[0, z, ...][valid], coord_map[1, z, ...][valid], - method=method) + method=method, + ) ret[0, z, ...] = u.reshape(tg_x.shape) ret[1, z, ...] = v.reshape(tg_y.shape) except spatial.qhull.QhullError: @@ -434,10 +508,14 @@ def resample_map(coord_map: np.ndarray, return ret -def compose_maps(map1: np.ndarray, box1: bounding_box.BoundingBoxBase, - stride1: float, map2: np.ndarray, - box2: bounding_box.BoundingBoxBase, - stride2: float) -> np.ndarray: +def compose_maps( + map1: np.ndarray, + box1: bounding_box.BoundingBox, + stride1: float, + map2: np.ndarray, + box2: bounding_box.BoundingBox, + stride2: float, +) -> np.ndarray: """Composes two coordinate maps. Invalid values in map2 are interpolated. @@ -454,12 +532,17 @@ def compose_maps(map1: np.ndarray, box1: bounding_box.BoundingBoxBase, coordinate map corresponding to map2(map1(x, y)) """ + assert map1.shape[0] == 2 + assert map2.shape[0] == 2 + abs_map1 = to_absolute(map1, stride1, box1) abs_map2 = to_absolute(map2, stride2, box2) ret = np.full_like(map1, np.nan) - src_y, src_x = np.mgrid[box2.start[1]:box2.end[1], box2.start[0]:box2.end[0]] + src_y, src_x = np.mgrid[ + box2.start[1] : box2.end[1], box2.start[0] : box2.end[0] + ] src_x = src_x * stride2 src_y = src_y * stride2 @@ -468,9 +551,7 @@ def compose_maps(map1: np.ndarray, box1: bounding_box.BoundingBoxBase, if not np.any(valid): continue - query_points = ( - abs_map1[0, z, ...][valid], # - abs_map1[1, z, ...][valid]) + query_points = (abs_map1[0, z, ...][valid], abs_map1[1, z, ...][valid]) # valid_src = np.all(np.isfinite(abs_map2[:, z, ...]), axis=0) if not np.any(valid_src): @@ -478,9 +559,12 @@ def compose_maps(map1: np.ndarray, box1: bounding_box.BoundingBoxBase, src_points = src_x[valid_src], src_y[valid_src] try: - u, v = _interpolate_points(src_points, query_points, - abs_map2[0, z, ...][valid_src], - abs_map2[1, z, ...][valid_src]) + u, v = _interpolate_points( + src_points, + query_points, + abs_map2[0, z, ...][valid_src], + abs_map2[1, z, ...][valid_src], + ) ret[0, z, ...][valid] = u ret[1, z, ...][valid] = v except spatial.qhull.QhullError: @@ -491,13 +575,15 @@ def compose_maps(map1: np.ndarray, box1: bounding_box.BoundingBoxBase, # TODO(mjanusz): Automatically split computation into smaller boxes (overlapping # as necessary) in order to improve precision of the calculations. -def compose_maps_fast(map1: jnp.ndarray, - start1: Sequence[float], - stride1: Union[float, Sequence[float]], - map2: jnp.ndarray, - start2: Sequence[float], - stride2: Union[float, Sequence[float]], - mode='nearest') -> jnp.ndarray: +def compose_maps_fast( + map1: jnp.ndarray, + start1: Sequence[float], + stride1: Union[float, Sequence[float]], + map2: jnp.ndarray, + start2: Sequence[float], + stride2: Union[float, Sequence[float]], + mode='nearest', +) -> jnp.ndarray: """Composes two cooordinate maps using JAX. Unlike compose_maps(), invalid value in either map are NOT interpolated. @@ -546,18 +632,26 @@ def _ref_grid(coord_map, start, stride): query_coords = jnp.array([qy, qx]) # [2, y, x] # Query data in absolute format and then immediately convert to relative. - xx = jax.scipy.ndimage.map_coordinates( - map2[0, z, ...] + ref2[-1], - query_coords, - order=1, - mode=mode, - cval=np.nan) - ref1[-1] - yy = jax.scipy.ndimage.map_coordinates( - map2[1, z, ...] + ref2[-2], - query_coords, - order=1, - mode=mode, - cval=np.nan) - ref1[-2] + xx = ( + jax.scipy.ndimage.map_coordinates( + map2[0, z, ...] + ref2[-1], + query_coords, + order=1, + mode=mode, + cval=np.nan, + ) + - ref1[-1] + ) + yy = ( + jax.scipy.ndimage.map_coordinates( + map2[1, z, ...] + ref2[-2], + query_coords, + order=1, + mode=mode, + cval=np.nan, + ) + - ref1[-2] + ) ret = ret.at[:, z, :, :].set(jnp.array([xx, yy])) else: qx = (ref1[-1] + map1[0, ...]) / stride2[-1] @@ -565,25 +659,48 @@ def _ref_grid(coord_map, start, stride): qz = (ref1[-3] + map1[2, ...]) / stride2[-3] query_coords = jnp.array([qz, qy, qx]) - xx = jax.scipy.ndimage.map_coordinates( - map2[0, ...] + ref2[-1], query_coords, order=1, mode=mode, - cval=np.nan) - ref1[-1] - yy = jax.scipy.ndimage.map_coordinates( - map2[1, ...] + ref2[-2], query_coords, order=1, mode=mode, - cval=np.nan) - ref1[-2] - zz = jax.scipy.ndimage.map_coordinates( - map2[2, ...] + ref2[-3], query_coords, order=1, mode=mode, - cval=np.nan) - ref1[-3] + xx = ( + jax.scipy.ndimage.map_coordinates( + map2[0, ...] + ref2[-1], + query_coords, + order=1, + mode=mode, + cval=np.nan, + ) + - ref1[-1] + ) + yy = ( + jax.scipy.ndimage.map_coordinates( + map2[1, ...] + ref2[-2], + query_coords, + order=1, + mode=mode, + cval=np.nan, + ) + - ref1[-2] + ) + zz = ( + jax.scipy.ndimage.map_coordinates( + map2[2, ...] + ref2[-3], + query_coords, + order=1, + mode=mode, + cval=np.nan, + ) + - ref1[-3] + ) ret = jnp.array([xx, yy, zz]) return ret -def mask_irregular(coord_map: np.ndarray, - stride: float, - frac: float, - max_frac: Optional[float] = None, - dilation_iters: int = 1) -> np.ndarray: +def mask_irregular( + coord_map: np.ndarray, + stride: float, + frac: float, + max_frac: Optional[float] = None, + dilation_iters: int = 1, +) -> np.ndarray: """Masks stretched/folded parts of the map. Masked entries are replaced with nan's in-place. @@ -602,6 +719,7 @@ def mask_irregular(coord_map: np.ndarray, input map """ assert len(coord_map.shape) == 3 + assert coord_map.shape[0] == 2 if max_frac is None: max_frac = 2 - frac @@ -617,15 +735,19 @@ def mask_irregular(coord_map: np.ndarray, bad = ndimage.morphology.binary_dilation( bad, ndimage.morphology.generate_binary_structure(2, 2), - iterations=dilation_iters) + iterations=dilation_iters, + ) coord_map[0, ...][bad] = np.nan coord_map[1, ...][bad] = np.nan return bad -def make_affine_map(matrix: np.ndarray, box: bounding_box.BoundingBoxBase, - stride: Union[float, Sequence[float]]) -> np.ndarray: +def make_affine_map( + matrix: np.ndarray, + box: bounding_box.BoundingBox, + stride: Union[float, Sequence[float]], +) -> np.ndarray: """Builds a coordinate map for an affine transform. Args: @@ -641,6 +763,8 @@ def make_affine_map(matrix: np.ndarray, box: bounding_box.BoundingBoxBase, coord_map[1, ...] += box.start[1] coord_map[2, ...] += box.start[2] - affine_absolute = (np.dot(matrix[:3, :3], coord_map.reshape( - (3, -1))) + matrix[:, 3][:, np.newaxis]).reshape(coord_map.shape) + affine_absolute = ( + np.dot(matrix[:3, :3], coord_map.reshape((3, -1))) + + matrix[:, 3][:, np.newaxis] + ).reshape(coord_map.shape) return affine_absolute - coord_map diff --git a/mesh.py b/mesh.py index 37f43d1..3942a8c 100644 --- a/mesh.py +++ b/mesh.py @@ -28,7 +28,7 @@ import collections import dataclasses import functools -from typing import List, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union from absl import logging @@ -183,7 +183,7 @@ def elastic_mesh_3d(x: jnp.ndarray, """Computes internal forces on the nodes of a 3d spring mesh. Args: - x: [3, z, y, x] array of mesh node positions, in relative format + x: [3, [batch..], z, y, x] array of mesh node positions, in relative format k: spring constant for springs along the x direction; will be scaled according to `stride` for all other springs to maintain constant elasticity @@ -195,7 +195,6 @@ def elastic_mesh_3d(x: jnp.ndarray, Returns: [3, z, y, x] array of forces """ - assert x.ndim == 4 assert x.shape[0] == 3 if prefer_orig_order: raise NotImplementedError('prefer_orig_order not supported for 3d mesh.') @@ -206,11 +205,13 @@ def elastic_mesh_3d(x: jnp.ndarray, stride = np.array(stride) f_tot = None for direction in links: - l0 = np.array(stride * direction).reshape([3, 1, 1, 1]) - sel1 = [np.s_[:]] - sel2 = [np.s_[:]] - pad_neg = [(0, 0)] - pad_pos = [(0, 0)] + l0 = np.array(stride * direction).reshape([3] + [1] * (x.ndim - 1)) + # Select everything in non-spatial dimensions. + sel1 = [np.s_[:]] * (x.ndim - 3) + sel2 = [np.s_[:]] * (x.ndim - 3) + # No padding for non-spatial dimensions. + pad_neg = [(0, 0)] * (x.ndim - 3) + pad_pos = [(0, 0)] * (x.ndim - 3) for dim in direction[::-1]: # zyx if dim == -1: sel1.append(np.s_[:-1]) @@ -442,7 +443,7 @@ def relax_mesh( prev: Optional[jnp.ndarray], config: IntegrationConfig, mesh_force=inplane_force, - prev_fn=None) -> Tuple[jnp.ndarray, List[float], int]: + prev_fn=None) -> tuple[jnp.ndarray, list[float], int]: """Simulates mesh relaxation. Args: diff --git a/stitch_elastic.py b/stitch_elastic.py index ba48197..ccac008 100644 --- a/stitch_elastic.py +++ b/stitch_elastic.py @@ -331,15 +331,15 @@ def _update_mesh(mesh: jnp.ndarray, """Updates mesh with data for a neighboring tile. Args: - mesh: [2, y, x] mesh to update + mesh: [2 or 3, [z,] y, x] mesh to update nbor_data: [max(NeighborInfo)] array of neighbor info - x: [2, n, y, x] array of mesh node positions for all tiles - fx: [2, n, y, x] array of flow data for horizontal tile NNs - fy: [2, n, y, x] array of flow data for vertical tile NNs - stride: yx stride for the flow and mesh data + x: [2 o r3, n, [z,] y, x] array of mesh node positions for all tiles + fx: [2 or 3, n, [z,] y, x] array of flow data for horizontal tile NNs + fy: [2 or 3, n, [z,] y, x] array of flow data for vertical tile NNs + stride: [z]yx stride for the flow and mesh data Returns: - [2, y, x] updated mesh + [2 or 3, [z,] y, x] updated mesh """ nbor_idx = nbor_data[NeighborInfo.nbor_idx] flow_idx = nbor_data[NeighborInfo.flow_idx] @@ -379,30 +379,34 @@ def compute_target_mesh(nbor_data: jnp.ndarray, x, fx, fy, vmap(partial(compute_target_mesh, x=x, fx=fx, fy=fy))(nbors) Args: - nbor_data: [4, 8] array of neighbor info; -1 in nbor and flow indices + nbor_data: [4, 8 or 11] array of neighbor info; -1 in nbor and flow indices indicates invalid (missing) entries - x: [2, n, y, x] array with node positions - fx: [2, n, y, x] array with flow data for horizontal neighbors - fy: [2, n, y, x] array with flow data for vertical neighbors - stride: yx stride for the flow and mesh data + x: [2 or 3, n, [z, ]y, x] array with node positions + fx: [2 or 3, n, [z, ]y, x] array with flow data for horizontal neighbors + fy: [2 or 3, n, [z, ]y, x] array with flow data for vertical neighbors + stride: [z]yx stride for the flow and mesh data Returns: - [2, y, x] array of target positions + [2 or 3, [z, ]y, x] array of target positions """ # When used within vmap/jit, dynamic_update_slice with the pasted content # extending beyond the updated array will cause the whole update to fail. # To mitigate this, extend the buffer sufficiently to ensure that the # pasted content (fx, fy) will always fit. - y_size, x_size = x.shape[-2:] - y_size += max(fy.shape[-2], fx.shape[-2]) - x_size += max(fy.shape[-1], fx.shape[-1]) + dim = x.shape[0] + zyx_size = list(x.shape[-dim:]) + for i in range(dim): + zyx_size[i] += max(fy.shape[i], fx.shape[i]) # Scan over neighbors (currently this is always exactly 4 and so # could just be explicitly unrolled). - mesh = jnp.full([2, y_size, x_size], np.nan) + mesh = jnp.full([2] + zyx_size, np.nan) updated = jax.lax.scan( ft.partial(_update_mesh, x=x, fx=fx, fy=fy, stride=stride), mesh, nbor_data)[0] # Cut the array back to the desired shape. - return updated[:, :x.shape[-2], :x.shape[-1]] + if dim == 2: + return updated[:, :x.shape[-2], :x.shape[-1]] + else: + return updated[:, :x.shape[-3], :x.shape[-2], :x.shape[-1]]