Skip to content

Commit

Permalink
Merge pull request #81 from ashutosh-b-b/bb/nn_stopping
Browse files Browse the repository at this point in the history
[FEAT]: Add `NNStopping`
  • Loading branch information
ChrisRackauckas authored Feb 9, 2024
2 parents e44771b + 66372d3 commit ed7a4bf
Show file tree
Hide file tree
Showing 26 changed files with 825 additions and 126 deletions.
5 changes: 4 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
pages = [
"Home" => "index.md",
"Getting started" => "getting_started.md",
"Problems" => "problems.md",
"Solver Algorithms" => ["MLP.md",
"DeepSplitting.md",
"DeepBSDE.md"],
"DeepBSDE.md",
"NNStopping.md"],
"Tutorials" => [
"tutorials/deepsplitting.md",
"tutorials/deepbsde.md",
"tutorials/mlp.md",
"tutorials/nnstopping.md",
],
"Feynman Kac formula" => "Feynman_Kac.md",
]
3 changes: 3 additions & 0 deletions docs/src/DeepBSDE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# [The `DeepBSDE` algorithm](@id deepbsde)

### Problems Supported:
1. [`ParabolicPDEProblem`](@ref)

```@autodocs
Modules = [HighDimPDE]
Pages = ["DeepBSDE.jl", "DeepBSDE_Han.jl"]
Expand Down
8 changes: 6 additions & 2 deletions docs/src/DeepSplitting.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# [The `DeepSplitting` algorithm](@id deepsplitting)

### Problems Supported:
1. [`PIDEProblem`](@ref)
2. [`ParabolicPDEProblem`](@ref)

```@autodocs
Modules = [HighDimPDE]
Pages = ["DeepSplitting.jl"]
Expand Down Expand Up @@ -62,14 +66,14 @@ In `HighDimPDE.jl` the right parameter combination $\theta$ is found by iterativ
`DeepSplitting` allows obtaining $u(t,x)$ on a single point $x \in \Omega$ with the keyword $x$.

```julia
prob = PIDEProblem(g, f, μ, σ, x, tspan)
prob = PIDEProblem(μ, σ, x, tspan, g, f,)
```

### Hypercube
Yet more generally, one wants to solve Eq. (1) on a $d$-dimensional cube $[a,b]^d$. This is offered by `HighDimPDE.jl` with the keyword `x0_sample`.

```julia
prob = PIDEProblem(g, f, μ, σ, x, tspan, x0_sample = x0_sample)
prob = PIDEProblem(μ, σ, x, tspan, g, f; x0_sample = x0_sample)
```
Internally, this is handled by assigning a random variable as the initial point of the particles, i.e.
```math
Expand Down
4 changes: 4 additions & 0 deletions docs/src/MLP.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# [The `MLP` algorithm](@id mlp)

### Problems Supported:
1. [`PIDEProblem`](@ref)
2. [`ParabolicPDEProblem`](@ref)

```@autodocs
Modules = [HighDimPDE]
Pages = ["MLP.jl"]
Expand Down
32 changes: 32 additions & 0 deletions docs/src/NNStopping.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# [The `NNStopping` algorithm](@id nn_stopping)

### Problems Supported:
1. [`ParabolicPDEProblem`](@ref)

```@autodocs
Modules = [HighDimPDE]
Pages = ["NNStopping.jl"]
```
## The general idea 💡

Similar to DeepSplitting and DeepBSDE, NNStopping evaluates the PDE as a Stochastic Differential Equation. Consider an Obstacle PDE of the form:
```math
max\lbrace\partial_t u(t,x) + \mu(t, x) \nabla_x u(t,x) + \frac{1}{2} \sigma^2(t, x) \Delta_x u(t,x) , g(t,x) - u(t,x)\rbrace
```

Such PDEs are commonly used as representations for the dynamics of stock prices that can be exercised before maturity, such as American Options.

Using the Feynman-Kac formula, the underlying SDE will be:

```math
dX_{t}=\mu(X,t)dt + \sigma(X,t)\ dW_{t}^{Q}
```

