Skip to content

Commit

Permalink
Fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
katspaugh committed Sep 19, 2024
1 parent 42d0193 commit 04f23d0
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 134 deletions.
24 changes: 9 additions & 15 deletions src/safe_apps/admin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any

from django import forms
from django.contrib import admin
from django.contrib.admin import widgets as admin_widgets
from django.db.models import Model, QuerySet
from django.forms import ModelForm

from .models import Chain, Client, Feature, Provider, SafeApp, SocialProfile, Tag

Expand Down Expand Up @@ -38,32 +39,25 @@ class SocialProfileInline(admin.TabularInline[Model, Model]):
verbose_name_plural = "Social profiles set for this Safe App"


class SafeAppAdminForm(ModelForm):
class SafeAppForm(forms.ModelForm[SafeApp]):
class Meta:
model = SafeApp
fields = "__all__"
widgets = {
"chains": admin.widgets.FilteredSelectMultiple("Chains", False),
"chains": admin_widgets.FilteredSelectMultiple("Chains", False),
}


@admin.register(SafeApp)
class SafeAppAdmin(admin.ModelAdmin[SafeApp]):
form = SafeAppAdminForm
list_display = ("name", "url", "get_chains", "listed")
list_filter = (ChainFilter,)
form = SafeAppForm
list_display = ("name", "url", "get_chains")
search_fields = ("name", "url")
ordering = ("name",)
inlines = [
TagInline,
FeatureInline,
SocialProfileInline,
]

def get_chains(self, obj):
def get_chains(self, obj: SafeApp) -> str:
return ", ".join([chain.name for chain in obj.chains.all()])

get_chains.short_description = "Chains"
get_chains.short_description = "Chains" # type: ignore[attr-defined]


@admin.register(Provider)
Expand Down Expand Up @@ -102,6 +96,6 @@ class SocialProfileAdmin(admin.ModelAdmin[SocialProfile]):


@admin.register(Chain)
class ChainAdmin(admin.ModelAdmin):
class ChainAdmin(admin.ModelAdmin[Chain]):
list_display = ("chain_id", "name")
search_fields = ("chain_id", "name")
67 changes: 14 additions & 53 deletions src/safe_apps/migrations/0015_populate_chains.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,23 @@
from django.db import migrations, models
from typing import Any

from django.db import migrations

def populate_chains(apps, schema_editor):

def populate_chains(apps: Any, schema_editor: Any) -> None:
SafeApp = apps.get_model("safe_apps", "SafeApp")
Chain = apps.get_model("safe_apps", "Chain")
chains_data = [
{"chain_id": 1, "name": "Ethereum"},
{"chain_id": 100, "name": "Gnosis Chain"},
{"chain_id": 137, "name": "Polygon"},
{"chain_id": 1101, "name": "Polygon zkEVM"},
{"chain_id": 56, "name": "BNB Chain"},
{"chain_id": 42161, "name": "Arbitrum"},
{"chain_id": 10, "name": "Optimism"},
{"chain_id": 8453, "name": "Base"},
{"chain_id": 59144, "name": "Linea"},
{"chain_id": 324, "name": "zkSync Era"},
{"chain_id": 534352, "name": "Scroll"},
{"chain_id": 196, "name": "X Layer"},
{"chain_id": 42220, "name": "Celo"},
{"chain_id": 43114, "name": "Avalanche"},
{"chain_id": 81457, "name": "Blast"},
{"chain_id": 1313161554, "name": "Aurora"},
{"chain_id": 11155111, "name": "Sepolia"},
{"chain_id": 84532, "name": "Base Sepolia"},
{"chain_id": 10200, "name": "Gnosis Chiado"},
]

for chain in chains_data:
Chain.objects.get_or_create(
chain_id=chain["chain_id"], defaults={"name": chain["name"]}
)
for safe_app in SafeApp.objects.all():
for chain_id in safe_app.chain_ids:
chain, _ = Chain.objects.get_or_create(chain_id=chain_id)
safe_app.chains.add(chain)


def reverse_populate_chains(apps, schema_editor):
Chain = apps.get_model("safe_apps", "Chain")
Chain.objects.all().delete()
def reverse_populate_chains(apps: Any, schema_editor: Any) -> None:
SafeApp = apps.get_model("safe_apps", "SafeApp")

for safe_app in SafeApp.objects.all():
safe_app.chains.clear()


class Migration(migrations.Migration):
Expand All @@ -43,28 +27,5 @@ class Migration(migrations.Migration):
]

