Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename samples field of WeightedIntervalTraining to points and clean tstops_loss logic #727

Merged
merged 4 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,13 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo
weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
samples = strategy.samples
points = strategy.points

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, samples * item)) .* difference .+ minT .+
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end
Expand Down Expand Up @@ -450,19 +450,17 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
if !(tstops isa Nothing)
num_tstops_points = length(tstops)
tstops_loss_func = evaluate_tstops_loss(phi, f, autodiff, tstops, p, batch)
total_tstops_loss = tstops_loss_func(θ, phi) * num_tstops_points
tstops_loss = tstops_loss_func(θ, phi)
if strategy isa GridTraining
num_original_points = length(tspan[1]:(strategy.dx):tspan[2])
elseif strategy isa WeightedIntervalTraining
num_original_points = strategy.samples
elseif strategy isa StochasticTraining
elseif strategy isa Union{WeightedIntervalTraining, StochasticTraining}
num_original_points = strategy.points
else
L2_loss = L2_loss + tstops_loss_func(θ, phi)
return L2_loss
return L2_loss + tstops_loss
end

total_original_loss = L2_loss * num_original_points
total_tstops_loss = tstops_loss * num_original_points
total_points = num_original_points + num_tstops_points
L2_loss = (total_original_loss + total_tstops_loss) / total_points

Expand Down
8 changes: 4 additions & 4 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,19 +312,19 @@ such that the total number of sampled points is equivalent to the given samples
## Positional Arguments

* `weights`: A vector of weights that should sum to 1, representing the proportion of samples at each interval.
* `samples`: the total number of samples that we want, across the entire time span
* `points`: the total number of samples that we want, across the entire time span

## Limitations

This training strategy can only be used with ODEs (`NNODE`).
"""
struct WeightedIntervalTraining{T} <: AbstractTrainingStrategy
weights::Vector{T}
samples::Int
points::Int
end

function WeightedIntervalTraining(weights, samples)
WeightedIntervalTraining(weights, samples)
function WeightedIntervalTraining(weights, points)
WeightedIntervalTraining(weights, points)
end

function get_loss_function(loss_function, train_set, eltypeθ,
Expand Down
4 changes: 2 additions & 2 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N,

opt = Optimisers.Adam(0.01)
weights = [0.7, 0.2, 0.1]
samples = 200
points = 200
alg = NeuralPDE.NNODE(chain, opt, autodiff = false,
strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
strategy = NeuralPDE.WeightedIntervalTraining(weights, points))
sol = solve(prob_oop, alg, verbose = true, maxiters = 100000, saveat = 0.01)

@test abs(mean(sol) - mean(true_sol)) < 0.2
Expand Down
10 changes: 5 additions & 5 deletions test/NNODE_tstops_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ threshold = 0.2

#bad choices for weights, samples and dx so that the algorithm will fail without the added points
weights = [0.3, 0.3, 0.4]
samples = 3
points = 3
dx = 1.0

#Grid Training without added points (difference between solutions should be high)
Expand All @@ -43,25 +43,25 @@ sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, t
@test abs(mean(sol) - mean(true_sol)) < threshold

#WeightedIntervalTraining without added points (difference between solutions should be high)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, points))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat)

@test abs(mean(sol) - mean(true_sol)) > threshold

#WeightedIntervalTraining with added points (difference between solutions should be low)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, points))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints)

@test abs(mean(sol) - mean(true_sol)) < threshold

#StochasticTraining without added points (difference between solutions should be high)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(samples))
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(points))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat)

@test abs(mean(sol) - mean(true_sol)) > threshold

#StochasticTraining with added points (difference between solutions should be low)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(samples))
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(points))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints)

@test abs(mean(sol) - mean(true_sol)) < threshold
Loading