refactor: use StrEnum for model capabilities instead of plain strings

This commit is contained in:
Konstantin Fickel 2026-02-15 08:04:28 +01:00
parent 6a80cfb78e
commit 8e3ed7010f
Signed by: kfickel
GPG key ID: A793722F9933C1A5

View file

@ -3,9 +3,25 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import StrEnum
from typing import Literal from typing import Literal
class Capability(StrEnum):
"""Capabilities a model may support."""
TEXT_GENERATION = "text generation"
VISION = "vision"
TEXT_TO_IMAGE = "text-to-image"
HIGH_RESOLUTION = "high resolution"
REFERENCE_IMAGES = "reference images"
CONTROL_IMAGES = "control images"
EDGE_DETECTION = "edge detection"
DEPTH_MAP = "depth map"
INPAINTING = "inpainting"
OUTPAINTING = "outpainting"
@dataclass(frozen=True) @dataclass(frozen=True)
class ModelInfo: class ModelInfo:
"""Describes a supported model and its capabilities.""" """Describes a supported model and its capabilities."""
@ -13,7 +29,7 @@ class ModelInfo:
name: str name: str
provider: str provider: str
type: Literal["text", "image"] type: Literal["text", "image"]
capabilities: list[str] capabilities: list[Capability]
TEXT_MODELS: list[ModelInfo] = [ TEXT_MODELS: list[ModelInfo] = [
@ -21,25 +37,25 @@ TEXT_MODELS: list[ModelInfo] = [
name="mistral-large-latest", name="mistral-large-latest",
provider="Mistral", provider="Mistral",
type="text", type="text",
capabilities=["text generation"], capabilities=[Capability.TEXT_GENERATION],
), ),
ModelInfo( ModelInfo(
name="mistral-small-latest", name="mistral-small-latest",
provider="Mistral", provider="Mistral",
type="text", type="text",
capabilities=["text generation"], capabilities=[Capability.TEXT_GENERATION],
), ),
ModelInfo( ModelInfo(
name="pixtral-large-latest", name="pixtral-large-latest",
provider="Mistral", provider="Mistral",
type="text", type="text",
capabilities=["text generation", "vision"], capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
), ),
ModelInfo( ModelInfo(
name="pixtral-12b-latest", name="pixtral-12b-latest",
provider="Mistral", provider="Mistral",
type="text", type="text",
capabilities=["text generation", "vision"], capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
), ),
] ]
@ -48,61 +64,69 @@ IMAGE_MODELS: list[ModelInfo] = [
name="flux-dev", name="flux-dev",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image"], capabilities=[Capability.TEXT_TO_IMAGE],
), ),
ModelInfo( ModelInfo(
name="flux-pro", name="flux-pro",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image"], capabilities=[Capability.TEXT_TO_IMAGE],
), ),
ModelInfo( ModelInfo(
name="flux-pro-1.1", name="flux-pro-1.1",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image"], capabilities=[Capability.TEXT_TO_IMAGE],
), ),
ModelInfo( ModelInfo(
name="flux-pro-1.1-ultra", name="flux-pro-1.1-ultra",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image", "high resolution"], capabilities=[Capability.TEXT_TO_IMAGE, Capability.HIGH_RESOLUTION],
), ),
ModelInfo( ModelInfo(
name="flux-2-pro", name="flux-2-pro",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image", "reference images"], capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
), ),
ModelInfo( ModelInfo(
name="flux-kontext-pro", name="flux-kontext-pro",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image", "reference images"], capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
), ),
ModelInfo( ModelInfo(
name="flux-pro-1.0-canny", name="flux-pro-1.0-canny",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image", "control images", "edge detection"], capabilities=[
Capability.TEXT_TO_IMAGE,
Capability.CONTROL_IMAGES,
Capability.EDGE_DETECTION,
],
), ),
ModelInfo( ModelInfo(
name="flux-pro-1.0-depth", name="flux-pro-1.0-depth",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image", "control images", "depth map"], capabilities=[
Capability.TEXT_TO_IMAGE,
Capability.CONTROL_IMAGES,
Capability.DEPTH_MAP,
],
), ),
ModelInfo( ModelInfo(
name="flux-pro-1.0-fill", name="flux-pro-1.0-fill",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image", "inpainting"], capabilities=[Capability.TEXT_TO_IMAGE, Capability.INPAINTING],
), ),
ModelInfo( ModelInfo(
name="flux-pro-1.0-expand", name="flux-pro-1.0-expand",
provider="BlackForestLabs", provider="BlackForestLabs",
type="image", type="image",
capabilities=["text-to-image", "outpainting"], capabilities=[Capability.TEXT_TO_IMAGE, Capability.OUTPAINTING],
), ),
] ]