Skip to content

Commit 19331c0

Browse files
Merge pull request #1830 from IntelPython/backport-gh-1827
Backport gh-1827
2 parents 0a0e9ae + 406af46 commit 19331c0

16 files changed

+259
-83
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ The full list of changes that went into this release are:
4040
* Update version of 'pybind11' used [gh-1758](https://github.com/IntelPython/dpctl/pull/1758), [gh-1812](https://github.com/IntelPython/dpctl/pull/1812)
4141
* Handle possible exceptions by `usm_host_allocator` used with `std::vector` [gh-1791](https://github.com/IntelPython/dpctl/pull/1791)
4242
* Use `dpctl::tensor::offset_utils::sycl_free_noexcept` instead of `sycl::free` in `host_task` tasks associated with life-time management of temporary USM allocations [gh-1797](https://github.com/IntelPython/dpctl/pull/1797)
43+
* Add `"same_kind"`-style casting for in-place mathematical operators of `tensor.usm_ndarray` [gh-1827](https://github.com/IntelPython/dpctl/pull/1827), [gh-1830](https://github.com/IntelPython/dpctl/pull/1830)
4344

4445
### Fixed
4546

dpctl/tensor/_elementwise_common.py

+141-3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
_all_data_types,
3838
_find_buf_dtype,
3939
_find_buf_dtype2,
40+
_find_buf_dtype_in_place_op,
4041
_resolve_weak_types,
4142
_to_device_supported_dtype,
4243
)
@@ -213,8 +214,8 @@ def __call__(self, x, /, *, out=None, order="K"):
213214

214215
if res_dt != out.dtype:
215216
raise ValueError(
216-
f"Output array of type {res_dt} is needed,"
217-
f" got {out.dtype}"
217+
f"Output array of type {res_dt} is needed, "
218+
f"got {out.dtype}"
218219
)
219220

220221
if (
@@ -650,7 +651,7 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
650651

651652
if res_dt != out.dtype:
652653
raise ValueError(
653-
f"Output array of type {res_dt} is needed,"
654+
f"Output array of type {res_dt} is needed, "
654655
f"got {out.dtype}"
655656
)
656657

@@ -927,3 +928,140 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
927928
)
928929
_manager.add_event_pair(ht_, bf_ev)
929930
return out
931+
932+
def _inplace_op(self, o1, o2):
933+
if self.binary_inplace_fn_ is None:
934+
raise ValueError(
935+
"binary function does not have a dedicated in-place "
936+
"implementation"
937+
)
938+
if not isinstance(o1, dpt.usm_ndarray):
939+
raise TypeError(
940+
"Expected first argument to be "
941+
f"dpctl.tensor.usm_ndarray, got {type(o1)}"
942+
)
943+
if not o1.flags.writable:
944+
raise ValueError("provided left-hand side array is read-only")
945+
q1, o1_usm_type = o1.sycl_queue, o1.usm_type
946+
q2, o2_usm_type = _get_queue_usm_type(o2)
947+
if q2 is None:
948+
exec_q = q1
949+
res_usm_type = o1_usm_type
950+
else:
951+
exec_q = dpctl.utils.get_execution_queue((q1, q2))
952+
if exec_q is None:
953+
raise ExecutionPlacementError(
954+
"Execution placement can not be unambiguously inferred "
955+
"from input arguments."
956+
)
957+
res_usm_type = dpctl.utils.get_coerced_usm_type(
958+
(
959+
o1_usm_type,
960+
o2_usm_type,
961+
)
962+
)
963+
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
964+
o1_shape = o1.shape
965+
o2_shape = _get_shape(o2)
966+
if not isinstance(o2_shape, (tuple, list)):
967+
raise TypeError(
968+
"Shape of second argument can not be inferred. "
969+
"Expected list or tuple."
970+
)
971+
try:
972+
res_shape = _broadcast_shape_impl(
973+
[
974+
o1_shape,
975+
o2_shape,
976+
]
977+
)
978+
except ValueError:
979+
raise ValueError(
980+
"operands could not be broadcast together with shapes "
981+
f"{o1_shape} and {o2_shape}"
982+
)
983+
984+
if res_shape != o1_shape:
985+
raise ValueError(
986+
"The shape of the non-broadcastable left-hand "
987+
f"side {o1_shape} is inconsistent with the "
988+
f"broadcast shape {res_shape}."
989+
)
990+
991+
sycl_dev = exec_q.sycl_device
992+
o1_dtype = o1.dtype
993+
o2_dtype = _get_dtype(o2, sycl_dev)
994+
if not _validate_dtype(o2_dtype):
995+
raise ValueError("Operand has an unsupported data type")
996+
997+
o1_dtype, o2_dtype = self.weak_type_resolver_(
998+
o1_dtype, o2_dtype, sycl_dev
999+
)
1000+
1001+
buf_dt, res_dt = _find_buf_dtype_in_place_op(
1002+
o1_dtype,
1003+
o2_dtype,
1004+
self.result_type_resolver_fn_,
1005+
sycl_dev,
1006+
)
1007+
1008+
if res_dt is None:
1009+
raise ValueError(
1010+
f"function '{self.name_}' does not support input types "
1011+
f"({o1_dtype}, {o2_dtype}), "
1012+
"and the inputs could not be safely coerced to any "
1013+
"supported types according to the casting rule "
1014+
"''same_kind''."
1015+
)
1016+
1017+
if res_dt != o1_dtype:
1018+
raise ValueError(
1019+
f"Output array of type {res_dt} is needed, " f"got {o1_dtype}"
1020+
)
1021+
1022+
_manager = SequentialOrderManager[exec_q]
1023+
if isinstance(o2, dpt.usm_ndarray):
1024+
src2 = o2
1025+
if (
1026+
ti._array_overlap(o2, o1)
1027+
and not ti._same_logical_tensors(o2, o1)
1028+
and buf_dt is None
1029+
):
1030+
buf_dt = o2_dtype
1031+
else:
1032+
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1033+
if buf_dt is None:
1034+
if src2.shape != res_shape:
1035+
src2 = dpt.broadcast_to(src2, res_shape)
1036+
dep_evs = _manager.submitted_events
1037+
ht_, comp_ev = self.binary_inplace_fn_(
1038+
lhs=o1,
1039+
rhs=src2,
1040+
sycl_queue=exec_q,
1041+
depends=dep_evs,
1042+
)
1043+
_manager.add_event_pair(ht_, comp_ev)
1044+
else:
1045+
buf = dpt.empty_like(src2, dtype=buf_dt)
1046+
dep_evs = _manager.submitted_events
1047+
(
1048+
ht_copy_ev,
1049+
copy_ev,
1050+
) = ti._copy_usm_ndarray_into_usm_ndarray(
1051+
src=src2,
1052+
dst=buf,
1053+
sycl_queue=exec_q,
1054+
depends=dep_evs,
1055+
)
1056+
_manager.add_event_pair(ht_copy_ev, copy_ev)
1057+
1058+
buf = dpt.broadcast_to(buf, res_shape)
1059+
ht_, bf_ev = self.binary_inplace_fn_(
1060+
lhs=o1,
1061+
rhs=buf,
1062+
sycl_queue=exec_q,
1063+
depends=[copy_ev],
1064+
)
1065+
_manager.add_event_pair(ht_, bf_ev)
1066+
1067+
return o1

dpctl/tensor/_type_utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,21 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
277277
return None, None, None
278278

279279

280+
def _find_buf_dtype_in_place_op(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
281+
res_dt = query_fn(arg1_dtype, arg2_dtype)
282+
if res_dt:
283+
return None, res_dt
284+
285+
_fp16 = sycl_dev.has_aspect_fp16
286+
_fp64 = sycl_dev.has_aspect_fp64
287+
if _can_cast(arg2_dtype, arg1_dtype, _fp16, _fp64, casting="same_kind"):
288+
res_dt = query_fn(arg1_dtype, arg1_dtype)
289+
if res_dt:
290+
return arg1_dtype, res_dt
291+
292+
return None, None
293+
294+
280295
class WeakBooleanType:
281296
"Python type representing type of Python boolean objects"
282297

@@ -959,4 +974,5 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
959974
"WeakComplexType",
960975
"_default_accumulation_dtype",
961976
"_default_accumulation_dtype_fp_types",
977+
"_find_buf_dtype_in_place_op",
962978
]

dpctl/tensor/_usmarray.pyx

+13-13
Original file line numberDiff line numberDiff line change
@@ -1508,43 +1508,43 @@ cdef class usm_ndarray:
15081508
return dpctl.tensor.bitwise_xor(other, self)
15091509

15101510
def __iadd__(self, other):
1511-
return dpctl.tensor.add(self, other, out=self)
1511+
return dpctl.tensor.add._inplace_op(self, other)
15121512

15131513
def __iand__(self, other):
1514-
return dpctl.tensor.bitwise_and(self, other, out=self)
1514+
return dpctl.tensor.bitwise_and._inplace_op(self, other)
15151515

15161516
def __ifloordiv__(self, other):
1517-
return dpctl.tensor.floor_divide(self, other, out=self)
1517+
return dpctl.tensor.floor_divide._inplace_op(self, other)
15181518

15191519
def __ilshift__(self, other):
1520-
return dpctl.tensor.bitwise_left_shift(self, other, out=self)
1520+
return dpctl.tensor.bitwise_left_shift._inplace_op(self, other)
15211521

15221522
def __imatmul__(self, other):
1523-
return dpctl.tensor.matmul(self, other, out=self)
1523+
return dpctl.tensor.matmul(self, other, out=self, dtype=self.dtype)
15241524

15251525
def __imod__(self, other):
1526-
return dpctl.tensor.remainder(self, other, out=self)
1526+
return dpctl.tensor.remainder._inplace_op(self, other)
15271527

15281528
def __imul__(self, other):
1529-
return dpctl.tensor.multiply(self, other, out=self)
1529+
return dpctl.tensor.multiply._inplace_op(self, other)
15301530

15311531
def __ior__(self, other):
1532-
return dpctl.tensor.bitwise_or(self, other, out=self)
1532+
return dpctl.tensor.bitwise_or._inplace_op(self, other)
15331533

15341534
def __ipow__(self, other):
1535-
return dpctl.tensor.pow(self, other, out=self)
1535+
return dpctl.tensor.pow._inplace_op(self, other)
15361536

15371537
def __irshift__(self, other):
1538-
return dpctl.tensor.bitwise_right_shift(self, other, out=self)
1538+
return dpctl.tensor.bitwise_right_shift._inplace_op(self, other)
15391539

15401540
def __isub__(self, other):
1541-
return dpctl.tensor.subtract(self, other, out=self)
1541+
return dpctl.tensor.subtract._inplace_op(self, other)
15421542

15431543
def __itruediv__(self, other):
1544-
return dpctl.tensor.divide(self, other, out=self)
1544+
return dpctl.tensor.divide._inplace_op(self, other)
15451545

15461546
def __ixor__(self, other):
1547-
return dpctl.tensor.bitwise_xor(self, other, out=self)
1547+
return dpctl.tensor.bitwise_xor._inplace_op(self, other)
15481548

15491549
def __str__(self):
15501550
return usm_ndarray_str(self)

dpctl/tests/elementwise/test_add.py

+67-8
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,9 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
358358
dev = q.sycl_device
359359
_fp16 = dev.has_aspect_fp16
360360
_fp64 = dev.has_aspect_fp64
361-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
361+
# operators use a different Python implementation which permits
362+
# same kind style casting
363+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
362364
ar1 += ar2
363365
assert (
364366
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
@@ -373,9 +375,28 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
373375
else:
374376
with pytest.raises(ValueError):
375377
ar1 += ar2
378+
379+
# here, test the special case where out is the first argument
380+
# so an in-place kernel is used for efficiency
381+
# this covers a specific branch in the BinaryElementwiseFunc logic
382+
ar1 = dpt.ones(sz, dtype=op1_dtype)
383+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
384+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
385+
dpt.add(ar1, ar2, out=ar1)
386+
assert (
387+
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
388+
).all()
389+
390+
ar3 = dpt.ones(sz, dtype=op1_dtype)[::-1]
391+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)[::2]
392+
dpt.add(ar3, ar4, out=ar3)
393+
assert (
394+
dpt.asnumpy(ar3) == np.full(ar3.shape, 2, dtype=ar3.dtype)
395+
).all()
396+
else:
397+
with pytest.raises(ValueError):
376398
dpt.add(ar1, ar2, out=ar1)
377399

378-
# out is second arg
379400
ar1 = dpt.ones(sz, dtype=op1_dtype)
380401
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
381402
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
@@ -401,7 +422,7 @@ def test_add_inplace_broadcasting():
401422
m = dpt.ones((100, 5), dtype="i4")
402423
v = dpt.arange(5, dtype="i4")
403424

404-
m += v
425+
dpt.add(m, v, out=m)
405426
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
406427

407428
# check case where second arg is out
@@ -411,6 +432,26 @@ def test_add_inplace_broadcasting():
411432
).all()
412433

413434

435+
def test_add_inplace_operator_broadcasting():
436+
get_queue_or_skip()
437+
438+
m = dpt.ones((100, 5), dtype="i4")
439+
v = dpt.arange(5, dtype="i4")
440+
441+
m += v
442+
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
443+
444+
445+
def test_add_inplace_operator_mutual_broadcast():
446+
get_queue_or_skip()
447+
448+
x1 = dpt.ones((1, 10), dtype="i4")
449+
x2 = dpt.ones((10, 1), dtype="i4")
450+
451+
with pytest.raises(ValueError):
452+
dpt.add._inplace_op(x1, x2)
453+
454+
414455
def test_add_inplace_errors():
415456
get_queue_or_skip()
416457
try:
@@ -425,27 +466,45 @@ def test_add_inplace_errors():
425466
ar1 = dpt.ones(2, dtype="float32", sycl_queue=gpu_queue)
426467
ar2 = dpt.ones_like(ar1, sycl_queue=cpu_queue)
427468
with pytest.raises(ExecutionPlacementError):
428-
ar1 += ar2
469+
dpt.add(ar1, ar2, out=ar1)
429470

430471
ar1 = dpt.ones(2, dtype="float32")
431472
ar2 = dpt.ones(3, dtype="float32")
432473
with pytest.raises(ValueError):
433-
ar1 += ar2
474+
dpt.add(ar1, ar2, out=ar1)
434475

435476
ar1 = np.ones(2, dtype="float32")
436477
ar2 = dpt.ones(2, dtype="float32")
437478
with pytest.raises(TypeError):
438-
ar1 += ar2
479+
dpt.add(ar1, ar2, out=ar1)
439480

440481
ar1 = dpt.ones(2, dtype="float32")
441482
ar2 = dict()
442483
with pytest.raises(ValueError):
443-
ar1 += ar2
484+
dpt.add(ar1, ar2, out=ar1)
444485

445486
ar1 = dpt.ones((2, 1), dtype="float32")
446487
ar2 = dpt.ones((1, 2), dtype="float32")
447488
with pytest.raises(ValueError):
448-
ar1 += ar2
489+
dpt.add(ar1, ar2, out=ar1)
490+
491+
492+
def test_add_inplace_operator_errors():
493+
q1 = get_queue_or_skip()
494+
q2 = get_queue_or_skip()
495+
496+
x = dpt.ones(10, dtype="i4", sycl_queue=q1)
497+
with pytest.raises(TypeError):
498+
dpt.add._inplace_op(dict(), x)
499+
500+
x.flags["W"] = False
501+
with pytest.raises(ValueError):
502+
dpt.add._inplace_op(x, 2)
503+
504+
x_q1 = dpt.ones(10, dtype="i4", sycl_queue=q1)
505+
x_q2 = dpt.ones(10, dtype="i4", sycl_queue=q2)
506+
with pytest.raises(ExecutionPlacementError):
507+
dpt.add._inplace_op(x_q1, x_q2)
449508

450509

451510
def test_add_inplace_same_tensors():

0 commit comments

Comments
 (0)