|
@@ -5,21 +5,29 @@ from datetime import datetime
|
|
|
import os
|
|
|
import streamlit as st
|
|
|
|
|
|
-api_key = os.environ["OPENAI_API_KEY"]
|
|
|
+DEFAULT_API_KEY = os.environ.get("TOGETHER_API_KEY")
|
|
|
+DEFAULT_BASE_URL = "https://api.together.xyz/v1"
|
|
|
+DEFAULT_MODEL = "meta-llama/Llama-3-8b-chat-hf"
|
|
|
+DEFAULT_TEMPERATURE = 0.7
|
|
|
+DEFAULT_MAX_TOKENS = 512
|
|
|
+DEFAULT_TOKEN_BUDGET = 4096
|
|
|
|
|
|
class ConversationManager:
|
|
|
- def __init__(self, api_key, base_url="https://api.openai.com/v1", history_file=None, default_model="gpt-3.5-turbo", default_temperature=0.7, default_max_tokens=150, token_budget=4096):
|
|
|
- self.client = OpenAI(api_key=api_key)
|
|
|
- self.base_url = base_url
|
|
|
- if history_file is None:
|
|
|
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
- self.history_file = f"conversation_history_{timestamp}.json"
|
|
|
- else:
|
|
|
- self.history_file = history_file
|
|
|
- self.default_model = default_model
|
|
|
- self.default_temperature = default_temperature
|
|
|
- self.default_max_tokens = default_max_tokens
|
|
|
- self.token_budget = token_budget
|
|
|
+ def __init__(self, api_key=None, base_url=None, model=None, history_file=None, temperature=None, max_tokens=None, token_budget=None):
|
|
|
+ if not api_key:
|
|
|
+ api_key = DEFAULT_API_KEY
|
|
|
+ if not base_url:
|
|
|
+ base_url = DEFAULT_BASE_URL
|
|
|
+
|
|
|
+ self.client = OpenAI(
|
|
|
+ api_key=api_key,
|
|
|
+ base_url=base_url
|
|
|
+ )
|
|
|
+
|
|
|
+ self.model = model if model else DEFAULT_MODEL
|
|
|
+ self.temperature = temperature if temperature else DEFAULT_TEMPERATURE
|
|
|
+ self.max_tokens = max_tokens if max_tokens else DEFAULT_MAX_TOKENS
|
|
|
+ self.token_budget = token_budget if token_budget else DEFAULT_TOKEN_BUDGET
|
|
|
|
|
|
self.system_messages = {
|
|
|
"sassy_assistant": "You are a sassy assistant that is fed up with answering questions.",
|
|
@@ -28,19 +36,18 @@ class ConversationManager:
|
|
|
"custom": "Enter your custom system message here."
|
|
|
}
|
|
|
self.system_message = self.system_messages["sassy_assistant"] # Default persona
|
|
|
-
|
|
|
- self.load_conversation_history()
|
|
|
+ self.conversation_history = [{"role": "system", "content": self.system_message}]
|
|
|
|
|
|
def count_tokens(self, text):
|
|
|
try:
|
|
|
- encoding = tiktoken.encoding_for_model(self.default_model)
|
|
|
+ encoding = tiktoken.encoding_for_model(self.model)
|
|
|
except KeyError:
|
|
|
- print(f"Warning: Model '{self.default_model}' not found. Using 'gpt-3.5-turbo' encoding as default.")
|
|
|
- encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
|
|
+ encoding = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
tokens = encoding.encode(text)
|
|
|
return len(tokens)
|
|
|
|
|
|
+
|
|
|
def total_tokens_used(self):
|
|
|
try:
|
|
|
return sum(self.count_tokens(message['content']) for message in self.conversation_history)
|
|
@@ -80,9 +87,9 @@ class ConversationManager:
|
|
|
print(f"An unexpected error occurred while updating the system message in the conversation history: {e}")
|
|
|
|
|
|
def chat_completion(self, prompt, temperature=None, max_tokens=None, model=None):
|
|
|
- temperature = temperature if temperature is not None else self.default_temperature
|
|
|
- max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
|
|
|
- model = model if model is not None else self.default_model
|
|
|
+ temperature = temperature if temperature is not None else self.temperature
|
|
|
+ max_tokens = max_tokens if max_tokens is not None else self.max_tokens
|
|
|
+ model = model if model is not None else self.model
|
|
|
|
|
|
self.conversation_history.append({"role": "user", "content": prompt})
|
|
|
|
|
@@ -101,35 +108,11 @@ class ConversationManager:
|
|
|
|
|
|
ai_response = response.choices[0].message.content
|
|
|
self.conversation_history.append({"role": "assistant", "content": ai_response})
|
|
|
- self.save_conversation_history()
|
|
|
|
|
|
return ai_response
|
|
|
|
|
|
- def load_conversation_history(self):
|
|
|
- try:
|
|
|
- with open(self.history_file, "r") as file:
|
|
|
- self.conversation_history = json.load(file)
|
|
|
- except FileNotFoundError:
|
|
|
- self.conversation_history = [{"role": "system", "content": self.system_message}]
|
|
|
- except json.JSONDecodeError:
|
|
|
- print("Error reading the conversation history file. Starting with an empty history.")
|
|
|
- self.conversation_history = [{"role": "system", "content": self.system_message}]
|
|
|
-
|
|
|
- def save_conversation_history(self):
|
|
|
- try:
|
|
|
- with open(self.history_file, "w") as file:
|
|
|
- json.dump(self.conversation_history, file, indent=4)
|
|
|
- except IOError as e:
|
|
|
- print(f"An I/O error occurred while saving the conversation history: {e}")
|
|
|
- except Exception as e:
|
|
|
- print(f"An unexpected error occurred while saving the conversation history: {e}")
|
|
|
-
|
|
|
def reset_conversation_history(self):
|
|
|
self.conversation_history = [{"role": "system", "content": self.system_message}]
|
|
|
- try:
|
|
|
- self.save_conversation_history() # Attempt to save the reset history to the file
|
|
|
- except Exception as e:
|
|
|
- print(f"An unexpected error occurred while resetting the conversation history: {e}")
|
|
|
|
|
|
### Streamlit code ###
|
|
|
st.title("Sassy Chatbot :face_with_rolling_eyes:")
|
|
@@ -139,7 +122,7 @@ st.sidebar.header("Options")
|
|
|
|
|
|
# Initialize the ConversationManager object
|
|
|
if 'chat_manager' not in st.session_state:
|
|
|
- st.session_state['chat_manager'] = ConversationManager(api_key)
|
|
|
+ st.session_state['chat_manager'] = ConversationManager()
|
|
|
|
|
|
chat_manager = st.session_state['chat_manager']
|
|
|
|
|
@@ -181,7 +164,4 @@ if user_input:
|
|
|
for message in conversation_history:
|
|
|
if message["role"] != "system":
|
|
|
with st.chat_message(message["role"]):
|
|
|
- st.write(message["content"])
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
+ st.write(message["content"])
|