Skip to content

Commit

Permalink
Merge pull request MichelNivard#91 from calderonsamuel/httr2
Browse files Browse the repository at this point in the history
Migrate internals to httr2 and handle chat completion stream
  • Loading branch information
JamesHWade authored May 17, 2023
2 parents 94b2c0c + 081799f commit 6905643
Show file tree
Hide file tree
Showing 49 changed files with 1,324 additions and 402 deletions.
6 changes: 5 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,22 @@ Imports:
bslib (>= 0.4.2),
cli,
colorspace,
curl,
fontawesome,
glue,
grDevices,
htmltools,
httr,
htmlwidgets,
httr2,
jsonlite,
magrittr,
methods,
purrr,
R6,
rlang,
rstudioapi (>= 0.12),
shiny,
stringr,
usethis,
utils,
waiter
Expand Down
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ export(addin_spelling_grammar)
export(check_api)
export(check_api_connection)
export(check_api_key)
export(get_available_endpoints)
export(get_available_models)
export(get_ide_theme_info)
export(gpt_chat)
export(gpt_chat_in_source)
Expand All @@ -19,11 +21,19 @@ export(openai_create_edit)
export(run_chatgpt_app)
import(cli)
import(htmltools)
import(htmlwidgets)
import(rlang)
import(shiny)
importFrom(R6,R6Class)
importFrom(assertthat,assert_that)
importFrom(assertthat,is.count)
importFrom(assertthat,is.number)
importFrom(assertthat,is.string)
importFrom(glue,glue)
importFrom(jsonlite,fromJSON)
importFrom(magrittr,"%>%")
importFrom(purrr,map)
importFrom(purrr,map_chr)
importFrom(rlang,"%||%")
importFrom(stringr,str_remove)
importFrom(stringr,str_split_1)
162 changes: 162 additions & 0 deletions R/StreamHandler.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#' Stream handler for chat completions
#'
#' R6 class that allows to handle chat completions chunk by chunk.
#' It also adds methods to retrieve relevant data. This class DOES NOT make the request.
#'
#' Because `curl::curl_fetch_stream` blocks the R console until the stream finishes,
#' this class can take a shiny session object to handle communication with JS
#' without recurring to a `shiny::observe` inside a module server.
#'
#' @param session The shiny session it will send the message to (optional).
#' @param user_prompt The prompt for the chat completion. Only to be displayed in an HTML tag containing the prompt. (Optional).
#' @importFrom rlang %||%
#' @importFrom magrittr %>%
#' @importFrom R6 R6Class
#' @importFrom stringr str_remove str_split_1
#' @importFrom purrr map_chr map
#' @importFrom jsonlite fromJSON
StreamHandler <- R6::R6Class(
classname = "StreamHandler",
public = list(

#' @field current_value The content of the stream. It updates constantly until the stream ends.
current_value = NULL,

#' @field chunks The list of chunks streamed. It updates constantly until the stream ends.
chunks = list(),

#' @field shinySession Holds the `session` provided at initialization
shinySession = NULL,

#' @field user_message The `user_prompt` provided at initialization after being formatted with markdown.
user_message = NULL,

#' @description Start a StreamHandler. Recommended to be assigned to the `stream_handler` name.
initialize = function(session = NULL, user_prompt = NULL) {
self$current_value <- ""
self$shinySession <- session
self$user_message <- shiny::markdown(user_prompt)
},

#' @description The main reason this class exists. It reduces to stream to chunks and its current value. If the object finds a shiny session will send a `render-stream` message to JS.
#' @param x The streamed element. Preferably after conversion from raw.
handle_streamed_element = function(x) {
translated <- private$translate_element(x)
self$chunks <- c(self$chunks, translated)
self$current_value <- private$convert_chunks_into_response_str()

if (!is.null(self$shinySession)) {
# any communication with JS should be handled here!!
self$shinySession$sendCustomMessage(
type = "render-stream",
message = list(
user = self$user_message,
assistant = shiny::markdown(self$current_value)
)
)
}
},

#' @description Extract the message content as a message ready to be styled or appended to the chat history. Useful after the stream ends.
extract_message = function() {
list(
role = "assistant",
content = self$current_value
)
}
),
private = list(
# Translates a streamed element and converts it to chunk.
# Also handles the case of multiple elements in a single stream.
translate_element = function(x) {
x %>%
stringr::str_remove("^data: ") %>% # handle first element
stringr::str_remove("(\n\ndata: \\[DONE\\])?\n\n$") %>% # handle last element
stringr::str_split_1("\n\ndata: ") %>%
purrr::map(\(x) jsonlite::fromJSON(x, simplifyVector = FALSE))
},
# Reduces the chuks into just the message content.
convert_chunks_into_response_str = function() {
self$chunks %>%
purrr::map_chr(~ .x$choices[[1]]$delta$content %||% "") %>%
paste0(collapse = "")
}
)
)

