Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<?php

declare(strict_types=1);

namespace WordPress\AiClient\Providers\Contracts;

use WordPress\AiClient\Providers\Http\Contracts\RequestAuthenticationInterface;

/**
* Interface for providers that supply their own request authentication.
*
* @since n.e.x.t
*/
interface ProviderWithRequestAuthenticationInterface
{
/**
* Gets the request authentication instance for the provider.
*
* @since n.e.x.t
*
* @return RequestAuthenticationInterface|null The request authentication instance, or null if not configured.
*/
public static function requestAuthentication(): ?RequestAuthenticationInterface;
}
47 changes: 27 additions & 20 deletions src/Providers/ProviderRegistry.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
use WordPress\AiClient\Common\Exception\RuntimeException;
use WordPress\AiClient\Providers\Contracts\ProviderInterface;
use WordPress\AiClient\Providers\Contracts\ProviderWithOperationsHandlerInterface;
use WordPress\AiClient\Providers\Contracts\ProviderWithRequestAuthenticationInterface;
use WordPress\AiClient\Providers\DTO\ProviderMetadata;
use WordPress\AiClient\Providers\DTO\ProviderModelsMetadata;
use WordPress\AiClient\Providers\Http\Contracts\HttpTransporterInterface;
Expand Down Expand Up @@ -464,27 +465,29 @@ private function setRequestAuthenticationForProvider(
string $className,
RequestAuthenticationInterface $requestAuthentication
): void {
$authenticationMethod = $className::metadata()->getAuthenticationMethod();
if ($authenticationMethod === null) {
throw new InvalidArgumentException(
sprintf(
'Provider %s does not expect any authentication, but got %s.',
$className,
get_class($requestAuthentication)
)
);
}
if (!is_subclass_of($className, ProviderWithRequestAuthenticationInterface::class)) {
$authenticationMethod = $className::metadata()->getAuthenticationMethod();
if ($authenticationMethod === null) {
throw new InvalidArgumentException(
sprintf(
'Provider %s does not expect any authentication, but got %s.',
$className,
get_class($requestAuthentication)
)
);
}

$expectedClass = $authenticationMethod->getImplementationClass();
if (!$requestAuthentication instanceof $expectedClass) {
throw new InvalidArgumentException(
sprintf(
'Provider %s expects authentication of type %s, but got %s.',
$className,
$expectedClass,
get_class($requestAuthentication)
)
);
$expectedClass = $authenticationMethod->getImplementationClass();
if (!$requestAuthentication instanceof $expectedClass) {
throw new InvalidArgumentException(
sprintf(
'Provider %s expects authentication of type %s, but got %s.',
$className,
$expectedClass,
get_class($requestAuthentication)
)
);
}
}

$availability = $className::availability();
Expand Down Expand Up @@ -521,6 +524,10 @@ private function createDefaultProviderRequestAuthentication(
$providerId = $providerMetadata->getId();
$authenticationMethod = $providerMetadata->getAuthenticationMethod();

if (is_subclass_of($className, ProviderWithRequestAuthenticationInterface::class)) {
return $className::requestAuthentication();
}

if ($authenticationMethod === null) {
return null;
}
Expand Down
21 changes: 21 additions & 0 deletions tests/mocks/MockCustomAuthModel.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
<?php

declare(strict_types=1);

namespace WordPress\AiClient\Tests\mocks;

use WordPress\AiClient\Providers\DTO\ProviderMetadata;

/**
* Mock model for the custom authentication provider.
*/
class MockCustomAuthModel extends MockModel
{
/**
* {@inheritDoc}
*/
public function providerMetadata(): ProviderMetadata
{
return MockCustomAuthProvider::metadata();
}
}
73 changes: 73 additions & 0 deletions tests/mocks/MockCustomAuthProvider.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
<?php

declare(strict_types=1);

namespace WordPress\AiClient\Tests\mocks;

use WordPress\AiClient\Providers\Contracts\ProviderWithRequestAuthenticationInterface;
use WordPress\AiClient\Providers\DTO\ProviderMetadata;
use WordPress\AiClient\Providers\Enums\ProviderTypeEnum;
use WordPress\AiClient\Providers\Http\Contracts\RequestAuthenticationInterface;
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;

/**
* Mock provider with custom request authentication for testing purposes.
*/
class MockCustomAuthProvider extends MockProvider implements ProviderWithRequestAuthenticationInterface
{
/**
* @var RequestAuthenticationInterface|null Custom request authentication instance.
*/
private static ?RequestAuthenticationInterface $requestAuthentication = null;

/**
* {@inheritDoc}
*/
public static function metadata(): ProviderMetadata
{
return new ProviderMetadata(
'mock-custom-auth',
'Mock Custom Auth Provider',
ProviderTypeEnum::cloud()
);
}

/**
* {@inheritDoc}
*/
public static function requestAuthentication(): ?RequestAuthenticationInterface
{
return static::$requestAuthentication;
}

/**
* {@inheritDoc}
*/
public static function model(string $modelId, ?ModelConfig $modelConfig = null): ModelInterface
{
$modelMetadata = static::modelMetadataDirectory()->getModelMetadata($modelId);
$config = $modelConfig ?? new ModelConfig();

return new MockCustomAuthModel($modelMetadata, $config);
}

/**
* Sets the request authentication for testing.
*
* @param RequestAuthenticationInterface|null $requestAuthentication The request authentication instance.
*/
public static function setRequestAuthentication(?RequestAuthenticationInterface $requestAuthentication): void
{
static::$requestAuthentication = $requestAuthentication;
}

/**
* Resets static state for testing.
*/
public static function reset(): void
{
parent::reset();
static::$requestAuthentication = null;
}
}
30 changes: 30 additions & 0 deletions tests/mocks/MockCustomRequestAuthentication.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<?php

declare(strict_types=1);

namespace WordPress\AiClient\Tests\mocks;

use WordPress\AiClient\Providers\Http\Contracts\RequestAuthenticationInterface;
use WordPress\AiClient\Providers\Http\DTO\Request;

/**
* Mock custom request authentication for testing purposes.
*/
class MockCustomRequestAuthentication implements RequestAuthenticationInterface
{
/**
* {@inheritDoc}
*/
public function authenticateRequest(Request $request): Request
{
return $request->withHeader('X-Mock-Auth', 'custom');
}

/**
* {@inheritDoc}
*/
public static function getJsonSchema(): array
{
return [];
}
}
75 changes: 72 additions & 3 deletions tests/unit/Providers/ProviderRegistryTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
use WordPress\AiClient\Providers\Models\DTO\ModelRequirements;
use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum;
use WordPress\AiClient\Providers\ProviderRegistry;
use WordPress\AiClient\Tests\mocks\MockCustomAuthProvider;
use WordPress\AiClient\Tests\mocks\MockCustomRequestAuthentication;
use WordPress\AiClient\Tests\mocks\MockHttpTransporter;
use WordPress\AiClient\Tests\mocks\MockModel;
use WordPress\AiClient\Tests\mocks\MockModelMetadataDirectory;
Expand All @@ -32,12 +34,14 @@ protected function setUp(): void
{
parent::setUp();
$this->registry = new ProviderRegistry();
MockCustomAuthProvider::reset();
MockProvider::reset(); // Reset static state of mock provider before each test.
}

protected function tearDown(): void
{
MockProvider::reset(); // Reset static state of mock provider after each test.
MockCustomAuthProvider::reset();
parent::tearDown();
}

Expand Down Expand Up @@ -376,6 +380,65 @@ public function testGetProviderRequestAuthenticationReturnsDefault(): void
$this->assertNull($retrievedAuth);
}

/**
* Tests that provider-supplied request authentication is used when available.
*
* @return void
*/
public function testRegisterProviderUsesProviderSuppliedRequestAuthentication(): void
{
$requestAuthentication = new MockCustomRequestAuthentication();
MockCustomAuthProvider::setRequestAuthentication($requestAuthentication);

$this->registry->registerProvider(MockCustomAuthProvider::class);

$this->assertSame(
$requestAuthentication,
$this->registry->getProviderRequestAuthentication('mock-custom-auth')
);
}

/**
* Tests that provider-supplied request authentication is bound to models.
*
* @return void
*/
public function testRegisterProviderBindsProviderSuppliedRequestAuthenticationToModels(): void
{
$requestAuthentication = new MockCustomRequestAuthentication();
MockCustomAuthProvider::setRequestAuthentication($requestAuthentication);

$this->registry->registerProvider(MockCustomAuthProvider::class);

$model = $this->registry->getProviderModel('mock-custom-auth', 'mock-text-model');

$this->assertInstanceOf(MockModel::class, $model);
$this->assertSame($requestAuthentication, $model->getRequestAuthentication());
}

/**
* Tests that explicit request authentication overrides provider-supplied authentication.
*
* @return void
*/
public function testSetProviderRequestAuthenticationOverridesProviderSuppliedRequestAuthentication(): void
{
MockCustomAuthProvider::setRequestAuthentication(new MockCustomRequestAuthentication());

$this->registry->registerProvider(MockCustomAuthProvider::class);

$requestAuthentication = new MockCustomRequestAuthentication();
$this->registry->setProviderRequestAuthentication('mock-custom-auth', $requestAuthentication);

$model = $this->registry->getProviderModel('mock-custom-auth', 'mock-text-model');

$this->assertSame(
$requestAuthentication,
$this->registry->getProviderRequestAuthentication('mock-custom-auth')
);
$this->assertSame($requestAuthentication, $model->getRequestAuthentication());
}

/**
* Tests the internal getEnvVarName method using reflection.
*
Expand All @@ -388,7 +451,9 @@ public function testGetProviderRequestAuthenticationReturnsDefault(): void
public function testGetEnvVarName(string $providerId, string $field, string $expected): void
{
$method = new \ReflectionMethod(ProviderRegistry::class, 'getEnvVarName');
$method->setAccessible(true);
if (PHP_VERSION_ID < 80100) {
$method->setAccessible(true);
}

$result = $method->invoke($this->registry, $providerId, $field); // Invoke on instance

Expand Down Expand Up @@ -424,7 +489,9 @@ public function testCreateDefaultProviderRequestAuthenticationWithEnvVar(): void
$this->registry->registerProvider(MockProvider::class);

$method = new \ReflectionMethod(ProviderRegistry::class, 'createDefaultProviderRequestAuthentication');
$method->setAccessible(true);
if (PHP_VERSION_ID < 80100) {
$method->setAccessible(true);
}

$auth = $method->invoke($this->registry, MockProvider::class);

Expand All @@ -448,7 +515,9 @@ public function testCreateDefaultProviderRequestAuthenticationWithoutEnvVar(): v
$this->registry->registerProvider(MockProvider::class);

$method = new \ReflectionMethod(ProviderRegistry::class, 'createDefaultProviderRequestAuthentication');
$method->setAccessible(true);
if (PHP_VERSION_ID < 80100) {
$method->setAccessible(true);
}

$auth = $method->invoke($this->registry, MockProvider::class);

Expand Down
Loading