diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index e08311d941..2406007004 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -140,7 +140,11 @@ def add_user_to_group( """ try: user_id = self.get_id_from_username(username) - group_id = self.get_id_from_groupname(group_name) + if self.entra_group_exists(group_name): + group_id = self.get_id_from_groupname(group_name) + else: + msg = f"Group '{group_name}' not found." + raise DataSafeHavenMicrosoftGraphError(msg) json_response = self.http_get( f"{self.base_endpoint}/groups/{group_id}/members", ).json() @@ -321,7 +325,7 @@ def create_group(self, group_name: str) -> None: DataSafeHavenMicrosoftGraphError if the group could not be created """ try: - if self.get_id_from_groupname(group_name): + if self.entra_group_exists(group_name): self.logger.info( f"Found existing Entra group '[green]{group_name}[/]'.", ) @@ -515,18 +519,17 @@ def get_service_principal_by_name( except (DataSafeHavenMicrosoftGraphError, StopIteration): return None + def entra_group_exists(self, group_name: str) -> bool: + return bool(any(x["displayName"] == group_name for x in self.read_groups())) + def get_id_from_groupname(self, group_name: str) -> str: - try: - return str( - next( - group - for group in self.read_groups() - if group["displayName"] == group_name - )["id"] - ) - except (DataSafeHavenMicrosoftGraphError, StopIteration): - msg = "Admin group not found. Check that the group exists and that you have the correct permissions." - raise DataSafeHavenMicrosoftGraphError(msg) from None + return str( + next( + group + for group in self.read_groups() + if group["displayName"] == group_name + )["id"] + ) def get_id_from_username(self, username: str) -> str | None: try: @@ -1016,6 +1019,9 @@ def remove_user_from_group( """ try: user_id = self.get_id_from_username(username) + if not self.entra_group_exists(group_name): + msg = f"Group '{group_name}' not found." + raise DataSafeHavenMicrosoftGraphError(msg) group_id = self.get_id_from_groupname(group_name) # Check whether user is in group json_response = self.http_get(