Skip to content

Commit

Permalink
traceback: limit locals stringification
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Jul 26, 2023
1 parent 58f507d commit a18ff01
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
53 changes: 52 additions & 1 deletion src/saturn_engine/utils/traceback_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import linecache
import sys
import traceback
from collections.abc import Collection
from collections.abc import Mapping
from contextlib import suppress
from types import FrameType
from types import TracebackType

Expand All @@ -20,6 +23,54 @@
)


def format_local(anyval: object, maxlen: int = 80) -> str:
val: str = "<???>"
with suppress(Exception):
if isinstance(anyval, (str, bytes)):
val = repr(anyval[:maxlen])
elif isinstance(anyval, (int, float)):
val = repr(anyval)
elif isinstance(anyval, Mapping):
val = ""
if not isinstance(anyval, dict):
val = f"{type(anyval)}"
val += "{"
vals = []
valslen = len(val)
for k, v in anyval.items():
ks = format_local(k, maxlen=maxlen - valslen)
valslen += len(ks)
vs = format_local(v, maxlen=maxlen - valslen)
valslen += len(vs) + 4
vals.append(f"{ks}: {vs}")
if valslen > maxlen:
return val + ", ".join(vals)
return val + ", ".join(vals) + "}"
elif isinstance(anyval, Collection):
val = ""
if not isinstance(anyval, list):
val = f"{type(anyval)}"
val += "["
vals = []
valslen = len(val)
for v in anyval:
vs = format_local(v, maxlen=maxlen - valslen)
valslen += len(vs) + 2
vals.append(vs)
if valslen > maxlen:
return val + ", ".join(vals)
return val + ", ".join(vals) + "]"
else:
val = str(type(anyval))

if len(val) >= maxlen:
if val[-1] == "'":
val = val[:-1] + "<...>"
else:
val += "<...>"
return val


@dataclasses.dataclass
class TracebackData:
"""Very alike traceback.TracebackException, with a few tweak to ensure
Expand Down Expand Up @@ -114,7 +165,7 @@ def extract_stack(
linecache.lazycache(filename, f.f_globals)
f_locals = f.f_locals
_locals: dict[str, str] = (
{k: str(v) for k, v in f_locals.items()} if f_locals else {}
{k: format_local(v) for k, v in f_locals.items()} if f_locals else {}
)

firstlineno = co.co_firstlineno
Expand Down
17 changes: 17 additions & 0 deletions tests/utils/test_traceback_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from saturn_engine.utils.traceback_data import format_local


def test_format_local() -> None:
assert format_local("123") == "'123'"
assert format_local(b"123") == "b'123'"
assert format_local(123) == "123"
assert format_local(1.23) == "1.23"
assert (
format_local({"a": 123, "b": (123, 4), "c": [1, 2, {"e": "f"}]})
== "{'a': 123, 'b': <class 'tuple'>[123, 4], 'c': [1, 2, {'e': 'f'}]}"
)

assert (
format_local({"a": "a", "b": "b" * 80, "c": "c"})
== "{'a': 'a', 'b': '" + "b" * 66 + "<...>"
)

0 comments on commit a18ff01

Please sign in to comment.