<?php

namespace Drupal\ai_provider_ollama\Plugin\AiProvider;

use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\Plugin\ContainerFactoryPluginInterface;
use Drupal\Core\StringTranslation\StringTranslationTrait;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Drupal\ai\Attribute\AiProvider;
use Drupal\ai\Base\AiProviderClientBase;
use Drupal\ai\Exception\AiRequestErrorException;
use Drupal\ai\OperationType\Chat\ChatInput;
use Drupal\ai\OperationType\Chat\ChatInterface;
use Drupal\ai\OperationType\Chat\ChatMessage;
use Drupal\ai\OperationType\Chat\ChatOutput;
use Drupal\ai\OperationType\Chat\Tools\ToolsFunctionOutput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInterface;
use Drupal\ai\OperationType\Embeddings\EmbeddingsOutput;
use Drupal\ai\OperationType\Moderation\ModerationInput;
use Drupal\ai\OperationType\Moderation\ModerationOutput;
use Drupal\ai\OperationType\Moderation\ModerationResponse;
use Drupal\ai\Traits\OperationType\ChatTrait;
use Drupal\ai_provider_ollama\Models\Moderation\LlamaGuard3;
use Drupal\ai_provider_ollama\Models\Moderation\ShieldGemma;
use Drupal\ai_provider_ollama\OllamaChatMessageIterator;
use Drupal\ai_provider_ollama\OllamaControlApi;
use Drupal\Component\Serialization\Json;
use GuzzleHttp\Client as GuzzleClient;
use OpenAI\Client;
use Symfony\Component\DependencyInjection\ContainerInterface;
use Symfony\Component\Yaml\Yaml;

/**
 * Plugin implementation of the 'ollama' provider.
 */
