<?php

namespace Drupal\ai_provider_google_vertex\Plugin\AiProvider;

use Drupal\Component\Serialization\Json;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\Plugin\ContainerFactoryPluginInterface;
use Drupal\Core\StringTranslation\StringTranslationTrait;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Drupal\ai\Attribute\AiProvider;
use Drupal\ai\Base\AiProviderClientBase;
use Drupal\ai\Exception\AiBadRequestException;
use Drupal\ai\Exception\AiSetupFailureException;
use Drupal\ai\OperationType\Chat\ChatInput;
use Drupal\ai\OperationType\Chat\ChatInterface;
use Drupal\ai\OperationType\Chat\ChatMessage;
use Drupal\ai\OperationType\Chat\ChatOutput;
use Drupal\ai\OperationType\Chat\StreamedChatMessageIteratorInterface;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInterface;
use Drupal\ai\OperationType\Embeddings\EmbeddingsOutput;
use Drupal\ai\Traits\OperationType\ChatTrait;
use Drupal\ai_provider_google_vertex\GoogleVertexChatIterator;
use Google\Cloud\AIPlatform\V1\Blob;
use Google\Cloud\AIPlatform\V1\Client\PredictionServiceClient;
use Google\Cloud\AIPlatform\V1\Content;
use Google\Cloud\AIPlatform\V1\GenerateContentRequest;
use Google\Cloud\AIPlatform\V1\Part;
use Google\Cloud\AIPlatform\V1\PredictRequest;
use Google\Protobuf\ListValue;
use Google\Protobuf\Struct;
use Google\Protobuf\Value;
use Symfony\Component\Yaml\Yaml;

/**
 * Plugin implementation of the 'google vertex' provider.
 */
