-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added a working implementation for QG * Working demo for 2 Layer QG * Added high res options. * Added a refactored version. Added wind forcing. * Added book to jupyter books.
- Loading branch information
1 parent
8cfbfd4
commit 589e454
Showing
25 changed files
with
13,818 additions
and
588 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import typing as tp | ||
import equinox as eqx | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from jaxtyping import Array | ||
|
||
|
||
class LayerDomain(eqx.Module): | ||
heights: Array = eqx.static_field() | ||
reduced_gravities: Array = eqx.static_field() | ||
Nz: float = eqx.static_field() | ||
A: Array = eqx.static_field() | ||
A_layer_2_mode: Array = eqx.static_field() | ||
A_mode_2_layer: Array = eqx.static_field() | ||
lambd: Array = eqx.static_field() | ||
|
||
def __init__(self, heights: tp.List[float], reduced_gravities: tp.List[float]): | ||
num_layers = len(heights) | ||
|
||
msg = "Incorrect number of heights to reduced gravities." | ||
msg += f"\nHeights: {heights} | {num_layers}" | ||
msg += f"\nReduced Gravities: {reduced_gravities} | {len(reduced_gravities)}" | ||
assert num_layers - 1 == len(reduced_gravities), msg | ||
|
||
self.heights = jnp.asarray(heights) | ||
self.reduced_gravities = jnp.asarray(reduced_gravities) | ||
self.Nz = num_layers | ||
|
||
# calculate matrix M | ||
A = create_qg_multilayer_mat(heights, reduced_gravities) | ||
self.A = jnp.asarray(A) | ||
|
||
# create layer to mode matrices | ||
lambd, A_layer_2_mode, A_mode_2_layer = compute_layer_to_mode_matrices(A) | ||
self.lambd = jnp.asarray(lambd) | ||
self.A_layer_2_mode = jnp.asarray(A_layer_2_mode) | ||
self.A_mode_2_layer = jnp.asarray(A_mode_2_layer) | ||
|
||
|
||
def create_qg_multilayer_mat( | ||
heights: tp.List[float], reduced_gravities: tp.List[float] | ||
) -> np.ndarray: | ||
"""Computes the Matrix that is used to connected a stacked | ||
isopycnal Quasi-Geostrophic model. | ||
Args: | ||
heights (tp.List[float]): the height for each layer | ||
Size = [Nx] | ||
reduced_gravities (tp.List[float]): the reduced gravities | ||
for each layer, Size = [Nx-1] | ||
Returns: | ||
np.ndarray: The Matrix connecting the layers, Size = [Nz, Nx] | ||
""" | ||
num_heights = len(heights) | ||
|
||
# initialize matrix | ||
A = np.zeros((num_heights, num_heights)) | ||
|
||
# top rows | ||
A[0, 0] = 1.0 / (heights[0] * reduced_gravities[0]) | ||
A[0, 1] = -1.0 / (heights[0] * reduced_gravities[0]) | ||
|
||
# interior rows | ||
for i in range(1, num_heights - 1): | ||
A[i, i - 1] = -1.0 / (heights[i] * reduced_gravities[i - 1]) | ||
A[i, i] = ( | ||
1.0 / heights[i] * (1 / reduced_gravities[i] + 1 / reduced_gravities[i - 1]) | ||
) | ||
A[i, i + 1] = -1.0 / (heights[i] * reduced_gravities[num_heights - 2]) | ||
|
||
# bottom rows | ||
A[-1, -1] = 1.0 / (heights[num_heights - 1] * reduced_gravities[num_heights - 2]) | ||
A[-1, -2] = -1.0 / (heights[num_heights - 1] * reduced_gravities[num_heights - 2]) | ||
return A | ||
|
||
|
||
def compute_layer_to_mode_matrices(A): | ||
# eigenvalue decomposition | ||
lambd_r, R = jnp.linalg.eig(A) | ||
_, L = jnp.linalg.eig(A.T) | ||
|
||
# extract real components | ||
lambd, R, L = lambd_r.real, R.real, L.real | ||
|
||
# create matrices | ||
Cl2m = np.diag(1.0 / np.diag(L.T @ R)) @ L.T | ||
Cm2l = R | ||
# create diagonal matrix | ||
return lambd, -Cl2m, -Cm2l |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
from jaxsw._src.operators.functional import grid as F_grid | ||
import finitediffx as fdx | ||
|
||
|
||
def init_tau(domain, tau0: float = 2.0e-5): | ||
""" | ||
Args | ||
---- | ||
tau0 (float): wind stress magnitude m/s^2 | ||
default=2.0e-5""" | ||
# initial TAU | ||
tau = np.zeros((2, domain.Nx[0], domain.Nx[1])) | ||
|
||
# create staggered coordinates (y-direction) | ||
y_coords = np.arange(domain.Nx[1]) + 0.5 | ||
|
||
# create tau | ||
tau[0, :, :] = -tau0 * np.cos(2 * np.pi * (y_coords / domain.Nx[1])) | ||
|
||
return tau | ||
|
||
|
||
def calculate_wind_forcing(tau, domain): | ||
# move from edges to nodes | ||
tau_x = F_grid.x_average_2D(tau[0], padding=((1, 0), (0, 0))) | ||
tau_y = F_grid.y_average_2D(tau[1], padding=((0, 0), (1, 0))) | ||
|
||
# compute finite difference | ||
dF2dX = fdx.difference( | ||
tau_y, axis=0, step_size=domain.dx[0], accuracy=1, method="central" | ||
) | ||
dF1dY = fdx.difference( | ||
tau_x, axis=1, step_size=domain.dx[1], accuracy=1, method="central" | ||
) | ||
curl_stagg = dF2dX - dF1dY | ||
|
||
return F_grid.center_average_2D(curl_stagg.squeeze()[1:, 1:]) |
Oops, something went wrong.