#[AiProvider(
  id: 'ollama',
  label: new TranslatableMarkup('Ollama'),
)]
class OllamaProvider extends AiProviderClientBase implements
  ContainerFactoryPluginInterface,
  ChatInterface,
  EmbeddingsInterface {

  use StringTranslationTrait;
  use ChatTrait;

  /**
   * The OpenAI Client for API calls.
   *
   * @var \OpenAI\Client|null
   */
  protected $client;

  /**
   * The Ollama Control API for configuration calls.
   *
   * @var \Drupal\ai_provider_ollama\OllamaControlApi
   */
  protected $controlApi;

  /**
   * The current user.
   *
   * @var \Drupal\Core\Session\AccountProxyInterface
   */
  protected $currentUser;

  /**
   * The messenger service.
   *
   * @var \Drupal\Core\Messenger\MessengerInterface
   */
  protected $messenger;

  /**
   * Dependency Injection for the Ollama Control API.
   */
  public static function create(ContainerInterface $container, array $configuration, $plugin_id, $plugin_definition) {
    $instance = parent::create($container, $configuration, $plugin_id, $plugin_definition);
    $instance->controlApi = $container->get('ai_provider_ollama.control_api');
    $instance->controlApi->setConnectData($instance->getBaseHost());
    $instance->currentUser = $container->get('current_user');
    $instance->messenger = $container->get('messenger');
    return $instance;
  }

  /**
   * {@inheritdoc}
   */
  public function getConfiguredModels(?string $operation_type = NULL, array $capabilities = []): array {
    // Graceful failure.
    try {
      $response = $this->controlApi->getModels();
    }
    catch (\Exception $e) {
      if ($this->currentUser->hasPermission('administer ai providers')) {
        $this->messenger->addError($this->t('Failed to get models from Ollama: @error', ['@error' => $e->getMessage()]));
      }
      $this->loggerFactory->get('ai_provider_ollama')->error('Failed to get models from Ollama: @error', ['@error' => $e->getMessage()]);
      return [];
    }
    $models = [];
    if (isset($response['models'])) {
      foreach ($response['models'] as $model) {
        $root_model = explode(':', $model['model'])[0];
        if ($operation_type == 'moderation') {
          if (in_array($root_model, [
            'shieldgemma',
            'llama-guard3',
          ])) {
            $models[$model['model']] = $model['name'];
          }
        }
        else {
          $models[$model['model']] = $model['name'];
        }
      }
    }
    return $models;
  }

  /**
   * {@inheritdoc}
   */
  public function isUsable(?string $operation_type = NULL, array $capabilities = []): bool {
    // If its one of the bundles that Ollama supports its usable.
    if (!$this->getBaseHost()) {
      return FALSE;
    }
    if ($operation_type) {
      return in_array($operation_type, $this->getSupportedOperationTypes());
    }
    return TRUE;
  }

  /**
   * {@inheritdoc}
   */
  public function getSupportedOperationTypes(): array {
    return [
      'chat',
      'embeddings',
      'moderation',
    ];
  }

  /**
   * {@inheritdoc}
   */
  public function getConfig(): ImmutableConfig {
    return $this->configFactory->get('ai_provider_ollama.settings');
  }

  /**
   * {@inheritdoc}
   */
  public function getApiDefinition(): array {
    // Load the configuration.
    $definition = Yaml::parseFile($this->moduleHandler->getModule('ai_provider_ollama')->getPath() . '/definitions/api_defaults.yml');
    return $definition;
  }

  /**
   * {@inheritdoc}
   */
  public function getModelSettings(string $model_id, array $generalConfig = []): array {
    return $generalConfig;
  }

  /**
   * {@inheritdoc}
   */
  public function setAuthentication(mixed $authentication): void {
    // Doesn't do anything.
    $this->client = NULL;
  }

  /**
   * Gets the raw client.
   *
   * This is the client for inference.
   *
   * @return \OpenAI\Client
   *   The OpenAI client.
   */
  public function getClient(): Client {
    $this->loadClient();
    return $this->client;
  }

  /**
   * Get control client.
   *
   * This is the client for controlling the Ollama API.
   *
   * @return \Drupal\ai_provider_ollama\OllamaControlApi
   *   The control client.
   */
  public function getControlClient(): OllamaControlApi {
    return $this->controlApi;
  }

  /**
   * Loads the Ollama Client with hostname and port.
   */
  protected function loadClient(): void {
    if (!$this->client) {
      $host = $this->getBaseHost();
      $host .= '/v1';

      // Set longer timeout.
      $client = new GuzzleClient(['timeout' => 600]);

      $this->client = \OpenAI::factory()
        ->withHttpClient($client)
        ->withBaseUri($host)
        ->withHttpHeader('content-type', 'application/json')
        ->make();
    }
  }

  /**
   * {@inheritdoc}
   */
  public function chat(array|string|ChatInput $input, string $model_id, array $tags = []): ChatOutput {
    $this->loadClient();
    // Normalize the input if needed.
    $chat_input = $input;
    $images_found = FALSE;
    if ($input instanceof ChatInput) {
      $chat_input = [];
      // Add a system role if wanted.
      if ($this->chatSystemRole) {
        $chat_input[] = [
          'role' => 'system',
          'content' => $this->chatSystemRole,
        ];
      }
      /** @var \Drupal\ai\OperationType\Chat\ChatMessage $message */
      foreach ($input->getMessages() as $message) {
        $content = $message->getText();
        $images = [];
        if (count($message->getImages())) {
          $images_found = TRUE;
          foreach ($message->getImages() as $image) {
            $images[] = $image->getAsBase64EncodedString('');
          }
        }
        $chat_input[] = [
          'role' => $message->getRole(),
          'content' => $content,
          'images' => $images,
        ];
      }
    }
    $payload = [
      'model' => $model_id,
      'messages' => $chat_input,
    ] + $this->configuration;
    if (method_exists($input, 'getChatTools') && $input->getChatTools()) {
      $tools = $input->getChatTools()->renderToolsArray();
      // Ollama does only support string enums.
      foreach ($tools as $key => $tool) {
        foreach ($tool['function']['parameters']['properties'] as $param_id => $param) {
          if (!empty($param['enum'])) {
            $tools[$key]['function']['parameters']['properties'][$param_id]['type'] = 'string';
            foreach ($param['enum'] as $enum_key => $enum_value) {
              $tools[$key]['function']['parameters']['properties'][$param_id]['enum'][$enum_key] = (string) $enum_value;
            }
          }
        }
      }
      $payload['tools'] = $tools;
    }
    // Also check if structured response is wanted.
    if (method_exists($input, 'getChatStructuredJsonSchema') && $input->getChatStructuredJsonSchema()) {
      $payload['format'] = $input->getChatStructuredJsonSchema()['schema'];
      $response = $this->controlApi->chat($payload);
      $message = new ChatMessage($response['message']['role'], $response['message']['content']);
    }
    elseif ($images_found) {
      $response = $this->controlApi->chat($payload);
      $message = new ChatMessage($response['message']['role'], $response['message']['content']);
    }
    elseif ($this->streamed) {
      $response = $this->client->chat()->createStreamed($payload);
      $message = new OllamaChatMessageIterator($response);
    }
    else {
      $response = $this->client->chat()->create($payload)->toArray();
      $message = new ChatMessage($response['choices'][0]['message']['role'], $response['choices'][0]['message']['content']);
    }
    if (method_exists($input, 'getChatTools') && !empty($input->getChatTools()) && !empty($response['choices'][0]['message']['tool_calls'])) {
      $functions = [];
      foreach ($response['choices'][0]['message']['tool_calls'] as $tool_call) {
        if (isset($tool_call['function'])) {
          $input_function = $input->getChatTools()->getFunctionByName($tool_call['function']['name']);
          if ($input_function) {
            $arguments = Json::decode($tool_call['function']['arguments']);
            $results = [];
            $functions[] = new ToolsFunctionOutput($input->getChatTools()->getFunctionByName($tool_call['function']['name']), $tool_call['id'], $arguments);
          }
        }
        $message->setTools($functions);
      }

    }
    return new ChatOutput($message, $response, []);
  }

  /**
   * {@inheritdoc}
   */
  public function embeddings(string|EmbeddingsInput $input, string $model_id, array $tags = []): EmbeddingsOutput {
    $this->loadClient();
    // Normalize the input if needed.
    if ($input instanceof EmbeddingsInput) {
      $input = $input->getPrompt();
    }
    $response = $this->controlApi->embeddings($input, $model_id);

    return new EmbeddingsOutput($response['embedding'], $response, []);
  }

  /**
   * {@inheritdoc}
   */
  public function moderation(string|ModerationInput $input, ?string $model_id = NULL, array $tags = []): ModerationOutput {
    $this->loadClient();
    // Normalize the input if needed.
    $chat_input[] = [
      'role' => 'user',
      'content' => $input instanceof ModerationInput ? $input->getPrompt() : $input,
    ];

    $payload = [
      'model' => $model_id,
      'messages' => $chat_input,
    ] + $this->configuration;

    $response = $this->client->chat()->create($payload)->toArray();
    if (!isset($response['choices'][0]['message']['content'])) {
      throw new AiRequestErrorException('No content in moderation response.');
    }
    $message = $response['choices'][0]['message']['content'];

    $moderation_response = new ModerationResponse(FALSE);
    $root_model_id = explode(':', $model_id)[0];
    switch ($root_model_id) {
      case 'llama-guard3':
        $moderation_response = LlamaGuard3::moderationRules($message);
        break;

      case 'shieldgemma':
        $moderation_response = ShieldGemma::moderationRules($message);
        break;

      default:
        throw new AiRequestErrorException('Model not supported for moderation.');

    }

    return new ModerationOutput($moderation_response, $message, $response);
  }

  /**
   * {@inheritdoc}
   */
  public function embeddingsVectorSize(string $model_id): int {
    $this->loadClient();
    $data = $this->controlApi->embeddingsVectorSize($model_id);
    if ($data) {
      return $data;
    }
    // Fallback to parent method.
    return parent::embeddingsVectorSize($model_id);
  }

  /**
   * Gets the base host.
   *
   * @return string
   *   The base host.
   */
  protected function getBaseHost(): string {
    $host = rtrim($this->getConfig()->get('host_name'), '/');
    if ($this->getConfig()->get('port')) {
      $host .= ':' . $this->getConfig()->get('port');
    }
    return $host;
  }

  /**
   * {@inheritdoc}
   */
  public function maxEmbeddingsInput($model_id = ''): int {
    $this->loadClient();
    return $this->controlApi->embeddingsContextSize($model_id);
  }

}
