Skip to content

Commit

Permalink
modifying freeze.txt and adjusting max_tokens in
Browse files Browse the repository at this point in the history
reducing freeze.txt to only relevant modules which allow notebook to run in colab. Updating models.py to add option for max_tokens to give concise outputs for the phi2 model
  • Loading branch information
faaiz-25 committed Mar 25, 2024
1 parent c3a763a commit cf496f0
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 236 deletions.
Binary file added .DS_Store
Binary file not shown.
239 changes: 10 additions & 229 deletions freeze.txt
Original file line number Diff line number Diff line change
@@ -1,229 +1,10 @@
aiohttp == 3.9.1
aiosignal == 1.3.1
annotated-types == 0.6.0
anthropic == 0.10.0
anyio == 4.2.0
appnope == 0.1.3
argon2-cffi == 23.1.0
argon2-cffi-bindings == 21.2.0
arrow == 1.3.0
asgiref == 3.7.2
asttokens == 2.4.1
async-lru == 2.0.4
attrs == 23.2.0
Automat == 22.10.0
Babel == 2.14.0
backoff == 2.2.1
bcrypt == 4.1.2
beautifulsoup4 == 4.12.3
bleach == 6.1.0
build == 1.0.3
cachetools == 5.3.2
certifi == 2023.11.17
cffi == 1.16.0
charset-normalizer == 3.3.2
chroma-hnswlib == 0.7.3
chromadb == 0.4.22
click == 8.1.7
coloredlogs == 15.0.1
comm == 0.2.1
constantly == 23.10.4
cryptography == 42.0.0
cssselect == 1.2.0
dataclasses-json == 0.6.3
debugpy == 1.8.0
decorator == 5.1.1
defusedxml == 0.7.1
Deprecated == 1.2.14
distro == 1.9.0
executing == 2.0.1
fastapi == 0.109.0
fastjsonschema == 2.19.1
filelock == 3.13.1
flatbuffers == 23.5.26
fqdn == 1.5.1
frozenlist == 1.4.1
fsspec == 2023.12.2
google-auth == 2.26.2
googleapis-common-protos == 1.62.0
greenlet == 3.0.3
grpcio == 1.60.0
h11 == 0.14.0
httpcore == 1.0.2
httptools == 0.6.1
httpx == 0.26.0
huggingface-hub == 0.20.3
humanfriendly == 10.0
hyperlink == 21.0.0
idna == 3.6
importlib-metadata == 6.11.0
importlib-resources == 6.1.1
incremental == 22.10.0
iniconfig == 2.0.0
ipykernel == 6.29.0
ipython == 8.20.0
ipywidgets == 8.1.1
isoduration == 20.11.0
itemadapter == 0.8.0
itemloaders == 1.1.0
jedi == 0.19.1
Jinja2 == 3.1.3
jmespath == 1.0.
joblib == 1.3.2
json5 == 0.9.14
jsonlines == 4.0.0
jsonpatch == 1.33
jsonpointer == 2.4
jsonschema == 4.21.1
jsonschema-specifications == 2023.12.1
jupyter == 1.0.0
jupyter-console == 6.6.3
jupyter-events == 0.9.0
jupyter-lsp == 2.2.2
jupyter_client == 8.6.0
jupyter_core == 5.7.1
jupyter_server == 2.12.5
jupyter_server_terminals == 0.5.2
jupyterlab == 4.0.11
jupyterlab-widgets == 3.0.9
jupyterlab_pygments == 0.3.0
jupyterlab_server == 2.25.2
kubernetes == 29.0.0
langchain == 0.1.2
langchain-community == 0.0.14
langchain-core == 0.1.14
langsmith == 0.0.83
lxml == 5.1.0
MarkupSafe == 2.1.4
marshmallow == 3.20.2
matplotlib-inline == 0.1.6
mistune == 3.0.2
mmh3 == 4.1.0
monotonic == 1.6
mpmath == 1.3.0
multidict == 6.0.4
mypy-extensions == 1.0.0
nbclient == 0.9.0
nbconvert == 7.14.2
nbformat == 5.9.2
nest-asyncio == 1.6.0
networkx == 3.2.1
nltk == 3.8.1
notebook == 7.0.7
notebook_shim == 0.2.3
numpy == 1.26.3
oauthlib == 3.2.2
onnxruntime == 1.16.3
opentelemetry-api == 1.22.0
opentelemetry-exporter-otlp-proto-common == 1.22.0
opentelemetry-exporter-otlp-proto-grpc == 1.22.0
opentelemetry-instrumentation == 0.43b0
opentelemetry-instrumentation-asgi == 0.43b0
opentelemetry-instrumentation-fastapi == 0.43b0
opentelemetry-proto == 1.22.0
opentelemetry-sdk == 1.22.0
opentelemetry-semantic-conventions == 0.43b0
opentelemetry-util-http == 0.43b0
overrides == 7.6.0
packaging == 23.2
pandas == 2.2.0
pandocfilters == 1.5.1
parsel == 1.8.1
parso == 0.8.3
pexpect == 4.9.0
pillow == 10.2.0
pip == 23.2.1
platformdirs == 4.1.0
pluggy == 1.3.0
posthog == 3.3.2
prometheus-client == 0.19.0
prompt-toolkit == 3.0.43
Protego == 0.3.0
protobuf == 4.25.2
psutil == 5.9.8
ptyprocess == 0.7.0
pulsar-client == 3.4.0
pure-eval == 0.2.2
pyasn1 == 0.5.1
pyasn1-modules == 0.3.0
pycparser == 2.21
pydantic == 2.5.3
pydantic_core == 2.14.6
PyDispatcher == 2.0.7
Pygments == 2.17.2
pyOpenSSL == 24.0.0
PyPika == 0.48.9
pyproject_hooks == 1.0.0
pytest == 7.4.4
pytest-html == 4.1.1
pytest-metadata == 3.0.0
python-dateutil == 2.8.2
python-dotenv == 1.0.1
python-json-logger == 2.0.7
pytz == 2023.3.post1
PyYAML == 6.0.1
pyzmq == 25.1.2
qtconsole == 5.5.1
QtPy == 2.4.1
queuelib == 1.6.2
referencing == 0.32.1
regex == 2023.12.25
requests == 2.31.0
requests-file == 1.5.1
requests-oauthlib == 1.3.1
rfc3339-validator == 0.1.4
rfc3986-validator == 0.1.1
rpds-py == 0.17.1
rsa == 4.9
safetensors == 0.4.2
scikit-learn == 1.4.0
scipy == 1.12.0
Scrapy == 2.11.0
Send2Trash == 1.8.2
sentence-transformers == 2.2.2
sentencepiece == 0.1.99
service-identity == 24.1.0
setuptools == 69.0.3
six == 1.16.0
sniffio == 1.3.0
soupsieve == 2.5
SQLAlchemy == 2.0.25
stack-data == 0.6.3
starlette == 0.35.1
sympy == 1.12
tdqm == 0.0.1
tenacity == 8.2.3
terminado == 0.18.0
threadpoolctl == 3.2.0
tinycss2 == 1.2.1
tldextract == 5.1.1
tokenizers == 0.15.1
toml == 0.10.2
torch == 2.1.2
torchvision == 0.16.2
tornado == 6.4
tqdm == 4.66.1
traitlets == 5.14.1
transformers == 4.37.0
Twisted == 22.10.0
typer == 0.9.0
types-python-dateutil == 2.8.19.20240106
typing-inspect == 0.9.0
typing_extensions == 4.9.0
tzdata == 2023.4
uri-template == 1.3.0
urllib3 == 2.1.0
uvicorn == 0.27.0
uvloop == 0.19.0
w3lib == 2.1.2
watchfiles == 0.21.0
wcwidth == 0.2.13
ebcolors == 1.13
webencodings == 0.5.1
websocket-client == 1.7.0
websockets == 12.0
widgetsnbextension == 4.0.9
wrapt == 1.16.0
yarl == 1.9.4
ipp == 3.17.0
zope.interface == 6.1
transformers
langchain_community
bitsandbytes
langchain
accelerate
tensorflow == 2.15
chromadb
unstructured
sentence-transformers
faiss-cpu
14 changes: 7 additions & 7 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch


def initialise_phi2():
def initialise_phi2(max_tokens):
"""initialise phi2 model from HuggingFace and output as a langchain model object
"""
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
Expand All @@ -25,12 +25,12 @@ def initialise_phi2():

#load in phi-2 model - a small model with 2B parameters
model_id = "microsoft/phi-2"
#set max tokens to 1000 as small models such as phi-2 will produce verbose outputs
max_new_tokens = 1000


#create hugging face pipeline for phi2 model using the max_tokens parameter
#max_tokens initalised to 500 characters to produce concise relevant respons
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,quantization_config=quantization_config)#, device_map='auto')
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1000)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens)

#set logging information to info to avoid warnings
logging.set_verbosity_error()
Expand Down Expand Up @@ -84,7 +84,7 @@ def initialise_anthropic():


class RagPipeline:
def __init__(self, EMBEDDING_MODEL, PERSIST_DIRECTORY, stuff_documents_prompt=STUFF_DOCUMENTS_PROMPT, inject_metadata_prompt=INJECT_METADATA_PROMPT, hyde_prompt = HYDE_PROMPT, device=None, model_type="anthropic"):
def __init__(self, EMBEDDING_MODEL, PERSIST_DIRECTORY, stuff_documents_prompt=STUFF_DOCUMENTS_PROMPT, inject_metadata_prompt=INJECT_METADATA_PROMPT, hyde_prompt = HYDE_PROMPT, device=None, model_type="anthropic", max_tokens=500):

if device is None:
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
Expand All @@ -93,7 +93,7 @@ def __init__(self, EMBEDDING_MODEL, PERSIST_DIRECTORY, stuff_documents_prompt=ST

#if user wants to run phi2 model insert this as the prompt for the stuff documents chain if not default to anthropic prompt
if model_type == 'phi2':
self.llm = initialise_phi2()
self.llm = initialise_phi2(max_tokens=max_tokens)
stuff_documents_prompt = PHI2_PROMPT
else:
self.llm = initialise_anthropic()
Expand Down

0 comments on commit cf496f0

Please sign in to comment.