ThalamOS
a powerful Flask web application designed to enhance your storage management.
Loading...
Searching...
No Matches
ollama_manager.py
Go to the documentation of this file.
1"""
2This module manages interactions with the Ollama model and the SQLite database.
3
4It includes functions to set up the Ollama pipeline, check if the Ollama feature is enabled,
5fetch Ollama models, convert the storage table to a CSV file, and ask questions to the Ollama model.
6
7Classes:
8 SQLQuery: A class to execute SQL queries on the SQLite database.
9
10Functions:
11 setup_ollama: Sets up the Ollama pipeline.
12 pre_check_ollama_enabled: Checks if the Ollama feature is enabled.
13 check_ollama_enabled: A decorator to check if Ollama is enabled before executing a function.
14 get_ollama_models: Fetches a list of Ollama models from a specified API endpoint.
15 storage_table_to_csv: Converts the 'storage' table from the SQLite database to a CSV file.
16 ask_question: Asks a question to the Ollama model using a haystack pipeline.
17"""
18from functools import wraps
19import os
20from typing import Annotated, List
21import sqlite3
22import requests
23from dotenv import load_dotenv
24from haystack import Pipeline, component
25from haystack.components.builders import PromptBuilder
26from haystack.components.routers import ConditionalRouter
27from haystack_integrations.components.generators.ollama import OllamaGenerator
28import pandas as pd
29
30from logger_config import logger
31
32ENV_PATH: Annotated[str, "path to environment variables"] = os.path.join(
33 os.path.dirname(__file__), "data/.env"
34)
35load_dotenv(dotenv_path=ENV_PATH)
36is_ollama_enabled: Annotated[bool, "environment variable for ollama enabled"] = (
37 os.getenv("IS_OLLAMA_ENABLED", "false").lower() == "true"
38)
39
40ollama_host: Annotated[str, "environment variable for ollama host"] = os.getenv(
41 "OLLAMA_HOST"
42)
43default_model: Annotated[str, "environment variable for default model"] = os.getenv(
44 "DEFAULT_MODEL"
45)
46
47DATABASE_PATH: Annotated[str, "path to database"] = os.path.join(
48 os.path.dirname(__file__), "data/storage.db"
49)
50
51prompt_instance = PromptBuilder(
52 template="""The table **`storage`** contains information about items stored inside a shelf. It has the following columns: **{{columns}}**.
53
54### Rules for Generating an SQL Query:
551. **Only generate a query if the question is directly answerable using only the `storage` table**.
562. **First, check if the question is about stored items**. If the question is about anything else (e.g., news, weather, prices, etc.), return `"no_answer"` and **do not generate a query**.
573. **The `info` column contains JSON data**. To access specific fields within the JSON, use the appropriate SQL functions for JSON extraction:
58 - For **SQLLite**, use `JSON_EXTRACT(info, '$.field_name')` to extract a field.
594. **If the question asks for ordering based on a field inside the `info` JSON** (like `length`), **extract the field from the JSON** and **order by it** (use `CAST(info ->> 'length' AS INTEGER)` for PostgreSQL/SQLite or `CAST(JSON_EXTRACT(info, '$.length') AS UNSIGNED)` for MySQL).
605. **Do not use columns like `id` or `position` for ordering unless explicitly mentioned**. If the question is about sorting based on a JSON value (like `length`), **use the JSON extraction in the `ORDER BY` clause**.
616. **If the question asks for the "longest screw" or something similar**, order by `length` from the JSON data, **not by `id`** or other irrelevant columns.
627. **If the question cannot be answered with the given columns**, return exactly `"no_answer"` (without explanation).
638. **Do not attempt to match unrelated concepts to column names**.
649. **Do not modify the database (no DELETE, UPDATE, or INSERT operations).**
6510. **Ensure the query returns the entire row of the matched item**.
6611. **Ensure the SQL syntax is correct and valid for SQLite**.
6712. **The possible values for the `type` column are: `screw`, `nail`, `display`, `cable`, `misc`, `motor-driver`**. These types are case-sensitive.
6813. Always take other columns then info into account when answering the question. For example name or type.
69
70
71**Output (only one of the following):**
72- A valid **SQL query**, that returns the row, that matches the user Request, if and only if the question is **directly answerable**. **Do not output anything except of the sql query**.
73- `"no_answer"` (exactly this string) if the question is irrelevant or unanswerable.
74
75
76**Question:** {{question}}
77"""
78)
79
80fallback_prompt_instance = PromptBuilder(
81 template="""User entered a query that cannot be answerwed with the given table.
82 The query was: {{question}} and the table had columns: {{columns}}.
83 Let the user know why the question cannot be answered using the table, but try it to answer with your general knowledge."""
84)
85
86
87@component
89 """
90 A component to execute SQL queries on the SQLite database.
91
92 Attributes:
93 connection (sqlite3.Connection): The connection to the SQLite database.
94
95 Methods:
96 run(queries: List[str]) -> dict: Executes the provided SQL queries and returns the results.
97 """
98
99 def __init__(self, sql_database: str):
100 self.connection = sqlite3.connect(f"file:{sql_database}?mode=ro", uri=True)
101
102 @component.output_types(results=List[str], queries=List[str])
103 def run(self, queries: List[str]):
104 """
105 Executes the provided SQL queries and returns the results. Tests if the SQL queries are valid before execution.
106
107 Args:
108 queries (List[str]): A list of SQL queries to be executed.
109
110 Returns:
111 dict: A dictionary containing the results of the executed queries and the original queries.
112 """
113 results = []
114 for query in queries:
115 try:
116 print(f"query: {query}")
117 result = pd.read_sql(query, self.connection).to_json(orient="records")
118 results.append(result)
119 except ValueError as e:
120 logger.error(f"Error parsing SQL query: {e}")
121 return {"results": ["error"], "queries": queries}
122
123 return {"results": results, "queries": queries}
124
125 def __str__(self):
126 return "<SQLQuery Object>"
127
128
130 """
131 Sets up the Ollama pipeline with the necessary components and connections.
132
133 This function initializes the SQLQuery component, defines routing conditions,
134 and sets up the ConditionalRouter, OllamaGenerator, and fallback components.
135 It then creates a Pipeline, adds the components to it, and connects them
136 according to the defined routes.
137
138 Returns:
139 Pipeline: The configured Ollama pipeline.
140 """
141 sql_query = SQLQuery(DATABASE_PATH)
142
143 routes = [
144 {
145 "condition": "{{'no_answer' not in replies[0]}}",
146 "output": "{{replies}}",
147 "output_name": "sql",
148 "output_type": List[str],
149 },
150 {
151 "condition": "{{'no_answer'|lower in replies[0]|lower}}",
152 "output": "{{question}}",
153 "output_name": "go_to_fallback",
154 "output_type": str,
155 },
156 ]
157 router = ConditionalRouter(routes)
158 llm = OllamaGenerator(model=default_model, url=ollama_host)
159 fallback_llm = OllamaGenerator(model=default_model, url=ollama_host)
160
161 conditional_sql_pipeline = Pipeline()
162 conditional_sql_pipeline.add_component("prompt", prompt_instance)
163 conditional_sql_pipeline.add_component("llm", llm)
164 conditional_sql_pipeline.add_component("router", router)
165 conditional_sql_pipeline.add_component("fallback_prompt", fallback_prompt_instance)
166 conditional_sql_pipeline.add_component("fallback_llm", fallback_llm)
167 conditional_sql_pipeline.add_component("sql_querier", sql_query)
168
169 conditional_sql_pipeline.connect("prompt", "llm")
170 conditional_sql_pipeline.connect("llm.replies", "router.replies")
171 conditional_sql_pipeline.connect("router.sql", "sql_querier.queries")
172 conditional_sql_pipeline.connect(
173 "router.go_to_fallback", "fallback_prompt.question"
174 )
175 conditional_sql_pipeline.connect("fallback_prompt", "fallback_llm")
176
177 return conditional_sql_pipeline
178
179
181 """
182 Checks if the Ollama feature is enabled.
183 Returns:
184 bool: True if the Ollama feature is enabled, False otherwise.
185 """
186 if is_ollama_enabled:
187 return True
188 return False
189
190
192 """
193 Decorator to check if Ollama is enabled before executing the function.
194 This decorator wraps a function and checks if Ollama is enabled by calling
195 the `pre_check_ollama_enabled` function. If Ollama is enabled, the wrapped
196 function is executed. Otherwise, a log message is generated, and the function
197 execution is skipped.
198 Args:
199 func (callable): The function to be wrapped by the decorator.
200 Returns:
201 callable: The wrapped function that includes the Ollama enabled check.
202 """
203
204 @wraps(func)
205 def wrapper(*args, **kwargs):
207 return func(*args, **kwargs)
208
209 logger.info(
210 f"Ollama is not enabled. Execution of function {func.__name__} skipped."
211 )
212 return None
213
214 return wrapper
215
216
217@check_ollama_enabled
219 """
220 Fetches a list of Ollama models from a specified API endpoint.
221 This function sends a GET request to the API endpoint at "http://10.45.2.60:11434/api/tags",
222 retrieves the JSON response, and extracts the list of models from the response data.
223 Returns:
224 list: A list of model names (strings) retrieved from the API response.
225 """
226 url = f"{ollama_host}/api/tags"
227 response = requests.get(url, timeout=10)
228 data = response.json()
229 models = data.get("models", [])
230 model_list = []
231 for model in models:
232 model_list.append(model["model"])
233 return model_list
234
235
236def storage_table_to_csv(path: str) -> pd.DataFrame:
237 """
238 Converts the 'storage' table from the SQLite database to a CSV file.
239
240 Args:
241 path (str): The file path to the SQLite database.
242
243 Returns:
244 pd.DataFrame: A DataFrame containing the data from the 'storage' table.
245 """
246 conn = sqlite3.connect(path)
247 table = pd.read_sql_query("SELECT * FROM storage", conn)
248 conn.close()
249 return table
250
251
252@check_ollama_enabled
254 """
255 Asks a question to the Ollama model using a haystack pipeline.
256
257 Args:
258 msg (str): The question to ask the Ollama model.
259
260 Returns:
261 tuple: A tuple containing the type of response ("Item" or "Fallback") and the response itself.
262 """
263 table = storage_table_to_csv(DATABASE_PATH)
264 columns = table.columns.tolist()
265 result = global_Pipeline.run(
266 {
267 "prompt": {"question": msg, "columns": columns},
268 "router": {"question": msg},
269 "fallback_prompt": {"columns": columns},
270 }
271 )
272
273 if "sql_querier" in result:
274 result = result["sql_querier"]["results"][0]
275 logger.info(
276 f"llm answered with the following SQL query: {result} with type {type(result)}"
277 )
278 return "Item", result
279 if "fallback_llm" in result:
280 result = result["fallback_llm"]["replies"][0]
281 logger.info(
282 f"llm answered with the following fallback: {result} with type {type(result)}"
283 )
284 return "Fallback", result
285
286
287global_Pipeline = setup_ollama()
run(self, List[str] queries)
__init__(self, str sql_database)
pd.DataFrame storage_table_to_csv(str path)