<?php

namespace Drupal\ai_provider_groq\Plugin\AiProvider;

use Drupal\Component\Serialization\Json;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\Plugin\ContainerFactoryPluginInterface;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Drupal\ai\Attribute\AiProvider;
use Drupal\ai\Base\AiProviderClientBase;
use Drupal\ai\Enum\AiModelCapability;
use Drupal\ai\Exception\AiRateLimitException;
use Drupal\ai\OperationType\Chat\ChatInput;
use Drupal\ai\OperationType\Chat\ChatInterface;
use Drupal\ai\OperationType\Chat\ChatMessage;
use Drupal\ai\OperationType\Chat\ChatOutput;
use Drupal\ai\OperationType\Chat\Tools\ToolsFunctionOutput;
use Drupal\ai\Traits\OperationType\ChatTrait;
use Drupal\Component\Utility\Crypt;
use OpenAI\Client;
use Symfony\Component\Yaml\Yaml;

/**
 * Plugin implementation of the 'groq' provider.
 */
#[AiProvider(
  id: 'groq',
  label: new TranslatableMarkup('Groq'),
)]
class GroqProvider extends AiProviderClientBase implements
  ContainerFactoryPluginInterface,
  ChatInterface {

  use ChatTrait;


  /**
   * The OpenAI Client for API calls.
   *
   * @var \OpenAI\Client|null
   */
  protected ?Client $client = NULL;

  /**
   * API Key.
   *
   * @var string
   */
  protected string $apiKey = '';

  /**
   * The list of models that support function calling.
   *
   * @var array
   */
  protected array $toolCallingModels = [
    'qwen-qwq-32b',
    'qwen-2.5-coder-32b',
    'qwen-2.5-32b',
    'deepseek-r1-distill-qwen-32b',
    'deepseek-r1-distill-llama-70b',
    'llama-3.3-70b-versatile',
    'llama-3.1-8b-instant',
    'mixtral-8x7b-32768',
    'gemma2-9b-it',
  ];

  /**
   * {@inheritdoc}
   */
  public function getConfiguredModels(?string $operation_type = NULL, array $capabilities = []): array {
    // Get the cache key.
    $cache_key = 'groq_models_' . $operation_type . '_' . Crypt::hashBase64(Json::encode($capabilities));
    $cache_data = $this->cacheBackend->get($cache_key);

    if (!empty($cache_data)) {
      return $cache_data->data;
    }

    // Get all available models with a single request.
    $response = $this->getClient()->models()->list()->toArray();

    $models = [];

    if (isset($response['data'])) {
      foreach ($response['data'] as $model) {
        $model_id = $model['id'];

        // Skip text-to-speech models (for now).
        if (strpos($model_id, 'playai-tts') === 0) {
          continue;
        }

        // Skip speech-to-text (Whisper) models.
        if (strpos($model_id, 'whisper') === 0 || strpos($model_id, 'distil-whisper') === 0) {
          continue;
        }

        // Handle vision models.
        if (in_array(AiModelCapability::ChatWithImageVision, $capabilities, TRUE)) {
          // Only include models with "vision" in their name.
          if (strpos($model_id, 'vision') !== FALSE) {
            $models[$model_id] = $model_id;
          }
          continue;
        }

        // Add all other models.
        $models[$model_id] = $model_id;
      }
    }

    // If its function calling capability.
    if (in_array(AiModelCapability::ChatTools, $capabilities, TRUE)) {
      // Hardcoded list :(.
      foreach ($models as $model_id => $model_name) {
        if (!in_array($model_id, $this->toolCallingModels)) {
          unset($models[$model_id]);
        }
      }
    }

    // Save cache.
    if (!empty($models)) {
      asort($models);
      $this->cacheBackend->set($cache_key, $models);
    }

    return $models;
  }

  /**
   * Determines if a model supports reasoning format.
   *
   * @param string $model_id
   *   The model ID to check.
   *
   * @return bool
   *   TRUE if the model supports reasoning format, FALSE otherwise.
   */
  protected function modelSupportsReasoning(string $model_id): bool {
    // Check if the model ID contains '-qwq-' or 'deepseek'.
    return strpos($model_id, '-qwq-') !== FALSE || strpos($model_id, 'deepseek') !== FALSE;
  }

  /**
   * {@inheritdoc}
   */
  public function isUsable(?string $operation_type = NULL, array $capabilities = []): bool {
    // If its not configured, it is not usable.
    if (!$this->getConfig()->get('api_key')) {
      return FALSE;
    }
    // If its one of the bundles that Groq supports its usable.
    if ($operation_type) {
      return in_array($operation_type, $this->getSupportedOperationTypes(), TRUE);
    }
    return TRUE;
  }

  /**
   * {@inheritdoc}
   */
  public function getSupportedOperationTypes(): array {
    return [
      'chat',
    ];
  }

  /**
   * {@inheritdoc}
   */
  public function getConfig(): ImmutableConfig {
    return $this->configFactory->get('ai_provider_groq.settings');
  }

  /**
   * Get operation-specific settings or default settings.
   *
   * @param string $operation_type
   *   The type of operation.
   *
   * @return array
   *   Array of settings for the operation.
   */
  protected function getOperationSettings(string $operation_type): array {
    // Get default fallback settings.
    $config = $this->getConfig();
    $default_settings = [
      'reasoning_format' => $config->get('reasoning_format') ?: 'hidden',
      'temperature' => $config->get('temperature') ?: 0.6,
      'max_tokens' => $config->get('max_tokens') ?: 1024,
      'json_mode' => $config->get('json_mode') ?: FALSE,
    ];

    // Get operation-specific overrides if they exist.
    $overrides_config = $this->configFactory->get('ai_provider_groq.overrides');
    $operation_overrides = $overrides_config->get('operation_overrides') ?? [];

    // If there are overrides for this operation, merge them with defaults.
    if (!empty($operation_overrides[$operation_type])) {
      return array_merge($default_settings, $operation_overrides[$operation_type]);
    }

    return $default_settings;
  }

  /**
   * {@inheritdoc}
   */
  public function getApiDefinition(): array {
    // Load the configuration.
    $definition = Yaml::parseFile($this->moduleHandler->getModule('ai_provider_groq')->getPath() . '/definitions/api_defaults.yml');
    return $definition;
  }

  /**
   * {@inheritdoc}
   */
  public function getModelSettings(string $model_id, array $generalConfig = []): array {
    // If this is from a specific operation, get operation-specific settings.
    $operation_type = $generalConfig['operation_type'] ?? NULL;
    if ($operation_type) {
      $settings = $this->getOperationSettings($operation_type);
      return array_merge($generalConfig, $settings);
    }
    return $generalConfig;
  }

  /**
   * {@inheritdoc}
   */
  public function setAuthentication(mixed $authentication): void {
    // Set the new API key and reset the client.
    $this->apiKey = $authentication;
    $this->client = NULL;
  }

  /**
   * Gets the raw client.
   *
   * This is the client for inference.
   *
   * @return \OpenAI\Client
   *   The OpenAI client.
   */
  public function getClient(): Client {
    $this->loadClient();
    return $this->client;
  }

  /**
   * Loads the Groq Client with authentication if not initialized.
   */
  protected function loadClient(): void {
    if (!$this->client) {
      if (!$this->apiKey) {
        $this->setAuthentication($this->loadApiKey());
      }

      $this->client = \OpenAI::factory()
        ->withApiKey($this->apiKey)
        ->withBaseUri('https://api.groq.com/openai/v1')
        ->withHttpClient($this->httpClient)
        ->make();
    }
  }

  /**
   * {@inheritdoc}
   */
  public function chat(array|string|ChatInput $input, string $model_id, array $tags = []): ChatOutput {
    $this->loadClient();

    // Normalize the input if needed.
    $chat_input = $input;
    if ($input instanceof ChatInput) {
      $chat_input = [];
      if ($this->chatSystemRole) {
        $chat_input[] = [
          'role' => 'system',
          'content' => $this->chatSystemRole,
        ];
      }
      foreach ($input->getMessages() as $message) {
        $new_message = [
          'role' => $message->getRole(),
          'content' => $message->getText(),
        ];
        // If its a tools response.
        if ($message->getToolsId()) {
          $new_message['tool_call_id'] = $message->getToolsId();
        }

        // If we want the results from some older tools call.
        if ($message->getTools()) {
          $new_message['tool_calls'] = $message->getRenderedTools();
        }

        $chat_input[] = $new_message;
      }
    }

    // Get operation-specific settings or defaults.
    $settings = $this->getOperationSettings('chat');

    // Initialize the API request payload.
    $payload = [
      'model' => $model_id,
      'messages' => $chat_input,
    ] + $this->configuration;

    // If we want to add tools to the input.
    if (method_exists($input, 'getChatTools') && $input->getChatTools()) {
      $payload['tools'] = $input->getChatTools()->renderToolsArray();
      foreach ($payload['tools'] as $key => $tool) {
        $payload['tools'][$key]['function']['strict'] = FALSE;
      }
    }

    // Merge in the defaults 1st.
    $payload = array_merge($payload, $settings);
    // Add any remaining configuration from the provider.
    $payload = array_merge($payload, $this->configuration);

    // Handle reasoning format parameter.
    // Reasoning format should only be applied when:
    // 1. The model supports reasoning functionality.
    // 2. A valid format option is selected (not 'n/a').
    if (!$this->modelSupportsReasoning($model_id)) {
      // For non-supporting models, always remove reasoning format.
      unset($payload['reasoning_format']);
    }
    elseif (isset($payload['reasoning_format']) && $payload['reasoning_format'] === 'n/a') {
      // Payload has 'n/a' AND it's reasoning model,
      // use operation setting or default to 'hidden'.
      $payload['reasoning_format'] = $settings['reasoning_format'] ?? 'hidden';
    }

    // Apply JSON mode setting if enabled (normalize for what Groq API expects).
    if (!empty($settings['json_mode'])) {
      $payload['response_format'] = ['type' => 'json_object'];
    }

    // If we do get passed a FALSE - just pop-it.
    if ($payload['json_mode'] === FALSE) {
      unset($payload['json_mode']);
    }

    try {
      $response = $this->client->chat()->create($payload)->toArray();
    }
    catch (\Exception $e) {
      // Check if this is a rate limit issue.
      if (strpos($e->getMessage(), 'Rate limit reached for model') !== FALSE) {
        throw new AiRateLimitException($e->getMessage());
      }
      throw $e;
    }
    $message = new ChatMessage($response['choices'][0]['message']['role'], $response['choices'][0]['message']['content'] ?? "");
    if (isset($response['choices'][0]['finish_reason']) && $response['choices'][0]['finish_reason'] == "tool_calls") {
      $tools = [];
      if (!empty($response['choices'][0]['message']['tool_calls'])) {
        foreach ($response['choices'][0]['message']['tool_calls'] as $tool) {
          $arguments = Json::decode($tool['function']['arguments']);
          $tools[] = new ToolsFunctionOutput($input->getChatTools()->getFunctionByName($tool['function']['name']), $tool['id'], $arguments);
        }
      }
      if (!empty($tools)) {
        $message->setTools($tools);
      }
    }

    // Store any metadata from the response.
    $metadata = [];
    return new ChatOutput($message, $response, $metadata);
  }

  /**
   * {@inheritdoc}
   */
  public function getSetupData(): array {
    return [
      'key_config_name' => 'api_key',
      'default_models' => [
        'chat' => 'llama-3.3-70b-versatile',
      ],
    ];
  }

}
