Skip to content
Open
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ ADMYRAL_ENV=dev
ADMYRAL_LOGGING_LEVEL=DEBUG

ADMYRAL_DATABASE_URL="postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/admyral"
ADMYRAL_TEST_DATABASE_URL="postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5433/admyral"

# Optional but required if ai_action action is used
OPENAI_API_KEY=""
Expand Down
21 changes: 10 additions & 11 deletions admyral/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,20 @@ class SecretsManagerType(str, Enum):


ENV_ADMYRAL_DATABASE_URL = "ADMYRAL_DATABASE_URL"
ENV_ADMYRAL_TEST_DATABASE_URL = "ADMYRAL_TEST_DATABASE_URL"
ENV_ADMYRAL_SECRETS_MANAGER_TYPE = "ADMYRAL_SECRETS_MANAGER"


ADMYRAL_DATABASE_URL = os.getenv(
ENV_ADMYRAL_DATABASE_URL,
"postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5432/admyral",
)
if ADMYRAL_DATABASE_URL.startswith("postgresql"):
if ADMYRAL_DATABASE_URL.startswith("postgresql://"):
ADMYRAL_DATABASE_URL = ADMYRAL_DATABASE_URL.replace(
"postgresql://", "postgresql+asyncpg://"
)
ADMYRAL_DATABASE_TYPE = DatabaseType.POSTGRES
else:
raise NotImplementedError(f"Unsupported database type: {ADMYRAL_DATABASE_URL}")
"postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5433/admyral",
).replace("postgresql://", "postgresql+asyncpg://")

ADMYRAL_TEST_DATABASE_URL = os.getenv(
ENV_ADMYRAL_TEST_DATABASE_URL,
"postgresql://postgres:your-super-secret-and-long-postgres-password@localhost:5433/admyral",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default should point to the test postgres and not the actual postgres database. Can you also adapt .env.example and set the an env var for the test database, as well, so that we're consistent with how we set the normal database url.

).replace("postgresql://", "postgresql+asyncpg://")


ADMYRAL_SECRETS_MANAGER_TYPE = SecretsManagerType(
os.getenv(ENV_ADMYRAL_SECRETS_MANAGER_TYPE, SecretsManagerType.SQL)
Expand All @@ -151,8 +150,8 @@ class GlobalConfig(BaseModel):
default_user_email: str = "default.user@admyral.ai"
telemetry_disabled: bool = ADMYRAL_DISABLE_TELEMETRY
storage_directory: str = get_local_storage_path()
database_type: DatabaseType = ADMYRAL_DATABASE_TYPE
database_url: str = ADMYRAL_DATABASE_URL
test_database_url: str = ADMYRAL_TEST_DATABASE_URL
temporal_host: str = ADMYRAL_TEMPORAL_HOST
secrets_manager_type: SecretsManagerType = ADMYRAL_SECRETS_MANAGER_TYPE
posthog_api_key: str = ADMYRAL_POSTHOG_API_KEY
Expand Down
33 changes: 26 additions & 7 deletions admyral/db/admyral_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@
UserSchema,
ApiKeySchema,
WorkflowControlResultsSchema,
ControlSchema,
)
from admyral.db.alembic.database_manager import DatabaseManager
from admyral.config.config import GlobalConfig, CONFIG
from admyral.config.config import CONFIG
from admyral.logger import get_logger
from admyral.utils.time import utc_now
from admyral.utils.crypto import generate_hs256
Expand Down Expand Up @@ -74,11 +75,11 @@ async def commit(self) -> None:


class AdmyralStore(StoreInterface):
def __init__(self, config: GlobalConfig) -> None:
self.config = config
def __init__(self, database_url: str) -> None:
self.database_url = database_url

self.engine = create_async_engine(
self.config.database_url, echo=True, future=True, pool_pre_ping=True
database_url, echo=True, future=True, pool_pre_ping=True
)
self.async_session_maker = sessionmaker(
self.engine, class_=AsyncSession, expire_on_commit=False
Expand All @@ -90,8 +91,10 @@ def __init__(self, config: GlobalConfig) -> None:

# TODO: pass down config
@classmethod
async def create_store(cls, skip_setup: bool = False) -> "AdmyralStore":
store = cls(CONFIG)
async def create_store(
cls, skip_setup: bool = False, database_url: str | None = None
) -> "AdmyralStore":
store = cls(database_url or CONFIG.database_url)
if not skip_setup:
await store.setup()
store.performed_setup = True
Expand All @@ -105,7 +108,7 @@ async def create_store(cls, skip_setup: bool = False) -> "AdmyralStore":
async def setup(self):
logger.info("Setting up Admyral store.")

database_manager = DatabaseManager(self.engine, self.config)
database_manager = DatabaseManager(self.engine, self.database_url)

does_db_exist = await database_manager.database_exists()
if not does_db_exist:
Expand Down Expand Up @@ -1074,3 +1077,19 @@ async def store_workflow_control_result(
)
)
await db.commit()

async def clean_up_controls_data(self, user_id: str) -> None:
"""Delete all data from controls-related tables. Should only be used in testing environments."""
async with self._get_async_session() as db:
# Delete workflow control results where the workflow belongs to the user
await db.exec(
delete(WorkflowControlResultsSchema).where(
WorkflowControlResultsSchema.workflow_id.in_(
select(WorkflowSchema.workflow_id).where(
WorkflowSchema.user_id == user_id
)
)
)
)
await db.exec(delete(ControlSchema).where(ControlSchema.user_id == user_id))
await db.commit()
49 changes: 18 additions & 31 deletions admyral/db/alembic/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from sqlalchemy.engine import Connection
from functools import partial

