@@ -40,7 +40,7 @@ def _get_indexing_mode(name):
4040 )
4141
4242
43- def take(x, indices, /, *, axis=None, mode="wrap"):
43+ def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
4444 """take(x, indices, axis=None, mode="wrap")
4545
4646 Takes elements from an array along a given axis at given indices.
@@ -54,6 +54,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
5454 The axis along which the values will be selected.
5555 If ``x`` is one-dimensional, this argument is optional.
5656 Default: ``None``.
57+ out (Optional[usm_ndarray]):
58+ Output array to populate. Array must have the correct
59+ shape and the expected data type.
5760 mode (str, optional):
5861 How out-of-bounds indices will be handled. Possible values
5962 are:
@@ -121,18 +124,53 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
121124 raise ValueError("`axis` must be 0 for an array of dimension 0.")
122125 res_shape = indices.shape
123126
124- res = dpt.empty(
125- res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
126- )
127+ dt = x.dtype
128+
129+ orig_out = out
130+ if out is not None:
131+ if not isinstance(out, dpt.usm_ndarray):
132+ raise TypeError(
133+ f"output array must be of usm_ndarray type, got {type(out)}"
134+ )
135+ if not out.flags.writable:
136+ raise ValueError("provided `out` array is read-only")
137+
138+ if out.shape != res_shape:
139+ raise ValueError(
140+ "The shape of input and output arrays are inconsistent. "
141+ f"Expected output shape is {res_shape}, got {out.shape}"
142+ )
143+ if dt != out.dtype:
144+ raise ValueError(
145+ f"Output array of type {dt} is needed, " f"got {out.dtype}"
146+ )
147+ if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
148+ raise dpctl.utils.ExecutionPlacementError(
149+ "Input and output allocation queues are not compatible"
150+ )
151+ if ti._array_overlap(x, out):
152+ out = dpt.empty_like(out)
153+ else:
154+ out = dpt.empty(
155+ res_shape, dtype=dt, usm_type=res_usm_type, sycl_queue=exec_q
156+ )
127157
128158 _manager = dpctl.utils.SequentialOrderManager[exec_q]
129159 deps_ev = _manager.submitted_events
130160 hev, take_ev = ti._take(
131- x, (indices,), res , axis, mode, sycl_queue=exec_q, depends=deps_ev
161+ x, (indices,), out , axis, mode, sycl_queue=exec_q, depends=deps_ev
132162 )
133163 _manager.add_event_pair(hev, take_ev)
134164
135- return res
165+ if not (orig_out is None or out is orig_out):
166+ # Copy the out data from temporary buffer to original memory
167+ ht_e_cpy, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
168+ src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev]
169+ )
170+ _manager.add_event_pair(ht_e_cpy, cpy_ev)
171+ out = orig_out
172+
173+ return out
136174
137175
138176def put(x, indices, vals, /, *, axis=None, mode="wrap"):
0 commit comments