Skip to content

WIP: BUGFIX for #50 #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 64 additions & 50 deletions src/stratify/_vinterp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ cdef inline int relative_sign(double z, double z_base) nogil:
@cython.boundscheck(False)
@cython.wraparound(False)
cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
double[:, :] fz_src, bint increasing,
double[:, :] fz_src, bint rising,
bint aligned,
Interpolator interpolation,
Extrapolator extrapolation,
double [:, :] fz_target) nogil except -1:
Expand All @@ -65,7 +66,8 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
z_target - the levels to interpolate the source data ``fz_src`` to.
z_src - the levels that the source data ``fz_src`` is interpolated from.
fz_src - the source data to be interpolated.
increasing - true when increasing Z index generally implies increasing Z values
rising - true when rising Z index generally implies rising Z values
aligned - true when both src and tgt increase/decrease in the same direction
interpolation - the inner interpolation functionality. See the definition of
Interpolator.
extrapolation - the inner extrapolation functionality. See the definition of
Expand All @@ -91,7 +93,7 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
cdef unsigned int i_src, i_target, n_src, n_target, i, m
cdef bint all_nans = True
cdef double z_before, z_current, z_after, z_last
cdef int sign_after, sign_before, extrapolating
cdef int sign_after, sign_before, extrapolating, z_final

n_src = z_src.shape[0]
n_target = z_target.shape[0]
Expand All @@ -110,13 +112,12 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
fz_target[i, i_target] = NAN
return 0

interpolation.prepare_column(z_target, z_src, fz_src, increasing)
extrapolation.prepare_column(z_target, z_src, fz_src, increasing)
interpolation.prepare_column(z_target, z_src, fz_src, rising)
extrapolation.prepare_column(z_target, z_src, fz_src, rising)
with gil:
z_src = np.asarray(z_src)

if increasing:
z_before = -INFINITY
else:
z_before = INFINITY
z_before = -INFINITY if rising else INFINITY

z_last = -z_before

Expand All @@ -125,7 +126,11 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
# first window value (typically -inf, but may be +inf) and the first z_src.
# This search window will be moved along until a crossing is detected, at
# which point we will do an interpolation.
z_after = z_src[0]
with gil:
z_final = z_src.size - 1


z_after = z_src[0] if aligned else z_src[z_final]

# We start in extrapolation mode. This will be turned off as soon as we
# start increasing i_src.
Expand All @@ -151,7 +156,12 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
i_src += 1
if i_src < n_src:
extrapolating = 0
z_after = z_src[i_src]
with gil:
if aligned:
z_after = z_src[i_src]
else:
dummy = z_src.size - (i_src + 1)
z_after = z_src[dummy]
if isnan(z_after):
with gil:
raise ValueError('The source coordinate may not contain NaN values.')
Expand Down Expand Up @@ -201,7 +211,7 @@ cdef class Interpolator(object):
'the kernel function.')

cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
double[:, :] fz_src, bint increasing) nogil except -1:
double[:, :] fz_src, bint rising) nogil except -1:
# Called before all levels are interpolated.
pass

Expand Down Expand Up @@ -262,7 +272,7 @@ cdef class PyFuncInterpolator(Interpolator):
def __init__(self, use_column_prep=True):
self.use_column_prep = use_column_prep

def column_prep(self, z_target, z_src, fz_src, increasing):
def column_prep(self, z_target, z_src, fz_src, rising):
"""
Called each time this interpolator sees a new data array.
This method may be used for validation of a column, or for column
Expand All @@ -274,10 +284,10 @@ cdef class PyFuncInterpolator(Interpolator):
pass

cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
double[:, :] fz_src, bint increasing) nogil except -1:
double[:, :] fz_src, bint rising) nogil except -1:
if self.use_column_prep:
with gil:
self.column_prep(z_target, z_src, fz_src, increasing)
self.column_prep(z_target, z_src, fz_src, rising)

def interp_kernel(self, index, z_src, fz_src, level, output_array):
# Fill the output array with the fz_src data at the given index.
Expand Down Expand Up @@ -319,7 +329,7 @@ cdef class Extrapolator(object):
'the kernel function.')

cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
double[:, :] fz_src, bint increasing) nogil except -1:
double[:, :] fz_src, bint rising) nogil except -1:
pass


Expand Down Expand Up @@ -359,7 +369,7 @@ cdef class NearestNExtrapolator(Extrapolator):

cdef class LinearExtrapolator(Extrapolator):
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
double[:, :] fz_src, bint increasing) nogil except -1:
double[:, :] fz_src, bint rising) nogil except -1:
cdef unsigned int n_src_pts = z_src.shape[0]

if n_src_pts < 2:
Expand Down Expand Up @@ -402,7 +412,7 @@ cdef class PyFuncExtrapolator(Extrapolator):
def __init__(self, use_column_prep=True):
self.use_column_prep = use_column_prep

def column_prep(self, z_target, z_src, fz_src, increasing):
def column_prep(self, z_target, z_src, fz_src, rising):
"""
Called each time this extrapolator sees a new data array.
This method may be used for validation of a column, or for column
Expand All @@ -414,10 +424,10 @@ cdef class PyFuncExtrapolator(Extrapolator):
pass

cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
double[:, :] fz_src, bint increasing) nogil except -1:
double[:, :] fz_src, bint rising) nogil except -1:
if self.use_column_prep:
with gil:
self.column_prep(z_target, z_src, fz_src, increasing)
self.column_prep(z_target, z_src, fz_src, rising)

