feat: support splitting paragraphs into multiple shards

Signed-off-by: Konstantin Fickel <mail@konstantinfickel.de>
This commit is contained in:
Konstantin Fickel 2025-06-20 15:34:02 +02:00
parent 2091e5c98d
commit 9b13370409
2 changed files with 86 additions and 19 deletions

View file

@ -5,7 +5,9 @@ from pydantic import BaseModel
from mistletoe import Document from mistletoe import Document
from mistletoe.markdown_renderer import MarkdownRenderer, Fragment from mistletoe.markdown_renderer import MarkdownRenderer, Fragment
from mistletoe.span_token import SpanToken, RawText from mistletoe.span_token import SpanToken, RawText
from mistletoe.block_token import Paragraph, BlockToken
from mistletoe.token import Token from mistletoe.token import Token
from itertools import pairwise
import re import re
@ -26,7 +28,6 @@ class TagMarkdownRenderer(MarkdownRenderer):
class Shard(BaseModel): class Shard(BaseModel):
markers: list[str] markers: list[str]
tags: list[str] tags: list[str]
content: str
start_line: int start_line: int
end_line: int end_line: int
children: list[Shard] children: list[Shard]
@ -41,13 +42,13 @@ T = TypeVar("T")
def extract_tags(tokens: list[Token]) -> list[str]: def extract_tags(tokens: list[Token]) -> list[str]:
return map( return list(map(
lambda marker: marker.content, lambda marker: marker.content,
filter(lambda token: isinstance(token, Tag), tokens), filter(lambda token: isinstance(token, Tag), tokens),
) ))
def extract_markers_and_tags(header: Token) -> tuple[list[str], list[str]]: def extract_markers_and_tags(header: Optional[Token]) -> tuple[list[str], list[str]]:
marker_boundary_check = lambda token: isinstance(token, Tag) or ( marker_boundary_check = lambda token: isinstance(token, Tag) or (
isinstance(token, RawText) and re.match(r"^[\s]*$", token.content) isinstance(token, RawText) and re.match(r"^[\s]*$", token.content)
) )
@ -57,21 +58,63 @@ def extract_markers_and_tags(header: Token) -> tuple[list[str], list[str]]:
return extract_tags(marker_region), extract_tags(tag_region) return extract_tags(marker_region), extract_tags(tag_region)
def has_markers(token: Token) -> bool:
markers, _ = extract_markers_and_tags(token)
return len(markers) > 0
def find_shard_positions(block_tokens: list[BlockToken]) -> list[int]:
return [
index for index, block_token in enumerate(block_tokens)
if isinstance(block_token, Paragraph) and has_markers(block_token)
]
T = TypeVar('T')
def split_at(list_to_be_split: list[T], positions: list[int]):
positions = sorted(set([0, *positions, len(list_to_be_split)]))
return [
list_to_be_split[left : right]
for left, right in pairwise(positions)
]
def to_shard(tokens: list[Token], start_line: int, end_line: int, children: list[Shard] = []) -> Shard:
markers, tags = extract_markers_and_tags(tokens[0]) if len(tokens) > 0 else ([], [])
# TODO: also find tags of children!
return Shard(
markers=markers,
tags=tags,
start_line=start_line,
end_line=end_line,
children=children,
)
def parse_markdown_file(file_name: str, file_content: str) -> StreamFile: def parse_markdown_file(file_name: str, file_content: str) -> StreamFile:
shard = None shard = None
with TagMarkdownRenderer() as renderer: with TagMarkdownRenderer() as renderer:
ast = Document(file_content) ast = Document(file_content)
line_count = len(file_content.splitlines())
if block_tokes := ast.children: if block_tokens := ast.children:
markers, tags = extract_markers_and_tags(block_tokes[0]) shard_starts = find_shard_positions(block_tokens)
shard = Shard(
markers=markers, child_shards: list[Shard] = []
tags=tags, own_elements: list[BlockToken] = []
content=file_content,
start_line=1, for i in range(len(block_tokens)):
end_line=len(file_content.splitlines()), token = block_tokens[i]
children=[], if i in shard_starts:
) end_line = block_tokens[i + 1].line_number - 1 if i + 1 < len(block_tokens) else line_count
child_shards.append(to_shard([token], token.line_number, end_line))
else:
own_elements.append(token)
if len(child_shards) == 1 and len(own_elements) == 0:
shard = child_shards[0]
else:
shard = to_shard(own_elements, 1, line_count, children=child_shards)
return StreamFile(shard=shard, filename=file_name) return StreamFile(shard=shard, filename=file_name)

View file

@ -19,7 +19,6 @@ class TestParseProcess:
shard=Shard( shard=Shard(
markers=[], markers=[],
tags=[], tags=[],
content=test_file,
start_line=1, start_line=1,
end_line=1, end_line=1,
children=[], children=[],
@ -33,7 +32,6 @@ class TestParseProcess:
shard=Shard( shard=Shard(
markers=[], markers=[],
tags=[], tags=[],
content=test_file,
start_line=1, start_line=1,
end_line=2, end_line=2,
children=[], children=[],
@ -47,7 +45,6 @@ class TestParseProcess:
shard=Shard( shard=Shard(
markers=["Tag"], markers=["Tag"],
tags=[], tags=[],
content=test_file,
start_line=1, start_line=1,
end_line=1, end_line=1,
children=[], children=[],
@ -61,7 +58,6 @@ class TestParseProcess:
shard=Shard( shard=Shard(
markers=["Tag1", "Tag2"], markers=["Tag1", "Tag2"],
tags=[], tags=[],
content=test_file,
start_line=1, start_line=1,
end_line=1, end_line=1,
children=[], children=[],
@ -75,9 +71,37 @@ class TestParseProcess:
shard=Shard( shard=Shard(
markers=["Tag1", "Tag2"], markers=["Tag1", "Tag2"],
tags=["Tag3"], tags=["Tag3"],
content=test_file,
start_line=1, start_line=1,
end_line=1, end_line=1,
children=[], children=[],
), ),
) )
def test_parse_split_paragraphs_into_shards(self):
file_text = f"Hello World!\n\n@Tag1 Block 1\n\n@Tag2 Block 2"
assert parse_markdown_file(self.file_name, file_text) == StreamFile(
filename=self.file_name,
shard=Shard(
markers=[],
tags=[],
start_line=1,
end_line=5,
children=[
Shard(
markers=["Tag1"],
tags=[],
start_line=3,
end_line=3,
children=[],
),
Shard(
markers=["Tag2"],
tags=[],
start_line=5,
end_line=5,
children=[],
),
],
),
)