The payoff of the option would then be:

```math
sup\lbrace\mathbb{E}[g(X_\tau, \tau)]\rbrace
```
Where τ is the stopping (exercising) time. The goal is to retrieve both the optimal exercising strategy (τ) and the payoff.

We approximate each stopping decision with a neural network architecture, inorder to maximise the expected payoff.
6 changes: 3 additions & 3 deletions docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ g(x) = exp(-sum(x.^2)) # initial condition
μ(x, p, t) = 0.0 # advection coefficients
σ(x, p, t) = 0.1 # diffusion coefficients
f(x, y, v_x, v_y, ∇v_x, ∇v_y, p, t) = max(0.0, v_x) * (1 - max(0.0, v_x)) # nonlocal nonlinear part of the
prob = PIDEProblem(g, f, μ, σ, x0, tspan) # defining the problem
prob = PIDEProblem(μ, σ, x0, tspan, g, f) # defining the problem
## Definition of the algorithm
alg = MLP() # defining the algorithm. We use the Multi Level Picard algorithm
Expand Down Expand Up @@ -62,7 +62,7 @@ g(x) = exp( -sum(x.^2) ) # initial condition
σ(x, p, t) = 0.1 # diffusion coefficients
mc_sample = UniformSampling(fill(-5f-1, d), fill(5f-1, d))
f(x, y, v_x, v_y, ∇v_x, ∇v_y, p, t) = max(0.0, v_x) * (1 - max(0.0, v_y))
prob = PIDEProblem(g, f, μ, σ, x0, tspan) # defining x0_sample is sufficient to implement Neumann boundary conditions
prob = PIDEProblem(μ, σ, x0, tspan, g, f) # defining x0_sample is sufficient to implement Neumann boundary conditions
## Definition of the algorithm
alg = MLP(mc_sample = mc_sample)
Expand All @@ -87,7 +87,7 @@ g(x) = exp.(-sum(x.^2, dims=1)) # initial condition
σ(x, p, t) = 0.1f0 # diffusion coefficients
x0_sample = UniformSampling(fill(-5f-1, d), fill(5f-1, d))
f(x, y, v_x, v_y, ∇v_x, ∇v_y, p, t) = v_x .* (1f0 .- v_y)
prob = PIDEProblem(g, f, μ, σ, x0, tspan,
prob = PIDEProblem(μ, σ, x0, tspan, g, f;
x0_sample = x0_sample)
## Definition of the neural network to use
Expand Down
15 changes: 14 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# HighDimPDE.jl


**HighDimPDE.jl** is a Julia package to **solve Highly Dimensional non-linear, non-local PDEs** of the form
**HighDimPDE.jl** is a Julia package to **solve Highly Dimensional non-linear, non-local PDEs** of the forms:

1. Partial Integro Differential Equations:
```math
\begin{aligned}
(\partial_t u)(t,x) &= \int_{\Omega} f\big(t,x,{\bf x}, u(t,x),u(t,{\bf x}), ( \nabla_x u )(t,x ),( \nabla_x u )(t,{\bf x} ) \big) \, d{\bf x} \\
Expand All @@ -12,6 +13,18 @@

where $u \colon [0,T] \times \Omega \to \R$, $\Omega \subseteq \R^d$ is subject to initial and boundary conditions, and where $d$ is large.

2. Parabolic Partial Differential Equations:
```math
\begin{aligned}
(\partial_t u)(t,x) &= f\big(t,x, u(t,x), ( \nabla_x u )(t,x )\big)
+ \big\langle \mu(t,x), ( \nabla_x u )( t,x ) \big\rangle + \tfrac{1}{2} \text{Trace} \big(\sigma(t,x) [ \sigma(t,x) ]^* ( \text{Hess}_x u)(t, x ) \big).
\end{aligned}
```

where $u \colon [0,T] \times \Omega \to \R$, $\Omega \subseteq \R^d$ is subject to initial and boundary conditions, and where $d$ is large.

!!! note
The difference between the two problems is that in Partial Integro Differential Equations, the integrand is integrated over **x**, while in Parabolic Integro Differential Equations, the function `f` is just evaluated for `x`.

**HighDimPDE.jl** implements solver algorithms that break down the curse of dimensionality, including

Expand Down
8 changes: 8 additions & 0 deletions docs/src/problems.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
```@docs
PIDEProblem
ParabolicPDEProblem
```

!!! note
While choosing to define a PDE using `PIDEProblem`, note that the function being integrated `f` is a function of `f(x, y, v_x, v_y, ∇v_x, ∇v_y)` out of which `y` is the integrating variable and `x` is constant throughout the integration.
If a PDE has no integral and the non linear term `f` is just evaluated as `f(x, v_x, ∇v_x)` then we suggest using `ParabolicPDEProblem`
6 changes: 3 additions & 3 deletions docs/src/tutorials/deepbsde.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ g(X) = log(0.5f0 + 0.5f0 * sum(X.^2))
f(X,u,σᵀ∇u,p,t) = -λ * sum(σᵀ∇u.^2)
μ_f(X,p,t) = zero(X) # Vector d x 1 λ
σ_f(X,p,t) = Diagonal(sqrt(2.0f0) * ones(Float32, d)) # Matrix d x d
prob = PIDEProblem(g, f, μ_f, σ_f, X0, tspan)
prob = PIDEProblem(μ_f, σ_f, X0, tspan, g, f)
hls = 10 + d # hidden layer size
opt = Optimisers.Adam(0.01) # optimizer
# sub-neural network approximating solutions at the desired point
Expand Down Expand Up @@ -75,7 +75,7 @@ g(X) = log(0.5f0 + 0.5f0*sum(X.^2))
f(X,u,σᵀ∇u,p,t) = -λ*sum(σᵀ∇u.^2)
μ_f(X,p,t) = zero(X) #Vector d x 1 λ
σ_f(X,p,t) = Diagonal(sqrt(2.0f0)*ones(Float32,d)) #Matrix d x d
prob = PIDEProblem(g, f, μ_f, σ_f, X0, tspan)
prob = PIDEProblem(μ_f, σ_f, X0, tspan, g, f)
```

#### Define the Solver Algorithm
Expand Down Expand Up @@ -135,7 +135,7 @@ f(X,u,σᵀ∇u,p,t) = r * (u - sum(X.*σᵀ∇u))
g(X) = sum(X.^2)
μ_f(X,p,t) = zero(X) #Vector d x 1
σ_f(X,p,t) = Diagonal(sigma*X) #Matrix d x d
prob = PIDEProblem(g, f, μ_f, σ_f, X0, tspan)
prob = PIDEProblem(μ_f, σ_f, X0, tspan, g, f)
```

As described in the API docs, we now need to define our `NNPDENS` algorithm
Expand Down
4 changes: 1 addition & 3 deletions docs/src/tutorials/deepsplitting.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ g(x) = exp.(- sum(x.^2, dims=1) ) # initial condition
σ(x, p, t) = 0.1f0 # diffusion coefficients
x0_sample = UniformSampling(fill(-5f-1, d), fill(5f-1, d))
f(x, y, v_x, v_y, ∇v_x, ∇v_y, p, t) = v_x .* (1f0 .- v_y)
prob = PIDEProblem(g, f, μ,
σ, x0, tspan,
x0_sample = x0_sample)
prob = PIDEProblem(μ, σ, x0, tspan, g, f; x0_sample = x0_sample)

## Definition of the neural network to use
using Flux # needed to define the neural network
Expand Down
7 changes: 3 additions & 4 deletions docs/src/tutorials/mlp.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ x0 = fill(0.,d) # initial point
g(x) = exp(- sum(x.^2) ) # initial condition
μ(x, p, t) = 0.0 # advection coefficients
σ(x, p, t) = 0.1 # diffusion coefficients
f(x, y, v_x, v_y, ∇v_x, ∇v_y, p, t) = max(0.0, v_x) * (1 - max(0.0, v_x)) # nonlocal nonlinear part of the
prob = PIDEProblem(g, f, μ, σ, x0, tspan) # defining the problem
f(x, v_x, ∇v_x, p, t) = max(0.0, v_x) * (1 - max(0.0, v_x)) # nonlocal nonlinear part of the
prob = ParabolicPDEProblem(μ, σ, x0, tspan, g, f) # defining the problem

## Definition of the algorithm
alg = MLP() # defining the algorithm. We use the Multi Level Picard algorithm
Expand All @@ -44,8 +44,7 @@ g(x) = exp( -sum(x.^2) ) # initial condition
σ(x, p, t) = 0.1 # diffusion coefficients
mc_sample = UniformSampling(fill(-5f-1, d), fill(5f-1, d))
f(x, y, v_x, v_y, ∇v_x, ∇v_y, t) = max(0.0, v_x) * (1 - max(0.0, v_y))
prob = PIDEProblem(g, f, μ,
σ, x0, tspan) # defining x0_sample is sufficient to implement Neumann boundary conditions
prob = PIDEProblem(μ, σ, x0, tspan, g, f) # defining x0_sample is sufficient to implement Neumann boundary conditions

## Definition of the algorithm
alg = MLP(mc_sample = mc_sample )
Expand Down
52 changes: 52 additions & 0 deletions docs/src/tutorials/nnstopping.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# `NNStopping`

## Solving for optimal strategy and expected payoff of a Bermudan Max-Call option

We will calculate optimal strategy for Bermudan Max-Call option with following drift, diffusion and payoff:
```math
μ(x) =(r − δ) x, σ(x) = β diag(x1, ... , xd),\\
g(t, x) = e^{-rt}max\lbrace max\lbrace x1, ... , xd \rbrace − K, 0\rbrace
```
We define the parameters, drift function and the diffusion function for the dynamics of the option.
```julia
d = 3 # Number of assets in the stock
r = 0.05 # interest rate
beta = 0.2 # volatility
T = 3 # maturity
u0 = fill(90.0, d) # initial stock value
delta = 0.1 # delta
f(du, u, p, t) = du .= (r - delta) * u # drift
sigma(du, u, p, t) = du .= beta * u # diffusion
tspan = (0.0, T)
N = 9 # discretization parameter
dt = T / (N)
K = 100.00 # strike price

# payoff function
function g(x, t)
return exp(-r * t) * (max(maximum(x) - K, 0))
end

```
We then define a `PIDEProblem` with no non linear term:
```julia
prob = PIDEProblem(f, sigma, u0, tspan; payoff = g)
```
!!! note
We provide the payoff function with a keyword argument `payoff`

And now we define our models:
```julia
models = [Chain(Dense(d + 1, 32, tanh), BatchNorm(32, tanh), Dense(32, 1, sigmoid))
for i in 1:N]
```
!!! note
The number of models should be equal to the time discritization.

And finally we define our optimizer and algorithm, and call `solve`:
```julia
opt = Flux.Optimisers.Adam(0.01)
alg = NNStopping(models, opt)

sol = solve(prob, alg, SRIW1(); dt = dt, trajectories = 1000, maxiters = 1000, verbose = true)
```
6 changes: 3 additions & 3 deletions paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ vol = prod(x0_sample[2] - x0_sample[1])
f(y, z, v_y, v_z, p, t) = max.(v_y, 0f0) .* (m(y) .- vol * max.(v_z, 0f0) .* m(z)) # nonlocal nonlinear part of the

# defining the problem
prob = PIDEProblem(g, f, μ, σ, tspan,
prob = PIDEProblem(μ, σ, tspan, g, f,
x0_sample = x0_sample
)
# solving
Expand Down Expand Up @@ -162,8 +162,8 @@ g(x) = exp( -sum(x.^2) ) # initial condition
σ(x, p, t) = 0.1 # diffusion coefficients
x0_sample = [-1/2, 1/2]
f(x, y, v_x, v_y, ∇v_x, ∇v_y, t) = max(0.0, v_x) * (1 - max(0.0, v_y))
prob = PIDEProblem(g, f, μ,
σ, x0, tspan,
prob = PIDEProblem(μ,
σ, x0, tspan, g, f,
x0_sample = x0_sample) # defining x0_sample is sufficient to implement Neumann boundary conditions

## Definition of the algorithm
Expand Down
8 changes: 5 additions & 3 deletions src/DeepBSDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ f(X,u,σᵀ∇u,p,t) = r * (u - sum(X.*σᵀ∇u))
g(X) = sum(X.^2)
μ_f(X,p,t) = zero(X) #Vector d x 1
σ_f(X,p,t) = Diagonal(sigma*X) #Matrix d x d
prob = PIDEProblem(g, f, μ_f, σ_f, x0, tspan)
prob = PIDEProblem(μ_f, σ_f, x0, tspan, g, f)
hls = 10 + d #hidden layer size
opt = Flux.Optimise.Adam(0.001)
Expand Down Expand Up @@ -59,7 +59,7 @@ end
DeepBSDE(u0, σᵀ∇u; opt = Flux.Optimise.Adam(0.1)) = DeepBSDE(u0, σᵀ∇u, opt)

"""
$(SIGNATURES)
$(TYPEDSIGNATURES)
Returns a `PIDESolution` object.
Expand All @@ -73,9 +73,11 @@ Returns a `PIDESolution` object.
[DifferentialEquations.jl doc](https://diffeq.sciml.ai/stable/solvers/sde_solve/).
- `limits`: if `true`, upper and lower limits will be calculated, based on
[Deep Primal-Dual algorithm for BSDEs](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=3071506).
- `maxiters`: The number of training epochs. Defaults to `300`
- `trajectories`: The number of trajectories simulated for training. Defaults to `100`
- Extra keyword arguments passed to `solve` will be further passed to the SDE solver.
"""
function DiffEqBase.solve(prob::PIDEProblem,
function DiffEqBase.solve(prob::ParabolicPDEProblem,
pdealg::DeepBSDE,
sdealg;
verbose = false,
Expand Down
13 changes: 12 additions & 1 deletion src/DeepBSDE_Han.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# called whenever sdealg is not specified.
function DiffEqBase.solve(prob::PIDEProblem,
"""
$(TYPEDSIGNATURES)
Returns a `PIDESolution` object.
# Arguments:
- `maxiters`: The number of training epochs. Defaults to `300`
- `trajectories`: The number of trajectories simulated for training. Defaults to `100`
To use [SDE Algorithms](https://diffeq.sciml.ai/stable/solvers/sde_solve/) use [`DeepBSDE`](@ref)
"""
function DiffEqBase.solve(prob::ParabolicPDEProblem,
alg::DeepBSDE;
dt,
abstol = 1.0f-6,
Expand Down
12 changes: 9 additions & 3 deletions src/DeepSplitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function DeepSplitting(nn;
end

"""
$(SIGNATURES)
$(TYPEDSIGNATURES)
Returns a `PIDESolution` object.
Expand All @@ -64,7 +64,7 @@ Returns a `PIDESolution` object.
- `use_cuda` : set to `true` to use CUDA.
- `cuda_device` : integer, to set the CUDA device used in the training, if `use_cuda == true`.
"""
function DiffEqBase.solve(prob::PIDEProblem,
function DiffEqBase.solve(prob::Union{PIDEProblem, ParabolicPDEProblem},
alg::DeepSplitting,
dt;
batch_size = 1,
Expand Down Expand Up @@ -98,7 +98,13 @@ function DiffEqBase.solve(prob::PIDEProblem,
K = alg.K
opt = alg.opt
λs = alg.λs
g, f, μ, σ, p = prob.g, prob.f, prob.μ, prob.σ, prob.p
g, μ, σ, p = prob.g, prob.μ, prob.σ, prob.p

f = if isa(prob, ParabolicPDEProblem)
(y, z, v_y, v_z, ∇v_y, ∇v_z, p, t) -> prob.f(y, v_y, ∇v_y, p, t )
else
prob.f
end
T = eltype(x0)

# neural network model
Expand Down
Loading

0 comments on commit ed7a4bf

Please sign in to comment.