Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MultiLayer QG Model. #26

Merged
merged 5 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion jaxsw/_src/domain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,22 @@ def ndim(self) -> int:
def size(self) -> tp.Tuple[int]:
return tuple(map(len, self.coords))

@property
def Nx(self) -> tp.Tuple[int]:
return self.size

@property
def Lx(self) -> tp.Tuple[int]:
f = lambda xmin, xmax: xmax - xmin
return tuple(map(f, self.xmin, self.xmax))

@property
def cell_volume(self) -> float:
return reduce(mul, self.dx)


def make_coords(xmin, xmax, delta):
return jnp.arange(xmin, xmax, delta)
return jnp.arange(xmin, xmax + delta, delta)


def make_grid_from_coords(coords: tp.Iterable) -> Float[Array, " D"]:
Expand Down
90 changes: 90 additions & 0 deletions jaxsw/_src/domain/qg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import typing as tp
import equinox as eqx
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array


class LayerDomain(eqx.Module):
heights: Array = eqx.static_field()
reduced_gravities: Array = eqx.static_field()
Nz: float = eqx.static_field()
A: Array = eqx.static_field()
A_layer_2_mode: Array = eqx.static_field()
A_mode_2_layer: Array = eqx.static_field()
lambd: Array = eqx.static_field()

def __init__(self, heights: tp.List[float], reduced_gravities: tp.List[float]):
num_layers = len(heights)

msg = "Incorrect number of heights to reduced gravities."
msg += f"\nHeights: {heights} | {num_layers}"
msg += f"\nReduced Gravities: {reduced_gravities} | {len(reduced_gravities)}"
assert num_layers - 1 == len(reduced_gravities), msg

self.heights = jnp.asarray(heights)
self.reduced_gravities = jnp.asarray(reduced_gravities)
self.Nz = num_layers

# calculate matrix M
A = create_qg_multilayer_mat(heights, reduced_gravities)
self.A = jnp.asarray(A)

# create layer to mode matrices
lambd, A_layer_2_mode, A_mode_2_layer = compute_layer_to_mode_matrices(A)
self.lambd = jnp.asarray(lambd)
self.A_layer_2_mode = jnp.asarray(A_layer_2_mode)
self.A_mode_2_layer = jnp.asarray(A_mode_2_layer)


def create_qg_multilayer_mat(
heights: tp.List[float], reduced_gravities: tp.List[float]
) -> np.ndarray:
"""Computes the Matrix that is used to connected a stacked
isopycnal Quasi-Geostrophic model.

Args:
heights (tp.List[float]): the height for each layer
Size = [Nx]
reduced_gravities (tp.List[float]): the reduced gravities
for each layer, Size = [Nx-1]

Returns:
np.ndarray: The Matrix connecting the layers, Size = [Nz, Nx]
"""
num_heights = len(heights)

# initialize matrix
A = np.zeros((num_heights, num_heights))

# top rows
A[0, 0] = 1.0 / (heights[0] * reduced_gravities[0])
A[0, 1] = -1.0 / (heights[0] * reduced_gravities[0])

# interior rows
for i in range(1, num_heights - 1):
A[i, i - 1] = -1.0 / (heights[i] * reduced_gravities[i - 1])
A[i, i] = (
1.0 / heights[i] * (1 / reduced_gravities[i] + 1 / reduced_gravities[i - 1])
)
A[i, i + 1] = -1.0 / (heights[i] * reduced_gravities[num_heights - 2])

# bottom rows
A[-1, -1] = 1.0 / (heights[num_heights - 1] * reduced_gravities[num_heights - 2])
A[-1, -2] = -1.0 / (heights[num_heights - 1] * reduced_gravities[num_heights - 2])
return A


def compute_layer_to_mode_matrices(A):
# eigenvalue decomposition
lambd_r, R = jnp.linalg.eig(A)
_, L = jnp.linalg.eig(A.T)

# extract real components
lambd, R, L = lambd_r.real, R.real, L.real

# create matrices
Cl2m = np.diag(1.0 / np.diag(L.T @ R)) @ L.T
Cm2l = R
# create diagonal matrix
return lambd, -Cl2m, -Cm2l
8 changes: 4 additions & 4 deletions jaxsw/_src/domain/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_1d_domain():
domain = Domain(xmin=demo.xmin, xmax=demo.xmax, dx=demo.dx)

assert domain.ndim == 1
assert domain.size == (20,)
assert domain.grid.shape == (20, 1)
assert domain.Nx == (21,)
assert domain.grid.shape == (21, 1)
assert domain.cell_volume == 0.1


Expand All @@ -24,6 +24,6 @@ def test_2d_domain():
domain = Domain(xmin=demo.xmin, xmax=demo.xmax, dx=demo.dx)

assert domain.ndim == 2
assert domain.size == (20, 20)
assert domain.grid.shape == (20, 20, 2)
assert domain.Nx == (21, 21)
assert domain.grid.shape == (21, 21, 2)
assert domain.cell_volume == 0.05
Empty file added jaxsw/_src/forcing/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions jaxsw/_src/forcing/wind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
from jaxsw._src.operators.functional import grid as F_grid
import finitediffx as fdx


def init_tau(domain, tau0: float = 2.0e-5):
"""
Args
----
tau0 (float): wind stress magnitude m/s^2
default=2.0e-5"""
# initial TAU
tau = np.zeros((2, domain.Nx[0], domain.Nx[1]))

# create staggered coordinates (y-direction)
y_coords = np.arange(domain.Nx[1]) + 0.5

# create tau
tau[0, :, :] = -tau0 * np.cos(2 * np.pi * (y_coords / domain.Nx[1]))

return tau


def calculate_wind_forcing(tau, domain):
# move from edges to nodes
tau_x = F_grid.x_average_2D(tau[0], padding=((1, 0), (0, 0)))
tau_y = F_grid.y_average_2D(tau[1], padding=((0, 0), (1, 0)))

# compute finite difference
dF2dX = fdx.difference(
tau_y, axis=0, step_size=domain.dx[0], accuracy=1, method="central"
)
dF1dY = fdx.difference(
tau_x, axis=1, step_size=domain.dx[1], accuracy=1, method="central"
)
curl_stagg = dF2dX - dF1dY

return F_grid.center_average_2D(curl_stagg.squeeze()[1:, 1:])
Loading
Loading