Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 71 additions & 56 deletions sympde/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __new__(cls, arguments, expr, **options):

args = _sanitize_arguments(arguments, is_linear=True)

if not is_linear_expression(expr, args, integral=False):
if not is_linear_expression(expr, args):
msg = '> Expression is not linear'
raise UnconsistentLinearExpressionError(msg)

Expand Down Expand Up @@ -739,71 +739,86 @@ def linearize(form, fields, trials=None):
return BilinearForm((trials, tests), bilinear_expr)

#==============================================================================
def is_linear_expression(expr, args, integral=True, debug=True):
"""checks if an expression is linear with respect to the given arguments."""
# ...
left_args = []
right_args = []
def is_linear_expression(expr, args, debug=True):
"""
Checks if expression is linear with respect to ``args``:

1. Additivity: f(x + y) = f(x) + f(y)
2. Homogeneity: f(alpha * x) = alpha * f(x)

Parameters
----------
expr
Symbolic expression to test.
args : iterable
Each argument must be ScalarFunction or VectorFunction.
debug : bool, optional
Print diagnostic info if a check fails.

Returns
-------
bool
True if the expression is linear with respect to all given arguments, False otherwise.
"""

x_args = []
y_args = []

# create 2 independent copies (x and y) of every original argument
for arg in args:
tag = random_string( 4 )
tag = random_string(4)

if isinstance(arg, ScalarFunction):
left = ScalarFunction(arg.space, name='l_' + tag)
right = ScalarFunction(arg.space, name='r_' + tag)
x = ScalarFunction(arg.space, name='x_' + tag)
y = ScalarFunction(arg.space, name='y_' + tag)

elif isinstance(arg, VectorFunction):
left = VectorFunction(arg.space, name='l_' + tag)
right = VectorFunction(arg.space, name='r_' + tag)
x = VectorFunction(arg.space, name='x_' + tag)
y = VectorFunction(arg.space, name='y_' + tag)
else:
raise TypeError('argument must be a {Scalar|Vector}Function')

left_args += [left]
right_args += [right]
# ...

# ... check addition
newargs = [left + right for left, right in zip(left_args, right_args)]

newexpr = expr.subs(zip(args, newargs))
left_expr = expr.subs(zip(args, left_args))
right_expr = expr.subs(zip(args, right_args))

a = newexpr
b = left_expr + right_expr

if not( (a-b).expand() == 0 or a.expand() == b.expand()):
# TODO use a warning or exception?
if debug:
print('Failed to assert addition property')
print('{} != {}'.format(a.expand(), b.expand()))
return False

# ...

# ... check multiplication
tag = random_string( 4 )
coeff = Constant('alpha_' + tag)

newexpr = expr
for arg, left in zip(args, left_args):
newarg = coeff * left
newexpr = newexpr.subs(arg, newarg)

atoms = list(newexpr.atoms(BasicOperator))
x_args.append(x)
y_args.append(y)

# ---------------------------------------------------------------------------
# check addition property: f(x + y) = f(x) + f(y)
summed_args = [x + y for x, y in zip(x_args, y_args)]
expr_at_x = expr.subs(zip(args, x_args)) # f(x)
expr_at_y = expr.subs(zip(args, y_args)) # f(y)
expr_at_sum = expr.subs(zip(args, summed_args)) # f(x + y)
expected_sum = expr_at_x + expr_at_y # = f(x) + f (y)

if (expr_at_sum - expected_sum).expand() != 0:
expr1 = expr_at_sum.expand()
expr2 = expected_sum.expand()
if expr1 != expr2:
if debug:
print('Failed to assert addition property')
print('{} != {}'.format(expr1, expr2))
return False

# ---------------------------------------------------------------------------
# check multiplication property: f(alpha * x) = alpha * f(x)
alpha = Constant(f"alpha_{random_string(4)}")

scaled_x_args = [alpha * x for x in x_args]
expr_at_scaled_x = expr.subs(zip(args, scaled_x_args))

atoms = list(expr_at_scaled_x.atoms(BasicOperator))
subs = [e.func(*e.args, evaluate=True) for e in atoms]
newexpr = newexpr.subs(zip(atoms, subs))


left_expr = expr.subs(list(zip(args, left_args)))
left_expr = coeff * left_expr
if not( (newexpr-left_expr).expand() == 0 or newexpr.expand()==left_expr.expand()):
# TODO use a warning or exception?
if debug:
print('Failed to assert multiplication property')
print('{} != {}'.format(newexpr, left_expr))
return False
# ...
expr_at_scaled_x = expr_at_scaled_x.subs(zip(atoms, subs))

scaled_expr = alpha * expr_at_x

if (expr_at_scaled_x - scaled_expr).expand() != 0:
expr1 = expr_at_scaled_x.expand()
expr2 = scaled_expr.expand()
if expr1 != expr2:
if debug:
print('Failed to assert multiplication property')
print('{} != {}'.format(expr1, expr2))
return False

return True

Expand Down
Loading