diff --git a/.flake8 b/.flake8 index acfc542..f0135db 100644 --- a/.flake8 +++ b/.flake8 @@ -3,5 +3,7 @@ exclude = .git, __pycache__, build, + venv, + .venv env max-complexity = 10 diff --git a/docs/source/Contributing.rst b/docs/source/Contributing.rst index 2cf25f8..729b3c1 100644 --- a/docs/source/Contributing.rst +++ b/docs/source/Contributing.rst @@ -71,6 +71,7 @@ To run uvicorn locally: export ENV_NAME='local' export AWS_DEFAULT_REGION='us-west-2' export AIND_AIRFLOW_PARAM_PREFIX='/aind/dev/airflow/variables/job_types' + export AIND_SSO_SECRET_NAME='/aind/dev/data_transfer_service/sso/secrets' uvicorn aind_data_transfer_service.server:app --host 0.0.0.0 --port 5000 --reload You can now access aind-data-transfer-service at diff --git a/pyproject.toml b/pyproject.toml index 406a7b4..f85a06d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,8 @@ server = [ 'wtforms', 'requests==2.25.0', 'openpyxl', - 'python-logging-loki' + 'python-logging-loki', + 'authlib' ] [tool.setuptools.packages.find] diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index 5510751..fa7b1a5 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -15,6 +15,7 @@ __version__ as aind_data_transfer_models_version, ) from aind_data_transfer_models.core import SubmitJobRequest, validation_context +from authlib.integrations.starlette_client import OAuth from botocore.exceptions import ClientError from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse @@ -23,6 +24,9 @@ from openpyxl import load_workbook from pydantic import SecretStr, ValidationError from starlette.applications import Starlette +from starlette.config import Config +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import RedirectResponse from starlette.routing import Route from aind_data_transfer_service import OPEN_DATA_BUCKET_NAME @@ -95,6 +99,27 @@ def get_project_names() -> List[str]: return project_names +def set_oauth() -> OAuth: + """Set up OAuth for the service""" + secrets_client = boto3.client("secretsmanager") + secret_response = secrets_client.get_secret_value( + SecretId=os.getenv("AIND_SSO_SECRET_NAME") + ) + secret_value = json.loads(secret_response["SecretString"]) + for secrets in secret_value: + os.environ[secrets] = secret_value[secrets] + config = Config() + oauth = OAuth(config) + oauth.register( + name="azure", + client_id=config("CLIENT_ID"), + client_secret=config("CLIENT_SECRET"), + server_metadata_url=config("AUTHORITY"), + client_kwargs={"scope": "openid email profile"}, + ) + return oauth + + def get_job_types(version: Optional[str] = None) -> List[str]: """Get a list of job_types""" params = get_parameter_infos(version) @@ -1090,6 +1115,60 @@ def get_parameter(request: Request): ) +async def admin(request: Request): + """Get admin page if authenticated, else redirect to login.""" + user = request.session.get("user") + if os.getenv("ENV_NAME") == "local": + user = {"name": "local user"} + if user: + return templates.TemplateResponse( + name="admin.html", + context=( + { + "request": request, + "project_names_url": project_names_url, + "user_name": user.get("name", "unknown"), + "user_email": user.get("email", "unknown"), + } + ), + ) + return RedirectResponse(url="/login") + + +async def login(request: Request): + """Redirect to Azure login page""" + oauth = set_oauth() + redirect_uri = request.url_for("auth") + response = await oauth.azure.authorize_redirect(request, redirect_uri) + return response + + +async def logout(request: Request): + """Logout user and clear session""" + request.session.pop("user", None) + return RedirectResponse(url="/") + + +async def auth(request: Request): + """Authenticate user and store user info in session""" + oauth = set_oauth() + try: + token = await oauth.azure.authorize_access_token(request) + user = token.get("userinfo") + if not user: + raise ValueError("User info not found in access token.") + request.session["user"] = dict(user) + except Exception as error: + return JSONResponse( + content={ + "message": "Error Logging In", + "data": {"error": f"{error.__class__.__name__}{error.args}"}, + }, + status_code=500, + ) + return RedirectResponse(url="/admin") + + routes = [ Route("/", endpoint=index, methods=["GET", "POST"]), Route("/api/validate_csv", endpoint=validate_csv_legacy, methods=["POST"]), @@ -1132,6 +1211,11 @@ def get_parameter(request: Request): endpoint=download_job_template, methods=["GET"], ), + Route("/login", login, methods=["GET"]), + Route("/logout", logout, methods=["GET"]), + Route("/auth", auth, methods=["GET"]), + Route("/admin", admin, methods=["GET"]), ] app = Starlette(routes=routes) +app.add_middleware(SessionMiddleware, secret_key=None) diff --git a/src/aind_data_transfer_service/templates/admin.html b/src/aind_data_transfer_service/templates/admin.html new file mode 100644 index 0000000..fba27bc --- /dev/null +++ b/src/aind_data_transfer_service/templates/admin.html @@ -0,0 +1,36 @@ + + +
+ + + +