Skip to content

Commit 4f75c5a

Browse files
authored
[feat] Add shape operation tools (Layout-Parser#72)
* Add shape operation tools * Add test for tools * Add scipy as the dep
1 parent 341d3fc commit 4f75c5a

File tree

5 files changed

+246
-0
lines changed

5 files changed

+246
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
install_requires=[
3838
"numpy",
3939
"opencv-python",
40+
"scipy",
4041
"pandas",
4142
"pillow",
4243
"pyyaml>=5.1",

src/layoutparser/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@
5252
"is_pytesseract_available",
5353
"is_gcv_available",
5454
"requires_backends"
55+
],
56+
"tools": [
57+
"generalized_connected_component_analysis_1d",
58+
"simple_line_detection",
59+
"group_textblocks_based_on_category"
5560
]
5661
}
5762

src/layoutparser/tools/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .shape_operations import (
2+
generalized_connected_component_analysis_1d,
3+
simple_line_detection,
4+
group_textblocks_based_on_category,
5+
)
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright 2021 The Layout Parser team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Union, Any, Callable, Iterable
16+
from functools import partial, reduce
17+
18+
import numpy as np
19+
from scipy.sparse import csr_matrix
20+
from scipy.sparse.csgraph import connected_components
21+
22+
from ..elements import BaseLayoutElement, TextBlock
23+
24+
25+
def generalized_connected_component_analysis_1d(
26+
sequence: List[Any],
27+
scoring_func: Callable[[Any, Any], int],
28+
aggregation_func: Callable[[List[Any]], Any] = None,
29+
default_score_value: int = 0,
30+
) -> List[Any]:
31+
"""Perform connected componenet analysis for any 1D sequence based on
32+
the scoring function and the aggregation function.
33+
It will generate the adjacency_matrix for the 1D sequence object using
34+
the provided `scoring_func` and find the connected componenets.
35+
The `aggregation_func` will be used to aggregate all elements within
36+
identified components (when not set, it will be the identity function).
37+
38+
Args:
39+
sequence (List[Any]):
40+
The provided 1D sequence of objects.
41+
scoring_func (Callable[[Any, Any], int]):
42+
The scoring function used to construct the adjacency_matrix.
43+
It should take two objects in the sequence and produe a integer.
44+
aggregation_func (Callable[[List[Any]], Any], optional):
45+
The function used to aggregate the elements within an identified
46+
component.
47+
Defaults to the identify function: `lambda x: x`.
48+
default_score_value (int, optional):
49+
Used to set the default (background) score values that should be
50+
not considered when running connected component analysis.
51+
Defaults to 0.
52+
53+
Returns:
54+
List[Any]: A list of length n - the number of the detected componenets.
55+
"""
56+
57+
if aggregation_func is None:
58+
aggregation_func = lambda x: x # Identity Function
59+
60+
seq_len = len(sequence)
61+
adjacency_matrix = np.ones((seq_len, seq_len)) * default_score_value
62+
63+
for i in range(seq_len):
64+
for j in range(i + 1, seq_len):
65+
adjacency_matrix[i][j] = scoring_func(sequence[i], sequence[j])
66+
67+
graph = csr_matrix(adjacency_matrix)
68+
n_components, labels = connected_components(
69+
csgraph=graph, directed=False, return_labels=True
70+
)
71+
72+
grouped_sequence = []
73+
for comp_idx in range(n_components):
74+
element_idx = np.where(labels == comp_idx)[0]
75+
grouped_sequence.append(aggregation_func([sequence[i] for i in element_idx]))
76+
77+
return grouped_sequence
78+
79+
80+
def simple_line_detection(
81+
layout: Iterable[BaseLayoutElement], x_tolerance: int = 10, y_tolerance: int = 10
82+
) -> List[BaseLayoutElement]:
83+
"""Perform line detection based on connected component analysis.
84+
85+
The is_line_wise_close is the scoring function, which returns True
86+
if the y-difference is smaller than the y_tolerance AND the
87+
x-difference (the horizontal gap between two boxes) is also smaller
88+
than the x_tolerance, and False otherwise.
89+
90+
All the detected components will then be passed into aggregation_func,
91+
which returns the overall union box of all the elements, or the line
92+
box.
93+
94+
Args:
95+
layout (Iterable):
96+
A list (or Layout) of BaseLayoutElement
97+
x_tolerance (int, optional):
98+
The value used for specifying the maximum allowed y-difference
99+
when considered whether two tokens are from the same line.
100+
Defaults to 10.
101+
y_tolerance (int, optional):
102+
The value used for specifying the maximum allowed horizontal gap
103+
when considered whether two tokens are from the same line.
104+
Defaults to 10.
105+
106+
Returns:
107+
List[BaseLayoutElement]: A list of BaseLayoutElement, denoting the line boxes.
108+
"""
109+
110+
def is_line_wise_close(token_a, token_b, x_tolerance, y_tolerance):
111+
y_a = token_a.block.center[1]
112+
y_b = token_b.block.center[1]
113+
114+
a_left, a_right = token_a.block.coordinates[0::2]
115+
b_left, b_right = token_b.block.coordinates[0::2]
116+
117+
return (
118+
abs(y_a - y_b) <= y_tolerance
119+
and min(abs(a_left - b_right), abs(a_right - b_left)) <= x_tolerance
120+
)
121+
# If the y-difference is smaller than the y_tolerance AND
122+
# the x-difference (the horizontal gap between two boxes)
123+
# is also smaller than the x_tolerance threshold, then
124+
# these two tokens are considered as line-wise close.
125+
126+
detected_lines = generalized_connected_component_analysis_1d(
127+
layout,
128+
scoring_func=partial(
129+
is_line_wise_close, y_tolerance=x_tolerance, x_tolerance=y_tolerance
130+
),
131+
aggregation_func=lambda seq: reduce(layout[0].__class__.union, seq),
132+
)
133+
134+
return detected_lines
135+
136+
137+
def group_textblocks_based_on_category(
138+
layout: Iterable[TextBlock], union_group: bool = True
139+
) -> Union[List[TextBlock], List[List[TextBlock]]]:
140+
"""Group textblocks based on their category (block.type).
141+
142+
Args:
143+
layout (Iterable):
144+
A list (or Layout) of BaseLayoutElement
145+
union_group (bool):
146+
Whether to union the boxes within each group.
147+
Defaults to True.
148+
149+
Returns:
150+
List[TextBlock]: When `union_group=True`, it produces a list of
151+
TextBlocks, denoting the boundaries of each texblock group.
152+
List[List[TextBlock]]: When `union_group=False`, it preserves
153+
the elements within each group for further processing.
154+
"""
155+
156+
if union_group:
157+
aggregation_func = lambda seq: reduce(layout[0].__class__.union, seq)
158+
else:
159+
aggregation_func = None
160+
161+
detected_group_boxes = generalized_connected_component_analysis_1d(
162+
layout,
163+
scoring_func=lambda a, b: a.type == b.type,
164+
aggregation_func=aggregation_func,
165+
)
166+
167+
return detected_group_boxes

