diff --git a/tests/test_using.py b/tests/test_using.py index 5b2200a6..96b6baff 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -1421,6 +1421,59 @@ class Params: dict(one=None, two=False, three=None, four=True, five=None), ) + def test_chaining_traits_and_related_with_nested_factories(self): + class TestRelatedObject: + def __init__(self, obj=None, attr=None, nested=None): + obj.related = self + self.nested = nested + self.attr = attr + + class TestNestedObject: + def __init__(self, attr=None): + self.attr = attr + + class TestNestedObjectFactory(factory.Factory): + class Meta: + model = TestNestedObject + attr = 1 + + class TestRelatedObjectFactory(factory.Factory): + class Meta: + model = TestRelatedObject + attr = 1 + + nested = factory.SubFactory(TestNestedObjectFactory, attr=None) + + class TestObjectFactory(factory.Factory): + class Meta: + model = TestObject + + class Params: + with_related = factory.Trait( + related_obj=factory.RelatedFactory( + TestRelatedObjectFactory, + factory_related_name='obj', + nested__attr=2, + ), + ) + with_related_nested_override = factory.Trait( + with_related=True, + related_obj__nested__attr=3, + ) + + obj = TestObjectFactory.build(with_related_nested_override=True) + self.assertEqual(1, obj.related.attr) + self.assertEqual(3, obj.related.nested.attr) + obj = TestObjectFactory.build(with_related_nested_override=True, related_obj__nested__attr=4) + self.assertEqual(1, obj.related.attr) + self.assertEqual(4, obj.related.nested.attr) + obj = TestObjectFactory.build(with_related=True) + self.assertEqual(1, obj.related.attr) + self.assertEqual(2, obj.related.nested.attr) + obj = TestObjectFactory.build(with_related_nested_override=False) + with self.assertRaises(AttributeError): + obj.related + def test_prevent_cyclic_traits(self): with self.assertRaises(errors.CyclicDefinitionError):