Queer European MD passionate about IT
Bladeren bron

Merge pull request #197 from dataquestio/llm-class-updates-gp909-and-gp903

Llm class updates gp909 and gp903
acstrahl 6 maanden geleden
bovenliggende
commit
9aafb0b164
2 gewijzigde bestanden met toevoegingen van 58 en 67 verwijderingen
  1. 30 50
      Mission903Solutions.py
  2. 28 17
      Mission909Solutions.ipynb

+ 30 - 50
Mission903Solutions.py

@@ -5,21 +5,29 @@ from datetime import datetime
 import os
 import streamlit as st
 
-api_key = os.environ["OPENAI_API_KEY"]
+DEFAULT_API_KEY = os.environ.get("TOGETHER_API_KEY")
+DEFAULT_BASE_URL = "https://api.together.xyz/v1"
+DEFAULT_MODEL = "meta-llama/Llama-3-8b-chat-hf"
+DEFAULT_TEMPERATURE = 0.7
+DEFAULT_MAX_TOKENS = 512
+DEFAULT_TOKEN_BUDGET = 4096
 
 class ConversationManager:
-    def __init__(self, api_key, base_url="https://api.openai.com/v1", history_file=None, default_model="gpt-3.5-turbo", default_temperature=0.7, default_max_tokens=150, token_budget=4096):
-        self.client = OpenAI(api_key=api_key)
-        self.base_url = base_url
-        if history_file is None:
-            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
-            self.history_file = f"conversation_history_{timestamp}.json"
-        else:
-            self.history_file = history_file
-        self.default_model = default_model
-        self.default_temperature = default_temperature
-        self.default_max_tokens = default_max_tokens
-        self.token_budget = token_budget
+    def __init__(self, api_key=None, base_url=None, model=None, history_file=None, temperature=None, max_tokens=None, token_budget=None):
+        if not api_key:
+            api_key = DEFAULT_API_KEY
+        if not base_url:
+            base_url = DEFAULT_BASE_URL
+            
+        self.client = OpenAI(
+            api_key=api_key,
+            base_url=base_url
+        )
+
+        self.model = model if model else DEFAULT_MODEL
+        self.temperature = temperature if temperature else DEFAULT_TEMPERATURE
+        self.max_tokens = max_tokens if max_tokens else DEFAULT_MAX_TOKENS
+        self.token_budget = token_budget if token_budget else DEFAULT_TOKEN_BUDGET
 
         self.system_messages = {
             "sassy_assistant": "You are a sassy assistant that is fed up with answering questions.",
@@ -28,19 +36,18 @@ class ConversationManager:
             "custom": "Enter your custom system message here."
         }
         self.system_message = self.system_messages["sassy_assistant"]  # Default persona
-
-        self.load_conversation_history()
+        self.conversation_history = [{"role": "system", "content": self.system_message}]
 
     def count_tokens(self, text):
         try:
-            encoding = tiktoken.encoding_for_model(self.default_model)
+            encoding = tiktoken.encoding_for_model(self.model)
         except KeyError:
-            print(f"Warning: Model '{self.default_model}' not found. Using 'gpt-3.5-turbo' encoding as default.")
-            encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
+            encoding = tiktoken.get_encoding("cl100k_base")
 
         tokens = encoding.encode(text)
         return len(tokens)
 
+
     def total_tokens_used(self):
         try:
             return sum(self.count_tokens(message['content']) for message in self.conversation_history)
@@ -80,9 +87,9 @@ class ConversationManager:
             print(f"An unexpected error occurred while updating the system message in the conversation history: {e}")
 
     def chat_completion(self, prompt, temperature=None, max_tokens=None, model=None):
-        temperature = temperature if temperature is not None else self.default_temperature
-        max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
-        model = model if model is not None else self.default_model
+        temperature = temperature if temperature is not None else self.temperature
+        max_tokens = max_tokens if max_tokens is not None else self.max_tokens
+        model = model if model is not None else self.model
 
         self.conversation_history.append({"role": "user", "content": prompt})
 
@@ -101,35 +108,11 @@ class ConversationManager:
 
         ai_response = response.choices[0].message.content
         self.conversation_history.append({"role": "assistant", "content": ai_response})
-        self.save_conversation_history()
 
         return ai_response
     
-    def load_conversation_history(self):
-        try:
-            with open(self.history_file, "r") as file:
-                self.conversation_history = json.load(file)
-        except FileNotFoundError:
-            self.conversation_history = [{"role": "system", "content": self.system_message}]
-        except json.JSONDecodeError:
-            print("Error reading the conversation history file. Starting with an empty history.")
-            self.conversation_history = [{"role": "system", "content": self.system_message}]
-
-    def save_conversation_history(self):
-        try:
-            with open(self.history_file, "w") as file:
-                json.dump(self.conversation_history, file, indent=4)
-        except IOError as e:
-            print(f"An I/O error occurred while saving the conversation history: {e}")
-        except Exception as e:
-            print(f"An unexpected error occurred while saving the conversation history: {e}")
-
     def reset_conversation_history(self):
         self.conversation_history = [{"role": "system", "content": self.system_message}]
