#' Get DRI LLM Response
#'
#' `get_dri_llm_response` uses <https://openrouter.ai> to generate artificial LLM
#' responses to DRI survey questions
#'
#' @param model_id a model_id string from openrounter.ai
#' @param survey_info a list with survey question information, including `type`,
#' `order`, `statement`, `name`, `scale_max`, and `q_method`
#' @param api_key the API key generated by OpenRouter
#' @param role_info a named list with basic data of a role (i.e., `uid`,
#' `role`, `description`)
#' @param n the number of responses requested (default = 1)
#' @param request_log_path an optional path to a file where the request texts
#' are saved
#'
#' @returns a dataframe with `n` survey responses by `model_id`, including a
#' unique identifier, `uuid`, a creation timestamp, `created_at_utc`, the time
#' it took to generate the response, `time_s`, the estimated cost in USD,
#' `est_cost_usd`, whether the response is valid, `is_valid`, and the reason it
#' is not, `invalid_reason`
#'
#' @export get_dri_llm_response
#'
#' @import tibble
#' @importFrom uuid UUIDgenerate
#' @importFrom tidyr pivot_wider
#' @importFrom readr write_csv
#' @importFrom rlang .data
#' @importFrom stats IQR median
#'
#' @seealso [get_model_ids()] for all currently available model ids from
#' openrouter.ai
#'
#' @family LLM methods
#'
#' @examples
#'
#' # get DRI survey
#' survey_info <- surveys[surveys$name == "acp",]
#'
#' # select a model from openrouter
#' model_id <- "google/gemini-2.5-flash-lite"
#'
#' # send request to openrouter API
#' \dontrun{
#' llm_data <- get_dri_llm_response(model_id, survey_info)}
#'
get_dri_llm_response <- function(model_id,
                                 survey_info = list(
                                   type = NA_character_,
                                   order = NA_integer_,
                                   statement = NA_character_,
                                   name = NA_character_,
                                   scale_max = NA_integer_,
                                   q_method = NA
                                 ),
                                 api_key = Sys.getenv("OPENROUTER_API_KEY"),
                                 role_info = list(
                                   uid = NA_character_,
                                   role = NA_character_,
                                   description = NA_character_
                                 ),
                                 n = 1,
                                 request_log_path = NA_character_) {

  # set time to UTC for consistent logging
  Sys.setenv(TZ='UTC')

  # survey_names <- sort(unique(surveys$name))
  #
  # if (!survey_name %in% survey_names) {
  #   stop("Invalid survey name: ",
  #        survey_name,
  #        "\nValid names include:\n",
  #        paste(paste0(1:length(survey_names), ". ", sort(survey_names)), collapse = "\n"))
  # }

  split_parts <- strsplit(model_id, "/")[[1]]

  ### FORMAT SURVEY
  dri_survey <- format_dri_survey(survey_info)

  llm_data <- list()

  for (i in 1:n) {

    # CREATE META DATA to be attached to request logs
    meta <- tibble(
      uuid = UUIDgenerate(),
      created_at_utc = Sys.time(),
      provider = split_parts[1],
      model = split_parts[2],
      survey = dri_survey$name,
      role_uid = if (length(role_info$uid) == 0) NA else role_info$uid,
    )


    ### MAKE PROMPT and shuffle statements
    shuffled_info <- .shuffle_statements(dri_survey)

    prompts <- make_dri_llm_prompts(shuffled_info, role_info)


    ### GET LLM RESPONSES
    est_cost_usd <- 0

    # considerations (C)
    res <- get_llm_response(prompts$considerations,
                            model_id = model_id,
                            system_prompt = prompts$system,
                            api_key = api_key)
    response_c <- res$response
    est_cost_usd <- est_cost_usd + .calculate_cost(res$usage, model_id)
    log <- .log_request(meta, "considerations", prompts$considerations, res)

    # policy preferences (P)
    res <- get_llm_response(prompts$policies,
                            model_id = model_id,
                            context = res$context,
                            api_key = api_key)
    response_p <- res$response
    est_cost_usd <- est_cost_usd + .calculate_cost(res$usage, model_id)
    log <- bind_rows(log, .log_request(meta, "policies", prompts$policies, res))

    # reason (R)
    res <- get_llm_response(prompts$reason,
                            model_id = model_id,
                            context = res$context,
                            api_key = api_key)
    response_r <- res$response
    est_cost_usd <- est_cost_usd + .calculate_cost(res$usage, model_id)
    log <- bind_rows(log, .log_request(meta, "reason", prompts$reason, res))

    ## PARSE RESPONSES
    considerations <- .parse_llm_response(response_c, 50, "C", shuffled_info)
    policies <- .parse_llm_response(response_p, 10, "P", shuffled_info)
    reason <- .parse_llm_response(response_r, 1, "R")

    validity <- .is_valid_response(considerations, policies, dri_survey)

    end_time <- Sys.time()
    time_s <- as.numeric(difftime(end_time, meta$created_at_utc, units = "secs"))

    llm_data[[length(llm_data)+1]] <- tibble(
      meta,
      time_s,
      est_cost_usd,
      validity,
      considerations,
      policies,
      reason
    )

    ## append log to
    if (!is.na(request_log_path))
      write_csv(log, request_log_path, append = TRUE)

    # log result
    progress <- if (n > 1) paste0("[",i,"/",n,"] ") else ""
    status <- if (validity$is_valid) "SUCCESS: " else "ERROR! "
    message(progress,
            status,"LLM response generated in ",
            round(time_s, 1), "s")

  }

  llm_data <- bind_rows(llm_data)

  return(llm_data)

}


