<?php

declare(strict_types=1);

namespace Drupal\ai_provider_cohere;

use Drupal\ai\OperationType\AiRerank\AiReRankInput;
use Drupal\ai\OperationType\AiRerank\AiReRankOutput;
use GuzzleHttp\Exception\GuzzleException;
use Psr\Http\Client\ClientInterface;

/**
 * A client for the Cohere API.
 */
class CohereClient {

  private const API_BASE_V1_URL = 'https://api.cohere.ai/v1';

  private const API_BASE_V2_URL = 'https://api.cohere.ai/v2';

  /**
   * CohereClient constructor.
   *
   * @param \Psr\Http\Client\ClientInterface $httpClient
   *   The HTTP client.
   * @param string $apiKey
   *   The API key.
   */
  private function __construct(
    private readonly ClientInterface $httpClient,
    private readonly string $apiKey,
  ) {}

  /**
   * Crete a new CohereClient from an HTTP client and an API key.
   *
   * @param \Psr\Http\Client\ClientInterface $httpClient
   *   The HTTP client.
   * @param string $apiKey
   *   The API key.
   *
   * @return \Drupal\ai_provider_cohere\CohereClient
   */
  public static function fromClientAndApiKey(
    ClientInterface $httpClient,
    string $apiKey,
  ): self {
    return new self($httpClient, $apiKey);
  }

  /**
   * List the available models.
   *
   * @return array
   *   The available models.
   *
   * @throws \Exception
   */
  public function listModels(): array {
    try {
      $models = $this->httpClient->get(self::API_BASE_V1_URL . '/models', [
        'headers' => [
          'Authorization' => 'Bearer ' . $this->apiKey,
        ],
      ]);

      return json_decode($models->getBody()->getContents(), TRUE);
    }
    catch (GuzzleException $e) {
      throw new \Exception('Failed to list models', previous: $e);
    }
  }

  /**
   * Rerank a list of documents.
   *
   * @param \Drupal\ai\OperationType\AiRerank\AiReRankInput $data
   *   The data to rerank.
   *
   * @return \Drupal\ai\OperationType\AiRerank\AiReRankOutput
   *   The reranked documents.
   *
   * @throws \Exception
   */
  public function rerank(AiReRankInput $data): AiReRankOutput {
    try {
      $request = $this->httpClient->post(self::API_BASE_V2_URL.'/rerank', [
        'headers' => [
          'Authorization' => 'Bearer '.$this->apiKey,
        ],
        'json' => $data->toArray(),
      ]);

      $response = json_decode($request->getBody()->getContents(), TRUE);

      $reranked = [];
      foreach ($response['results'] as $result) {
        $document = $data->documents[$result['index']];
        $document['rerank_score'] = $result['relevance_score'];
        $reranked[] = $document;
      }

      return new AiReRankOutput(
        $reranked,
        $response['id'],
        $response['meta'],
      );
    }
    catch (GuzzleException $e) {
      throw new \Exception('Failed to list models', previous: $e);
    }
  }

}
