Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/AiClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
use WordPress\AiClient\Providers\ProviderRegistry;
use WordPress\AiClient\Results\DTO\EmbeddingResult;
use WordPress\AiClient\Results\DTO\GenerativeAiResult;

/**
Expand Down Expand Up @@ -386,6 +387,71 @@ public static function generateVideoResult(
return self::getConfiguredPromptBuilder($prompt, $modelOrConfig, $registry)->generateVideoResult();
}

/**
* Generates embeddings using the traditional API approach.
*
* @since 1.4.0
*
* @param Prompt $prompt The prompt content.
* @param ModelInterface|ModelConfig|null $modelOrConfig Optional specific model to use,
* or model configuration for auto-discovery,
* or null for defaults.
* @param ProviderRegistry|null $registry Optional custom registry. If null, uses default.
* @return EmbeddingResult The embedding result.
*
* @throws \InvalidArgumentException If the prompt format is invalid.
* @throws \RuntimeException If no suitable model is found.
*/
public static function generateEmbeddingResult(
$prompt,
$modelOrConfig = null,
?ProviderRegistry $registry = null
): EmbeddingResult {
self::validateModelOrConfigParameter($modelOrConfig);
return self::getConfiguredPromptBuilder($prompt, $modelOrConfig, $registry)->generateEmbeddingResult();
}

/**
* Generates an embedding using the traditional API approach.
*
* @since 1.4.0
*
* @param Prompt $prompt The prompt content.
* @param ModelInterface|ModelConfig|null $modelOrConfig Optional specific model to use,
* or model configuration for auto-discovery,
* or null for defaults.
* @param ProviderRegistry|null $registry Optional custom registry. If null, uses default.
* @return list<float|int> The generated embedding vector.
*/
public static function generateEmbedding(
$prompt,
$modelOrConfig = null,
?ProviderRegistry $registry = null
): array {
return self::generateEmbeddingResult($prompt, $modelOrConfig, $registry)->getEmbedding();
}

/**
* Generates embeddings for a list of prompts using the traditional API approach.
*
* @since 1.4.0
*
* @param list<Prompt> $prompts The prompts to embed.
* @param ModelInterface|ModelConfig|null $modelOrConfig Optional specific model to use,
* or model configuration for auto-discovery,
* or null for defaults.
* @param ProviderRegistry|null $registry Optional custom registry. If null, uses default.
* @return list<list<float|int>> The generated embedding vectors.
*/
public static function generateEmbeddings(
array $prompts,
$modelOrConfig = null,
?ProviderRegistry $registry = null
): array {
self::validateModelOrConfigParameter($modelOrConfig);
return self::getConfiguredPromptBuilder(null, $modelOrConfig, $registry)->generateEmbeddings($prompts);
}

/**
* Creates a new message builder for fluent API usage.
*
Expand Down
144 changes: 143 additions & 1 deletion src/Builders/PromptBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
use WordPress\AiClient\Providers\Models\DTO\ModelRequirements;
use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationModelInterface;
use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum;
use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface;
use WordPress\AiClient\Providers\Models\SpeechGeneration\Contracts\SpeechGenerationModelInterface;
use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface;
use WordPress\AiClient\Providers\Models\TextToSpeechConversion\Contracts\TextToSpeechConversionModelInterface;
use WordPress\AiClient\Providers\Models\VideoGeneration\Contracts\VideoGenerationModelInterface;
use WordPress\AiClient\Providers\ProviderRegistry;
use WordPress\AiClient\Results\DTO\EmbeddingResult;
use WordPress\AiClient\Results\DTO\GenerativeAiResult;
use WordPress\AiClient\Tools\DTO\FunctionDeclaration;
use WordPress\AiClient\Tools\DTO\FunctionResponse;
Expand Down Expand Up @@ -479,6 +481,34 @@ public function usingCandidateCount(int $candidateCount): self
return $this;
}

/**
* Sets the embedding dimensions.
*
* @since 1.4.0
*
* @param int $dimensions The embedding dimensions.
* @return self
*/
public function usingDimensions(int $dimensions): self
{
$this->modelConfig->setDimensions($dimensions);
return $this;
}

