Skip to content

Commit 0646b5c

Browse files
committed
fix: implement all 24 review findings (P0-P3)
Phase 1: Foundation & Quick Fixes - P0-4: Fix test fixture scope (session -> function) - P0-2: Add secret key validation with production check - P1-2: Increase connection pool size (5->25, overflow 10->50) - P2-4: Create constants module for centralized values - P1-4: Extract generate_slug utility, remove duplication Phase 2: Multi-Tenancy Enforcement - P0-1: Make tenant_id required in all UserRepository methods - Add system-level methods for authentication flows - P1-7: Create tenant isolation integration tests Phase 3: Redis & OAuth - P0-3: Implement Redis caching layer for OAuth state - P1-8: Refactor oauth_callback into helper functions - P1-5: Create OAuth integration tests Phase 4: Performance & Database - P1-1: Change lazy='selectin' to lazy='raise' for relationships - P2-2: Add composite indexes (tenant_email, oauth) - P2-3: Change expires_at from String to DateTime - P1-3: Add JWT token revocation with jti claim Phase 5: RBAC & Permissions - P1-9: Consolidate permission decorators with shared helper - P2-6: Add superuser bypass audit logging - P1-6: Create RBAC unit and integration tests Phase 6: Code Quality & Polish - P2-1: Fix Mypy type errors in middleware - P2-5: Add password complexity validation - P2-7: Remove outdated TODO - P3-1: Fix user enumeration with generic messages - P3-2: Improve CORS configuration - P3-3: Use configurable api_docs_base_url for error URIs - P3-4: Fix OAuth provider ClassVar inheritance
1 parent c4fbbef commit 0646b5c

33 files changed

+2802
-352
lines changed

REVIEW_FINDINGS.md

Lines changed: 76 additions & 16 deletions
Large diffs are not rendered by default.
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""performance_and_security_fixes
2+
3+
Revision ID: b2c3d4e5f6g7
4+
Revises: a1b2c3d4e5f6
5+
Create Date: 2025-12-01 00:02:00.000000
6+
7+
This migration adds:
8+
- Composite indexes for performance (P2-2)
9+
- Changes expires_at from String to DateTime (P2-3)
10+
- Adds revoked_tokens table for JWT revocation (P1-3)
11+
"""
12+
13+
from typing import Sequence, Union
14+
15+
import sqlalchemy as sa
16+
from alembic import op
17+
18+
19+
# revision identifiers, used by Alembic.
20+
revision: str = "b2c3d4e5f6g7"
21+
down_revision: Union[str, None] = "a1b2c3d4e5f6"
22+
branch_labels: Union[str, Sequence[str], None] = None
23+
depends_on: Union[str, Sequence[str], None] = None
24+
25+
26+
def upgrade() -> None:
27+
"""Upgrade database schema."""
28+
# P2-2: Add composite indexes for users table
29+
op.create_index(
30+
"ix_users_tenant_email",
31+
"users",
32+
["tenant_id", "email"],
33+
unique=True,
34+
)
35+
op.create_index(
36+
"ix_users_oauth",
37+
"users",
38+
["oauth_provider", "oauth_id"],
39+
unique=False,
40+
)
41+
42+
# P2-3: Change expires_at from String to DateTime
43+
# First, add a new column
44+
op.add_column(
45+
"refresh_tokens",
46+
sa.Column(
47+
"expires_at_new",
48+
sa.DateTime(timezone=True),
49+
nullable=True,
50+
),
51+
)
52+
53+
# Migrate data from string to datetime
54+
op.execute(
55+
"""
56+
UPDATE refresh_tokens
57+
SET expires_at_new = expires_at::timestamptz
58+
WHERE expires_at IS NOT NULL
59+
"""
60+
)
61+
62+
# Drop old column and rename new one
63+
op.drop_column("refresh_tokens", "expires_at")
64+
op.alter_column(
65+
"refresh_tokens",
66+
"expires_at_new",
67+
new_column_name="expires_at",
68+
nullable=False,
69+
)
70+
71+
# Add index on expires_at for efficient cleanup queries
72+
op.create_index(
73+
"ix_refresh_tokens_expires_at",
74+
"refresh_tokens",
75+
["expires_at"],
76+
)
77+
78+
# P1-3: Create revoked_tokens table for JWT revocation
79+
op.create_table(
80+
"revoked_tokens",
81+
sa.Column("jti", sa.String(length=64), nullable=False),
82+
sa.Column(
83+
"expires_at",
84+
sa.DateTime(timezone=True),
85+
nullable=False,
86+
),
87+
sa.Column("id", sa.Uuid(), nullable=False),
88+
sa.Column(
89+
"created_at",
90+
sa.DateTime(timezone=True),
91+
server_default=sa.text("now()"),
92+
nullable=False,
93+
),
94+
sa.Column(
95+
"updated_at",
96+
sa.DateTime(timezone=True),
97+
server_default=sa.text("now()"),
98+
nullable=False,
99+
),
100+
sa.PrimaryKeyConstraint("id"),
101+
)
102+
op.create_index(
103+
"ix_revoked_tokens_id",
104+
"revoked_tokens",
105+
["id"],
106+
unique=False,
107+
)
108+
op.create_index(
109+
"ix_revoked_tokens_jti",
110+
"revoked_tokens",
111+
["jti"],
112+
unique=True,
113+
)
114+
op.create_index(
115+
"ix_revoked_tokens_expires_at",
116+
"revoked_tokens",
117+
["expires_at"],
118+
)
119+
120+
121+
def downgrade() -> None:
122+
"""Downgrade database schema."""
123+
# Drop revoked_tokens table
124+
op.drop_index("ix_revoked_tokens_expires_at", table_name="revoked_tokens")
125+
op.drop_index("ix_revoked_tokens_jti", table_name="revoked_tokens")
126+
op.drop_index("ix_revoked_tokens_id", table_name="revoked_tokens")
127+
op.drop_table("revoked_tokens")
128+
129+
# Revert expires_at back to String
130+
op.drop_index("ix_refresh_tokens_expires_at", table_name="refresh_tokens")
131+
132+
op.add_column(
133+
"refresh_tokens",
134+
sa.Column(
135+
"expires_at_old",
136+
sa.String(),
137+
nullable=True,
138+
),
139+
)
140+
141+
op.execute(
142+
"""
143+
UPDATE refresh_tokens
144+
SET expires_at_old = expires_at::text
145+
WHERE expires_at IS NOT NULL
146+
"""
147+
)
148+
149+
op.drop_column("refresh_tokens", "expires_at")
150+
op.alter_column(
151+
"refresh_tokens",
152+
"expires_at_old",
153+
new_column_name="expires_at",
154+
nullable=False,
155+
)
156+
157+
# Drop composite indexes from users table
158+
op.drop_index("ix_users_oauth", table_name="users")
159+
op.drop_index("ix_users_tenant_email", table_name="users")
160+

src/app/api/dependencies.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Shared API dependencies."""
22

