Skip to content

Commit 4845a98

Browse files
fix: properly handle rational functions in HomotopyContinuation
1 parent 8098e0a commit 4845a98

File tree

4 files changed

+79
-23
lines changed

4 files changed

+79
-23
lines changed

ext/MTKHomotopyContinuationExt.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,16 @@ function MTK.HomotopyContinuationProblem(
101101
return prob
102102
end
103103

104-
function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; kwargs...)
104+
function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing;
105+
fraction_cancel_fn = SymbolicUtils.simplify_fractions, kwargs...)
105106
if !iscomplete(sys)
106107
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
107108
end
108109
transformation = MTK.PolynomialTransformation(sys)
109110
if transformation isa MTK.NotPolynomialError
110111
return transformation
111112
end
112-
result = MTK.transform_system(sys, transformation)
113+
result = MTK.transform_system(sys, transformation; fraction_cancel_fn)
113114
if result isa MTK.NotPolynomialError
114115
return result
115116
end

src/systems/nonlinear/homotopy_continuation.jl

+44-19
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,8 @@ Transform the system `sys` with `transformation` and return a
442442
`PolynomialTransformationResult`, or a `NotPolynomialError` if the system cannot
443443
be transformed.
444444
"""
445-
function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation)
445+
function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation;
446+
fraction_cancel_fn = simplify_fractions)
446447
subrules = transformation.substitution_rules
447448
dvs = unknowns(sys)
448449
eqs = full_equations(sys)
@@ -463,7 +464,7 @@ function transform_system(sys::NonlinearSystem, transformation::PolynomialTransf
463464
return NotPolynomialError(
464465
VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata)
465466
end
466-
num, den = handle_rational_polynomials(t, new_dvs)
467+
num, den = handle_rational_polynomials(t, new_dvs; fraction_cancel_fn)
467468
# make factors different elements, otherwise the nonzero factors artificially
468469
# inflate the error of the zero factor.
469470
if iscall(den) && operation(den) == *
@@ -492,43 +493,67 @@ $(TYPEDSIGNATURES)
492493
Given a `x`, a polynomial in variables in `wrt` which may contain rational functions,
493494
express `x` as a single rational function with polynomial `num` and denominator `den`.
494495
Return `(num, den)`.
496+
497+
Keyword arguments:
498+
- `fraction_cancel_fn`: A function which takes a fraction (`operation(expr) == /`) and returns
499+
a simplified symbolic quantity with common factors in the numerator and denominator are
500+
cancelled. Defaults to `SymbolicUtils.simplify_fractions`, but can be changed to
501+
`nothing` to improve performance on large polynomials at the cost of avoiding non-trivial
502+
cancellation.
495503
"""
496-
function handle_rational_polynomials(x, wrt)
504+
function handle_rational_polynomials(x, wrt; fraction_cancel_fn = simplify_fractions)
497505
x = unwrap(x)
498506
symbolic_type(x) == NotSymbolic() && return x, 1
499507
iscall(x) || return x, 1
500508
contains_variable(x, wrt) || return x, 1
501509
any(isequal(x), wrt) && return x, 1
502510

503-
# simplify_fractions cancels out some common factors
504-
# and expands (a / b)^c to a^c / b^c, so we only need
505-
# to handle these cases
506-
x = simplify_fractions(x)
507511
op = operation(x)
508512
args = arguments(x)
509513