-        try:
-            self.save_conversation_history()  # Attempt to save the reset history to the file
-        except Exception as e:
-            print(f"An unexpected error occurred while resetting the conversation history: {e}")
 
 ### Streamlit code ###
 st.title("Sassy Chatbot :face_with_rolling_eyes:")
@@ -139,7 +122,7 @@ st.sidebar.header("Options")
 
 # Initialize the ConversationManager object
 if 'chat_manager' not in st.session_state:
-    st.session_state['chat_manager'] = ConversationManager(api_key)
+    st.session_state['chat_manager'] = ConversationManager()
 
 chat_manager = st.session_state['chat_manager']
 
@@ -181,7 +164,4 @@ if user_input:
 for message in conversation_history:
     if message["role"] != "system":
         with st.chat_message(message["role"]):
-            st.write(message["content"])
-
-
-
+            st.write(message["content"])

+ 28 - 17
Mission909Solutions.ipynb

@@ -28,7 +28,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## API Variables"
+    "## Default Global Variables"
    ]
   },
   {
@@ -37,9 +37,12 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "api_key = os.environ[\"OPENAI_API_KEY\"] # or paste your API key here\n",
-    "base_url = \"https://api.openai.com/v1\"\n",
-    "model_name =\"gpt-3.5-turbo\""
+    "DEFAULT_API_KEY = os.environ.get(\"TOGETHER_API_KEY\")\n",
+    "DEFAULT_BASE_URL = \"https://api.together.xyz/v1\"\n",
+    "DEFAULT_MODEL = \"meta-llama/Llama-3-8b-chat-hf\"\n",
+    "DEFAULT_TEMPERATURE = 0.7\n",
+    "DEFAULT_MAX_TOKENS = 512\n",
+    "DEFAULT_TOKEN_BUDGET = 4096"
    ]
   },
   {
@@ -62,18 +65,26 @@
     "    \"\"\"\n",
     "\n",
     "    # The __init__ method stores the API key, the base URL, the default model, the default temperature, the default max tokens, and the token budget.\n",
-    "    def __init__(self, api_key=api_key, base_url=base_url, history_file=None, default_model=model_name, default_temperature=0.7, default_max_tokens=120, token_budget=1500):\n",
-    "        self.client = OpenAI(api_key=api_key, base_url=base_url)\n",
-    "        self.base_url = base_url\n",
+    "    def __init__(self, api_key=None, base_url=None, model=None, history_file=None, temperature=None, max_tokens=None, token_budget=None):\n",
+    "        if not api_key:\n",
+    "            api_key = DEFAULT_API_KEY\n",
+    "        if not base_url:\n",
+    "            base_url = DEFAULT_BASE_URL\n",
+    "            \n",
+    "        self.client = OpenAI(\n",
+    "            api_key=api_key,\n",
+    "            base_url=base_url\n",
+    "        )\n",
     "        if history_file is None:\n",
     "            timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
     "            self.history_file = f\"conversation_history_{timestamp}.json\"\n",
     "        else:\n",
     "            self.history_file = history_file\n",
-    "        self.default_model = default_model\n",
-    "        self.default_temperature = default_temperature\n",
-    "        self.default_max_tokens = default_max_tokens\n",
-    "        self.token_budget = token_budget\n",
+    "\n",
+    "        self.model = model if model else DEFAULT_MODEL\n",
+    "        self.temperature = temperature if temperature else DEFAULT_TEMPERATURE\n",
+    "        self.max_tokens = max_tokens if max_tokens else DEFAULT_MAX_TOKENS\n",
+    "        self.token_budget = token_budget if token_budget else DEFAULT_TOKEN_BUDGET\n",
     "\n",
     "        self.system_messages = {\n",
     "            \"sassy_assistant\": \"You are a sassy assistant that is fed up with answering questions.\",\n",
@@ -89,7 +100,7 @@
     "    # The count_tokens method counts the number of tokens in a text.\n",
     "    def count_tokens(self, text):\n",
     "        try:\n",
-    "            encoding = tiktoken.encoding_for_model(self.default_model)\n",
+    "            encoding = tiktoken.encoding_for_model(self.model)\n",
     "        except KeyError:\n",
     "            encoding = tiktoken.get_encoding(\"cl100k_base\")\n",
     "\n",
@@ -146,10 +157,10 @@
     "\n",
     "        try:\n",
     "            response = self.client.chat.completions.create(\n",
-    "                model=self.default_model,\n",
-    "                messages=self.conversation_history, # type: ignore\n",
-    "                temperature=self.default_temperature,\n",
-    "                max_tokens=self.default_max_tokens,\n",
+    "                model=self.model,\n",
+    "                messages=self.conversation_history,\n",
+    "                temperature=self.temperature,\n",
+    "                max_tokens=self.max_tokens,\n",
     "            )\n",
     "        except Exception as e:\n",
     "            print(f\"An error occurred while generating a response: {e}\")\n",
@@ -204,7 +215,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "conv_manager = ConversationManager(api_key)"
+    "conv_manager = ConversationManager()"
    ]
   },
   {