#' Stream Chat Completion
#'
#' This function sends a prompt to the OpenAI API for chat-based completion and retrieves the streamed response.
#'
#' @param prompt The user's message or prompt.
#' @param history A list of previous messages in the conversation (optional).
#' @param element_callback A callback function to handle each element of the streamed response (optional).
#' @param style The style of the chat conversation (optional). Default is retrieved from the "gptstudio.code_style" option.
#' @param skill The skill to use for the chat conversation (optional). Default is retrieved from the "gptstudio.skill" option.
#' @param model The model to use for chat completion (optional). Default is "gpt-3.5-turbo".
#' @param openai_api_key The OpenAI API key (optional). By default, it is fetched from the "OPENAI_API_KEY" environment variable.
#'
#' @return the same as `curl::curl_fetch_stream`
#'
stream_chat_completion <-
function(prompt,
history = NULL,
element_callback = cat,
style = getOption("gptstudio.code_style"),
skill = getOption("gptstudio.skill"),
model = "gpt-3.5-turbo",
openai_api_key = Sys.getenv("OPENAI_API_KEY")) {
# Set the API endpoint URL
url <- "https://api.openai.com/v1/chat/completions"

# Set the request headers
headers <- list(
"Content-Type" = "application/json",
"Authorization" = paste0("Bearer ", openai_api_key)
)

# Set the new chat history so the system prompt depends
# on the current parameters and not in previous ones
instructions <- list(
list(
role = "system",
content = chat_create_system_prompt(style, skill, in_source = FALSE)
),
list(
role = "user",
content = prompt
)
)

history <- purrr::discard(history, ~ .x$role == "system")

messages <- c(history, instructions)

# Set the request body
body <- list(
"model" = model,
"stream" = TRUE,
"messages" = messages
)

# Create a new curl handle object
handle <- curl::new_handle() %>%
curl::handle_setheaders(.list = headers) %>%
curl::handle_setopt(postfields = jsonlite::toJSON(body, auto_unbox = TRUE)) # request body


# Make the streaming request using curl_fetch_stream()
curl::curl_fetch_stream(
url = url,
fun = \(x) {
element <- rawToChar(x)
element_callback(element) # Do whatever element_callback does
},
handle = handle
)
}

# stream_handler <- StreamHandler$new()

