diff --git a/src/AiClient.php b/src/AiClient.php index ebfeec75..071d3226 100644 --- a/src/AiClient.php +++ b/src/AiClient.php @@ -14,6 +14,7 @@ use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\ProviderRegistry; +use WordPress\AiClient\Results\DTO\EmbeddingResult; use WordPress\AiClient\Results\DTO\GenerativeAiResult; /** @@ -386,6 +387,71 @@ public static function generateVideoResult( return self::getConfiguredPromptBuilder($prompt, $modelOrConfig, $registry)->generateVideoResult(); } + /** + * Generates embeddings using the traditional API approach. + * + * @since 1.4.0 + * + * @param Prompt $prompt The prompt content. + * @param ModelInterface|ModelConfig|null $modelOrConfig Optional specific model to use, + * or model configuration for auto-discovery, + * or null for defaults. + * @param ProviderRegistry|null $registry Optional custom registry. If null, uses default. + * @return EmbeddingResult The embedding result. + * + * @throws \InvalidArgumentException If the prompt format is invalid. + * @throws \RuntimeException If no suitable model is found. + */ + public static function generateEmbeddingResult( + $prompt, + $modelOrConfig = null, + ?ProviderRegistry $registry = null + ): EmbeddingResult { + self::validateModelOrConfigParameter($modelOrConfig); + return self::getConfiguredPromptBuilder($prompt, $modelOrConfig, $registry)->generateEmbeddingResult(); + } + + /** + * Generates an embedding using the traditional API approach. + * + * @since 1.4.0 + * + * @param Prompt $prompt The prompt content. + * @param ModelInterface|ModelConfig|null $modelOrConfig Optional specific model to use, + * or model configuration for auto-discovery, + * or null for defaults. + * @param ProviderRegistry|null $registry Optional custom registry. If null, uses default. + * @return list The generated embedding vector. + */ + public static function generateEmbedding( + $prompt, + $modelOrConfig = null, + ?ProviderRegistry $registry = null + ): array { + return self::generateEmbeddingResult($prompt, $modelOrConfig, $registry)->getEmbedding(); + } + + /** + * Generates embeddings for a list of prompts using the traditional API approach. + * + * @since 1.4.0 + * + * @param list $prompts The prompts to embed. + * @param ModelInterface|ModelConfig|null $modelOrConfig Optional specific model to use, + * or model configuration for auto-discovery, + * or null for defaults. + * @param ProviderRegistry|null $registry Optional custom registry. If null, uses default. + * @return list> The generated embedding vectors. + */ + public static function generateEmbeddings( + array $prompts, + $modelOrConfig = null, + ?ProviderRegistry $registry = null + ): array { + self::validateModelOrConfigParameter($modelOrConfig); + return self::getConfiguredPromptBuilder(null, $modelOrConfig, $registry)->generateEmbeddings($prompts); + } + /** * Creates a new message builder for fluent API usage. * diff --git a/src/Builders/PromptBuilder.php b/src/Builders/PromptBuilder.php index 130fc574..14067030 100644 --- a/src/Builders/PromptBuilder.php +++ b/src/Builders/PromptBuilder.php @@ -23,6 +23,7 @@ use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\DTO\ModelRequirements; +use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationModelInterface; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\SpeechGeneration\Contracts\SpeechGenerationModelInterface; @@ -30,6 +31,7 @@ use WordPress\AiClient\Providers\Models\TextToSpeechConversion\Contracts\TextToSpeechConversionModelInterface; use WordPress\AiClient\Providers\Models\VideoGeneration\Contracts\VideoGenerationModelInterface; use WordPress\AiClient\Providers\ProviderRegistry; +use WordPress\AiClient\Results\DTO\EmbeddingResult; use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Tools\DTO\FunctionDeclaration; use WordPress\AiClient\Tools\DTO\FunctionResponse; @@ -479,6 +481,34 @@ public function usingCandidateCount(int $candidateCount): self return $this; } + /** + * Sets the embedding dimensions. + * + * @since 1.4.0 + * + * @param int $dimensions The embedding dimensions. + * @return self + */ + public function usingDimensions(int $dimensions): self + { + $this->modelConfig->setDimensions($dimensions); + return $this; + } + + /** + * Sets the embedding encoding format. + * + * @since 1.4.0 + * + * @param string $encodingFormat The embedding encoding format. + * @return self + */ + public function usingEncodingFormat(string $encodingFormat): self + { + $this->modelConfig->setEncodingFormat($encodingFormat); + return $this; + } + /** * Sets the function declarations available to the model. * @@ -761,7 +791,6 @@ private function inferCapabilityFromModelInterfaces(ModelInterface $model): ?Cap if ($model instanceof VideoGenerationModelInterface) { return CapabilityEnum::videoGeneration(); } - // No supported interface found return null; } @@ -1120,6 +1149,36 @@ public function generateVideoResult(): GenerativeAiResult return $this->generateResult(CapabilityEnum::videoGeneration()); } + /** + * Generates an embedding result from the prompt. + * + * @since 1.4.0 + * + * @return EmbeddingResult The generated embedding result. + * @throws InvalidArgumentException If the prompt or model validation fails. + * @throws RuntimeException If the model doesn't support embedding generation. + */ + public function generateEmbeddingResult(): EmbeddingResult + { + $this->validateMessages(); + + $capability = CapabilityEnum::embeddingGeneration(); + $model = $this->getConfiguredModel($capability); + + if (!$model instanceof EmbeddingGenerationModelInterface) { + throw new RuntimeException( + sprintf( + 'Model "%s" does not support embedding generation.', + $model->metadata()->getId() + ) + ); + } + + $this->dispatchEvent(new BeforeGenerateResultEvent($this->messages, $model, $capability)); + + return $model->generateEmbeddingResult([$this->messages]); + } + /** * Generates text from the prompt. * @@ -1166,6 +1225,62 @@ public function generateImage(): File return $this->generateImageResult()->toFile(); } + /** + * Generates an embedding from the prompt. + * + * @since 1.4.0 + * + * @return list The generated embedding vector. + * @throws InvalidArgumentException If the prompt or model validation fails. + */ + public function generateEmbedding(): array + { + return $this->generateEmbeddingResult()->getEmbedding(); + } + + /** + * Generates embeddings from the prompt or from the provided prompt list. + * + * @since 1.4.0 + * + * @param list|null $prompts Optional prompts to embed as a batch. + * @return list> The generated embedding vectors. + * @throws InvalidArgumentException If a prompt or model validation fails. + */ + public function generateEmbeddings(?array $prompts = null): array + { + if ($prompts === null) { + return $this->generateEmbeddingResult()->getEmbeddings(); + } + + if (!array_is_list($prompts)) { + throw new InvalidArgumentException('Prompts must be a list array.'); + } + + if (empty($prompts)) { + throw new InvalidArgumentException('Cannot generate embeddings from an empty prompt list.'); + } + + $promptMessages = []; + foreach ($prompts as $prompt) { + $promptMessages[] = $this->parsePromptToMessages($prompt); + } + + $capability = CapabilityEnum::embeddingGeneration(); + $model = $this->getConfiguredModel($capability); + + if (!$model instanceof EmbeddingGenerationModelInterface) { + throw new RuntimeException( + sprintf( + 'Model "%s" does not support embedding generation.', + $model->metadata()->getId() + ) + ); + } + + return $model->generateEmbeddingResult($promptMessages)->getEmbeddings(); + } + /** * Generates multiple images from the prompt. * @@ -1595,6 +1710,33 @@ private function parseMessage($input, MessageRoleEnum $defaultRole): Message return new Message($defaultRole, $parts); } + /** + * Parses prompt input into a message list. + * + * @since 1.4.0 + * + * @param Prompt $prompt The prompt to parse. + * @return list The parsed messages. + */ + private function parsePromptToMessages($prompt): array + { + if ($this->isMessagesList($prompt)) { + $messages = $prompt; + } else { + $messages = [$this->parseMessage($prompt, MessageRoleEnum::user())]; + } + + $originalMessages = $this->messages; + try { + $this->messages = $messages; + $this->validateMessages(); + } finally { + $this->messages = $originalMessages; + } + + return $messages; + } + /** * Validates the messages array for prompt generation. * diff --git a/src/Providers/Models/DTO/ModelConfig.php b/src/Providers/Models/DTO/ModelConfig.php index 20166520..8acaf060 100644 --- a/src/Providers/Models/DTO/ModelConfig.php +++ b/src/Providers/Models/DTO/ModelConfig.php @@ -45,6 +45,8 @@ * outputMediaOrientation?: string, * outputMediaAspectRatio?: string, * outputSpeechVoice?: string, + * dimensions?: int, + * encodingFormat?: string, * customOptions?: array * } * @@ -72,6 +74,8 @@ class ModelConfig extends AbstractDataTransferObject public const KEY_OUTPUT_MEDIA_ORIENTATION = 'outputMediaOrientation'; public const KEY_OUTPUT_MEDIA_ASPECT_RATIO = 'outputMediaAspectRatio'; public const KEY_OUTPUT_SPEECH_VOICE = 'outputSpeechVoice'; + public const KEY_DIMENSIONS = 'dimensions'; + public const KEY_ENCODING_FORMAT = 'encodingFormat'; public const KEY_CUSTOM_OPTIONS = 'customOptions'; /* @@ -181,6 +185,16 @@ class ModelConfig extends AbstractDataTransferObject */ protected ?string $outputSpeechVoice = null; + /** + * @var int|null Embedding vector dimensions. + */ + protected ?int $dimensions = null; + + /** + * @var string|null Embedding encoding format. + */ + protected ?string $encodingFormat = null; + /** * @var array Custom provider-specific options. */ @@ -771,6 +785,58 @@ public function getOutputSpeechVoice(): ?string return $this->outputSpeechVoice; } + /** + * Sets the embedding dimensions. + * + * @since 1.4.0 + * + * @param int $dimensions The embedding dimensions. + */ + public function setDimensions(int $dimensions): void + { + if ($dimensions < 1) { + throw new InvalidArgumentException('Dimensions must be greater than zero.'); + } + + $this->dimensions = $dimensions; + } + + /** + * Gets the embedding dimensions. + * + * @since 1.4.0 + * + * @return int|null The embedding dimensions. + */ + public function getDimensions(): ?int + { + return $this->dimensions; + } + + /** + * Sets the embedding encoding format. + * + * @since 1.4.0 + * + * @param string $encodingFormat The embedding encoding format. + */ + public function setEncodingFormat(string $encodingFormat): void + { + $this->encodingFormat = $encodingFormat; + } + + /** + * Gets the embedding encoding format. + * + * @since 1.4.0 + * + * @return string|null The embedding encoding format. + */ + public function getEncodingFormat(): ?string + { + return $this->encodingFormat; + } + /** * Sets a single custom option. * @@ -915,6 +981,15 @@ public static function getJsonSchema(): array 'type' => 'string', 'description' => 'Output speech voice.', ], + self::KEY_DIMENSIONS => [ + 'type' => 'integer', + 'minimum' => 1, + 'description' => 'Embedding vector dimensions.', + ], + self::KEY_ENCODING_FORMAT => [ + 'type' => 'string', + 'description' => 'Embedding encoding format.', + ], self::KEY_CUSTOM_OPTIONS => [ 'type' => 'object', 'additionalProperties' => true, @@ -1026,6 +1101,14 @@ static function (FunctionDeclaration $functionDeclaration): array { $data[self::KEY_OUTPUT_SPEECH_VOICE] = $this->outputSpeechVoice; } + if ($this->dimensions !== null) { + $data[self::KEY_DIMENSIONS] = $this->dimensions; + } + + if ($this->encodingFormat !== null) { + $data[self::KEY_ENCODING_FORMAT] = $this->encodingFormat; + } + if (!empty($this->customOptions)) { $data[self::KEY_CUSTOM_OPTIONS] = $this->customOptions; } @@ -1131,6 +1214,14 @@ static function (array $functionDeclarationData): FunctionDeclaration { $config->setOutputSpeechVoice($array[self::KEY_OUTPUT_SPEECH_VOICE]); } + if (isset($array[self::KEY_DIMENSIONS])) { + $config->setDimensions($array[self::KEY_DIMENSIONS]); + } + + if (isset($array[self::KEY_ENCODING_FORMAT])) { + $config->setEncodingFormat($array[self::KEY_ENCODING_FORMAT]); + } + if (isset($array[self::KEY_CUSTOM_OPTIONS])) { $config->setCustomOptions($array[self::KEY_CUSTOM_OPTIONS]); } diff --git a/src/Providers/Models/DTO/ModelRequirements.php b/src/Providers/Models/DTO/ModelRequirements.php index 7b92d804..31ecbbeb 100644 --- a/src/Providers/Models/DTO/ModelRequirements.php +++ b/src/Providers/Models/DTO/ModelRequirements.php @@ -342,6 +342,14 @@ private static function toRequiredOptions(ModelConfig $modelConfig): array ); } + if ($modelConfig->getDimensions() !== null) { + $requiredOptions[] = new RequiredOption(OptionEnum::dimensions(), $modelConfig->getDimensions()); + } + + if ($modelConfig->getEncodingFormat() !== null) { + $requiredOptions[] = new RequiredOption(OptionEnum::encodingFormat(), $modelConfig->getEncodingFormat()); + } + // Add custom options as individual RequiredOptions foreach ($modelConfig->getCustomOptions() as $key => $value) { $requiredOptions[] = new RequiredOption(OptionEnum::customOptions(), [$key => $value]); diff --git a/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php b/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php new file mode 100644 index 00000000..ca087ad8 --- /dev/null +++ b/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php @@ -0,0 +1,26 @@ +> $prompts Array of message lists to embed. + * @return EmbeddingResult Result containing generated embedding vectors. + */ + public function generateEmbeddingResult(array $prompts): EmbeddingResult; +} diff --git a/src/Providers/Models/Enums/OptionEnum.php b/src/Providers/Models/Enums/OptionEnum.php index 27b2248f..b8f209e3 100644 --- a/src/Providers/Models/Enums/OptionEnum.php +++ b/src/Providers/Models/Enums/OptionEnum.php @@ -21,6 +21,8 @@ * Dynamically loaded from ModelConfig KEY_* constants: * @method static self candidateCount() Creates an instance for CANDIDATE_COUNT option. * @method static self customOptions() Creates an instance for CUSTOM_OPTIONS option. + * @method static self dimensions() Creates an instance for DIMENSIONS option. + * @method static self encodingFormat() Creates an instance for ENCODING_FORMAT option. * @method static self frequencyPenalty() Creates an instance for FREQUENCY_PENALTY option. * @method static self functionDeclarations() Creates an instance for FUNCTION_DECLARATIONS option. * @method static self logprobs() Creates an instance for LOGPROBS option. @@ -42,6 +44,8 @@ * @method static self webSearch() Creates an instance for WEB_SEARCH option. * @method bool isCandidateCount() Checks if the option is CANDIDATE_COUNT. * @method bool isCustomOptions() Checks if the option is CUSTOM_OPTIONS. + * @method bool isDimensions() Checks if the option is DIMENSIONS. + * @method bool isEncodingFormat() Checks if the option is ENCODING_FORMAT. * @method bool isFrequencyPenalty() Checks if the option is FREQUENCY_PENALTY. * @method bool isFunctionDeclarations() Checks if the option is FUNCTION_DECLARATIONS. * @method bool isLogprobs() Checks if the option is LOGPROBS. diff --git a/src/Results/DTO/EmbeddingResult.php b/src/Results/DTO/EmbeddingResult.php new file mode 100644 index 00000000..dd2ee86f --- /dev/null +++ b/src/Results/DTO/EmbeddingResult.php @@ -0,0 +1,250 @@ +>, + * dimensions: int, + * tokenUsage: TokenUsageArrayShape, + * providerMetadata: ProviderMetadataArrayShape, + * modelMetadata: ModelMetadataArrayShape, + * additionalData?: array + * } + * + * @extends AbstractDataTransferObject + */ +class EmbeddingResult extends AbstractDataTransferObject implements ResultInterface +{ + public const KEY_ID = 'id'; + public const KEY_EMBEDDINGS = 'embeddings'; + public const KEY_DIMENSIONS = 'dimensions'; + public const KEY_TOKEN_USAGE = 'tokenUsage'; + public const KEY_PROVIDER_METADATA = 'providerMetadata'; + public const KEY_MODEL_METADATA = 'modelMetadata'; + public const KEY_ADDITIONAL_DATA = 'additionalData'; + + private string $id; + + /** + * @var list> + */ + private array $embeddings; + + private int $dimensions; + private TokenUsage $tokenUsage; + private ProviderMetadata $providerMetadata; + private ModelMetadata $modelMetadata; + + /** + * @var array + */ + private array $additionalData; + + /** + * Constructor. + * + * @since 1.4.0 + * + * @param string $id Unique identifier for this result. + * @param list> $embeddings The generated embedding vectors. + * @param int $dimensions The vector dimension count. + * @param TokenUsage $tokenUsage Token usage statistics. + * @param ProviderMetadata $providerMetadata Provider metadata. + * @param ModelMetadata $modelMetadata Model metadata. + * @param array $additionalData Additional data. + */ + public function __construct( + string $id, + array $embeddings, + int $dimensions, + TokenUsage $tokenUsage, + ProviderMetadata $providerMetadata, + ModelMetadata $modelMetadata, + array $additionalData = [] + ) { + if (empty($embeddings)) { + throw new InvalidArgumentException('At least one embedding must be provided'); + } + + if ($dimensions < 1) { + throw new InvalidArgumentException('Embedding dimensions must be greater than zero'); + } + + foreach ($embeddings as $embedding) { + if (!array_is_list($embedding)) { + throw new InvalidArgumentException('Embeddings must be list arrays.'); + } + + if (count($embedding) !== $dimensions) { + throw new InvalidArgumentException('Embedding vector length must match dimensions.'); + } + } + + $this->id = $id; + $this->embeddings = $embeddings; + $this->dimensions = $dimensions; + $this->tokenUsage = $tokenUsage; + $this->providerMetadata = $providerMetadata; + $this->modelMetadata = $modelMetadata; + $this->additionalData = $additionalData; + } + + public function getId(): string + { + return $this->id; + } + + /** + * Gets the generated embedding vectors. + * + * @since 1.4.0 + * + * @return list> The embeddings. + */ + public function getEmbeddings(): array + { + return $this->embeddings; + } + + /** + * Gets the first generated embedding vector. + * + * @since 1.4.0 + * + * @return list The first embedding. + */ + public function getEmbedding(): array + { + return $this->embeddings[0]; + } + + public function getDimensions(): int + { + return $this->dimensions; + } + + public function getTokenUsage(): TokenUsage + { + return $this->tokenUsage; + } + + public function getProviderMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + public function getModelMetadata(): ModelMetadata + { + return $this->modelMetadata; + } + + public function getAdditionalData(): array + { + return $this->additionalData; + } + + public static function getJsonSchema(): array + { + return [ + 'type' => 'object', + 'properties' => [ + self::KEY_ID => [ + 'type' => 'string', + 'description' => 'Unique identifier for this result.', + ], + self::KEY_EMBEDDINGS => [ + 'type' => 'array', + 'items' => [ + 'type' => 'array', + 'items' => [ + 'type' => 'number', + ], + ], + 'description' => 'Generated embedding vectors.', + ], + self::KEY_DIMENSIONS => [ + 'type' => 'integer', + 'minimum' => 1, + 'description' => 'Embedding vector dimensions.', + ], + self::KEY_TOKEN_USAGE => TokenUsage::getJsonSchema(), + self::KEY_PROVIDER_METADATA => ProviderMetadata::getJsonSchema(), + self::KEY_MODEL_METADATA => ModelMetadata::getJsonSchema(), + self::KEY_ADDITIONAL_DATA => [ + 'type' => 'object', + 'additionalProperties' => true, + 'description' => 'Additional provider-specific data.', + ], + ], + 'required' => [ + self::KEY_ID, + self::KEY_EMBEDDINGS, + self::KEY_DIMENSIONS, + self::KEY_TOKEN_USAGE, + self::KEY_PROVIDER_METADATA, + self::KEY_MODEL_METADATA, + ], + ]; + } + + /** + * @return EmbeddingResultArrayShape + */ + public function toArray(): array + { + $data = [ + self::KEY_ID => $this->id, + self::KEY_EMBEDDINGS => $this->embeddings, + self::KEY_DIMENSIONS => $this->dimensions, + self::KEY_TOKEN_USAGE => $this->tokenUsage->toArray(), + self::KEY_PROVIDER_METADATA => $this->providerMetadata->toArray(), + self::KEY_MODEL_METADATA => $this->modelMetadata->toArray(), + ]; + + if (!empty($this->additionalData)) { + $data[self::KEY_ADDITIONAL_DATA] = $this->additionalData; + } + + return $data; + } + + public static function fromArray(array $array): self + { + static::validateFromArrayData($array, [ + self::KEY_ID, + self::KEY_EMBEDDINGS, + self::KEY_DIMENSIONS, + self::KEY_TOKEN_USAGE, + self::KEY_PROVIDER_METADATA, + self::KEY_MODEL_METADATA, + ]); + + return new self( + $array[self::KEY_ID], + $array[self::KEY_EMBEDDINGS], + $array[self::KEY_DIMENSIONS], + TokenUsage::fromArray($array[self::KEY_TOKEN_USAGE]), + ProviderMetadata::fromArray($array[self::KEY_PROVIDER_METADATA]), + ModelMetadata::fromArray($array[self::KEY_MODEL_METADATA]), + $array[self::KEY_ADDITIONAL_DATA] ?? [] + ); + } +} diff --git a/tests/traits/MockModelCreationTrait.php b/tests/traits/MockModelCreationTrait.php index d330c518..4ca4784b 100644 --- a/tests/traits/MockModelCreationTrait.php +++ b/tests/traits/MockModelCreationTrait.php @@ -11,12 +11,14 @@ use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; +use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationModelInterface; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; use WordPress\AiClient\Providers\Models\VideoGeneration\Contracts\VideoGenerationModelInterface; use WordPress\AiClient\Providers\ProviderRegistry; use WordPress\AiClient\Results\DTO\Candidate; +use WordPress\AiClient\Results\DTO\EmbeddingResult; use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Results\DTO\TokenUsage; use WordPress\AiClient\Results\Enums\FinishReasonEnum; @@ -76,6 +78,38 @@ protected function createTestResult(string $content = 'Test response'): Generati ); } + /** + * Creates a test EmbeddingResult for testing purposes. + * + * @param list>|null $embeddings Optional embeddings for the response. + * @return EmbeddingResult + */ + protected function createTestEmbeddingResult(?array $embeddings = null): EmbeddingResult + { + $embeddings = $embeddings ?? [[0.1, 0.2, 0.3]]; + + $providerMetadata = new ProviderMetadata( + 'mock', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + $modelMetadata = new ModelMetadata( + 'mock-embedding-model', + 'Mock Embedding Model', + [CapabilityEnum::embeddingGeneration()], + [] + ); + + return new EmbeddingResult( + 'test-embedding-result-id', + $embeddings, + count($embeddings[0]), + new TokenUsage(10, 0, 10), + $providerMetadata, + $modelMetadata + ); + } + /** * Creates a test model metadata instance for text generation. * @@ -133,6 +167,25 @@ protected function createTestVideoModelMetadata( ); } + /** + * Creates a test model metadata instance for embedding generation. + * + * @param string $id Optional model ID. + * @param string $name Optional model name. + * @return ModelMetadata + */ + protected function createTestEmbeddingModelMetadata( + string $id = 'test-embedding-model', + string $name = 'Test Embedding Model' + ): ModelMetadata { + return new ModelMetadata( + $id, + $name, + [CapabilityEnum::embeddingGeneration()], + [] + ); + } + /** * Creates a mock text generation model using anonymous class. * @@ -334,6 +387,73 @@ public function generateVideoResult(array $prompt): GenerativeAiResult }; } + /** + * Creates a mock embedding generation model using anonymous class. + * + * @param EmbeddingResult $result The result to return from generation. + * @param ModelMetadata|null $metadata Optional metadata (uses default if not provided). + * @return ModelInterface&EmbeddingGenerationModelInterface The mock model. + */ + protected function createMockEmbeddingGenerationModel( + EmbeddingResult $result, + ?ModelMetadata $metadata = null + ): ModelInterface { + $metadata = $metadata ?? $this->createTestEmbeddingModelMetadata(); + + $providerMetadata = new ProviderMetadata( + 'mock', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + + return new class ( + $metadata, + $providerMetadata, + $result + ) implements ModelInterface, EmbeddingGenerationModelInterface { + private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; + private EmbeddingResult $result; + private ModelConfig $config; + + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata, + EmbeddingResult $result + ) { + $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->result = $result; + $this->config = new ModelConfig(); + } + + public function metadata(): ModelMetadata + { + return $this->metadata; + } + + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + public function getConfig(): ModelConfig + { + return $this->config; + } + + public function generateEmbeddingResult(array $prompts): EmbeddingResult + { + return $this->result; + } + }; + } + /** * Creates a mock model that doesn't implement any generation interfaces. * diff --git a/tests/unit/AiClientTest.php b/tests/unit/AiClientTest.php index c573c756..e8360cf3 100644 --- a/tests/unit/AiClientTest.php +++ b/tests/unit/AiClientTest.php @@ -166,6 +166,65 @@ public function testGenerateVideoResultWithInvalidModel(): void AiClient::generateVideoResult($prompt, $invalidModel, $registry); } + /** + * Tests generateEmbeddingResult with string prompt and provided model. + */ + public function testGenerateEmbeddingResultWithStringAndModel(): void + { + $prompt = 'Generate embedding'; + $expectedResult = $this->createTestEmbeddingResult(); + $mockModel = $this->createMockEmbeddingGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); + + $result = AiClient::generateEmbeddingResult($prompt, $mockModel, $registry); + + $this->assertSame($expectedResult, $result); + } + + /** + * Tests generateEmbedding returns the first vector. + */ + public function testGenerateEmbeddingReturnsFirstVector(): void + { + $expectedResult = $this->createTestEmbeddingResult([[0.1, 0.2], [0.3, 0.4]]); + $mockModel = $this->createMockEmbeddingGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); + + $embedding = AiClient::generateEmbedding('Generate embedding', $mockModel, $registry); + + $this->assertSame([0.1, 0.2], $embedding); + } + + /** + * Tests generateEmbeddings returns batch vectors. + */ + public function testGenerateEmbeddingsReturnsBatchVectors(): void + { + $expectedEmbeddings = [[0.1, 0.2], [0.3, 0.4]]; + $expectedResult = $this->createTestEmbeddingResult($expectedEmbeddings); + $mockModel = $this->createMockEmbeddingGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); + + $embeddings = AiClient::generateEmbeddings(['First prompt', 'Second prompt'], $mockModel, $registry); + + $this->assertSame($expectedEmbeddings, $embeddings); + } + + /** + * Tests generateEmbeddingResult throws exception for model without embedding generation interface. + */ + public function testGenerateEmbeddingResultWithInvalidModel(): void + { + $prompt = 'Generate embedding'; + $invalidModel = $this->createMockUnsupportedModel('invalid-embedding-model'); + $registry = $this->createRegistryWithMockProvider(); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('Model "invalid-embedding-model" does not support embedding generation.'); + + AiClient::generateEmbeddingResult($prompt, $invalidModel, $registry); + } + /** * Tests generateTextResult with Message object. diff --git a/tests/unit/Builders/PromptBuilderTest.php b/tests/unit/Builders/PromptBuilderTest.php index ce68a223..0a549763 100644 --- a/tests/unit/Builders/PromptBuilderTest.php +++ b/tests/unit/Builders/PromptBuilderTest.php @@ -1815,6 +1815,83 @@ public function testGenerateImageResult(): void $this->assertTrue($modalities[0]->isImage()); } + /** + * Tests generateEmbeddingResult method. + * + * @return void + */ + public function testGenerateEmbeddingResult(): void + { + $result = $this->createTestEmbeddingResult([[0.1, 0.2, 0.3]]); + $model = $this->createMockEmbeddingGenerationModel($result); + + $builder = new PromptBuilder($this->registry, 'Generate embedding'); + $builder->usingModel($model); + $builder->usingDimensions(3); + $builder->usingEncodingFormat('float'); + + $actualResult = $builder->generateEmbeddingResult(); + $this->assertSame($result, $actualResult); + + $this->assertSame(3, $model->getConfig()->getDimensions()); + $this->assertSame('float', $model->getConfig()->getEncodingFormat()); + } + + /** + * Tests generateEmbedding returns the first vector. + * + * @return void + */ + public function testGenerateEmbedding(): void + { + $result = $this->createTestEmbeddingResult([[0.1, 0.2], [0.3, 0.4]]); + $model = $this->createMockEmbeddingGenerationModel($result); + + $builder = new PromptBuilder($this->registry, 'Generate embedding'); + $builder->usingModel($model); + + $this->assertSame([0.1, 0.2], $builder->generateEmbedding()); + } + + /** + * Tests generateEmbeddings returns batch vectors. + * + * @return void + */ + public function testGenerateEmbeddings(): void + { + $embeddings = [[0.1, 0.2], [0.3, 0.4]]; + $result = $this->createTestEmbeddingResult($embeddings); + $model = $this->createMockEmbeddingGenerationModel($result); + + $builder = new PromptBuilder($this->registry); + $builder->usingModel($model); + + $this->assertSame($embeddings, $builder->generateEmbeddings(['First prompt', 'Second prompt'])); + } + + /** + * Tests generateEmbeddingResult throws exception for unsupported model. + * + * @return void + */ + public function testGenerateEmbeddingResultThrowsExceptionForUnsupportedModel(): void + { + $metadata = $this->createMock(ModelMetadata::class); + $metadata->method('getId')->willReturn('test-model'); + + $model = $this->createMock(ModelInterface::class); + $model->method('metadata')->willReturn($metadata); + + $builder = new PromptBuilder($this->registry, 'Generate embedding'); + $builder->usingModel($model); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('Model "test-model" does not support embedding generation'); + + $builder->generateEmbeddingResult(); + } + /** * Tests generateVideoResult method. * diff --git a/tests/unit/Providers/Models/DTO/ModelConfigTest.php b/tests/unit/Providers/Models/DTO/ModelConfigTest.php index 578e10e9..74483acb 100644 --- a/tests/unit/Providers/Models/DTO/ModelConfigTest.php +++ b/tests/unit/Providers/Models/DTO/ModelConfigTest.php @@ -74,6 +74,8 @@ public function testDefaultConstructor(): void $this->assertNull($config->getOutputMediaOrientation()); $this->assertNull($config->getOutputMediaAspectRatio()); $this->assertNull($config->getOutputSpeechVoice()); + $this->assertNull($config->getDimensions()); + $this->assertNull($config->getEncodingFormat()); $this->assertEquals([], $config->getCustomOptions()); } @@ -177,6 +179,14 @@ public function testSettersAndGetters(): void $config->setOutputSpeechVoice('alloy'); $this->assertEquals('alloy', $config->getOutputSpeechVoice()); + // Test embedding dimensions + $config->setDimensions(1536); + $this->assertEquals(1536, $config->getDimensions()); + + // Test embedding encoding format + $config->setEncodingFormat('float'); + $this->assertEquals('float', $config->getEncodingFormat()); + // Test custom options $customOptions = ['custom_param' => 'value', 'another_param' => 123]; $config->setCustomOptions($customOptions); @@ -219,6 +229,8 @@ public function testGetJsonSchema(): void ModelConfig::KEY_OUTPUT_MEDIA_ORIENTATION, ModelConfig::KEY_OUTPUT_MEDIA_ASPECT_RATIO, ModelConfig::KEY_OUTPUT_SPEECH_VOICE, + ModelConfig::KEY_DIMENSIONS, + ModelConfig::KEY_ENCODING_FORMAT, ModelConfig::KEY_CUSTOM_OPTIONS ]; @@ -238,6 +250,8 @@ public function testGetJsonSchema(): void $this->assertEquals('string', $schema['properties'][ModelConfig::KEY_OUTPUT_MEDIA_ORIENTATION]['type']); $this->assertEquals('string', $schema['properties'][ModelConfig::KEY_OUTPUT_MEDIA_ASPECT_RATIO]['type']); $this->assertEquals('string', $schema['properties'][ModelConfig::KEY_OUTPUT_SPEECH_VOICE]['type']); + $this->assertEquals('integer', $schema['properties'][ModelConfig::KEY_DIMENSIONS]['type']); + $this->assertEquals('string', $schema['properties'][ModelConfig::KEY_ENCODING_FORMAT]['type']); $this->assertEquals('object', $schema['properties'][ModelConfig::KEY_CUSTOM_OPTIONS]['type']); // Check constraints @@ -277,6 +291,8 @@ public function testToArrayAllProperties(): void $config->setOutputMediaOrientation(MediaOrientationEnum::portrait()); $config->setOutputMediaAspectRatio('9:16'); $config->setOutputSpeechVoice('onyx'); + $config->setDimensions(768); + $config->setEncodingFormat('base64'); $config->setCustomOptions(['key' => 'value']); $array = $config->toArray(); @@ -302,6 +318,8 @@ public function testToArrayAllProperties(): void $this->assertEquals('portrait', $array[ModelConfig::KEY_OUTPUT_MEDIA_ORIENTATION]); $this->assertEquals('9:16', $array[ModelConfig::KEY_OUTPUT_MEDIA_ASPECT_RATIO]); $this->assertEquals('onyx', $array[ModelConfig::KEY_OUTPUT_SPEECH_VOICE]); + $this->assertEquals(768, $array[ModelConfig::KEY_DIMENSIONS]); + $this->assertEquals('base64', $array[ModelConfig::KEY_ENCODING_FORMAT]); $this->assertEquals(['key' => 'value'], $array[ModelConfig::KEY_CUSTOM_OPTIONS]); } @@ -414,6 +432,8 @@ public function testFromArrayAllProperties(): void ModelConfig::KEY_OUTPUT_MEDIA_ORIENTATION => 'landscape', ModelConfig::KEY_OUTPUT_MEDIA_ASPECT_RATIO => '16:9', ModelConfig::KEY_OUTPUT_SPEECH_VOICE => 'fable', + ModelConfig::KEY_DIMENSIONS => 1024, + ModelConfig::KEY_ENCODING_FORMAT => 'float', ModelConfig::KEY_CUSTOM_OPTIONS => ['custom' => true] ]; @@ -444,6 +464,8 @@ public function testFromArrayAllProperties(): void $this->assertEquals(MediaOrientationEnum::landscape(), $config->getOutputMediaOrientation()); $this->assertEquals('16:9', $config->getOutputMediaAspectRatio()); $this->assertEquals('fable', $config->getOutputSpeechVoice()); + $this->assertEquals(1024, $config->getDimensions()); + $this->assertEquals('float', $config->getEncodingFormat()); $this->assertEquals(['custom' => true], $config->getCustomOptions()); } @@ -479,6 +501,8 @@ public function testArrayRoundTrip(): void $original->setOutputMediaOrientation(MediaOrientationEnum::square()); $original->setOutputMediaAspectRatio('1:1'); $original->setOutputSpeechVoice('shimmer'); + $original->setDimensions(256); + $original->setEncodingFormat('float'); $original->setCustomOptions(['test' => 'value']); $array = $original->toArray(); @@ -493,6 +517,8 @@ public function testArrayRoundTrip(): void $this->assertEquals($original->getOutputMediaOrientation(), $restored->getOutputMediaOrientation()); $this->assertEquals($original->getOutputMediaAspectRatio(), $restored->getOutputMediaAspectRatio()); $this->assertEquals($original->getOutputSpeechVoice(), $restored->getOutputSpeechVoice()); + $this->assertEquals($original->getDimensions(), $restored->getDimensions()); + $this->assertEquals($original->getEncodingFormat(), $restored->getEncodingFormat()); $this->assertEquals($original->getCustomOptions(), $restored->getCustomOptions()); } diff --git a/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php b/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php index 1cb7897e..74566e9a 100644 --- a/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php +++ b/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php @@ -598,6 +598,8 @@ public function testFromPromptDataWithModelConfigOptions(): void $modelConfig->setMaxTokens(2000); $modelConfig->setTopP(0.95); $modelConfig->setStopSequences(['END']); + $modelConfig->setDimensions(1536); + $modelConfig->setEncodingFormat('float'); $requirements = ModelRequirements::fromPromptData( CapabilityEnum::textGeneration(), @@ -612,6 +614,8 @@ public function testFromPromptDataWithModelConfigOptions(): void $hasTemperature = false; $hasMaxTokens = false; $hasTopP = false; + $hasDimensions = false; + $hasEncodingFormat = false; foreach ($options as $option) { if ($option->getName()->isTemperature()) { @@ -626,10 +630,20 @@ public function testFromPromptDataWithModelConfigOptions(): void $hasTopP = true; $this->assertEquals(0.95, $option->getValue()); } + if ($option->getName()->isDimensions()) { + $hasDimensions = true; + $this->assertEquals(1536, $option->getValue()); + } + if ($option->getName()->isEncodingFormat()) { + $hasEncodingFormat = true; + $this->assertEquals('float', $option->getValue()); + } } $this->assertTrue($hasTemperature, 'Temperature option should be present'); $this->assertTrue($hasMaxTokens, 'Max tokens option should be present'); $this->assertTrue($hasTopP, 'Top P option should be present'); + $this->assertTrue($hasDimensions, 'Dimensions option should be present'); + $this->assertTrue($hasEncodingFormat, 'Encoding format option should be present'); } } diff --git a/tests/unit/Providers/Models/Enums/OptionEnumTest.php b/tests/unit/Providers/Models/Enums/OptionEnumTest.php index 9248a9ad..b905ddc9 100644 --- a/tests/unit/Providers/Models/Enums/OptionEnumTest.php +++ b/tests/unit/Providers/Models/Enums/OptionEnumTest.php @@ -57,6 +57,8 @@ protected function getExpectedValues(): array 'OUTPUT_MEDIA_ORIENTATION' => 'outputMediaOrientation', 'OUTPUT_MEDIA_ASPECT_RATIO' => 'outputMediaAspectRatio', 'OUTPUT_SPEECH_VOICE' => 'outputSpeechVoice', + 'DIMENSIONS' => 'dimensions', + 'ENCODING_FORMAT' => 'encodingFormat', 'CUSTOM_OPTIONS' => 'customOptions', ]; } @@ -111,6 +113,8 @@ public function testDynamicallyLoadedConstants(): void $this->assertInstanceOf(OptionEnum::class, OptionEnum::outputFileType()); $this->assertInstanceOf(OptionEnum::class, OptionEnum::outputMediaOrientation()); $this->assertInstanceOf(OptionEnum::class, OptionEnum::outputMediaAspectRatio()); + $this->assertInstanceOf(OptionEnum::class, OptionEnum::dimensions()); + $this->assertInstanceOf(OptionEnum::class, OptionEnum::encodingFormat()); $this->assertInstanceOf(OptionEnum::class, OptionEnum::customOptions()); } @@ -134,6 +138,8 @@ public function testGetValuesIncludesDynamicConstants(): void $this->assertContains('outputFileType', $values); $this->assertContains('outputMediaOrientation', $values); $this->assertContains('outputMediaAspectRatio', $values); + $this->assertContains('dimensions', $values); + $this->assertContains('encodingFormat', $values); $this->assertContains('customOptions', $values); } } diff --git a/tests/unit/Results/DTO/EmbeddingResultTest.php b/tests/unit/Results/DTO/EmbeddingResultTest.php new file mode 100644 index 00000000..e1e4b2c7 --- /dev/null +++ b/tests/unit/Results/DTO/EmbeddingResultTest.php @@ -0,0 +1,99 @@ + 'provider-123'] + ); + } + + public function testGetters(): void + { + $result = $this->createEmbeddingResult(); + + $this->assertSame('embedding-result-id', $result->getId()); + $this->assertSame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], $result->getEmbeddings()); + $this->assertSame([0.1, 0.2, 0.3], $result->getEmbedding()); + $this->assertSame(3, $result->getDimensions()); + $this->assertSame(4, $result->getTokenUsage()->getPromptTokens()); + $this->assertSame('mock', $result->getProviderMetadata()->getId()); + $this->assertSame('mock-embedding-model', $result->getModelMetadata()->getId()); + $this->assertSame(['providerResultId' => 'provider-123'], $result->getAdditionalData()); + } + + public function testArrayRoundTrip(): void + { + $result = $this->createEmbeddingResult(); + + $this->assertEquals($result, EmbeddingResult::fromArray($result->toArray())); + } + + public function testRequiresAtLeastOneEmbedding(): void + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('At least one embedding must be provided'); + + new EmbeddingResult( + 'embedding-result-id', + [], + 3, + new TokenUsage(4, 0, 4), + new ProviderMetadata('mock', 'Mock Provider', ProviderTypeEnum::cloud()), + new ModelMetadata( + 'mock-embedding-model', + 'Mock Embedding Model', + [CapabilityEnum::embeddingGeneration()], + [] + ) + ); + } + + public function testEmbeddingLengthMustMatchDimensions(): void + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Embedding vector length must match dimensions.'); + + new EmbeddingResult( + 'embedding-result-id', + [[0.1, 0.2]], + 3, + new TokenUsage(4, 0, 4), + new ProviderMetadata('mock', 'Mock Provider', ProviderTypeEnum::cloud()), + new ModelMetadata( + 'mock-embedding-model', + 'Mock Embedding Model', + [CapabilityEnum::embeddingGeneration()], + [] + ) + ); + } +}