diff --git a/sympde/expr/expr.py b/sympde/expr/expr.py index 46202054..0eeac3dd 100644 --- a/sympde/expr/expr.py +++ b/sympde/expr/expr.py @@ -127,7 +127,7 @@ def __new__(cls, arguments, expr, **options): args = _sanitize_arguments(arguments, is_linear=True) - if not is_linear_expression(expr, args, integral=False): + if not is_linear_expression(expr, args): msg = '> Expression is not linear' raise UnconsistentLinearExpressionError(msg) @@ -739,71 +739,86 @@ def linearize(form, fields, trials=None): return BilinearForm((trials, tests), bilinear_expr) #============================================================================== -def is_linear_expression(expr, args, integral=True, debug=True): - """checks if an expression is linear with respect to the given arguments.""" - # ... - left_args = [] - right_args = [] +def is_linear_expression(expr, args, debug=True): + """ + Checks if expression is linear with respect to ``args``: + + 1. Additivity: f(x + y) = f(x) + f(y) + 2. Homogeneity: f(alpha * x) = alpha * f(x) + + Parameters + ---------- + expr + Symbolic expression to test. + args : iterable + Each argument must be ScalarFunction or VectorFunction. + debug : bool, optional + Print diagnostic info if a check fails. + + Returns + ------- + bool + True if the expression is linear with respect to all given arguments, False otherwise. + """ + + x_args = [] + y_args = [] + # create 2 independent copies (x and y) of every original argument for arg in args: - tag = random_string( 4 ) + tag = random_string(4) if isinstance(arg, ScalarFunction): - left = ScalarFunction(arg.space, name='l_' + tag) - right = ScalarFunction(arg.space, name='r_' + tag) + x = ScalarFunction(arg.space, name='x_' + tag) + y = ScalarFunction(arg.space, name='y_' + tag) elif isinstance(arg, VectorFunction): - left = VectorFunction(arg.space, name='l_' + tag) - right = VectorFunction(arg.space, name='r_' + tag) + x = VectorFunction(arg.space, name='x_' + tag) + y = VectorFunction(arg.space, name='y_' + tag) else: raise TypeError('argument must be a {Scalar|Vector}Function') - left_args += [left] - right_args += [right] - # ... - - # ... check addition - newargs = [left + right for left, right in zip(left_args, right_args)] - - newexpr = expr.subs(zip(args, newargs)) - left_expr = expr.subs(zip(args, left_args)) - right_expr = expr.subs(zip(args, right_args)) - - a = newexpr - b = left_expr + right_expr - - if not( (a-b).expand() == 0 or a.expand() == b.expand()): - # TODO use a warning or exception? - if debug: - print('Failed to assert addition property') - print('{} != {}'.format(a.expand(), b.expand())) - return False - - # ... - - # ... check multiplication - tag = random_string( 4 ) - coeff = Constant('alpha_' + tag) - - newexpr = expr - for arg, left in zip(args, left_args): - newarg = coeff * left - newexpr = newexpr.subs(arg, newarg) - - atoms = list(newexpr.atoms(BasicOperator)) + x_args.append(x) + y_args.append(y) + + # --------------------------------------------------------------------------- + # check addition property: f(x + y) = f(x) + f(y) + summed_args = [x + y for x, y in zip(x_args, y_args)] + expr_at_x = expr.subs(zip(args, x_args)) # f(x) + expr_at_y = expr.subs(zip(args, y_args)) # f(y) + expr_at_sum = expr.subs(zip(args, summed_args)) # f(x + y) + expected_sum = expr_at_x + expr_at_y # = f(x) + f (y) + + if (expr_at_sum - expected_sum).expand() != 0: + expr1 = expr_at_sum.expand() + expr2 = expected_sum.expand() + if expr1 != expr2: + if debug: + print('Failed to assert addition property') + print('{} != {}'.format(expr1, expr2)) + return False + + # --------------------------------------------------------------------------- + # check multiplication property: f(alpha * x) = alpha * f(x) + alpha = Constant(f"alpha_{random_string(4)}") + + scaled_x_args = [alpha * x for x in x_args] + expr_at_scaled_x = expr.subs(zip(args, scaled_x_args)) + + atoms = list(expr_at_scaled_x.atoms(BasicOperator)) subs = [e.func(*e.args, evaluate=True) for e in atoms] - newexpr = newexpr.subs(zip(atoms, subs)) - - - left_expr = expr.subs(list(zip(args, left_args))) - left_expr = coeff * left_expr - if not( (newexpr-left_expr).expand() == 0 or newexpr.expand()==left_expr.expand()): - # TODO use a warning or exception? - if debug: - print('Failed to assert multiplication property') - print('{} != {}'.format(newexpr, left_expr)) - return False - # ... + expr_at_scaled_x = expr_at_scaled_x.subs(zip(atoms, subs)) + + scaled_expr = alpha * expr_at_x + + if (expr_at_scaled_x - scaled_expr).expand() != 0: + expr1 = expr_at_scaled_x.expand() + expr2 = scaled_expr.expand() + if expr1 != expr2: + if debug: + print('Failed to assert multiplication property') + print('{} != {}'.format(expr1, expr2)) + return False return True