Skip to content

Commit ae41152

Browse files
committed
task: syclqueue.copy() method
1 parent 442d61f commit ae41152

File tree

3 files changed

+150
-0
lines changed

3 files changed

+150
-0
lines changed

dpctl/_sycl_queue.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ cdef public api class SyclQueue (_SyclQueue) [
103103
cdef DPCTLSyclQueueRef get_queue_ref(self)
104104
cpdef memcpy(self, dest, src, size_t count)
105105
cpdef SyclEvent memcpy_async(self, dest, src, size_t count, list dEvents=*)
106+
cpdef copy(self, dest, src, size_t count)
107+
cpdef SyclEvent copy_async(self, dest, src, size_t count, list dEvents=*)
106108
cpdef prefetch(self, ptr, size_t count=*)
107109
cpdef mem_advise(self, ptr, size_t count, int mem)
108110
cpdef SyclEvent submit_barrier(self, dependent_events=*)

dpctl/_sycl_queue.pyx

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,76 @@ cdef class SyclQueue(_SyclQueue):
14261426

14271427
return SyclEvent._create(ERef)
14281428

1429+
cpdef copy(self, dest, src, size_t count):
1430+
"""Copy ``count`` bytes from ``src`` to ``dest`` and wait.
1431+
1432+
This is a synchronizing variant corresponding to
1433+
:meth:`dpctl.SyclQueue.copy_async`.
1434+
"""
1435+
cdef DPCTLSyclEventRef ERef = NULL
1436+
1437+
ERef = _memcpy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
1438+
if (ERef is NULL):
1439+
raise RuntimeError(
1440+
"SyclQueue.copy operation encountered an error"
1441+
)
1442+
with nogil:
1443+
DPCTLEvent_Wait(ERef)
1444+
DPCTLEvent_Delete(ERef)
1445+
1446+
cpdef SyclEvent copy_async(
1447+
self, dest, src, size_t count, list dEvents=None
1448+
):
1449+
"""Copy ``count`` bytes from ``src`` to ``dest`` asynchronously.
1450+
1451+
Args:
1452+
dest:
1453+
Destination USM object or Python object supporting
1454+
writable buffer protocol.
1455+
src:
1456+
Source USM object or Python object supporting buffer
1457+
protocol.
1458+
count (int):
1459+
Number of bytes to copy.
1460+
dEvents (List[dpctl.SyclEvent], optional):
1461+
Events that this copy depends on.
1462+
1463+
Returns:
1464+
dpctl.SyclEvent:
1465+
Event associated with the copy operation.
1466+
"""
1467+
cdef DPCTLSyclEventRef ERef = NULL
1468+
cdef DPCTLSyclEventRef *depEvents = NULL
1469+
cdef size_t nDE = 0
1470+
1471+
if dEvents is None:
1472+
ERef = _memcpy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
1473+
else:
1474+
nDE = len(dEvents)
1475+
depEvents = (
1476+
<DPCTLSyclEventRef*>malloc(nDE*sizeof(DPCTLSyclEventRef))
1477+
)
1478+
if depEvents is NULL:
1479+
raise MemoryError()
1480+
else:
1481+
for idx, de in enumerate(dEvents):
1482+
if isinstance(de, SyclEvent):
1483+
depEvents[idx] = (<SyclEvent>de).get_event_ref()
1484+
else:
1485+
free(depEvents)
1486+
raise TypeError(
1487+
"A sequence of dpctl.SyclEvent is expected"
1488+
)
1489+
ERef = _memcpy_impl(self, dest, src, count, depEvents, nDE)
1490+
free(depEvents)
1491+
1492+
if (ERef is NULL):
1493+
raise RuntimeError(
1494+
"SyclQueue.copy operation encountered an error"
1495+
)
1496+
1497+
return SyclEvent._create(ERef)
1498+
14291499
cpdef prefetch(self, mem, size_t count=0):
14301500
cdef void *ptr
14311501
cdef DPCTLSyclEventRef ERef = NULL
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Defines unit test cases for the SyclQueue.copy."""
18+
19+
import pytest
20+
21+
import dpctl
22+
import dpctl.memory
23+
24+
25+
def _create_memory(q):
26+
nbytes = 1024
27+
mobj = dpctl.memory.MemoryUSMShared(nbytes, queue=q)
28+
return mobj
29+
30+
31+
def test_copy_copy_host_to_host():
32+
try:
33+
q = dpctl.SyclQueue()
34+
except dpctl.SyclQueueCreationError:
35+
pytest.skip("Default constructor for SyclQueue failed")
36+
37+
src_buf = b"abcdefghijklmnopqrstuvwxyz"
38+
dst_buf = bytearray(len(src_buf))
39+
40+
q.copy(dst_buf, src_buf, len(src_buf))
41+
42+
assert dst_buf == src_buf
43+
44+
45+
def test_copy_async():
46+
try:
47+
q = dpctl.SyclQueue()
48+
except dpctl.SyclQueueCreationError:
49+
pytest.skip("Default constructor for SyclQueue failed")
50+
51+
src_buf = b"abcdefghijklmnopqrstuvwxyz"
52+
n = len(src_buf)
53+
dst_buf = bytearray(n)
54+
dst_buf2 = bytearray(n)
55+
56+
e = q.copy_async(dst_buf, src_buf, n)
57+
e2 = q.copy_async(dst_buf2, src_buf, n, [e])
58+
59+
e.wait()
60+
e2.wait()
61+
assert dst_buf == src_buf
62+
assert dst_buf2 == src_buf
63+
64+
65+
def test_copy_type_error():
66+
try:
67+
q = dpctl.SyclQueue()
68+
except dpctl.SyclQueueCreationError:
69+
pytest.skip("Default constructor for SyclQueue failed")
70+
mobj = _create_memory(q)
71+
72+
with pytest.raises(TypeError) as cm:
73+
q.copy(None, mobj, 3)
74+
assert "_Memory" in str(cm.value)
75+
76+
with pytest.raises(TypeError) as cm:
77+
q.copy(mobj, None, 3)
78+
assert "_Memory" in str(cm.value)

0 commit comments

Comments
 (0)