Skip to content

Commit

Permalink
convert
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Sep 23, 2024
1 parent 5c34346 commit 5677ade
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,19 +219,21 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D
span = map(b -> get(dict_var_span, b, b), bt)
_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
x = convert.(eltypeθ, _set)
end

pde_vars = get_variables(eqs, dict_indvars, dict_depvars)
pde_args = get_argument(eqs, dict_indvars, dict_depvars)

pde_train_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points),
Iterators.product(bc_data...)))...))
# pde_train_set = adapt(eltypeθ,
# hcat(vec(map(points -> collect(points),
# Iterators.product(bc_data...)))...))

pde_train_sets = map(pde_args) do bt
span = map(b -> get(dict_var_span_, b, b), bt)
_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
x = convert.(eltypeθ, _set)
end
[pde_train_sets, bcs_train_sets]
end
Expand Down
1 change: 1 addition & 0 deletions src/neural_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ function generate_training_sets(domains, dx, eqs, eltypeθ)
spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]
train_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(spans...)))...))
convert.(eltypeθ, train_set)
end

function get_loss_function_(loss, init_params, pde_system, strategy::GridTraining)
Expand Down

0 comments on commit 5677ade

Please sign in to comment.