feat: support splitting paragraphs into multiple shards

Signed-off-by: Konstantin Fickel <mail@konstantinfickel.de>
This commit is contained in:
Konstantin Fickel 2025-06-20 15:34:02 +02:00
parent 2091e5c98d
commit 9b13370409
2 changed files with 86 additions and 19 deletions

View file

@ -5,7 +5,9 @@ from pydantic import BaseModel
from mistletoe import Document
from mistletoe.markdown_renderer import MarkdownRenderer, Fragment
from mistletoe.span_token import SpanToken, RawText
from mistletoe.block_token import Paragraph, BlockToken
from mistletoe.token import Token
from itertools import pairwise
import re
@ -26,7 +28,6 @@ class TagMarkdownRenderer(MarkdownRenderer):
class Shard(BaseModel):
markers: list[str]
tags: list[str]
content: str
start_line: int
end_line: int
children: list[Shard]
@ -41,13 +42,13 @@ T = TypeVar("T")
def extract_tags(tokens: list[Token]) -> list[str]:
return map(
return list(map(
lambda marker: marker.content,
filter(lambda token: isinstance(token, Tag), tokens),
)
))
def extract_markers_and_tags(header: Token) -> tuple[list[str], list[str]]:
def extract_markers_and_tags(header: Optional[Token]) -> tuple[list[str], list[str]]:
marker_boundary_check = lambda token: isinstance(token, Tag) or (
isinstance(token, RawText) and re.match(r"^[\s]*$", token.content)
)
@ -57,21 +58,63 @@ def extract_markers_and_tags(header: Token) -> tuple[list[str], list[str]]:
return extract_tags(marker_region), extract_tags(tag_region)
def has_markers(token: Token) -> bool:
markers, _ = extract_markers_and_tags(token)
return len(markers) > 0
def find_shard_positions(block_tokens: list[BlockToken]) -> list[int]:
return [
index for index, block_token in enumerate(block_tokens)
if isinstance(block_token, Paragraph) and has_markers(block_token)
]
T = TypeVar('T')
def split_at(list_to_be_split: list[T], positions: list[int]):
positions = sorted(set([0, *positions, len(list_to_be_split)]))
return [
list_to_be_split[left : right]
for left, right in pairwise(positions)
]
def to_shard(tokens: list[Token], start_line: int, end_line: int, children: list[Shard] = []) -> Shard:
markers, tags = extract_markers_and_tags(tokens[0]) if len(tokens) > 0 else ([], [])
# TODO: also find tags of children!
return Shard(
markers=markers,
tags=tags,
start_line=start_line,
end_line=end_line,
children=children,
)
def parse_markdown_file(file_name: str, file_content: str) -> StreamFile:
shard = None
with TagMarkdownRenderer() as renderer:
ast = Document(file_content)
line_count = len(file_content.splitlines())
if block_tokes := ast.children:
markers, tags = extract_markers_and_tags(block_tokes[0])
shard = Shard(
markers=markers,
tags=tags,
content=file_content,
start_line=1,
end_line=len(file_content.splitlines()),
children=[],
)
if block_tokens := ast.children:
shard_starts = find_shard_positions(block_tokens)
child_shards: list[Shard] = []
own_elements: list[BlockToken] = []
for i in range(len(block_tokens)):
token = block_tokens[i]
if i in shard_starts:
end_line = block_tokens[i + 1].line_number - 1 if i + 1 < len(block_tokens) else line_count
child_shards.append(to_shard([token], token.line_number, end_line))
else:
own_elements.append(token)
if len(child_shards) == 1 and len(own_elements) == 0:
shard = child_shards[0]
else:
shard = to_shard(own_elements, 1, line_count, children=child_shards)
return StreamFile(shard=shard, filename=file_name)