Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add randomisation tests and docstrings #177

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions randomisation/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from django.test import TestCase

from randomisation.models import StrataMatrix
from randomisation.utils import (
get_random_stratification_arm,
validate_stratification_data,
Expand All @@ -11,22 +12,30 @@
from .utils import create_test_strategy


# TODO: add docstrings to tests
class TestValidateStratificationData(TestCase):
def setUp(self):
self.strategy = create_test_strategy()

def test_stratification_validation_valid_data(self):
"""
Test with valid data
"""
error = validate_stratification_data(
self.strategy, {"age-group": "18-29", "province": "WC"}
)
self.assertIsNone(error)

def test_stratification_validation_missing_key(self):
"""
Test with a missing strata key
"""
error = validate_stratification_data(self.strategy, {"age-group": "18-29"})
self.assertEqual(error, "'province' is a required property")

def test_stratification_validation_extra_key(self):
"""
Test with strata key that is not configured
"""
error = validate_stratification_data(
self.strategy, {"age-group": "18-29", "province": "WC", "extra": "key"}
)
Expand All @@ -36,17 +45,79 @@ def test_stratification_validation_extra_key(self):
)

def test_stratification_validation_invalid_option(self):
"""
Test with invalid strata value
"""
error = validate_stratification_data(
self.strategy, {"age-group": "18-29", "province": "FS"}
)
self.assertEqual(error, "'FS' is not one of ['WC', 'GT']")


class TestGetRandomStratification(TestCase):
def test_random_arm(self):
"""
Test that it returns a random arm that matches the first item in the matrix
object it created and next index is set
"""
strategy = create_test_strategy()

data = {"age-group": "18-29", "province": "WC"}
random_arm = get_random_stratification_arm(strategy, data)

strata_arm = StrataMatrix.objects.first()

self.assertEqual(random_arm, strata_arm.arm_order.split(",")[0])
self.assertEqual(strata_arm.next_index, 1)

def test_random_arm_with_matrix(self):
"""
Check the next arm from the existing matrix record and the next_idnex is updated
"""
strategy = create_test_strategy()

data = {"age-group": "18-29", "province": "WC"}

StrataMatrix.objects.create(
strategy=strategy,
strata_data=data,
next_index=1,
arm_order="Arm 1,Arm 2,Arm 3",
)

random_arm = get_random_stratification_arm(strategy, data)

strata_arm = StrataMatrix.objects.first()

self.assertEqual(random_arm, "Arm 2")
self.assertEqual(strata_arm.next_index, 2)

def test_random_arm_out_of_index(self):
"""
Test for out of index to delete the order after maximum arm
"""

strategy = create_test_strategy()

data = {"age-group": "18-29", "province": "WC"}

StrataMatrix.objects.create(
strategy=strategy,
strata_data=data,
next_index=2,
arm_order="Arm 1,Arm 2,Arm 3",
)

random_arm = get_random_stratification_arm(strategy, data)

# TODO: add more tests for randomisation
self.assertEqual(StrataMatrix.objects.count(), 0)
self.assertEqual(random_arm, "Arm 3")

def test_stratification_balancing(self):
"""
Testing that after 100 iterations that the resuls are balanced accross the
configured arms in total and per strata group
"""
strategy = create_test_strategy()

totals = defaultdict(int)
Expand Down
8 changes: 8 additions & 0 deletions randomisation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@


def validate_stratification_data(strategy, data):
"""
Validates that the data dict received is valid compared to the strategy
configuration.
"""
try:
schema = {
"type": "object",
Expand All @@ -25,6 +29,10 @@ def validate_stratification_data(strategy, data):


def get_random_stratification_arm(strategy, data):
"""
Get or create a strata matrix object for the given data and returns the next arm,
it will delete the matrix record if the last arm in the matrix was returned.
"""
matrix, created = StrataMatrix.objects.get_or_create(
strategy=strategy, strata_data=data
)
Expand Down
Loading