diff --git a/.flake8 b/.flake8 index acfc542..f0135db 100644 --- a/.flake8 +++ b/.flake8 @@ -3,5 +3,7 @@ exclude = .git, __pycache__, build, + venv, + .venv env max-complexity = 10 diff --git a/docs/source/Contributing.rst b/docs/source/Contributing.rst index 2cf25f8..729b3c1 100644 --- a/docs/source/Contributing.rst +++ b/docs/source/Contributing.rst @@ -71,6 +71,7 @@ To run uvicorn locally: export ENV_NAME='local' export AWS_DEFAULT_REGION='us-west-2' export AIND_AIRFLOW_PARAM_PREFIX='/aind/dev/airflow/variables/job_types' + export AIND_SSO_SECRET_NAME='/aind/dev/data_transfer_service/sso/secrets' uvicorn aind_data_transfer_service.server:app --host 0.0.0.0 --port 5000 --reload You can now access aind-data-transfer-service at diff --git a/pyproject.toml b/pyproject.toml index 406a7b4..f85a06d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,8 @@ server = [ 'wtforms', 'requests==2.25.0', 'openpyxl', - 'python-logging-loki' + 'python-logging-loki', + 'authlib' ] [tool.setuptools.packages.find] diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index 5510751..fa7b1a5 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -15,6 +15,7 @@ __version__ as aind_data_transfer_models_version, ) from aind_data_transfer_models.core import SubmitJobRequest, validation_context +from authlib.integrations.starlette_client import OAuth from botocore.exceptions import ClientError from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse @@ -23,6 +24,9 @@ from openpyxl import load_workbook from pydantic import SecretStr, ValidationError from starlette.applications import Starlette +from starlette.config import Config +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import RedirectResponse from starlette.routing import Route from aind_data_transfer_service import OPEN_DATA_BUCKET_NAME @@ -95,6 +99,27 @@ def get_project_names() -> List[str]: return project_names +def set_oauth() -> OAuth: + """Set up OAuth for the service""" + secrets_client = boto3.client("secretsmanager") + secret_response = secrets_client.get_secret_value( + SecretId=os.getenv("AIND_SSO_SECRET_NAME") + ) + secret_value = json.loads(secret_response["SecretString"]) + for secrets in secret_value: + os.environ[secrets] = secret_value[secrets] + config = Config() + oauth = OAuth(config) + oauth.register( + name="azure", + client_id=config("CLIENT_ID"), + client_secret=config("CLIENT_SECRET"), + server_metadata_url=config("AUTHORITY"), + client_kwargs={"scope": "openid email profile"}, + ) + return oauth + + def get_job_types(version: Optional[str] = None) -> List[str]: """Get a list of job_types""" params = get_parameter_infos(version) @@ -1090,6 +1115,60 @@ def get_parameter(request: Request): ) +async def admin(request: Request): + """Get admin page if authenticated, else redirect to login.""" + user = request.session.get("user") + if os.getenv("ENV_NAME") == "local": + user = {"name": "local user"} + if user: + return templates.TemplateResponse( + name="admin.html", + context=( + { + "request": request, + "project_names_url": project_names_url, + "user_name": user.get("name", "unknown"), + "user_email": user.get("email", "unknown"), + } + ), + ) + return RedirectResponse(url="/login") + + +async def login(request: Request): + """Redirect to Azure login page""" + oauth = set_oauth() + redirect_uri = request.url_for("auth") + response = await oauth.azure.authorize_redirect(request, redirect_uri) + return response + + +async def logout(request: Request): + """Logout user and clear session""" + request.session.pop("user", None) + return RedirectResponse(url="/") + + +async def auth(request: Request): + """Authenticate user and store user info in session""" + oauth = set_oauth() + try: + token = await oauth.azure.authorize_access_token(request) + user = token.get("userinfo") + if not user: + raise ValueError("User info not found in access token.") + request.session["user"] = dict(user) + except Exception as error: + return JSONResponse( + content={ + "message": "Error Logging In", + "data": {"error": f"{error.__class__.__name__}{error.args}"}, + }, + status_code=500, + ) + return RedirectResponse(url="/admin") + + routes = [ Route("/", endpoint=index, methods=["GET", "POST"]), Route("/api/validate_csv", endpoint=validate_csv_legacy, methods=["POST"]), @@ -1132,6 +1211,11 @@ def get_parameter(request: Request): endpoint=download_job_template, methods=["GET"], ), + Route("/login", login, methods=["GET"]), + Route("/logout", logout, methods=["GET"]), + Route("/auth", auth, methods=["GET"]), + Route("/admin", admin, methods=["GET"]), ] app = Starlette(routes=routes) +app.add_middleware(SessionMiddleware, secret_key=None) diff --git a/src/aind_data_transfer_service/templates/admin.html b/src/aind_data_transfer_service/templates/admin.html new file mode 100644 index 0000000..fba27bc --- /dev/null +++ b/src/aind_data_transfer_service/templates/admin.html @@ -0,0 +1,36 @@ + + + + + + + {% block title %} {% endblock %} AIND Data Transfer Service Admin + + + + +
+

