<?php

namespace Drupal\ai_provider_cohere\Plugin\AiProvider;

use Drupal\ai\Attribute\AiProvider;
use Drupal\ai\Base\AiProviderClientBase;
use Drupal\ai_provider_cohere\CohereClient;
use Drupal\ai\OperationType\AiRerank\AiReRankInput;
use Drupal\ai\OperationType\AiRerank\AiReRankInterface;
use Drupal\ai\OperationType\AiRerank\AiReRankOutput;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\Plugin\ContainerFactoryPluginInterface;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Symfony\Component\Yaml\Yaml;

/**
 * Plugin implementation of the 'cohere' provider.
 */
#[AiProvider(
  id: 'cohere',
  label: new TranslatableMarkup('Cohere'),
)]
class CohereProvider extends AiProviderClientBase implements
  ContainerFactoryPluginInterface,
  AiReRankInterface {

  /**
   * The Cohere Client.
   *
   * @var \Drupal\ai_provider_cohere\CohereClient|null
   */
  protected ?CohereClient $client;

  /**
   * {@inheritdoc}
   */
  public function getConfiguredModels(
    ?string $operation_type = NULL,
    array $capabilities = [],
  ): array {
    $models = CohereClient::fromClientAndApiKey(
      $this->httpClient,
      $this->loadApiKey()
    )->listModels();

    // Cohere uses 'rerank' instead of 'ai_rerank'.
    if ($operation_type == 'ai_rerank') {
      $operation_type = 'rerank';
    }

    $models = array_map(static fn($model) => $model['name'],
      \array_filter(
        $models['models'],
        function($model) use ($operation_type) {
          return in_array($operation_type, $model['endpoints']);
        }
      ));

    return array_combine($models, $models);
  }

  /**
   * {@inheritdoc}
   */
  public function isUsable(
    ?string $operation_type = NULL,
    array $capabilities = [],
  ): bool {
    // If it's not configured, it is not usable.
    if (!$this->getConfig()->get('api_key')) {
      return FALSE;
    }
    // If its one of the bundles that Cohere supports its usable.
    if ($operation_type) {
      return in_array($operation_type, $this->getSupportedOperationTypes());
    }

    return TRUE;
  }

  /**
   * {@inheritdoc}
   */
  public function getSupportedOperationTypes(): array {
    return [
      'ai_rerank',
    ];
  }

  /**
   * {@inheritdoc}
   */
  public function getSupportedCapabilities(): array {
    return [];
  }

  /**
   * {@inheritdoc}
   */
  public function getConfig(): ImmutableConfig {
    return $this->configFactory->get('ai_provider_cohere.settings');
  }

  /**
   * {@inheritdoc}
   */
  public function getApiDefinition(): array {
    // Load the configuration.
    $definition = Yaml::parseFile(
      $this->moduleHandler->getModule('ai_provider_cohere')->getPath(
      ).'/definitions/api_defaults.yml'
    );

    return $definition;
  }

  /**
   * {@inheritdoc}
   */
  public function getModelSettings(
    string $model_id,
    array $generalConfig = [],
  ): array {
    return [];
  }

  /**
   * {@inheritdoc}
   */
  public function setAuthentication(mixed $authentication): void {}

  /**
   * {@inheritdoc}
   */
  public function rerank(AiReRankInput $input, string $model): AiReRankOutput {
    return CohereClient::fromClientAndApiKey(
      $this->httpClient,
      $this->loadApiKey()
    )->rerank($input);
  }

}
