Skip to content

Commit

Permalink
Harrison/jina (langchain-ai#2043)
Browse files Browse the repository at this point in the history
Co-authored-by: numb3r3 <[email protected]>
Co-authored-by: felix-wang <[email protected]>
  • Loading branch information
3 people authored Mar 28, 2023
1 parent d0a56f4 commit eff5eed
Show file tree
Hide file tree
Showing 7 changed files with 1,272 additions and 115 deletions.
18 changes: 18 additions & 0 deletions docs/ecosystem/jina.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Jina

This page covers how to use the Jina ecosystem within LangChain.
It is broken into two parts: installation and setup, and then references to specific Jina wrappers.

## Installation and Setup
- Install the Python SDK with `pip install jina`
- Get a Jina AI Cloud auth token from [here](https://cloud.jina.ai/settings/tokens) and set it as an environment variable (`JINA_AUTH_TOKEN`)

## Wrappers

### Embeddings

There exists a Jina Embeddings wrapper, which you can access with
```python
from langchain.embeddings import JinaEmbeddings
```
For a more detailed walkthrough of this, see [this notebook](../modules/indexes/examples/embeddings.ipynb)
101 changes: 101 additions & 0 deletions docs/modules/models/text_embedding/examples/jina.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1c0cf975",
"metadata": {},
"source": [
"# Jina\n",
"\n",
"Let's load the Jina Embedding class."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d94c62b4",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import JinaEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "523a09e3",
"metadata": {},
"outputs": [],
"source": [
"embeddings = JinaEmbeddings(jina_auth_token=jina_auth_token, model_name=\"ViT-B-32::openai\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b212bd5a",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57db66bd",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b790fd09",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text])"
]
},
{
"cell_type": "markdown",
"id": "6f3607a0",
"metadata": {},
"source": [
"In the above example, `ViT-B-32::openai`, OpenAI's pretrained `ViT-B-32` model is used. For a full list of models, see [here](https://cloud.jina.ai/user/inference/model/63dca9df5a0da83009d519cd)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd5f148e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions langchain/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
HuggingFaceInstructEmbeddings,
)
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
from langchain.embeddings.jina import JinaEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.sagemaker_endpoint import SagemakerEndpointEmbeddings
from langchain.embeddings.self_hosted import SelfHostedEmbeddings
Expand All @@ -24,6 +25,7 @@
"OpenAIEmbeddings",
"HuggingFaceEmbeddings",
"CohereEmbeddings",
"JinaEmbeddings",
"HuggingFaceHubEmbeddings",
"TensorflowHubEmbeddings",
"SagemakerEndpointEmbeddings",
Expand Down
98 changes: 98 additions & 0 deletions langchain/embeddings/jina.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Wrapper around Jina embedding models."""

import os
from typing import Any, Dict, List, Optional

import requests
from pydantic import BaseModel, root_validator

from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env


class JinaEmbeddings(BaseModel, Embeddings):
client: Any #: :meta private:

model_name: str = "ViT-B-32::openai"
"""Model name to use."""

jina_auth_token: Optional[str] = None
jina_api_url: str = "https://api.clip.jina.ai/api/v1/models/"
request_headers: Optional[dict] = None

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that auth token exists in environment."""
# Set Auth
jina_auth_token = get_from_dict_or_env(
values, "jina_auth_token", "JINA_AUTH_TOKEN"
)
values["jina_auth_token"] = jina_auth_token
values["request_headers"] = (("authorization", jina_auth_token),)

# Test that package is installed
try:
import jina
except ImportError:
raise ValueError(
"Could not import `jina` python package. "
"Please it install it with `pip install jina`."
)

# Setup client
jina_api_url = os.environ.get("JINA_API_URL", values["jina_api_url"])
model_name = values["model_name"]
try:
resp = requests.get(
jina_api_url + f"?model_name={model_name}",
headers={"Authorization": jina_auth_token},
)

if resp.status_code == 401:
raise ValueError(
"The given Jina auth token is invalid. "
"Please check your Jina auth token."
)
elif resp.status_code == 404:
raise ValueError(
f"The given model name `{model_name}` is not valid. "
f"Please go to https://cloud.jina.ai/user/inference "
f"and create a model with the given model name."
)
resp.raise_for_status()

endpoint = resp.json()["endpoints"]["grpc"]
values["client"] = jina.Client(host=endpoint)
except requests.exceptions.HTTPError as err:
raise ValueError(f"Error: {err!r}")
return values

def _post(self, docs: List[Any], **kwargs: Any) -> Any:
payload = dict(inputs=docs, metadata=self.request_headers, **kwargs)
return self.client.post(on="/encode", **payload)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Jina's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
from docarray import Document, DocumentArray

embeddings = self._post(
docs=DocumentArray([Document(text=t) for t in texts])
).embeddings
return [list(map(float, e)) for e in embeddings]

def embed_query(self, text: str) -> List[float]:
"""Call out to Jina's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
from docarray import Document, DocumentArray

embedding = self._post(docs=DocumentArray([Document(text=text)])).embeddings[0]
return list(map(float, embedding))
Loading

0 comments on commit eff5eed

Please sign in to comment.