Admin

+
Hello {{user_name}}, welcome to the admin page
+
Email: {{user_email}}
+
+ + diff --git a/src/aind_data_transfer_service/templates/index.html b/src/aind_data_transfer_service/templates/index.html index 7bd6ba1..4d28aca 100644 --- a/src/aind_data_transfer_service/templates/index.html +++ b/src/aind_data_transfer_service/templates/index.html @@ -49,7 +49,8 @@ Job Parameters | Job Submit Template | Project Names | - Help + Help | + Admin
diff --git a/src/aind_data_transfer_service/templates/job_params.html b/src/aind_data_transfer_service/templates/job_params.html index 6ccd82e..0ad5fcb 100644 --- a/src/aind_data_transfer_service/templates/job_params.html +++ b/src/aind_data_transfer_service/templates/job_params.html @@ -34,7 +34,8 @@ Job Parameters | Job Submit Template | Project Names | - Help + Help | + Admin

diff --git a/src/aind_data_transfer_service/templates/job_status.html b/src/aind_data_transfer_service/templates/job_status.html index c316901..cf64972 100644 --- a/src/aind_data_transfer_service/templates/job_status.html +++ b/src/aind_data_transfer_service/templates/job_status.html @@ -32,7 +32,8 @@ Job Parameters | Job Submit Template | Project Names | - Help + Help | + Admin
diff --git a/tests/resources/get_secrets_response.json b/tests/resources/get_secrets_response.json new file mode 100644 index 0000000..328bd9d --- /dev/null +++ b/tests/resources/get_secrets_response.json @@ -0,0 +1,19 @@ +{ + "ARN": "arn_value", + "Name": "secret_name", + "VersionId": "version_id", + "SecretString": "{\"CLIENT_ID\":\"client_id\",\"CLIENT_SECRET\":\"client_secret\",\"AUTHORITY\":\"https://authority\"}", + "VersionStages": ["AWSCURRENT"], + "CreatedDate": "2025-04-15T16:44:07.279000Z", + "ResponseMetadata": { + "RequestId": "request_id", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "2b090d64-c92d-48c5-a43a-abf5696c815e", + "content-type": "application/x-amz-json-1.1", + "content-length": "748", + "date": "Wed, 23 Apr 2025 21:19:04 GMT" + }, + "RetryAttempts": 0 + } +} \ No newline at end of file diff --git a/tests/test_server.py b/tests/test_server.py index 92779d7..1590cb0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,7 +8,7 @@ from datetime import datetime, timedelta, timezone from io import BytesIO from pathlib import Path, PurePosixPath -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from aind_data_schema_models.modalities import Modality from aind_data_schema_models.platforms import Platform @@ -22,8 +22,9 @@ V0036JobProperties, ) from aind_data_transfer_models.trigger import TriggerConfigModel, ValidJobType +from authlib.integrations.starlette_client import OAuthError from botocore.exceptions import ClientError -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.testclient import TestClient from pydantic import SecretStr from requests import Response @@ -85,6 +86,9 @@ GET_PARAMETER_RESPONSE = ( TEST_DIRECTORY / "resources" / "get_parameter_response.json" ) +GET_SECRETS_RESPONSE = ( + TEST_DIRECTORY / "resources" / "get_secrets_response.json" +) class TestServer(unittest.TestCase): @@ -112,6 +116,7 @@ class TestServer(unittest.TestCase): "AIND_AIRFLOW_SERVICE_USER": "airflow_user", "AIND_AIRFLOW_SERVICE_PASSWORD": "airflow_password", "AIND_AIRFLOW_PARAM_PREFIX": "/param_prefix", + "AIND_SSO_SECRET_NAME": "/secret/name", } with open(SAMPLE_CSV, "r") as file: @@ -135,6 +140,9 @@ class TestServer(unittest.TestCase): with open(GET_PARAMETER_RESPONSE) as f: get_parameter_response = json.load(f) + with open(GET_SECRETS_RESPONSE) as f: + get_secrets_response = json.load(f) + expected_job_configs = deepcopy(TestJobConfigs.expected_job_configs) for config in expected_job_configs: config.aws_param_store_name = None @@ -1562,6 +1570,52 @@ def test_job_params(self): self.assertEqual(response.status_code, 200) self.assertIn("Job Parameters", response.text) + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("fastapi.Request.session") + def test_admin(self, mock_session: MagicMock): + """Tests that the admin page renders when user is authenticated.""" + expected_user = {"name": "test_user", "email": "test_email"} + mock_session.get.return_value = expected_user + with TestClient(app) as client: + response = client.get("/admin") + mock_session.get.assert_called_once_with("user") + self.assertEqual(response.status_code, 200) + self.assertIn("Admin", response.text) + self.assertIn("test_user", response.text) + + @patch.dict( + os.environ, {**EXAMPLE_ENV_VAR1, "ENV_NAME": "local"}, clear=True + ) + def test_admin_local(self): + """Tests that the admin page renders when user is authenticated.""" + with TestClient(app) as client: + response = client.get("/admin") + self.assertEqual(response.status_code, 200) + self.assertIn("Admin", response.text) + self.assertIn("local user", response.text) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("fastapi.Request.session") + @patch("aind_data_transfer_service.server.RedirectResponse") + def test_admin_unauthenticated( + self, mock_redirect_response: MagicMock, mock_session: MagicMock + ): + """Tests that the admin page redirects to login if user is not + authenticated.""" + expected_user = None + mock_session.get.return_value = expected_user + mock_redirect_response.return_value = JSONResponse( + content={ + "message": "Redirecting to login page", + "data": None, + }, + status_code=307, + ) + with TestClient(app) as client: + response = client.get("/admin") + mock_redirect_response.assert_called_once_with(url="/login") + self.assertEqual(response.status_code, 307) + @patch("aind_data_transfer_service.server.JobUploadTemplate") @patch("logging.Logger.exception") def test_download_invalid_job_template( @@ -2374,6 +2428,124 @@ def test_validate_json_error( mock_get_job_types.assert_called_once_with("v2") self.assertEqual(2, mock_get_project_names.call_count) + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("boto3.client") + @patch("aind_data_transfer_service.server.OAuth") + def test_login( + self, mock_set_oauth: MagicMock, mock_secrets_client: MagicMock + ): + """Tests the login function.""" + mock_set_oauth.return_value.azure.authorize_redirect = AsyncMock( + return_value=JSONResponse( + content={ + "message": "mock_redirect_url", + }, + status_code=200, + ) + ) + mock_secrets_client.return_value.get_secret_value.return_value = ( + self.get_secrets_response + ) + with TestClient(app) as client: + response = client.get("/login") + mock_secrets_client.assert_called_with("secretsmanager") + mock_secrets_client.return_value.get_secret_value.assert_called_with( + SecretId="/secret/name" + ) + mock_set_oauth.return_value.register.assert_called_with( + name="azure", + client_id="client_id", + client_secret="client_secret", + server_metadata_url="https://authority", + client_kwargs={"scope": "openid email profile"}, + ) + mock_oauth = mock_set_oauth.return_value + mock_azure = mock_oauth.azure + mock_azure.authorize_redirect.assert_called_once() + self.assertEqual(response.status_code, 200) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("aind_data_transfer_service.server.RedirectResponse") + @patch("fastapi.Request.session") + def test_logout(self, mock_session: MagicMock, mock_redirect: MagicMock): + """Tests logout clears user from session and redirects to index.""" + expected_user = {"name": "test_user", "email": "test_email"} + mock_session.get.return_value = expected_user + mock_redirect.return_value = JSONResponse( + content={"message": "Redirecting to index"}, + status_code=307, + ) + with TestClient(app) as client: + response = client.get("/logout") + mock_redirect.assert_called_once_with(url="/") + self.assertEqual(response.status_code, 307) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("boto3.client") + @patch("aind_data_transfer_service.server.OAuth") + def test_auth( + self, mock_set_oauth: MagicMock, mock_secrets_client: MagicMock + ): + """Tests the auth callback function.""" + mock_set_oauth.return_value.azure.authorize_access_token = AsyncMock( + return_value={"userinfo": {"some_user": "info"}} + ) + mock_secrets_client.return_value.get_secret_value.return_value = ( + self.get_secrets_response + ) + with TestClient(app) as client: + response = client.get("/auth") + mock_oauth = mock_set_oauth.return_value + mock_azure = mock_oauth.azure + mock_azure.authorize_access_token.assert_called_once() + self.assertEqual(response.status_code, 200) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("boto3.client") + @patch("aind_data_transfer_service.server.OAuth") + def test_auth_error( + self, mock_set_oauth: MagicMock, mock_secrets_client: MagicMock + ): + """Tests an error in the auth callback function.""" + mock_set_oauth.return_value.azure.authorize_access_token = AsyncMock( + side_effect=OAuthError("Error Logging In") + ) + mock_secrets_client.return_value.get_secret_value.return_value = ( + self.get_secrets_response + ) + with TestClient(app) as client: + response = client.get("/auth") + expected_response = { + "message": "Error Logging In", + "data": {"error": "OAuthError('Error Logging In: ',)"}, + } + self.assertEqual(response.status_code, 500) + self.assertEqual(response.json(), expected_response) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("boto3.client") + @patch("aind_data_transfer_service.server.OAuth") + def test_auth_error_userinfo( + self, mock_set_oauth: MagicMock, mock_secrets_client: MagicMock + ): + """Tests the auth callback function when userinfo is not provided.""" + mock_set_oauth.return_value.azure.authorize_access_token = AsyncMock( + return_value={"invalid": {"some_user": "info"}} + ) + mock_secrets_client.return_value.get_secret_value.return_value = ( + self.get_secrets_response + ) + with TestClient(app) as client: + response = client.get("/auth") + expected_response = { + "message": "Error Logging In", + "data": { + "error": "ValueError('User info not found in access token.',)" + }, + } + self.assertEqual(response.status_code, 500) + self.assertEqual(response.json(), expected_response) + if __name__ == "__main__": unittest.main()