Source code for pisa.utils.numba_tools

# pylint: disable=invalid-name, ungrouped-imports

"""
Numba tools

This is a colection of functions used for numba functions
that work for targets cpu as well as cuda
"""

from __future__ import absolute_import, division, print_function

__all__ = [
    "cuda",
    "ctype",
    "ftype",
    "myjit",
    "conjugate_transpose",
    "conjugate_transpose_guf",
    "test_conjugate_transpose",
    "conjugate",
    "conjugate_guf",
    "test_conjugate",
    "matrix_dot_matrix",
    "matrix_dot_matrix_guf",
    "test_matrix_dot_matrix",
    "matrix_dot_vector",
    "matrix_dot_vector_guf",
    "test_matrix_dot_vector",
    "clear_matrix",
    "clear_matrix_guf",
    "test_clear_matrix",
    "copy_matrix",
    "copy_matrix_guf",
    "test_copy_matrix",
    "cuda_copy",
    "FX", "IX"
]
__version__ = "0.2"
__author__ = "Philipp Eller (pde3@psu.edu)"

from argparse import ArgumentParser
import inspect

# NOTE: Following must be imported to be in the namespace for use by `myjit`
# when re-compiling modified (external) function code
import cmath  # pylint: disable=unused-import
import math  # pylint: disable=unused-import

import numpy as np

import numba

from numba import (  # pylint: disable=unused-import
    complex64,
    complex128,
    float32,
    float64,
    int32,
    int64,
    uint32,
    uint64,
    guvectorize,
    jit,
)

from pisa import FTYPE, TARGET, PISA_NUM_THREADS
from pisa.utils.comparisons import ALLCLOSE_KW
from pisa.utils.log import Levels, logging, set_verbosity

if TARGET == "parallel":
    numba.set_num_threads(PISA_NUM_THREADS)

if TARGET is None:
    raise NotImplementedError("Numba not supported.")

if TARGET == "cuda":
    from numba import cuda

    if FTYPE == np.float64:
        ctype = complex128
        ftype = float64
    elif FTYPE == np.float32:
        ctype = complex64
        ftype = float32
else:
    if FTYPE == np.float64:
        ctype = np.complex128
        ftype = np.float64
    elif FTYPE == np.float32:
        ctype = np.complex64
        ftype = np.float32
    cuda = lambda: None
    cuda.jit = lambda x: x


if FTYPE == np.float32:
    FX = "f4"
    CX = "c8"
    IX = "i4"
elif FTYPE == np.float64:
    FX = "f8"
    CX = "c16"
    IX = "i8"