from admyral.config.config import GlobalConfig, DatabaseType


# TODO: why are we filtering out the alembic_version table?
def include_object(object, name, type_, reflected, compare_to):
Expand All @@ -25,9 +23,9 @@ def get_admyral_dir() -> str:


class DatabaseManager:
def __init__(self, engine: AsyncEngine, config: GlobalConfig) -> None:
def __init__(self, engine: AsyncEngine, database_url: str) -> None:
self.engine = engine
self.config = config
self.database_url = database_url

self.target_metadata = SQLModel.metadata

Expand All @@ -42,39 +40,28 @@ def __init__(self, engine: AsyncEngine, config: GlobalConfig) -> None:

def _get_postgres_setup_engine(self) -> str:
# https://stackoverflow.com/questions/6506578/how-to-create-a-new-database-using-sqlalchemy/8977109#8977109
db_name = self.config.database_url.split("/")[-1]
db_url = self.config.database_url[: -len(db_name)] + "postgres"
db_name = self.database_url.split("/")[-1]
db_url = self.database_url[: -len(db_name)] + "postgres"
return create_async_engine(db_url, echo=True, future=True, pool_pre_ping=True)

async def database_exists(self) -> bool:
if self.config.database_type == DatabaseType.POSTGRES:
engine = self._get_postgres_setup_engine()
try:
async with engine.connect() as conn:
result = await conn.execute(
text(
"select exists (select 1 from pg_database where datname = 'admyral')"
)
engine = self._get_postgres_setup_engine()
try:
async with engine.connect() as conn:
result = await conn.execute(
text(
"select exists (select 1 from pg_database where datname = 'admyral')"
)
return result.scalar()
except Exception:
return False

raise NotImplementedError(
f"Unimplemented database type in database_exists: {self.database_type}"
)
)
return result.scalar()
except Exception:
return False

async def create_database(self) -> None:
if self.config.database_type == DatabaseType.POSTGRES:
engine = self._get_postgres_setup_engine()
async with engine.connect() as conn:
await conn.execute(text("commit"))
await conn.execute(text("create database admyral"))
return

raise NotImplementedError(
f"Unimplemented database type in create_database: {self.database_type}"
)
engine = self._get_postgres_setup_engine()
async with engine.connect() as conn:
await conn.execute(text("commit"))
await conn.execute(text("create database admyral"))

async def drop_database(self) -> None:
# TODO:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""add controls and controls_workflows tables
Revision ID: 3967e08c0051
Revises: 7985f1c159a3
Create Date: 2024-11-27 17:36:23.138028
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlmodel # noqa F401


# revision identifiers, used by Alembic.
revision: str = "3967e08c0051"
down_revision: Union[str, None] = "7985f1c159a3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"controls",
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("control_id", sa.TEXT(), nullable=False),
sa.Column("user_id", sa.TEXT(), nullable=False),
sa.Column("control_name", sa.TEXT(), nullable=False),
sa.Column("control_description", sa.TEXT(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["User.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("control_id", "user_id"),
sa.UniqueConstraint("user_id", "control_id", name="unique_control_id"),
)
op.create_index(
op.f("ix_controls_control_name"), "controls", ["control_name"], unique=False
)
op.create_table(
"controls_workflows",
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("control_id", sa.TEXT(), nullable=False),
sa.Column("user_id", sa.TEXT(), nullable=False),
sa.Column("workflow_id", sa.TEXT(), nullable=False),
sa.ForeignKeyConstraint(
["control_id", "user_id"],
["controls.control_id", "controls.user_id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["workflow_id"], ["workflows.workflow_id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("control_id", "user_id", "workflow_id"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("controls_workflows")
op.drop_index(op.f("ix_controls_control_name"), table_name="controls")
op.drop_table("controls")
# ### end Alembic commands ###
6 changes: 6 additions & 0 deletions admyral/db/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from admyral.db.schemas.workflow_control_results_schemas import (
WorkflowControlResultsSchema,
)
from admyral.db.schemas.control_schemas import (
ControlSchema,
ControlsWorkflowsMappingSchema,
)

__all__ = [
"PipLockfileCacheSchema",
Expand All @@ -36,4 +40,6 @@
"AuthenticatorSchema",
"ApiKeySchema",
"WorkflowControlResultsSchema",
"ControlSchema",
"ControlsWorkflowsMappingSchema",
]
4 changes: 4 additions & 0 deletions admyral/db/schemas/auth_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from admyral.db.schemas.workflow_schemas import WorkflowSchema
from admyral.db.schemas.secrets_schemas import SecretsSchema
from admyral.db.schemas.base_schemas import BaseSchema
from admyral.db.schemas.control_schemas import ControlSchema


class UserSchema(BaseSchema, table=True):
Expand Down Expand Up @@ -53,6 +54,9 @@ class UserSchema(BaseSchema, table=True):
secrets: list[SecretsSchema] = Relationship(
back_populates="user", sa_relationship_kwargs=dict(cascade="all, delete")
)
controls: list[ControlSchema] = Relationship(
back_populates="user", sa_relationship_kwargs=dict(cascade="all, delete")
)

def to_model(self) -> User:
return User.model_validate(
Expand Down
Loading