Skip to content

Commit

Permalink
update MyMagicAI (run-llama#11263)
Browse files Browse the repository at this point in the history
* update

* address comments
  • Loading branch information
logan-markewich authored Feb 22, 2024
1 parent 9a13dee commit 439e1b6
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 22 deletions.
34 changes: 19 additions & 15 deletions docs/examples/llm/mymagic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@
"outputs": [],
"source": [
"llm = MyMagicAI(\n",
" api_key=\"your_api_key\",\n",
" storage_provider=\"your_storage_provider\", # s3, gcs\n",
" bucket_name=\"your_bucket_name\",\n",
" session=\"your_session\", # files should be located in this folder on which batch inference will be run\n",
" system_prompt=\"Answer the question succinctly\",\n",
" api_key=\"your-api-key\",\n",
" storage_provider=\"s3\", # s3, gcs\n",
" bucket_name=\"your-bucket-name\",\n",
" session=\"your-session-name\", # files should be located in this folder on which batch inference will be run\n",
" role_arn=\"your-role-arn\",\n",
" system_prompt=\"your-system-prompt\",\n",
")"
]
},
Expand All @@ -75,9 +76,9 @@
"outputs": [],
"source": [
"resp = llm.complete(\n",
" question=\"Summarize the document!\",\n",
" model=\"mistral7b\",\n",
" max_tokens=10, # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n",
" question=\"your-question\",\n",
" model=\"chhoose-model\", # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n",
" max_tokens=5, # number of tokens to generate, default is 10\n",
")"
]
},
Expand Down Expand Up @@ -116,14 +117,17 @@
"source": [
"async def main():\n",
" allm = MyMagicAI(\n",
" api_key=\"your_api_key\",\n",
" storage_provider=\"your_storage_provider\",\n",
" bucket_name=\"your_bucket_name\",\n",
" session=\"your_session_name\",\n",
" system_prompt=\"your_system_prompt\",\n",
" api_key=\"your-api-key\",\n",
" storage_provider=\"s3\", # s3, gcs\n",
" bucket_name=\"your-bucket-name\",\n",
" session=\"your-session-name\", # files should be located in this folder on which batch inference will be run\n",
" role_arn=\"your-role-arn\",\n",
" system_prompt=\"your-system-prompt\",\n",
" )\n",
" response = await allm.acomplete(\n",
" question=\"your_question\", model=\"mistral7b\", max_tokens=10\n",
" question=\"your-question\",\n",
" model=\"chhoose-model\", # currently we support mistral7b, llama7b, mixtral8x7b,codellama70b, llama70b, more to come...\n",
" max_tokens=5, # number of tokens to generate, default is 10\n",
" )\n",
"\n",
" print(\"Async completion response:\", response)"
Expand All @@ -135,7 +139,7 @@
"metadata": {},
"outputs": [],
"source": [
"asyncio.run(main())"
"await main()"
]
}
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class MyMagicAI(LLM):
max_tokens: int = Field(
default=10, description="The maximum number of tokens to generate."
)
question = Field(default="", description="The user question.")
storage_provider: str = Field(
default="gcs", description="The storage provider to use."
)
Expand Down Expand Up @@ -105,11 +106,17 @@ async def _get_result(self, task_id: str) -> Dict[str, Any]:
return resp.json()

async def acomplete(
self, question: str, model: str, max_tokens: int, poll_interval: float = 1.0
self,
question: str,
model: Optional[str] = None,
max_tokens: Optional[int] = None,
poll_interval: float = 1.0,
) -> CompletionResponse:
self.question_data["question"] = question
self.question_data["model"] = model
self.question_data["max_tokens"] = max_tokens
self.model = self.question_data["model"] = model or self.model
self.max_tokens = self.question_data["max_tokens"] = (
max_tokens or self.max_tokens
)

task_response = await self._submit_question(self.question_data)
task_id = task_response.get("task_id")
Expand All @@ -120,11 +127,17 @@ async def acomplete(
await asyncio.sleep(poll_interval)

def complete(
self, question: str, model: str, max_tokens: int, poll_interval: float = 1.0
self,
question: str,
model: Optional[str] = None,
max_tokens: Optional[int] = None,
poll_interval: float = 1.0,
) -> CompletionResponse:
self.question_data["question"] = question
self.question_data["model"] = model
self.question_data["max_tokens"] = max_tokens
self.model = self.question_data["model"] = model or self.model
self.max_tokens = self.question_data["max_tokens"] = (
max_tokens or self.max_tokens
)

task_response = self._submit_question_sync(self.question_data)
task_id = task_response.get("task_id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ description = "llama-index llms mymagic integration"
license = "MIT"
name = "llama-index-llms-mymagic"
readme = "README.md"
version = "0.1.0"
version = "0.1.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<3.12"
Expand Down

0 comments on commit 439e1b6

Please sign in to comment.