Skip to content

Commit fae9869

Browse files
committed
test(graph): add unit tests for relational deletion helpers
1 parent a0f25f4 commit fae9869

File tree

2 files changed

+330
-0
lines changed

2 files changed

+330
-0
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import uuid
2+
import pytest
3+
from sqlalchemy import select, func
4+
5+
from cognee.infrastructure.databases.relational import (
6+
create_db_and_tables,
7+
get_async_session,
8+
)
9+
from cognee.modules.graph.models import Node, Edge
10+
from cognee.modules.graph.methods import (
11+
delete_data_related_nodes,
12+
delete_data_related_edges,
13+
)
14+
15+
16+
@pytest.mark.asyncio
17+
async def test_delete_data_related_nodes_removes_only_target_data():
18+
await create_db_and_tables()
19+
20+
dataset_id = uuid.uuid4()
21+
user_id = uuid.uuid4()
22+
data_a = uuid.uuid4()
23+
data_b = uuid.uuid4()
24+
25+
node_a1 = Node(
26+
id=uuid.uuid4(),
27+
slug=uuid.uuid4(),
28+
user_id=user_id,
29+
data_id=data_a,
30+
dataset_id=dataset_id,
31+
label="A1",
32+
type="TypeA",
33+
indexed_fields=["text"],
34+
attributes={"k": "v"},
35+
)
36+
node_a2 = Node(
37+
id=uuid.uuid4(),
38+
slug=uuid.uuid4(),
39+
user_id=user_id,
40+
data_id=data_a,
41+
dataset_id=dataset_id,
42+
label="A2",
43+
type="TypeA",
44+
indexed_fields=["text"],
45+
attributes={"k2": "v2"},
46+
)
47+
node_b1 = Node(
48+
id=uuid.uuid4(),
49+
slug=uuid.uuid4(),
50+
user_id=user_id,
51+
data_id=data_b,
52+
dataset_id=dataset_id,
53+
label="B1",
54+
type="TypeB",
55+
indexed_fields=["text"],
56+
attributes={"k3": "v3"},
57+
)
58+
59+
async with get_async_session(auto_commit=True) as session:
60+
session.add_all([node_a1, node_a2, node_b1])
61+
62+
await delete_data_related_nodes(data_a)
63+
64+
async with get_async_session() as session:
65+
count_a = (
66+
await session.scalar(select(func.count()).select_from(Node).where(Node.data_id == data_a))
67+
)
68+
count_b = (
69+
await session.scalar(select(func.count()).select_from(Node).where(Node.data_id == data_b))
70+
)
71+
72+
assert count_a == 0
73+
assert count_b == 1
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_delete_data_related_edges_removes_only_target_data():
78+
await create_db_and_tables()
79+
80+
dataset_id = uuid.uuid4()
81+
user_id = uuid.uuid4()
82+
data_a = uuid.uuid4()
83+
data_b = uuid.uuid4()
84+
85+
# Nodes for referencing
86+
n1 = Node(
87+
id=uuid.uuid4(),
88+
slug=uuid.uuid4(),
89+
user_id=user_id,
90+
data_id=data_a,
91+
dataset_id=dataset_id,
92+
label="N1",
93+
type="TypeA",
94+
indexed_fields=["text"],
95+
attributes={},
96+
)
97+
n2 = Node(
98+
id=uuid.uuid4(),
99+
slug=uuid.uuid4(),
100+
user_id=user_id,
101+
data_id=data_b,
102+
dataset_id=dataset_id,
103+
label="N2",
104+
type="TypeB",
105+
indexed_fields=["text"],
106+
attributes={},
107+
)
108+
109+
e_a = Edge(
110+
id=uuid.uuid4(),
111+
slug=uuid.uuid4(),
112+
user_id=user_id,
113+
data_id=data_a,
114+
dataset_id=dataset_id,
115+
source_node_id=n1.id,
116+
destination_node_id=n1.id,
117+
relationship_name="REL_A",
118+
label="LA",
119+
attributes={},
120+
)
121+
e_b = Edge(
122+
id=uuid.uuid4(),
123+
slug=uuid.uuid4(),
124+
user_id=user_id,
125+
data_id=data_b,
126+
dataset_id=dataset_id,
127+
source_node_id=n2.id,
128+
destination_node_id=n2.id,
129+
relationship_name="REL_B",
130+
label="LB",
131+
attributes={},
132+
)
133+
134+
async with get_async_session(auto_commit=True) as session:
135+
session.add_all([n1, n2, e_a, e_b])
136+
137+
await delete_data_related_edges(data_a)
138+
139+
async with get_async_session() as session:
140+
count_a = (
141+
await session.scalar(select(func.count()).select_from(Edge).where(Edge.data_id == data_a))
142+
)
143+
count_b = (
144+
await session.scalar(select(func.count()).select_from(Edge).where(Edge.data_id == data_b))
145+
)
146+
147+
assert count_a == 0
148+
assert count_b == 1
149+
150+
151+
@pytest.mark.asyncio
152+
async def test_delete_data_related_nodes_edges_noop_on_empty_tables():
153+
await create_db_and_tables()
154+
155+
data_x = uuid.uuid4()
156+
157+
# Should not raise
158+
await delete_data_related_nodes(data_x)
159+
await delete_data_related_edges(data_x)
160+
161+
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import uuid
2+
import pytest
3+
from sqlalchemy import select, func
4+
5+
from cognee.infrastructure.databases.relational import (
6+
create_db_and_tables,
7+
get_async_session,
8+
)
9+
from cognee.modules.graph.models import Node, Edge
10+
from cognee.modules.graph.methods import (
11+
delete_dataset_related_nodes,
12+
delete_dataset_related_edges,
13+
)
14+
15+
16+
@pytest.mark.asyncio
17+
async def test_delete_dataset_related_nodes_removes_only_target_dataset():
18+
await create_db_and_tables()
19+
20+
dataset_a = uuid.uuid4()
21+
dataset_b = uuid.uuid4()
22+
user_id = uuid.uuid4()
23+
data_id = uuid.uuid4()
24+
25+
node_a1 = Node(
26+
id=uuid.uuid4(),
27+
slug=uuid.uuid4(),
28+
user_id=user_id,
29+
data_id=data_id,
30+
dataset_id=dataset_a,
31+
label="A1",
32+
type="TypeA",
33+
indexed_fields=["text"],
34+
attributes={"k": "v"},
35+
)
36+
node_a2 = Node(
37+
id=uuid.uuid4(),
38+
slug=uuid.uuid4(),
39+
user_id=user_id,
40+
data_id=data_id,
41+
dataset_id=dataset_a,
42+
label="A2",
43+
type="TypeA",
44+
indexed_fields=["text"],
45+
attributes={"k2": "v2"},
46+
)
47+
node_b1 = Node(
48+
id=uuid.uuid4(),
49+
slug=uuid.uuid4(),
50+
user_id=user_id,
51+
data_id=data_id,
52+
dataset_id=dataset_b,
53+
label="B1",
54+
type="TypeB",
55+
indexed_fields=["text"],
56+
attributes={"k3": "v3"},
57+
)
58+
59+
async with get_async_session(auto_commit=True) as session:
60+
session.add_all([node_a1, node_a2, node_b1])
61+
62+
await delete_dataset_related_nodes(dataset_a)
63+
64+
async with get_async_session() as session:
65+
count_a = (
66+
await session.scalar(
67+
select(func.count()).select_from(Node).where(Node.dataset_id == dataset_a)
68+
)
69+
)
70+
count_b = (
71+
await session.scalar(
72+
select(func.count()).select_from(Node).where(Node.dataset_id == dataset_b)
73+
)
74+
)
75+
76+
assert count_a == 0
77+
assert count_b == 1
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_delete_dataset_related_edges_removes_only_target_dataset():
82+
await create_db_and_tables()
83+
84+
dataset_a = uuid.uuid4()
85+
dataset_b = uuid.uuid4()
86+
user_id = uuid.uuid4()
87+
data_id = uuid.uuid4()
88+
89+
# Create two nodes to reference in edges
90+
n1 = Node(
91+
id=uuid.uuid4(),
92+
slug=uuid.uuid4(),
93+
user_id=user_id,
94+
data_id=data_id,
95+
dataset_id=dataset_a,
96+
label="N1",
97+
type="TypeA",
98+
indexed_fields=["text"],
99+
attributes={},
100+
)
101+
n2 = Node(
102+
id=uuid.uuid4(),
103+
slug=uuid.uuid4(),
104+
user_id=user_id,
105+
data_id=data_id,
106+
dataset_id=dataset_b,
107+
label="N2",
108+
type="TypeB",
109+
indexed_fields=["text"],
110+
attributes={},
111+
)
112+
113+
e_a = Edge(
114+
id=uuid.uuid4(),
115+
slug=uuid.uuid4(),
116+
user_id=user_id,
117+
data_id=data_id,
118+
dataset_id=dataset_a,
119+
source_node_id=n1.id,
120+
destination_node_id=n1.id,
121+
relationship_name="REL_A",
122+
label="LA",
123+
attributes={},
124+
)
125+
e_b = Edge(
126+
id=uuid.uuid4(),
127+
slug=uuid.uuid4(),
128+
user_id=user_id,
129+
data_id=data_id,
130+
dataset_id=dataset_b,
131+
source_node_id=n2.id,
132+
destination_node_id=n2.id,
133+
relationship_name="REL_B",
134+
label="LB",
135+
attributes={},
136+
)
137+
138+
async with get_async_session(auto_commit=True) as session:
139+
session.add_all([n1, n2, e_a, e_b])
140+
141+
await delete_dataset_related_edges(dataset_a)
142+
143+
async with get_async_session() as session:
144+
count_a = (
145+
await session.scalar(
146+
select(func.count()).select_from(Edge).where(Edge.dataset_id == dataset_a)
147+
)
148+
)
149+
count_b = (
150+
await session.scalar(
151+
select(func.count()).select_from(Edge).where(Edge.dataset_id == dataset_b)
152+
)
153+
)
154+
155+
assert count_a == 0
156+
assert count_b == 1
157+
158+
159+
@pytest.mark.asyncio
160+
async def test_delete_dataset_related_nodes_edges_noop_on_empty_tables():
161+
await create_db_and_tables()
162+
163+
dataset_x = uuid.uuid4()
164+
165+
# Should not raise
166+
await delete_dataset_related_nodes(dataset_x)
167+
await delete_dataset_related_edges(dataset_x)
168+
169+

0 commit comments

Comments
 (0)