# --- Internal Caching for Model Prices ---
.openrouter_cache <- new.env(parent = emptyenv())

.get_model_pricing <- function(model_id) {
  # Check if model data is already cached
  if (is.null(.openrouter_cache$models)) {
    # message("Fetching model pricing information from OpenRouter...")
    response <- httr::GET("https://openrouter.ai/api/v1/models")
    if (httr::status_code(response) == 200) {
      .openrouter_cache$models <- httr::content(response, "parsed")$data
    } else {
      stop("Failed to fetch model pricing information.")
    }
  }

  # Find the specific model in the cached data
  model_info <- Filter(function(m) m$id == model_id, .openrouter_cache$models)

  if (length(model_info) == 0) {
    warning("Could not find pricing information for model: ", model_id)
    return(list(prompt = 0, completion = 0))
  }

  pricing <- model_info[[1]]$pricing

  return(list(
    prompt = as.numeric(pricing$prompt),
    completion = as.numeric(pricing$completion)
  ))
}

.calculate_cost <- function(usage, model_id) {

  prompt_tokens <- usage$prompt_tokens
  completion_tokens <- usage$completion_tokens

  pricing <- .get_model_pricing(model_id)

  prompt_cost <- prompt_tokens * pricing$prompt
  completion_cost <- completion_tokens * pricing$completion
  total_cost <- prompt_cost + completion_cost

  return(total_cost)
}

.shuffle_statements <- function(dri_survey) {
  dri_survey$considerations <- dri_survey$considerations %>%
    mutate(shuffle = sample(order)) %>%
    arrange(.data$shuffle)
  dri_survey$policies <- dri_survey$policies %>%
    mutate(shuffle = sample(order)) %>%
    arrange(.data$shuffle)
  dri_survey
}