/**
* Sets the embedding encoding format.
*
* @since 1.4.0
*
* @param string $encodingFormat The embedding encoding format.
* @return self
*/
public function usingEncodingFormat(string $encodingFormat): self
{
$this->modelConfig->setEncodingFormat($encodingFormat);
return $this;
}

/**
* Sets the function declarations available to the model.
*
Expand Down Expand Up @@ -761,7 +791,6 @@ private function inferCapabilityFromModelInterfaces(ModelInterface $model): ?Cap
if ($model instanceof VideoGenerationModelInterface) {
return CapabilityEnum::videoGeneration();
}

// No supported interface found
return null;
}
Expand Down Expand Up @@ -1120,6 +1149,36 @@ public function generateVideoResult(): GenerativeAiResult
return $this->generateResult(CapabilityEnum::videoGeneration());
}

/**
* Generates an embedding result from the prompt.
*
* @since 1.4.0
*
* @return EmbeddingResult The generated embedding result.
* @throws InvalidArgumentException If the prompt or model validation fails.
* @throws RuntimeException If the model doesn't support embedding generation.
*/
public function generateEmbeddingResult(): EmbeddingResult
{
$this->validateMessages();

$capability = CapabilityEnum::embeddingGeneration();
$model = $this->getConfiguredModel($capability);

if (!$model instanceof EmbeddingGenerationModelInterface) {
throw new RuntimeException(
sprintf(
'Model "%s" does not support embedding generation.',
$model->metadata()->getId()
)
);
}

$this->dispatchEvent(new BeforeGenerateResultEvent($this->messages, $model, $capability));

return $model->generateEmbeddingResult([$this->messages]);
}

/**
* Generates text from the prompt.
*
Expand Down Expand Up @@ -1166,6 +1225,62 @@ public function generateImage(): File
return $this->generateImageResult()->toFile();
}

/**
* Generates an embedding from the prompt.
*
* @since 1.4.0
*
* @return list<float|int> The generated embedding vector.
* @throws InvalidArgumentException If the prompt or model validation fails.
*/
public function generateEmbedding(): array
{
return $this->generateEmbeddingResult()->getEmbedding();
}

/**
* Generates embeddings from the prompt or from the provided prompt list.
*
* @since 1.4.0
*
* @param list<Prompt>|null $prompts Optional prompts to embed as a batch.
* @return list<list<float|int>> The generated embedding vectors.
* @throws InvalidArgumentException If a prompt or model validation fails.
*/
public function generateEmbeddings(?array $prompts = null): array
{
if ($prompts === null) {
return $this->generateEmbeddingResult()->getEmbeddings();
}

if (!array_is_list($prompts)) {
throw new InvalidArgumentException('Prompts must be a list array.');
}

if (empty($prompts)) {
throw new InvalidArgumentException('Cannot generate embeddings from an empty prompt list.');
}

$promptMessages = [];
foreach ($prompts as $prompt) {
$promptMessages[] = $this->parsePromptToMessages($prompt);
}

$capability = CapabilityEnum::embeddingGeneration();
$model = $this->getConfiguredModel($capability);

if (!$model instanceof EmbeddingGenerationModelInterface) {
throw new RuntimeException(
sprintf(
'Model "%s" does not support embedding generation.',
$model->metadata()->getId()
)
);
}

return $model->generateEmbeddingResult($promptMessages)->getEmbeddings();
}

/**
* Generates multiple images from the prompt.
*
Expand Down Expand Up @@ -1595,6 +1710,33 @@ private function parseMessage($input, MessageRoleEnum $defaultRole): Message
return new Message($defaultRole, $parts);
}

/**
* Parses prompt input into a message list.
*
* @since 1.4.0
*
* @param Prompt $prompt The prompt to parse.
* @return list<Message> The parsed messages.
*/
private function parsePromptToMessages($prompt): array
{
if ($this->isMessagesList($prompt)) {
$messages = $prompt;
} else {
$messages = [$this->parseMessage($prompt, MessageRoleEnum::user())];
}

$originalMessages = $this->messages;
try {
$this->messages = $messages;
$this->validateMessages();
} finally {
$this->messages = $originalMessages;
}

return $messages;
}

/**
* Validates the messages array for prompt generation.
*
Expand Down
Loading
Loading