Skip to content

Commit

Permalink
add function for checking entra group exists, remove try except
Browse files Browse the repository at this point in the history
  • Loading branch information
craddm committed Sep 25, 2024
1 parent 560917f commit 3fa7058
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions data_safe_haven/external/api/graph_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ def add_user_to_group(
"""
try:
user_id = self.get_id_from_username(username)
group_id = self.get_id_from_groupname(group_name)
if self.entra_group_exists(group_name):
group_id = self.get_id_from_groupname(group_name)
else:
msg = f"Group '{group_name}' not found."
raise DataSafeHavenMicrosoftGraphError(msg)
json_response = self.http_get(
f"{self.base_endpoint}/groups/{group_id}/members",
).json()
Expand Down Expand Up @@ -321,7 +325,7 @@ def create_group(self, group_name: str) -> None:
DataSafeHavenMicrosoftGraphError if the group could not be created
"""
try:
if self.get_id_from_groupname(group_name):
if self.entra_group_exists(group_name):
self.logger.info(
f"Found existing Entra group '[green]{group_name}[/]'.",
)
Expand Down Expand Up @@ -515,18 +519,17 @@ def get_service_principal_by_name(
except (DataSafeHavenMicrosoftGraphError, StopIteration):
return None

def entra_group_exists(self, group_name: str) -> bool:
return bool(any(x["displayName"] == group_name for x in self.read_groups()))

def get_id_from_groupname(self, group_name: str) -> str:
try:
return str(
next(
group
for group in self.read_groups()
if group["displayName"] == group_name
)["id"]
)
except (DataSafeHavenMicrosoftGraphError, StopIteration):
msg = "Admin group not found. Check that the group exists and that you have the correct permissions."
raise DataSafeHavenMicrosoftGraphError(msg) from None
return str(
next(
group
for group in self.read_groups()
if group["displayName"] == group_name
)["id"]
)

def get_id_from_username(self, username: str) -> str | None:
try:
Expand Down Expand Up @@ -1016,6 +1019,9 @@ def remove_user_from_group(
"""
try:
user_id = self.get_id_from_username(username)
if not self.entra_group_exists(group_name):
msg = f"Group '{group_name}' not found."
raise DataSafeHavenMicrosoftGraphError(msg)
group_id = self.get_id_from_groupname(group_name)
# Check whether user is in group
json_response = self.http_get(
Expand Down

0 comments on commit 3fa7058

Please sign in to comment.