Skip to content

Commit

Permalink
Fix content-part encoding and decoding for Google API. (#212)
Browse files Browse the repository at this point in the history
* Make JsonProcessor process ContentPart properly

* Explicitly remove ```json ```

* Add a failing test for #209

* Pass the tests for #209

* Fix JasonProcessor content processing when regex is present

* Add live google ai call tests for messages with image parts
  • Loading branch information
vkryukov authored Dec 12, 2024
1 parent 315e787 commit 693e918
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 14 deletions.
53 changes: 50 additions & 3 deletions lib/chat_models/chat_google_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,60 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
}
end

def for_api(%Message{} = message) do
def for_api(%Message{content: content} = message) when is_binary(content) do
%{
"role" => map_role(message.role),
"parts" => [%{"text" => message.content}]
}
end

def for_api(%Message{content: content} = message) when is_list(content) do
%{
"role" => message.role,
"parts" => Enum.map(content, &for_api/1)
}
end

def for_api(%ContentPart{type: :text} = part) do
%{"text" => part.content}
end

# Supported image types: png, jpeg, webp, heic, heif: https://ai.google.dev/gemini-api/docs/vision?lang=rest#technical-details-image
def for_api(%ContentPart{type: :image} = part) do
mime_type =
case Keyword.get(part.options || [], :media, nil) do
:png ->
"image/png"

type when type in [:jpeg, :jpg] ->
"image/jpeg"

:webp ->
"image/webp"

:heic ->
"image/heic"

:heif ->
"image/heif"

type when is_binary(type) ->
"image/type"

other ->
message = "Received unsupported media type for ContentPart: #{inspect(other)}"
Logger.error(message)
raise LangChainError, message
end

%{
"inline_data" => %{
"mime_type" => mime_type,
"data" => part.content
}
}
end