.parse_llm_response <- function(response, max_cols=c(50, 10, 1), col_prefix=c("C", "P", "R"), shuffled_info=NULL) {

  # check for reasoning case
  if (col_prefix == "R") {
    return(tibble(
      R = gsub("[\r\n]+$", "", response) ## remove trailing newlines
    ))
  }

  lines <- unlist(strsplit(trimws(response), "\n"))
  data_matrix <- do.call(rbind, strsplit(lines, "\\. "))

  # retrieve shuffled order
  if (col_prefix == "C")
    order <- shuffled_info$considerations$order
  else if (col_prefix == "P")
    order <- shuffled_info$policies$order
  else
    order <- as.numeric(data_matrix[, 1])

  df <- data.frame(
    order = order,
    value = as.numeric(data_matrix[, 2])
  )

  # unshuffle statements
  df <- df %>%
    arrange(order) %>%
    mutate(sid = paste0(col_prefix, order)) %>%
    select(.data$sid, .data$value)

  df <- pivot_wider(df, names_from = "sid", values_from = "value")

  num_cols <- ncol(df)

  if (is.na(max_cols) || num_cols >= max_cols) {
    return(df)
  }

  # Calculate the number of columns to add
  cols_to_add <- max_cols - num_cols

  # Create a new data frame with the columns to be added
  new_cols <- data.frame(matrix(NA, nrow = nrow(df), ncol = cols_to_add))

  # Generate names for the new columns
  names(new_cols) <- paste0(col_prefix, (num_cols + 1):max_cols)

  # Combine the original data frame with the new columns
  # The use of `cbind` ensures the new columns are appended
  combined_df <- cbind(df, new_cols)

  return(combined_df)
}

.log_request <- function(meta, type, prompt, res) {

  tibble(
    meta,
    type,
    prompt_tokens = res$usage$prompt_tokens,
    completion_tokens = res$usage$completion_tokens,
    prompt,
    response = res$response,
  )

}

.is_valid_response <- function(considerations, policies, dri_survey) {

  # Extract relevant data from dri_survey
  c_ranks <- considerations %>% select(matches("^C\\d+$") & where(~!all(is.na(.))))
  p_ranks <- policies %>% select(matches("^P\\d+$") & where(~!all(is.na(.))))
  scale_max <- dri_survey$scale_max
  q_method <- dri_survey$q_method

  validity <- tibble(
    is_valid = TRUE,
    invalid_reason = NA_character_
  )

  # Check if data is valid (length mismatch)
  if (ncol(c_ranks) != nrow(dri_survey$considerations)) {
    message(paste("- Considerations length mismatch (", ncol(c_ranks), "/", nrow(dri_survey$considerations), ")."))
    validity$is_valid = FALSE; validity$invalid_reason = "c_length_mismatch"
    return(validity)
  }

  if (ncol(p_ranks) != nrow(dri_survey$policies)) {
    message(paste("- Policies length mismatch (", ncol(p_ranks), "/", nrow(dri_survey$policies), ")."))
    validity$is_valid = FALSE; validity$invalid_reason = "p_length_mismatch"
    return(validity)
  }

  # Check if c_ranks contains invalid values
  if (any(c_ranks > scale_max | c_ranks < 1)) {
    message("- Consideration ranks contain invalid values.")
    validity$is_valid = FALSE; validity$invalid_reason = "c_invalid_values"
    return(validity)
  }

  # Check if p_ranks contains invalid values
  if (any(p_ranks > ncol(p_ranks) | p_ranks < 1)) {
    message("- Policy ranks contain invalid values.")
    validity$is_valid = FALSE; validity$invalid_reason = "p_invalid_values"
    return(validity)
  }

  # Check for duplicate values in p_ranks
  if (ncol(p_ranks) != length(unique(unlist(p_ranks)))) {
    message("- Policy ranks contains duplicate values.")
    validity$is_valid = FALSE; validity$invalid_reason = "p_duplicate_ranks"
    return(validity)
  }

  # Check for quasi-normality (assuming a quasi_normality_check function exists in R)
  if (q_method && !.quasi_normality_check(c_ranks)) {
    message("- Considerations do not follow a Fixed Quasi-Normal Distribution.")
    validity$is_valid = FALSE; validity$invalid_reason = "c_not_q_method"
    return(validity)
  }

  # Check if all considerations are the same value
  if (length(unique(unlist(c_ranks))) == 1) {
    message("- All considerations have the same rating.")
    validity$is_valid = FALSE; validity$invalid_reason = "c_all_equal"
    return(validity)
  }

  return(validity)
}


# FIXME: make check more robust
.quasi_normality_check <- function(ratings) {

  mean_val <- mean(ratings)
  median_val <- median(ratings)
  iqr_val <- IQR(ratings)

  # Define rough criteria (adjust as needed)
  is_quasi_normal <- abs(mean_val - median_val) < 10 && iqr_val < 30

  return(is_quasi_normal)
}