tests/test_tools.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2021 The Layout Parser team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from layoutparser import load_pdf
16+
from layoutparser.tools import (
17+
generalized_connected_component_analysis_1d,
18+
simple_line_detection,
19+
group_textblocks_based_on_category,
20+
)
21+
22+
def test_generalized_connected_component_analysis_1d():
23+
24+
A = [1, 2, 3]
25+
26+
results = generalized_connected_component_analysis_1d(
27+
A,
28+
scoring_func=lambda x,y: abs(x-y)<=1
29+
)
30+
assert len(results) == 1
31+
32+
A = [1, 2, 3, 5, 6, 7]
33+
results = generalized_connected_component_analysis_1d(
34+
A,
35+
scoring_func=lambda x,y: abs(x-y)<=1
36+
)
37+
assert len(results) == 2
38+
39+
A = [1, 2, 3, 5, 6, 7]
40+
results = generalized_connected_component_analysis_1d(
41+
A,
42+
scoring_func=lambda x,y: abs(x-y)<=2
43+
)
44+
assert len(results) == 1
45+
46+
A = [1, 2, 3, 5, 6, 7]
47+
results = generalized_connected_component_analysis_1d(
48+
A,
49+
scoring_func=lambda x,y: abs(x-y)<=1,
50+
aggregation_func=max
51+
)
52+
assert results == [3, 7]
53+
54+
def test_simple_line_detection():
55+
56+
page_layout = load_pdf("tests/fixtures/io/example.pdf")[0]
57+
58+
pdf_lines = simple_line_detection(page_layout)
59+
60+
assert len(pdf_lines) == 15
61+
62+
def test_group_textblocks_based_on_category():
63+
64+
page_layout = load_pdf("tests/fixtures/io/example.pdf")[0]
65+
66+
pdf_blocks = group_textblocks_based_on_category(page_layout)
67+
68+
assert len(pdf_blocks) == 3

0 commit comments

Comments
 (0)