def for_api(%ToolCall{} = call) do
%{
"functionCall" => %{
Expand Down Expand Up @@ -598,12 +641,16 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
def do_process_response(_model, {:error, %Jason.DecodeError{} = response}, _) do
error_message = "Received invalid JSON: #{inspect(response)}"
Logger.error(error_message)
{:error, LangChainError.exception(type: "invalid_json", message: error_message, original: response)}

{:error,
LangChainError.exception(type: "invalid_json", message: error_message, original: response)}
end

def do_process_response(_model, other, _) do
Logger.error("Trying to process an unexpected response. #{inspect(other)}")
{:error, LangChainError.exception(type: "unexpected_response", message: "Unexpected response")}

{:error,
LangChainError.exception(type: "unexpected_response", message: "Unexpected response")}
end

@doc false
Expand Down
13 changes: 11 additions & 2 deletions lib/message_processors/json_processor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ defmodule LangChain.MessageProcessors.JsonProcessor do
@spec run(LLMChain.t(), Message.t()) ::
{:cont, Message.t()} | {:halt, Message.t()}
def run(%LLMChain{} = chain, %Message{} = message) do
case Jason.decode(message.processed_content) do
case Jason.decode(content_to_string(message.processed_content)) do
{:ok, parsed} ->
if chain.verbose, do: IO.puts("Parsed JSON text to a map")
{:cont, %Message{message | processed_content: parsed}}
Expand All @@ -122,7 +122,9 @@ defmodule LangChain.MessageProcessors.JsonProcessor do
end

def run(%LLMChain{} = chain, %Message{} = message, regex_pattern) do
case Regex.run(regex_pattern, message.processed_content, capture: :all_but_first) do
case Regex.run(regex_pattern, content_to_string(message.processed_content),
capture: :all_but_first
) do
[json] ->
if chain.verbose, do: IO.puts("Extracted JSON text from message")
# run recursive call on just the extracted JSON
Expand All @@ -132,4 +134,11 @@ defmodule LangChain.MessageProcessors.JsonProcessor do
{:halt, Message.new_user!("ERROR: No JSON found")}
end
end

defp content_to_string([
%LangChain.Message.ContentPart{type: :text, content: content}
]),
do: content

defp content_to_string(content), do: content
end
8 changes: 3 additions & 5 deletions notebooks/context-specific-image-descriptions.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

```elixir
Mix.install([
{:langchain, "~> 0.3.0-rc.0"},
{:langchain, github: "brainlid/langchain"},
{:kino, "~> 0.12.0"}
])
```
Expand Down Expand Up @@ -181,7 +181,7 @@ image_data_from_other_system = "image of urban art mural on underpass at 507 Kin
%{llm: openai_chat_model, verbose: true}
|> LLMChain.new!()
|> LLMChain.apply_prompt_templates(messages, %{extra_image_info: image_data_from_other_system})
|> LLMChain.message_processors([JsonProcessor.new!()])
|> LLMChain.message_processors([JsonProcessor.new!(~r/```json(.*?)```/s)])
|> LLMChain.run(mode: :until_success)

updated_chain.last_message.processed_content
Expand Down Expand Up @@ -242,7 +242,7 @@ image_data_from_other_system = "image of urban art mural on underpass at 507 Kin
%{llm: anthropic_chat_model, verbose: true}
|> LLMChain.new!()
|> LLMChain.apply_prompt_templates(messages, %{extra_image_info: image_data_from_other_system})
|> LLMChain.message_processors([JsonProcessor.new!()])
|> LLMChain.message_processors([JsonProcessor.new!(~r/```json(.*?)```/s)])
|> LLMChain.run(mode: :until_success)

updated_chain.last_message.processed_content
Expand All @@ -262,5 +262,3 @@ Here's what I got from it:
```

We would want to run multiple tests on a small sampling of images and tweak our prompt until we are happy with the result. Then, we can process full batch and save our work as a template for future projects as well.

<!-- livebook:{"offset":12761,"stamp":{"token":"XCP.W16VHoMa17Ik5HZEODX4xG_3efAZOoT53nwCTV0ILJBlJPOfjaoVorequscNTIpjctd5Dd_rFjn2mYnQ3HBs-HEgL3Ndv-JDxG2NMRBdcbJi_vREiaEJT2lrNKafOvhP9ZvW698i28G9jhon35Zc","version":2}} -->
86 changes: 82 additions & 4 deletions test/chat_models/chat_google_ai_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,47 @@ defmodule ChatModels.ChatGoogleAITest do
} = tool_result
end

test "generate a map containing a text and an image part (bug #209)", %{google_ai: google_ai} do
messages = [
%LangChain.Message{
content:
"You are an expert at providing an image description for assistive technology and SEO benefits.",
role: :system
},
%LangChain.Message{
content: [
%LangChain.Message.ContentPart{
type: :text,
content: "This is the text."
},
%LangChain.Message.ContentPart{
type: :image,
content: "/9j/4AAQSkz",
options: [media: :jpg, detail: "low"]
}
],
role: :user
}
]

data = ChatGoogleAI.for_api(google_ai, messages, [])
assert %{"contents" => [msg1]} = data

assert %{
"parts" => [
%{
"text" => "This is the text."
},
%{
"inline_data" => %{
"mime_type" => "image/jpeg",
"data" => "/9j/4AAQSkz"
}
}
]
} = msg1
end

test "translates a Message with function results to the expected structure" do
expected =
%{
Expand Down Expand Up @@ -402,7 +443,9 @@ defmodule ChatModels.ChatGoogleAITest do
]
}

assert [{:error, %LangChainError{} = error}] = ChatGoogleAI.do_process_response(model, response)
assert [{:error, %LangChainError{} = error}] =
ChatGoogleAI.do_process_response(model, response)

assert error.type == "changeset"
assert error.message == "role: is invalid"
end
Expand Down Expand Up @@ -483,22 +526,29 @@ defmodule ChatModels.ChatGoogleAITest do
}
}

assert {:error, %LangChainError{} = error} = ChatGoogleAI.do_process_response(model, response)
assert {:error, %LangChainError{} = error} =
ChatGoogleAI.do_process_response(model, response)

assert error.type == nil
assert error.message == "Invalid request"
end

test "handles Jason.DecodeError", %{model: model} do
response = {:error, %Jason.DecodeError{}}

assert {:error, %LangChainError{} = error} = ChatGoogleAI.do_process_response(model, response)
assert {:error, %LangChainError{} = error} =
ChatGoogleAI.do_process_response(model, response)

assert error.type == "invalid_json"
assert "Received invalid JSON:" <> _ = error.message
end

test "handles unexpected response with error", %{model: model} do
response = %{}
assert {:error, %LangChainError{} = error} = ChatGoogleAI.do_process_response(model, response)

assert {:error, %LangChainError{} = error} =
ChatGoogleAI.do_process_response(model, response)

assert error.type == "unexpected_response"
assert error.message == "Unexpected response"
end
Expand Down Expand Up @@ -766,4 +816,32 @@ defmodule ChatModels.ChatGoogleAITest do
assert message.role == :assistant
end
end

@tag live_call: true, live_google_ai: true
test "image classification with Google AI model" do
alias LangChain.Chains.LLMChain
alias LangChain.Message
alias LangChain.Message.ContentPart
alias LangChain.Utils.ChainResult

model = ChatGoogleAI.new!(%{temperature: 0, stream: false, model: "gemini-1.5-flash"})

image_data =
File.read!("test/support/images/barn_owl.jpg")
|> Base.encode64()

{:ok, updated_chain} =
%{llm: model, verbose: false, stream: false}
|> LLMChain.new!()
|> LLMChain.add_message(
Message.new_user!([
ContentPart.text!("Please describe the image."),
ContentPart.image!(image_data, media: :jpg)
])
)
|> LLMChain.run()

{:ok, string} = ChainResult.to_string(updated_chain)
assert string =~ "owl"
end
end

0 comments on commit 693e918

Please sign in to comment.