[docs] def myjit(func): """ Decorator to assign the right jit for different targets In case of non-cuda targets, all instances of `cuda.local.array` are replaced by `np.empty`. This is a dirty fix, hopefully in the near future numba will support numpy array allocation and this will not be necessary anymore Parameters ---------- func : callable Returns ------- new_nb_func: numba callable Refactored version of `func` but with `cuda.local.array` replaced by `np.empty` if `TARGET == "cpu"`. For either TARGET, the returned function will be callable within numba code for that target. """ # pylint: disable=exec-used, eval-used if TARGET == "cuda": new_nb_func = cuda.jit(func, device=True) else: source = inspect.getsource(func).splitlines() assert source[0].strip().startswith("@myjit") source = "\n".join(source[1:]) + "\n" source = source.replace("cuda.local.array", "np.empty") exec(source) new_py_func = eval(func.__name__) new_nb_func = jit(new_py_func, nopython=True) # needs to be exported to globals globals()[func.__name__] = new_nb_func return new_nb_func
# --------------------------------------------------------------------------- #
[docs] def cuda_copy(func): ''' Handle copying back device array''' def wrapper(*args, **kwargs): out = kwargs.pop("out") if TARGET == "cuda" and not isinstance(out, numba.cuda.devicearray.DeviceNDArray): d_out = numba.cuda.to_device(out) func(*args, **kwargs, out=d_out) d_out.copy_to_host(out) else: func(*args, **kwargs, out=out) return wrapper
[docs] @myjit def conjugate_transpose(A, B): """B is the conjugate (Hermitian) transpose of A .. :: B[j, i] = A[i, j]* A : 2d array of shape (M, N) B : 2d array of shape (N, M) """ for i in range(A.shape[0]): for j in range(A.shape[1]): B[j, i] = A[i, j].conjugate()
@guvectorize( [f"({XX}[:, :], {XX}[:, :])" for XX in [FX, CX]], "(i, j) -> (j, i)", target=TARGET, ) def conjugate_transpose_guf(A, out): """gufunc that calls conjugate_transpose""" conjugate_transpose(A, out)
[docs] def test_conjugate_transpose(): """Unit tests of `conjugate_transpose` and `conjugate_transpose_guf`""" A = (np.linspace(1, 12, 12) + 1j * np.linspace(21, 32, 12)).reshape(4, 3).astype(CX) B = np.ones((3, 4), dtype=CX) conjugate_transpose_guf(A, B) test = B ref = A.conj().T assert np.allclose(test, ref, **ALLCLOSE_KW), f"test:\n{test}\n!= ref:\n{ref}" A = np.linspace(1, 12, 12, dtype=FX).reshape(3, 4) B = np.ones((4, 3), dtype=FX) conjugate_transpose_guf(A, B) test = B ref = A.conj().T assert np.allclose(test, ref, **ALLCLOSE_KW), f"test:\n{test}\n!= ref:\n{ref}" logging.info("<< PASS : test_conjugate_transpose >>")
# --------------------------------------------------------------------------- #
[docs] @myjit def conjugate(A, B): """B is the element-by-element conjugate of A .. :: B[i, j] = A[i, j]* Parameters ---------- A : 2d array B : 2d array """ for i in range(A.shape[0]): for j in range(A.shape[1]): B[i, j] = A[i, j].conjugate()
@guvectorize( [f"({XX}[:, :], {XX}[:, :])" for XX in [FX, CX]], "(i, j) -> (i, j)", target=TARGET, ) def conjugate_guf(A, out): """gufunc that calls `conjugate`""" conjugate(A, out)
[docs] def test_conjugate(): """Unit tests of `conjugate` and `conjugate_guf`""" A = (np.linspace(1, 12, 12) + 1j * np.linspace(21, 32, 12)).reshape(4, 3).astype(CX) B = np.ones((4, 3), dtype=CX) conjugate_guf(A, B) test = B ref = A.conj() assert np.allclose(test, ref, **ALLCLOSE_KW), f"test:\n{test}\n!= ref:\n{ref}" A = np.linspace(1, 12, 12, dtype=FX).reshape(3, 4) B = np.ones((3, 4), dtype=FX) conjugate_guf(A, B) test = B ref = A.conj() assert np.allclose(test, ref, **ALLCLOSE_KW), f"test:\n{test}\n!= ref:\n{ref}" logging.info("<< PASS : test_conjugate >>")
# --------------------------------------------------------------------------- #
[docs] @myjit def matrix_dot_matrix(A, B, C): """Dot-product of two 2d arrays .. :: C = A * B """ for j in range(B.shape[1]): for i in range(A.shape[0]): C[i, j] = 0.0 for n in range(B.shape[0]): C[i, j] += A[i, n] * B[n, j]
@guvectorize( [f"({XX}[:, :], {XX}[:, :], {XX}[:, :])" for XX in [FX, CX]], "(i, n), (n, j) -> (i, j)", target=TARGET, ) def matrix_dot_matrix_guf(A, B, out): """gufunc that calls matrix_dot_matrix""" matrix_dot_matrix(A, B, out)
[docs] def test_matrix_dot_matrix(): """Unit tests of `matrix_dot_matrix` and `matrix_dot_matrix_guf`""" A = np.linspace(1, 12, 12, dtype=FTYPE).reshape(3, 4) B = np.linspace(1, 12, 12, dtype=FTYPE).reshape(4, 3) C = np.ones((3, 3), dtype=FTYPE) matrix_dot_matrix_guf(A, B, C) test = C ref = np.dot(A, B).astype(FTYPE) assert np.allclose(test, ref, **ALLCLOSE_KW), f"test:\n{test}\n!= ref:\n{ref}" logging.info("<< PASS : test_matrix_dot_matrix >>")
# --------------------------------------------------------------------------- #
[docs] @myjit def matrix_dot_vector(A, v, w): """Dot-product of a 2d array and a vector .. :: w = A * v """ for i in range(A.shape[0]): w[i] = 0.0 for j in range(A.shape[1]): w[i] += A[i, j] * v[j]
@guvectorize( [f"({XX}[:, :], {XX}[:], {XX}[:])" for XX in [FX, CX]], "(i, j), (j) -> (i)", target=TARGET, ) def matrix_dot_vector_guf(A, B, out): """gufunc that calls matrix_dot_vector""" matrix_dot_vector(A, B, out)
[docs] def test_matrix_dot_vector(): """Unit tests of `matrix_dot_vector` and `matrix_dot_vector_guf`""" A = np.linspace(1, 12, 12, dtype=FTYPE).reshape(4, 3) v = np.linspace(1, 3, 3, dtype=FTYPE) w = np.ones(4, dtype=FTYPE) matrix_dot_vector_guf(A, v, w) test = w ref = np.dot(A, v).astype(FTYPE) assert np.allclose(test, ref, **ALLCLOSE_KW), f"test:\n{test}\n!= ref:\n{ref}" logging.info("<< PASS : test_matrix_dot_vector >>")
# --------------------------------------------------------------------------- #
[docs] @myjit def clear_matrix(A): """Zero out 2D matrix .. :: A[i, j] = 0 """ for i in range(A.shape[0]): for j in range(A.shape[1]): A[i, j] = 0
@guvectorize( [f"({XX}[:, :], {XX}[:, :])" for XX in [FX, CX]], "(i, j) -> (i, j)", target=TARGET, ) def clear_matrix_guf(dummy, out): # pylint: disable=unused-argument """gufunc that calls `clear_matrix`""" clear_matrix(out)
[docs] def test_clear_matrix(): """Unit tests of `clear_matrix` and `clear_matrix_guf`""" A = np.ones((4, 3), dtype=FTYPE) clear_matrix_guf(A, A) test = A ref = np.zeros((4, 3), dtype=FTYPE) assert np.array_equal(test, ref), f"test:\n{test}\n!= ref:\n{ref}" logging.info("<< PASS : test_clear_matrix >>")
# --------------------------------------------------------------------------- #
[docs] @myjit def copy_matrix(A, B): """Copy elemnts of 2d array A to array B .. :: B[i, j] = A[i, j] """ for i in range(A.shape[0]): for j in range(A.shape[1]): B[i, j] = A[i, j]
@guvectorize( [f"({XX}[:, :], {XX}[:, :])" for XX in [FX, CX]], "(i, j) -> (i, j)", target=TARGET, ) def copy_matrix_guf(A, out): """gufunc that calls `copy_matrix`""" copy_matrix(A, out)
[docs] def test_copy_matrix(): """Unit tests of `copy_matrix` and `copy_matrix_guf`""" A = np.ones((3, 3), dtype=FTYPE) B = np.zeros((3, 3), dtype=FTYPE) copy_matrix_guf(A, B) test = B ref = A assert np.array_equal(test, ref), f"test:\n{test}\n!= ref:\n{ref}" logging.info("<< PASS : test_copy_matrix >>")
# --------------------------------------------------------------------------- # def parse_args(): """parse command line args""" parser = ArgumentParser() parser.add_argument("-v", action="count", default=Levels.WARN, help="Verbosity") args = parser.parse_args() return vars(args) if __name__ == "__main__": set_verbosity(parse_args()["v"]) test_conjugate_transpose() test_conjugate() test_matrix_dot_matrix() test_matrix_dot_vector() test_clear_matrix() test_copy_matrix()