-
Notifications
You must be signed in to change notification settings - Fork 37
Add support for specialization constants #2304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
7946152
b5befc3
8ddc440
9e264f2
1de2f07
b100cf5
e2e4826
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,18 @@ an OpenCL source string or a SPIR-V binary file. | |
|
|
||
| """ | ||
|
|
||
| from cpython.buffer cimport ( | ||
| Py_buffer, | ||
| PyBUF_ANY_CONTIGUOUS, | ||
| PyBUF_SIMPLE, | ||
| PyBuffer_Release, | ||
| PyObject_CheckBuffer, | ||
| PyObject_GetBuffer, | ||
| ) | ||
| from cpython.bytes cimport PyBytes_FromStringAndSize | ||
| from libc.stdint cimport uint32_t | ||
| from libc.stdlib cimport free, malloc | ||
| from libc.string cimport memcmp | ||
|
|
||
| import warnings | ||
|
|
||
|
|
@@ -51,14 +62,20 @@ from dpctl._backend cimport ( # noqa: E211, E402; | |
| DPCTLSyclDeviceRef, | ||
| DPCTLSyclKernelBundleRef, | ||
| DPCTLSyclKernelRef, | ||
| _spec_const, | ||
| ) | ||
|
|
||
| import numbers | ||
|
|
||
| import numpy as np | ||
|
|
||
| __all__ = [ | ||
| "create_kernel_bundle_from_source", | ||
| "create_kernel_bundle_from_spirv", | ||
| "SyclKernel", | ||
| "SyclKernelBundle", | ||
| "SyclKernelBundleCompilationError", | ||
| "SpecializationConstant", | ||
| ] | ||
|
|
||
| cdef class SyclKernelBundleCompilationError(Exception): | ||
|
|
@@ -252,6 +269,160 @@ cdef api SyclKernelBundle SyclKernelBundle_Make(DPCTLSyclKernelBundleRef KBRef): | |
| return SyclKernelBundle._create(copied_KBRef) | ||
|
|
||
|
|
||
| cdef class SpecializationConstant: | ||
| """ | ||
| SpecializationConstant(spec_id, *args) | ||
|
|
||
| Python class representing SYCL specialization constants that can be used | ||
| when creating a :class:`dpctl.program.SyclKernelBundle` from SPIR-V. | ||
|
|
||
| There are multiple ways to create a :class:`.SpecializationConstant`: | ||
|
|
||
| - ``SpecializationConstant(spec_id, obj)`` | ||
| If the constructor is invoked with a single variadic argument, the | ||
| argument is expected to either expose the Python buffer protocol or be | ||
| coercible to a NumPy array. If the argument is coercible to a NumPy array | ||
| or is one, it must have a supported data type (bool, integral, or | ||
| floating point). The specialization constant will be constructed from the | ||
| data in the buffer | ||
|
|
||
| - ``SpecializationConstant(spec_id, dtype, obj)`` | ||
| If the constructor is invoked with two variadic arguments, and the first | ||
| argument is a string, it is interpreted as a NumPy ``dtype`` string and the | ||
| second argument will be coerced to a NumPy array with that data type. | ||
| The data type specified by the first argument must be a supported data | ||
| type (bool, integral, or floating point). | ||
|
|
||
| - ``SpecializationConstant(spec_id, nbytes, raw_ptr)`` | ||
| If the constructor is invoked with two variadic arguments where both are | ||
| integers, the first argument is interpreted as the number of bytes and | ||
| the second argument is interpreted as a pointer to the data. | ||
|
|
||
| Note that when constructing from a buffer, the | ||
| :class:`.SpecializationConstant`, shares memory with the original object. | ||
| Modifications to the original object's data after construction will be | ||
| reflected when the :class:`.SpecializationConstant` is used to create a | ||
| :class:`.SyclKernelBundle`. This is not the case when constructing from a | ||
| raw pointer, as the data is copied. | ||
|
|
||
| Args: | ||
| spec_id (int): | ||
| The SPIR-V specialization ID. | ||
| args: | ||
| Variadic argument, see class documentation. | ||
|
|
||
| Raises: | ||
| TypeError: In case of incorrect arguments given to constructor, | ||
| failure to coerce to a buffer, or unsupported data type when | ||
| coercing to a buffer. | ||
| ValueError: If the provided object fails to construct a buffer. | ||
| """ | ||
|
|
||
| cdef _spec_const _spec_const | ||
| cdef Py_buffer _buffer | ||
|
|
||
| def __cinit__(self, spec_id, *args): | ||
| cdef int ret_code = 0 | ||
| cdef object target_obj = None | ||
|
|
||
| if not isinstance(spec_id, numbers.Integral): | ||
| raise TypeError( | ||
| "Specialization constant ID must be of type `int`, got " | ||
| f"{type(spec_id)}" | ||
| ) | ||
|
|
||
| if len(args) == 0 or len(args) > 2: | ||
| raise TypeError( | ||
| f"Constructor takes 2 or 3 arguments, got {len(args)}." | ||
| ) | ||
|
|
||
| self._spec_const.id = <uint32_t>spec_id | ||
|
|
||
| if len(args) == 2: | ||
| if ( | ||
| isinstance(args[0], numbers.Integral) and | ||
| isinstance(args[1], numbers.Integral) | ||
| ): | ||
| target_obj = PyBytes_FromStringAndSize( | ||
| <const char *><size_t>args[1], <Py_ssize_t>args[0] | ||
| ) | ||
| elif isinstance(args[0], str): | ||
| target_obj = np.ascontiguousarray(args[1], dtype=args[0]) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need an explicit error handling here in case of |
||
| elif len(args) == 1: | ||
| target_obj = args[0] | ||
| if not PyObject_CheckBuffer(target_obj): | ||
| # attempt to coerce to a numpy array | ||
| target_obj = np.ascontiguousarray(target_obj) | ||
| else: | ||
| raise TypeError( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably it'd better to move |
||
| "Invalid arguments." | ||
| ) | ||
|
|
||
| if isinstance(target_obj, np.ndarray): | ||
| if target_obj.dtype.kind not in ("b", "i", "u", "f", "c"): | ||
| raise TypeError( | ||
| "Coercion of input to buffer resulted in an unsupported " | ||
| f"data type '{target_obj.dtype}'. When coercing objects, " | ||
| "`SpecializationConstant` expects the data to coerce to a " | ||
| "supported type: bool, integral, or real or complex " | ||
| "floating point. To pass arbitrary data, use a " | ||
| "`memoryview` or `bytes` object, or pass the pointer and " | ||
| "size directly." | ||
| ) | ||
|
|
||
| ret_code = PyObject_GetBuffer( | ||
| target_obj, &(self._buffer), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | ||
| ) | ||
| if ret_code != 0: | ||
| raise ValueError( | ||
| "Failed to get buffer view for the provided object." | ||
| ) | ||
| self._spec_const.value = <void*>self._buffer.buf | ||
| self._spec_const.size = <size_t>self._buffer.len | ||
|
|
||
| def __dealloc__(self): | ||
| PyBuffer_Release(&(self._buffer)) | ||
|
|
||
| def __repr__(self): | ||
| return f"SpecializationConstant({self._spec_const.id})" | ||
|
|
||
| def __eq__(self, other): | ||
| if not isinstance(other, SpecializationConstant): | ||
| return False | ||
| cdef SpecializationConstant _other = <SpecializationConstant>other | ||
| if ( | ||
| self._spec_const.id != _other._spec_const.id or | ||
| self._spec_const.size != _other._spec_const.size or | ||
| self._spec_const.value != _other._spec_const.value | ||
| ): | ||
| return False | ||
| return memcmp( | ||
| self._spec_const.value, | ||
| _other._spec_const.value, | ||
| self._spec_const.size | ||
| ) == 0 | ||
|
|
||
| @property | ||
| def id(self): | ||
| """Returns the specialization ID for this specialization constant.""" | ||
| return self._spec_const.id | ||
|
|
||
| @property | ||
| def size(self): | ||
| """ | ||
| Returns the size in bytes of the data for this specialization constant. | ||
| """ | ||
| return self._spec_const.size | ||
|
|
||
| cdef size_t addressof(self): | ||
| """ | ||
| Returns the address of the _spec_const for this | ||
| :class:`.SpecializationConstant` cast to ``size_t``. | ||
| """ | ||
| return <size_t>&(self._spec_const) | ||
|
|
||
|
|
||
| cpdef create_kernel_bundle_from_source(SyclQueue q, str src, str copts=""): | ||
| """ | ||
| Creates a Sycl interoperability kernel bundle from an OpenCL source | ||
|
|
@@ -299,7 +470,10 @@ cpdef create_kernel_bundle_from_source(SyclQueue q, str src, str copts=""): | |
|
|
||
|
|
||
| cpdef create_kernel_bundle_from_spirv( | ||
| SyclQueue q, const unsigned char[:] IL, str copts="" | ||
| SyclQueue q, | ||
| const unsigned char[:] IL, | ||
| str copts="", | ||
| list specializations=None, | ||
| ): | ||
| """ | ||
| Creates a Sycl interoperability kernel bundle from an SPIR-V binary. | ||
|
|
@@ -317,7 +491,9 @@ cpdef create_kernel_bundle_from_spirv( | |
| copts (str, optional) | ||
| Optional compilation flags that will be used | ||
| when compiling the kernel bundle. Default: ``""``. | ||
|
|
||
| specializations (list, optional) | ||
| A list of :class:`.SpecializationConstant` objects to be used | ||
| when creating the kernel bundle. Default: ``None``. | ||
| Returns: | ||
| kernel_bundle (:class:`.SyclKernelBundle`) | ||
| A :class:`.SyclKernelBundle` object wrapping the | ||
|
|
@@ -336,11 +512,44 @@ cpdef create_kernel_bundle_from_spirv( | |
| cdef size_t length = IL.shape[0] | ||
| cdef bytes bCOpts = copts.encode("utf8") | ||
| cdef const char *COpts = <const char*>bCOpts | ||
| KBref = DPCTLKernelBundle_CreateFromSpirv( | ||
| CRef, DRef, <const void*>dIL, length, COpts | ||
| ) | ||
| if KBref is NULL: | ||
| raise SyclKernelBundleCompilationError() | ||
| cdef size_t num_spconsts | ||
| cdef _spec_const *spconsts | ||
| cdef SpecializationConstant spconst | ||
|
|
||
| if specializations is not None: | ||
| num_spconsts = len(specializations) | ||
| spconsts = <_spec_const *>( | ||
| malloc(num_spconsts * sizeof(_spec_const)) | ||
| ) | ||
| if spconsts == NULL: | ||
| raise MemoryError( | ||
| "Failed to allocate memory for specialization constants." | ||
| ) | ||
| for i, spconst in enumerate(specializations): | ||
| if not isinstance(spconst, SpecializationConstant): | ||
| free(spconsts) | ||
| raise TypeError( | ||
| "All items in specializations must be of type " | ||
| f"`SpecializationConstant`, got {type(spconst)}" | ||
| ) | ||
| spconsts[i] = spconst._spec_const | ||
| else: | ||
| num_spconsts = 0 | ||
| spconsts = NULL | ||
| try: | ||
| KBref = DPCTLKernelBundle_CreateFromSpirv( | ||
| CRef, | ||
| DRef, | ||
| <const void*>dIL, | ||
| length, COpts, | ||
| num_spconsts, | ||
| spconsts, | ||
| ) | ||
| if KBref is NULL: | ||
| raise SyclKernelBundleCompilationError() | ||
| finally: | ||
| if spconsts != NULL: | ||
| free(spconsts) | ||
|
|
||
| return SyclKernelBundle._create(KBref) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # Data Parallel Control (dpctl) | ||
| # | ||
| # Copyright 2020-2025 Intel Corporation | ||
|
antonwolfy marked this conversation as resolved.
|
||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| A collection of utility functions for dpctl.program module. | ||
| """ | ||
|
|
||
| from ._utils import parse_spirv_specializations | ||
|
|
||
| __all__ = [ | ||
| "parse_spirv_specializations", | ||
| ] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need to warn also about object lifetime: if the source object is deleted, the SpecializationConstant holds a dangling pointer.