diff --git a/config/settings/base.py b/config/settings/base.py index 4a7a683..01c4628 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -45,6 +45,7 @@ "rp_transferto", "rp_recruit", "rp_interceptors", + "randomisation", ] MIDDLEWARE = [ diff --git a/config/urls.py b/config/urls.py index 9dbf067..7a2125d 100644 --- a/config/urls.py +++ b/config/urls.py @@ -8,4 +8,5 @@ path("recruit/", include("rp_recruit.urls"), name="rp_recruit"), path("interceptor/", include("rp_interceptors.urls")), path("dtone/", include("rp_dtone.urls")), + path("randomisation/", include("randomisation.urls")), ] diff --git a/randomisation/__init__.py b/randomisation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/randomisation/admin.py b/randomisation/admin.py new file mode 100644 index 0000000..8a30886 --- /dev/null +++ b/randomisation/admin.py @@ -0,0 +1,25 @@ +from django.contrib import admin + +from randomisation.models import Arm, Strata, StrataOption, Strategy + + +class ArmInline(admin.TabularInline): + model = Arm + + +class StrataOptionInline(admin.TabularInline): + model = StrataOption + + +@admin.register(Strategy) +class StrategyAdmin(admin.ModelAdmin): + list_display = ("name",) + + inlines = [ArmInline] + + +@admin.register(Strata) +class StrataAdmin(admin.ModelAdmin): + list_display = ("__str__",) + + inlines = [StrataOptionInline] diff --git a/randomisation/apps.py b/randomisation/apps.py new file mode 100644 index 0000000..02216fa --- /dev/null +++ b/randomisation/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class RandomisationConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "randomisation" diff --git a/randomisation/migrations/0001_initial.py b/randomisation/migrations/0001_initial.py new file mode 100644 index 0000000..9755dd9 --- /dev/null +++ b/randomisation/migrations/0001_initial.py @@ -0,0 +1,124 @@ +# Generated by Django 4.2.11 on 2024-04-17 09:03 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="Strata", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=200)), + ], + ), + migrations.CreateModel( + name="Strategy", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=200)), + ( + "stratas", + models.ManyToManyField( + related_name="stategy_stratas", to="randomisation.strata" + ), + ), + ], + options={ + "verbose_name_plural": "Strategies", + }, + ), + migrations.CreateModel( + name="StrataOption", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("description", models.CharField(max_length=200)), + ( + "strata", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="options", + to="randomisation.strata", + ), + ), + ], + ), + migrations.CreateModel( + name="StrataMatrix", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("strata_data", models.JSONField()), + ("next_index", models.IntegerField(default=0)), + ("arm_order", models.CharField(max_length=255)), + ( + "strategy", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="matrix_records", + to="randomisation.strategy", + ), + ), + ], + ), + migrations.CreateModel( + name="Arm", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=200)), + ( + "strategy", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="arms", + to="randomisation.strategy", + ), + ), + ], + ), + ] diff --git a/randomisation/migrations/__init__.py b/randomisation/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/randomisation/models.py b/randomisation/models.py new file mode 100644 index 0000000..5f10bbb --- /dev/null +++ b/randomisation/models.py @@ -0,0 +1,55 @@ +from django.db import models +from django.db.models import JSONField +from django.utils.text import slugify + + +class Strata(models.Model): + name = models.CharField(max_length=200, null=False, blank=False) + + @property + def slug(self): + return slugify(self.name) + + def __str__(self): + options = [option.description for option in self.options.all()] + return f"{self.name} - [{', '.join(options)}]" + + +class Strategy(models.Model): + class Meta: + verbose_name_plural = "Strategies" + + name = models.CharField(max_length=200, null=False, blank=False) + stratas = models.ManyToManyField(Strata, related_name="stategy_stratas") + + +class Arm(models.Model): + strategy = models.ForeignKey( + Strategy, + related_name="arms", + null=False, + on_delete=models.CASCADE, + ) + name = models.CharField(max_length=200, null=False, blank=False) + + +class StrataOption(models.Model): + strata = models.ForeignKey( + Strata, + related_name="options", + null=False, + on_delete=models.CASCADE, + ) + description = models.CharField(max_length=200, null=False, blank=False) + + +class StrataMatrix(models.Model): + strategy = models.ForeignKey( + Strategy, + related_name="matrix_records", + null=False, + on_delete=models.CASCADE, + ) + strata_data = JSONField() + next_index = models.IntegerField(default=0) + arm_order = models.CharField(max_length=255, null=False, blank=False) diff --git a/randomisation/tests/__init__.py b/randomisation/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/randomisation/tests/test_utils.py b/randomisation/tests/test_utils.py new file mode 100644 index 0000000..7463344 --- /dev/null +++ b/randomisation/tests/test_utils.py @@ -0,0 +1,72 @@ +import random +from collections import defaultdict + +from django.test import TestCase + +from randomisation.utils import ( + get_random_stratification_arm, + validate_stratification_data, +) + +from .utils import create_test_strategy + + +# TODO: add docstrings to tests +class TestValidateStratificationData(TestCase): + def setUp(self): + self.strategy = create_test_strategy() + + def test_stratification_validation_valid_data(self): + error = validate_stratification_data( + self.strategy, {"age-group": "18-29", "province": "WC"} + ) + self.assertIsNone(error) + + def test_stratification_validation_missing_key(self): + error = validate_stratification_data(self.strategy, {"age-group": "18-29"}) + self.assertEqual(error, "'province' is a required property") + + def test_stratification_validation_extra_key(self): + error = validate_stratification_data( + self.strategy, {"age-group": "18-29", "province": "WC", "extra": "key"} + ) + + self.assertEqual( + error, "Additional properties are not allowed ('extra' was unexpected)" + ) + + def test_stratification_validation_invalid_option(self): + error = validate_stratification_data( + self.strategy, {"age-group": "18-29", "province": "FS"} + ) + self.assertEqual(error, "'FS' is not one of ['WC', 'GT']") + + +class TestGetRandomStratification(TestCase): + + # TODO: add more tests for randomisation + + def test_stratification_balancing(self): + strategy = create_test_strategy() + + totals = defaultdict(int) + stratas = defaultdict(lambda: defaultdict(int)) + for i in range(100): + random_age = random.choice(["18-29", "29-39"]) + random_province = random.choice(["WC", "GT"]) + + data = {"age-group": random_age, "province": random_province} + + random_arm = get_random_stratification_arm(strategy, data) + stratas[f"{random_age}_{random_province}"][random_arm] += 1 + totals[random_arm] += 1 + + def check_arms_balanced(arms, diff, description): + values = [value for value in arms.values()] + msg = f"Arms not balanced: {description} - {values}" + assert max(values) - diff < values[0] < min(values) + diff, msg + + check_arms_balanced(totals, 3, "Totals") + + for key, arms in stratas.items(): + check_arms_balanced(arms, 3, key) diff --git a/randomisation/tests/utils.py b/randomisation/tests/utils.py new file mode 100644 index 0000000..2b03066 --- /dev/null +++ b/randomisation/tests/utils.py @@ -0,0 +1,27 @@ +from randomisation.models import Arm, Strata, StrataOption, Strategy + +DEFAULT_STRATEGY = { + "name": "Test Strategy", + "arms": ["Arm 1", "Arm 2", "Arm 3"], + "stratas": [ + {"name": "Age Group", "options": ["18-29", "29-39"]}, + {"name": "Province", "options": ["WC", "GT"]}, + ], +} + + +def create_test_strategy(data=DEFAULT_STRATEGY): + strategy = Strategy.objects.create(name=data["name"]) + + for arm in data["arms"]: + Arm.objects.create(strategy=strategy, name=arm) + + for strata_data in data["stratas"]: + strata = Strata.objects.create(name=strata_data["name"]) + + for option in strata_data["options"]: + StrataOption.objects.create(strata=strata, description=option) + + strategy.stratas.add(strata) + + return strategy diff --git a/randomisation/urls.py b/randomisation/urls.py new file mode 100644 index 0000000..f72cb6b --- /dev/null +++ b/randomisation/urls.py @@ -0,0 +1,11 @@ +from django.urls import path + +from . import views + +urlpatterns = [ + path( + "/get_random_arm/", + views.GetRandomArmView.as_view(), + name="get_random_arm", + ), +] diff --git a/randomisation/utils.py b/randomisation/utils.py new file mode 100644 index 0000000..37d7447 --- /dev/null +++ b/randomisation/utils.py @@ -0,0 +1,48 @@ +import random + +from jsonschema import validate +from jsonschema.exceptions import ValidationError + +from randomisation.models import StrataMatrix + + +def validate_stratification_data(strategy, data): + try: + schema = { + "type": "object", + "properties": {}, + "required": [strata.slug for strata in strategy.stratas.all()], + "additionalProperties": False, + } + + for strata in strategy.stratas.all(): + options = [option.description for option in strata.options.all()] + schema["properties"][strata.slug] = {"type": "string", "enum": options} + + validate(instance=data, schema=schema) + except ValidationError as e: + return e.message + + +def get_random_stratification_arm(strategy, data): + matrix, created = StrataMatrix.objects.get_or_create( + strategy=strategy, strata_data=data + ) + + if created: + study_arms = [arm.name for arm in strategy.arms.all()] + random.shuffle(study_arms) + random_arms = study_arms + matrix.arm_order = ",".join(study_arms) + else: + random_arms = matrix.arm_order.split(",") + + arm = random_arms[matrix.next_index] + + if matrix.next_index + 1 == len(random_arms): + matrix.delete() + else: + matrix.next_index += 1 + matrix.save() + + return arm diff --git a/randomisation/views.py b/randomisation/views.py new file mode 100644 index 0000000..2fa397d --- /dev/null +++ b/randomisation/views.py @@ -0,0 +1,32 @@ +from django.http import JsonResponse +from rest_framework import status +from rest_framework.views import APIView + +from randomisation.models import Strategy +from randomisation.utils import ( + get_random_stratification_arm, + validate_stratification_data, +) + +# TODO: add a endpoint to call the validate_stratification_data + + +class GetRandomArmView(APIView): + def post(self, request, *args, **kwargs): + strategy_id = kwargs["strategy_id"] + strategy = Strategy.objects.get(id=strategy_id) + + # TODO: serializer for request.data? + + error = validate_stratification_data(strategy, request.data) + if error: + return JsonResponse( + data={"error": error}, status=status.HTTP_400_BAD_REQUEST + ) + + arm = get_random_stratification_arm(strategy, request.data) + + return JsonResponse(data={"arm": arm}, status=status.HTTP_200_OK) + + +# TODO: add tests for views diff --git a/setup.py b/setup.py index 764dc50..8261170 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "sentry-sdk==1.14.0", "dj-database-url==0.5.0", "boto3", + "jsonschema==4.21.1", ], classifiers=[ "Development Status :: 4 - Beta",