operations = [
migrations.CreateModel(
name="Chain",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("chain_id", models.PositiveIntegerField(unique=True)),
("name", models.CharField(max_length=100)),
],
),
migrations.AddField(
model_name="safeapp",
name="chains",
field=models.ManyToManyField(
related_name="safe_apps", to="safe_apps.chain"
),
),
migrations.RunPython(populate_chains, reverse_populate_chains),
]
35 changes: 15 additions & 20 deletions src/safe_apps/migrations/0016_chain_ids_to_chains.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,34 @@
from typing import Any

from django.db import migrations


def copy_chain_ids_to_chains(apps, schema_editor):
def migrate_chain_ids_to_chains(apps: Any, schema_editor: Any) -> None:
SafeApp = apps.get_model("safe_apps", "SafeApp")
Chain = apps.get_model("safe_apps", "Chain")

# Get all SafeApps and their chain_ids
safe_apps = SafeApp.objects.all()
for safe_app in SafeApp.objects.all():
chain_ids = safe_app.chain_ids
for chain_id in chain_ids:
chain, _ = Chain.objects.get_or_create(chain_id=chain_id)
safe_app.chains.add(chain)

# Create Chain objects for any missing chain_ids
for safe_app in safe_apps:
for chain_id in safe_app.chain_ids:
Chain.objects.get_or_create(
chain_id=chain_id,
defaults={
"name": f"Chain {chain_id}"
}, # Default name if not already set
)

def reverse_migrate_chain_ids_to_chains(apps: Any, schema_editor: Any) -> None:
SafeApp = apps.get_model("safe_apps", "SafeApp")

def reverse_copy_chain_ids(apps, schema_editor):
Chain = apps.get_model("safe_apps", "Chain")
Chain.objects.all().delete()
for safe_app in SafeApp.objects.all():
safe_app.chain_ids = list(safe_app.chains.values_list("chain_id", flat=True))
safe_app.save()


class Migration(migrations.Migration):

dependencies = [
("safe_apps", "0015_populate_chains"),
]

operations = [
migrations.RunPython(copy_chain_ids_to_chains, reverse_copy_chain_ids),
migrations.RemoveField(
model_name="safeapp",
name="chain_ids",
migrations.RunPython(
migrate_chain_ids_to_chains, reverse_migrate_chain_ids_to_chains
),
]
26 changes: 22 additions & 4 deletions src/safe_apps/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import uuid
from enum import Enum
from typing import IO, Union
from typing import IO, Any, List, Union

from django.core.exceptions import ValidationError
from django.core.files.images import get_image_dimensions
from django.core.validators import RegexValidator
from django.db import models
from django.db.models.manager import Manager

_HOSTNAME_VALIDATOR = RegexValidator(
r"^(https?:\/\/)?(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\/?$",
Expand Down Expand Up @@ -55,10 +56,14 @@ class Chain(models.Model):
chain_id = models.PositiveIntegerField(unique=True)
name = models.CharField(max_length=100)

def __str__(self):
def __str__(self) -> str:
return f"{self.name} ({self.chain_id})"


class ChainManager(Manager["Chain"]):
pass


class SafeApp(models.Model):
class AccessControlPolicy(str, Enum):
NO_RESTRICTIONS = "NO_RESTRICTIONS"
Expand All @@ -77,7 +82,8 @@ class AccessControlPolicy(str, Enum):
default="safe_apps/icon_url.jpg",
)
description = models.CharField(max_length=200)
chains = models.ManyToManyField(Chain, related_name="safe_apps")
# Using Any for type annotation due to mypy limitations with ManyToManyField
chains: Any = models.ManyToManyField(Chain, related_name="safe_apps")
provider = models.ForeignKey(
Provider, null=True, blank=True, on_delete=models.SET_NULL
)
Expand All @@ -94,7 +100,19 @@ def get_access_control_type(self) -> AccessControlPolicy:
return SafeApp.AccessControlPolicy.NO_RESTRICTIONS

def __str__(self) -> str:
return f"{self.name} | {self.url} | chain_ids={self.chain_ids}"
return f"{self.name} | {self.url} | chain_ids={[chain.chain_id for chain in self.chains.all()]}"

@property
def chain_ids(self) -> List[int]:
return list(self.chains.values_list("chain_id", flat=True))

def add_chains(self, *chains: Union[Chain, int]) -> None:
for chain in chains:
if isinstance(chain, int):
chain_obj, _ = Chain.objects.get_or_create(chain_id=chain)
self.chains.add(chain_obj)
else:
self.chains.add(chain)


class Tag(models.Model):
Expand Down
42 changes: 21 additions & 21 deletions src/safe_apps/signals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any
from typing import Any, Set

from django.core.cache import caches
from django.db.models.signals import (
Expand All @@ -22,11 +22,11 @@
def on_safe_app_update(sender: SafeApp, instance: SafeApp, **kwargs: Any) -> None:
logger.info("Clearing safe-apps cache")
caches["safe-apps"].clear()
chain_ids = set(instance.chain_ids)
chain_ids = set(chain.chain_id for chain in instance.chains.all())
if instance.app_id is not None: # existing SafeApp being updated
previous = SafeApp.objects.filter(app_id=instance.app_id).first()
if previous is not None:
chain_ids.update(previous.chain_ids)
chain_ids.update(chain.chain_id for chain in previous.chains.all())
for chain_id in chain_ids:
hook_event(HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain_id))

Expand All @@ -35,8 +35,10 @@ def on_safe_app_update(sender: SafeApp, instance: SafeApp, **kwargs: Any) -> Non
def on_safe_app_delete(sender: SafeApp, instance: SafeApp, **kwargs: Any) -> None:
logger.info("Clearing safe-apps cache")
caches["safe-apps"].clear()
for chain_id in instance.chain_ids:
hook_event(HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain_id))
for chain in instance.chains.all():
hook_event(
HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain.chain_id)
)


