From 3fa705866bf02acd8e55218abf4951bb8d87671c Mon Sep 17 00:00:00 2001 From: Matt Craddock <5796417+craddm@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:26:55 +0000 Subject: [PATCH] add function for checking entra group exists, remove try except --- data_safe_haven/external/api/graph_api.py | 32 ++++++++++++++--------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index e08311d941..2406007004 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -140,7 +140,11 @@ def add_user_to_group( """ try: user_id = self.get_id_from_username(username) - group_id = self.get_id_from_groupname(group_name) + if self.entra_group_exists(group_name): + group_id = self.get_id_from_groupname(group_name) + else: + msg = f"Group '{group_name}' not found." + raise DataSafeHavenMicrosoftGraphError(msg) json_response = self.http_get( f"{self.base_endpoint}/groups/{group_id}/members", ).json() @@ -321,7 +325,7 @@ def create_group(self, group_name: str) -> None: DataSafeHavenMicrosoftGraphError if the group could not be created """ try: - if self.get_id_from_groupname(group_name): + if self.entra_group_exists(group_name): self.logger.info( f"Found existing Entra group '[green]{group_name}[/]'.", ) @@ -515,18 +519,17 @@ def get_service_principal_by_name( except (DataSafeHavenMicrosoftGraphError, StopIteration): return None + def entra_group_exists(self, group_name: str) -> bool: + return bool(any(x["displayName"] == group_name for x in self.read_groups())) + def get_id_from_groupname(self, group_name: str) -> str: - try: - return str( - next( - group - for group in self.read_groups() - if group["displayName"] == group_name - )["id"] - ) - except (DataSafeHavenMicrosoftGraphError, StopIteration): - msg = "Admin group not found. Check that the group exists and that you have the correct permissions." - raise DataSafeHavenMicrosoftGraphError(msg) from None + return str( + next( + group + for group in self.read_groups() + if group["displayName"] == group_name + )["id"] + ) def get_id_from_username(self, username: str) -> str | None: try: @@ -1016,6 +1019,9 @@ def remove_user_from_group( """ try: user_id = self.get_id_from_username(username) + if not self.entra_group_exists(group_name): + msg = f"Group '{group_name}' not found." + raise DataSafeHavenMicrosoftGraphError(msg) group_id = self.get_id_from_groupname(group_name) # Check whether user is in group json_response = self.http_get(