#[AiProvider(
  id: 'google_vertex',
  label: new TranslatableMarkup('Google Vertex'),
)]
class VertexProvider extends AiProviderClientBase implements
  ContainerFactoryPluginInterface,
  ChatInterface,
  EmbeddingsInterface {

  use StringTranslationTrait;
  use ChatTrait;

  /**
   * The OpenAI Client for API calls.
   *
   * @var \Google\Cloud\AIPlatform\V1\Client\PredictionServiceClient|null
   */
  protected $client;

  /**
   * Credential file.
   *
   * @var string
   */
  protected string $credentialFile = '';

  /**
   * We want to add models to the provider dynamically.
   *
   * @var bool
   */
  protected bool $hasPredefinedModels = FALSE;

  /**
   * {@inheritdoc}
   */
  public function isUsable(?string $operation_type = NULL, array $capabilities = []): bool {
    // If its not configured, it is not usable.
    if (!$this->getConfig()->get('general_credential_file')) {
      return FALSE;
    }
    // If its one of the bundles that Vertex supports its usable.
    if ($operation_type) {
      return in_array($operation_type, $this->getSupportedOperationTypes());
    }
    return TRUE;
  }

  /**
   * {@inheritdoc}
   */
  public function getSupportedOperationTypes(): array {
    return [
      'chat',
      'embeddings',
    ];
  }

  /**
   * {@inheritdoc}
   */
  public function getConfig(): ImmutableConfig {
    return $this->configFactory->get('ai_provider_google_vertex.settings');
  }

  /**
   * {@inheritdoc}
   */
  public function getApiDefinition(): array {
    // Load the configuration.
    $definition = Yaml::parseFile($this->moduleHandler->getModule('ai_provider_google_vertex')->getPath() . '/definitions/api_defaults.yml');
    return $definition;
  }

  /**
   * {@inheritdoc}
   */
  public function getModelSettings(string $model_id, array $generalConfig = []): array {
    return $generalConfig;
  }

  /**
   * {@inheritdoc}
   */
  public function setAuthentication(mixed $file_location): void {
    // Set the new file credential location.
    $this->credentialFile = $file_location;
    $this->client = NULL;
  }

  /**
   * Gets the raw chat client.
   *
   * This is the client for inference.
   *
   * @return \Google\Cloud\AIPlatform\V1\Client\PredictionServiceClient
   *   The Google Vertex client.
   */
  public function getClient(): PredictionServiceClient {
    $this->loadClient();
    return $this->client;
  }

  /**
   * Loads the Google Prediction Client with authentication if not initialized.
   */
  protected function loadClient(): void {
    if (!$this->client) {
      if (!$this->credentialFile) {
        $this->setAuthentication($this->loadCredentialFile());
      }
      putenv('GOOGLE_APPLICATION_CREDENTIALS=' . $this->credentialFile);
      $this->client = new PredictionServiceClient();
    }
  }

  /**
   * {@inheritdoc}
   */
  public function chat(array|string|ChatInput $input, string $model_id, array $tags = []): ChatOutput {
    $info = $this->getModelInfo('chat', $model_id);
    if (!isset($info['vertex_model_id']) || !isset($info['project_id']) || !isset($info['location'])) {
      throw new AiBadRequestException('The model does not exist.');
    }
    $this->loadClient();

    // Calculate the url.
    $url = sprintf('projects/%s/locations/%s/publishers/google/models/%s', $info['project_id'], $info['location'], $info['vertex_model_id']);
    $request = new GenerateContentRequest([
      'model' => $url,
    ]);

    if ($input instanceof ChatInput) {
      // If the system message is set.
      if ($this->chatSystemRole) {
        $request->setSystemInstruction(new Content([
          'role' => 'system',
          'parts' => [
            new Part([
              'text' => $this->chatSystemRole,
            ]),
          ],
        ]));
      }
      $chats = [];
      foreach ($input->getMessages() as $message) {
        $parts = [
          new Part([
            'text' => $message->getText(),
          ]),
        ];
        $structured_message = new Content([
          // Special rule for Vertex.
          'role' => $message->getRole() == 'assistant' ? 'model' : $message->getRole(),
        ]);
        if (count($message->getImages())) {
          foreach ($message->getImages() as $image) {
            $blob = new Blob([
              'data' => $image->getAsBinary(),
              'mime_type' => $image->getMimeType(),
            ]);
            $parts[] = new Part([
              'inline_data' => $blob,
            ]);
          }
        }
        $structured_message->setParts($parts);
        $chats[] = $structured_message;
      }
    }
    else {
      // Raw input.
      $chats = $input;
    }

    $raw_output = [];
    // Send POST request to Vertex AI API.
    $response = NULL;
    try {
      $request->setContents($chats);

      if ($this->streamed) {
        $response = $this->client->streamGenerateContent($request);
        // Streamed response.
        $return = new GoogleVertexChatIterator(new \ArrayIterator([]));
        $return->setStreamedResponse($response);
      }
      else {
        $response = $this->client->generateContent($request);
        // None-streamed response.
        foreach ($response->getCandidates() as $candidate) {
          $parts = $candidate->getContent()->getParts();
          if (isset($parts[0])) {
            $raw_output['candidates'][] = [
              'text' => $parts[0]->getText(),
              'role' => $candidate->getContent()->getRole(),
            ];
            $return = new ChatMessage($candidate->getContent()->getRole(), $parts[0]->getText());
          }
          else {
            throw new AiBadRequestException('Could not get a text response from the model.');
          }
        }
      }
    }
    catch (\Exception $e) {
      $this->throwError($e->getMessage());
    }
    if (!isset($return)) {
      throw new AiBadRequestException('Could not get a return from the model.');
    }

    // If return is a StreamedChatMessageIteratorInterface the.
    if ($return instanceof StreamedChatMessageIteratorInterface) {
      return new ChatOutput($return, [], []);
    }
    // Make the raw output.
    $raw_output['model_version'] = $response->getModelVersion();
    if ($response->hasPromptFeedback()) {
      $raw_output['prompt_feedback']['block_reason'] = $response->getPromptFeedback()->getBlockReason();
      $raw_output['prompt_feedback']['block_reason_message'] = $response->getPromptFeedback()->getBlockReasonMessage();
      $raw_output['prompt_feedback']['safety_rating'] = $response->getPromptFeedback()->getSafetyRatings();
    }
    if ($response->hasUsageMetadata()) {
      $raw_output['usage_metadata']['prompt_token_count'] = $response->getUsageMetadata()->getPromptTokenCount();
      $raw_output['usage_metadata']['candidates_token_count'] = $response->getUsageMetadata()->getCandidatesTokenCount();
      $raw_output['usage_metadata']['total_token_count'] = $response->getUsageMetadata()->getTotalTokenCount();
    }
    return new ChatOutput($return, $raw_output, $raw_output['usage_metadata']);
  }

  /**
   * {@inheritdoc}
   */
  public function embeddings(string|EmbeddingsInput $input, string $model_id, array $tags = []): EmbeddingsOutput {
    $info = $this->getModelInfo('embeddings', $model_id);
    if (!isset($info['vertex_model_id']) || !isset($info['project_id']) || !isset($info['location'])) {
      throw new AiBadRequestException('The model does not exist.');
    }
    $this->loadClient();

    // Calculate the url.
    $url = sprintf('projects/%s/locations/%s/publishers/google/models/%s', $info['project_id'], $info['location'], $info['vertex_model_id']);
    if ($input instanceof EmbeddingsInput) {
      $text = $input->getPrompt();
    }
    else {
      $text = $input;
    }

    // Create the instance payload.
    $instance = new Struct();
    $instance->setFields([
      'content' => (new Value())->setStringValue($text),
    ]);

    $instance_value = new Value();
    $instance_value->setStructValue($instance);

    // Create the PredictRequest.
    $request = (new PredictRequest())
      ->setEndpoint($url)
      ->setInstances([$instance_value]);

    $client = new PredictionServiceClient();

    $response = NULL;
    try {
      // Perform the prediction.
      $response = $client->predict($request);
      $embedding_values = [];
      foreach ($response->getPredictions() as $prediction) {
        $embeddings = $prediction->getStructValue()->getFields()->getIterator();
        foreach ($embeddings as $embedding) {
          $vector_lists = $embedding->getStructValue()->getFields()->getIterator();
          foreach ($vector_lists as $vector_list) {
            if ($vector_list->getListValue() instanceof ListValue) {
              $vectors = $vector_list->getListValue()->getValues()->getIterator();
              foreach ($vectors as $vector) {
                $embedding_values[] = $vector->getNumberValue();
              }
            }
          }
        }
      }
    }
    catch (\Exception $e) {
      $this->throwError($e->getMessage());
    }

    if (empty($embedding_values)) {
      throw new AiBadRequestException('Could not get a return from the model.');
    }
    // Make the raw output.
    $raw_output['model_version'] = $response->getModelDisplayName();
    $raw_output['prediction'] = $embedding_values;
    return new EmbeddingsOutput($embedding_values, $raw_output, []);
  }

  /**
   * {@inheritdoc}
   */
  public function maxEmbeddingsInput(string $model_id = ''): int {
    return 4096;
  }

  /**
   * {@inheritdoc}
   */
  public function embeddingsVectorSize(string $model_id): int {
    return 768;
  }

  /**
   * {@inheritdoc}
   */
  public function loadModelsForm(array $form, $form_state, string $operation_type, string|NULL $model_id = NULL): array {
    $form = parent::loadModelsForm($form, $form_state, $operation_type, $model_id);
    $config = $this->loadModelConfig($operation_type, $model_id);

    $form['model_data']['project_id'] = [
      '#type' => 'textfield',
      '#title' => $this->t('Project ID'),
      '#description' => $this->t('The Google Vertex project id needed to access the model.'),
      '#default_value' => $config['project_id'] ?? '',
      '#required' => TRUE,
    ];

    $form['model_data']['location'] = [
      '#type' => 'textfield',
      '#title' => $this->t('Location'),
      '#description' => $this->t('The Google Vertex location needed to access the model. Example: us-central1'),
      '#default_value' => $config['location'] ?? '',
      '#required' => TRUE,
    ];

    $form['model_data']['vertex_model_id'] = [
      '#type' => 'textfield',
      '#title' => $this->t('Vertex Model ID'),
      '#description' => $this->t('The Google Vertex model id. Example: gemini-1.0-pro'),
      '#default_value' => $config['vertex_model_id'] ?? '',
      '#required' => TRUE,
    ];

    return $form;
  }

  /**
   * Throw errors.
   *
   * @param string $message
   *   The message that got thrown.
   */
  protected function throwError(string $message): void {
    $data = Json::decode($message);
    if (is_array($data)) {
      if (isset($data['reason'])) {
        if ($data['reason'] == 'CONSUMER_INVALID') {
          throw new AiBadRequestException('You do not have access to the model.');
        }
      }
    }
    throw new AiBadRequestException('Unknown Error: Could not get a response from the model.');
  }

  /**
   * Load the credentials file from the key module.
   *
   * @return string
   *   The credential file location.
   */
  protected function loadCredentialFile(): string {
    $key_id = $this->getConfig()->get('general_credential_file');
    $key = $this->keyRepository->getKey($key_id);
    // If it came here, but the key is missing, something is wrong with the env.
    if (!$key || !($file = $key->getKeyValue())) {
      throw new AiSetupFailureException(sprintf('Could not load the %s credential file or its not available, please check your environment settings or your setup.', $this->getPluginDefinition()['label']));
    }
    return $file;
  }

}
