|
@@ -28,7 +28,7 @@
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "## API Variables"
|
|
|
+ "## Default Global Variables"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -37,9 +37,12 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "api_key = os.environ[\"OPENAI_API_KEY\"] # or paste your API key here\n",
|
|
|
- "base_url = \"https://api.openai.com/v1\"\n",
|
|
|
- "model_name =\"gpt-3.5-turbo\""
|
|
|
+ "DEFAULT_API_KEY = os.environ.get(\"TOGETHER_API_KEY\")\n",
|
|
|
+ "DEFAULT_BASE_URL = \"https://api.together.xyz/v1\"\n",
|
|
|
+ "DEFAULT_MODEL = \"meta-llama/Llama-3-8b-chat-hf\"\n",
|
|
|
+ "DEFAULT_TEMPERATURE = 0.7\n",
|
|
|
+ "DEFAULT_MAX_TOKENS = 512\n",
|
|
|
+ "DEFAULT_TOKEN_BUDGET = 4096"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -62,18 +65,26 @@
|
|
|
" \"\"\"\n",
|
|
|
"\n",
|
|
|
" # The __init__ method stores the API key, the base URL, the default model, the default temperature, the default max tokens, and the token budget.\n",
|
|
|
- " def __init__(self, api_key=api_key, base_url=base_url, history_file=None, default_model=model_name, default_temperature=0.7, default_max_tokens=120, token_budget=1500):\n",
|
|
|
- " self.client = OpenAI(api_key=api_key, base_url=base_url)\n",
|
|
|
- " self.base_url = base_url\n",
|
|
|
+ " def __init__(self, api_key=None, base_url=None, model=None, history_file=None, temperature=None, max_tokens=None, token_budget=None):\n",
|
|
|
+ " if not api_key:\n",
|
|
|
+ " api_key = DEFAULT_API_KEY\n",
|
|
|
+ " if not base_url:\n",
|
|
|
+ " base_url = DEFAULT_BASE_URL\n",
|
|
|
+ " \n",
|
|
|
+ " self.client = OpenAI(\n",
|
|
|
+ " api_key=api_key,\n",
|
|
|
+ " base_url=base_url\n",
|
|
|
+ " )\n",
|
|
|
" if history_file is None:\n",
|
|
|
" timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
|
|
|
" self.history_file = f\"conversation_history_{timestamp}.json\"\n",
|
|
|
" else:\n",
|
|
|
" self.history_file = history_file\n",
|
|
|
- " self.default_model = default_model\n",
|
|
|
- " self.default_temperature = default_temperature\n",
|
|
|
- " self.default_max_tokens = default_max_tokens\n",
|
|
|
- " self.token_budget = token_budget\n",
|
|
|
+ "\n",
|
|
|
+ " self.model = model if model else DEFAULT_MODEL\n",
|
|
|
+ " self.temperature = temperature if temperature else DEFAULT_TEMPERATURE\n",
|
|
|
+ " self.max_tokens = max_tokens if max_tokens else DEFAULT_MAX_TOKENS\n",
|
|
|
+ " self.token_budget = token_budget if token_budget else DEFAULT_TOKEN_BUDGET\n",
|
|
|
"\n",
|
|
|
" self.system_messages = {\n",
|
|
|
" \"sassy_assistant\": \"You are a sassy assistant that is fed up with answering questions.\",\n",
|
|
@@ -89,7 +100,7 @@
|
|
|
" # The count_tokens method counts the number of tokens in a text.\n",
|
|
|
" def count_tokens(self, text):\n",
|
|
|
" try:\n",
|
|
|
- " encoding = tiktoken.encoding_for_model(self.default_model)\n",
|
|
|
+ " encoding = tiktoken.encoding_for_model(self.model)\n",
|
|
|
" except KeyError:\n",
|
|
|
" encoding = tiktoken.get_encoding(\"cl100k_base\")\n",
|
|
|
"\n",
|
|
@@ -146,10 +157,10 @@
|
|
|
"\n",
|
|
|
" try:\n",
|
|
|
" response = self.client.chat.completions.create(\n",
|
|
|
- " model=self.default_model,\n",
|
|
|
- " messages=self.conversation_history, # type: ignore\n",
|
|
|
- " temperature=self.default_temperature,\n",
|
|
|
- " max_tokens=self.default_max_tokens,\n",
|
|
|
+ " model=self.model,\n",
|
|
|
+ " messages=self.conversation_history,\n",
|
|
|
+ " temperature=self.temperature,\n",
|
|
|
+ " max_tokens=self.max_tokens,\n",
|
|
|
" )\n",
|
|
|
" except Exception as e:\n",
|
|
|
" print(f\"An error occurred while generating a response: {e}\")\n",
|
|
@@ -204,7 +215,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "conv_manager = ConversationManager(api_key)"
|
|
|
+ "conv_manager = ConversationManager()"
|
|
|
]
|
|
|
},
|
|
|
{
|