feat: add download target type for fetching files from URLs
This commit is contained in:
parent
a4600df4d5
commit
c1ad6e6e3c
14 changed files with 296 additions and 74 deletions
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
|
||||
from hokusai.config import (
|
||||
Defaults,
|
||||
TargetConfig,
|
||||
GenerateTargetConfig,
|
||||
)
|
||||
from hokusai.providers.models import Capability
|
||||
from hokusai.resolve import infer_required_capabilities, resolve_model
|
||||
|
|
@ -16,37 +16,37 @@ class TestInferRequiredCapabilities:
|
|||
"""Test capability inference from file extensions and target config."""
|
||||
|
||||
def test_plain_image(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
target = GenerateTargetConfig(prompt="x")
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{Capability.TEXT_TO_IMAGE}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
|
||||
def test_image_extensions(self, name: str) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
target = GenerateTargetConfig(prompt="x")
|
||||
caps = infer_required_capabilities(name, target)
|
||||
assert Capability.TEXT_TO_IMAGE in caps
|
||||
|
||||
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
|
||||
def test_case_insensitive(self, name: str) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
target = GenerateTargetConfig(prompt="x")
|
||||
caps = infer_required_capabilities(name, target)
|
||||
assert Capability.TEXT_TO_IMAGE in caps
|
||||
|
||||
def test_image_with_reference_images(self) -> None:
|
||||
target = TargetConfig(prompt="x", reference_images=["ref.png"])
|
||||
target = GenerateTargetConfig(prompt="x", reference_images=["ref.png"])
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
|
||||
)
|
||||
|
||||
def test_image_with_control_images(self) -> None:
|
||||
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
|
||||
target = GenerateTargetConfig(prompt="x", control_images=["ctrl.png"])
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
|
||||
)
|
||||
|
||||
def test_image_with_both(self) -> None:
|
||||
target = TargetConfig(
|
||||
target = GenerateTargetConfig(
|
||||
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
|
||||
)
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
|
|
@ -59,35 +59,35 @@ class TestInferRequiredCapabilities:
|
|||
|
||||
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
|
||||
def test_text_extensions(self, name: str) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
target = GenerateTargetConfig(prompt="x")
|
||||
caps = infer_required_capabilities(name, target)
|
||||
assert caps == frozenset({Capability.TEXT_GENERATION})
|
||||
|
||||
def test_text_with_text_inputs(self) -> None:
|
||||
target = TargetConfig(prompt="x", inputs=["data.txt"])
|
||||
target = GenerateTargetConfig(prompt="x", inputs=["data.txt"])
|
||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||
{Capability.TEXT_GENERATION}
|
||||
)
|
||||
|
||||
def test_text_with_image_input(self) -> None:
|
||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||
target = GenerateTargetConfig(prompt="x", inputs=["photo.png"])
|
||||
assert infer_required_capabilities("out.txt", target) == frozenset(
|
||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||
)
|
||||
|
||||
def test_text_with_image_reference(self) -> None:
|
||||
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
|
||||
target = GenerateTargetConfig(prompt="x", reference_images=["ref.jpg"])
|
||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||
)
|
||||
|
||||
def test_unsupported_extension_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
target = GenerateTargetConfig(prompt="x")
|
||||
with pytest.raises(ValueError, match="unsupported extension"):
|
||||
_ = infer_required_capabilities("data.csv", target)
|
||||
|
||||
def test_no_extension_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
target = GenerateTargetConfig(prompt="x")
|
||||
with pytest.raises(ValueError, match="unsupported extension"):
|
||||
_ = infer_required_capabilities("Makefile", target)
|
||||
|
||||
|
|
@ -96,50 +96,52 @@ class TestResolveModel:
|
|||
"""Test model resolution with capability validation."""
|
||||
|
||||
def test_explicit_model_wins(self) -> None:
|
||||
target = TargetConfig(prompt="x", model="mistral-small-latest")
|
||||
target = GenerateTargetConfig(prompt="x", model="mistral-small-latest")
|
||||
result = resolve_model("out.txt", target, Defaults())
|
||||
assert result.name == "mistral-small-latest"
|
||||
|
||||
def test_default_text_model(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
target = GenerateTargetConfig(prompt="x")
|
||||
defaults = Defaults(text_model="mistral-large-latest")
|
||||
result = resolve_model("out.md", target, defaults)
|
||||
assert result.name == "mistral-large-latest"
|
||||
|
||||
def test_default_image_model(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
target = GenerateTargetConfig(prompt="x")
|
||||
defaults = Defaults(image_model="flux-dev")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert result.name == "flux-dev"
|
||||
|
||||
def test_unknown_model_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x", model="nonexistent-model")
|
||||
target = GenerateTargetConfig(prompt="x", model="nonexistent-model")
|
||||
with pytest.raises(ValueError, match="Unknown model"):
|
||||
_ = resolve_model("out.txt", target, Defaults())
|
||||
|
||||
def test_explicit_model_missing_capability_raises(self) -> None:
|
||||
# flux-dev does not support reference images
|
||||
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
|
||||
target = GenerateTargetConfig(
|
||||
prompt="x", model="flux-dev", reference_images=["r.png"]
|
||||
)
|
||||
with pytest.raises(ValueError, match="lacks required capabilities"):
|
||||
_ = resolve_model("out.png", target, Defaults())
|
||||
|
||||
def test_default_fallback_for_reference_images(self) -> None:
|
||||
# Default flux-dev lacks reference_images, should fall back to a capable model
|
||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||
target = GenerateTargetConfig(prompt="x", reference_images=["r.png"])
|
||||
defaults = Defaults(image_model="flux-dev")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert Capability.REFERENCE_IMAGES in result.capabilities
|
||||
|
||||
def test_default_fallback_for_vision(self) -> None:
|
||||
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
|
||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||
target = GenerateTargetConfig(prompt="x", inputs=["photo.png"])
|
||||
defaults = Defaults(text_model="mistral-large-latest")
|
||||
result = resolve_model("out.txt", target, defaults)
|
||||
assert Capability.VISION in result.capabilities
|
||||
|
||||
def test_default_preferred_when_capable(self) -> None:
|
||||
# Default flux-2-pro already supports reference_images, should be used directly
|
||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||
target = GenerateTargetConfig(prompt="x", reference_images=["r.png"])
|
||||
defaults = Defaults(image_model="flux-2-pro")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert result.name == "flux-2-pro"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue