Skip to content

Commit 7ceb194

Browse files
Merge pull request #3422 from AayushSabharwal/as/getproperty-adjoint
feat: mark `getproperty(::AbstractSystem, ::Symbol)` as non-differentiable
2 parents e57b2a8 + ddf0d12 commit 7ceb194

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

ext/MTKChainRulesCoreExt.jl

+2
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,6 @@ function ChainRulesCore.rrule(
103103
newbuf, pullback
104104
end
105105

106+
ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol)
107+
106108
end

test/extensions/ad.jl

+10
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,13 @@ fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals)
124124
nsol = solve(nprob, NewtonRaphson())
125125
@test nsol[1] 10.0 / 1.0 + 9.81 * 1.0 / 2 # anal free fall solution is y = v0*t - g*t^2/2 -> v0 = y/t + g*t/2
126126
end
127+
128+
@testset "`sys.var` is non-differentiable" begin
129+
@variables x(t)
130+
@mtkbuild sys = ODESystem(D(x) ~ x, t)
131+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
132+
133+
grad = Zygote.gradient(prob) do prob
134+
prob[sys.x]
135+
end
136+
end

0 commit comments

Comments
 (0)