|
| 1 | +from __future__ import annotations |
| 2 | +import time |
| 3 | + |
| 4 | +import json |
| 5 | +import logging |
| 6 | +from pathlib import Path |
| 7 | +from typing import Any, Dict, Optional |
| 8 | +import base64 |
| 9 | +from io import BytesIO |
| 10 | +import qrcode |
| 11 | +import requests |
| 12 | + |
| 13 | + |
| 14 | +import oci |
| 15 | +from flask import Flask, jsonify, request, send_from_directory |
| 16 | +from flask_cors import CORS |
| 17 | +from oci.generative_ai_agent_runtime.models import ChatDetails, CreateSessionDetails, FunctionCallingPerformedAction |
| 18 | + |
| 19 | +# --------------------------------------------------------------------------- |
| 20 | +# Configuration UPDATE with your endpoint OCID and your Service Endpoint |
| 21 | +# --------------------------------------------------------------------------- |
| 22 | + |
| 23 | +AGENT_ENDPOINT_ID = "ocid1.genaiagentendpoint.oc1.eu-frankfurt-1." |
| 24 | +SERVICE_ENDPOINT = "https://agent-runtime.generativeai.eu-frankfurt-1.oci.oraclecloud.com" |
| 25 | +PROFILE_NAME = "DEFAULT" # profile in ~/.oci/config |
| 26 | + |
| 27 | +logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=logging.INFO) |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | +class ChatAgent: |
| 31 | + |
| 32 | + def __init__(self) -> None: |
| 33 | + cfg = oci.config.from_file("~/.oci/config", profile_name=PROFILE_NAME) |
| 34 | + self.client = oci.generative_ai_agent_runtime.GenerativeAiAgentRuntimeClient( |
| 35 | + config=cfg, |
| 36 | + service_endpoint=SERVICE_ENDPOINT, |
| 37 | + retry_strategy=oci.retry.NoneRetryStrategy(), |
| 38 | + timeout=(10, 240), |
| 39 | + ) |
| 40 | + self.session_id: Optional[str] = None |
| 41 | + self.current_user: Optional[str] = None |
| 42 | + # ------------------------------------------------------------------ |
| 43 | + def _create_session(self, *, user: str) -> str: |
| 44 | + details = CreateSessionDetails( |
| 45 | + display_name=f"Flask-session for {user}", |
| 46 | + description="Session created by flask_chat_server.py", |
| 47 | + ) |
| 48 | + resp = self.client.create_session(details, AGENT_ENDPOINT_ID) |
| 49 | + self.session_id = resp.data.id |
| 50 | + self.current_user = user |
| 51 | + logger.debug("Created new session %s for %s", self.session_id, user) |
| 52 | + return self.session_id |
| 53 | + |
| 54 | + # ------------------------------------------------------------------ |
| 55 | + # QR |
| 56 | + @staticmethod |
| 57 | + def generate_qr_code(url: str, size: int = 200) -> dict: |
| 58 | + """Generates a QR code for the given URL and returns it as a base64-encoded PNG.""" |
| 59 | + qr = qrcode.QRCode(box_size=10, border=2) |
| 60 | + qr.add_data(url) |
| 61 | + qr.make(fit=True) |
| 62 | + img = qr.make_image(fill="black", back_color="white").resize((size, size)) |
| 63 | + buffer = BytesIO() |
| 64 | + img.save(buffer, format="PNG") |
| 65 | + b64 = base64.b64encode(buffer.getvalue()).decode() |
| 66 | + return {"image_base64": b64} |
| 67 | + |
| 68 | + # ------------------------------------------------------------------ |
| 69 | + @staticmethod |
| 70 | + def _extract_text(message_obj: Any) -> str: |
| 71 | + try: |
| 72 | + content_list = getattr(message_obj, "content", None) |
| 73 | + if content_list: |
| 74 | + first = content_list[0] if isinstance(content_list, list) else content_list |
| 75 | + return getattr(first, "text", str(first)) |
| 76 | + return str(message_obj) |
| 77 | + except Exception as exc: |
| 78 | + logger.warning("Failed to extract text: %s", exc) |
| 79 | + return str(message_obj) |
| 80 | + |
| 81 | + def convert_sql_to_human_readable(self, sql_output: str, user: str) -> str: |
| 82 | + """Convert SQL output to human readable format using OCI LLM.""" |
| 83 | + try: |
| 84 | + if len(sql_output) > 1000: |
| 85 | + return f"Hi {user}, the data you requested is quite large. Could you please specify what specific information you're looking for? This will help me provide you with a more focused and useful response." |
| 86 | + |
| 87 | + prompt = f"""Please convert this SQL output into a friendly, human-readable format. |
| 88 | + Address the user as '{user}' once and present the information in a clear, conversational way. |
| 89 | + If it's a list or table, format it nicely Please do not provide in the answer SQL query that was passed only human nice language. |
| 90 | + Here's the SQL output: |
| 91 | + {sql_output}""" |
| 92 | + |
| 93 | + chat_details = ChatDetails( |
| 94 | + user_message=prompt, |
| 95 | + session_id=self.session_id, |
| 96 | + should_stream=False |
| 97 | + ) |
| 98 | + reply = self.client.chat(AGENT_ENDPOINT_ID, chat_details) |
| 99 | + return self._extract_text(reply.data.message) |
| 100 | + |
| 101 | + except Exception as e: |
| 102 | + logger.error(f"Error converting SQL output: {e}") |
| 103 | + return f"Hi {user}, I encountered an error while processing your data. Please try again." |
| 104 | + |
| 105 | + def _is_sql_output(self, text: str) -> bool: |
| 106 | + """Check if the text appears to be SQL output.""" |
| 107 | + sql_indicators = [ |
| 108 | + "SELECT", "FROM", "WHERE", "JOIN", "GROUP BY", "ORDER BY", |
| 109 | + "INSERT INTO", "UPDATE", "DELETE FROM", "CREATE TABLE" |
| 110 | + ] |
| 111 | + return any(indicator in text.upper() for indicator in sql_indicators) |
| 112 | + |
| 113 | + # ------------------------------------------------------------------ |
| 114 | + def send(self, *, prompt: str, user: str, max_tokens: Optional[int] = None) -> str: |
| 115 | + if self.session_id is None or self.current_user != user: |
| 116 | + self._create_session(user=user) |
| 117 | + |
| 118 | + weather_response = self.process_weather_query(prompt) |
| 119 | + if weather_response: |
| 120 | + return weather_response |
| 121 | + |
| 122 | + filter_conditions: Dict[str, Any] = { |
| 123 | + "filterConditions": [ |
| 124 | + { |
| 125 | + "field": "person", |
| 126 | + "field_type": "list_of_string", |
| 127 | + "operation": "contains", |
| 128 | + "value": user, |
| 129 | + } |
| 130 | + ] |
| 131 | + } |
| 132 | + |
| 133 | + logger.info("Applying RAG filter with author='%s'", user) |
| 134 | + |
| 135 | + tool_params: Dict[str, str] = {"rag": json.dumps(filter_conditions)} |
| 136 | + if max_tokens is not None: |
| 137 | + tool_params["max_tokens"] = str(max_tokens) |
| 138 | + |
| 139 | + chat_details = ChatDetails( |
| 140 | + user_message=prompt, |
| 141 | + session_id=self.session_id, |
| 142 | + should_stream=False, |
| 143 | + tool_parameters=tool_params, |
| 144 | + ) |
| 145 | + reply = self.client.chat(AGENT_ENDPOINT_ID, chat_details) |
| 146 | + data = reply.data |
| 147 | + |
| 148 | + response_text = self._extract_text(data.message) |
| 149 | + |
| 150 | + if self._is_sql_output(response_text): |
| 151 | + return self.convert_sql_to_human_readable(response_text, user) |
| 152 | + |
| 153 | + if getattr(data, "required_actions", None): |
| 154 | + action = data.required_actions[0] |
| 155 | + func_call = action.function_call |
| 156 | + args = json.loads(func_call.arguments) |
| 157 | + if func_call.name == "generateQrCode": |
| 158 | + result = self.generate_qr_code(**args) |
| 159 | + performed = FunctionCallingPerformedAction( |
| 160 | + action_id=action.action_id, |
| 161 | + function_call_output=json.dumps(result) |
| 162 | + ) |
| 163 | + final_reply = self.client.chat( |
| 164 | + AGENT_ENDPOINT_ID, |
| 165 | + ChatDetails( |
| 166 | + session_id=self.session_id, |
| 167 | + user_message="", |
| 168 | + performed_actions=[performed], |
| 169 | + should_stream=False, |
| 170 | + tool_parameters=tool_params, |
| 171 | + ) |
| 172 | + ) |
| 173 | + |
| 174 | + raw = self._extract_text(final_reply.data.message) |
| 175 | + payload = result["image_base64"] |
| 176 | + if payload in raw: |
| 177 | + raw = raw.replace(payload, "").strip() |
| 178 | + if not raw: |
| 179 | + raw = "Here's your QR code!" |
| 180 | + return { |
| 181 | + "message": raw, |
| 182 | + "qr_image_base64": payload |
| 183 | + } |
| 184 | + |
| 185 | + return self._extract_text(reply.data.message) |
| 186 | + |
| 187 | + |
| 188 | +# --------------------------------------------------------------------------- |
| 189 | +# Flask wiring |
| 190 | +# --------------------------------------------------------------------------- |
| 191 | + |
| 192 | +app = Flask(__name__) |
| 193 | +CORS(app) |
| 194 | +chat_agent = ChatAgent() |
| 195 | + |
| 196 | +@app.route("/") |
| 197 | +def index(): |
| 198 | + return send_from_directory(Path(app.root_path), "index.html") |
| 199 | + |
| 200 | +@app.route("/config.js") |
| 201 | +def config_js(): |
| 202 | + return send_from_directory(Path(app.root_path), "config.js") |
| 203 | + |
| 204 | +@app.route("/chat", methods=["POST"]) |
| 205 | +def chat() -> tuple[Any, int] | Any: |
| 206 | + data = request.get_json(silent=True) or {} |
| 207 | + prompt_raw = str(data.get("message", "").strip()) |
| 208 | + user_raw = str(data.get("user", "anonymous").strip() or "anonymous") |
| 209 | + |
| 210 | + #user = user_raw[0].upper() + user_raw[1:] if user_raw else user_raw |
| 211 | + user = user_raw |
| 212 | + max_tokens_raw = data.get("max_tokens") |
| 213 | + |
| 214 | + if not prompt_raw: |
| 215 | + return jsonify({"error": "Missing 'message'"}), 400 |
| 216 | + |
| 217 | + try: |
| 218 | + max_tokens = int(max_tokens_raw) if max_tokens_raw is not None else None |
| 219 | + except (TypeError, ValueError): |
| 220 | + return jsonify({"error": "'max_tokens' must be an integer"}), 400 |
| 221 | + |
| 222 | + try: |
| 223 | + reply = chat_agent.send(prompt=prompt_raw, user=user, max_tokens=max_tokens) |
| 224 | + if isinstance(reply, dict): |
| 225 | + return jsonify({"user": user, **reply}) |
| 226 | + else: |
| 227 | + return jsonify({"user": user, "message": reply}) |
| 228 | + except Exception as exc: |
| 229 | + logger.exception("/chat failed: %s", exc) |
| 230 | + return jsonify({"error": str(exc)}), 500 |
| 231 | + |
| 232 | + |
| 233 | +if __name__ == "__main__": |
| 234 | + app.run(debug=True, port=5000) |
0 commit comments