Source code for QuickCW.lapack_wrappers

"""C 2021 Matthew Digman
various jit compatible interfaces to cython lapack functions """
import ctypes
from numba.extending import get_cython_function_address
from numba import njit
import numpy as np

_PTR  = ctypes.POINTER

_dble = ctypes.c_double
_char = ctypes.c_char
_int  = ctypes.c_int

_ptr_select = ctypes.c_voidp
_ptr_dble = _PTR(_dble)
_ptr_char = _PTR(_char)
_ptr_int  = _PTR(_int)


# signature is:
# void dtrtrs(
#  char *UPLO,
#  char *TRANS,
#  char *DIAG,
#  int *N,
#  int *NRHS,
#  d *A,
#  int *LDA,
#  d *B,
#  int *LDB,
#  int *info
# )
# bind to the real space variant of the function
addr = get_cython_function_address('scipy.linalg.cython_lapack', 'dtrtrs')
functype = ctypes.CFUNCTYPE(None,
                            _ptr_int, # UPLO
                            _ptr_int, # TRANS
                            _ptr_int, # DIAG
                            _ptr_int, # N
                            _ptr_int, # NRHS
                            _ptr_dble, # A
                            _ptr_int, # LDA
                            _ptr_dble, # B
                            _ptr_int, # LDB
                            _ptr_int, # INFO
                            )
dtrtrs_fn = functype(addr)
[docs]@njit() def solve_triangular(x,y,lower_a=True,trans_a=True,unitdiag=False,overwrite_b=False): """solve x*B=y :param x: triangular matrix (must be either type of contiguous) :param y: vector (must be fortran ordered) :return B: Solution to x*B=y """ #if the input matrix is c contiguous but not fortran contiguous #transposing it will make it fortran contiguous with no copying #then flipping upper and lower and telling dtrtrs to undo the transpose will force dtrtrs to do the correct operation if x.flags.c_contiguous and not x.flags.f_contiguous: trans_a = not trans_a lower_a = not lower_a A = x.T else: A = x # in & out if trans_a: TRANS = np.array([ord('T')], np.int32) else: TRANS = np.array([ord('N')], np.int32) if lower_a: UPLO = np.array([ord('L')], np.int32) else: UPLO = np.array([ord('U')], np.int32) #TODO why was this in place? added copy to mitigate, ensure nothing relied on that behavior if overwrite_b: B = y else: B = y.T.copy().T #cannot do this operation in place if y is not contiguous, though could copy if not (A.flags.f_contiguous and B.flags.f_contiguous): raise ValueError('x must be contiguous and y must be fortran contiguous') if unitdiag: DIAG = np.array([ord('U')], np.int32) else: DIAG = np.array([ord('N')], np.int32) _M, _N = x.shape if y.ndim==1: _LDB = y.size _NB = 1 else: _LDB,_NB = y.shape if _LDB != _N or _M!=_N: raise ValueError('x must be square and y must have same first dimension as x') N = np.array(_N, np.int32) NRHS = np.array(_NB, np.int32) LDA = np.array(_N, np.int32) LDB = np.array(_LDB, np.int32) #changed from _N INFO = np.empty(1, dtype=np.int32) def check_info(info): if info[0] != 0: print(info) raise RuntimeError("INFO indicates problem with dtrtrs") dtrtrs_fn(UPLO.ctypes, TRANS.ctypes, DIAG.ctypes, N.ctypes, NRHS.ctypes, A.ctypes, LDA.ctypes, B.ctypes, LDB.ctypes, INFO.ctypes) check_info(INFO) return B