def extrap_kernel(self, direction, z_src, fz_src, level, output_array):
# Fill the output array with nans.
Expand Down Expand Up @@ -449,7 +459,7 @@ EXTRAPOLATE_NEAREST = extrap_schemes['nearest']()
EXTRAPOLATE_LINEAR = extrap_schemes['linear']()


def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
def interpolate(z_target, z_src, fz_src, rising=None, axis=-1,
interpolation='linear', extrapolation='nan'):
"""
Interface for optimised 1d interpolation across multiple dimensions.
Expand Down Expand Up @@ -486,16 +496,6 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
the same as the shape of ``z_src``.
axis: int (default -1)
The ``fz_src`` axis to perform the interpolation over.
rising: bool (default None)
Whether the values of the source's interpolation coordinate values
are generally rising or generally falling. For example, values of
pressure levels will be generally falling as the z coordinate
increases.
This will determine whether extrapolation needs to occur for
``z_target`` below the first and above the last ``z_src``.
If rising is None, the first two interpolation coordinate values
will be used to determine the general direction. In most cases,
this is a good option.
interpolation: :class:`.Interpolator` instance or valid scheme name
The core interpolation operation to use. :attr:`.INTERPOLATE_LINEAR`
and :attr:`_INTERPOLATE_NEAREST` are provided for convenient
Expand All @@ -509,7 +509,6 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
func = functools.partial(
_interpolate,
axis=axis,
rising=rising,
interpolation=interpolation,
extrapolation=extrapolation
)
Expand Down Expand Up @@ -564,14 +563,14 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
meta=np.array((), dtype=fz_src.dtype))


def _interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
def _interpolate(z_target, z_src, fz_src, axis=-1,
interpolation='linear', extrapolation='nan'):
if interpolation in interp_schemes:
interpolation = interp_schemes[interpolation]()
if extrapolation in extrap_schemes:
extrapolation = extrap_schemes[extrapolation]()

interp = _Interpolation(z_target, z_src, fz_src, rising=rising, axis=axis,
interp = _Interpolation(z_target, z_src, fz_src, axis=axis,
interpolation=interpolation,
extrapolation=extrapolation)
if interp.z_target.ndim == 1:
Expand All @@ -583,16 +582,14 @@ def _interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
cdef class _Interpolation(object):
"""
Where the magic happens for gridwise_interp. The work of this __init__ is
mostly for putting the input nd arrays into a 3 and 4 dimensional form for
convenient (read: efficient) Cython form. Inline comments should help with
understanding.
mostly for putting the input nd.

"""
cdef Interpolator interpolation
cdef Extrapolator extrapolation

cdef public np.dtype _target_dtype
cdef int rising
cdef rising, aligned
cdef public z_target, orig_shape, axis, _zp_reshaped, _fp_reshaped
cdef public _result_working_shape, result_shape

Expand Down Expand Up @@ -692,17 +689,27 @@ cdef class _Interpolation(object):
#: The shape of the interpolated data.
self.result_shape = tuple(result_shape)

if rising is None:
if z_src.shape[zp_axis] < 2:
raise ValueError('The rising keyword must be defined when '
'the size of the source array is <2 in '
'the interpolation axis.')
z_src_indexer = [0] * z_src.ndim
z_src_indexer[zp_axis] = slice(0, 2)
first_two = z_src[tuple(z_src_indexer)]
rising = first_two[0] <= first_two[1]
if z_src.shape[zp_axis] < 2:
raise ValueError('The rising keyword must be defined when '
'the size of the source array is <2 in '
'the interpolation axis.')


self.rising = bool(rising)
z_src_indexer = [0] * z_src.ndim
z_src_indexer[zp_axis] = slice(0, 2)
src_first_two = z_src[tuple(z_src_indexer)]
src_rising = src_first_two[0] <= src_first_two[1]
src_rise = bool(src_rising)

z_tgt_indexer = [0] * z_target.ndim
z_tgt_indexer[zp_axis] = slice(0, 2)
tgt_first_two = z_target[tuple(z_tgt_indexer)]
tgt_rising = tgt_first_two[0] <= tgt_first_two[1]
tgt_rise = bool(tgt_rising)


self.rising = bool(tgt_rising)
self.aligned = src_rise == tgt_rise

# Sometimes we want to add additional constraints on our interpolation
# and extrapolation - for example, linear extrapolation requires there
Expand Down Expand Up @@ -733,13 +740,17 @@ cdef class _Interpolation(object):
# Construct a memory view of the fz_target array.
cdef double[:, :, :, :] fz_target_view = fz_target

cdef int rising = self.rising
cdef int aligned = self.aligned

# Release the GIL and do the for loop over the left-hand, and
# right-hand dimensions. The loop optimised for row-major data (C).
with nogil:
for j in range(nj):
for i in range(ni):
gridwise_interpolation(z_target, z_src[i, :, j], fz_src[:, i, :, j],
self.rising,
rising,
aligned,
self.interpolation,
self.extrapolation,
fz_target_view[:, i, :, j])
Expand All @@ -755,6 +766,8 @@ cdef class _Interpolation(object):
fz_target = np.empty(self._result_working_shape, dtype=np.float64)

cdef unsigned int i, j, ni, nj
cdef int rising = self.rising
cdef int aligned = self.aligned

ni = fz_target.shape[1]
nj = fz_target.shape[3]
Expand All @@ -775,7 +788,8 @@ cdef class _Interpolation(object):
for j in range(nj):
for i in range(ni):
gridwise_interpolation(z_target[i, :, j], z_src[i, :, j], fz_src[:, i, :, j],
self.rising,
rising,
aligned,
self.interpolation,
self.extrapolation,
fz_target_view[:, i, :, j])
Expand Down
Loading
Loading