|
def map_concatenate(self, expr: Concatenate) -> IndexLambda: |
|
from pymbolic.primitives import If, Comparison, Subscript |
|
|
|
def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: |
|
aggregate = prim.Variable(f"_in{array_index}") |
|
index = [prim.Variable(f"_{i}") |
|
if i != expr.axis |
|
else (prim.Variable(f"_{i}") - offset) |
|
for i in range(len(expr.shape))] |
|
return Subscript(aggregate, tuple(index)) |
|
|
|
lbounds: List[Any] = [0] |
|
ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]] |
|
|
|
for i, array in enumerate(expr.arrays[1:], start=1): |
|
ubounds.append(ubounds[i-1]+array.shape[expr.axis]) |
|
lbounds.append(ubounds[i-1]) |
|
|
|
# I = axis index |
|
# |
|
# => If(_I < arrays[0].shape[axis], |
|
# _in0[_0, _1, ..., _I, ...], |
|
# If(_I < (arrays[1].shape[axis]+arrays[0].shape[axis]), |
|
# _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...], |
|
# ... |
|
# _inNm1[_0, _1, ...] ...)) |
|
for i in range(len(expr.arrays) - 1, -1, -1): |
|
lbound, ubound = lbounds[i], ubounds[i] |
|
subarray_expr = get_subscript(i, lbound) |
|
if i == len(expr.arrays) - 1: |
|
concat_expr = subarray_expr |
|
else: |
|
concat_expr = If(Comparison(prim.Variable(f"_{expr.axis}"), |
|
"<", ubound), |
|
subarray_expr, |
|
concat_expr) |
|
|
|
bindings = {f"_in{i}": self.rec(array) |
|
for i, array in enumerate(expr.arrays)} |
|
|
|
return IndexLambda(expr=concat_expr, |
|
shape=self._rec_shape(expr.shape), |
|
dtype=expr.dtype, |
|
bindings=immutabledict(bindings), |
|
axes=expr.axes, |
|
var_to_reduction_descr=immutabledict(), |
|
tags=expr.tags) |
pytato/pytato/transform/lower_to_index_lambda.py
Lines 140 to 186 in 5aa8aa3
cc @majosm