diff --git a/backend/prompt_studio/prompt_profile_manager_v2/migrations/0006_make_extraction_adapters_nullable.py b/backend/prompt_studio/prompt_profile_manager_v2/migrations/0006_make_extraction_adapters_nullable.py new file mode 100644 index 0000000000..df4c7a2c1c --- /dev/null +++ b/backend/prompt_studio/prompt_profile_manager_v2/migrations/0006_make_extraction_adapters_nullable.py @@ -0,0 +1,47 @@ +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("adapter_processor_v2", "0001_initial"), + ("prompt_profile_manager_v2", "0005_profilemanager_shared_to_org_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="profilemanager", + name="vector_store", + field=models.ForeignKey( + blank=True, + db_comment="Field to store the chosen vector store.", + null=True, + on_delete=django.db.models.deletion.PROTECT, + related_name="profiles_vector_store", + to="adapter_processor_v2.adapterinstance", + ), + ), + migrations.AlterField( + model_name="profilemanager", + name="embedding_model", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.PROTECT, + related_name="profiles_embedding_model", + to="adapter_processor_v2.adapterinstance", + ), + ), + migrations.AlterField( + model_name="profilemanager", + name="x2text", + field=models.ForeignKey( + blank=True, + db_comment="Field to store the X2Text Adapter chosen by the user", + null=True, + on_delete=django.db.models.deletion.PROTECT, + related_name="profiles_x2text", + to="adapter_processor_v2.adapterinstance", + ), + ), + ] diff --git a/backend/prompt_studio/prompt_profile_manager_v2/models.py b/backend/prompt_studio/prompt_profile_manager_v2/models.py index 10a234f462..88c6accc56 100644 --- a/backend/prompt_studio/prompt_profile_manager_v2/models.py +++ b/backend/prompt_studio/prompt_profile_manager_v2/models.py @@ -61,15 +61,15 @@ class RetrievalStrategy(models.TextChoices): vector_store = models.ForeignKey( AdapterInstance, db_comment="Field to store the chosen vector store.", - blank=False, - null=False, + blank=True, + null=True, on_delete=models.PROTECT, related_name="profiles_vector_store", ) embedding_model = models.ForeignKey( AdapterInstance, - blank=False, - null=False, + blank=True, + null=True, on_delete=models.PROTECT, related_name="profiles_embedding_model", ) @@ -84,8 +84,8 @@ class RetrievalStrategy(models.TextChoices): x2text = models.ForeignKey( AdapterInstance, db_comment="Field to store the X2Text Adapter chosen by the user", - blank=False, - null=False, + blank=True, + null=True, on_delete=models.PROTECT, related_name="profiles_x2text", ) diff --git a/backend/prompt_studio/prompt_profile_manager_v2/serializers.py b/backend/prompt_studio/prompt_profile_manager_v2/serializers.py index 008fed3850..d7b6d0040d 100644 --- a/backend/prompt_studio/prompt_profile_manager_v2/serializers.py +++ b/backend/prompt_studio/prompt_profile_manager_v2/serializers.py @@ -1,6 +1,7 @@ import logging from adapter_processor_v2.adapter_processor import AdapterProcessor +from rest_framework import serializers from backend.serializers import AuditSerializer from prompt_studio.prompt_profile_manager_v2.constants import ProfileManagerKeys @@ -9,6 +10,14 @@ logger = logging.getLogger(__name__) +# Extraction adapter fields that are only required when at least one prompt +# using this profile needs text extraction (extraction_inputs != "image"). +_TEXT_EXTRACTION_FIELDS = ( + ProfileManagerKeys.VECTOR_STORE, + ProfileManagerKeys.EMBEDDING_MODEL, + ProfileManagerKeys.X2TEXT, +) + class ProfileManagerSerializer(AuditSerializer): class Meta: @@ -18,12 +27,49 @@ class Meta: # the DRF auto-validator that 400s on re-save / PUT before the view runs. validators = [] + def validate(self, attrs): + """Enforce x2text/embedding/vector_store when text extraction needed. + + These fields are nullable at the DB level to support image-only + profiles, but must be populated when any prompt using this profile + requires text extraction. + """ + attrs = super().validate(attrs) + + instance = self.instance + if instance is not None: + # Update: check prompts currently linked to this profile + needs_text = instance.tool_studio_prompts.exclude( + extraction_inputs="image" + ).exists() + else: + # Create: no prompts linked yet — require extraction adapters + # by default so existing flows are unaffected + needs_text = True + + if needs_text: + missing = [ + field + for field in _TEXT_EXTRACTION_FIELDS + if not attrs.get(field) + and (instance is None or not getattr(instance, f"{field}_id", None)) + ] + if missing: + raise serializers.ValidationError( + { + field: "This field is required when any linked prompt " + "uses text extraction." + for field in missing + } + ) + return attrs + def to_representation(self, instance): # type: ignore rep: dict[str, str] = super().to_representation(instance) llm = rep[ProfileManagerKeys.LLM] - embedding = rep[ProfileManagerKeys.EMBEDDING_MODEL] - vector_db = rep[ProfileManagerKeys.VECTOR_STORE] - x2text = rep[ProfileManagerKeys.X2TEXT] + embedding = rep.get(ProfileManagerKeys.EMBEDDING_MODEL) + vector_db = rep.get(ProfileManagerKeys.VECTOR_STORE) + x2text = rep.get(ProfileManagerKeys.X2TEXT) if llm: rep[ProfileManagerKeys.LLM] = AdapterProcessor.get_adapter_instance_by_id(llm) if embedding: diff --git a/backend/prompt_studio/prompt_studio_core_v2/constants.py b/backend/prompt_studio/prompt_studio_core_v2/constants.py index 03bd68c1d8..467c345638 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/constants.py +++ b/backend/prompt_studio/prompt_studio_core_v2/constants.py @@ -108,6 +108,10 @@ class ToolStudioPromptKeys: # Webhook postprocessing settings ENABLE_POSTPROCESSING_WEBHOOK = "enable_postprocessing_webhook" POSTPROCESSING_WEBHOOK_URL = "postprocessing_webhook_url" + # Vision mode fields + EXTRACTION_INPUTS = "extraction_inputs" + SOURCE_OF_TRUTH = "source_of_truth" + SOURCE_FILE_PATH = "source_file_path" class FileViewTypes: diff --git a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py index 91515db118..8640c5c00e 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py @@ -411,6 +411,10 @@ def _build_prompt_output( if lookup_config := get_lookup_config(prompt): output["lookup_config"] = lookup_config + # Vision mode fields + output[TSPKeys.EXTRACTION_INPUTS] = prompt.extraction_inputs + output[TSPKeys.SOURCE_OF_TRUTH] = prompt.source_of_truth + output[TSPKeys.EVAL_SETTINGS] = {} output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_EVALUATE] = prompt.evaluate output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_MONITOR_LLM] = [monitor_llm] @@ -825,6 +829,10 @@ def build_fetch_response_payload( if lookup_config := get_lookup_config(prompt): output["lookup_config"] = lookup_config + # Vision mode fields + output[TSPKeys.EXTRACTION_INPUTS] = prompt.extraction_inputs + output[TSPKeys.SOURCE_OF_TRUTH] = prompt.source_of_truth + output[TSPKeys.EVAL_SETTINGS] = {} output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_EVALUATE] = prompt.evaluate output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_MONITOR_LLM] = [monitor_llm] @@ -874,6 +882,7 @@ def build_fetch_response_payload( TSPKeys.FILE_NAME: doc_name, TSPKeys.FILE_HASH: file_hash, TSPKeys.FILE_PATH: extract_path, + TSPKeys.SOURCE_FILE_PATH: file_path, Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, TSPKeys.CUSTOM_DATA: tool.custom_data, @@ -1064,6 +1073,7 @@ def build_bulk_fetch_response_payload( TSPKeys.FILE_NAME: doc_name, TSPKeys.FILE_HASH: file_hash, TSPKeys.FILE_PATH: extract_path, + TSPKeys.SOURCE_FILE_PATH: file_path, Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, TSPKeys.CUSTOM_DATA: tool.custom_data, @@ -1225,6 +1235,7 @@ def build_single_pass_payload( TSPKeys.FILE_HASH: file_hash, TSPKeys.FILE_NAME: doc_name, TSPKeys.FILE_PATH: file_path, + TSPKeys.SOURCE_FILE_PATH: doc_path, Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, TSPKeys.CUSTOM_DATA: tool.custom_data, @@ -1950,6 +1961,9 @@ def _fetch_response( output[TSPKeys.POSTPROCESSING_WEBHOOK_URL] = webhook_url if lookup_config := get_lookup_config(prompt): output["lookup_config"] = lookup_config + # Vision mode fields + output[TSPKeys.EXTRACTION_INPUTS] = prompt.extraction_inputs + output[TSPKeys.SOURCE_OF_TRUTH] = prompt.source_of_truth # Eval settings for the prompt output[TSPKeys.EVAL_SETTINGS] = {} output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_EVALUATE] = prompt.evaluate @@ -2000,6 +2014,7 @@ def _fetch_response( TSPKeys.FILE_NAME: doc_name, TSPKeys.FILE_HASH: file_hash, TSPKeys.FILE_PATH: doc_path, + TSPKeys.SOURCE_FILE_PATH: doc_path, Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, TSPKeys.CUSTOM_DATA: tool.custom_data, diff --git a/backend/prompt_studio/prompt_studio_registry_v2/constants.py b/backend/prompt_studio/prompt_studio_registry_v2/constants.py index 35d6654851..a140821061 100644 --- a/backend/prompt_studio/prompt_studio_registry_v2/constants.py +++ b/backend/prompt_studio/prompt_studio_registry_v2/constants.py @@ -106,6 +106,9 @@ class JsonSchemaKey: ENABLE_POSTPROCESSING_WEBHOOK = "enable_postprocessing_webhook" POSTPROCESSING_WEBHOOK_URL = "postprocessing_webhook_url" WORD_CONFIDENCE_POSTAMBLE = "word_confidence_postamble" + # Vision mode fields + EXTRACTION_INPUTS = "extraction_inputs" + SOURCE_OF_TRUTH = "source_of_truth" class SpecKey: diff --git a/backend/prompt_studio/prompt_studio_registry_v2/prompt_studio_registry_helper.py b/backend/prompt_studio/prompt_studio_registry_v2/prompt_studio_registry_helper.py index 4fee8c10bc..23acbcab5b 100644 --- a/backend/prompt_studio/prompt_studio_registry_v2/prompt_studio_registry_helper.py +++ b/backend/prompt_studio/prompt_studio_registry_v2/prompt_studio_registry_helper.py @@ -266,10 +266,19 @@ def frame_export_json( embedding_suffix = "" adapter_id = "" - vector_db = str(default_llm_profile.vector_store.id) - embedding_model = str(default_llm_profile.embedding_model.id) + # Extraction adapters may be null for image-only profiles + vector_db = ( + str(default_llm_profile.vector_store.id) + if default_llm_profile.vector_store + else "" + ) + embedding_model = ( + str(default_llm_profile.embedding_model.id) + if default_llm_profile.embedding_model + else "" + ) llm = str(default_llm_profile.llm.id) - x2text = str(default_llm_profile.x2text.id) + x2text = str(default_llm_profile.x2text.id) if default_llm_profile.x2text else "" # Tool settings tool_settings = {} @@ -328,36 +337,51 @@ def frame_export_json( invalidated_outputs.append(prompt.prompt_key) continue - vector_db = str(prompt.profile_manager.vector_store.id) - embedding_model = str(prompt.profile_manager.embedding_model.id) - llm = str(prompt.profile_manager.llm.id) - x2text = str(prompt.profile_manager.x2text.id) - adapter_id = str(prompt.profile_manager.embedding_model.adapter_id) - embedding_suffix = adapter_id.split("|")[0] + # Extraction adapters may be null for image-only prompts + pm = prompt.profile_manager + vector_db = str(pm.vector_store.id) if pm.vector_store else "" + embedding_model = str(pm.embedding_model.id) if pm.embedding_model else "" + llm = str(pm.llm.id) + x2text = str(pm.x2text.id) if pm.x2text else "" + if pm.embedding_model: + adapter_id = str(pm.embedding_model.adapter_id) + embedding_suffix = adapter_id.split("|")[0] + else: + adapter_id = "" + embedding_suffix = "" output[JsonSchemaKey.PROMPT] = prompt.prompt output[JsonSchemaKey.ACTIVE] = prompt.active output[JsonSchemaKey.REQUIRED] = prompt.required - output[JsonSchemaKey.CHUNK_SIZE] = prompt.profile_manager.chunk_size + output[JsonSchemaKey.CHUNK_SIZE] = pm.chunk_size output[JsonSchemaKey.VECTOR_DB] = vector_db output[JsonSchemaKey.EMBEDDING] = embedding_model output[JsonSchemaKey.X2TEXT_ADAPTER] = x2text - output[JsonSchemaKey.CHUNK_OVERLAP] = prompt.profile_manager.chunk_overlap + output[JsonSchemaKey.CHUNK_OVERLAP] = pm.chunk_overlap output[JsonSchemaKey.LLM] = llm output[JsonSchemaKey.PREAMBLE] = tool.preamble output[JsonSchemaKey.POSTAMBLE] = tool.postamble output[JsonSchemaKey.GRAMMAR] = grammar_list output[JsonSchemaKey.TYPE] = prompt.enforce_type output[JsonSchemaKey.NAME] = prompt.prompt_key - output[JsonSchemaKey.RETRIEVAL_STRATEGY] = ( - prompt.profile_manager.retrieval_strategy - ) - output[JsonSchemaKey.SIMILARITY_TOP_K] = ( - prompt.profile_manager.similarity_top_k - ) - output[JsonSchemaKey.SECTION] = prompt.profile_manager.section - output[JsonSchemaKey.REINDEX] = prompt.profile_manager.reindex + output[JsonSchemaKey.RETRIEVAL_STRATEGY] = pm.retrieval_strategy + output[JsonSchemaKey.SIMILARITY_TOP_K] = pm.similarity_top_k + output[JsonSchemaKey.SECTION] = pm.section + output[JsonSchemaKey.REINDEX] = pm.reindex output[JsonSchemaKey.EMBEDDING_SUFFIX] = embedding_suffix + # Vision mode fields — force text_only when single-pass is enabled + if tool.single_pass_extraction_mode and prompt.extraction_inputs != "text": + logger.warning( + "Single-pass extraction enabled: forcing prompt '%s' " + "from extraction_inputs='%s' to 'text' in export", + prompt.prompt_key, + prompt.extraction_inputs, + ) + output[JsonSchemaKey.EXTRACTION_INPUTS] = "text" + output[JsonSchemaKey.SOURCE_OF_TRUTH] = "text" + else: + output[JsonSchemaKey.EXTRACTION_INPUTS] = prompt.extraction_inputs + output[JsonSchemaKey.SOURCE_OF_TRUTH] = prompt.source_of_truth # Webhook postprocessing settings output[JsonSchemaKey.ENABLE_POSTPROCESSING_WEBHOOK] = ( prompt.enable_postprocessing_webhook diff --git a/backend/prompt_studio/prompt_studio_v2/migrations/0015_toolstudioprompt_extraction_inputs_and_more.py b/backend/prompt_studio/prompt_studio_v2/migrations/0015_toolstudioprompt_extraction_inputs_and_more.py new file mode 100644 index 0000000000..6d83a79dda --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/migrations/0015_toolstudioprompt_extraction_inputs_and_more.py @@ -0,0 +1,36 @@ +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("prompt_studio_v2", "0014_alter_toolstudioprompt_enforce_type"), + ] + + operations = [ + migrations.AddField( + model_name="toolstudioprompt", + name="extraction_inputs", + field=models.TextField( + choices=[ + ("text", "Text only (default)"), + ("image", "Page image only"), + ("both", "Text and page image"), + ], + db_comment="What inputs to send to the LLM: text, image, or both", + default="text", + ), + ), + migrations.AddField( + model_name="toolstudioprompt", + name="source_of_truth", + field=models.TextField( + choices=[ + ("text", "Text is source of truth"), + ("image", "Image is source of truth"), + ], + db_comment="Which input is source of truth " + "(only meaningful when extraction_inputs=both)", + default="text", + ), + ), + ] diff --git a/backend/prompt_studio/prompt_studio_v2/models.py b/backend/prompt_studio/prompt_studio_v2/models.py index b4c64f0912..217ad41848 100644 --- a/backend/prompt_studio/prompt_studio_v2/models.py +++ b/backend/prompt_studio/prompt_studio_v2/models.py @@ -33,6 +33,15 @@ class EnforceType(models.TextChoices): TABLE = "table", "Response sent as json" AGENTIC_TABLE = "agentic_table", "Response sent as agentic table extraction" + class ExtractionInput(models.TextChoices): + TEXT = "text", "Text only (default)" + IMAGE = "image", "Page image only" + BOTH = "both", "Text and page image" + + class SourceOfTruth(models.TextChoices): + TEXT = "text", "Text is source of truth" + IMAGE = "image", "Image is source of truth" + class PromptType(models.TextChoices): PROMPT = "PROMPT", "Response sent as Text" NOTES = "NOTES", "Response sent as float" @@ -117,6 +126,17 @@ class RequiredType(models.TextChoices): default=dict, db_comment="JSON adapter metadata for the FE to load the pagination", ) + extraction_inputs = models.TextField( + choices=ExtractionInput.choices, + default=ExtractionInput.TEXT, + db_comment="What inputs to send to the LLM: text, image, or both", + ) + source_of_truth = models.TextField( + choices=SourceOfTruth.choices, + default=SourceOfTruth.TEXT, + db_comment="Which input is source of truth " + "(only meaningful when extraction_inputs=both)", + ) created_by = models.ForeignKey( User, on_delete=models.SET_NULL, diff --git a/backend/uv.lock b/backend/uv.lock index bad2b30614..d5919761db 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = "==3.12.*" [manifest] @@ -3908,6 +3908,7 @@ dependencies = [ { name = "llama-parse" }, { name = "llmwhisperer-client" }, { name = "pdfplumber" }, + { name = "pypdfium2" }, { name = "python-dotenv" }, { name = "python-magic" }, { name = "qdrant-client" }, @@ -3947,6 +3948,7 @@ requires-dist = [ { name = "llama-parse", specifier = ">=0.6.0" }, { name = "llmwhisperer-client", specifier = ">=2.6.2" }, { name = "pdfplumber", specifier = ">=0.11.2" }, + { name = "pypdfium2", specifier = ">=4.0.0" }, { name = "python-dotenv", specifier = "==1.2.2" }, { name = "python-magic", specifier = "~=0.4.27" }, { name = "qdrant-client", specifier = ">=1.16.0,<1.17.0" }, diff --git a/backend/workflow_manager/endpoint_v2/source.py b/backend/workflow_manager/endpoint_v2/source.py index 23b0d04f2c..44ed741b8f 100644 --- a/backend/workflow_manager/endpoint_v2/source.py +++ b/backend/workflow_manager/endpoint_v2/source.py @@ -103,6 +103,7 @@ def __init__( self.hash_value_of_file_content: str | None = None self.workflow_log = workflow_log self.use_file_history = use_file_history + self._vision_signature = self._compute_vision_signature(workflow) def _get_endpoint_for_workflow( self, @@ -122,6 +123,67 @@ def _get_endpoint_for_workflow( ) return endpoint + @staticmethod + def _compute_vision_signature(workflow: Workflow) -> str: + """Compute a vision mode signature for cache key discrimination. + + Queries the workflow's tool configuration to determine whether any + prompt uses vision mode (extraction_inputs != 'text'). Returns a + deterministic signature string that is appended to the file content + hash, preventing cache collisions between text-only and vision runs + of the same file. + + Returns empty string when no vision prompts exist, preserving + backward-compatible hash values for existing text-only workflows. + + Note: For ETL/TASK connectors using provider_file_uuid-based file + history lookups, this discriminator is not applied because the + content hash is not available at lookup time. Users must clear + file markers after changing vision mode on such workflows. + """ + try: + from prompt_studio.prompt_studio_registry_v2.models import ( + PromptStudioRegistry, + ) + from tool_instance_v2.models import ToolInstance + + tool_instance = ToolInstance.objects.filter(workflow=workflow).first() + if not tool_instance or not tool_instance.metadata: + return "" + + prompt_registry_id = tool_instance.metadata.get("prompt_registry_id") + if not prompt_registry_id: + return "" + + registry = PromptStudioRegistry.objects.filter( + prompt_registry_id=prompt_registry_id + ).first() + if not registry or not registry.tool_metadata: + return "" + + outputs = registry.tool_metadata.get("outputs", []) + vision_parts: list[str] = [] + for output in outputs: + extraction_inputs = output.get("extraction_inputs", "text") + source_of_truth = output.get("source_of_truth", "text") + if extraction_inputs != "text": + name = output.get("name", "") + vision_parts.append(f"{name}:{extraction_inputs}:{source_of_truth}") + + if not vision_parts: + return "" + + # Sort for determinism across runs + return "|vision:" + ",".join(sorted(vision_parts)) + except Exception: + logger.warning( + "Failed to compute vision signature for workflow %s, " + "using content-only hash", + workflow.id, + exc_info=True, + ) + return "" + def validate(self) -> None: connection_type = self.endpoint.connection_type connector: ConnectorInstance = self.endpoint.connector_instance @@ -938,19 +1000,24 @@ def list_files_from_source( def get_file_content_hash(self, source_fs: UnstractFileSystem, file_path: str) -> str: """Generate a hash value from the file content. + Includes vision signature in the hash when any prompt uses vision mode, + preventing cache collisions between text-only and vision runs. + Args: source_fs (UnstractFileSystem): The file system object used for reading the file. file_path (str): The path of the file. Returns: - str: The hash value of the file content. + str: The hash value of the file content (with vision discriminator). """ file_content_hash = sha256() source = source_fs.get_fsspec_fs() with source.open(file_path, "rb") as remote_file: while chunk := remote_file.read(self.READ_CHUNK_SIZE): file_content_hash.update(chunk) + if self._vision_signature: + file_content_hash.update(self._vision_signature.encode()) return file_content_hash.hexdigest() def copy_file_to_infile_dir(self, source_file_path: str, infile_path: str) -> None: @@ -1021,6 +1088,8 @@ def add_input_from_connector_to_volume(self, file_hash: FileHash) -> str: # This function is typically relevant for extracted text content, # may not be necessary for PDFs, images, or other non-text formats. self.publish_input_file_content(input_file_path, input_log) + if self._vision_signature: + file_content_hash.update(self._vision_signature.encode()) hash_value_of_file_content = file_content_hash.hexdigest() file_hash.mime_type = mime_type logger.info( @@ -1221,6 +1290,7 @@ def add_input_file_to_api_storage( workflow_id=workflow_id, execution_id=execution_id ) workflow: Workflow = Workflow.objects.get(id=workflow_id) + vision_signature = cls._compute_vision_signature(workflow) file_hashes: dict[str, FileHash] = {} unique_file_hashes: set[str] = set() connection_type = WorkflowEndpoint.ConnectionType.API @@ -1260,6 +1330,8 @@ def add_input_file_to_api_storage( for chunk in file.chunks(chunk_size=cls.READ_CHUNK_SIZE): file_hash.update(chunk) file_storage.write(path=destination_path, mode="ab", data=chunk) + if vision_signature: + file_hash.update(vision_signature.encode()) file_hash = file_hash.hexdigest() # Skip duplicate files diff --git a/backend/workflow_manager/workflow_v2/file_history_helper.py b/backend/workflow_manager/workflow_v2/file_history_helper.py index 687f88dfe9..67b67d26b4 100644 --- a/backend/workflow_manager/workflow_v2/file_history_helper.py +++ b/backend/workflow_manager/workflow_v2/file_history_helper.py @@ -58,8 +58,14 @@ def get_file_history( filters = Q(workflow=workflow) if cache_key: + # cache_key includes vision mode signature when applicable, + # so lookups by cache_key are vision-aware automatically. filters &= Q(cache_key=cache_key) elif provider_file_uuid: + # Note: provider_file_uuid lookups are NOT vision-discriminated. + # If vision mode changes on a deployed workflow using file connectors + # with provider UUIDs (e.g., Google Drive), users must clear file + # markers to trigger reprocessing. filters &= Q(provider_file_uuid=provider_file_uuid) file_history: FileHistory | None diff --git a/backend/workflow_manager/workflow_v2/models/file_history.py b/backend/workflow_manager/workflow_v2/models/file_history.py index c45ec609dc..a8a76c3fce 100644 --- a/backend/workflow_manager/workflow_v2/models/file_history.py +++ b/backend/workflow_manager/workflow_v2/models/file_history.py @@ -46,7 +46,7 @@ def has_exceeded_limit(self, workflow: "Workflow") -> bool: id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) cache_key = models.CharField( max_length=HASH_LENGTH, - db_comment="Hash value of file contents, WF and tool modified times", + db_comment="SHA256 of file contents, includes vision mode signature when any prompt uses vision", ) provider_file_uuid = models.CharField( max_length=HASH_LENGTH, diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx b/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx index 5cb5d099b1..01fa1eb048 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx +++ b/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx @@ -47,6 +47,14 @@ try { // The component will remain null of it is not available } +let VisionModeSelector; +try { + const mod = await import("../../../plugins/prompt-card/VisionModeSelector"); + VisionModeSelector = mod.VisionModeSelector; +} catch { + // Cloud-only vision mode selector; stays undefined in OSS builds +} + function PromptCardItems({ promptDetails, enforceTypeList, @@ -308,6 +316,19 @@ function PromptCardItems({ )} + {VisionModeSelector && ( + + )} {TableExtractionSettingsBtn && ( =0.25.2", "pdfplumber>=0.11.2", + "pypdfium2>=4.0.0", "redis>=5.2.1", # # LLMWhisperer client "llmwhisperer-client>=2.6.2", diff --git a/unstract/sdk1/src/unstract/sdk1/constants.py b/unstract/sdk1/src/unstract/sdk1/constants.py index 3494a8e2a6..f27760668a 100644 --- a/unstract/sdk1/src/unstract/sdk1/constants.py +++ b/unstract/sdk1/src/unstract/sdk1/constants.py @@ -1,3 +1,4 @@ +import os from enum import Enum @@ -205,3 +206,62 @@ class RequestHeader: REQUEST_ID = "X-Request-ID" AUTHORIZATION = "Authorization" + + +# --------------------------------------------------------------------------- +# Vision mode constants +# --------------------------------------------------------------------------- + +MAX_VISION_PAGES = int(os.environ.get("MAX_VISION_PAGES", "10")) + + +class VisionMode: + """Derived vision mode values used at the answer step.""" + + TEXT_ONLY = "text_only" + SPATIAL_HELPER = "spatial_helper" + SOURCE_OF_TRUTH = "source_of_truth" + + +class ExtractionInputs: + """Values for the per-prompt extraction_inputs field.""" + + TEXT = "text" + IMAGE = "image" + BOTH = "both" + + +class SourceOfTruthValues: + """Values for the per-prompt source_of_truth field.""" + + TEXT = "text" + IMAGE = "image" + + +def derive_vision_mode(extraction_inputs: str, source_of_truth: str) -> str: + """Derive the vision mode from two orthogonal per-prompt UI fields. + + +-------------------+-----------------+---------------------+ + | extraction_inputs | source_of_truth | vision_mode | + +-------------------+-----------------+---------------------+ + | text | (ignored) | text_only | + | image | (ignored) | source_of_truth | + | both | text | spatial_helper | + | both | image | source_of_truth | + +-------------------+-----------------+---------------------+ + + Args: + extraction_inputs: One of "text", "image", "both". + source_of_truth: One of "text", "image". + + Returns: + One of VisionMode.TEXT_ONLY, SPATIAL_HELPER, SOURCE_OF_TRUTH. + """ + if extraction_inputs == ExtractionInputs.TEXT: + return VisionMode.TEXT_ONLY + if extraction_inputs == ExtractionInputs.IMAGE: + return VisionMode.SOURCE_OF_TRUTH + # extraction_inputs == "both" + if source_of_truth == SourceOfTruthValues.IMAGE: + return VisionMode.SOURCE_OF_TRUTH + return VisionMode.SPATIAL_HELPER diff --git a/unstract/sdk1/src/unstract/sdk1/llm.py b/unstract/sdk1/src/unstract/sdk1/llm.py index b45b856239..8b12f00422 100644 --- a/unstract/sdk1/src/unstract/sdk1/llm.py +++ b/unstract/sdk1/src/unstract/sdk1/llm.py @@ -414,6 +414,9 @@ def complete_vision( ) -> dict[str, object]: """Chat completion with multimodal (text + image) messages. + Full parity with complete(): retry logic, extract_json, + process_text callback, usage recording, and error wrapping. + Accepts pre-built messages with image_url content blocks:: [ @@ -432,18 +435,17 @@ def complete_vision( LiteLLM auto-translates the OpenAI-style image format for Anthropic, Bedrock, Vertex, and other providers. - Same error handling, usage tracking, and metrics as complete(). - Args: messages: List of message dicts with multimodal content. **kwargs: Additional arguments passed to litellm.completion(). + Supports extract_json (bool) and process_text (callable) + same as complete(). Returns: - dict with "response" key containing LLMResponseCompat. + dict with "response" key containing LLMResponseCompat, + plus any post-processed output from extract_json/process_text. """ try: - litellm.drop_params = True - logger.debug( f"[sdk1][LLM]Invoking {self.adapter.get_provider()} " f"vision completion API" @@ -452,9 +454,14 @@ def complete_vision( completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs}) completion_kwargs.pop("cost_model", None) - response: dict[str, object] = litellm.completion( - messages=messages, - **completion_kwargs, + max_retries = pop_litellm_retry_kwargs( + completion_kwargs, self._get_adapter_info() + ) + response: dict[str, object] = call_with_retry( + lambda: litellm.completion(messages=messages, **completion_kwargs), + max_retries=max_retries, + retry_predicate=is_retryable_litellm_error, + description=self._get_adapter_info(), ) response_text = response["choices"][0]["message"]["content"] @@ -471,9 +478,21 @@ def complete_vision( if response_text is None: self._raise_for_empty_response(finish_reason) + extract_json: bool = cast("bool", kwargs.get("extract_json", False)) + post_process_fn: ( + Callable[[LLMResponseCompat, bool, str], dict[str, object]] | None + ) = cast( + "Callable[[LLMResponseCompat, bool, str], dict[str, object]] | None", + kwargs.get("process_text", None), + ) + + response_text, post_processed_output = self._post_process_response( + response_text, extract_json, post_process_fn + ) + response_object = LLMResponseCompat(response_text) response_object.raw = response - return {"response": response_object} + return {"response": response_object, **post_processed_output} except LLMError: raise diff --git a/unstract/sdk1/src/unstract/sdk1/rasteriser.py b/unstract/sdk1/src/unstract/sdk1/rasteriser.py new file mode 100644 index 0000000000..095393238f --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/rasteriser.py @@ -0,0 +1,164 @@ +"""In-memory PDF page rasteriser for VLM vision calls. + +Renders PDF pages to preprocessed PNG bytes without writing to disk. +Preprocessing mirrors the agentic table pipeline: + 1. Render at DPI via pypdfium2 + 2. Upscale 2x with LANCZOS + 3. Gaussian blur (radius=0.5) to smooth aliasing + 4. UnsharpMask to restore text sharpness + 5. Constrain to max_dimension +""" + +import io +import logging +import os +from dataclasses import dataclass + +import pypdfium2 as pdfium +from PIL import Image, ImageFilter + +logger = logging.getLogger(__name__) + +MAX_VISION_PAGES = int(os.environ.get("MAX_VISION_PAGES", "10")) + + +@dataclass +class RenderSettings: + """Configuration for PDF page rendering and preprocessing.""" + + dpi: int = 150 + max_dimension: int = 1568 + + @property + def scale(self) -> float: + """Convert DPI to pypdfium2 scale factor (PDF points are 72 DPI).""" + return self.dpi / 72.0 + + +def _preprocess_image(img: Image.Image, max_dimension: int) -> Image.Image: + """Apply the proven agentic-table preprocessing pipeline. + + Args: + img: Raw rendered PIL image (RGB). + max_dimension: Maximum width or height after preprocessing. + + Returns: + Preprocessed PIL image ready for PNG encoding. + """ + # Upscale 2x with high-quality resampling + upscaled = img.resize( + (img.width * 2, img.height * 2), + Image.Resampling.LANCZOS, + ) + + # Smooth aliasing artifacts + smoothed = upscaled.filter(ImageFilter.GaussianBlur(radius=0.5)) + + # Restore text sharpness + sharpened = smoothed.filter( + ImageFilter.UnsharpMask(radius=1, percent=50, threshold=3) + ) + + # Constrain to max_dimension + if max(sharpened.size) > max_dimension: + ratio = max_dimension / max(sharpened.size) + new_size = (int(sharpened.width * ratio), int(sharpened.height * ratio)) + sharpened = sharpened.resize(new_size, Image.Resampling.LANCZOS) + + return sharpened + + +def rasterise_pages( + file_bytes: bytes, + settings: RenderSettings | None = None, + page_set: set[int] | None = None, + max_pages: int = MAX_VISION_PAGES, +) -> list[tuple[int, bytes]]: + """Render PDF pages to preprocessed PNG bytes in memory. + + Args: + file_bytes: Raw PDF file content. + settings: Render configuration. Uses defaults if None. + page_set: 0-indexed page numbers to render. None renders all pages. + max_pages: Maximum number of pages to render. Logs a warning + if the document has more pages than this limit. + + Returns: + List of (page_number, png_bytes) tuples. Page numbers are 0-indexed. + + Raises: + ValueError: If file_bytes is empty or not a valid PDF. + """ + if not file_bytes: + raise ValueError("file_bytes is empty") + + if settings is None: + settings = RenderSettings() + + pdf = pdfium.PdfDocument(file_bytes) + try: + total_pages = len(pdf) + if total_pages == 0: + logger.warning("PDF has 0 pages, nothing to rasterise") + return [] + + # Determine which pages to render + if page_set is not None: + # Filter to valid pages and sort + pages_to_render = sorted(p for p in page_set if 0 <= p < total_pages) + invalid = page_set - set(range(total_pages)) + if invalid: + logger.warning( + "Skipping invalid page numbers %s (PDF has %d pages)", + sorted(invalid), + total_pages, + ) + else: + pages_to_render = list(range(total_pages)) + + # Apply max_pages cap + if len(pages_to_render) > max_pages: + logger.warning( + "Truncating from %d to %d pages (MAX_VISION_PAGES=%d)", + len(pages_to_render), + max_pages, + max_pages, + ) + pages_to_render = pages_to_render[:max_pages] + + results: list[tuple[int, bytes]] = [] + for page_num in pages_to_render: + page = pdf[page_num] + bitmap = page.render(scale=settings.scale) + pil_image = bitmap.to_pil() + + if pil_image.mode != "RGB": + pil_image = pil_image.convert("RGB") + + # Apply preprocessing pipeline + processed = _preprocess_image(pil_image, settings.max_dimension) + + # Encode as PNG bytes + buffer = io.BytesIO() + processed.save(buffer, format="PNG", optimize=True) + results.append((page_num, buffer.getvalue())) + + logger.debug( + "Rasterised page %d: %dx%d → %dx%d", + page_num, + pil_image.width, + pil_image.height, + processed.width, + processed.height, + ) + + logger.info( + "Rasterised %d/%d pages (dpi=%d, max_dim=%d)", + len(results), + total_pages, + settings.dpi, + settings.max_dimension, + ) + return results + finally: + pdf.close() diff --git a/unstract/sdk1/src/unstract/sdk1/vision.py b/unstract/sdk1/src/unstract/sdk1/vision.py new file mode 100644 index 0000000000..3e9283f3c7 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/vision.py @@ -0,0 +1,124 @@ +"""Vision message builder for VLM completions. + +Assembles OpenAI-style messages with text + base64 image blocks. +LiteLLM auto-translates for Anthropic, Bedrock, Vertex, etc. +""" + +import base64 +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def build_vision_messages( + system_prompt: str, + text_context: str | None, + page_images: list[tuple[int, bytes]], + prompt: str, + mode: str, +) -> list[dict[str, Any]]: + """Assemble OpenAI-style multimodal messages for VLM completion. + + The ordering of text and image content blocks depends on the mode: + + - ``spatial_helper``: Text context first (primary source of truth), + then images as spatial aids for layout understanding. + - ``source_of_truth``: Images first (primary source of truth), + then text as an optional secondary hint. + + Args: + system_prompt: System message content for the LLM. + text_context: Extracted text from the document. May be None or + empty for image-only extraction. + page_images: List of (page_number, png_bytes) from the rasteriser. + prompt: The user's extraction prompt / question. + mode: Vision mode — ``"spatial_helper"`` or ``"source_of_truth"``. + + Returns: + List of message dicts ready for ``llm.complete_vision()``. + + Raises: + ValueError: If mode is not one of the supported values. + """ + valid_modes = ("spatial_helper", "source_of_truth") + if mode not in valid_modes: + raise ValueError(f"Invalid vision mode '{mode}'. Must be one of {valid_modes}") + + if not page_images: + raise ValueError("page_images is empty — cannot build vision messages") + + messages: list[dict[str, Any]] = [] + + # System message + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Build multimodal user content blocks + content: list[dict[str, Any]] = [] + + if mode == "spatial_helper": + # Text is primary, images are spatial aids + _append_text_context(content, text_context, is_primary=True) + _append_image_blocks(content, page_images, is_primary=False) + else: + # source_of_truth: Images are primary, text is a hint + _append_image_blocks(content, page_images, is_primary=True) + _append_text_context(content, text_context, is_primary=False) + + # Append the extraction prompt last + content.append({"type": "text", "text": prompt}) + + messages.append({"role": "user", "content": content}) + + logger.debug( + "Built vision messages: mode=%s, images=%d, has_text=%s", + mode, + len(page_images), + bool(text_context), + ) + return messages + + +def _append_text_context( + content: list[dict[str, Any]], + text_context: str | None, + is_primary: bool, +) -> None: + """Append text context block with appropriate framing.""" + if not text_context: + if is_primary: + # This shouldn't happen for spatial_helper — text is expected + logger.warning("No text context available for primary text source") + return + + if is_primary: + label = "DOCUMENT TEXT (primary source — use for extraction):" + else: + label = "DOCUMENT TEXT (supplementary reference):" + + content.append({"type": "text", "text": f"{label}\n{text_context}"}) + + +def _append_image_blocks( + content: list[dict[str, Any]], + page_images: list[tuple[int, bytes]], + is_primary: bool, +) -> None: + """Append image blocks with page labels.""" + if is_primary: + label = "DOCUMENT PAGES (primary source — use for extraction):" + else: + label = "DOCUMENT PAGES (spatial reference for layout context):" + + content.append({"type": "text", "text": label}) + + for page_num, png_bytes in page_images: + b64 = base64.standard_b64encode(png_bytes).decode("utf-8") + content.append({"type": "text", "text": f"Page {page_num + 1}:"}) + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{b64}"}, + } + ) diff --git a/unstract/sdk1/tests/test_complete_vision.py b/unstract/sdk1/tests/test_complete_vision.py new file mode 100644 index 0000000000..d12d827e0b --- /dev/null +++ b/unstract/sdk1/tests/test_complete_vision.py @@ -0,0 +1,219 @@ +"""Tests for LLM.complete_vision() method.""" + +from typing import Self +from unittest.mock import MagicMock, patch + +import pytest +from unstract.sdk1.exceptions import LLMError +from unstract.sdk1.llm import LLM +from unstract.sdk1.utils.common import LLMResponseCompat + + +def _make_llm() -> LLM: + """Create an LLM instance with mocked internals (bypassing __init__).""" + llm = object.__new__(LLM) + + # Adapter + llm.adapter = MagicMock() + llm.adapter.get_provider.return_value = "test-provider" + llm.adapter.validate.side_effect = lambda kwargs: kwargs + + # LLM kwargs + llm.kwargs = {"model": "test-vision-model"} + llm._cost_model = None + llm._adapter_name = "Test Adapter" + + # Metrics — capture_metrics decorator checks these + llm._capture_metrics = False + llm._run_id = None + llm._metrics = {} + + # Usage recording + llm._record_usage = MagicMock() + llm._pending_usage = [] + + return llm + + +def _make_litellm_response( + content: str = "extracted value", + finish_reason: str = "stop", + usage: dict | None = None, +) -> dict: + """Build a mock litellm.completion() response dict.""" + return { + "choices": [ + { + "message": {"content": content}, + "finish_reason": finish_reason, + } + ], + "usage": usage + or { + "prompt_tokens": 100, + "completion_tokens": 20, + "total_tokens": 120, + }, + } + + +class TestCompleteVision: + """Tests for LLM.complete_vision().""" + + @patch("unstract.sdk1.llm.litellm.completion") + @patch("unstract.sdk1.llm.pop_litellm_retry_kwargs", return_value=0) + def test_returns_response_compat( + self: Self, + _mock_pop: MagicMock, + mock_completion: MagicMock, + ) -> None: + """Should return dict with 'response' key containing LLMResponseCompat.""" + mock_completion.return_value = _make_litellm_response("answer") + llm = _make_llm() + + messages = [{"role": "user", "content": [{"type": "text", "text": "hi"}]}] + result = llm.complete_vision(messages) + + assert "response" in result + assert isinstance(result["response"], LLMResponseCompat) + assert result["response"].text == "answer" + + @patch("unstract.sdk1.llm.litellm.completion") + @patch("unstract.sdk1.llm.pop_litellm_retry_kwargs", return_value=0) + def test_messages_passed_to_litellm( + self: Self, + _mock_pop: MagicMock, + mock_completion: MagicMock, + ) -> None: + """Messages should be forwarded to litellm.completion().""" + mock_completion.return_value = _make_litellm_response() + llm = _make_llm() + + messages = [ + {"role": "system", "content": "sys prompt"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "context"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,AAAA"}, + }, + ], + }, + ] + llm.complete_vision(messages) + + mock_completion.assert_called_once() + call_kwargs = mock_completion.call_args + assert call_kwargs.kwargs["messages"] is messages + + @patch("unstract.sdk1.llm.litellm.completion") + @patch("unstract.sdk1.llm.pop_litellm_retry_kwargs", return_value=0) + def test_usage_recorded( + self: Self, + _mock_pop: MagicMock, + mock_completion: MagicMock, + ) -> None: + """Usage data should be recorded via _record_usage.""" + usage = { + "prompt_tokens": 500, + "completion_tokens": 50, + "total_tokens": 550, + } + mock_completion.return_value = _make_litellm_response(usage=usage) + llm = _make_llm() + + messages = [{"role": "user", "content": [{"type": "text", "text": "q"}]}] + llm.complete_vision(messages) + + llm._record_usage.assert_called_once() + call_args = llm._record_usage.call_args + assert call_args.args[0] == "test-vision-model" # model + assert call_args.args[2] == usage # usage dict + assert call_args.args[3] == "complete_vision" # llm_api + + @patch("unstract.sdk1.llm.litellm.completion") + @patch("unstract.sdk1.llm.pop_litellm_retry_kwargs", return_value=0) + def test_extract_json_strips_markers( + self: Self, + _mock_pop: MagicMock, + mock_completion: MagicMock, + ) -> None: + """extract_json=True should strip JSON code markers.""" + json_response = '```json\n{"key": "value"}\n```' + mock_completion.return_value = _make_litellm_response(content=json_response) + llm = _make_llm() + + messages = [{"role": "user", "content": [{"type": "text", "text": "q"}]}] + result = llm.complete_vision(messages, extract_json=True) + + # After extract_json, the markers should be stripped + response_text = result["response"].text + assert "```" not in response_text + + @patch("unstract.sdk1.llm.litellm.completion") + @patch("unstract.sdk1.llm.pop_litellm_retry_kwargs", return_value=0) + def test_none_response_raises_llm_error( + self: Self, + _mock_pop: MagicMock, + mock_completion: MagicMock, + ) -> None: + """None response content should raise LLMError.""" + mock_completion.return_value = _make_litellm_response( + content=None, finish_reason="content_filter" + ) + llm = _make_llm() + + messages = [{"role": "user", "content": [{"type": "text", "text": "q"}]}] + with pytest.raises(LLMError): + llm.complete_vision(messages) + + @patch("unstract.sdk1.llm.litellm.completion") + @patch("unstract.sdk1.llm.pop_litellm_retry_kwargs", return_value=0) + def test_exception_wrapped_in_llm_error( + self: Self, + _mock_pop: MagicMock, + mock_completion: MagicMock, + ) -> None: + """Non-LLM exceptions should be wrapped in LLMError.""" + mock_completion.side_effect = RuntimeError("provider error") + llm = _make_llm() + + messages = [{"role": "user", "content": [{"type": "text", "text": "q"}]}] + with pytest.raises(LLMError, match="Error from LLM adapter"): + llm.complete_vision(messages) + + @patch("unstract.sdk1.llm.litellm.completion") + @patch("unstract.sdk1.llm.pop_litellm_retry_kwargs", return_value=0) + def test_llm_error_reraised_as_is( + self: Self, + _mock_pop: MagicMock, + mock_completion: MagicMock, + ) -> None: + """LLMError from the call should be re-raised without wrapping.""" + original_err = LLMError(message="original error") + mock_completion.side_effect = original_err + llm = _make_llm() + + messages = [{"role": "user", "content": [{"type": "text", "text": "q"}]}] + with pytest.raises(LLMError) as exc_info: + llm.complete_vision(messages) + assert exc_info.value is original_err + + @patch("unstract.sdk1.llm.litellm.completion") + @patch("unstract.sdk1.llm.pop_litellm_retry_kwargs", return_value=0) + def test_raw_response_attached( + self: Self, + _mock_pop: MagicMock, + mock_completion: MagicMock, + ) -> None: + """Raw litellm response should be attached to response_object.raw.""" + raw_resp = _make_litellm_response("text") + mock_completion.return_value = raw_resp + llm = _make_llm() + + messages = [{"role": "user", "content": [{"type": "text", "text": "q"}]}] + result = llm.complete_vision(messages) + + assert result["response"].raw is raw_resp diff --git a/unstract/sdk1/tests/test_derive_vision_mode.py b/unstract/sdk1/tests/test_derive_vision_mode.py new file mode 100644 index 0000000000..06a106326e --- /dev/null +++ b/unstract/sdk1/tests/test_derive_vision_mode.py @@ -0,0 +1,81 @@ +"""Tests for derive_vision_mode() and vision constants.""" + +from typing import Self + +from unstract.sdk1.constants import ( + ExtractionInputs, + SourceOfTruthValues, + VisionMode, + derive_vision_mode, +) + + +class TestVisionModeConstants: + """Verify vision mode string constants are correct.""" + + def test_vision_mode_values(self: Self) -> None: + """VisionMode should expose the expected string values.""" + assert VisionMode.TEXT_ONLY == "text_only" + assert VisionMode.SPATIAL_HELPER == "spatial_helper" + assert VisionMode.SOURCE_OF_TRUTH == "source_of_truth" + + def test_extraction_inputs_values(self: Self) -> None: + """ExtractionInputs should expose the expected string values.""" + assert ExtractionInputs.TEXT == "text" + assert ExtractionInputs.IMAGE == "image" + assert ExtractionInputs.BOTH == "both" + + def test_source_of_truth_values(self: Self) -> None: + """SourceOfTruthValues should expose the expected string values.""" + assert SourceOfTruthValues.TEXT == "text" + assert SourceOfTruthValues.IMAGE == "image" + + +class TestDeriveVisionMode: + """Tests for derive_vision_mode() derivation logic. + + Derivation table: + | extraction_inputs | source_of_truth | vision_mode | + |-------------------|-----------------|-------------------| + | text | (ignored) | text_only | + | image | (ignored) | source_of_truth | + | both | text | spatial_helper | + | both | image | source_of_truth | + """ + + def test_text_returns_text_only(self: Self) -> None: + """extraction_inputs=text -> text_only, source_of_truth ignored.""" + assert derive_vision_mode("text", "text") == VisionMode.TEXT_ONLY + assert derive_vision_mode("text", "image") == VisionMode.TEXT_ONLY + + def test_image_returns_source_of_truth(self: Self) -> None: + """extraction_inputs=image -> source_of_truth, always.""" + assert derive_vision_mode("image", "text") == VisionMode.SOURCE_OF_TRUTH + assert derive_vision_mode("image", "image") == VisionMode.SOURCE_OF_TRUTH + + def test_both_text_sot_returns_spatial_helper(self: Self) -> None: + """extraction_inputs=both, source_of_truth=text -> spatial_helper.""" + assert derive_vision_mode("both", "text") == VisionMode.SPATIAL_HELPER + + def test_both_image_sot_returns_source_of_truth(self: Self) -> None: + """extraction_inputs=both, source_of_truth=image -> source_of_truth.""" + assert derive_vision_mode("both", "image") == VisionMode.SOURCE_OF_TRUTH + + def test_with_enum_constants(self: Self) -> None: + """Verify the function works with class constants.""" + assert ( + derive_vision_mode(ExtractionInputs.TEXT, SourceOfTruthValues.TEXT) + == VisionMode.TEXT_ONLY + ) + assert ( + derive_vision_mode(ExtractionInputs.IMAGE, SourceOfTruthValues.TEXT) + == VisionMode.SOURCE_OF_TRUTH + ) + assert ( + derive_vision_mode(ExtractionInputs.BOTH, SourceOfTruthValues.TEXT) + == VisionMode.SPATIAL_HELPER + ) + assert ( + derive_vision_mode(ExtractionInputs.BOTH, SourceOfTruthValues.IMAGE) + == VisionMode.SOURCE_OF_TRUTH + ) diff --git a/unstract/sdk1/tests/test_rasteriser.py b/unstract/sdk1/tests/test_rasteriser.py new file mode 100644 index 0000000000..559001aa4f --- /dev/null +++ b/unstract/sdk1/tests/test_rasteriser.py @@ -0,0 +1,282 @@ +"""Tests for the PDF rasteriser module.""" + +import io +from typing import Self +from unittest.mock import MagicMock, patch + +import pytest +from PIL import Image +from unstract.sdk1.rasteriser import ( + RenderSettings, + _preprocess_image, + rasterise_pages, +) + + +def _make_pil_image( + width: int = 100, + height: int = 100, + mode: str = "RGB", +) -> Image.Image: + """Create a simple test PIL image.""" + return Image.new(mode, (width, height), color="red") + + +# --------------------------------------------------------------------------- +# RenderSettings +# --------------------------------------------------------------------------- + + +class TestRenderSettings: + """Tests for the RenderSettings dataclass.""" + + def test_defaults(self: Self) -> None: + """Default DPI=150, max_dimension=1568.""" + s = RenderSettings() + assert s.dpi == 150 + assert s.max_dimension == 1568 + + def test_scale_at_72_dpi(self: Self) -> None: + """At 72 DPI, scale should be 1.0.""" + assert RenderSettings(dpi=72).scale == pytest.approx(1.0) + + def test_scale_at_150_dpi(self: Self) -> None: + """At 150 DPI, scale should be 150/72.""" + assert RenderSettings(dpi=150).scale == pytest.approx(150.0 / 72.0) + + def test_custom_values(self: Self) -> None: + """Custom values should be stored correctly.""" + s = RenderSettings(dpi=300, max_dimension=2048) + assert s.dpi == 300 + assert s.max_dimension == 2048 + + +# --------------------------------------------------------------------------- +# _preprocess_image +# --------------------------------------------------------------------------- + + +class TestPreprocessImage: + """Tests for _preprocess_image() preprocessing pipeline.""" + + def test_small_image_upscaled_2x(self: Self) -> None: + """Small images should be upscaled 2x (under max_dimension).""" + img = _make_pil_image(100, 100) + result = _preprocess_image(img, max_dimension=1568) + # 100 -> 200 (2x upscale), stays under 1568 + assert result.width == 200 + assert result.height == 200 + + def test_large_image_constrained(self: Self) -> None: + """After 2x upscale, images exceeding max_dimension get constrained.""" + img = _make_pil_image(1000, 800) + result = _preprocess_image(img, max_dimension=500) + # 1000 -> 2000 (2x), then constrained to 500 max dim + assert max(result.size) <= 500 + + def test_preserves_rgb_mode(self: Self) -> None: + """Output should always be RGB.""" + img = _make_pil_image(50, 50) + result = _preprocess_image(img, max_dimension=1568) + assert result.mode == "RGB" + + def test_aspect_ratio_preserved(self: Self) -> None: + """Constraining should preserve the aspect ratio.""" + img = _make_pil_image(800, 400) # 2:1 ratio + result = _preprocess_image(img, max_dimension=500) + ratio_input = 800 / 400 + ratio_output = result.width / result.height + assert ratio_input == pytest.approx(ratio_output, abs=0.15) + + +# --------------------------------------------------------------------------- +# rasterise_pages — mocked pypdfium2 +# --------------------------------------------------------------------------- + + +def _make_mock_pdf( + num_pages: int = 1, + img_size: tuple[int, int] = (200, 200), +) -> MagicMock: + """Create a mock PdfDocument.""" + mock_doc = MagicMock() + mock_doc.__len__ = MagicMock(return_value=num_pages) + + def _getitem(idx: int) -> MagicMock: + mock_page = MagicMock() + mock_bitmap = MagicMock() + mock_bitmap.to_pil.return_value = _make_pil_image(*img_size) + mock_page.render.return_value = mock_bitmap + return mock_page + + mock_doc.__getitem__ = MagicMock(side_effect=_getitem) + return mock_doc + + +class TestRasterisePages: + """Tests for rasterise_pages() with mocked pypdfium2.""" + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_single_page( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """Should produce one (page_num, png_bytes) tuple for a 1-page PDF.""" + mock_pdf_cls.return_value = _make_mock_pdf(1) + + results = rasterise_pages(b"fake-pdf-bytes") + + assert len(results) == 1 + page_num, png_bytes = results[0] + assert page_num == 0 + assert len(png_bytes) > 0 + + # Verify output is valid PNG + img = Image.open(io.BytesIO(png_bytes)) + assert img.format == "PNG" + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_multi_page( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """Should rasterise all pages of a multi-page PDF.""" + mock_pdf_cls.return_value = _make_mock_pdf(3) + + results = rasterise_pages(b"fake-pdf") + + assert len(results) == 3 + assert [r[0] for r in results] == [0, 1, 2] + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_page_set_filtering( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """Should only render pages in page_set.""" + mock_pdf_cls.return_value = _make_mock_pdf(5) + + results = rasterise_pages(b"fake-pdf", page_set={1, 3}) + + assert len(results) == 2 + assert [r[0] for r in results] == [1, 3] + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_page_set_skips_invalid( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """Should skip page numbers beyond the PDF's page count.""" + mock_pdf_cls.return_value = _make_mock_pdf(3) + + results = rasterise_pages(b"fake-pdf", page_set={0, 1, 10, 20}) + + assert len(results) == 2 + assert [r[0] for r in results] == [0, 1] + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_max_pages_truncation( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """Should cap at max_pages.""" + mock_pdf_cls.return_value = _make_mock_pdf(20) + + results = rasterise_pages(b"fake-pdf", max_pages=5) + + assert len(results) == 5 + assert [r[0] for r in results] == [0, 1, 2, 3, 4] + + def test_empty_bytes_raises_value_error(self: Self) -> None: + """Empty file_bytes should raise ValueError.""" + with pytest.raises(ValueError, match="file_bytes is empty"): + rasterise_pages(b"") + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_zero_page_pdf_returns_empty( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """PDF with 0 pages should return empty list.""" + mock_pdf_cls.return_value = _make_mock_pdf(0) + + results = rasterise_pages(b"fake-pdf") + + assert results == [] + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_rgba_converted_to_rgb( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """RGBA images from the renderer should be converted to RGB.""" + mock_doc = MagicMock() + mock_doc.__len__ = MagicMock(return_value=1) + + mock_page = MagicMock() + mock_bitmap = MagicMock() + # Return an RGBA image + mock_bitmap.to_pil.return_value = Image.new("RGBA", (100, 100), (255, 0, 0, 128)) + mock_page.render.return_value = mock_bitmap + mock_doc.__getitem__ = MagicMock(return_value=mock_page) + mock_pdf_cls.return_value = mock_doc + + results = rasterise_pages(b"fake-pdf") + + assert len(results) == 1 + output_img = Image.open(io.BytesIO(results[0][1])) + assert output_img.mode == "RGB" + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_pdf_closed_after_processing( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """PDF document should be closed after successful processing.""" + mock_doc = _make_mock_pdf(1) + mock_pdf_cls.return_value = mock_doc + + rasterise_pages(b"fake-pdf") + + mock_doc.close.assert_called_once() + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_pdf_closed_on_error( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """PDF document should be closed even when an error occurs.""" + mock_doc = MagicMock() + mock_doc.__len__ = MagicMock(side_effect=RuntimeError("bad pdf")) + mock_pdf_cls.return_value = mock_doc + + with pytest.raises(RuntimeError, match="bad pdf"): + rasterise_pages(b"fake-pdf") + + mock_doc.close.assert_called_once() + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_default_settings_when_none( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """When settings=None, default RenderSettings should be used.""" + mock_pdf_cls.return_value = _make_mock_pdf(1) + + # Should not raise + results = rasterise_pages(b"fake-pdf", settings=None) + + assert len(results) == 1 + + @patch("unstract.sdk1.rasteriser.pdfium.PdfDocument") + def test_page_set_sorted_output( + self: Self, + mock_pdf_cls: MagicMock, + ) -> None: + """Pages should be rendered in sorted order.""" + mock_pdf_cls.return_value = _make_mock_pdf(10) + + results = rasterise_pages(b"fake-pdf", page_set={5, 2, 8, 0}) + + page_nums = [r[0] for r in results] + assert page_nums == [0, 2, 5, 8] diff --git a/unstract/sdk1/tests/test_vision_messages.py b/unstract/sdk1/tests/test_vision_messages.py new file mode 100644 index 0000000000..f01ab68f64 --- /dev/null +++ b/unstract/sdk1/tests/test_vision_messages.py @@ -0,0 +1,235 @@ +"""Tests for vision message builder (build_vision_messages).""" + +import base64 +from typing import Self + +import pytest +from unstract.sdk1.vision import build_vision_messages + + +@pytest.fixture() +def sample_images() -> list[tuple[int, bytes]]: + """Create sample page images for testing.""" + return [ + (0, b"fake-png-page-0"), + (1, b"fake-png-page-1"), + ] + + +class TestBuildVisionMessages: + """Tests for build_vision_messages().""" + + def test_spatial_helper_text_before_images( + self: Self, + sample_images: list[tuple[int, bytes]], + ) -> None: + """In spatial_helper mode, text context appears before images.""" + messages = build_vision_messages( + system_prompt="Be helpful", + text_context="Document text here", + page_images=sample_images, + prompt="Extract the value", + mode="spatial_helper", + ) + + # System message first + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be helpful" + + # User message with content blocks + user_msg = messages[1] + assert user_msg["role"] == "user" + content = user_msg["content"] + + # First content block: text (primary) + assert content[0]["type"] == "text" + assert "primary source" in content[0]["text"] + assert "Document text here" in content[0]["text"] + + # Images come after text + image_blocks = [b for b in content if b.get("type") == "image_url"] + assert len(image_blocks) == 2 + + # Last block is the extraction prompt + assert content[-1]["type"] == "text" + assert content[-1]["text"] == "Extract the value" + + def test_source_of_truth_images_before_text( + self: Self, + sample_images: list[tuple[int, bytes]], + ) -> None: + """In source_of_truth mode, images appear before text context.""" + messages = build_vision_messages( + system_prompt="Be helpful", + text_context="Document text", + page_images=sample_images, + prompt="Extract", + mode="source_of_truth", + ) + + user_content = messages[1]["content"] + + # First content block: images (primary) + assert user_content[0]["type"] == "text" + assert "primary source" in user_content[0]["text"] + assert "PAGES" in user_content[0]["text"] + + # Text comes after images — find supplementary text + text_blocks = [ + b + for b in user_content + if b.get("type") == "text" and "supplementary" in b.get("text", "") + ] + assert len(text_blocks) == 1 + assert "Document text" in text_blocks[0]["text"] + + def test_images_encoded_as_base64(self: Self) -> None: + """Image bytes should be base64-encoded in data: URLs.""" + page_images = [(0, b"test-png-bytes")] + messages = build_vision_messages( + system_prompt="", + text_context="text", + page_images=page_images, + prompt="Extract", + mode="spatial_helper", + ) + + # No system prompt -> user message is messages[0] + user_content = messages[0]["content"] + image_blocks = [b for b in user_content if b.get("type") == "image_url"] + assert len(image_blocks) == 1 + + url = image_blocks[0]["image_url"]["url"] + assert url.startswith("data:image/png;base64,") + + # Verify round-trip decoding + b64_part = url.split(",", 1)[1] + assert base64.standard_b64decode(b64_part) == b"test-png-bytes" + + def test_page_numbers_1_indexed_in_labels(self: Self) -> None: + """Each image should have a 'Page N:' label (1-indexed).""" + page_images = [(0, b"p0"), (3, b"p3")] + messages = build_vision_messages( + system_prompt="sys", + text_context="txt", + page_images=page_images, + prompt="prompt", + mode="spatial_helper", + ) + + user_content = messages[1]["content"] + text_values = [b["text"] for b in user_content if b.get("type") == "text"] + + # Page labels are 1-indexed (page_num + 1) + assert any("Page 1:" in t for t in text_values) + assert any("Page 4:" in t for t in text_values) + + def test_empty_system_prompt_omitted( + self: Self, + sample_images: list[tuple[int, bytes]], + ) -> None: + """Empty system prompt should not produce a system message.""" + messages = build_vision_messages( + system_prompt="", + text_context="text", + page_images=sample_images, + prompt="Extract", + mode="spatial_helper", + ) + + # No system message, just user message + assert len(messages) == 1 + assert messages[0]["role"] == "user" + + def test_none_text_context_spatial_helper( + self: Self, + sample_images: list[tuple[int, bytes]], + ) -> None: + """None text_context in spatial_helper should omit text block.""" + messages = build_vision_messages( + system_prompt="sys", + text_context=None, + page_images=sample_images, + prompt="Extract", + mode="spatial_helper", + ) + + user_content = messages[1]["content"] + doc_text_blocks = [ + b + for b in user_content + if b.get("type") == "text" and "DOCUMENT TEXT" in b.get("text", "") + ] + assert len(doc_text_blocks) == 0 + + def test_none_text_context_source_of_truth( + self: Self, + sample_images: list[tuple[int, bytes]], + ) -> None: + """None text_context in source_of_truth is normal (image-only).""" + messages = build_vision_messages( + system_prompt="sys", + text_context=None, + page_images=sample_images, + prompt="Extract", + mode="source_of_truth", + ) + + user_content = messages[1]["content"] + doc_text_blocks = [ + b + for b in user_content + if b.get("type") == "text" and "DOCUMENT TEXT" in b.get("text", "") + ] + assert len(doc_text_blocks) == 0 + + def test_invalid_mode_raises_value_error( + self: Self, + sample_images: list[tuple[int, bytes]], + ) -> None: + """Invalid mode should raise ValueError.""" + with pytest.raises(ValueError, match="Invalid vision mode"): + build_vision_messages( + system_prompt="sys", + text_context="text", + page_images=sample_images, + prompt="Extract", + mode="invalid_mode", + ) + + def test_empty_page_images_raises_value_error(self: Self) -> None: + """Empty page_images should raise ValueError.""" + with pytest.raises(ValueError, match="page_images is empty"): + build_vision_messages( + system_prompt="sys", + text_context="text", + page_images=[], + prompt="Extract", + mode="spatial_helper", + ) + + def test_message_structure_complete( + self: Self, + sample_images: list[tuple[int, bytes]], + ) -> None: + """Verify the full message structure is OpenAI-compatible.""" + messages = build_vision_messages( + system_prompt="System prompt", + text_context="Document text", + page_images=sample_images, + prompt="What is the value?", + mode="spatial_helper", + ) + + assert len(messages) == 2 + assert messages[0] == { + "role": "system", + "content": "System prompt", + } + assert messages[1]["role"] == "user" + assert isinstance(messages[1]["content"], list) + + # Every content block must have a "type" key + for block in messages[1]["content"]: + assert "type" in block + assert block["type"] in ("text", "image_url") diff --git a/unstract/sdk1/uv.lock b/unstract/sdk1/uv.lock index a482fe5b95..de04cde5a6 100644 --- a/unstract/sdk1/uv.lock +++ b/unstract/sdk1/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = "==3.12.*" [[package]] @@ -2734,6 +2734,7 @@ dependencies = [ { name = "llama-parse" }, { name = "llmwhisperer-client" }, { name = "pdfplumber" }, + { name = "pypdfium2" }, { name = "python-dotenv" }, { name = "python-magic" }, { name = "qdrant-client" }, @@ -2795,6 +2796,7 @@ requires-dist = [ { name = "llama-parse", specifier = ">=0.6.0" }, { name = "llmwhisperer-client", specifier = ">=2.6.2" }, { name = "pdfplumber", specifier = ">=0.11.2" }, + { name = "pypdfium2", specifier = ">=4.0.0" }, { name = "python-dotenv", specifier = "==1.2.2" }, { name = "python-magic", specifier = "~=0.4.27" }, { name = "qdrant-client", specifier = ">=1.16.0,<1.17.0" }, diff --git a/workers/executor/executors/answer_prompt.py b/workers/executor/executors/answer_prompt.py index d1eef5b3be..1b04db1200 100644 --- a/workers/executor/executors/answer_prompt.py +++ b/workers/executor/executors/answer_prompt.py @@ -20,6 +20,14 @@ from executor.executors.constants import PromptServiceConstants as PSKeys from executor.executors.exceptions import LegacyExecutorError, RateLimitError +from unstract.sdk1.constants import ( + ExtractionInputs, + VisionMode, + derive_vision_mode, +) +from unstract.sdk1.rasteriser import RenderSettings, rasterise_pages +from unstract.sdk1.vision import build_vision_messages + logger = logging.getLogger(__name__) @@ -115,6 +123,7 @@ def construct_and_run_prompt( file_path: str = "", execution_source: str | None = "ide", process_text: Any = None, + source_file_path: str = "", ) -> str: """Construct the full prompt and run LLM completion. @@ -129,10 +138,17 @@ def construct_and_run_prompt( execution_source: "ide" or "tool". process_text: Optional callback for text processing during completion (e.g. highlight-data plugin's ``run`` method). + source_file_path: Path to the original source file (PDF) for + vision mode rasterisation. Empty string disables vision. Returns: The LLM answer string. """ + # Derive vision mode from per-prompt fields + extraction_inputs = output.get(PSKeys.EXTRACTION_INPUTS, ExtractionInputs.TEXT) + source_of_truth = output.get(PSKeys.SOURCE_OF_TRUTH, "text") + vision_mode = derive_vision_mode(extraction_inputs, source_of_truth) + platform_postamble = tool_settings.get(PSKeys.PLATFORM_POSTAMBLE, "") word_confidence_postamble = tool_settings.get( PSKeys.WORD_CONFIDENCE_POSTAMBLE, "" @@ -142,6 +158,16 @@ def construct_and_run_prompt( enable_word_confidence = tool_settings.get(PSKeys.ENABLE_WORD_CONFIDENCE, False) if not enable_highlight: enable_word_confidence = False + + # Vision mode: suppress highlights and postambles (no OCR line + # metadata to ground against) + if vision_mode != VisionMode.TEXT_ONLY: + platform_postamble = "" + word_confidence_postamble = "" + enable_highlight = False + enable_word_confidence = False + process_text = None + prompt_type = output.get(PSKeys.TYPE, PSKeys.TEXT) if not enable_highlight or summarize_as_source: platform_postamble = "" @@ -159,6 +185,21 @@ def construct_and_run_prompt( prompt_type=prompt_type, ) output[PSKeys.COMBINED_PROMPT] = prompt + + if vision_mode != VisionMode.TEXT_ONLY: + return AnswerPromptService.run_vision_completion( + llm=llm, + text_prompt=prompt, + text_context=context, + source_file_path=source_file_path, + vision_mode=vision_mode, + metadata=metadata, + prompt_key=output[PSKeys.NAME], + prompt_type=prompt_type, + preamble=tool_settings.get(PSKeys.PREAMBLE, ""), + execution_source=execution_source or "ide", + ) + return AnswerPromptService.run_completion( llm=llm, prompt=prompt, @@ -280,6 +321,132 @@ def run_completion( status_code = getattr(e, "status_code", None) or 500 raise LegacyExecutorError(message=str(e), code=status_code) from e + @staticmethod + def run_vision_completion( + llm: Any, + text_prompt: str, + text_context: str, + source_file_path: str, + vision_mode: str, + metadata: dict[str, Any], + prompt_key: str, + prompt_type: str = "text", + preamble: str = "", + execution_source: str = "ide", + ) -> str: + """Run VLM completion with page images and optional text context. + + Reads the source file, rasterises pages, builds multimodal messages, + and calls ``llm.complete_vision()``. + + Args: + llm: LLM adapter instance. + text_prompt: The constructed prompt string (preamble + question + + postamble + context). + text_context: Retrieved text context (may be empty for image-only). + source_file_path: Path to the original source file (PDF) in + file storage, for rasterisation. + vision_mode: One of VisionMode.SPATIAL_HELPER or + VisionMode.SOURCE_OF_TRUTH. + metadata: Metadata dict (updated in place). + prompt_key: The prompt name for metadata keying. + prompt_type: "text" or "json" — controls extract_json. + preamble: The preamble text used as system prompt for the VLM. + execution_source: "ide" or "tool" — determines storage backend. + """ + try: + from unstract.sdk1.exceptions import RateLimitError as _sdk_rate_limit_error + from unstract.sdk1.exceptions import SdkError as _sdk_error + except ImportError: + _sdk_rate_limit_error = Exception + _sdk_error = Exception + + if not source_file_path: + raise LegacyExecutorError( + message=( + f"Vision mode '{vision_mode}' requires a source file path " + f"for rasterisation, but none was provided for prompt " + f"'{prompt_key}'." + ), + code=400, + ) + + try: + # Read source file bytes from file storage + from executor.executors.file_utils import FileUtils + + fs = FileUtils.get_fs_instance(execution_source=execution_source) + file_bytes: bytes = fs.read(path=source_file_path, mode="rb") + + # Rasterise pages + settings = RenderSettings() + page_images = rasterise_pages( + file_bytes=file_bytes, + settings=settings, + ) + if not page_images: + raise LegacyExecutorError( + message=( + f"No pages could be rasterised from '{source_file_path}' " + f"for prompt '{prompt_key}'." + ), + code=500, + ) + + logger.info( + "Vision mode=%s: rasterised %d pages for prompt=%s", + vision_mode, + len(page_images), + prompt_key, + ) + + # Map VisionMode constants to build_vision_messages mode strings + mode_str = ( + "spatial_helper" + if vision_mode == VisionMode.SPATIAL_HELPER + else "source_of_truth" + ) + + # Build multimodal messages + messages = build_vision_messages( + system_prompt=preamble, + text_context=text_context if text_context.strip() else None, + page_images=page_images, + prompt=text_prompt, + mode=mode_str, + ) + + # Call VLM completion + completion = llm.complete_vision( + messages=messages, + extract_json=prompt_type.lower() != PSKeys.TEXT, + ) + + answer: str = completion[PSKeys.RESPONSE].text + return answer + + except _sdk_rate_limit_error as e: + raise RateLimitError(f"Rate limit error. {str(e)}") from e + except (LegacyExecutorError, RateLimitError): + raise + except _sdk_error as e: + logger.error( + "Error during vision completion for prompt %s: %s", + prompt_key, + e, + ) + status_code = getattr(e, "status_code", None) or 500 + raise LegacyExecutorError(message=str(e), code=status_code) from e + except Exception as e: + logger.error( + "Unexpected error during vision completion for prompt %s: %s", + prompt_key, + e, + ) + raise LegacyExecutorError( + message=f"Vision completion failed: {e}", code=500 + ) from e + @staticmethod def _run_webhook_postprocess( parsed_data: Any, diff --git a/workers/executor/executors/constants.py b/workers/executor/executors/constants.py index 9eddab8423..013ceafd11 100644 --- a/workers/executor/executors/constants.py +++ b/workers/executor/executors/constants.py @@ -102,6 +102,10 @@ class PromptServiceConstants: # Webhook postprocessing settings ENABLE_POSTPROCESSING_WEBHOOK = "enable_postprocessing_webhook" POSTPROCESSING_WEBHOOK_URL = "postprocessing_webhook_url" + # Vision mode fields + EXTRACTION_INPUTS = "extraction_inputs" + SOURCE_OF_TRUTH = "source_of_truth" + SOURCE_FILE_PATH = "source_file_path" class RunLevel(Enum): diff --git a/workers/executor/executors/legacy_executor.py b/workers/executor/executors/legacy_executor.py index ce7fbea0d1..a20fbd8368 100644 --- a/workers/executor/executors/legacy_executor.py +++ b/workers/executor/executors/legacy_executor.py @@ -1349,6 +1349,7 @@ def _handle_answer_prompt(self, context: ExecutionContext) -> ExecutionResult: tool_id: str = params.get(PSKeys.TOOL_ID, "") run_id: str = context.run_id file_path = params.get(PSKeys.FILE_PATH) + source_file_path: str = params.get(PSKeys.SOURCE_FILE_PATH, "") doc_name = str(params.get(PSKeys.FILE_NAME, "")) execution_source = params.get(PSKeys.EXECUTION_SOURCE, context.execution_source) platform_api_key: str = params.get(PSKeys.PLATFORM_SERVICE_API_KEY, "") @@ -1659,6 +1660,7 @@ def _execute_single_prompt( execution_id = params.get(PSKeys.EXECUTION_ID, "") file_hash = params.get(PSKeys.FILE_HASH) file_path = params.get(PSKeys.FILE_PATH) + source_file_path: str = params.get(PSKeys.SOURCE_FILE_PATH, "") doc_name = str(params.get(PSKeys.FILE_NAME, "")) log_events_id = params.get(PSKeys.LOG_EVENTS_ID, "") tool_id = params.get(PSKeys.TOOL_ID, "") @@ -1818,6 +1820,7 @@ def _execute_single_prompt( execution_source=execution_source, file_path=file_path, process_text=process_text_fn, + source_file_path=source_file_path, ) else: logger.warning( diff --git a/workers/file_processing/structure_tool_task.py b/workers/file_processing/structure_tool_task.py index cfa0e316a6..feb5caae86 100644 --- a/workers/file_processing/structure_tool_task.py +++ b/workers/file_processing/structure_tool_task.py @@ -114,6 +114,9 @@ class _SK: CUSTOM_DATA = "custom_data" SINGLE_PASS_EXTRACTION_MODE = "single_pass_extraction_mode" CHALLENGE_LLM_ADAPTER_ID = "challenge_llm_adapter_id" + EXTRACTION_INPUTS = "extraction_inputs" + SOURCE_OF_TRUTH = "source_of_truth" + SOURCE_FILE_PATH = "source_file_path" # ----------------------------------------------------------------------- @@ -207,6 +210,19 @@ def _should_skip_extraction_for_smart_table( return False +def _should_skip_extraction_for_vision( + outputs: list[dict[str, Any]], +) -> bool: + """Check if extraction/indexing should be skipped for image-only vision. + + When ALL prompts use extraction_inputs="image" (page image only), + text extraction is unnecessary since the LLM receives only page images. + """ + if not outputs: + return False + return all(output.get(_SK.EXTRACTION_INPUTS, "text") == "image" for output in outputs) + + # ----------------------------------------------------------------------- # Main Celery task # ----------------------------------------------------------------------- @@ -368,12 +384,18 @@ def _execute_structure_tool_impl(params: dict) -> dict: execution_run_data_folder = Path(execution_data_dir) extracted_input_file = str(execution_run_data_folder / _SK.EXTRACT) - # ---- Step 4: Smart table detection ---- + # ---- Step 4: Smart table / vision-only detection ---- skip_extraction_and_indexing = _should_skip_extraction_for_smart_table(outputs) if skip_extraction_and_indexing: logger.info( "Skipping extraction and indexing for Excel table with valid JSON schema" ) + if not skip_extraction_and_indexing: + skip_extraction_and_indexing = _should_skip_extraction_for_vision(outputs) + if skip_extraction_and_indexing: + logger.info( + "Skipping extraction and indexing: all prompts use image-only vision mode" + ) # ---- Step 5: Build pipeline params ---- usage_kwargs: dict[Any, Any] = {} @@ -383,6 +405,9 @@ def _execute_structure_tool_impl(params: dict) -> dict: usage_kwargs[UsageKwargs.EXECUTION_ID] = execution_id custom_data = exec_metadata.get(_SK.CUSTOM_DATA, {}) + # SOURCE is the immutable original file (e.g. PDF) written by the + # source connector alongside INFILE; used for vision mode rasterisation. + source_file_path = str(execution_run_data_folder / "SOURCE") answer_params = { _SK.RUN_ID: file_execution_id, _SK.EXECUTION_ID: execution_id, @@ -394,6 +419,7 @@ def _execute_structure_tool_impl(params: dict) -> dict: _SK.FILE_PATH: extracted_input_file, _SK.EXECUTION_SOURCE: _SK.TOOL, _SK.CUSTOM_DATA: custom_data, + _SK.SOURCE_FILE_PATH: source_file_path, "PLATFORM_SERVICE_API_KEY": platform_service_api_key, } diff --git a/workers/tests/test_single_pass_vision_guard.py b/workers/tests/test_single_pass_vision_guard.py new file mode 100644 index 0000000000..aae11bf575 --- /dev/null +++ b/workers/tests/test_single_pass_vision_guard.py @@ -0,0 +1,147 @@ +"""Tests for the single-pass extraction vision mode guard (Phase 9a). + +Verifies that SinglePassExtractionExecutor rejects payloads containing +vision-enabled prompts, since single-pass merges all prompts into one +LLM call that cannot apply per-prompt vision modes. + +The single_pass_extraction plugin is installed separately (cloud-only), +so we add its src directory to sys.path for testing. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest +from executor.executors.constants import PromptServiceConstants as PSKeys + +from unstract.sdk1.execution.context import ExecutionContext, Operation + +# Add plugin src to path so we can import it +_plugin_src = str( + Path(__file__).resolve().parent.parent / "plugins" / "single_pass_extraction" / "src" +) +if _plugin_src not in sys.path: + sys.path.insert(0, _plugin_src) + + +def _make_output( + name: str = "field_a", + extraction_inputs: str = "text", +) -> dict: + """Build a minimal prompt output dict for single-pass.""" + return { + PSKeys.NAME: name, + PSKeys.PROMPT: "What is the value?", + PSKeys.TYPE: "text", + PSKeys.CHUNK_SIZE: 0, + PSKeys.CHUNK_OVERLAP: 0, + PSKeys.RETRIEVAL_STRATEGY: "simple", + PSKeys.LLM: "llm-1", + PSKeys.EMBEDDING: "emb-1", + PSKeys.VECTOR_DB: "vdb-1", + PSKeys.X2TEXT_ADAPTER: "x2t-1", + PSKeys.SIMILARITY_TOP_K: 5, + PSKeys.EXTRACTION_INPUTS: extraction_inputs, + } + + +def _make_context(outputs: list[dict]) -> ExecutionContext: + """Build an ExecutionContext for single-pass extraction.""" + params = { + PSKeys.OUTPUTS: outputs, + PSKeys.TOOL_SETTINGS: {PSKeys.LLM: "llm-1"}, + PSKeys.TOOL_ID: "tool-1", + PSKeys.FILE_PATH: "/data/doc.txt", + PSKeys.FILE_NAME: "doc.txt", + PSKeys.EXECUTION_SOURCE: "tool", + PSKeys.CUSTOM_DATA: {}, + PSKeys.PLATFORM_SERVICE_API_KEY: "pk-test", + } + return ExecutionContext( + executor_name="single_pass_extraction", + operation=Operation.SINGLE_PASS_EXTRACTION.value, + executor_params=params, + run_id="run-1", + execution_source="tool", + ) + + +class TestSinglePassVisionGuard: + """Single-pass extraction must reject vision-enabled prompts.""" + + def test_text_only_prompts_pass_guard(self): + """Text-only prompts should not trigger the guard.""" + from single_pass_extraction.executor import SinglePassExtractionExecutor + + executor = SinglePassExtractionExecutor() + ctx = _make_context([ + _make_output("field_a", "text"), + _make_output("field_b", "text"), + ]) + + # The guard only checks vision; execution will fail later on LLM init. + # We just need to verify the guard itself doesn't block text-only. + result = executor.execute(ctx) + # It should NOT fail with the vision guard message + if not result.success: + assert "vision" not in result.error.lower() + + def test_image_only_prompt_triggers_guard(self): + """A prompt with extraction_inputs='image' should be rejected.""" + from single_pass_extraction.executor import SinglePassExtractionExecutor + + executor = SinglePassExtractionExecutor() + ctx = _make_context([ + _make_output("field_a", "image"), + ]) + result = executor.execute(ctx) + assert not result.success + assert "vision" in result.error.lower() + assert "field_a" in result.error + + def test_both_mode_prompt_triggers_guard(self): + """A prompt with extraction_inputs='both' should be rejected.""" + from single_pass_extraction.executor import SinglePassExtractionExecutor + + executor = SinglePassExtractionExecutor() + ctx = _make_context([ + _make_output("field_a", "both"), + ]) + result = executor.execute(ctx) + assert not result.success + assert "vision" in result.error.lower() + assert "field_a" in result.error + + def test_mixed_prompts_triggers_guard(self): + """If ANY prompt has vision, the guard should reject.""" + from single_pass_extraction.executor import SinglePassExtractionExecutor + + executor = SinglePassExtractionExecutor() + ctx = _make_context([ + _make_output("text_field", "text"), + _make_output("vision_field", "image"), + ]) + result = executor.execute(ctx) + assert not result.success + assert "vision_field" in result.error + assert "text_field" not in result.error + + def test_missing_extraction_inputs_defaults_to_text(self): + """Prompts without extraction_inputs field should default to text.""" + from single_pass_extraction.executor import SinglePassExtractionExecutor + + executor = SinglePassExtractionExecutor() + output = _make_output("field_a", "text") + del output[PSKeys.EXTRACTION_INPUTS] + ctx = _make_context([output]) + + result = executor.execute(ctx) + # Should not fail with vision guard message + if not result.success: + assert "vision" not in result.error.lower() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/workers/tests/test_structure_tool_task.py b/workers/tests/test_structure_tool_task.py index 8bdb541eca..a136562800 100644 --- a/workers/tests/test_structure_tool_task.py +++ b/workers/tests/test_structure_tool_task.py @@ -19,7 +19,10 @@ import pytest -from file_processing.structure_tool_task import _fairness_headers +from file_processing.structure_tool_task import ( + _fairness_headers, + _should_skip_extraction_for_vision, +) from queue_backend.fairness import WorkloadType @@ -50,5 +53,51 @@ def test_workload_type_is_non_api_not_api(self): assert wire["x-fairness-key"]["workload_type"] != WorkloadType.API.value +class TestShouldSkipExtractionForVision: + """Tests for ``_should_skip_extraction_for_vision``. + + Extraction should be skipped only when ALL prompts use image-only + vision mode (extraction_inputs="image"). + """ + + def test_empty_outputs_returns_false(self): + assert _should_skip_extraction_for_vision([]) is False + + def test_all_text_only_returns_false(self): + outputs = [ + {"name": "p1", "extraction_inputs": "text"}, + {"name": "p2", "extraction_inputs": "text"}, + ] + assert _should_skip_extraction_for_vision(outputs) is False + + def test_all_image_only_returns_true(self): + outputs = [ + {"name": "p1", "extraction_inputs": "image"}, + {"name": "p2", "extraction_inputs": "image"}, + ] + assert _should_skip_extraction_for_vision(outputs) is True + + def test_mixed_modes_returns_false(self): + """If any prompt needs text, extraction must run.""" + outputs = [ + {"name": "p1", "extraction_inputs": "image"}, + {"name": "p2", "extraction_inputs": "both"}, + ] + assert _should_skip_extraction_for_vision(outputs) is False + + def test_both_mode_returns_false(self): + outputs = [{"name": "p1", "extraction_inputs": "both"}] + assert _should_skip_extraction_for_vision(outputs) is False + + def test_missing_field_defaults_to_text(self): + """Outputs without extraction_inputs should default to text.""" + outputs = [{"name": "p1"}] + assert _should_skip_extraction_for_vision(outputs) is False + + def test_single_image_only_returns_true(self): + outputs = [{"name": "p1", "extraction_inputs": "image"}] + assert _should_skip_extraction_for_vision(outputs) is True + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/workers/uv.lock b/workers/uv.lock index d4f3b24046..281d4c41ab 100644 --- a/workers/uv.lock +++ b/workers/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.14'", @@ -4788,6 +4788,7 @@ dependencies = [ { name = "llama-parse" }, { name = "llmwhisperer-client" }, { name = "pdfplumber" }, + { name = "pypdfium2" }, { name = "python-dotenv" }, { name = "python-magic" }, { name = "qdrant-client" }, @@ -4827,6 +4828,7 @@ requires-dist = [ { name = "llama-parse", specifier = ">=0.6.0" }, { name = "llmwhisperer-client", specifier = ">=2.6.2" }, { name = "pdfplumber", specifier = ">=0.11.2" }, + { name = "pypdfium2", specifier = ">=4.0.0" }, { name = "python-dotenv", specifier = "==1.2.2" }, { name = "python-magic", specifier = "~=0.4.27" }, { name = "qdrant-client", specifier = ">=1.16.0,<1.17.0" },