3-
from collections.abc import AsyncGenerator
43
from typing import Annotated
54

65
from fastapi import Depends
@@ -13,14 +12,6 @@
1312
DBSession = Annotated[AsyncSession, Depends(get_db)]
1413

1514

16-
async def get_current_tenant_id() -> AsyncGenerator[str | None, None]:
17-
"""Get current tenant ID from request context.
18-
19-
This is a placeholder that will be implemented in Phase 2
20-
with proper authentication middleware.
21-
"""
22-
# TODO: Extract tenant_id from authenticated user
23-
yield None
24-
25-
26-
CurrentTenantId = Annotated[str | None, Depends(get_current_tenant_id)]
15+
# Note: Tenant context is now handled by TenantContextMiddleware.
16+
# For authenticated routes, use TenantId from app.core.auth.dependencies
17+
# which extracts tenant_id from the JWT token.

src/app/config.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from functools import lru_cache
44

5-
from pydantic import PostgresDsn, RedisDsn, computed_field
5+
from pydantic import PostgresDsn, RedisDsn, computed_field, field_validator
66
from pydantic_settings import BaseSettings, SettingsConfigDict
77

8+
from app.core.constants import DEFAULT_INSECURE_SECRET, MIN_SECRET_KEY_LENGTH
9+
810

911
class Settings(BaseSettings):
1012
"""Application settings loaded from environment variables."""
@@ -20,16 +22,52 @@ class Settings(BaseSettings):
2022
app_name: str = "Agency Standard"
2123
debug: bool = False
2224
environment: str = "development" # development, staging, production
23-
secret_key: str = "change-me-in-production"
25+
secret_key: str = DEFAULT_INSECURE_SECRET
2426

2527
# Database
2628
database_url: PostgresDsn = PostgresDsn(
2729
"postgresql://postgres:postgres@localhost:5432/agency_standard"
2830
)
29-
database_pool_size: int = 5
30-
database_max_overflow: int = 10
31+
database_pool_size: int = 25
32+
database_max_overflow: int = 50
3133
database_echo: bool = False
3234

35+
# CORS
36+
cors_origins: list[str] = []
37+
38+
# API Documentation
39+
api_docs_base_url: str = "https://api.example.com"
40+
41+
@field_validator("secret_key")
42+
@classmethod
43+
def validate_secret_key(cls, v: str, info) -> str:
44+
"""Validate that secret_key is secure in production.
45+
46+
Args:
47+
v: The secret key value
48+
info: Validation info containing other field values
49+
50+
Returns:
51+
The validated secret key
52+
53+
Raises:
54+
ValueError: If secret key is insecure in production
55+
"""
56+
# Get environment from the data being validated
57+
# Note: In pydantic v2, we need to handle this differently
58+
# The validation happens before all fields are set
59+
# So we check the value itself rather than environment
60+
if v == DEFAULT_INSECURE_SECRET:
61+
# We'll do a runtime check in is_production instead
62+
# to allow development mode to use default
63+
pass
64+
elif len(v) < MIN_SECRET_KEY_LENGTH:
65+
raise ValueError(
66+
f"SECRET_KEY must be at least {MIN_SECRET_KEY_LENGTH} characters. "
67+
"Generate one with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
68+
)
69+
return v
70+
3371
# Redis
3472
redis_url: RedisDsn = RedisDsn("redis://localhost:6379")
3573

@@ -63,8 +101,18 @@ def async_database_url(self) -> str:
63101
@computed_field # type: ignore[prop-decorator]
64102
@property
65103
def is_production(self) -> bool:
66-
"""Check if running in production environment."""
67-
return self.environment == "production"
104+
"""Check if running in production environment.
105+
106+
Raises:
107+
ValueError: If using insecure secret key in production
108+
"""
109+
is_prod = self.environment == "production"
110+
if is_prod and self.secret_key == DEFAULT_INSECURE_SECRET:
111+
raise ValueError(
112+
"SECRET_KEY must be set to a secure value in production. "
113+
"Generate one with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
114+
)
115+
return is_prod
68116

69117
@computed_field # type: ignore[prop-decorator]
70118
@property

src/app/core/auth/backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- Password hashing with bcrypt
55
- JWT token creation and verification
66
- Token hashing for storage
7+
- Token revocation
78
"""
89

910
import hashlib
@@ -17,6 +18,7 @@
1718

1819
from app.config import settings
1920
from app.core.auth.schemas import TokenData
21+
from app.core.constants import ACCESS_TOKEN_JTI_LENGTH
2022

2123

2224
# Password hashing context using bcrypt
@@ -93,6 +95,7 @@ def create_access_token(
9395
"exp": expire,
9496
"type": "access",
9597
"iat": datetime.now(UTC),
98+
"jti": secrets.token_urlsafe(ACCESS_TOKEN_JTI_LENGTH), # Unique token ID
9699
}
97100

98101
if additional_claims:
@@ -162,6 +165,7 @@ def decode_token(token: str) -> TokenData | None:
162165
tenant_id = payload.get("tenant_id")
163166
exp = payload.get("exp")
164167
token_type = payload.get("type", "access")
168+
jti = payload.get("jti")
165169

166170
if not user_id or not tenant_id:
167171
return None
@@ -171,6 +175,7 @@ def decode_token(token: str) -> TokenData | None:
171175
tenant_id=UUID(tenant_id),
172176
exp=datetime.fromtimestamp(exp, tz=UTC),
173177
type=token_type,
178+
jti=jti,
174179
)
175180

176181
except (JWTError, ValueError):

src/app/core/auth/dependencies.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ async def get_current_user(
7777
from app.modules.users.repos import UserRepository # noqa: PLC0415
7878

7979
repo = UserRepository(db)
80-
user = await repo.get_by_id(token_data.user_id)
80+
# Use tenant-scoped lookup for security
81+
user = await repo.get_by_id(token_data.user_id, token_data.tenant_id)
8182

8283
if not user:
8384
raise UnauthorizedError(
@@ -180,7 +181,8 @@ async def get_optional_user(
180181
from app.modules.users.repos import UserRepository # noqa: PLC0415
181182

182183
repo = UserRepository(db)
183-
user = await repo.get_by_id(token_data.user_id)
184+
# Use tenant-scoped lookup for security
185+
user = await repo.get_by_id(token_data.user_id, token_data.tenant_id)
184186

185187
if not user or not user.is_active:
186188
return None

src/app/core/auth/middleware.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,19 @@
66
"""
77

88
import uuid
9-
from collections.abc import Callable
10-
from typing import Any
9+
from typing import TYPE_CHECKING
1110

1211
import structlog
1312
from fastapi import Request, Response
14-
from starlette.middleware.base import BaseHTTPMiddleware
13+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
1514

1615
from app.core.auth.backend import decode_token
1716

1817

18+
if TYPE_CHECKING:
19+
from starlette.types import ASGIApp
20+
21+
1922
logger = structlog.get_logger()
2023

2124

@@ -31,7 +34,7 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
3134

3235
def __init__(
3336
self,
34-
app: Any,
37+
app: "ASGIApp",
3538
exclude_paths: list[str] | None = None,
3639
) -> None:
3740
super().__init__(app)
@@ -49,7 +52,7 @@ def __init__(
4952
async def dispatch(
5053
self,
5154
request: Request,
52-
call_next: Callable[[Request], Any],
55+
call_next: RequestResponseEndpoint,
5356
) -> Response:
5457
"""Process the request and inject tenant context.
5558
@@ -95,7 +98,7 @@ class RequestIdMiddleware(BaseHTTPMiddleware):
9598
async def dispatch(
9699
self,
97100
request: Request,
98-
call_next: Callable[[Request], Any],
101+
call_next: RequestResponseEndpoint,
99102
) -> Response:
100103
"""Process the request and add request ID.
101104

0 commit comments

Comments
 (0)