Skip to content

Commit

Permalink
fixes rollback creation of User when phone number exists
Browse files Browse the repository at this point in the history
Because `Contact.save` passes the `using` argument
to `transaction.atomic(using=using)`, any outer `transaction.atomic`
involving `Contact.save` (by extension `Contact.create`) will not properly
rollback unless the same database is passed through the `using` argument.
  • Loading branch information
smirolo committed Jul 29, 2024
1 parent 8e0a8e4 commit 85542f9
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 39 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# The short X.Y version
version = '0.9'
# The full version, including alpha/beta/rc tags
release = '0.9.6'
release = '0.9.7-dev'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion signup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
PEP 386-compliant version number for the signup django app.
"""

__version__ = '0.9.6'
__version__ = '0.9.7-dev'
25 changes: 14 additions & 11 deletions signup/api/users.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, DjaoDjin inc.
# Copyright (c) 2024, DjaoDjin inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -26,7 +26,7 @@

import pyotp
from django.contrib.auth import logout as auth_logout
from django.db import transaction, IntegrityError
from django.db import router, transaction, IntegrityError
from django.db.models import Q
from django.contrib.auth import update_session_auth_hash, get_user_model
from rest_framework import generics, parsers, status
Expand Down Expand Up @@ -295,7 +295,8 @@ def perform_destroy(self, instance):
email = '%s%s' % (slug, look.group(1))
# We are deleting a `User` model. Let's unlink the `Contact`
# info but otherwise leave the poor-man's CRM's data intact.
with transaction.atomic():
#pylint:disable=protected-access
with transaction.atomic(using=user._state.db):
user.contacts.all().update(user=None)
self.delete_records(user)
requires_logout = (self.request.user == user)
Expand Down Expand Up @@ -366,9 +367,10 @@ def perform_update(self, serializer):
raise ValidationError({'detail':
_("Phone verification code does not match.")})

with transaction.atomic():
user = self.get_object()
try:
user = self.get_object()
try:
#pylint:disable=protected-access
with transaction.atomic(using=user._state.db):
if user.pk:
saves_user = False
if slug:
Expand Down Expand Up @@ -464,8 +466,8 @@ def perform_update(self, serializer):
Contact.objects.update_or_create(
slug=self.kwargs.get(self.lookup_url_kwarg),
defaults=update_fields)
except IntegrityError as err:
handle_uniq_error(err)
except IntegrityError as err:
handle_uniq_error(err)
# A little patchy but it works. Otherwise we would need to override
# `update` as well.
#pylint:disable=pointless-statement,protected-access
Expand Down Expand Up @@ -683,7 +685,7 @@ def create(self, request, *args, **kwargs):
serializer.validated_data.get('get_phone'))
lang = serializer.validated_data.get('lang',
serializer.validated_data.get('get_lang'))
with transaction.atomic():
with transaction.atomic(using=router.db_for_write(Contact)):
try:
user = user_model.objects.get(
email__iexact=serializer.validated_data.get('email'))
Expand Down Expand Up @@ -985,7 +987,8 @@ def update(self, request, *args, **kwargs):

def perform_update(self, serializer):
notification_slugs = serializer.validated_data.get('notifications', [])
with transaction.atomic():
#pylint:disable=protected-access
with transaction.atomic(using=self.user._state.db):
self.user.notifications.clear()
for notification_slug in six.iterkeys(self.get_notifications(
user=self.user)):
Expand Down Expand Up @@ -1052,7 +1055,7 @@ def post(self, request, *args, **kwargs):
"", "", ""))
location = self.request.build_absolute_uri(location)
user_model = self.user_queryset.model
with transaction.atomic():
with transaction.atomic(using=router.db_for_write(Contact)):
try:
user = user_model.objects.get(
username=self.kwargs.get(self.lookup_url_kwarg))
Expand Down
45 changes: 29 additions & 16 deletions signup/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,32 @@ def find_candidate(self, **cleaned_data):
username = cleaned_data.get('username')
email = cleaned_data.get('email')
phone = cleaned_data.get('phone')
if not username:
username = email
if not username:
username = phone
try:
user = self.model.objects.find_user(username)

if not email:
email = user.email
if username:
try:
user = self.model.objects.find_user(username)
if not email:
email = user.email
except self.model.DoesNotExist:
user = None

if not user and email:
try:
user = self.model.objects.find_user(email)
except self.model.DoesNotExist:
user = None

except self.model.DoesNotExist:
user = None
if not user and phone:
try:
user = self.model.objects.find_user(phone)
if not email:
email = user.email
except self.model.DoesNotExist:
user = None

return user, email


def auth_check_disabled(self, user):
auth_disabled = get_disabled_authentication(self.request, user)
if auth_disabled:
Expand Down Expand Up @@ -299,24 +310,24 @@ def run_pipeline(self):
# Login, Verify: Check if auth is disabled for User, or
# auth disabled globally if we only have a Contact
self.auth_check_disabled(user)
LOGGER.debug("[run_pipeline] auth_check_disabled user=%s", user)
LOGGER.debug("[run_pipeline] auth_check_disabled(user=%s)", user)

# Login, Verify: Auth rate-limiter
self.check_user_throttles(self.request, user)
LOGGER.debug("[run_pipeline] check_user_throttles user=%s", user)
LOGGER.debug("[run_pipeline] check_user_throttles(user=%s)", user)

# Login, Verify, Register:
# Redirects if email requires SSO
self.check_sso_required(email)
LOGGER.debug("[run_pipeline] check_sso_required email=%s", email)
LOGGER.debug("[run_pipeline] check_sso_required(email=%s)", email)

# Login: If login by verifying e-mail or phone, send code
# Else check password
#pylint:disable=assignment-from-none
user_with_backend = self.check_password(user, **cleaned_data)
LOGGER.debug(
"[run_pipeline] check_password(%s) returned user_with_backend=%s",
user, user_with_backend)
LOGGER.debug("[run_pipeline] check_password(user=%s, cleaned_data=%s)"\
" returned user_with_backend=%s",
user, cleaned_data, user_with_backend)

# Login, Verify: If required, check 2FA
self.auth_check_mfa(user, **cleaned_data)
Expand All @@ -331,6 +342,8 @@ def run_pipeline(self):
# want those to fall through `check_password`, end up here
# and render the link unusable before someone can click on it.
#pylint:disable=assignment-from-none
LOGGER.debug("[run_pipeline] create_user(cleaned_data=%s)",
cleaned_data)
user_with_backend = self.create_user(**cleaned_data)

if not user_with_backend:
Expand Down
16 changes: 8 additions & 8 deletions signup/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def create_user_from_email(self, email, password=None, **kwargs):
username_base = username
while trials < 10:
try:
with transaction.atomic():
with transaction.atomic(using=self._db):
return super(ActivatedUserManager, self).create_user(
username, email=email, password=password, **kwargs)
except IntegrityError as exp:
Expand Down Expand Up @@ -142,7 +142,7 @@ def create_user_from_phone(self, phone, password=None, **kwargs):
username_base = username
while trials < 10:
try:
with transaction.atomic():
with transaction.atomic(using=self._db):
return super(ActivatedUserManager, self).create_user(
username, password=password, **kwargs)
except IntegrityError as exp:
Expand Down Expand Up @@ -183,7 +183,7 @@ def create_user(self, username, email=None, password=None, **kwargs):
# By definition, all users created through a SSO provider
# are active.
password = "*****"
with transaction.atomic():
with transaction.atomic(using=self._db):
if username:
user = super(ActivatedUserManager, self).create_user(
username, email=email, password=password, **kwargs)
Expand Down Expand Up @@ -334,7 +334,7 @@ def prepare_email_verification(self, email, user=None, at_time=None,

if contact:
created = False
with transaction.atomic():
with transaction.atomic(using=self._db):
# We have to wrap in a transaction.atomic here, otherwise
# we end-up with a TransactionManager error when Contact.slug
# already exists in db and we generate new one.
Expand Down Expand Up @@ -407,7 +407,7 @@ def prepare_phone_verification(self, phone, user=None, at_time=None,

if contact:
created = False
with transaction.atomic():
with transaction.atomic(using=self._db):
# We have to wrap in a transaction.atomic here, otherwise
# we end-up with a TransactionManager error when Contact.slug
# already exists in db and we generate new one.
Expand Down Expand Up @@ -503,7 +503,7 @@ def activate_user(self, verification_key,
'email_verification_key': token.email_verification_key,
'phone_verification_key': token.phone_verification_key})
user_model = get_user_model()
with transaction.atomic():
with transaction.atomic(using=self._db):
if token.email_verification_key == verification_key:
token.email_verification_key = None
token.email_verified_at = at_time
Expand Down Expand Up @@ -718,10 +718,10 @@ def save(self, force_insert=False, force_update=False,
except IntegrityError as err:
if 'uniq' not in str(err).lower():
raise
handle_uniq_error(err)
handle_uniq_error(err) # could also be due to email or phone
except DRFValidationError as err:
if not 'slug' in err.detail:
raise err
raise
if len(slug_base) + 8 > max_length:
slug_base = slug_base[:(max_length - 8)]
self.slug = generate_random_slug(
Expand Down
6 changes: 4 additions & 2 deletions signup/views/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def form_valid(self, form):
if not contact:
contact = self.user.contacts.order_by('pk').first()
failed = False
with transaction.atomic():
#pylint:disable=protected-access
with transaction.atomic(using=self.user._state.db):
# `form.save(commit=False)` will copy the form fields values
# to the instance without committing to the database.
# `update_db_row` will commit to the database.
Expand Down Expand Up @@ -174,7 +175,8 @@ class UserNotificationsView(UserMixin, UpdateView):
template_name = 'users/notifications.html'

def form_valid(self, form):
with transaction.atomic():
#pylint:disable=protected-access
with transaction.atomic(using=self.user._state.db):
notifications = self.get_initial().get('notifications')
self.user.notifications.clear()
for notification_slug, enabled in six.iteritems(form.cleaned_data):
Expand Down

0 comments on commit 85542f9

Please sign in to comment.