# stream_chat_completion(messages = "Count from 1 to 10")
# stream_chat_completion(messages = "Count from 1 to 10", element_callback = stream_handler$handle_streamed_element)
8 changes: 5 additions & 3 deletions R/addin_chatgpt.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ random_port <- function() {
#' @return This function returns nothing because is meant to run an app as a
#' side effect.
run_app_as_bg_job <- function(appDir = ".", job_name, host, port) {
job_script <- create_tmp_job_script(appDir = appDir,
port = port,
host = host)
job_script <- create_tmp_job_script(
appDir = appDir,
port = port,
host = host
)
rstudioapi::jobRunScript(job_script, name = job_name)
cli::cli_alert_success(
paste0("'", job_name, "'", " initialized as background job in RStudio")
Expand Down
9 changes: 4 additions & 5 deletions R/check_api.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,10 @@ check_api <- function() {
}

simple_api_check <- function(api_key = Sys.getenv("OPENAI_API_KEY")) {
response <- httr::GET(
"https://api.openai.com/v1/models",
httr::add_headers(Authorization = paste0("Bearer ", api_key))
)
httr::status_code(response)
request_base(task = "models", token = api_key) |>
httr2::req_error(is_error = \(resp) FALSE) |>
httr2::req_perform() |>
httr2::resp_status()
}

set_openai_api_key <- function() {
Expand Down
5 changes: 3 additions & 2 deletions R/gpt_queries.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ chat_create_system_prompt <- function(style, skill, in_source) {
arg_match(style, c("tidyverse", "base", "no preference"))
arg_match(skill, c("beginner", "intermediate", "advanced", "genius"))
assert_that(is.logical(in_source),
msg = "chat system prompt creation needs logical `in_source`")
msg = "chat system prompt creation needs logical `in_source`"
)

# nolint start
intro <- "You are a helpful chat bot that answers questions for an R programmer working in the RStudio IDE."
Expand All @@ -291,7 +292,7 @@ chat_create_system_prompt <- function(style, skill, in_source) {
} else {
""
}
#nolint end
# nolint end

glue("{intro} {about_skill} {about_style} {in_source_intructions}")
}
80 changes: 71 additions & 9 deletions R/mod_chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ mod_chat_ui <- function(id) {

bslib::card(
class = "h-100",

bslib::card_body(
class = "py-2 h-100",
div(
class = "d-flex flex-column h-100",
div(
class = "p-2 mh-100 overflow-auto",
shiny::uiOutput(ns("all_chats_box")),
welcomeMessageOutput(ns("welcome")),
shiny::uiOutput(ns("history")),
streamingMessageOutput(ns("streaming"))
),
div(
class = "mt-auto",
Expand All @@ -32,12 +33,68 @@ mod_chat_ui <- function(id) {
#'
mod_chat_server <- function(id, ide_colors = get_ide_theme_info()) {
moduleServer(id, function(input, output, session) {
prompt <- mod_prompt_server("prompt", ide_colors)
rv <- reactiveValues()
rv$stream_ended <- 0L

waiter_color <-
if (ide_colors$is_dark) "rgba(255,255,255,0.5)" else "rgba(0,0,0,0.5)"

prompt <- mod_prompt_server("prompt")

output$welcome <- renderWelcomeMessage({
welcomeMessage(ide_colors)
}) %>%
bindEvent(prompt$clear_history)


output$streaming <- renderStreamingMessage({
# This has display: none by default. It is inly shown when receiving an stream
# After the stream is completed it will reset.
streamingMessage(ide_colors)
}) %>%
bindEvent(rv$stream_ended)

output$all_chats_box <- shiny::renderUI({

output$history <- shiny::renderUI({
prompt$chat_history %>%
style_chat_history(ide_colors = ide_colors)
})
}) |>
bindEvent(prompt$chat_history, prompt$clear_history)


shiny::observe({
# waiter::waiter_show(
# html = shiny::tagList(waiter::spin_flower(),
# shiny::h3("Asking ChatGPT...")),
# color = waiter_color
# )

stream_handler <- StreamHandler$new(
session = session,
user_prompt = prompt$input_prompt
)

stream_chat_completion(
prompt = prompt$input_prompt,
history = prompt$chat_history,
style = prompt$input_style,
skill = prompt$input_skill,
element_callback = stream_handler$handle_streamed_element
)

prompt$chat_history <- chat_history_append(
history = prompt$chat_history,
role = "assistant",
content = stream_handler$current_value
)

rv$stream_ended <- rv$stream_ended + 1L

# showNotification("test", session = session)

# waiter::waiter_hide()
}) %>%
shiny::bindEvent(prompt$start_stream, ignoreInit = TRUE)

# testing ----
exportTestValues(
Expand Down Expand Up @@ -90,12 +147,14 @@ style_chat_message <- function(message, ide_colors = get_ide_theme_info()) {

icon_name <- switch(message$role,
"user" = "fas fa-user",
"assistant" = "fas fa-robot")
"assistant" = "fas fa-robot"
)

# nolint start
position_class <- switch(message$role,
"user" = "justify-content-end",
"assistant" = "justify-content-start")
"assistant" = "justify-content-start"
)
# nolint end

htmltools::div(
Expand All @@ -107,8 +166,11 @@ style_chat_message <- function(message, ide_colors = get_ide_theme_info()) {
`background-color` = colors$bg_color
),
fontawesome::fa(icon_name),
htmltools::tagList(
shiny::markdown(message$content)
htmltools::tags$div(
class = glue("{message$role}-message-wrapper"),
htmltools::tagList(
shiny::markdown(message$content)
)
)
)
)
Expand Down
Loading

0 comments on commit 6905643

Please sign in to comment.