Skip to content

Commit

Permalink
Add custom deepcopy to AttackGraph and AttackGraphNode
Browse files Browse the repository at this point in the history
- Make sure that each node is copied once by using 'memo'.
- Copy lookup dicts for nodes/attackers.
- Add a test to make sure deep copy works as expected.
  • Loading branch information
mrkickling committed Sep 17, 2024
1 parent 2df9b74 commit 5bdf57e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 1 deletion.
26 changes: 25 additions & 1 deletion maltoolbox/attackgraph/attackgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
MAL-Toolbox Attack Graph Module
"""
from __future__ import annotations
import copy
import logging
import json

Expand Down Expand Up @@ -224,6 +225,29 @@ def _to_dict(self) -> dict:
'attackers': serialized_attackers,
}

def __deepcopy__(self, memo):
copied_attackgraph = AttackGraph(self.lang_graph, self.model)

# Deep copy nodes and add references to them
copied_attackgraph.nodes = copy.deepcopy(self.nodes, memo)

# Deep copy attackers and references to them
copied_attackgraph.attackers = copy.deepcopy(self.attackers, memo)

# Copy lookup dicts
copied_attackgraph._id_to_attacker = \
copy.deepcopy(self._id_to_attacker, memo)
copied_attackgraph._id_to_node = \
copy.deepcopy(self._id_to_node, memo)
copied_attackgraph._full_name_to_node = \
copy.deepcopy(self._full_name_to_node, memo)

# Copy counters
copied_attackgraph.next_node_id = self.next_node_id
copied_attackgraph.next_attacker_id = self.next_attacker_id

return copied_attackgraph

def save_to_file(self, filename: str) -> None:
"""Save to json/yml depending on extension"""
logger.debug('Save attack graph to file "%s".', filename)
Expand Down Expand Up @@ -382,7 +406,7 @@ def get_node_by_full_name(self, full_name: str) -> Optional[AttackGraphNode]:
The attack step node that matches the given full name.
"""

logger.debug(f'Looking up node with full name "{full_name}"')
logger.debug(f'Looking up node with id {full_name}')
return self._full_name_to_node.get(full_name)

def get_attacker_by_id(self, attacker_id: int) -> Optional[Attacker]:
Expand Down
34 changes: 34 additions & 0 deletions maltoolbox/attackgraph/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from __future__ import annotations
import copy
from dataclasses import field, dataclass
from typing import Any, Optional

Expand Down Expand Up @@ -69,6 +70,39 @@ def to_dict(self) -> dict:
def __repr__(self) -> str:
return str(self.to_dict())

def __deepcopy__(self, memo) -> AttackGraphNode:
"""Deep copy an attackgraph node"""

copied_node = AttackGraphNode(
self.type,
self.name,
self.ttc,
self.id,
self.asset,
[],
[],
self.defense_status,
self.existence_status,
self.is_viable,
self.is_necessary,
copy.deepcopy(self.compromised_by, memo),
self.mitre_info,
copy.deepcopy(self.tags, memo),
copy.deepcopy(self.attributes, memo),
copy.deepcopy(self.extras, memo)
)

# Remember that self was already copied
memo[id(self)] = copied_node

# Deep copy children and parents, send memo (avoid infinite recursion)
if self.parents:
copied_node.parents = copy.deepcopy(self.parents, memo)
if self.children:
copied_node.children = copy.deepcopy(self.children, memo)

return copied_node

def is_compromised(self) -> bool:
"""
Return True if any attackers have compromised this node.
Expand Down
37 changes: 37 additions & 0 deletions tests/attackgraph/test_attackgraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Unit tests for AttackGraph functionality"""

import copy
import pytest
from unittest.mock import patch

Expand Down Expand Up @@ -403,3 +404,39 @@ def test_attackgraph_remove_node(example_attackgraph: AttackGraph):
assert node_to_remove not in parent.children
for child in children:
assert node_to_remove not in child.parents


def test_attackgraph_deepcopy(example_attackgraph: AttackGraph):
"""Try to deepcopy an attackgraph object"""
copied_attackgraph: AttackGraph = copy.deepcopy(example_attackgraph)

assert copied_attackgraph != example_attackgraph
assert copied_attackgraph._to_dict() == example_attackgraph._to_dict()

assert copied_attackgraph.next_node_id == example_attackgraph.next_node_id
assert copied_attackgraph.next_attacker_id == example_attackgraph.next_attacker_id

assert len(copied_attackgraph.nodes) == len(example_attackgraph.nodes)

assert list(copied_attackgraph._id_to_node.keys()) \
== list(example_attackgraph._id_to_node.keys())

assert list(copied_attackgraph._id_to_attacker.keys()) \
== list(example_attackgraph._id_to_attacker.keys())

assert list(copied_attackgraph._full_name_to_node.keys()) \
== list(example_attackgraph._full_name_to_node.keys())

for node in copied_attackgraph.nodes:
assert node.id is not None
original_node = example_attackgraph.get_node_by_id(node.id)
assert original_node
assert id(original_node) != id(node)
assert original_node.to_dict() == node.to_dict()

for attacker in copied_attackgraph.attackers:
assert attacker.id is not None
original_attacker = example_attackgraph.get_attacker_by_id(attacker.id)
assert original_attacker
assert id(original_attacker) != id(attacker)
assert original_attacker.to_dict() == attacker.to_dict()

0 comments on commit 5bdf57e

Please sign in to comment.