Skip to content

Commit

Permalink
Added MultiLayer QG Model. (#26)
Browse files Browse the repository at this point in the history
* Added a working implementation for QG

* Working demo for 2 Layer QG

* Added high res options.

* Added a refactored version.
Added wind forcing.

* Added book to jupyter books.
  • Loading branch information
jejjohnson committed Jul 25, 2023
1 parent 8cfbfd4 commit 589e454
Show file tree
Hide file tree
Showing 25 changed files with 13,818 additions and 588 deletions.
11 changes: 10 additions & 1 deletion jaxsw/_src/domain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,22 @@ def ndim(self) -> int:
def size(self) -> tp.Tuple[int]:
return tuple(map(len, self.coords))

@property
def Nx(self) -> tp.Tuple[int]:
return self.size

@property
def Lx(self) -> tp.Tuple[int]:
f = lambda xmin, xmax: xmax - xmin
return tuple(map(f, self.xmin, self.xmax))

@property
def cell_volume(self) -> float:
return reduce(mul, self.dx)


def make_coords(xmin, xmax, delta):
return jnp.arange(xmin, xmax, delta)
return jnp.arange(xmin, xmax + delta, delta)


def make_grid_from_coords(coords: tp.Iterable) -> Float[Array, " D"]:
Expand Down
90 changes: 90 additions & 0 deletions jaxsw/_src/domain/qg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import typing as tp
import equinox as eqx
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array


class LayerDomain(eqx.Module):
heights: Array = eqx.static_field()
reduced_gravities: Array = eqx.static_field()
Nz: float = eqx.static_field()
A: Array = eqx.static_field()
A_layer_2_mode: Array = eqx.static_field()
A_mode_2_layer: Array = eqx.static_field()
lambd: Array = eqx.static_field()

def __init__(self, heights: tp.List[float], reduced_gravities: tp.List[float]):
num_layers = len(heights)

msg = "Incorrect number of heights to reduced gravities."
msg += f"\nHeights: {heights} | {num_layers}"
msg += f"\nReduced Gravities: {reduced_gravities} | {len(reduced_gravities)}"
assert num_layers - 1 == len(reduced_gravities), msg

self.heights = jnp.asarray(heights)
self.reduced_gravities = jnp.asarray(reduced_gravities)
self.Nz = num_layers

# calculate matrix M
A = create_qg_multilayer_mat(heights, reduced_gravities)
self.A = jnp.asarray(A)

# create layer to mode matrices
lambd, A_layer_2_mode, A_mode_2_layer = compute_layer_to_mode_matrices(A)
self.lambd = jnp.asarray(lambd)
self.A_layer_2_mode = jnp.asarray(A_layer_2_mode)
self.A_mode_2_layer = jnp.asarray(A_mode_2_layer)


def create_qg_multilayer_mat(
heights: tp.List[float], reduced_gravities: tp.List[float]
) -> np.ndarray:
"""Computes the Matrix that is used to connected a stacked
isopycnal Quasi-Geostrophic model.
Args:
heights (tp.List[float]): the height for each layer
Size = [Nx]
reduced_gravities (tp.List[float]): the reduced gravities
for each layer, Size = [Nx-1]
Returns:
np.ndarray: The Matrix connecting the layers, Size = [Nz, Nx]
"""
num_heights = len(heights)

# initialize matrix
A = np.zeros((num_heights, num_heights))

# top rows
A[0, 0] = 1.0 / (heights[0] * reduced_gravities[0])
A[0, 1] = -1.0 / (heights[0] * reduced_gravities[0])

# interior rows
for i in range(1, num_heights - 1):
A[i, i - 1] = -1.0 / (heights[i] * reduced_gravities[i - 1])
A[i, i] = (
1.0 / heights[i] * (1 / reduced_gravities[i] + 1 / reduced_gravities[i - 1])
)
A[i, i + 1] = -1.0 / (heights[i] * reduced_gravities[num_heights - 2])

# bottom rows
A[-1, -1] = 1.0 / (heights[num_heights - 1] * reduced_gravities[num_heights - 2])
A[-1, -2] = -1.0 / (heights[num_heights - 1] * reduced_gravities[num_heights - 2])
return A


def compute_layer_to_mode_matrices(A):
# eigenvalue decomposition
lambd_r, R = jnp.linalg.eig(A)
_, L = jnp.linalg.eig(A.T)

# extract real components
lambd, R, L = lambd_r.real, R.real, L.real

# create matrices
Cl2m = np.diag(1.0 / np.diag(L.T @ R)) @ L.T
Cm2l = R
# create diagonal matrix
return lambd, -Cl2m, -Cm2l
8 changes: 4 additions & 4 deletions jaxsw/_src/domain/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_1d_domain():
domain = Domain(xmin=demo.xmin, xmax=demo.xmax, dx=demo.dx)

assert domain.ndim == 1
assert domain.size == (20,)
assert domain.grid.shape == (20, 1)
assert domain.Nx == (21,)
assert domain.grid.shape == (21, 1)
assert domain.cell_volume == 0.1


Expand All @@ -24,6 +24,6 @@ def test_2d_domain():
domain = Domain(xmin=demo.xmin, xmax=demo.xmax, dx=demo.dx)

assert domain.ndim == 2
assert domain.size == (20, 20)
assert domain.grid.shape == (20, 20, 2)
assert domain.Nx == (21, 21)
assert domain.grid.shape == (21, 21, 2)
assert domain.cell_volume == 0.05
Empty file added jaxsw/_src/forcing/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions jaxsw/_src/forcing/wind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
from jaxsw._src.operators.functional import grid as F_grid
import finitediffx as fdx


def init_tau(domain, tau0: float = 2.0e-5):
"""
Args
----
tau0 (float): wind stress magnitude m/s^2
default=2.0e-5"""
# initial TAU
tau = np.zeros((2, domain.Nx[0], domain.Nx[1]))

# create staggered coordinates (y-direction)
y_coords = np.arange(domain.Nx[1]) + 0.5

# create tau
tau[0, :, :] = -tau0 * np.cos(2 * np.pi * (y_coords / domain.Nx[1]))

return tau


def calculate_wind_forcing(tau, domain):
# move from edges to nodes
tau_x = F_grid.x_average_2D(tau[0], padding=((1, 0), (0, 0)))
tau_y = F_grid.y_average_2D(tau[1], padding=((0, 0), (1, 0)))

# compute finite difference
dF2dX = fdx.difference(
tau_y, axis=0, step_size=domain.dx[0], accuracy=1, method="central"
)
dF1dY = fdx.difference(
tau_x, axis=1, step_size=domain.dx[1], accuracy=1, method="central"
)
curl_stagg = dF2dX - dF1dY

return F_grid.center_average_2D(curl_stagg.squeeze()[1:, 1:])
Loading

0 comments on commit 589e454

Please sign in to comment.