Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/Builders/PromptBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
use WordPress\AiClient\Providers\Models\DTO\ModelRequirements;
use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationModelInterface;
use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum;
use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface;
use WordPress\AiClient\Providers\Models\SpeechGeneration\Contracts\SpeechGenerationModelInterface;
use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface;
use WordPress\AiClient\Providers\Models\TextToSpeechConversion\Contracts\TextToSpeechConversionModelInterface;
use WordPress\AiClient\Providers\Models\VideoGeneration\Contracts\VideoGenerationModelInterface;
use WordPress\AiClient\Providers\ProviderRegistry;
use WordPress\AiClient\Results\DTO\EmbeddingResult;
use WordPress\AiClient\Results\DTO\GenerativeAiResult;
use WordPress\AiClient\Tools\DTO\FunctionDeclaration;
use WordPress\AiClient\Tools\DTO\FunctionResponse;
Expand Down Expand Up @@ -674,6 +676,20 @@ public function asOutputSpeechVoice(string $voice): self
return $this;
}

/**
* Sets the embedding vector dimension count.
*
* @since n.e.x.t
*
* @param int $dimensions The embedding vector dimension count.
* @return self
*/
public function usingEmbeddingDimensions(int $dimensions): self
{
$this->modelConfig->setEmbeddingDimensions($dimensions);
return $this;
}

/**
* Configures the prompt for JSON response output.
*
Expand Down Expand Up @@ -761,6 +777,9 @@ private function inferCapabilityFromModelInterfaces(ModelInterface $model): ?Cap
if ($model instanceof VideoGenerationModelInterface) {
return CapabilityEnum::videoGeneration();
}
if ($model instanceof EmbeddingGenerationModelInterface) {
return CapabilityEnum::embeddingGeneration();
}

// No supported interface found
return null;
Expand Down Expand Up @@ -1120,6 +1139,40 @@ public function generateVideoResult(): GenerativeAiResult
return $this->generateResult(CapabilityEnum::videoGeneration());
}

/**
* Generates an embedding result from the prompt.
*
* @since n.e.x.t
*
* @return EmbeddingResult The generated result containing embedding vectors.
* @throws InvalidArgumentException If the prompt or model validation fails.
* @throws RuntimeException If the model doesn't support embedding generation.
*/
public function generateEmbeddingResult(): EmbeddingResult
{
$this->validateMessages();
$capability = CapabilityEnum::embeddingGeneration();
$model = $this->getConfiguredModel($capability);

$this->dispatchEvent(
new BeforeGenerateResultEvent($this->messages, $model, $capability)
);

if (!$model instanceof EmbeddingGenerationModelInterface) {
throw new RuntimeException(
sprintf('Model "%s" does not support embedding generation.', $model->metadata()->getId())
);
}

$result = $model->generateEmbeddingResult($this->messages);

$this->dispatchEvent(
new AfterGenerateResultEvent($this->messages, $model, $capability, $result)
);

return $result;
}

/**
* Generates text from the prompt.
*
Expand All @@ -1133,6 +1186,30 @@ public function generateText(): string
return $this->generateTextResult()->toText();
}

/**
* Generates the first embedding vector from the prompt.
*
* @since n.e.x.t
*
* @return list<float> The generated embedding vector.
*/
public function generateEmbedding(): array
{
return $this->generateEmbeddingResult()->getEmbedding();
}

/**
* Generates embedding vectors from the prompt.
*
* @since n.e.x.t
*
* @return list<list<float>> The generated embedding vectors.
*/
public function generateEmbeddings(): array
{
return $this->generateEmbeddingResult()->getEmbeddings();
}

/**
* Generates multiple text candidates from the prompt.
*
Expand Down
14 changes: 7 additions & 7 deletions src/Events/AfterGenerateResultEvent.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use WordPress\AiClient\Messages\DTO\Message;
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum;
use WordPress\AiClient\Results\DTO\GenerativeAiResult;
use WordPress\AiClient\Results\Contracts\ResultInterface;

/**
* Event dispatched after a prompt has been sent to the AI model and a response received.
Expand Down Expand Up @@ -35,9 +35,9 @@ class AfterGenerateResultEvent
private ?CapabilityEnum $capability;

/**
* @var GenerativeAiResult The result from the model.
* @var ResultInterface The result from the model.
*/
private GenerativeAiResult $result;
private ResultInterface $result;

/**
* Constructor.
Expand All @@ -47,13 +47,13 @@ class AfterGenerateResultEvent
* @param list<Message> $messages The messages that were sent to the model.
* @param ModelInterface $model The model that processed the prompt.
* @param CapabilityEnum|null $capability The capability that was used for generation.
* @param GenerativeAiResult $result The result from the model.
* @param ResultInterface $result The result from the model.
*/
public function __construct(
array $messages,
ModelInterface $model,
?CapabilityEnum $capability,
GenerativeAiResult $result
ResultInterface $result
) {
$this->messages = $messages;
$this->model = $model;
Expand Down Expand Up @@ -102,9 +102,9 @@ public function getCapability(): ?CapabilityEnum
*
* @since 0.4.0
*
* @return GenerativeAiResult The result.
* @return ResultInterface The result.
*/
public function getResult(): GenerativeAiResult
public function getResult(): ResultInterface
{
return $this->result;
}
Expand Down
48 changes: 48 additions & 0 deletions src/Providers/Models/DTO/ModelConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
* outputMediaOrientation?: string,
* outputMediaAspectRatio?: string,
* outputSpeechVoice?: string,
* embeddingDimensions?: int,
* customOptions?: array<string, mixed>
* }
*
Expand Down Expand Up @@ -72,6 +73,7 @@ class ModelConfig extends AbstractDataTransferObject
public const KEY_OUTPUT_MEDIA_ORIENTATION = 'outputMediaOrientation';
public const KEY_OUTPUT_MEDIA_ASPECT_RATIO = 'outputMediaAspectRatio';
public const KEY_OUTPUT_SPEECH_VOICE = 'outputSpeechVoice';
public const KEY_EMBEDDING_DIMENSIONS = 'embeddingDimensions';
public const KEY_CUSTOM_OPTIONS = 'customOptions';