@receiver(post_save, sender=Provider)
Expand All @@ -45,9 +47,9 @@ def on_provider_update(sender: Provider, instance: Provider, **kwargs: Any) -> N
logger.info("Clearing safe-apps cache")
caches["safe-apps"].clear()
for safe_app in instance.safeapp_set.all():
for chain_id in safe_app.chain_ids:
for chain in safe_app.chains.all():
hook_event(
HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain_id)
HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain.chain_id)
)


Expand All @@ -59,23 +61,22 @@ def on_tag_update(sender: Tag, instance: Tag, **kwargs: Any) -> None:
logger.info("Clearing safe-apps cache")
caches["safe-apps"].clear()
for safe_app in instance.safe_apps.all():
for chain_id in safe_app.chain_ids:
for chain in safe_app.chains.all():
hook_event(
HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain_id)
HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain.chain_id)
)


@receiver(m2m_changed, sender=Tag.safe_apps.through)
def on_tag_chains_update(
sender: Tag, instance: Tag, action: str, pk_set: set[int], **kwargs: Any
sender: Tag, instance: Tag, action: str, pk_set: Set[int], **kwargs: Any
) -> None:
logger.info("TagChains update. Triggering CGW webhook")
caches["safe-apps"].clear()
if action == "post_add" or action == "post_remove":
chain_ids = set()
if action in ["post_add", "post_remove"]:
chain_ids: Set[int] = set()
for safe_app in SafeApp.objects.filter(app_id__in=pk_set):
for chain_id in safe_app.chain_ids:
chain_ids.add(chain_id)
chain_ids.update(chain.chain_id for chain in safe_app.chains.all())
for chain_id in chain_ids:
hook_event(
HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain_id)
Expand All @@ -90,23 +91,22 @@ def on_feature_update(sender: Feature, instance: Feature, **kwargs: Any) -> None
logger.info("Feature update. Triggering CGW webhook")
caches["safe-apps"].clear()
for safe_app in instance.safe_apps.all():
for chain_id in safe_app.chain_ids:
for chain in safe_app.chains.all():
hook_event(
HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain_id)
HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain.chain_id)
)


@receiver(m2m_changed, sender=Feature.safe_apps.through)
def on_feature_safe_apps_update(
sender: Feature, instance: Feature, action: str, pk_set: set[int], **kwargs: Any
sender: Feature, instance: Feature, action: str, pk_set: Set[int], **kwargs: Any
) -> None:
logger.info("FeatureSafeApps update. Triggering CGW webhook")
caches["safe-apps"].clear()
if action == "post_add" or action == "post_remove":
chain_ids = set()
if action in ["post_add", "post_remove"]:
chain_ids: Set[int] = set()
for safe_app in SafeApp.objects.filter(app_id__in=pk_set):
for chain_id in safe_app.chain_ids:
chain_ids.add(chain_id)
chain_ids.update(chain.chain_id for chain in safe_app.chains.all())
for chain_id in chain_ids:
hook_event(
HookEvent(type=HookEvent.Type.SAFE_APPS_UPDATE, chain_id=chain_id)
Expand Down
Loading

0 comments on commit 04f23d0

Please sign in to comment.