Skip to content

Commit

Permalink
test: update NNODE tests - forward pass in additional loss
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jun 29, 2024
1 parent e625c94 commit e436d1a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ end
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_)
end
alg1 = NNODE(luxchain, opt, strategy = GridTraining(0.01),
additional_loss = additional_loss)
Expand All @@ -203,7 +203,7 @@ end
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_)
end
alg1 = NNODE(luxchain, opt, additional_loss = additional_loss)
sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200)
Expand All @@ -215,7 +215,7 @@ end
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_)
end
alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000),
additional_loss = additional_loss)
Expand Down

0 comments on commit e436d1a

Please sign in to comment.