Skip to content

Commit

Permalink
Make sure correct errors thrown when adding attacker with 'bad' entry…
Browse files Browse the repository at this point in the history
… points or reached attack steps
  • Loading branch information
mrkickling committed Aug 7, 2024
1 parent c49e9dc commit 12d9dd4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
14 changes: 10 additions & 4 deletions maltoolbox/attackgraph/attackgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ def add_attacker(
self,
attacker: Attacker,
attacker_id: Optional[int] = None,
entry_points: list[int] = [],
reached_attack_steps: list[int] = []
entry_points: list[int | str] = [],
reached_attack_steps: list[int | str] = []
):
"""Add an attacker to the graph
Arguments:
Expand Down Expand Up @@ -653,7 +653,10 @@ def add_attacker(

self.next_attacker_id = max(attacker.id + 1, self.next_attacker_id)
for node_id in reached_attack_steps:
node = self.get_node_by_id(int(node_id))
if isinstance(node_id, str) and not node_id.isnumeric():
raise TypeError(f"Node id {node_id} in reached_attack_steps not numeric")
node_id = int(node_id)
node = self.get_node_by_id(node_id)
if node:
attacker.compromise(node)
else:
Expand All @@ -662,7 +665,10 @@ def add_attacker(
logger.error(msg, node_id)
raise AttackGraphException(msg % node_id)
for node_id in entry_points:
node = self.get_node_by_id(int(node_id))
if isinstance(node_id, str) and not node_id.isnumeric():
raise TypeError(f"Node id {node_id} in entry_points not numeric")
node_id = int(node_id)
node = self.get_node_by_id(node_id)
if node:
attacker.entry_points.append(node)
else:
Expand Down
30 changes: 29 additions & 1 deletion tests/attackgraph/test_attackgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from unittest.mock import patch

from maltoolbox.language import LanguageGraph
from maltoolbox.attackgraph import AttackGraph
from maltoolbox.attackgraph import AttackGraph, Attacker
from maltoolbox.model import Model, AttackerAttachment
from maltoolbox.exceptions import AttackGraphException

from test_model import create_application_asset, create_association

Expand Down Expand Up @@ -388,6 +389,33 @@ def test_attackgraph_regenerate_graph():
pass


def test_attackgraph_add_attacker_bad_entrypoints(example_attackgraph):
"""Make sure errors are thrown when attacker is added with bad
(wrong type/nonexistent) entrypoints or reached_attack_steps"""

# Add an attacker to the graph
attacker1 = Attacker("Attacker1")

with pytest.raises(TypeError):
example_attackgraph.add_attacker(
attacker1, reached_attack_steps=[0], entry_points=["STRING"])

with pytest.raises(TypeError):
example_attackgraph.add_attacker(
attacker1, reached_attack_steps=["STRING"], entry_points=[0])

with pytest.raises(AttackGraphException):
example_attackgraph.add_attacker(
attacker1, reached_attack_steps=["100000"], entry_points=[0])

with pytest.raises(AttackGraphException):
example_attackgraph.add_attacker(
attacker1, reached_attack_steps=[0], entry_points=["100000"])

example_attackgraph.add_attacker(
attacker1, reached_attack_steps=[0], entry_points=[0])


def test_attackgraph_remove_node(example_attackgraph: AttackGraph):
"""Make sure nodes are removed correctly"""
node_to_remove = example_attackgraph.nodes[10]
Expand Down

0 comments on commit 12d9dd4

Please sign in to comment.