/*
Expand Down Expand Up @@ -181,6 +183,11 @@ class ModelConfig extends AbstractDataTransferObject
*/
protected ?string $outputSpeechVoice = null;

/**
* @var int|null Embedding vector dimension count.
*/
protected ?int $embeddingDimensions = null;

/**
* @var array<string, mixed> Custom provider-specific options.
*/
Expand Down Expand Up @@ -771,6 +778,34 @@ public function getOutputSpeechVoice(): ?string
return $this->outputSpeechVoice;
}

/**
* Sets the embedding vector dimension count.
*
* @since n.e.x.t
*
* @param int $embeddingDimensions The embedding vector dimension count.
*/
public function setEmbeddingDimensions(int $embeddingDimensions): void
{
if ($embeddingDimensions < 1) {
throw new InvalidArgumentException('Embedding dimensions must be greater than 0.');
}

$this->embeddingDimensions = $embeddingDimensions;
}

/**
* Gets the embedding vector dimension count.
*
* @since n.e.x.t
*
* @return int|null The embedding vector dimension count.
*/
public function getEmbeddingDimensions(): ?int
{
return $this->embeddingDimensions;
}

/**
* Sets a single custom option.
*
Expand Down Expand Up @@ -915,6 +950,11 @@ public static function getJsonSchema(): array
'type' => 'string',
'description' => 'Output speech voice.',
],
self::KEY_EMBEDDING_DIMENSIONS => [
'type' => 'integer',
'minimum' => 1,
'description' => 'Embedding vector dimension count.',
],
self::KEY_CUSTOM_OPTIONS => [
'type' => 'object',
'additionalProperties' => true,
Expand Down Expand Up @@ -1026,6 +1066,10 @@ static function (FunctionDeclaration $functionDeclaration): array {
$data[self::KEY_OUTPUT_SPEECH_VOICE] = $this->outputSpeechVoice;
}

if ($this->embeddingDimensions !== null) {
$data[self::KEY_EMBEDDING_DIMENSIONS] = $this->embeddingDimensions;
}

if (!empty($this->customOptions)) {
$data[self::KEY_CUSTOM_OPTIONS] = $this->customOptions;
}
Expand Down Expand Up @@ -1131,6 +1175,10 @@ static function (array $functionDeclarationData): FunctionDeclaration {
$config->setOutputSpeechVoice($array[self::KEY_OUTPUT_SPEECH_VOICE]);
}

if (isset($array[self::KEY_EMBEDDING_DIMENSIONS])) {
$config->setEmbeddingDimensions($array[self::KEY_EMBEDDING_DIMENSIONS]);
}

if (isset($array[self::KEY_CUSTOM_OPTIONS])) {
$config->setCustomOptions($array[self::KEY_CUSTOM_OPTIONS]);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<?php

declare(strict_types=1);

namespace WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts;

use WordPress\AiClient\Messages\DTO\Message;
use WordPress\AiClient\Results\DTO\EmbeddingResult;

/**
* Interface for models that support embedding generation.
*
* @since n.e.x.t
*/
interface EmbeddingGenerationModelInterface
{
/**
* Generates embeddings from a prompt.
*
* @since n.e.x.t
*
* @param list<Message> $prompt Array of messages containing the embedding prompt.
* @return EmbeddingResult Result containing embedding vectors.
*/
public function generateEmbeddingResult(array $prompt): EmbeddingResult;
}
2 changes: 2 additions & 0 deletions src/Providers/Models/Enums/OptionEnum.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* Dynamically loaded from ModelConfig KEY_* constants:
* @method static self candidateCount() Creates an instance for CANDIDATE_COUNT option.
* @method static self customOptions() Creates an instance for CUSTOM_OPTIONS option.
* @method static self embeddingDimensions() Creates an instance for EMBEDDING_DIMENSIONS option.
* @method static self frequencyPenalty() Creates an instance for FREQUENCY_PENALTY option.
* @method static self functionDeclarations() Creates an instance for FUNCTION_DECLARATIONS option.
* @method static self logprobs() Creates an instance for LOGPROBS option.
Expand All @@ -42,6 +43,7 @@
* @method static self webSearch() Creates an instance for WEB_SEARCH option.
* @method bool isCandidateCount() Checks if the option is CANDIDATE_COUNT.
* @method bool isCustomOptions() Checks if the option is CUSTOM_OPTIONS.
* @method bool isEmbeddingDimensions() Checks if the option is EMBEDDING_DIMENSIONS.
* @method bool isFrequencyPenalty() Checks if the option is FREQUENCY_PENALTY.
* @method bool isFunctionDeclarations() Checks if the option is FUNCTION_DECLARATIONS.
* @method bool isLogprobs() Checks if the option is LOGPROBS.
Expand Down
Loading
Loading