diff --git a/src/Builders/PromptBuilder.php b/src/Builders/PromptBuilder.php index 130fc574..54f391e0 100644 --- a/src/Builders/PromptBuilder.php +++ b/src/Builders/PromptBuilder.php @@ -23,6 +23,7 @@ use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\DTO\ModelRequirements; +use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationModelInterface; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\SpeechGeneration\Contracts\SpeechGenerationModelInterface; @@ -30,6 +31,7 @@ use WordPress\AiClient\Providers\Models\TextToSpeechConversion\Contracts\TextToSpeechConversionModelInterface; use WordPress\AiClient\Providers\Models\VideoGeneration\Contracts\VideoGenerationModelInterface; use WordPress\AiClient\Providers\ProviderRegistry; +use WordPress\AiClient\Results\DTO\EmbeddingResult; use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Tools\DTO\FunctionDeclaration; use WordPress\AiClient\Tools\DTO\FunctionResponse; @@ -674,6 +676,20 @@ public function asOutputSpeechVoice(string $voice): self return $this; } + /** + * Sets the embedding vector dimension count. + * + * @since n.e.x.t + * + * @param int $dimensions The embedding vector dimension count. + * @return self + */ + public function usingEmbeddingDimensions(int $dimensions): self + { + $this->modelConfig->setEmbeddingDimensions($dimensions); + return $this; + } + /** * Configures the prompt for JSON response output. * @@ -761,6 +777,9 @@ private function inferCapabilityFromModelInterfaces(ModelInterface $model): ?Cap if ($model instanceof VideoGenerationModelInterface) { return CapabilityEnum::videoGeneration(); } + if ($model instanceof EmbeddingGenerationModelInterface) { + return CapabilityEnum::embeddingGeneration(); + } // No supported interface found return null; @@ -1120,6 +1139,40 @@ public function generateVideoResult(): GenerativeAiResult return $this->generateResult(CapabilityEnum::videoGeneration()); } + /** + * Generates an embedding result from the prompt. + * + * @since n.e.x.t + * + * @return EmbeddingResult The generated result containing embedding vectors. + * @throws InvalidArgumentException If the prompt or model validation fails. + * @throws RuntimeException If the model doesn't support embedding generation. + */ + public function generateEmbeddingResult(): EmbeddingResult + { + $this->validateMessages(); + $capability = CapabilityEnum::embeddingGeneration(); + $model = $this->getConfiguredModel($capability); + + $this->dispatchEvent( + new BeforeGenerateResultEvent($this->messages, $model, $capability) + ); + + if (!$model instanceof EmbeddingGenerationModelInterface) { + throw new RuntimeException( + sprintf('Model "%s" does not support embedding generation.', $model->metadata()->getId()) + ); + } + + $result = $model->generateEmbeddingResult($this->messages); + + $this->dispatchEvent( + new AfterGenerateResultEvent($this->messages, $model, $capability, $result) + ); + + return $result; + } + /** * Generates text from the prompt. * @@ -1133,6 +1186,30 @@ public function generateText(): string return $this->generateTextResult()->toText(); } + /** + * Generates the first embedding vector from the prompt. + * + * @since n.e.x.t + * + * @return list The generated embedding vector. + */ + public function generateEmbedding(): array + { + return $this->generateEmbeddingResult()->getEmbedding(); + } + + /** + * Generates embedding vectors from the prompt. + * + * @since n.e.x.t + * + * @return list> The generated embedding vectors. + */ + public function generateEmbeddings(): array + { + return $this->generateEmbeddingResult()->getEmbeddings(); + } + /** * Generates multiple text candidates from the prompt. * diff --git a/src/Events/AfterGenerateResultEvent.php b/src/Events/AfterGenerateResultEvent.php index 321ce088..be9d1c62 100644 --- a/src/Events/AfterGenerateResultEvent.php +++ b/src/Events/AfterGenerateResultEvent.php @@ -7,7 +7,7 @@ use WordPress\AiClient\Messages\DTO\Message; use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; -use WordPress\AiClient\Results\DTO\GenerativeAiResult; +use WordPress\AiClient\Results\Contracts\ResultInterface; /** * Event dispatched after a prompt has been sent to the AI model and a response received. @@ -35,9 +35,9 @@ class AfterGenerateResultEvent private ?CapabilityEnum $capability; /** - * @var GenerativeAiResult The result from the model. + * @var ResultInterface The result from the model. */ - private GenerativeAiResult $result; + private ResultInterface $result; /** * Constructor. @@ -47,13 +47,13 @@ class AfterGenerateResultEvent * @param list $messages The messages that were sent to the model. * @param ModelInterface $model The model that processed the prompt. * @param CapabilityEnum|null $capability The capability that was used for generation. - * @param GenerativeAiResult $result The result from the model. + * @param ResultInterface $result The result from the model. */ public function __construct( array $messages, ModelInterface $model, ?CapabilityEnum $capability, - GenerativeAiResult $result + ResultInterface $result ) { $this->messages = $messages; $this->model = $model; @@ -102,9 +102,9 @@ public function getCapability(): ?CapabilityEnum * * @since 0.4.0 * - * @return GenerativeAiResult The result. + * @return ResultInterface The result. */ - public function getResult(): GenerativeAiResult + public function getResult(): ResultInterface { return $this->result; } diff --git a/src/Providers/Models/DTO/ModelConfig.php b/src/Providers/Models/DTO/ModelConfig.php index 20166520..25520219 100644 --- a/src/Providers/Models/DTO/ModelConfig.php +++ b/src/Providers/Models/DTO/ModelConfig.php @@ -45,6 +45,7 @@ * outputMediaOrientation?: string, * outputMediaAspectRatio?: string, * outputSpeechVoice?: string, + * embeddingDimensions?: int, * customOptions?: array * } * @@ -72,6 +73,7 @@ class ModelConfig extends AbstractDataTransferObject public const KEY_OUTPUT_MEDIA_ORIENTATION = 'outputMediaOrientation'; public const KEY_OUTPUT_MEDIA_ASPECT_RATIO = 'outputMediaAspectRatio'; public const KEY_OUTPUT_SPEECH_VOICE = 'outputSpeechVoice'; + public const KEY_EMBEDDING_DIMENSIONS = 'embeddingDimensions'; public const KEY_CUSTOM_OPTIONS = 'customOptions'; /* @@ -181,6 +183,11 @@ class ModelConfig extends AbstractDataTransferObject */ protected ?string $outputSpeechVoice = null; + /** + * @var int|null Embedding vector dimension count. + */ + protected ?int $embeddingDimensions = null; + /** * @var array Custom provider-specific options. */ @@ -771,6 +778,34 @@ public function getOutputSpeechVoice(): ?string return $this->outputSpeechVoice; } + /** + * Sets the embedding vector dimension count. + * + * @since n.e.x.t + * + * @param int $embeddingDimensions The embedding vector dimension count. + */ + public function setEmbeddingDimensions(int $embeddingDimensions): void + { + if ($embeddingDimensions < 1) { + throw new InvalidArgumentException('Embedding dimensions must be greater than 0.'); + } + + $this->embeddingDimensions = $embeddingDimensions; + } + + /** + * Gets the embedding vector dimension count. + * + * @since n.e.x.t + * + * @return int|null The embedding vector dimension count. + */ + public function getEmbeddingDimensions(): ?int + { + return $this->embeddingDimensions; + } + /** * Sets a single custom option. * @@ -915,6 +950,11 @@ public static function getJsonSchema(): array 'type' => 'string', 'description' => 'Output speech voice.', ], + self::KEY_EMBEDDING_DIMENSIONS => [ + 'type' => 'integer', + 'minimum' => 1, + 'description' => 'Embedding vector dimension count.', + ], self::KEY_CUSTOM_OPTIONS => [ 'type' => 'object', 'additionalProperties' => true, @@ -1026,6 +1066,10 @@ static function (FunctionDeclaration $functionDeclaration): array { $data[self::KEY_OUTPUT_SPEECH_VOICE] = $this->outputSpeechVoice; } + if ($this->embeddingDimensions !== null) { + $data[self::KEY_EMBEDDING_DIMENSIONS] = $this->embeddingDimensions; + } + if (!empty($this->customOptions)) { $data[self::KEY_CUSTOM_OPTIONS] = $this->customOptions; } @@ -1131,6 +1175,10 @@ static function (array $functionDeclarationData): FunctionDeclaration { $config->setOutputSpeechVoice($array[self::KEY_OUTPUT_SPEECH_VOICE]); } + if (isset($array[self::KEY_EMBEDDING_DIMENSIONS])) { + $config->setEmbeddingDimensions($array[self::KEY_EMBEDDING_DIMENSIONS]); + } + if (isset($array[self::KEY_CUSTOM_OPTIONS])) { $config->setCustomOptions($array[self::KEY_CUSTOM_OPTIONS]); } diff --git a/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php b/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php new file mode 100644 index 00000000..a8fa24f5 --- /dev/null +++ b/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php @@ -0,0 +1,26 @@ + $prompt Array of messages containing the embedding prompt. + * @return EmbeddingResult Result containing embedding vectors. + */ + public function generateEmbeddingResult(array $prompt): EmbeddingResult; +} diff --git a/src/Providers/Models/Enums/OptionEnum.php b/src/Providers/Models/Enums/OptionEnum.php index 27b2248f..e29b7cdd 100644 --- a/src/Providers/Models/Enums/OptionEnum.php +++ b/src/Providers/Models/Enums/OptionEnum.php @@ -21,6 +21,7 @@ * Dynamically loaded from ModelConfig KEY_* constants: * @method static self candidateCount() Creates an instance for CANDIDATE_COUNT option. * @method static self customOptions() Creates an instance for CUSTOM_OPTIONS option. + * @method static self embeddingDimensions() Creates an instance for EMBEDDING_DIMENSIONS option. * @method static self frequencyPenalty() Creates an instance for FREQUENCY_PENALTY option. * @method static self functionDeclarations() Creates an instance for FUNCTION_DECLARATIONS option. * @method static self logprobs() Creates an instance for LOGPROBS option. @@ -42,6 +43,7 @@ * @method static self webSearch() Creates an instance for WEB_SEARCH option. * @method bool isCandidateCount() Checks if the option is CANDIDATE_COUNT. * @method bool isCustomOptions() Checks if the option is CUSTOM_OPTIONS. + * @method bool isEmbeddingDimensions() Checks if the option is EMBEDDING_DIMENSIONS. * @method bool isFrequencyPenalty() Checks if the option is FREQUENCY_PENALTY. * @method bool isFunctionDeclarations() Checks if the option is FUNCTION_DECLARATIONS. * @method bool isLogprobs() Checks if the option is LOGPROBS. diff --git a/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleEmbeddingGenerationModel.php b/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleEmbeddingGenerationModel.php new file mode 100644 index 00000000..98efe58b --- /dev/null +++ b/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleEmbeddingGenerationModel.php @@ -0,0 +1,221 @@ +, index?: int} + * @phpstan-type UsageData array{prompt_tokens?: int, total_tokens?: int} + * @phpstan-type ResponseData array{id?: string, data?: list, usage?: UsageData} + */ +abstract class AbstractOpenAiCompatibleEmbeddingGenerationModel extends AbstractApiBasedModel implements + EmbeddingGenerationModelInterface +{ + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function generateEmbeddingResult(array $prompt): EmbeddingResult + { + $params = $this->prepareGenerateEmbeddingParams($prompt); + + $request = $this->createRequest( + HttpMethodEnum::POST(), + 'embeddings', + ['Content-Type' => 'application/json'], + $params + ); + + $request = $this->getRequestAuthentication()->authenticateRequest($request); + $response = $this->getHttpTransporter()->send($request); + $this->throwIfNotSuccessful($response); + + return $this->parseResponseToEmbeddingResult($response); + } + + /** + * Prepares embedding API request parameters. + * + * @since n.e.x.t + * + * @param list $prompt The prompt to generate embeddings for. + * @return array The parameters for the API request. + */ + protected function prepareGenerateEmbeddingParams(array $prompt): array + { + $params = [ + 'model' => $this->metadata()->getId(), + 'input' => $this->prepareInputParam($prompt), + ]; + + $dimensions = $this->getConfig()->getEmbeddingDimensions(); + if ($dimensions !== null) { + $params['dimensions'] = $dimensions; + } + + foreach ($this->getConfig()->getCustomOptions() as $key => $value) { + if (isset($params[$key])) { + throw new InvalidArgumentException( + sprintf('The custom option "%s" conflicts with an existing parameter.', $key) + ); + } + $params[$key] = $value; + } + + return $params; + } + + /** + * Prepares the input parameter for the embedding API request. + * + * @since n.e.x.t + * + * @param list $messages The messages to prepare. + * @return string|list The input parameter. + */ + protected function prepareInputParam(array $messages) + { + $inputs = []; + foreach ($messages as $message) { + if (!$message->getRole()->isUser()) { + throw new InvalidArgumentException('The API requires user messages as embedding input.'); + } + + foreach ($message->getParts() as $part) { + $text = $part->getText(); + if ($text === null) { + throw new InvalidArgumentException('The API requires text message parts as embedding input.'); + } + $inputs[] = $text; + } + } + + if (empty($inputs)) { + throw new InvalidArgumentException('The API requires at least one text input.'); + } + + return count($inputs) === 1 ? $inputs[0] : $inputs; + } + + /** + * Creates a request object for the provider's API. + * + * @since n.e.x.t + * + * @param HttpMethodEnum $method The HTTP method. + * @param string $path The API endpoint path, relative to the base URI. + * @param array> $headers The request headers. + * @param string|array|null $data The request data. + * @return Request The request object. + */ + abstract protected function createRequest( + HttpMethodEnum $method, + string $path, + array $headers = [], + $data = null + ): Request; + + /** + * Throws an exception if the response is not successful. + * + * @since n.e.x.t + * + * @param Response $response The HTTP response to check. + * @throws ResponseException If the response is not successful. + */ + protected function throwIfNotSuccessful(Response $response): void + { + ResponseUtil::throwIfNotSuccessful($response); + } + + /** + * Parses the response from the API endpoint to an embedding result. + * + * @since n.e.x.t + * + * @param Response $response The response from the API endpoint. + * @return EmbeddingResult The parsed embedding result. + */ + protected function parseResponseToEmbeddingResult(Response $response): EmbeddingResult + { + /** @var ResponseData $responseData */ + $responseData = $response->getData(); + if (!isset($responseData['data']) || !$responseData['data']) { + throw ResponseException::fromMissingData($this->providerMetadata()->getName(), 'data'); + } + if (!is_array($responseData['data'])) { + throw ResponseException::fromInvalidData( + $this->providerMetadata()->getName(), + 'data', + 'The value must be an array.' + ); + } + + $embeddings = []; + foreach ($responseData['data'] as $index => $embeddingData) { + if ( + !is_array($embeddingData) || + !isset($embeddingData['embedding']) || + !is_array($embeddingData['embedding']) + ) { + throw ResponseException::fromInvalidData( + $this->providerMetadata()->getName(), + "data[{$index}]", + 'The value must contain an embedding array.' + ); + } + + $embeddings[] = array_map('floatval', $embeddingData['embedding']); + } + + $usage = isset($responseData['usage']) && is_array($responseData['usage']) ? $responseData['usage'] : []; + $tokenUsage = new TokenUsage( + $usage['prompt_tokens'] ?? 0, + 0, + $usage['total_tokens'] ?? ($usage['prompt_tokens'] ?? 0) + ); + + $providerMetadata = $responseData; + unset($providerMetadata['id'], $providerMetadata['data'], $providerMetadata['usage']); + + return new EmbeddingResult( + $this->getResultId($responseData), + $embeddings, + $tokenUsage, + $this->providerMetadata(), + $this->metadata(), + $providerMetadata + ); + } + + /** + * Extracts the result ID from the API response data. + * + * @since n.e.x.t + * + * @param array $responseData The response data from the API. + * @return string The result ID. + */ + protected function getResultId(array $responseData): string + { + return isset($responseData['id']) && is_string($responseData['id']) ? $responseData['id'] : ''; + } +} diff --git a/src/Results/DTO/EmbeddingResult.php b/src/Results/DTO/EmbeddingResult.php new file mode 100644 index 00000000..2837a3b5 --- /dev/null +++ b/src/Results/DTO/EmbeddingResult.php @@ -0,0 +1,275 @@ +>, + * tokenUsage: TokenUsageArrayShape, + * providerMetadata: ProviderMetadataArrayShape, + * modelMetadata: ModelMetadataArrayShape, + * additionalData?: array + * } + * + * @extends AbstractDataTransferObject + */ +class EmbeddingResult extends AbstractDataTransferObject implements ResultInterface +{ + public const KEY_ID = 'id'; + public const KEY_EMBEDDINGS = 'embeddings'; + public const KEY_TOKEN_USAGE = 'tokenUsage'; + public const KEY_PROVIDER_METADATA = 'providerMetadata'; + public const KEY_MODEL_METADATA = 'modelMetadata'; + public const KEY_ADDITIONAL_DATA = 'additionalData'; + + /** + * @var string Unique identifier for this result. + */ + private string $id; + + /** + * @var list> Embedding vectors. + */ + private array $embeddings; + + /** + * @var TokenUsage Token usage statistics. + */ + private TokenUsage $tokenUsage; + + /** + * @var ProviderMetadata Provider metadata. + */ + private ProviderMetadata $providerMetadata; + + /** + * @var ModelMetadata Model metadata. + */ + private ModelMetadata $modelMetadata; + + /** + * @var array Additional data. + */ + private array $additionalData; + + /** + * Constructor. + * + * @since n.e.x.t + * + * @param string $id Unique identifier for this result. + * @param list> $embeddings Embedding vectors. + * @param TokenUsage $tokenUsage Token usage statistics. + * @param ProviderMetadata $providerMetadata Provider metadata. + * @param ModelMetadata $modelMetadata Model metadata. + * @param array $additionalData Additional data. + */ + public function __construct( + string $id, + array $embeddings, + TokenUsage $tokenUsage, + ProviderMetadata $providerMetadata, + ModelMetadata $modelMetadata, + array $additionalData = [] + ) { + if (empty($embeddings)) { + throw new InvalidArgumentException('At least one embedding must be provided'); + } + + $this->id = $id; + $this->embeddings = $embeddings; + $this->tokenUsage = $tokenUsage; + $this->providerMetadata = $providerMetadata; + $this->modelMetadata = $modelMetadata; + $this->additionalData = $additionalData; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function getId(): string + { + return $this->id; + } + + /** + * Gets the embedding vectors. + * + * @since n.e.x.t + * + * @return list> The embedding vectors. + */ + public function getEmbeddings(): array + { + return $this->embeddings; + } + + /** + * Gets the first embedding vector. + * + * @since n.e.x.t + * + * @return list The first embedding vector. + */ + public function getEmbedding(): array + { + return $this->embeddings[0]; + } + + /** + * Gets the embedding vector dimension. + * + * @since n.e.x.t + * + * @return int The vector dimension. + */ + public function getDimensions(): int + { + return count($this->embeddings[0]); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function getTokenUsage(): TokenUsage + { + return $this->tokenUsage; + } + + /** + * Gets the provider metadata. + * + * @since n.e.x.t + * + * @return ProviderMetadata The provider metadata. + */ + public function getProviderMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + /** + * Gets the model metadata. + * + * @since n.e.x.t + * + * @return ModelMetadata The model metadata. + */ + public function getModelMetadata(): ModelMetadata + { + return $this->modelMetadata; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function getAdditionalData(): array + { + return $this->additionalData; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public static function getJsonSchema(): array + { + return [ + 'type' => 'object', + 'properties' => [ + self::KEY_ID => ['type' => 'string'], + self::KEY_EMBEDDINGS => [ + 'type' => 'array', + 'items' => [ + 'type' => 'array', + 'items' => ['type' => 'number'], + ], + ], + self::KEY_TOKEN_USAGE => TokenUsage::getJsonSchema(), + self::KEY_PROVIDER_METADATA => ProviderMetadata::getJsonSchema(), + self::KEY_MODEL_METADATA => ModelMetadata::getJsonSchema(), + self::KEY_ADDITIONAL_DATA => ['type' => 'object'], + ], + 'required' => [ + self::KEY_ID, + self::KEY_EMBEDDINGS, + self::KEY_TOKEN_USAGE, + self::KEY_PROVIDER_METADATA, + self::KEY_MODEL_METADATA, + ], + ]; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + * + * @return EmbeddingResultArrayShape + */ + public function toArray(): array + { + $data = [ + self::KEY_ID => $this->id, + self::KEY_EMBEDDINGS => $this->embeddings, + self::KEY_TOKEN_USAGE => $this->tokenUsage->toArray(), + self::KEY_PROVIDER_METADATA => $this->providerMetadata->toArray(), + self::KEY_MODEL_METADATA => $this->modelMetadata->toArray(), + ]; + + if ($this->additionalData !== []) { + $data[self::KEY_ADDITIONAL_DATA] = $this->additionalData; + } + + return $data; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public static function fromArray(array $array): self + { + static::validateFromArrayData($array, [ + self::KEY_ID, + self::KEY_EMBEDDINGS, + self::KEY_TOKEN_USAGE, + self::KEY_PROVIDER_METADATA, + self::KEY_MODEL_METADATA, + ]); + + return new self( + $array[self::KEY_ID], + $array[self::KEY_EMBEDDINGS], + TokenUsage::fromArray($array[self::KEY_TOKEN_USAGE]), + ProviderMetadata::fromArray($array[self::KEY_PROVIDER_METADATA]), + ModelMetadata::fromArray($array[self::KEY_MODEL_METADATA]), + $array[self::KEY_ADDITIONAL_DATA] ?? [] + ); + } +} diff --git a/tests/mocks/MockOpenAiCompatibleEmbeddingGenerationModel.php b/tests/mocks/MockOpenAiCompatibleEmbeddingGenerationModel.php new file mode 100644 index 00000000..cd761a4b --- /dev/null +++ b/tests/mocks/MockOpenAiCompatibleEmbeddingGenerationModel.php @@ -0,0 +1,38 @@ + The parameters for the API request. + */ + public function exposePrepareGenerateEmbeddingParams(array $prompt): array + { + return $this->prepareGenerateEmbeddingParams($prompt); + } +} diff --git a/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleEmbeddingGenerationModelTest.php b/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleEmbeddingGenerationModelTest.php new file mode 100644 index 00000000..1218834d --- /dev/null +++ b/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleEmbeddingGenerationModelTest.php @@ -0,0 +1,155 @@ +modelMetadata = $this->createStub(ModelMetadata::class); + $this->modelMetadata->method('getId')->willReturn('test-embedding-model'); + $this->providerMetadata = $this->createStub(ProviderMetadata::class); + $this->providerMetadata->method('getName')->willReturn('TestProvider'); + $this->mockHttpTransporter = $this->createMock(HttpTransporterInterface::class); + $this->mockRequestAuthentication = $this->createMock(RequestAuthenticationInterface::class); + } + + public function testGenerateEmbeddingResultSuccess(): void + { + $response = new Response( + 200, + [], + json_encode([ + 'id' => 'emb-result-123', + 'data' => [ + ['embedding' => [0.1, 0.2, 0.3], 'index' => 0], + ], + 'usage' => [ + 'prompt_tokens' => 4, + 'total_tokens' => 4, + ], + ]) + ); + + $this->mockRequestAuthentication + ->expects($this->once()) + ->method('authenticateRequest') + ->willReturnArgument(0); + + $this->mockHttpTransporter + ->expects($this->once()) + ->method('send') + ->willReturn($response); + + $result = $this->createModel()->generateEmbeddingResult([ + new Message(MessageRoleEnum::user(), [new MessagePart('Hello world')]), + ]); + + $this->assertInstanceOf(EmbeddingResult::class, $result); + $this->assertEquals('emb-result-123', $result->getId()); + $this->assertEquals([[0.1, 0.2, 0.3]], $result->getEmbeddings()); + $this->assertEquals(4, $result->getTokenUsage()->getPromptTokens()); + $this->assertEquals(0, $result->getTokenUsage()->getCompletionTokens()); + $this->assertEquals(4, $result->getTokenUsage()->getTotalTokens()); + } + + public function testPrepareGenerateEmbeddingParamsWithBatchAndDimensions(): void + { + $modelConfig = ModelConfig::fromArray(['embeddingDimensions' => 3]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateEmbeddingParams([ + new Message(MessageRoleEnum::user(), [new MessagePart('First'), new MessagePart('Second')]), + ]); + + $this->assertEquals('test-embedding-model', $params['model']); + $this->assertEquals(['First', 'Second'], $params['input']); + $this->assertEquals(3, $params['dimensions']); + } + + public function testPrepareGenerateEmbeddingParamsRejectsNonTextParts(): void + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('The API requires text message parts as embedding input.'); + + $this->createModel()->exposePrepareGenerateEmbeddingParams([ + new Message(MessageRoleEnum::user(), [new MessagePart(new File('https://example.com/image.png'))]), + ]); + } + + public function testParseResponseWithoutDataThrows(): void + { + $this->mockRequestAuthentication + ->expects($this->once()) + ->method('authenticateRequest') + ->willReturnArgument(0); + + $this->mockHttpTransporter + ->expects($this->once()) + ->method('send') + ->willReturn(new Response(200, [], json_encode(['id' => 'missing-data']))); + + $this->expectException(ResponseException::class); + + $this->createModel()->generateEmbeddingResult([ + new Message(MessageRoleEnum::user(), [new MessagePart('Hello world')]), + ]); + } + + private function createModel(?ModelConfig $modelConfig = null): MockOpenAiCompatibleEmbeddingGenerationModel + { + $model = new MockOpenAiCompatibleEmbeddingGenerationModel( + $this->modelMetadata, + $this->providerMetadata + ); + $model->setHttpTransporter($this->mockHttpTransporter); + $model->setRequestAuthentication($this->mockRequestAuthentication); + if ($modelConfig) { + $model->setConfig($modelConfig); + } + return $model; + } +} diff --git a/tests/unit/Results/DTO/EmbeddingResultTest.php b/tests/unit/Results/DTO/EmbeddingResultTest.php new file mode 100644 index 00000000..89ea6a0c --- /dev/null +++ b/tests/unit/Results/DTO/EmbeddingResultTest.php @@ -0,0 +1,84 @@ +createProviderMetadata(), + $this->createModelMetadata() + ); + + $this->assertInstanceOf(ResultInterface::class, $result); + $this->assertEquals('emb-123', $result->getId()); + $this->assertEquals([[0.1, 0.2, 0.3]], $result->getEmbeddings()); + $this->assertEquals([0.1, 0.2, 0.3], $result->getEmbedding()); + $this->assertEquals(3, $result->getDimensions()); + $this->assertEquals(5, $result->getTokenUsage()->getPromptTokens()); + } + + public function testCreateWithoutEmbeddingsThrows(): void + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('At least one embedding must be provided'); + + new EmbeddingResult( + 'emb-empty', + [], + new TokenUsage(0, 0, 0), + $this->createProviderMetadata(), + $this->createModelMetadata() + ); + } + + public function testArrayRoundTrip(): void + { + $result = new EmbeddingResult( + 'emb-456', + [[1.0, 2.0], [3.0, 4.0]], + new TokenUsage(7, 0, 7), + $this->createProviderMetadata(), + $this->createModelMetadata(), + ['object' => 'list'] + ); + + $restored = EmbeddingResult::fromArray($result->toArray()); + + $this->assertEquals($result->toArray(), $restored->toArray()); + } + + private function createProviderMetadata(): ProviderMetadata + { + return new ProviderMetadata('test-provider', 'Test Provider', ProviderTypeEnum::cloud()); + } + + private function createModelMetadata(): ModelMetadata + { + return new ModelMetadata( + 'test-embedding-model', + 'Test Embedding Model', + [CapabilityEnum::embeddingGeneration()], + [] + ); + } +}