Skip to content

Commit

Permalink
Merge branch 'SciML:master' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal authored Aug 29, 2023
2 parents 41bc26b + 5506fe4 commit d5432c7
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Lux = "0.4, 0.5"
MCMCChains = "6"
ModelingToolkit = "8"
Optim = "1.0"
Optimisers = "0.2"
Optimisers = "0.2, 0.3"
Optimization = "3"
QuasiMonteCarlo = "0.2.1"
RecursiveArrayTools = "2.31"
Expand Down
14 changes: 6 additions & 8 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,13 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo
weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
samples = strategy.samples
points = strategy.points

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, samples * item)) .* difference .+ minT .+
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end
Expand Down Expand Up @@ -450,19 +450,17 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
if !(tstops isa Nothing)
num_tstops_points = length(tstops)
tstops_loss_func = evaluate_tstops_loss(phi, f, autodiff, tstops, p, batch)
total_tstops_loss = tstops_loss_func(θ, phi) * num_tstops_points
tstops_loss = tstops_loss_func(θ, phi)
if strategy isa GridTraining
num_original_points = length(tspan[1]:(strategy.dx):tspan[2])
elseif strategy isa WeightedIntervalTraining
num_original_points = strategy.samples
elseif strategy isa StochasticTraining
elseif strategy isa Union{WeightedIntervalTraining, StochasticTraining}
num_original_points = strategy.points
else
L2_loss = L2_loss + tstops_loss_func(θ, phi)
return L2_loss
return L2_loss + tstops_loss
end

total_original_loss = L2_loss * num_original_points
total_tstops_loss = tstops_loss * num_original_points
total_points = num_original_points + num_tstops_points
L2_loss = (total_original_loss + total_tstops_loss) / total_points

Expand Down
8 changes: 4 additions & 4 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,19 +312,19 @@ such that the total number of sampled points is equivalent to the given samples
## Positional Arguments
* `weights`: A vector of weights that should sum to 1, representing the proportion of samples at each interval.
* `samples`: the total number of samples that we want, across the entire time span
* `points`: the total number of samples that we want, across the entire time span
## Limitations
This training strategy can only be used with ODEs (`NNODE`).
"""
struct WeightedIntervalTraining{T} <: AbstractTrainingStrategy
weights::Vector{T}
samples::Int
points::Int
end

function WeightedIntervalTraining(weights, samples)
WeightedIntervalTraining(weights, samples)
function WeightedIntervalTraining(weights, points)
WeightedIntervalTraining(weights, points)
end

function get_loss_function(loss_function, train_set, eltypeθ,
Expand Down
4 changes: 2 additions & 2 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N,

opt = Optimisers.Adam(0.01)
weights = [0.7, 0.2, 0.1]
samples = 200
points = 200
alg = NeuralPDE.NNODE(chain, opt, autodiff = false,
strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
strategy = NeuralPDE.WeightedIntervalTraining(weights, points))
sol = solve(prob_oop, alg, verbose = true, maxiters = 100000, saveat = 0.01)

@test abs(mean(sol) - mean(true_sol)) < 0.2
Expand Down
10 changes: 5 additions & 5 deletions test/NNODE_tstops_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ threshold = 0.2

#bad choices for weights, samples and dx so that the algorithm will fail without the added points
weights = [0.3, 0.3, 0.4]
samples = 3
points = 3
dx = 1.0

#Grid Training without added points (difference between solutions should be high)
Expand All @@ -43,25 +43,25 @@ sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, t
@test abs(mean(sol) - mean(true_sol)) < threshold

#WeightedIntervalTraining without added points (difference between solutions should be high)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, points))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat)

@test abs(mean(sol) - mean(true_sol)) > threshold

#WeightedIntervalTraining with added points (difference between solutions should be low)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, points))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints)

@test abs(mean(sol) - mean(true_sol)) < threshold

#StochasticTraining without added points (difference between solutions should be high)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(samples))
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(points))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat)

@test abs(mean(sol) - mean(true_sol)) > threshold

#StochasticTraining with added points (difference between solutions should be low)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(samples))
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(points))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints)

@test abs(mean(sol) - mean(true_sol)) < threshold

0 comments on commit d5432c7

Please sign in to comment.