Skip to content

Commit

Permalink
Added remaining Numpy NDArray single function expressions (#183)
Browse files Browse the repository at this point in the history
* Added support for single expressions involving the following functions: numpy.linalg.{matrix_power, qr, svd, det, matrix_rank, inv, pinv}.

* Fixes for CI tests.

* Fixed issues with line lengths and import order.

* Refactored code.
  • Loading branch information
aritrakar committed Nov 17, 2023
1 parent d11334e commit 1cfb8e7
Show file tree
Hide file tree
Showing 2 changed files with 297 additions and 2 deletions.
138 changes: 136 additions & 2 deletions src/latexify/codegen/expression_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(
"""Initializer.
Args:
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_math_symbols: Whether to convert identifiers with a math symbol
surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_set_symbols: Whether to use set symbols or not.
"""
self._identifier_converter = identifier_converter.IdentifierConverter(
Expand Down Expand Up @@ -240,6 +240,130 @@ def _generate_transpose(self, node: ast.Call) -> str | None:
else:
return None

def _generate_determinant(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.det.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "det"

if len(node.args) != 1:
return None

func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
arg_id = rf"\mathbf{{{func_arg.id}}}"
return rf"\det \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)"
elif isinstance(func_arg, ast.List):
matrix = self._generate_matrix(node)
return rf"\det \mathopen{{}}\left( {matrix} \mathclose{{}}\right)"

return None

def _generate_matrix_rank(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.matrix_rank.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "matrix_rank"

if len(node.args) != 1:
return None

func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
arg_id = rf"\mathbf{{{func_arg.id}}}"
return (
rf"\mathrm{{rank}} \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)"
)
elif isinstance(func_arg, ast.List):
matrix = self._generate_matrix(node)
return (
rf"\mathrm{{rank}} \mathopen{{}}\left( {matrix} \mathclose{{}}\right)"
)

return None

def _generate_matrix_power(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.matrix_power.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "matrix_power"

if len(node.args) != 2:
return None

func_arg = node.args[0]
power_arg = node.args[1]
if isinstance(power_arg, ast.Num):
if isinstance(func_arg, ast.Name):
return rf"\mathbf{{{func_arg.id}}}^{{{power_arg.n}}}"
elif isinstance(func_arg, ast.List):
matrix = self._generate_matrix(node)
if matrix is not None:
return rf"{matrix}^{{{power_arg.n}}}"
return None

def _generate_inv(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.inv.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "inv"

if len(node.args) != 1:
return None

func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
return rf"\mathbf{{{func_arg.id}}}^{{-1}}"
elif isinstance(func_arg, ast.List):
return rf"{self._generate_matrix(node)}^{{-1}}"
return None

def _generate_pinv(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.pinv.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "pinv"

if len(node.args) != 1:
return None

func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
return rf"\mathbf{{{func_arg.id}}}^{{+}}"
elif isinstance(func_arg, ast.List):
return rf"{self._generate_matrix(node)}^{{+}}"
return None

def visit_Call(self, node: ast.Call) -> str:
"""Visit a Call node."""
func_name = ast_utils.extract_function_name_or_none(node)
Expand All @@ -256,6 +380,16 @@ def visit_Call(self, node: ast.Call) -> str:
special_latex = self._generate_identity(node)
elif func_name == "transpose":
special_latex = self._generate_transpose(node)
elif func_name == "det":
special_latex = self._generate_determinant(node)
elif func_name == "matrix_rank":
special_latex = self._generate_matrix_rank(node)
elif func_name == "matrix_power":
special_latex = self._generate_matrix_power(node)
elif func_name == "inv":
special_latex = self._generate_inv(node)
elif func_name == "pinv":
special_latex = self._generate_pinv(node)
else:
special_latex = None

Expand Down
161 changes: 161 additions & 0 deletions src/latexify/codegen/expression_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,167 @@ def test_transpose(code: str, latex: str) -> None:
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("det(A)", r"\det \mathopen{}\left( \mathbf{A} \mathclose{}\right)"),
("det(b)", r"\det \mathopen{}\left( \mathbf{b} \mathclose{}\right)"),
(
"det([[1, 2], [3, 4]])",
r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 \\"
r" 3 & 4 \end{bmatrix} \mathclose{}\right)",
),
(
"det([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)",
),
# Unsupported
("det()", r"\mathrm{det} \mathopen{}\left( \mathclose{}\right)"),
("det(2)", r"\mathrm{det} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"det(a, (1, 0))",
r"\mathrm{det} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_determinant(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
(
"matrix_rank(A)",
r"\mathrm{rank} \mathopen{}\left( \mathbf{A} \mathclose{}\right)",
),
(
"matrix_rank(b)",
r"\mathrm{rank} \mathopen{}\left( \mathbf{b} \mathclose{}\right)",
),
(
"matrix_rank([[1, 2], [3, 4]])",
r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 \\"
r" 3 & 4 \end{bmatrix} \mathclose{}\right)",
),
(
"matrix_rank([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)",
),
# Unsupported
(
"matrix_rank()",
r"\mathrm{matrix\_rank} \mathopen{}\left( \mathclose{}\right)",
),
(
"matrix_rank(2)",
r"\mathrm{matrix\_rank} \mathopen{}\left( 2 \mathclose{}\right)",
),
(
"matrix_rank(a, (1, 0))",
r"\mathrm{matrix\_rank} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_matrix_rank(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("matrix_power(A, 2)", r"\mathbf{A}^{2}"),
("matrix_power(b, 2)", r"\mathbf{b}^{2}"),
(
"matrix_power([[1, 2], [3, 4]], 2)",
r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{2}",
),
(
"matrix_power([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 42)",
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{42}",
),
# Unsupported
(
"matrix_power()",
r"\mathrm{matrix\_power} \mathopen{}\left( \mathclose{}\right)",
),
(
"matrix_power(2)",
r"\mathrm{matrix\_power} \mathopen{}\left( 2 \mathclose{}\right)",
),
(
"matrix_power(a, (1, 0))",
r"\mathrm{matrix\_power} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_matrix_power(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("inv(A)", r"\mathbf{A}^{-1}"),
("inv(b)", r"\mathbf{b}^{-1}"),
("inv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{-1}"),
(
"inv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{-1}",
),
# Unsupported
("inv()", r"\mathrm{inv} \mathopen{}\left( \mathclose{}\right)"),
("inv(2)", r"\mathrm{inv} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"inv(a, (1, 0))",
r"\mathrm{inv} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_inv(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("pinv(A)", r"\mathbf{A}^{+}"),
("pinv(b)", r"\mathbf{b}^{+}"),
("pinv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{+}"),
(
"pinv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{+}",
),
# Unsupported
("pinv()", r"\mathrm{pinv} \mathopen{}\left( \mathclose{}\right)"),
("pinv(2)", r"\mathrm{pinv} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"pinv(a, (1, 0))",
r"\mathrm{pinv} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_pinv(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


# Check list for #89.
# https://github.com/google/latexify_py/issues/89#issuecomment-1344967636
@pytest.mark.parametrize(
Expand Down

0 comments on commit 1cfb8e7

Please sign in to comment.