FAQ | This is a LIVE service | Changelog

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • uis/devops/django/api-gateway-auth
1 result
Show changes
Commits on Source (4)
Showing
with 277 additions and 170 deletions
...@@ -7,4 +7,4 @@ insert_final_newline = true ...@@ -7,4 +7,4 @@ insert_final_newline = true
# 4 space indentation # 4 space indentation
[*.py] [*.py]
indent_style = space indent_style = space
indent_size = 4 indent_size = 4
\ No newline at end of file
[flake8] [flake8]
extend-ignore = E203
max-line-length = 99 max-line-length = 99
exclude = venv,env,.tox,*/migrations/*,*/frontend/*,build/*,.venv exclude = venv,env,.tox,*/migrations/*,*/frontend/*,build/*,.venv
...@@ -2,19 +2,21 @@ ...@@ -2,19 +2,21 @@
include: include:
- project: 'uis/devops/continuous-delivery/ci-templates' - project: 'uis/devops/continuous-delivery/ci-templates'
file: '/auto-devops/common-pipeline.yml' file: '/auto-devops/common-pipeline.yml'
ref: v2.4.0 ref: v3.0.0
variables: variables:
# we don't have an application for DAST to run against, so disable it # we don't have an application for DAST to run against, so disable it
DAST_DISABLED: "true" DAST_DISABLED: "1"
# we don't build any Docker images
BUILD_DISABLED: "1"
python:tox: python:tox:
parallel: parallel:
matrix: matrix:
- DJANGO_VERSION: ["django3.2", "django4.1", "django4.2"] - DJANGO_VERSION: ["django3.2", "django4.1", "django4.2"]
PYTHON_VERSION: !reference [".python:versions"] PYTHON_VERSION: ["3.10", "3.11"]
- TOX_ENV: flake8 TOX_ENV: py3 # Bare "py3" required to upload coverage and unit test reports.
- TOX_ENV: py3 # Bare "py3" required to upload coverage and unit test reports.
variables: variables:
TOX_ENV: py3-$DJANGO_VERSION TOX_ENV: py3-$DJANGO_VERSION
TOX_OPTS: -e $TOX_ENV TOX_OPTS: -e $TOX_ENV
TOX_ADDITIONAL_REQUIREMENTS: poetry
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
args:
- --unsafe
- id: check-json
- id: check-toml
- id: check-xml
- id: check-added-large-files
- id: check-executables-have-shebangs
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
- id: mixed-line-ending
- id: pretty-format-json
args:
- --autofix
- --no-sort-keys
- id: debug-statements
- repo: https://github.com/python-poetry/poetry
rev: 1.5.1
hooks:
- id: poetry-check
- repo: https://github.com/editorconfig-checker/editorconfig-checker.python
rev: 2.7.2
hooks:
- id: editorconfig-checker
args: ["-disable-indent-size"]
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
hooks:
- id: mypy
additional_dependencies: ["types-PyYAML"]
# Changelog # Changelog
## 0.0.4
Added:
- Repackaged using poetry.
- Aligned code style with black and isort by means of pre-commit checks.
## 0.0.3 ## 0.0.3
Added: Added:
......
# This Dockerfile is just used for testing purposes and therefore builds in tox
# to the image.
FROM python:3.10
WORKDIR /usr/src/app
ADD . .
RUN pip install --upgrade pip
RUN pip install tox
RUN pip install -r requirements.txt
RUN python setup.py sdist bdist_wheel
...@@ -8,6 +8,17 @@ This is a library which contains: ...@@ -8,6 +8,17 @@ This is a library which contains:
* a set of permissions classes allowing the authentication information provided by the API Gateway * a set of permissions classes allowing the authentication information provided by the API Gateway
to be used in authorization decisions throughout an API-based app, to be used in authorization decisions throughout an API-based app,
## Library developer quick start
This library is packages using `poetry` and uses our [common Python CI
pipeline](https://gitlab.developers.cam.ac.uk/uis/devops/continuous-delivery/ci-templates/-/blob/master/auto-devops/python.md).
Make sue that `poetry` is installed and bootstrap your local environment via:
```console
$ poetry install
$ poetry run pre-commit install
```
## Required settings ## Required settings
The following Django settings are required to allow this library to be used: The following Django settings are required to allow this library to be used:
......
from .api_gateway_auth import ( # noqa: F401 from .api_gateway_auth import APIGatewayAuthentication # noqa: F401
APIGatewayAuthentication, from .api_gateway_auth import APIGatewayAuthenticationDetails # noqa: F401
APIGatewayAuthenticationDetails,
)
default_app_config = 'apigatewayauth.apps.APIGatewayAuthConfig' default_app_config = "apigatewayauth.apps.APIGatewayAuthConfig"
from typing import Optional, Set
from dataclasses import dataclass from dataclasses import dataclass
from rest_framework.request import Request from typing import Optional, Set
from identitylib.identifiers import Identifier
from identitylib.identifiers import Identifier
from rest_framework import authentication from rest_framework import authentication
from rest_framework.exceptions import AuthenticationFailed from rest_framework.exceptions import AuthenticationFailed
from rest_framework.request import Request
@dataclass(eq=True) @dataclass(eq=True)
...@@ -31,26 +31,25 @@ class APIGatewayAuthentication(authentication.BaseAuthentication): ...@@ -31,26 +31,25 @@ class APIGatewayAuthentication(authentication.BaseAuthentication):
""" """
def authenticate(self, request: Request): def authenticate(self, request: Request):
if not request.META.get('HTTP_X_API_ORG_NAME', None): if not request.META.get("HTTP_X_API_ORG_NAME", None):
# bail early if we look like we're not being called by the API Gateway # bail early if we look like we're not being called by the API Gateway
return None return None
if not request.META.get('HTTP_X_API_OAUTH2_USER', None): if not request.META.get("HTTP_X_API_OAUTH2_USER", None):
raise AuthenticationFailed('Could not authenticate using x-api-* headers') raise AuthenticationFailed("Could not authenticate using x-api-* headers")
try: try:
principal_identifier = Identifier.from_string( principal_identifier = Identifier.from_string(
request.META['HTTP_X_API_OAUTH2_USER'], request.META["HTTP_X_API_OAUTH2_USER"], find_by_alias=True
find_by_alias=True
) )
except Exception: except Exception:
raise AuthenticationFailed('Invalid principal identifier') raise AuthenticationFailed("Invalid principal identifier")
auth = APIGatewayAuthenticationDetails( auth = APIGatewayAuthenticationDetails(
principal_identifier=principal_identifier, principal_identifier=principal_identifier,
scopes=set(filter(bool, request.META.get('HTTP_X_API_OAUTH2_SCOPE', '').split(' '))), scopes=set(filter(bool, request.META.get("HTTP_X_API_OAUTH2_SCOPE", "").split(" "))),
# the following will only be populated for confidential clients # the following will only be populated for confidential clients
app_id=request.META.get('HTTP_X_API_DEVELOPER_APP_ID', None), app_id=request.META.get("HTTP_X_API_DEVELOPER_APP_ID", None),
client_id=request.META.get('HTTP_X_API_OAUTH2_CLIENT_ID', None) client_id=request.META.get("HTTP_X_API_OAUTH2_CLIENT_ID", None),
) )
# the first item in the tuple represents the 'user' which we don't have when we've # the first item in the tuple represents the 'user' which we don't have when we've
# used the API Gateway for authentication. # used the API Gateway for authentication.
......
...@@ -2,5 +2,5 @@ from django.apps import AppConfig ...@@ -2,5 +2,5 @@ from django.apps import AppConfig
class APIGatewayAuthConfig(AppConfig): class APIGatewayAuthConfig(AppConfig):
name = 'apigatewayauth' name = "apigatewayauth"
verbose_name = 'API Gateway Authentication App' verbose_name = "API Gateway Authentication App"
from logging import getLogger
from typing import Set from typing import Set
from django.core.cache import cache from django.core.cache import cache
from rest_framework import permissions, request
from identitylib.identifiers import IdentifierSchemes from identitylib.identifiers import IdentifierSchemes
from rest_framework import permissions, request
from ucamlookup.ibisclient import IbisException, PersonMethods
from ucamlookup.utils import get_connection from ucamlookup.utils import get_connection
from ucamlookup.ibisclient import PersonMethods, IbisException
from logging import getLogger
from .api_gateway_auth import APIGatewayAuthenticationDetails from .api_gateway_auth import APIGatewayAuthenticationDetails
from .permissions_spec import ( from .permissions_spec import (
get_groups_with_permission, get_permission_spec, get_principals_with_permission get_groups_with_permission,
get_permission_spec,
get_principals_with_permission,
) )
LOG = getLogger(__name__) LOG = getLogger(__name__)
...@@ -23,6 +23,7 @@ class Disallowed(permissions.BasePermission): ...@@ -23,6 +23,7 @@ class Disallowed(permissions.BasePermission):
class to stop routes being added which accidentally expose data. class to stop routes being added which accidentally expose data.
""" """
def has_permission(self, request, view): def has_permission(self, request, view):
return False return False
...@@ -37,7 +38,7 @@ class IsResourceOwningPrincipal(permissions.BasePermission): ...@@ -37,7 +38,7 @@ class IsResourceOwningPrincipal(permissions.BasePermission):
""" """
message = 'Please authenticate as the owning user using the API Gateway.' message = "Please authenticate as the owning user using the API Gateway."
@staticmethod @staticmethod
def get_queryset_for_principal(request, base_object): def get_queryset_for_principal(request, base_object):
...@@ -46,39 +47,36 @@ class IsResourceOwningPrincipal(permissions.BasePermission): ...@@ -46,39 +47,36 @@ class IsResourceOwningPrincipal(permissions.BasePermission):
if required if required
""" """
if not getattr(request, 'should_limit_to_resource_owning_principal', False): if not getattr(request, "should_limit_to_resource_owning_principal", False):
return base_object.objects.all() return base_object.objects.all()
if not isinstance(getattr(request, 'auth', None), APIGatewayAuthenticationDetails): if not isinstance(getattr(request, "auth", None), APIGatewayAuthenticationDetails):
return base_object.objects.none() return base_object.objects.none()
if not callable(getattr(base_object, 'get_queryset_for_principal', None)): if not callable(getattr(base_object, "get_queryset_for_principal", None)):
raise ValueError( raise ValueError(f"{base_object} does not implement get_queryset_for_principal")
f'{base_object} does not implement get_queryset_for_principal'
)
return base_object.get_queryset_for_principal(request.auth.principal_identifier) return base_object.get_queryset_for_principal(request.auth.principal_identifier)
def has_permission(self, request, view): def has_permission(self, request, view):
# we cannot determine permissions ownership on list routes, but rely on # we cannot determine permissions ownership on list routes, but rely on
# `get_queryset_for_principal` to be used to filter the queryset appropriately # `get_queryset_for_principal` to be used to filter the queryset appropriately
if isinstance(getattr(request, 'auth', None), APIGatewayAuthenticationDetails): if isinstance(getattr(request, "auth", None), APIGatewayAuthenticationDetails):
setattr(request, 'should_limit_to_resource_owning_principal', True) setattr(request, "should_limit_to_resource_owning_principal", True)
return True return True
return False return False
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):
if not isinstance(getattr(request, 'auth', None), APIGatewayAuthenticationDetails): if not isinstance(getattr(request, "auth", None), APIGatewayAuthenticationDetails):
return False return False
is_owned_by = getattr(obj, 'is_owned_by', None) is_owned_by = getattr(obj, "is_owned_by", None)
if not callable(is_owned_by): if not callable(is_owned_by):
LOG.warn(f'Unable to determine ownership for {obj}') LOG.warn(f"Unable to determine ownership for {obj}")
return False return False
return is_owned_by(request.auth.principal_identifier) return is_owned_by(request.auth.principal_identifier)
def HasAnyScope(*required_scopes): def HasAnyScope(*required_scopes):
class HasAnyScopesPermission(permissions.BasePermission): class HasAnyScopesPermission(permissions.BasePermission):
""" """
A permissions class which enforces that the given request has any of the given scopes. A permissions class which enforces that the given request has any of the given scopes.
...@@ -88,7 +86,7 @@ def HasAnyScope(*required_scopes): ...@@ -88,7 +86,7 @@ def HasAnyScope(*required_scopes):
message = f'Request must have one of the following scope(s) {" ".join(required_scopes)}' message = f'Request must have one of the following scope(s) {" ".join(required_scopes)}'
def has_permission(self, request, view): def has_permission(self, request, view):
request_scopes = getattr(getattr(request, 'auth', {}), 'scopes', set()) request_scopes = getattr(getattr(request, "auth", {}), "scopes", set())
return len(set(required_scopes) & request_scopes) > 0 return len(set(required_scopes) & request_scopes) > 0
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):
...@@ -98,7 +96,6 @@ def HasAnyScope(*required_scopes): ...@@ -98,7 +96,6 @@ def HasAnyScope(*required_scopes):
def SpecifiedPermission(permission: str): def SpecifiedPermission(permission: str):
class HasSpecifiedPermission(permissions.BasePermission): class HasSpecifiedPermission(permissions.BasePermission):
""" """
A permissions class which ensures that the principal has the correct permissions A permissions class which ensures that the principal has the correct permissions
...@@ -106,7 +103,7 @@ def SpecifiedPermission(permission: str): ...@@ -106,7 +103,7 @@ def SpecifiedPermission(permission: str):
""" """
message = f'Authenticated principal does not have permission {permission}' message = f"Authenticated principal does not have permission {permission}"
def has_permission(self, request, view): def has_permission(self, request, view):
principals_with_permission = get_principals_with_permission(permission) principals_with_permission = get_principals_with_permission(permission)
...@@ -114,16 +111,19 @@ def SpecifiedPermission(permission: str): ...@@ -114,16 +111,19 @@ def SpecifiedPermission(permission: str):
return True return True
if request.auth.principal_identifier.scheme != IdentifierSchemes.CRSID: if request.auth.principal_identifier.scheme != IdentifierSchemes.CRSID:
LOG.warn('Can only determine group membership for principals identified by CRSID') LOG.warn("Can only determine group membership for principals identified by CRSID")
return False return False
# special case for people identified by crsid - check whether they are in a # special case for people identified by crsid - check whether they are in a
# lookup group within our list of identities for permission # lookup group within our list of identities for permission
groups_with_permission = get_groups_with_permission(permission) groups_with_permission = get_groups_with_permission(permission)
lookup_group_ids = set([ lookup_group_ids = set(
identifier.value for identifier in groups_with_permission [
if identifier.scheme == IdentifierSchemes.LOOKUP_GROUP identifier.value
]) for identifier in groups_with_permission
if identifier.scheme == IdentifierSchemes.LOOKUP_GROUP
]
)
if not lookup_group_ids: if not lookup_group_ids:
return False return False
...@@ -148,14 +148,14 @@ def SpecifiedPermission(permission: str): ...@@ -148,14 +148,14 @@ def SpecifiedPermission(permission: str):
is_in_group = False is_in_group = False
try: try:
group_list = PersonMethods( group_list = PersonMethods(get_connection()).getGroups(
get_connection() scheme="crsid", identifier=crsid
).getGroups(scheme="crsid", identifier=crsid) )
is_in_group = any( is_in_group = any(
(group.groupid for group in group_list if group.groupid in group_ids) (group.groupid for group in group_list if group.groupid in group_ids)
) )
except IbisException as err: except IbisException as err:
LOG.warn(f'Failed to get Lookup groups for {crsid} due to {err}') LOG.warn(f"Failed to get Lookup groups for {crsid} due to {err}")
return False return False
cache.set(cache_key, is_in_group, timeout=600) cache.set(cache_key, is_in_group, timeout=600)
...@@ -172,7 +172,12 @@ def get_permissions_for_request(req: request.Request): ...@@ -172,7 +172,12 @@ def get_permissions_for_request(req: request.Request):
""" """
return [ return (
permission_name for permission_name in get_permission_spec().keys() if [
SpecifiedPermission(permission_name)().has_permission(req, None) permission_name
] if isinstance(req.auth, APIGatewayAuthenticationDetails) else [] for permission_name in get_permission_spec().keys()
if SpecifiedPermission(permission_name)().has_permission(req, None)
]
if isinstance(req.auth, APIGatewayAuthenticationDetails)
else []
)
from typing import List, Dict, Set from typing import Dict, List, Set
from django.core.cache import cache
from django.conf import settings from django.conf import settings
from yaml import safe_load from django.core.cache import cache
from geddit import geddit from geddit import geddit
from identitylib.identifiers import Identifier from identitylib.identifiers import Identifier
from yaml import safe_load
PERMISSIONS_CACHE_KEY = "__PERMISSION_CACHE__"
PERMISSIONS_CACHE_KEY = '__PERMISSION_CACHE__'
def get_permission_spec() -> Dict[str, Dict[str, List]]: def get_permission_spec() -> Dict[str, Dict[str, List]]:
...@@ -38,7 +38,7 @@ def get_principals_with_permission(permission_name: str) -> Set[Identifier]: ...@@ -38,7 +38,7 @@ def get_principals_with_permission(permission_name: str) -> Set[Identifier]:
return set( return set(
map( map(
lambda identifier_str: Identifier.from_string(identifier_str, find_by_alias=True), lambda identifier_str: Identifier.from_string(identifier_str, find_by_alias=True),
get_permission_spec().get(permission_name, {}).get('principals', []), get_permission_spec().get(permission_name, {}).get("principals", []),
) )
) )
...@@ -52,6 +52,6 @@ def get_groups_with_permission(permission_name: str) -> Set[Identifier]: ...@@ -52,6 +52,6 @@ def get_groups_with_permission(permission_name: str) -> Set[Identifier]:
return set( return set(
map( map(
lambda identifier_str: Identifier.from_string(identifier_str, find_by_alias=True), lambda identifier_str: Identifier.from_string(identifier_str, find_by_alias=True),
get_permission_spec().get(permission_name, {}).get('groups', []), get_permission_spec().get(permission_name, {}).get("groups", []),
) )
) )
default_app_config = 'apigatewayauth.tests.mocks.apps.MockAPIGatewayAuthConfig' default_app_config = "apigatewayauth.tests.mocks.apps.MockAPIGatewayAuthConfig"
...@@ -2,5 +2,5 @@ from django.apps import AppConfig ...@@ -2,5 +2,5 @@ from django.apps import AppConfig
class MockAPIGatewayAuthConfig(AppConfig): class MockAPIGatewayAuthConfig(AppConfig):
name = 'apigatewayauth.tests.mocks' name = "apigatewayauth.tests.mocks"
verbose_name = 'Mock testing app' verbose_name = "Mock testing app"
...@@ -2,19 +2,23 @@ from django.db import migrations, models ...@@ -2,19 +2,23 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies: list[str] = []
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='TestModel', name="TestModel",
fields=[ fields=[
('name', models.TextField(primary_key=True, serialize=False, verbose_name='Name')), (
('isAdmin', models.BooleanField(verbose_name='Is Admin')), "name",
('principal_identifier', models.TextField(verbose_name='Principal identifier')), models.TextField(primary_key=True, serialize=False, verbose_name="Name"),
),
("isAdmin", models.BooleanField(verbose_name="Is Admin")),
(
"principal_identifier",
models.TextField(verbose_name="Principal identifier"),
),
], ],
), ),
] ]
...@@ -13,6 +13,6 @@ class TestModel(models.Model): ...@@ -13,6 +13,6 @@ class TestModel(models.Model):
principal_identifier__iexact=principal_identifier.value, principal_identifier__iexact=principal_identifier.value,
) )
name = models.TextField('Name', 'name', primary_key=True) name = models.TextField("Name", "name", primary_key=True)
is_admin = models.BooleanField('Is Admin', 'isAdmin') is_admin = models.BooleanField("Is Admin", "isAdmin")
principal_identifier = models.TextField('Principal identifier') principal_identifier = models.TextField("Principal identifier")
from typing import Dict
from functools import wraps from functools import wraps
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import yaml from typing import Dict
import yaml
from django.test import override_settings from django.test import override_settings
def override_permission_spec(permissions_spec: Dict[str, Dict[str, str]]): def override_permission_spec(permissions_spec: Dict[str, Dict[str, set[str]]]):
""" """
A decorator which allows the permissions specification to be mocked, allowing a A decorator which allows the permissions specification to be mocked, allowing a
permission to only be enabled for the given identities. permission to only be enabled for the given identities.
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapped_function(*args, **kwargs): def wrapped_function(*args, **kwargs):
with NamedTemporaryFile('w+') as temp_file: with NamedTemporaryFile("w+") as temp_file:
yaml.dump(permissions_spec, temp_file.file) yaml.dump(permissions_spec, temp_file.file)
with override_settings(PERMISSIONS_SPECIFICATION_URL=temp_file.name): with override_settings(PERMISSIONS_SPECIFICATION_URL=temp_file.name):
func(*args, **kwargs) func(*args, **kwargs)
return wrapped_function return wrapped_function
return decorator return decorator
...@@ -15,4 +15,4 @@ ABSTRACT_DATA_READER: ...@@ -15,4 +15,4 @@ ABSTRACT_DATA_READER:
THOUGHT_CREATOR: THOUGHT_CREATOR:
principals: principals:
- 1234@application.api.apps.cam.ac.uk - 1234@application.api.apps.cam.ac.uk
\ No newline at end of file
...@@ -2,19 +2,23 @@ from django.db import migrations, models ...@@ -2,19 +2,23 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies: list[str] = []
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='TestModel', name="TestModel",
fields=[ fields=[
('name', models.TextField(primary_key=True, serialize=False, verbose_name='Name')), (
('isAdmin', models.BooleanField(verbose_name='Is Admin')), "name",
('principal_identifier', models.TextField(verbose_name='Principal identifier')), models.TextField(primary_key=True, serialize=False, verbose_name="Name"),
),
("isAdmin", models.BooleanField(verbose_name="Is Admin")),
(
"principal_identifier",
models.TextField(verbose_name="Principal identifier"),
),
], ],
), ),
] ]
from django.test import TestCase from django.test import TestCase
from identitylib.identifiers import Identifier, IdentifierSchemes from identitylib.identifiers import Identifier, IdentifierSchemes
from rest_framework.test import APIRequestFactory
from rest_framework.exceptions import AuthenticationFailed from rest_framework.exceptions import AuthenticationFailed
from rest_framework.test import APIRequestFactory
from apigatewayauth.api_gateway_auth import ( from apigatewayauth.api_gateway_auth import (
APIGatewayAuthentication, APIGatewayAuthenticationDetails APIGatewayAuthentication,
APIGatewayAuthenticationDetails,
) )
class APIGatewayAuthTestCase(TestCase): class APIGatewayAuthTestCase(TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
...@@ -18,10 +18,9 @@ class APIGatewayAuthTestCase(TestCase): ...@@ -18,10 +18,9 @@ class APIGatewayAuthTestCase(TestCase):
def request_with_headers(self, headers={}): def request_with_headers(self, headers={}):
parsed_headers = { parsed_headers = {
f'HTTP_{key.upper().replace("-", "_")}': value f'HTTP_{key.upper().replace("-", "_")}': value for key, value in headers.items()
for key, value in headers.items()
} }
return self.request_factory.get('/', **parsed_headers) return self.request_factory.get("/", **parsed_headers)
def test_bails_early_without_api_org(self): def test_bails_early_without_api_org(self):
self.assertIsNone( self.assertIsNone(
...@@ -30,104 +29,135 @@ class APIGatewayAuthTestCase(TestCase): ...@@ -30,104 +29,135 @@ class APIGatewayAuthTestCase(TestCase):
def test_throws_without_auth_details(self): def test_throws_without_auth_details(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(
AuthenticationFailed, 'Could not authenticate using x-api-* headers' AuthenticationFailed, "Could not authenticate using x-api-* headers"
): ):
self.auth.authenticate(self.request_with_headers({"x-api-org-name": "test"})) self.auth.authenticate(self.request_with_headers({"x-api-org-name": "test"}))
def test_throws_without_principal_identifier(self): def test_throws_without_principal_identifier(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(
AuthenticationFailed, 'Could not authenticate using x-api-* headers' AuthenticationFailed, "Could not authenticate using x-api-* headers"
): ):
self.auth.authenticate(self.request_with_headers({ self.auth.authenticate(
"x-api-org-name": "test", self.request_with_headers(
"x-api-developer-app-class": "public" {"x-api-org-name": "test", "x-api-developer-app-class": "public"}
})) )
)
def test_throws_with_bad_principal_identifier(self): def test_throws_with_bad_principal_identifier(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(AuthenticationFailed, "Invalid principal identifier"):
AuthenticationFailed, 'Invalid principal identifier' self.auth.authenticate(
): self.request_with_headers(
self.auth.authenticate(self.request_with_headers({ {
"x-api-org-name": "test", "x-api-org-name": "test",
"x-api-developer-app-class": "public", "x-api-developer-app-class": "public",
"x-api-oauth2-user": "Monty Dawson" "x-api-oauth2-user": "Monty Dawson",
})) }
)
)
def test_can_use_any_identifier_scheme_in_principal_identifier(self): def test_can_use_any_identifier_scheme_in_principal_identifier(self):
for scheme in IdentifierSchemes.get_registered_schemes(): for scheme in IdentifierSchemes.get_registered_schemes():
_, auth = self.auth.authenticate(self.request_with_headers({ _, auth = self.auth.authenticate(
"x-api-org-name": "test", self.request_with_headers(
"x-api-developer-app-class": "public", {
"x-api-oauth2-user": str(Identifier("1000", scheme)) "x-api-org-name": "test",
})) "x-api-developer-app-class": "public",
"x-api-oauth2-user": str(Identifier("1000", scheme)),
}
)
)
self.assertEqual(auth.principal_identifier, Identifier("1000", scheme)) self.assertEqual(auth.principal_identifier, Identifier("1000", scheme))
def test_throws_with_unknown_identifier_type(self): def test_throws_with_unknown_identifier_type(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(AuthenticationFailed, "Invalid principal identifier"):
AuthenticationFailed, 'Invalid principal identifier' self.auth.authenticate(
): self.request_with_headers(
self.auth.authenticate(self.request_with_headers({ {
"x-api-org-name": "test", "x-api-org-name": "test",
"x-api-developer-app-class": "public", "x-api-developer-app-class": "public",
"x-api-oauth2-user": 'wgd23@gmail.com' "x-api-oauth2-user": "wgd23@gmail.com",
})) }
)
)
def test_returns_client_details_for_valid_auth(self): def test_returns_client_details_for_valid_auth(self):
user, auth = self.auth.authenticate(self.request_with_headers({ user, auth = self.auth.authenticate(
"x-api-org-name": "test", self.request_with_headers(
"x-api-developer-app-class": "public", {
"x-api-oauth2-user": str(Identifier('a123', IdentifierSchemes.CRSID)) "x-api-org-name": "test",
})) "x-api-developer-app-class": "public",
"x-api-oauth2-user": str(Identifier("a123", IdentifierSchemes.CRSID)),
}
)
)
self.assertIsNone(user) self.assertIsNone(user)
self.assertEqual( self.assertEqual(
auth, auth,
APIGatewayAuthenticationDetails( APIGatewayAuthenticationDetails(
Identifier('a123', IdentifierSchemes.CRSID), Identifier("a123", IdentifierSchemes.CRSID),
set(), set(),
None, None,
None, None,
) ),
) )
def test_will_pass_through_scopes(self): def test_will_pass_through_scopes(self):
_, auth = self.auth.authenticate(self.request_with_headers({ _, auth = self.auth.authenticate(
"x-api-org-name": "test", self.request_with_headers(
"x-api-developer-app-class": "public", {
"x-api-oauth2-user": str(Identifier('a123', IdentifierSchemes.CRSID)), "x-api-org-name": "test",
"x-api-oauth2-scope": ( "x-api-developer-app-class": "public",
"https://api.apps.cam.ac.uk/a.readonly https://api.apps.cam.ac.uk/b" "x-api-oauth2-user": str(Identifier("a123", IdentifierSchemes.CRSID)),
"x-api-oauth2-scope": (
"https://api.apps.cam.ac.uk/a.readonly https://api.apps.cam.ac.uk/b"
),
}
) )
})) )
self.assertEqual( self.assertEqual(
auth, auth,
APIGatewayAuthenticationDetails( APIGatewayAuthenticationDetails(
Identifier('a123', IdentifierSchemes.CRSID), Identifier("a123", IdentifierSchemes.CRSID),
set(['https://api.apps.cam.ac.uk/a.readonly', 'https://api.apps.cam.ac.uk/b']), set(
[
"https://api.apps.cam.ac.uk/a.readonly",
"https://api.apps.cam.ac.uk/b",
]
),
None, None,
None, None,
), ),
) )
def test_will_pass_through_app_and_client_ids(self): def test_will_pass_through_app_and_client_ids(self):
_, auth = self.auth.authenticate(self.request_with_headers({ _, auth = self.auth.authenticate(
"x-api-org-name": "test", self.request_with_headers(
"x-api-developer-app-class": "confidential", {
"x-api-oauth2-user": str(Identifier('a123', IdentifierSchemes.CRSID)), "x-api-org-name": "test",
"x-api-oauth2-scope": ( "x-api-developer-app-class": "confidential",
"https://api.apps.cam.ac.uk/a.readonly https://api.apps.cam.ac.uk/b" "x-api-oauth2-user": str(Identifier("a123", IdentifierSchemes.CRSID)),
), "x-api-oauth2-scope": (
"x-api-developer-app-id": "app-uuid-mock", "https://api.apps.cam.ac.uk/a.readonly https://api.apps.cam.ac.uk/b"
"x-api-oauth2-client-id": "client-id-uuid-mock", ),
})) "x-api-developer-app-id": "app-uuid-mock",
"x-api-oauth2-client-id": "client-id-uuid-mock",
}
)
)
self.assertEqual( self.assertEqual(
auth, auth,
APIGatewayAuthenticationDetails( APIGatewayAuthenticationDetails(
Identifier('a123', IdentifierSchemes.CRSID), Identifier("a123", IdentifierSchemes.CRSID),
set(['https://api.apps.cam.ac.uk/a.readonly', 'https://api.apps.cam.ac.uk/b']), set(
'app-uuid-mock', [
'client-id-uuid-mock' "https://api.apps.cam.ac.uk/a.readonly",
) "https://api.apps.cam.ac.uk/b",
]
),
"app-uuid-mock",
"client-id-uuid-mock",
),
) )