510514
if op == /
511515
# numerator and denominator are trivial
512516
num, den = args
513-
# but also search for rational functions in numerator
514-
n, d = handle_rational_polynomials(num, wrt)
515-
num, den = n, den * d
516-
elseif op == +
517+
n1, d1 = handle_rational_polynomials(num, wrt; fraction_cancel_fn)
518+
n2, d2 = handle_rational_polynomials(den, wrt; fraction_cancel_fn)
519+
num, den = n1 * d2, d1 * n2
520+
elseif (op == +) || (op == -)
517521
num = 0
518522
den = 1
519-
520-
# we don't need to do common denominator
521-
# because we don't care about cases where denominator
522-
# is zero. The expression is zero when all the numerators
523-
# are zero.
523+
if op == -
524+
args[2] = -args[2]
525+
end
526+
for arg in args
527+
n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn)
528+
num = num * d + n * den
529+
den *= d
530+
end
531+
elseif op == ^
532+
base, pow = args
533+
num, den = handle_rational_polynomials(base, wrt; fraction_cancel_fn)
534+
num ^= pow
535+
den ^= pow
536+
elseif op == *
537+
num = 1
538+
den = 1
524539
for arg in args
525-
n, d = handle_rational_polynomials(arg, wrt)
526-
num += n
540+
n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn)
541+
num *= n
527542
den *= d
528543
end
529544
else
530-
return x, 1
545+
error("Unhandled operation in `handle_rational_polynomials`. This should never happen. Please open an issue in ModelingToolkit.jl with an MWE.")
546+
end
547+
548+
if fraction_cancel_fn !== nothing
549+
expr = fraction_cancel_fn(num / den)
550+
if iscall(expr) && operation(expr) == /
551+
num, den = arguments(expr)
552+
else
553+
num, den = expr, 1
554+
end
531555
end
556+
532557
# if the denominator isn't a polynomial in `wrt`, better to not include it
533558
# to reduce the size of the gcd polynomial
534559
if !contains_variable(den, wrt)

src/systems/nonlinear/nonlinearsystem.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,8 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
501501
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`")
502502
end
503503
if use_homotopy_continuation
504-
prob = safe_HomotopyContinuationProblem(sys, u0map, parammap; check_length, kwargs...)
504+
prob = safe_HomotopyContinuationProblem(
505+
sys, u0map, parammap; check_length, kwargs...)
505506
if prob isa HomotopyContinuationProblem
506507
return prob
507508
end

test/extensions/homotopy_continuation.jl

+30-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkit, NonlinearSolve, SymbolicIndexingInterface
2+
using SymbolicUtils
23
import ModelingToolkit as MTK
34
using LinearAlgebra
45
using Test
@@ -34,6 +35,8 @@ import HomotopyContinuation
3435
sol = solve(prob2; threading = false)
3536
@test SciMLBase.successful_retcode(sol)
3637
@test norm(sol.resid)0.0 atol=1e-10
38+
39+
@test NonlinearProblem(sys, u0; use_homotopy_continuation = false) isa NonlinearProblem
3740
end
3841

3942
struct Wrapper
@@ -217,7 +220,17 @@ end
217220
@mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)])
218221
prob = HomotopyContinuationProblem(sys, [])
219222
@test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0)
220-
@test_nowarn solve(prob; threading = false)
223+
@test SciMLBase.successful_retcode(solve(prob; threading = false))
224+
end
225+
226+
@testset "Rational function forced to common denominators" begin
227+
@variables x = 1
228+
@mtkbuild sys = NonlinearSystem([0 ~ 1 / (1 + x) - x])
229+
prob = HomotopyContinuationProblem(sys, [])
230+
@test any(prob.denominator([1.0], parameter_values(prob)) .≈ 0.0)
231+
sol = solve(prob; threading = false)
232+
@test SciMLBase.successful_retcode(sol)
233+
@test 1 / (1 + sol.u[1]) - sol.u[1]0.0 atol=1e-10
221234
end
222235
end
223236

@@ -229,3 +242,19 @@ end
229242
@test sol[x] 2.0
230243
@test sol[y] sin(2.0)
231244
end
245+
246+
@testset "`fraction_cancel_fn`" begin
247+
@variables x = 1
248+
@named sys = NonlinearSystem([0 ~ ((x^2 - 5x + 6) / (x - 2) - 1) * (x^2 - 7x + 12) /
249+
(x - 4)^3])
250+
sys = complete(sys)
251+
252+
@testset "`simplify_fractions`" begin
253+
prob = HomotopyContinuationProblem(sys, [])
254+
@test prob.denominator([0.0], parameter_values(prob)) [4.0]
255+
end
256+
@testset "`nothing`" begin
257+
prob = HomotopyContinuationProblem(sys, []; fraction_cancel_fn = nothing)
258+
@test sort(prob.denominator([0.0], parameter_values(prob))) [2.0, 4.0^3]
259+
end
260+
end

0 commit comments

Comments
 (0)