From 29a6fbe1c6cee2abe1dc307a06939a76c3209e7d Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 15 Apr 2025 12:57:21 -0700 Subject: [PATCH] impl-syrk --- CHANGELOG.md | 1 + dpnp/backend/extensions/blas/CMakeLists.txt | 1 + dpnp/backend/extensions/blas/blas_py.cpp | 10 + dpnp/backend/extensions/blas/gemm.cpp | 8 +- dpnp/backend/extensions/blas/gemv.cpp | 46 ++- dpnp/backend/extensions/blas/gemv.hpp | 1 - dpnp/backend/extensions/blas/syrk.cpp | 307 ++++++++++++++++++ dpnp/backend/extensions/blas/syrk.hpp | 42 +++ dpnp/backend/extensions/blas/types_matrix.hpp | 25 ++ dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 46 ++- dpnp/tests/test_product.py | 45 ++- dpnp/tests/test_sycl_queue.py | 20 +- dpnp/tests/test_usm_type.py | 18 +- 13 files changed, 522 insertions(+), 48 deletions(-) create mode 100644 dpnp/backend/extensions/blas/syrk.cpp create mode 100644 dpnp/backend/extensions/blas/syrk.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index d5de63262567..d7b38e7b9250 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ This release achieves 100% compliance with Python Array API specification (revis * Added implementation of `dpnp.bartlett` [#2366](https://github.com/IntelPython/dpnp/pull/2366) * Added implementation of `dpnp.convolve` [#2205](https://github.com/IntelPython/dpnp/pull/2205) * Added implementation of `dpnp.kaiser` [#2387](https://github.com/IntelPython/dpnp/pull/2387) +* Added a new backend routine for performing symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix []() ### Changed diff --git a/dpnp/backend/extensions/blas/CMakeLists.txt b/dpnp/backend/extensions/blas/CMakeLists.txt index d5639a24b268..d419e3140990 100644 --- a/dpnp/backend/extensions/blas/CMakeLists.txt +++ b/dpnp/backend/extensions/blas/CMakeLists.txt @@ -30,6 +30,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/syrk.cpp ) pybind11_add_module(${python_module_name} MODULE ${_module_src}) diff --git a/dpnp/backend/extensions/blas/blas_py.cpp b/dpnp/backend/extensions/blas/blas_py.cpp index 0321ff6fc6bd..6235d141cd58 100644 --- a/dpnp/backend/extensions/blas/blas_py.cpp +++ b/dpnp/backend/extensions/blas/blas_py.cpp @@ -36,6 +36,7 @@ #include "dotu.hpp" #include "gemm.hpp" #include "gemv.hpp" +#include "syrk.hpp" namespace blas_ns = dpnp::extensions::blas; namespace py = pybind11; @@ -48,6 +49,7 @@ void init_dispatch_vectors_tables(void) blas_ns::init_gemm_batch_dispatch_table(); blas_ns::init_gemm_dispatch_table(); blas_ns::init_gemv_dispatch_vector(); + blas_ns::init_syrk_dispatch_vector(); } static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types]; @@ -141,6 +143,14 @@ PYBIND11_MODULE(_blas_impl, m) py::arg("depends") = py::list()); } + { + m.def("_syrk", &blas_ns::syrk, + "Call `syrk` from OneMKL BLAS library to compute " + "the matrix-vector product with a general matrix.", + py::arg("sycl_queue"), py::arg("matrixA"), py::arg("resultC"), + py::arg("depends") = py::list()); + } + { m.def( "_using_onemkl_interfaces", diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index 4d674010efd7..086b40b83a1a 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -129,8 +129,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q, Tab(1), // Scaling factor for the product of matrices A and B. a, // Pointer to matrix A. lda, // Leading dimension of matrix A, which is the - // stride between successive rows (for row major - // layout). + // stride between successive rows (for row major layout). b, // Pointer to matrix B. ldb, // Leading dimension of matrix B, similar to lda. Tab(0), // Scaling factor for matrix C. @@ -168,7 +167,8 @@ std::tuple const int resultC_nd = resultC.get_ndim(); if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) { - throw py::value_error("Input matrices must be two-dimensional."); + throw py::value_error( + "Input and output matrices must be two-dimensional."); } auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); @@ -286,6 +286,8 @@ std::tuple } } else { + // both A and B are f_contig so using column-major gemm and + // no transpose is needed transA = oneapi::mkl::transpose::N; transB = oneapi::mkl::transpose::N; lda = m; diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 87730fbec9a8..b4416a6ed48a 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -118,8 +118,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q, T(1), // Scaling factor for the matrix-vector product. a, // Pointer to the input matrix A. lda, // Leading dimension of matrix A, which is the - // stride between successive rows (for row major - // layout). + // stride between successive rows (for row major layout). x, // Pointer to the input vector x. incx, // The stride of vector x. T(0), // Scaling factor for vector y. @@ -190,6 +189,26 @@ std::pair const py::ssize_t *a_shape = matrixA.get_shape_raw(); const py::ssize_t *x_shape = vectorX.get_shape_raw(); const py::ssize_t *y_shape = vectorY.get_shape_raw(); + if (transpose) { + if (a_shape[0] != x_shape[0]) { + throw py::value_error("The number of rows in A must be equal to " + "the number of elements in X."); + } + if (a_shape[1] != y_shape[0]) { + throw py::value_error("The number of columns in A must be equal to " + "the number of elements in Y."); + } + } + else { + if (a_shape[1] != x_shape[0]) { + throw py::value_error("The number of columns in A must be equal to " + "the number of elements in X."); + } + if (a_shape[0] != y_shape[0]) { + throw py::value_error("The number of rows in A must be equal to " + "the number of elements in Y."); + } + } oneapi::mkl::transpose transA; std::size_t src_nelems; @@ -243,27 +262,6 @@ std::pair } #endif // USE_ONEMKL_CUBLAS - if (transpose) { - if (a_shape[0] != x_shape[0]) { - throw py::value_error("The number of rows in A must be equal to " - "the number of elements in X."); - } - if (a_shape[1] != y_shape[0]) { - throw py::value_error("The number of columns in A must be equal to " - "the number of elements in Y."); - } - } - else { - if (a_shape[1] != x_shape[0]) { - throw py::value_error("The number of columns in A must be equal to " - "the number of elements in X."); - } - if (a_shape[0] != y_shape[0]) { - throw py::value_error("The number of rows in A must be equal to " - "the number of elements in Y."); - } - } - const std::int64_t lda = is_row_major ? n : m; dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vectorY); dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vectorY, @@ -287,7 +285,7 @@ std::pair "Types of input arrays and result array are mismatched."); } - char *a_typeless_ptr = matrixA.get_data(); + const char *a_typeless_ptr = matrixA.get_data(); char *x_typeless_ptr = vectorX.get_data(); char *y_typeless_ptr = vectorY.get_data(); diff --git a/dpnp/backend/extensions/blas/gemv.hpp b/dpnp/backend/extensions/blas/gemv.hpp index 88e9f9c5c6f0..094cdafdc483 100644 --- a/dpnp/backend/extensions/blas/gemv.hpp +++ b/dpnp/backend/extensions/blas/gemv.hpp @@ -41,5 +41,4 @@ extern std::pair const std::vector &depends); extern void init_gemv_dispatch_vector(void); -extern void init_gemv_batch_dispatch_vector(void); } // namespace dpnp::extensions::blas diff --git a/dpnp/backend/extensions/blas/syrk.cpp b/dpnp/backend/extensions/blas/syrk.cpp new file mode 100644 index 000000000000..d25b2120de81 --- /dev/null +++ b/dpnp/backend/extensions/blas/syrk.cpp @@ -0,0 +1,307 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_utils.hpp" + +#include "syrk.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp::extensions::blas +{ +namespace mkl_blas = oneapi::mkl::blas; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*syrk_impl_fn_ptr_t)(sycl::queue &, + oneapi::mkl::transpose, + const std::int64_t, + const std::int64_t, + const char *, + const std::int64_t, + char *, + const std::int64_t, +#if !defined(USE_ONEMKL_CUBLAS) + const bool, +#endif // !USE_ONEMKL_CUBLAS + const std::vector &); + +static syrk_impl_fn_ptr_t syrk_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event syrk_impl(sycl::queue &exec_q, + oneapi::mkl::transpose transA, + const std::int64_t n, + const std::int64_t k, + const char *matrixA, + const std::int64_t lda, + char *resultC, + const std::int64_t ldc, +#if !defined(USE_ONEMKL_CUBLAS) + const bool is_row_major, +#endif // !USE_ONEMKL_CUBLAS + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + const T *a = reinterpret_cast(matrixA); + T *res = reinterpret_cast(resultC); + + std::stringstream error_msg; + bool is_exception_caught = false; + + sycl::event syrk_event; + try { + auto syrk_func = + [&](sycl::queue &q, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose transA, const std::int64_t n, + const std::int64_t k, T alpha, const T *a, + const std::int64_t lda, T beta, T *c, const std::int64_t ldc, + const std::vector &deps) -> sycl::event { +#if defined(USE_ONEMKL_CUBLAS) + return mkl_blas::column_major::syrk(q, upper_lower, transA, n, k, + alpha, a, lda, beta, c, ldc, + deps); +#else + if (is_row_major) { + return mkl_blas::row_major::syrk(q, upper_lower, transA, n, k, + alpha, a, lda, beta, c, ldc, + deps); + } + else { + return mkl_blas::column_major::syrk(q, upper_lower, transA, n, + k, alpha, a, lda, beta, c, + ldc, deps); + } +#endif // USE_ONEMKL_CUBLAS + }; + + // we pass beta = 0, so passing upper or lower does not matter + oneapi::mkl::uplo uplo = oneapi::mkl::uplo::upper; + syrk_event = syrk_func( + exec_q, + uplo, // Specifies whether C’s data is stored in its upper + // or lower triangle + transA, // Defines the transpose operation for matrix A: + // 'N' indicates no transpose, 'T' for transpose, + // or 'C' for a conjugate transpose. + n, // Number of rows in op(A). + // Number of rows and columns in C. + k, // Number of columns in op(A). + T(1), // Scaling factor for the rank-k update. + a, // Pointer to the input matrix A. + lda, // Leading dimension of matrix A, which is the + // stride between successive rows (for row major layout). + T(0), // Scaling factor for matrix C. + res, // Pointer to output matrix c, where the result is stored. + ldc, // Leading dimension of matrix C. + depends); + } catch (oneapi::mkl::exception const &e) { + error_msg + << "Unexpected MKL exception caught during syrk() call:\nreason: " + << e.what(); + is_exception_caught = true; + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during syrk() call:\n" + << e.what(); + is_exception_caught = true; + } + + if (is_exception_caught) // an unexpected error occurs + { + throw std::runtime_error(error_msg.str()); + } + + // kernel to copy upper triangle to lower triangle + sycl::event copy_event = exec_q.submit([&](sycl::handler &h) { + h.depends_on(syrk_event); + + h.parallel_for( + sycl::range<2>{static_cast(n), static_cast(n)}, + [=](sycl::id<2> idx) { + std::int64_t i = idx[0]; + std::int64_t j = idx[1]; + if (j > i) { + res[j * ldc + i] = res[i * ldc + j]; + } + }); + }); + + return copy_event; + + /* Copy the triangle + syrk_event.wait(); + for (std::int64_t i = 0; i < n; i++) { + for (std::int64_t j = i + 1; j < n; j++) { + res[j * ldc + i] = res[i * ldc + j]; + } + } + + return syrk_event;*/ +} + +std::pair + syrk(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &matrixA, + const dpctl::tensor::usm_ndarray &resultC, + const std::vector &depends) +{ + const int matrixA_nd = matrixA.get_ndim(); + const int resultC_nd = resultC.get_ndim(); + + if ((matrixA_nd != 2) || (resultC_nd != 2)) { + throw py::value_error("The given arrays have incorrect dimensions."); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(matrixA, resultC)) { + throw py::value_error("Input and output matrices are overlapping " + "segments of memory"); + } + + if (!dpctl::utils::queues_are_compatible( + exec_q, {matrixA.get_queue(), resultC.get_queue()})) + { + throw py::value_error( + "USM allocations are not compatible with the execution queue."); + } + + const py::ssize_t *a_shape = matrixA.get_shape_raw(); + const py::ssize_t *c_shape = resultC.get_shape_raw(); + if (c_shape[0] != c_shape[1]) { + throw py::value_error("The output matrix should be square."); + } + if (a_shape[0] != c_shape[0]) { + throw py::value_error("The number of rows in A must be equal to " + "the number of rows in result array."); + } + + const bool is_matrixA_f_contig = matrixA.is_f_contiguous(); + const bool is_matrixA_c_contig = matrixA.is_c_contiguous(); + if (!is_matrixA_f_contig and !is_matrixA_c_contig) { + throw py::value_error( + "Input matrix is not c-contiguous nor f-contiguous."); + } + + oneapi::mkl::transpose transA; + std::size_t src_nelems; + +// cuBLAS supports only column-major storage +#if defined(USE_ONEMKL_CUBLAS) + const bool is_row_major = false; + std::int64_t n; + std::int64_t k; + + if (is_matrixA_f_contig) { + transA = oneapi::mkl::transpose::N; + n = a_shape[0]; + k = a_shape[1]; + src_nelems = n * n; + } + else { + transA = oneapi::mkl::transpose::T; + k = a_shape[0]; + n = a_shape[1]; + src_nelems = k * k; + } +#else + bool is_row_major = true; + if (is_matrixA_f_contig) { + is_row_major = false; + } + + transA = oneapi::mkl::transpose::N; + const std::int64_t n = a_shape[0]; + const std::int64_t k = a_shape[1]; + src_nelems = n * n; +#endif // USE_ONEMKL_CUBLAS + + const std::int64_t lda = is_row_major ? k : n; + const std::int64_t ldc = n; + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC); + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC, + src_nelems); + + const int matrixA_typenum = matrixA.get_typenum(); + const int resultC_typenum = resultC.get_typenum(); + if (matrixA_typenum != resultC_typenum) { + throw py::value_error("Given arrays must be of the same type."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + const int type_id = array_types.typenum_to_lookup_id(matrixA_typenum); + syrk_impl_fn_ptr_t syrk_fn = syrk_dispatch_vector[type_id]; + if (syrk_fn == nullptr) { + throw py::value_error( + "Types of input arrays and result array are mismatched."); + } + + const char *a_typeless_ptr = matrixA.get_data(); + char *r_typeless_ptr = resultC.get_data(); + +#if defined(USE_ONEMKL_CUBLAS) + sycl::event syrk_ev = syrk_fn(exec_q, transA, n, k, a_typeless_ptr, lda, + r_typeless_ptr, ldc, depends); +#else + sycl::event syrk_ev = syrk_fn(exec_q, transA, n, k, a_typeless_ptr, lda, + r_typeless_ptr, ldc, is_row_major, depends); +#endif // USE_ONEMKL_CUBLAS + + sycl::event args_ev = + dpctl::utils::keep_args_alive(exec_q, {matrixA, resultC}, {syrk_ev}); + + return std::make_pair(args_ev, syrk_ev); +} + +template +struct SyrkContigFactory +{ + fnT get() + { + if constexpr (types::SyrkTypePairSupportFactory::is_defined) { + return syrk_impl; + } + else { + return nullptr; + } + } +}; + +void init_syrk_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(syrk_dispatch_vector); +} +} // namespace dpnp::extensions::blas diff --git a/dpnp/backend/extensions/blas/syrk.hpp b/dpnp/backend/extensions/blas/syrk.hpp new file mode 100644 index 000000000000..7fd38a9abdb7 --- /dev/null +++ b/dpnp/backend/extensions/blas/syrk.hpp @@ -0,0 +1,42 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp::extensions::blas +{ +extern std::pair + syrk(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &matrixA, + const dpctl::tensor::usm_ndarray &resultC, + const std::vector &depends); + +extern void init_syrk_dispatch_vector(void); +} // namespace dpnp::extensions::blas diff --git a/dpnp/backend/extensions/blas/types_matrix.hpp b/dpnp/backend/extensions/blas/types_matrix.hpp index 22fc98f05137..915a704e536d 100644 --- a/dpnp/backend/extensions/blas/types_matrix.hpp +++ b/dpnp/backend/extensions/blas/types_matrix.hpp @@ -186,4 +186,29 @@ struct GemvTypePairSupportFactory // fall-through dpctl_td_ns::NotDefinedEntry>::is_defined; }; + +/** + * @brief A factory to define pairs of supported types for which + * MKL BLAS library provides support in oneapi::mkl::blas::syrk + * function. + * + * @tparam T Type of input and output arrays. + */ +template +struct SyrkTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; } // namespace dpnp::extensions::blas::types diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 71314c90272c..c800430c0c2c 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -50,6 +50,29 @@ ] +def _call_syrk(x1, x2): + """ + Check to see if `syrk` can be called instead of `gemm`. + + It is assumed that `x1` and `x2` are usm_ndarrays. It is already validated + that input arrays here are 2-d and contiguous. With this assumption, here + we only check if both arrays point to the same memory, the number of rows + in the first array is equal to the number of columns in the second array, + and if one array is c_contiguous the other one is f_contiguous. + + """ + call_syrk = False + if ( + x1._pointer == x2._pointer + and x1.shape[0] == x2.shape[1] + and x1.flags.c_contiguous != x2.flags.c_contiguous + and x1.flags.f_contiguous != x2.flags.f_contiguous + ): + call_syrk = True + + return call_syrk + + def _compute_res_dtype(*arrays, sycl_queue, dtype=None, out=None, casting="no"): """ Determines the output array data type. @@ -310,10 +333,11 @@ def _gemm_batch_matmul(exec_q, x1, x2, res): def _gemm_matmul(exec_q, x1, x2, res): _manager = dpu.SequentialOrderManager[exec_q] + # it is assumed that x1 and x2 are usm_ndarrays ht_ev, gemm_ev, row_major = bi._gemm( exec_q, - dpnp.get_usm_ndarray(x1), - dpnp.get_usm_ndarray(x2), + x1, + x2, dpnp.get_usm_ndarray(res), depends=_manager.submitted_events, ) @@ -334,7 +358,7 @@ def _gemm_matmul(exec_q, x1, x2, res): def _gemm_special_case(x1, x2, res_dtype, call_flag): """ `gemm` and `gemm_batch` support these special cases of data types - while `gemv` does not. + while `gemv` or `syrk` do not. """ @@ -1062,7 +1086,6 @@ def dpnp_multiplication( x_usm = dpnp.get_usm_ndarray(x2) _manager = dpu.SequentialOrderManager[exec_q] - ht_ev, gemv_ev = bi._gemv( exec_q, a_usm, @@ -1073,7 +1096,20 @@ def dpnp_multiplication( ) _manager.add_event_pair(ht_ev, gemv_ev) elif call_flag == "gemm": - result = _gemm_matmul(exec_q, x1, x2, result) + x1_usm = dpnp.get_usm_ndarray(x1) + x2_usm = dpnp.get_usm_ndarray(x2) + call_syrk = _call_syrk(x1_usm, x2_usm) + if call_syrk: + _manager = dpu.SequentialOrderManager[exec_q] + ht_ev, gemv_ev = bi._syrk( + exec_q, + x1_usm, + dpnp.get_usm_ndarray(result), + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, gemv_ev) + else: + result = _gemm_matmul(exec_q, x1_usm, x2_usm, result) else: assert call_flag == "gemm_batch" result = _gemm_batch_matmul(exec_q, x1, x2, result) diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index 0eccd4deefc1..36a3b2ef9bcb 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -13,6 +13,7 @@ generate_random_numpy_array, get_all_dtypes, get_complex_dtypes, + get_float_complex_dtypes, is_win_platform, numpy_version, ) @@ -1059,17 +1060,19 @@ def test_strided_vec_mat(self, dtype, func, incx, incy, transpose): @pytest.mark.parametrize("dtype", _selected_dtypes) def test_out_order1(self, order1, order2, out_order, dtype): # test gemm with out keyword - a = generate_random_numpy_array((5, 4), dtype, low=-5, high=5) - b = generate_random_numpy_array((4, 7), dtype, low=-5, high=5) - a = numpy.array(a, order=order1) - b = numpy.array(b, order=order2) + a = generate_random_numpy_array( + (5, 4), dtype, order=order1, low=-5, high=5 + ) + b = generate_random_numpy_array( + (4, 7), dtype, order=order2, low=-5, high=5 + ) ia, ib = dpnp.array(a), dpnp.array(b) - iout = dpnp.empty((5, 7), dtype=dtype, order=out_order) + out = numpy.empty((5, 7), dtype=dtype, order=out_order) + iout = dpnp.array(out) result = dpnp.matmul(ia, ib, out=iout) assert result is iout - out = numpy.empty((5, 7), dtype=dtype, order=out_order) expected = numpy.matmul(a, b, out=out) assert result.flags.c_contiguous == expected.flags.c_contiguous assert result.flags.f_contiguous == expected.flags.f_contiguous @@ -1181,6 +1184,36 @@ def test_special_case(self, dt_out, shape1, shape2): result = dpnp.matmul(ia, ib, out=iout) assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("dt", get_float_complex_dtypes()) + def test_syrk(self, dt): + a = generate_random_numpy_array((6, 9), dtype=dt) + ia = dpnp.array(a) + + result = dpnp.matmul(ia, ia.mT) + expected = numpy.matmul(a, a.T) + assert_dtype_allclose(result, expected) + + iout = dpnp.empty(result.shape, dtype=dt) + result = dpnp.matmul(ia, ia.mT, out=iout) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "order, out_order", + [("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")], + ) + def test_syrk_out_order1(self, order, out_order): + # test syrk with out keyword + a = generate_random_numpy_array((5, 4), order=order, low=-5, high=5) + out = numpy.empty((5, 5), dtype=a.dtype, order=out_order) + ia, iout = dpnp.array(a), dpnp.array(out) + + expected = numpy.matmul(a, a.T, out=out) + result = dpnp.matmul(ia, ia.mT, out=iout) + assert result is iout + assert result.flags.c_contiguous == expected.flags.c_contiguous + assert result.flags.f_contiguous == expected.flags.f_contiguous + assert_dtype_allclose(result, expected) + def test_bool(self): a = generate_random_numpy_array((3, 4), dtype=dpnp.bool) b = generate_random_numpy_array((4, 5), dtype=dpnp.bool) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index b0112702e308..2d1a1af8ab6f 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -415,9 +415,6 @@ def test_1in_1out(func, data, device): pytest.param("ldexp", [5, 5, 5, 5, 5], [0, 1, 2, 3, 4]), pytest.param("logaddexp", [-1, 2, 5, 9], [4, -3, 2, -8]), pytest.param("logaddexp2", [-1, 2, 5, 9], [4, -3, 2, -8]), - pytest.param( - "matmul", [[1.0, 0.0], [0.0, 1.0]], [[4.0, 1.0], [1.0, 2.0]] - ), pytest.param("maximum", [2.0, 3.0, 4.0], [1.0, 5.0, 2.0]), pytest.param("minimum", [2.0, 3.0, 4.0], [1.0, 5.0, 2.0]), pytest.param( @@ -633,6 +630,7 @@ def test_bitwise_op_2in(op, device): @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) +@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) @pytest.mark.parametrize( "shape1, shape2", [ @@ -658,9 +656,12 @@ def test_bitwise_op_2in(op, device): "((6, 7, 4, 3), (6, 7, 3, 5))", ], ) -def test_matmul(device, shape1, shape2): - a = dpnp.arange(numpy.prod(shape1), device=device).reshape(shape1) - b = dpnp.arange(numpy.prod(shape2), device=device).reshape(shape2) +def test_matmul(device, dtype, shape1, shape2): + # dtype is needed, int32 checks dpctl implementation and float32 checks + # OneMKL implementation + a = dpnp.arange(numpy.prod(shape1), dtype=dtype, device=device) + b = dpnp.arange(numpy.prod(shape2), dtype=dtype, device=device) + a, b = a.reshape(shape1), b.reshape(shape2) result = dpnp.matmul(a, b) result_queue = result.sycl_queue @@ -668,6 +669,13 @@ def test_matmul(device, shape1, shape2): assert_sycl_queue_equal(result_queue, b.sycl_queue) +@pytest.mark.parametrize("device", valid_dev, ids=dev_ids) +def test_matmul_syrk(device): + a = dpnp.arange(20, dtype=dpnp.float32, device=device).reshape(4, 5) + result = dpnp.matmul(a, a.mT) + assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue) + + @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) @pytest.mark.parametrize( "shape1, shape2", diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index 1d512ce111a6..b3d157303ba6 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -405,6 +405,7 @@ def test_bitwise_op_2in(op, usm_type_x, usm_type_y): @pytest.mark.parametrize("usm_type_x", list_of_usm_types) @pytest.mark.parametrize("usm_type_y", list_of_usm_types) +@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) @pytest.mark.parametrize( "shape1, shape2", [ @@ -430,9 +431,12 @@ def test_bitwise_op_2in(op, usm_type_x, usm_type_y): "((6, 7, 4, 3), (6, 7, 3, 5))", ], ) -def test_matmul(usm_type_x, usm_type_y, shape1, shape2): - x = dpnp.arange(numpy.prod(shape1), usm_type=usm_type_x).reshape(shape1) - y = dpnp.arange(numpy.prod(shape2), usm_type=usm_type_y).reshape(shape2) +def test_matmul(usm_type_x, usm_type_y, dtype, shape1, shape2): + # dtype is needed, int32 checks dpctl implementation and float32 checks + # OneMKL implementation + x = dpnp.arange(numpy.prod(shape1), dtype=dtype, usm_type=usm_type_x) + y = dpnp.arange(numpy.prod(shape2), dtype=dtype, usm_type=usm_type_y) + x, y = x.reshape(shape1), y.reshape(shape2) z = dpnp.matmul(x, y) assert x.usm_type == usm_type_x @@ -440,6 +444,14 @@ def test_matmul(usm_type_x, usm_type_y, shape1, shape2): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) +@pytest.mark.parametrize("usm_type", list_of_usm_types) +def test_matmul_syrk(usm_type): + x = dpnp.arange(20, dtype=dpnp.float32, usm_type=usm_type).reshape(4, 5) + y = dpnp.matmul(x, x.mT) + + assert y.usm_type == usm_type + + @pytest.mark.parametrize("usm_type_x", list_of_usm_types) @pytest.mark.parametrize("usm_type_y", list_of_usm_types) @pytest.mark.parametrize(