diff --git a/contentcuration/contentcuration/tests/viewsets/test_user.py b/contentcuration/contentcuration/tests/viewsets/test_user.py index 1f801bd182..4e050888d2 100644 --- a/contentcuration/contentcuration/tests/viewsets/test_user.py +++ b/contentcuration/contentcuration/tests/viewsets/test_user.py @@ -6,6 +6,7 @@ from contentcuration.models import Change from contentcuration.tests import testdata from contentcuration.tests.base import StudioAPITestCase +from contentcuration.tests.helpers import reverse_with_query from contentcuration.tests.viewsets.base import generate_create_event from contentcuration.tests.viewsets.base import generate_delete_event from contentcuration.tests.viewsets.base import SyncTestMixin @@ -367,6 +368,28 @@ def test_fetch_users_no_permissions(self): self.assertEqual(response.status_code, 200, response.content) self.assertEqual(response.json(), []) + def test_remove_self_with_invalid_channel_id_returns_bad_request(self): + self.client.force_authenticate(user=self.user) + response = self.client.delete( + reverse_with_query( + "channeluser-remove-self", + kwargs={"pk": self.user.id}, + query={"channel_id": "not-a-valid-uuid"}, + ) + ) + self.assertEqual(response.status_code, 400, response.content) + + def test_remove_self_with_missing_channel_returns_not_found(self): + self.client.force_authenticate(user=self.user) + response = self.client.delete( + reverse_with_query( + "channeluser-remove-self", + kwargs={"pk": self.user.id}, + query={"channel_id": "00000000-0000-0000-0000-000000000000"}, + ) + ) + self.assertEqual(response.status_code, 404, response.content) + class MarkReadNotificationsTimestampTestCase(StudioAPITestCase): def setUp(self): diff --git a/contentcuration/contentcuration/viewsets/user.py b/contentcuration/contentcuration/viewsets/user.py index 81855739aa..126a342319 100644 --- a/contentcuration/contentcuration/viewsets/user.py +++ b/contentcuration/contentcuration/viewsets/user.py @@ -1,5 +1,6 @@ import csv import logging +import uuid from datetime import date from functools import reduce @@ -341,6 +342,10 @@ def remove_self(self, request, pk=None): if not channel_id: return HttpResponseBadRequest("Channel ID is required.") + try: + channel_id = uuid.UUID(channel_id).hex + except ValueError: + return HttpResponseBadRequest("Invalid channel ID") try: channel = Channel.objects.get(id=channel_id)